Skip to content

Commit

Permalink
[Security Solution][Elastic AI Assistant] Fixes get Knowledge Base st…
Browse files Browse the repository at this point in the history
…atus not using the default model (elastic#168040)

## Summary

After the introduction of elastic#167522
which fetches the default ELSER model, we weren't instantiating the
`esStore` within the get Knowledge Base status route with the default
model, so it was falling back to `.elser_model_1` and failing to report
the correct status in the UI (even though all documents were loaded).

Also updated the evaluation endpoint to use the `getElser` default model
when instantiating agents to evaluate.
  • Loading branch information
spong authored Oct 4, 2023
1 parent fc434d1 commit 9926628
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ describe('ElasticsearchStore', () => {
inference: {
field_map: { text: 'text_field' },
inference_config: { text_expansion: { results_field: 'tokens' } },
model_id: '.elser_model_1',
model_id: '.elser_model_2',
target_field: 'vector',
},
},
Expand All @@ -130,12 +130,12 @@ describe('ElasticsearchStore', () => {
trained_model_configs: [{ fully_defined: true }],
} as MlGetTrainedModelsResponse);

const isInstalled = await esStore.isModelInstalled('.elser_model_1');
const isInstalled = await esStore.isModelInstalled('.elser_model_2');

expect(isInstalled).toBe(true);
expect(mockEsClient.ml.getTrainedModels).toHaveBeenCalledWith({
include: 'definition_status',
model_id: '.elser_model_1',
model_id: '.elser_model_2',
});
});
});
Expand Down Expand Up @@ -217,7 +217,7 @@ describe('ElasticsearchStore', () => {
},
vector: {
tokens: {},
model_id: '.elser_model_1',
model_id: '.elser_model_2',
},
text: 'documents',
},
Expand All @@ -242,7 +242,7 @@ describe('ElasticsearchStore', () => {
{
text_expansion: {
'vector.tokens': {
model_id: '.elser_model_1',
model_id: '.elser_model_2',
model_text: query,
},
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ export class ElasticsearchStore extends VectorStore {
this.esClient = esClient;
this.index = index ?? KNOWLEDGE_BASE_INDEX_PATTERN;
this.logger = logger;
this.model = model ?? '.elser_model_1';
this.model = model ?? '.elser_model_2';
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export class ElasticsearchEmbeddings extends Embeddings {
/**
* TODO: Use inference API if not re-indexing to create embedding vectors, e.g.
*
* POST _ml/trained_models/.elser_model_1/_infer
* POST _ml/trained_models/.elser_model_2/_infer
* {
* "docs":[{"text_field": "The fool doth think he is wise, but the wise man knows himself to be a fool."}]
* }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ export const callOpenAIFunctionsExecutor = async ({
actions,
connectorId,
esClient,
elserId,
langChainMessages,
llmType,
logger,
Expand All @@ -44,7 +45,7 @@ export const callOpenAIFunctionsExecutor = async ({
});

// ELSER backed ElasticsearchStore for Knowledge Base
const esStore = new ElasticsearchStore(esClient, KNOWLEDGE_BASE_INDEX_PATTERN, logger);
const esStore = new ElasticsearchStore(esClient, KNOWLEDGE_BASE_INDEX_PATTERN, logger, elserId);
const chain = RetrievalQAChain.fromLLM(llm, esStore.asRetriever());

const tools: Tool[] = [
Expand Down
2 changes: 1 addition & 1 deletion x-pack/plugins/elastic_assistant/server/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ export class ElasticAssistantPlugin
// Actions Connector Execute (LLM Wrapper)
postActionsConnectorExecuteRoute(router, getElserId);
// Evaluate
postEvaluateRoute(router);
postEvaluateRoute(router, getElserId);
return {
actions: plugins.actions,
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import { v4 as uuidv4 } from 'uuid';
import { buildResponse } from '../../lib/build_response';
import { buildRouteValidation } from '../../schemas/common';

import { ElasticAssistantRequestHandlerContext } from '../../types';
import { ElasticAssistantRequestHandlerContext, GetElser } from '../../types';
import { EVALUATE } from '../../../common/constants';
import { PostEvaluateBody, PostEvaluatePathQuery } from '../../schemas/evaluate/post_evaluate';
import { performEvaluation } from '../../lib/model_evaluator/evaluation';
Expand All @@ -36,7 +36,10 @@ const AGENT_EXECUTOR_MAP: Record<string, AgentExecutor> = {
OpenAIFunctionsExecutor: callOpenAIFunctionsExecutor,
};

export const postEvaluateRoute = (router: IRouter<ElasticAssistantRequestHandlerContext>) => {
export const postEvaluateRoute = (
router: IRouter<ElasticAssistantRequestHandlerContext>,
getElser: GetElser
) => {
router.post(
{
path: EVALUATE,
Expand Down Expand Up @@ -89,6 +92,9 @@ export const postEvaluateRoute = (router: IRouter<ElasticAssistantRequestHandler
// writing results to the output index
const esClient = (await context.core).elasticsearch.client.asCurrentUser;

// Default ELSER model
const elserId = await getElser(request, (await context.core).savedObjects.getClient());

// Skeleton request to satisfy `subActionParams` spread in `ActionsClientLlm`
const skeletonRequest: KibanaRequest<unknown, unknown, RequestBody> = {
...request,
Expand All @@ -115,6 +121,7 @@ export const postEvaluateRoute = (router: IRouter<ElasticAssistantRequestHandler
actions,
connectorId,
esClient,
elserId,
langChainMessages,
llmType,
logger,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,16 @@ export const getKnowledgeBaseStatusRoute = (

// Get a scoped esClient for finding the status of the Knowledge Base index, pipeline, and documents
const esClient = (await context.core).elasticsearch.client.asCurrentUser;
const esStore = new ElasticsearchStore(esClient, KNOWLEDGE_BASE_INDEX_PATTERN, logger);
const elserId = await getElser(request, (await context.core).savedObjects.getClient());
const esStore = new ElasticsearchStore(
esClient,
KNOWLEDGE_BASE_INDEX_PATTERN,
logger,
elserId
);

const indexExists = await esStore.indexExists();
const pipelineExists = await esStore.pipelineExists();

const elserId = await getElser(request, (await context.core).savedObjects.getClient());
const modelExists = await esStore.isModelInstalled(elserId);

const body: GetKnowledgeBaseStatusResponse = {
Expand Down

0 comments on commit 9926628

Please sign in to comment.