diff --git a/x-pack/plugins/enterprise_search/server/lib/ml/ml_model_deployment_common.ts b/x-pack/plugins/enterprise_search/server/lib/ml/ml_model_deployment_common.ts index f97362725bac5..9465a94301443 100644 --- a/x-pack/plugins/enterprise_search/server/lib/ml/ml_model_deployment_common.ts +++ b/x-pack/plugins/enterprise_search/server/lib/ml/ml_model_deployment_common.ts @@ -11,7 +11,7 @@ import { isResourceNotFoundException, } from '../../utils/identify_exceptions'; -export const acceptableModelNames = ['.elser_model_1_SNAPSHOT']; +export const acceptableModelNames = ['.elser_model_1', '.elser_model_1_SNAPSHOT']; export function isNotFoundExceptionError(error: unknown): boolean { return ( diff --git a/x-pack/plugins/enterprise_search/server/lib/ml/start_ml_model_deployment.test.ts b/x-pack/plugins/enterprise_search/server/lib/ml/start_ml_model_deployment.test.ts index 827f41b83f575..ae11a89ed5ac0 100644 --- a/x-pack/plugins/enterprise_search/server/lib/ml/start_ml_model_deployment.test.ts +++ b/x-pack/plugins/enterprise_search/server/lib/ml/start_ml_model_deployment.test.ts @@ -15,7 +15,8 @@ import * as mockGetStatus from './get_ml_model_deployment_status'; import { startMlModelDeployment } from './start_ml_model_deployment'; describe('startMlModelDeployment', () => { - const knownModelName = '.elser_model_1_SNAPSHOT'; + const productionModelName = '.elser_model_1'; + const snapshotModelName = '.elser_model_1_SNAPSHOT'; const mockTrainedModelsProvider = { getTrainedModels: jest.fn(), getTrainedModelsStats: jest.fn(), @@ -27,7 +28,7 @@ describe('startMlModelDeployment', () => { }); it('should error when there is no trained model provider', () => { - expect(() => startMlModelDeployment(knownModelName, undefined)).rejects.toThrowError( + expect(() => startMlModelDeployment(productionModelName, undefined)).rejects.toThrowError( 'Machine Learning is not enabled' ); }); @@ -49,7 +50,7 @@ describe('startMlModelDeployment', () => { jest.spyOn(mockGetStatus, 'getMlModelDeploymentStatus').mockReturnValueOnce( Promise.resolve({ deploymentState: MlModelDeploymentState.Starting, - modelId: knownModelName, + modelId: productionModelName, nodeAllocationCount: 0, startTime: 123456, targetAllocationCount: 3, @@ -57,7 +58,26 @@ describe('startMlModelDeployment', () => { ); const response = await startMlModelDeployment( - knownModelName, + productionModelName, + mockTrainedModelsProvider as unknown as MlTrainedModels + ); + + expect(response.deploymentState).toEqual(MlModelDeploymentState.Starting); + }); + + it('should return the deployment state if not "downloaded" for snapshot model', async () => { + jest.spyOn(mockGetStatus, 'getMlModelDeploymentStatus').mockReturnValueOnce( + Promise.resolve({ + deploymentState: MlModelDeploymentState.Starting, + modelId: snapshotModelName, + nodeAllocationCount: 0, + startTime: 123456, + targetAllocationCount: 3, + }) + ); + + const response = await startMlModelDeployment( + snapshotModelName, mockTrainedModelsProvider as unknown as MlTrainedModels ); @@ -70,7 +90,7 @@ describe('startMlModelDeployment', () => { .mockReturnValueOnce( Promise.resolve({ deploymentState: MlModelDeploymentState.Downloaded, - modelId: knownModelName, + modelId: productionModelName, nodeAllocationCount: 0, startTime: 123456, targetAllocationCount: 3, @@ -79,7 +99,7 @@ describe('startMlModelDeployment', () => { .mockReturnValueOnce( Promise.resolve({ deploymentState: MlModelDeploymentState.Starting, - modelId: knownModelName, + modelId: productionModelName, nodeAllocationCount: 0, startTime: 123456, targetAllocationCount: 3, @@ -88,7 +108,7 @@ describe('startMlModelDeployment', () => { mockTrainedModelsProvider.startTrainedModelDeployment.mockImplementation(async () => {}); const response = await startMlModelDeployment( - knownModelName, + productionModelName, mockTrainedModelsProvider as unknown as MlTrainedModels ); expect(response.deploymentState).toEqual(MlModelDeploymentState.Starting);