Skip to content

Commit

Permalink
[Obs AI Assistant] knowledge base integration tests (#189000)
Browse files Browse the repository at this point in the history
Closes #188999

- integration tests for knowledge base api
- adds new config field `modelId`, for internal use, to override elser
model id
- refactors `knowledgeBaseService.setup()` to fix bug where if the model
failed to install when calling ml.putTrainedModel, we dont get stuck
polling and retrying the install. We were assuming that the first error
that gets throw when the model is exists would only happen once and the
return true or false and poll for whether its done installing. But the
installation could fail itself causing getTrainedModelsStats to
continuously throw and try to install the model. Now user immediately
gets error if model fails to install and polling does not happen.

---------

Co-authored-by: James Gowdy <[email protected]>
  • Loading branch information
neptunian and jgowdyelastic authored Aug 5, 2024
1 parent d79bdfd commit f18224c
Show file tree
Hide file tree
Showing 10 changed files with 470 additions and 274 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import { schema, type TypeOf } from '@kbn/config-schema';

export const config = schema.object({
enabled: schema.boolean({ defaultValue: true }),
modelId: schema.maybe(schema.string()),
});

export type ObservabilityAIAssistantConfig = TypeOf<typeof config>;
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,12 @@ export class ObservabilityAIAssistantPlugin
>
{
logger: Logger;
config: ObservabilityAIAssistantConfig;
service: ObservabilityAIAssistantService | undefined;

constructor(context: PluginInitializerContext<ObservabilityAIAssistantConfig>) {
this.logger = context.logger.get();
this.config = context.config.get<ObservabilityAIAssistantConfig>();
initLangtrace();
}
public setup(
Expand Down Expand Up @@ -112,10 +114,14 @@ export class ObservabilityAIAssistantPlugin

// Using once to make sure the same model ID is used during service init and Knowledge base setup
const getModelId = once(async () => {
const configModelId = this.config.modelId;
if (configModelId) {
return configModelId;
}
const defaultModelId = '.elser_model_2';
const [_, pluginsStart] = await core.getStartServices();
// Wait for the license to be available so the ML plugin's guards pass once we ask for ELSER stats
const license = await firstValueFrom(pluginsStart.licensing.license$);

if (!license.hasAtLeast('enterprise')) {
return defaultModelId;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
* 2.0.
*/
import { errors } from '@elastic/elasticsearch';
import { serverUnavailable, gatewayTimeout } from '@hapi/boom';
import { serverUnavailable, gatewayTimeout, badRequest } from '@hapi/boom';
import type { ElasticsearchClient, IUiSettingsClient } from '@kbn/core/server';
import type { Logger } from '@kbn/logging';
import type { TaskManagerStartContract } from '@kbn/task-manager-plugin/server';
Expand Down Expand Up @@ -39,14 +39,20 @@ export interface RecalledEntry {
labels?: Record<string, string>;
}

function isAlreadyExistsError(error: Error) {
function isModelMissingOrUnavailableError(error: Error) {
return (
error instanceof errors.ResponseError &&
(error.body.error.type === 'resource_not_found_exception' ||
error.body.error.type === 'status_exception')
);
}

function isCreateModelValidationError(error: Error) {
return (
error instanceof errors.ResponseError &&
error.statusCode === 400 &&
error.body?.error?.type === 'action_request_validation_exception'
);
}
function throwKnowledgeBaseNotReady(body: any) {
throw serverUnavailable(`Knowledge base is not ready yet`, body);
}
Expand Down Expand Up @@ -84,52 +90,73 @@ export class KnowledgeBaseService {
const elserModelId = await this.dependencies.getModelId();

const retryOptions = { factor: 1, minTimeout: 10000, retries: 12 };

const installModel = async () => {
this.dependencies.logger.info('Installing ELSER model');
await this.dependencies.esClient.asInternalUser.ml.putTrainedModel(
{
model_id: elserModelId,
input: {
field_names: ['text_field'],
},
wait_for_completion: true,
},
{ requestTimeout: '20m' }
);
this.dependencies.logger.info('Finished installing ELSER model');
};

const getIsModelInstalled = async () => {
const getResponse = await this.dependencies.esClient.asInternalUser.ml.getTrainedModels({
const getModelInfo = async () => {
return await this.dependencies.esClient.asInternalUser.ml.getTrainedModels({
model_id: elserModelId,
include: 'definition_status',
});
};

this.dependencies.logger.debug(
() => 'Model definition status:\n' + JSON.stringify(getResponse.trained_model_configs[0])
);
const isModelInstalledAndReady = async () => {
try {
const getResponse = await getModelInfo();
this.dependencies.logger.debug(
() => 'Model definition status:\n' + JSON.stringify(getResponse.trained_model_configs[0])
);

return Boolean(getResponse.trained_model_configs[0]?.fully_defined);
return Boolean(getResponse.trained_model_configs[0]?.fully_defined);
} catch (error) {
if (isModelMissingOrUnavailableError(error)) {
return false;
} else {
throw error;
}
}
};

await pRetry(async () => {
let isModelInstalled: boolean = false;
const installModelIfDoesNotExist = async () => {
const modelInstalledAndReady = await isModelInstalledAndReady();
if (!modelInstalledAndReady) {
await installModel();
}
};

const installModel = async () => {
this.dependencies.logger.info('Installing ELSER model');
try {
isModelInstalled = await getIsModelInstalled();
await this.dependencies.esClient.asInternalUser.ml.putTrainedModel(
{
model_id: elserModelId,
input: {
field_names: ['text_field'],
},
wait_for_completion: true,
},
{ requestTimeout: '20m' }
);
} catch (error) {
if (isAlreadyExistsError(error)) {
await installModel();
isModelInstalled = await getIsModelInstalled();
if (isCreateModelValidationError(error)) {
throw badRequest(error);
} else {
throw error;
}
}
this.dependencies.logger.info('Finished installing ELSER model');
};

if (!isModelInstalled) {
throwKnowledgeBaseNotReady({
message: 'Model is not fully defined',
});
}
}, retryOptions);
const pollForModelInstallCompleted = async () => {
await pRetry(async () => {
this.dependencies.logger.info('Polling installation of ELSER model');
const modelInstalledAndReady = await isModelInstalledAndReady();
if (!modelInstalledAndReady) {
throwKnowledgeBaseNotReady({
message: 'Model is not fully defined',
});
}
}, retryOptions);
};
await installModelIfDoesNotExist();
await pollForModelInstallCompleted();

try {
await this.dependencies.esClient.asInternalUser.ml.startTrainedModelDeployment({
Expand All @@ -139,7 +166,7 @@ export class KnowledgeBaseService {
} catch (error) {
this.dependencies.logger.debug('Error starting model deployment');
this.dependencies.logger.debug(error);
if (!isAlreadyExistsError(error)) {
if (!isModelMissingOrUnavailableError(error)) {
throw error;
}
}
Expand Down Expand Up @@ -380,7 +407,7 @@ export class KnowledgeBaseService {
namespace,
modelId,
}).catch((error) => {
if (isAlreadyExistsError(error)) {
if (isModelMissingOrUnavailableError(error)) {
throwKnowledgeBaseNotReady(error.body);
}
throw error;
Expand Down Expand Up @@ -521,7 +548,7 @@ export class KnowledgeBaseService {
})),
};
} catch (error) {
if (isAlreadyExistsError(error)) {
if (isModelMissingOrUnavailableError(error)) {
throwKnowledgeBaseNotReady(error.body);
}
throw error;
Expand Down Expand Up @@ -588,7 +615,7 @@ export class KnowledgeBaseService {

return Promise.resolve();
} catch (error) {
if (isAlreadyExistsError(error)) {
if (isModelMissingOrUnavailableError(error)) {
throwKnowledgeBaseNotReady(error.body);
}
throw error;
Expand Down
19 changes: 13 additions & 6 deletions x-pack/test/functional/services/ml/api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1473,11 +1473,13 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) {
}
},

async stopTrainedModelDeploymentES(deploymentId: string) {
async stopTrainedModelDeploymentES(deploymentId: string, force: boolean = false) {
log.debug(`Stopping trained model deployment with id "${deploymentId}"`);
const { body, status } = await esSupertest.post(
`/_ml/trained_models/${deploymentId}/deployment/_stop`
);
const url = `/_ml/trained_models/${deploymentId}/deployment/_stop${
force ? '?force=true' : ''
}`;

const { body, status } = await esSupertest.post(url);
this.assertResponseStatusCode(200, status, body);

log.debug('> Trained model deployment stopped');
Expand Down Expand Up @@ -1570,8 +1572,13 @@ export function MachineLearningAPIProvider({ getService }: FtrProviderContext) {
);
},

async importTrainedModel(modelId: string, modelName: SupportedTrainedModelNamesType) {
await this.createTrainedModel(modelId, this.getTrainedModelConfig(modelName));
async importTrainedModel(
modelId: string,
modelName: SupportedTrainedModelNamesType,
config?: PutTrainedModelConfig
) {
const trainedModelConfig = config ?? this.getTrainedModelConfig(modelName);
await this.createTrainedModel(modelId, trainedModelConfig);
await this.createTrainedModelVocabularyES(modelId, this.getTrainedModelVocabulary(modelName));
await this.uploadTrainedModelDefinitionES(
modelId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import { mapValues } from 'lodash';
import path from 'path';
import { createTestConfig, CreateTestConfig } from '../common/config';
import { SUPPORTED_TRAINED_MODELS } from '../../functional/services/ml/api';

export const observabilityAIAssistantDebugLogger = {
name: 'plugins.observabilityAIAssistant',
Expand All @@ -30,6 +31,7 @@ export const observabilityAIAssistantFtrConfigs = {
__dirname,
'../../../../test/analytics/plugins/analytics_ftr_helpers'
),
'xpack.observabilityAIAssistant.modelId': SUPPORTED_TRAINED_MODELS.TINY_ELSER.name,
},
},
};
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

import { MachineLearningProvider } from '../../../api_integration/services/ml';
import { SUPPORTED_TRAINED_MODELS } from '../../../functional/services/ml/api';

export const TINY_ELSER = {
...SUPPORTED_TRAINED_MODELS.TINY_ELSER,
id: SUPPORTED_TRAINED_MODELS.TINY_ELSER.name,
};

export async function createKnowledgeBaseModel(ml: ReturnType<typeof MachineLearningProvider>) {
const config = {
...ml.api.getTrainedModelConfig(TINY_ELSER.name),
input: {
field_names: ['text_field'],
},
};
await ml.api.importTrainedModel(TINY_ELSER.name, TINY_ELSER.id, config);
await ml.api.assureMlStatsIndexExists();
}
export async function deleteKnowledgeBaseModel(ml: ReturnType<typeof MachineLearningProvider>) {
await ml.api.stopTrainedModelDeploymentES(TINY_ELSER.id, true);
await ml.api.deleteTrainedModelES(TINY_ELSER.id);
await ml.api.cleanMlIndices();
await ml.testResources.cleanMLSavedObjects();
}
Loading

0 comments on commit f18224c

Please sign in to comment.