diff --git a/lib/client_factory.ts b/lib/client_factory.ts index 8bbc688ef..3be99b554 100644 --- a/lib/client_factory.ts +++ b/lib/client_factory.ts @@ -57,7 +57,8 @@ export const getOptimizelyInstance = (config: OptimizelyFactoryConfig): Optimize retryConfig: { maxRetries: DEFAULT_CMAB_RETRIES, backoffProvider: () => new ConstantBackoff(DEFAULT_CMAB_BACKOFF_MS), - } + }, + predictionEndpointTemplate: config.cmab?.predictionEndpointTemplate, }); const cmabCache: CacheWithRemove = config.cmab?.cache ? diff --git a/lib/core/decision_service/cmab/cmab_client.spec.ts b/lib/core/decision_service/cmab/cmab_client.spec.ts index 04c7246ca..88d258892 100644 --- a/lib/core/decision_service/cmab/cmab_client.spec.ts +++ b/lib/core/decision_service/cmab/cmab_client.spec.ts @@ -354,4 +354,50 @@ describe('DefaultCmabClient', () => { await expect(cmabClient.fetchDecision(ruleId, userId, attributes, cmabUuid)).rejects.toThrow('error'); }); + + it('should use custom prediction endpoint template when provided', async () => { + const requestHandler = getMockRequestHandler(); + + const mockMakeRequest: MockInstance = requestHandler.makeRequest; + mockMakeRequest.mockReturnValue(getMockAbortableRequest(mockSuccessResponse('var456'))); + + const customEndpoint = 'https://custom.example.com/predict/%s'; + const cmabClient = new DefaultCmabClient({ + requestHandler, + predictionEndpointTemplate: customEndpoint, + }); + const ruleId = '789'; + const userId = 'user789'; + const attributes = { + browser: 'firefox', + }; + const cmabUuid = 'uuid789'; + const variation = await cmabClient.fetchDecision(ruleId, userId, attributes, cmabUuid); + const [requestUrl] = mockMakeRequest.mock.calls[0]; + + expect(variation).toBe('var456'); + expect(mockMakeRequest.mock.calls.length).toBe(1); + expect(requestUrl).toBe('https://custom.example.com/predict/789'); + }); + + it('should use default prediction endpoint template when not provided', async () => { + const requestHandler = getMockRequestHandler(); + const mockMakeRequest: MockInstance = requestHandler.makeRequest; + mockMakeRequest.mockReturnValue(getMockAbortableRequest(mockSuccessResponse('var999'))); + const cmabClient = new DefaultCmabClient({ + requestHandler, + }); + const ruleId = '555'; + const userId = 'user555'; + const attributes = { + browser: 'safari', + }; + const cmabUuid = 'uuid555'; + const variation = await cmabClient.fetchDecision(ruleId, userId, attributes, cmabUuid); + const [requestUrl] = mockMakeRequest.mock.calls[0]; + + expect(variation).toBe('var999'); + expect(mockMakeRequest.mock.calls.length).toBe(1); + expect(requestUrl).toBe('https://prediction.cmab.optimizely.com/predict/555'); + }); }); diff --git a/lib/core/decision_service/cmab/cmab_client.ts b/lib/core/decision_service/cmab/cmab_client.ts index efe3a72ed..a6925713a 100644 --- a/lib/core/decision_service/cmab/cmab_client.ts +++ b/lib/core/decision_service/cmab/cmab_client.ts @@ -33,7 +33,7 @@ export interface CmabClient { ): Promise } -const CMAB_PREDICTION_ENDPOINT = 'https://prediction.cmab.optimizely.com/predict/%s'; +const DEFAULT_CMAB_PREDICTION_ENDPOINT = 'https://prediction.cmab.optimizely.com/predict/%s'; export type RetryConfig = { maxRetries: number, @@ -41,17 +41,22 @@ export type RetryConfig = { } export type CmabClientConfig = { - requestHandler: RequestHandler, + requestHandler: RequestHandler; retryConfig?: RetryConfig; + predictionEndpointTemplate?: string; } export class DefaultCmabClient implements CmabClient { private requestHandler: RequestHandler; private retryConfig?: RetryConfig; + private predictionEndpointTemplate: string = DEFAULT_CMAB_PREDICTION_ENDPOINT; constructor(config: CmabClientConfig) { this.requestHandler = config.requestHandler; this.retryConfig = config.retryConfig; + if (config.predictionEndpointTemplate) { + this.predictionEndpointTemplate = config.predictionEndpointTemplate; + } } async fetchDecision( @@ -60,7 +65,7 @@ export class DefaultCmabClient implements CmabClient { attributes: UserAttributes, cmabUuid: string, ): Promise { - const url = sprintf(CMAB_PREDICTION_ENDPOINT, ruleId); + const url = sprintf(this.predictionEndpointTemplate, ruleId); const cmabAttributes = Object.keys(attributes).map((key) => ({ id: key, diff --git a/lib/shared_types.ts b/lib/shared_types.ts index ef4221db3..1b450b1dd 100644 --- a/lib/shared_types.ts +++ b/lib/shared_types.ts @@ -402,6 +402,7 @@ export interface Config { cacheSize?: number; cacheTtl?: number; cache?: CacheWithRemove; + predictionEndpointTemplate?: string; } }