Skip to content

Commit

Permalink
[ML] API integration tests for start and stop model deployment (#168460)
Browse files Browse the repository at this point in the history
## Summary

Part of #164562

Adds API integration tests for `_start` and `_stop` trained model
deployment.

### Checklist

- [x] [Unit or functional
tests](https://www.elastic.co/guide/en/kibana/master/development-tests.html)
were updated or added to match the most common scenarios
  • Loading branch information
darnautov authored Oct 10, 2023
1 parent ae07584 commit 9ba0e71
Show file tree
Hide file tree
Showing 4 changed files with 228 additions and 11 deletions.
6 changes: 3 additions & 3 deletions x-pack/plugins/ml/server/routes/schemas/inference_schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ export const modelAndDeploymentIdSchema = schema.object({

export const threadingParamsSchema = schema.maybe(
schema.object({
number_of_allocations: schema.number(),
threads_per_allocation: schema.number(),
priority: schema.oneOf([schema.literal('low'), schema.literal('normal')]),
number_of_allocations: schema.maybe(schema.number()),
threads_per_allocation: schema.maybe(schema.number()),
priority: schema.maybe(schema.oneOf([schema.literal('low'), schema.literal('normal')])),
deployment_id: schema.maybe(schema.string()),
})
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ export default function ({ loadTestFile }: FtrProviderContext) {
loadTestFile(require.resolve('./get_model_stats'));
loadTestFile(require.resolve('./get_model_pipelines'));
loadTestFile(require.resolve('./delete_model'));
loadTestFile(require.resolve('./start_stop_deployment'));
});
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import expect from '@kbn/expect';
import type { MlGetTrainedModelsStatsResponse } from '@elastic/elasticsearch/lib/api/typesWithBodyKey';
import { SUPPORTED_TRAINED_MODELS } from '../../../../functional/services/ml/api';
import { FtrProviderContext } from '../../../ftr_provider_context';
import { USER } from '../../../../functional/services/ml/security_common';
import { getCommonRequestHeader } from '../../../../functional/services/ml/common_api';

export default ({ getService }: FtrProviderContext) => {
const supertest = getService('supertestWithoutAuth');
const ml = getService('ml');

const testModel = {
...SUPPORTED_TRAINED_MODELS.TINY_NER,
id: SUPPORTED_TRAINED_MODELS.TINY_NER.name,
};

const customDeploymentId = 'my_deployment_id';

describe('Start and stop deployment tests', () => {
before(async () => {
await ml.api.importTrainedModel(testModel.id, testModel.name);
await ml.testResources.setKibanaTimeZoneToUTC();

// Make sure the .ml-stats index is created in advance, see https://github.com/elastic/elasticsearch/issues/65846
await ml.api.assureMlStatsIndexExists();
});

after(async () => {
await ml.api.stopAllTrainedModelDeploymentsES();
await ml.api.deleteAllTrainedModelsES();
await ml.api.cleanMlIndices();
await ml.testResources.cleanMLSavedObjects();
});

it('does not allow to start trained model deployment if the user does not have required permissions', async () => {
const { body: startResponseBody, status: startResponseStatus } = await supertest
.post(`/internal/ml/trained_models/${testModel.id}/deployment/_start`)
.auth(USER.ML_VIEWER, ml.securityCommon.getPasswordForUser(USER.ML_VIEWER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(403, startResponseStatus, startResponseBody);

// verify that model deployment has not been started
const { body: statsResponse, status: statsResponseStatus } = await supertest
.get(`/internal/ml/trained_models/${testModel.id}/_stats`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, statsResponseStatus, statsResponse);

const deploymentStats = (
statsResponse as MlGetTrainedModelsStatsResponse
).trained_model_stats.find((v) => v.deployment_stats?.deployment_id === testModel.id);

expect(deploymentStats).to.be(undefined);
});

it('starts trained model deployment with the default ID', async () => {
const { body: startResponseBody, status: deleteResponseStatus } = await supertest
.post(`/internal/ml/trained_models/${testModel.id}/deployment/_start`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, deleteResponseStatus, startResponseBody);

expect(startResponseBody.assignment.assignment_state).to.eql('started');
expect(startResponseBody.assignment.task_parameters.threads_per_allocation).to.eql(1);
expect(startResponseBody.assignment.task_parameters.priority).to.eql('normal');
expect(startResponseBody.assignment.task_parameters.deployment_id).to.eql(testModel.id);

// check deployment status
const { body: statsResponse, status: statsResponseStatus } = await supertest
.get(`/internal/ml/trained_models/${testModel.id}/_stats`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, statsResponseStatus, statsResponse);

const modelStats = (
statsResponse as MlGetTrainedModelsStatsResponse
).trained_model_stats.find((v) => v.deployment_stats?.deployment_id === testModel.id);

expect(modelStats!.deployment_stats!.allocation_status.state).to.match(
/\bstarted\b|\bfully_allocated\b/
);
});

it('starts trained model deployment with provided deployment ID', async () => {
const { body: startResponseBody, status: deleteResponseStatus } = await supertest
.post(`/internal/ml/trained_models/${testModel.id}/deployment/_start`)
.query({ deployment_id: customDeploymentId })
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, deleteResponseStatus, startResponseBody);

expect(startResponseBody.assignment.assignment_state).to.eql('started');
expect(startResponseBody.assignment.task_parameters.deployment_id).to.eql(customDeploymentId);

// check deployment status
const { body: statsResponse, status: statsResponseStatus } = await supertest
.get(`/internal/ml/trained_models/${testModel.id}/_stats`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, statsResponseStatus, statsResponse);

const modelStats = (
statsResponse as MlGetTrainedModelsStatsResponse
).trained_model_stats.find((v) => v.deployment_stats?.deployment_id === customDeploymentId);

expect(modelStats!.deployment_stats!.allocation_status.state).to.match(
/\bstarted\b|\bfully_allocated\b/
);
});

it('returns 404 if requested trained model does not exist', async () => {
const { body, status } = await supertest
.post(`/internal/ml/trained_models/not_existing_model/deployment/_start`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(404, status, body);
});

it('does not allow to stop trained model deployment if the user does not have required permissions', async () => {
const { body: stopResponseBody, status: stopResponseStatus } = await supertest
.post(`/internal/ml/trained_models/${testModel.id}/${testModel.id}/deployment/_stop`)
.auth(USER.ML_VIEWER, ml.securityCommon.getPasswordForUser(USER.ML_VIEWER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(403, stopResponseStatus, stopResponseBody);

// verify that model deployment has not been started
const { body: statsResponse, status: statsResponseStatus } = await supertest
.get(`/internal/ml/trained_models/${testModel.id}/_stats`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, statsResponseStatus, statsResponse);

const modelStats = (
statsResponse as MlGetTrainedModelsStatsResponse
).trained_model_stats.find((v) => v.deployment_stats?.deployment_id === testModel.id);

expect(modelStats!.deployment_stats!.allocation_status.state).to.match(
/\bstarted\b|\bfully_allocated\b/
);
});

it('stops trained model deployment with the default ID', async () => {
const { body: stopResponseBody, status: stopResponseStatus } = await supertest
.post(`/internal/ml/trained_models/${testModel.id}/${testModel.id}/deployment/_stop`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, stopResponseStatus, stopResponseBody);

expect(stopResponseBody).to.eql({
[testModel.id]: {
success: true,
},
});

// check deployment status
const { body: statsResponse, status: statsResponseStatus } = await supertest
.get(`/internal/ml/trained_models/${testModel.id}/_stats`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, statsResponseStatus, statsResponse);

const deploymentStats = (
statsResponse as MlGetTrainedModelsStatsResponse
).trained_model_stats.find((v) => v.deployment_stats?.deployment_id === testModel.id);

expect(deploymentStats).to.be(undefined);
});

it('stops trained model deployment with provided deployment ID', async () => {
const { body: stopResponseBody, status: stopResponseStatus } = await supertest
.post(`/internal/ml/trained_models/${testModel.id}/${customDeploymentId}/deployment/_stop`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, stopResponseStatus, stopResponseBody);

expect(stopResponseBody).to.eql({
[customDeploymentId]: {
success: true,
},
});

// check deployment status
const { body: statsResponse, status: statsResponseStatus } = await supertest
.get(`/internal/ml/trained_models/${testModel.id}/_stats`)
.auth(USER.ML_POWERUSER, ml.securityCommon.getPasswordForUser(USER.ML_POWERUSER))
.set(getCommonRequestHeader('1'));
ml.api.assertResponseStatusCode(200, statsResponseStatus, statsResponse);

const deploymentStats = (
statsResponse as MlGetTrainedModelsStatsResponse
).trained_model_stats.find((v) => v.deployment_stats?.deployment_id === customDeploymentId);

expect(deploymentStats).to.be(undefined);
});
});
};
29 changes: 21 additions & 8 deletions x-pack/test/functional/services/ml/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,15 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) {
return body;
},

async getTrainedModelStatsES(): Promise<estypes.MlGetTrainedModelsStatsResponse> {
log.debug(`Getting trained models stats`);
const { body, status } = await esSupertest.get(`/_ml/trained_models/_stats`);
this.assertResponseStatusCode(200, status, body);

log.debug('> Trained model stats fetched');
return body;
},

async deleteTrainedModelES(modelId: string) {
log.debug(`Deleting trained model with id "${modelId}"`);
const { body, status } = await esSupertest
Expand All @@ -1363,10 +1372,10 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) {
}
},

async stopTrainedModelDeploymentES(modelId: string) {
log.debug(`Stopping trained model deployment with id "${modelId}"`);
async stopTrainedModelDeploymentES(deploymentId: string) {
log.debug(`Stopping trained model deployment with id "${deploymentId}"`);
const { body, status } = await esSupertest.post(
`/_ml/trained_models/${modelId}/deployment/_stop`
`/_ml/trained_models/${deploymentId}/deployment/_stop`
);
this.assertResponseStatusCode(200, status, body);

Expand All @@ -1375,13 +1384,17 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) {

async stopAllTrainedModelDeploymentsES() {
log.debug(`Stopping all trained model deployments`);
const getModelsRsp = await this.getTrainedModelsES();
for (const model of getModelsRsp.trained_model_configs) {
if (this.isInternalModelId(model.model_id)) {
log.debug(`> Skipping internal ${model.model_id}`);
const getModelsRsp = await this.getTrainedModelStatsES();
for (const modelStats of getModelsRsp.trained_model_stats) {
if (this.isInternalModelId(modelStats.model_id)) {
log.debug(`> Skipping internal ${modelStats.model_id}`);
continue;
}
if (modelStats.deployment_stats === undefined) {
log.debug(`> Skipping, no deployment stats for ${modelStats.model_id} found`);
continue;
}
await this.stopTrainedModelDeploymentES(model.model_id);
await this.stopTrainedModelDeploymentES(modelStats.deployment_stats.deployment_id);
}
},

Expand Down

0 comments on commit 9ba0e71

Please sign in to comment.