From 77f9279e5bac2cc763594c61d9c9cea1d65d8d70 Mon Sep 17 00:00:00 2001 From: Kibana Machine <42973632+kibanamachine@users.noreply.github.com> Date: Thu, 27 Apr 2023 14:33:32 -0400 Subject: [PATCH] [8.8] [Enterprise Search] Add Production Model Name to Allowed List for ELSER API Endpoints (#156020) (#156062) # Backport This will backport the following commits from `main` to `8.8`: - [[Enterprise Search] Add Production Model Name to Allowed List for ELSER API Endpoints (#156020)](https://github.com/elastic/kibana/pull/156020) ### Questions ? Please refer to the [Backport tool documentation](https://github.com/sqren/backport) Co-authored-by: Mark J. Hoy --- .../lib/ml/ml_model_deployment_common.ts | 2 +- .../lib/ml/start_ml_model_deployment.test.ts | 34 +++++++++++++++---- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/x-pack/plugins/enterprise_search/server/lib/ml/ml_model_deployment_common.ts b/x-pack/plugins/enterprise_search/server/lib/ml/ml_model_deployment_common.ts index f97362725bac5..9465a94301443 100644 --- a/x-pack/plugins/enterprise_search/server/lib/ml/ml_model_deployment_common.ts +++ b/x-pack/plugins/enterprise_search/server/lib/ml/ml_model_deployment_common.ts @@ -11,7 +11,7 @@ import { isResourceNotFoundException, } from '../../utils/identify_exceptions'; -export const acceptableModelNames = ['.elser_model_1_SNAPSHOT']; +export const acceptableModelNames = ['.elser_model_1', '.elser_model_1_SNAPSHOT']; export function isNotFoundExceptionError(error: unknown): boolean { return ( diff --git a/x-pack/plugins/enterprise_search/server/lib/ml/start_ml_model_deployment.test.ts b/x-pack/plugins/enterprise_search/server/lib/ml/start_ml_model_deployment.test.ts index 827f41b83f575..ae11a89ed5ac0 100644 --- a/x-pack/plugins/enterprise_search/server/lib/ml/start_ml_model_deployment.test.ts +++ b/x-pack/plugins/enterprise_search/server/lib/ml/start_ml_model_deployment.test.ts @@ -15,7 +15,8 @@ import * as mockGetStatus from './get_ml_model_deployment_status'; import { startMlModelDeployment } from './start_ml_model_deployment'; describe('startMlModelDeployment', () => { - const knownModelName = '.elser_model_1_SNAPSHOT'; + const productionModelName = '.elser_model_1'; + const snapshotModelName = '.elser_model_1_SNAPSHOT'; const mockTrainedModelsProvider = { getTrainedModels: jest.fn(), getTrainedModelsStats: jest.fn(), @@ -27,7 +28,7 @@ describe('startMlModelDeployment', () => { }); it('should error when there is no trained model provider', () => { - expect(() => startMlModelDeployment(knownModelName, undefined)).rejects.toThrowError( + expect(() => startMlModelDeployment(productionModelName, undefined)).rejects.toThrowError( 'Machine Learning is not enabled' ); }); @@ -49,7 +50,7 @@ describe('startMlModelDeployment', () => { jest.spyOn(mockGetStatus, 'getMlModelDeploymentStatus').mockReturnValueOnce( Promise.resolve({ deploymentState: MlModelDeploymentState.Starting, - modelId: knownModelName, + modelId: productionModelName, nodeAllocationCount: 0, startTime: 123456, targetAllocationCount: 3, @@ -57,7 +58,26 @@ describe('startMlModelDeployment', () => { ); const response = await startMlModelDeployment( - knownModelName, + productionModelName, + mockTrainedModelsProvider as unknown as MlTrainedModels + ); + + expect(response.deploymentState).toEqual(MlModelDeploymentState.Starting); + }); + + it('should return the deployment state if not "downloaded" for snapshot model', async () => { + jest.spyOn(mockGetStatus, 'getMlModelDeploymentStatus').mockReturnValueOnce( + Promise.resolve({ + deploymentState: MlModelDeploymentState.Starting, + modelId: snapshotModelName, + nodeAllocationCount: 0, + startTime: 123456, + targetAllocationCount: 3, + }) + ); + + const response = await startMlModelDeployment( + snapshotModelName, mockTrainedModelsProvider as unknown as MlTrainedModels ); @@ -70,7 +90,7 @@ describe('startMlModelDeployment', () => { .mockReturnValueOnce( Promise.resolve({ deploymentState: MlModelDeploymentState.Downloaded, - modelId: knownModelName, + modelId: productionModelName, nodeAllocationCount: 0, startTime: 123456, targetAllocationCount: 3, @@ -79,7 +99,7 @@ describe('startMlModelDeployment', () => { .mockReturnValueOnce( Promise.resolve({ deploymentState: MlModelDeploymentState.Starting, - modelId: knownModelName, + modelId: productionModelName, nodeAllocationCount: 0, startTime: 123456, targetAllocationCount: 3, @@ -88,7 +108,7 @@ describe('startMlModelDeployment', () => { mockTrainedModelsProvider.startTrainedModelDeployment.mockImplementation(async () => {}); const response = await startMlModelDeployment( - knownModelName, + productionModelName, mockTrainedModelsProvider as unknown as MlTrainedModels ); expect(response.deploymentState).toEqual(MlModelDeploymentState.Starting);