diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx index 65b8183b60a0b..2f46e99d12b07 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.test.tsx @@ -126,4 +126,88 @@ describe('fetchConnectorExecuteAction', () => { expect(result).toBe('Test response'); }); + + it('returns the value of the action_input property when assistantLangChain is true, and `content` has properly prefixed and suffixed JSON with the action_input property', async () => { + const content = '```json\n{"action_input": "value from action_input"}\n```'; + + (mockHttp.fetch as jest.Mock).mockResolvedValue({ + status: 'ok', + data: { + choices: [ + { + message: { + content, + }, + }, + ], + }, + }); + + const testProps: FetchConnectorExecuteAction = { + assistantLangChain: true, // <-- requires response parsing + http: mockHttp, + messages, + apiConfig, + }; + + const result = await fetchConnectorExecuteAction(testProps); + + expect(result).toBe('value from action_input'); + }); + + it('returns the original content when assistantLangChain is true, and `content` has properly formatted JSON WITHOUT the action_input property', async () => { + const content = '```json\n{"some_key": "some value"}\n```'; + + (mockHttp.fetch as jest.Mock).mockResolvedValue({ + status: 'ok', + data: { + choices: [ + { + message: { + content, + }, + }, + ], + }, + }); + + const testProps: FetchConnectorExecuteAction = { + assistantLangChain: true, // <-- requires response parsing + http: mockHttp, + messages, + apiConfig, + }; + + const result = await fetchConnectorExecuteAction(testProps); + + expect(result).toBe(content); + }); + + it('returns the original when assistantLangChain is true, and `content` is not JSON', async () => { + const content = 'plain text content'; + + (mockHttp.fetch as jest.Mock).mockResolvedValue({ + status: 'ok', + data: { + choices: [ + { + message: { + content, + }, + }, + ], + }, + }); + + const testProps: FetchConnectorExecuteAction = { + assistantLangChain: true, // <-- requires response parsing + http: mockHttp, + messages, + apiConfig, + }; + + const result = await fetchConnectorExecuteAction(testProps); + + expect(result).toBe(content); + }); }); diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx index 511b5aa585af0..6d3452b6f7880 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/api.tsx @@ -12,6 +12,7 @@ import { HttpSetup, IHttpFetchError } from '@kbn/core-http-browser'; import type { Conversation, Message } from '../assistant_context/types'; import { API_ERROR } from './translations'; import { MODEL_GPT_3_5_TURBO } from '../connectorland/models/model_selector/model_selector'; +import { getFormattedMessageContent } from './helpers'; export interface FetchConnectorExecuteAction { assistantLangChain: boolean; @@ -78,7 +79,8 @@ export const fetchConnectorExecuteAction = async ({ if (data.choices && data.choices.length > 0 && data.choices[0].message.content) { const result = data.choices[0].message.content.trim(); - return result; + + return assistantLangChain ? getFormattedMessageContent(result) : result; } else { return API_ERROR; } diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/helpers.test.ts b/x-pack/packages/kbn-elastic-assistant/impl/assistant/helpers.test.ts index 69bed887e730e..f2b89a07c319e 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/helpers.test.ts +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/helpers.test.ts @@ -5,7 +5,11 @@ * 2.0. */ -import { getDefaultConnector, getBlockBotConversation } from './helpers'; +import { + getBlockBotConversation, + getDefaultConnector, + getFormattedMessageContent, +} from './helpers'; import { enterpriseMessaging } from './use_conversation/sample_conversations'; import { ActionConnector } from '@kbn/triggers-actions-ui-plugin/public'; @@ -190,4 +194,41 @@ describe('getBlockBotConversation', () => { expect(result).toBeUndefined(); }); }); + + describe('getFormattedMessageContent', () => { + it('returns the value of the action_input property when `content` has properly prefixed and suffixed JSON with the action_input property', () => { + const content = '```json\n{"action_input": "value from action_input"}\n```'; + + expect(getFormattedMessageContent(content)).toBe('value from action_input'); + }); + + it('returns the original content when `content` has properly formatted JSON WITHOUT the action_input property', () => { + const content = '```json\n{"some_key": "some value"}\n```'; + expect(getFormattedMessageContent(content)).toBe(content); + }); + + it('returns the original content when `content` has improperly formatted JSON', () => { + const content = '```json\n{"action_input": "value from action_input",}\n```'; // <-- the trailing comma makes it invalid + + expect(getFormattedMessageContent(content)).toBe(content); + }); + + it('returns the original content when `content` is missing the prefix', () => { + const content = '{"action_input": "value from action_input"}\n```'; // <-- missing prefix + + expect(getFormattedMessageContent(content)).toBe(content); + }); + + it('returns the original content when `content` is missing the suffix', () => { + const content = '```json\n{"action_input": "value from action_input"}'; // <-- missing suffix + + expect(getFormattedMessageContent(content)).toBe(content); + }); + + it('returns the original content when `content` does NOT contain a JSON string', () => { + const content = 'plain text content'; + + expect(getFormattedMessageContent(content)).toBe(content); + }); + }); }); diff --git a/x-pack/packages/kbn-elastic-assistant/impl/assistant/helpers.ts b/x-pack/packages/kbn-elastic-assistant/impl/assistant/helpers.ts index b01c9001e8319..2b2c5b76851f7 100644 --- a/x-pack/packages/kbn-elastic-assistant/impl/assistant/helpers.ts +++ b/x-pack/packages/kbn-elastic-assistant/impl/assistant/helpers.ts @@ -59,3 +59,24 @@ export const getDefaultConnector = ( connectors: Array, Record>> | undefined ): ActionConnector, Record> | undefined => connectors?.length === 1 ? connectors[0] : undefined; + +/** + * When `content` is a JSON string, prefixed with "```json\n" + * and suffixed with "\n```", this function will attempt to parse it and return + * the `action_input` property if it exists. + */ +export const getFormattedMessageContent = (content: string): string => { + const formattedContentMatch = content.match(/```json\n([\s\S]+)\n```/); + + if (formattedContentMatch) { + try { + const parsedContent = JSON.parse(formattedContentMatch[1]); + + return parsedContent.action_input ?? content; + } catch { + // we don't want to throw an error here, so we'll fall back to the original content + } + } + + return content; +}; diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts index be1adbc2e1ce4..67fb3859b9943 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.test.ts @@ -12,7 +12,7 @@ import { ResponseBody } from '../helpers'; import { ActionsClientLlm } from '../llm/actions_client_llm'; import { mockActionResultData } from '../../../__mocks__/action_result_data'; import { langChainMessages } from '../../../__mocks__/lang_chain_messages'; -import { executeCustomLlmChain } from '.'; +import { callAgentExecutor } from '.'; import { loggerMock } from '@kbn/logging-mocks'; import { elasticsearchServiceMock } from '@kbn/core-elasticsearch-server-mocks'; @@ -23,11 +23,18 @@ const mockConversationChain = { }; jest.mock('langchain/chains', () => ({ - ConversationalRetrievalQAChain: { + RetrievalQAChain: { fromLLM: jest.fn().mockImplementation(() => mockConversationChain), }, })); +const mockCall = jest.fn(); +jest.mock('langchain/agents', () => ({ + initializeAgentExecutorWithOptions: jest.fn().mockImplementation(() => ({ + call: mockCall, + })), +})); + const mockConnectorId = 'mock-connector-id'; // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -42,7 +49,7 @@ const mockActions: ActionsPluginStart = {} as ActionsPluginStart; const mockLogger = loggerMock.create(); const esClientMock = elasticsearchServiceMock.createScopedClusterClient().asCurrentUser; -describe('executeCustomLlmChain', () => { +describe('callAgentExecutor', () => { beforeEach(() => { jest.clearAllMocks(); @@ -52,7 +59,7 @@ describe('executeCustomLlmChain', () => { }); it('creates an instance of ActionsClientLlm with the expected context from the request', async () => { - await executeCustomLlmChain({ + await callAgentExecutor({ actions: mockActions, connectorId: mockConnectorId, esClient: esClientMock, @@ -70,7 +77,7 @@ describe('executeCustomLlmChain', () => { }); it('kicks off the chain with (only) the last message', async () => { - await executeCustomLlmChain({ + await callAgentExecutor({ actions: mockActions, connectorId: mockConnectorId, esClient: esClientMock, @@ -79,15 +86,15 @@ describe('executeCustomLlmChain', () => { request: mockRequest, }); - expect(mockConversationChain.call).toHaveBeenCalledWith({ - question: '\n\nDo you know my name?', + expect(mockCall).toHaveBeenCalledWith({ + input: '\n\nDo you know my name?', }); }); it('kicks off the chain with the expected message when langChainMessages has only one entry', async () => { const onlyOneMessage = [langChainMessages[0]]; - await executeCustomLlmChain({ + await callAgentExecutor({ actions: mockActions, connectorId: mockConnectorId, esClient: esClientMock, @@ -96,13 +103,13 @@ describe('executeCustomLlmChain', () => { request: mockRequest, }); - expect(mockConversationChain.call).toHaveBeenCalledWith({ - question: 'What is my name?', + expect(mockCall).toHaveBeenCalledWith({ + input: 'What is my name?', }); }); it('returns the expected response body', async () => { - const result: ResponseBody = await executeCustomLlmChain({ + const result: ResponseBody = await callAgentExecutor({ actions: mockActions, connectorId: mockConnectorId, esClient: esClientMock, diff --git a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts index 5a65b1589b21e..b6a768ad69598 100644 --- a/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts +++ b/x-pack/plugins/elastic_assistant/server/lib/langchain/execute_custom_llm_chain/index.ts @@ -7,16 +7,18 @@ import { ElasticsearchClient, KibanaRequest, Logger } from '@kbn/core/server'; import type { PluginStartContract as ActionsPluginStart } from '@kbn/actions-plugin/server'; +import { initializeAgentExecutorWithOptions } from 'langchain/agents'; +import { RetrievalQAChain } from 'langchain/chains'; import { BufferMemory, ChatMessageHistory } from 'langchain/memory'; import { BaseMessage } from 'langchain/schema'; +import { ChainTool, Tool } from 'langchain/tools'; -import { ConversationalRetrievalQAChain } from 'langchain/chains'; +import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store'; import { ResponseBody } from '../helpers'; import { ActionsClientLlm } from '../llm/actions_client_llm'; -import { ElasticsearchStore } from '../elasticsearch_store/elasticsearch_store'; import { KNOWLEDGE_BASE_INDEX_PATTERN } from '../../../routes/knowledge_base/constants'; -export const executeCustomLlmChain = async ({ +export const callAgentExecutor = async ({ actions, connectorId, esClient, @@ -34,31 +36,38 @@ export const executeCustomLlmChain = async ({ }): Promise => { const llm = new ActionsClientLlm({ actions, connectorId, request, logger }); - // Chat History Memory: in-memory memory, from client local storage, first message is the system prompt const pastMessages = langChainMessages.slice(0, -1); // all but the last message const latestMessage = langChainMessages.slice(-1); // the last message + const memory = new BufferMemory({ chatHistory: new ChatMessageHistory(pastMessages), - memoryKey: 'chat_history', + memoryKey: 'chat_history', // this is the key expected by https://github.com/langchain-ai/langchainjs/blob/a13a8969345b0f149c1ca4a120d63508b06c52a5/langchain/src/agents/initialize.ts#L166 + inputKey: 'input', + outputKey: 'output', + returnMessages: true, }); // ELSER backed ElasticsearchStore for Knowledge Base const esStore = new ElasticsearchStore(esClient, KNOWLEDGE_BASE_INDEX_PATTERN, logger); + const chain = RetrievalQAChain.fromLLM(llm, esStore.asRetriever()); + + const tools: Tool[] = [ + new ChainTool({ + name: 'esql-language-knowledge-base', + description: + 'Call this for knowledge on how to build an ESQL query, or answer questions about the ES|QL query language.', + chain, + }), + ]; - // Chain w/ chat history memory and knowledge base retriever - const chain = ConversationalRetrievalQAChain.fromLLM(llm, esStore.asRetriever(), { + const executor = await initializeAgentExecutorWithOptions(tools, llm, { + agentType: 'chat-conversational-react-description', memory, - // See `qaChainOptions` from https://js.langchain.com/docs/modules/chains/popular/chat_vector_db - qaChainOptions: { type: 'stuff' }, + verbose: false, }); - await chain.call({ question: latestMessage[0].content }); - // Chain w/ just knowledge base retriever - // const chain = RetrievalQAChain.fromLLM(llm, esStore.asRetriever()); - // await chain.call({ query: latestMessage[0].content }); + await executor.call({ input: latestMessage[0].content }); - // The assistant (on the client side) expects the same response returned - // from the actions framework, so we need to return the same shape of data: return { connector_id: connectorId, data: llm.getActionResultData(), // the response from the actions framework diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts index 2e6709a6e33c2..57f2b25f5a65f 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.test.ts @@ -20,7 +20,7 @@ jest.mock('../lib/build_response', () => ({ })); jest.mock('../lib/langchain/execute_custom_llm_chain', () => ({ - executeCustomLlmChain: jest.fn().mockImplementation( + callAgentExecutor: jest.fn().mockImplementation( async ({ connectorId, }: { diff --git a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts index 1043f68f0f9c1..bbb1c76e3e579 100644 --- a/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts +++ b/x-pack/plugins/elastic_assistant/server/routes/post_actions_connector_execute.ts @@ -20,7 +20,7 @@ import { PostActionsConnectorExecutePathParams, } from '../schemas/post_actions_connector_execute'; import { ElasticAssistantRequestHandlerContext } from '../types'; -import { executeCustomLlmChain } from '../lib/langchain/execute_custom_llm_chain'; +import { callAgentExecutor } from '../lib/langchain/execute_custom_llm_chain'; export const postActionsConnectorExecuteRoute = ( router: IRouter @@ -53,7 +53,7 @@ export const postActionsConnectorExecuteRoute = ( // convert the assistant messages to LangChain messages: const langChainMessages = getLangChainMessages(assistantMessages); - const langChainResponseBody = await executeCustomLlmChain({ + const langChainResponseBody = await callAgentExecutor({ actions, connectorId, esClient,