From 6ff2d87b5c8ed48ccfaa66f9cc8d712ae161a076 Mon Sep 17 00:00:00 2001 From: Steph Milovic Date: Thu, 10 Oct 2024 15:59:10 -0600 Subject: [PATCH] [Security GenAI] Fix `VertexChatAI` tool calling (#195689) --- .../chat_vertex/chat_vertex.test.ts | 33 ++++++++++++++++++- .../language_models/chat_vertex/connection.ts | 17 ++++++++++ 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts index 37506922ff69b..07fe252bd5074 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts @@ -12,6 +12,7 @@ import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/act import { BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messages'; import { ActionsClientChatVertexAI } from './chat_vertex'; import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager'; +import { GeminiContent } from '@langchain/google-common'; const connectorId = 'mock-connector-id'; @@ -54,8 +55,10 @@ const mockStreamExecute = jest.fn().mockImplementation(() => { }; }); +const systemInstruction = 'Answer the following questions truthfully and as best you can.'; + const callMessages = [ - new SystemMessage('Answer the following questions truthfully and as best you can.'), + new SystemMessage(systemInstruction), new HumanMessage('Question: Do you know my name?\n\n'), ] as unknown as BaseMessage[]; @@ -196,4 +199,32 @@ describe('ActionsClientChatVertexAI', () => { expect(handleLLMNewToken).toHaveBeenCalledWith('token3'); }); }); + + describe('message formatting', () => { + it('Properly sorts out the system role', async () => { + const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs); + + await actionsClientChatVertexAI._generate(callMessages, callOptions, callRunManager); + const params = actionsClient.execute.mock.calls[0][0].params.subActionParams as unknown as { + messages: GeminiContent[]; + systemInstruction: string; + }; + expect(params.messages.length).toEqual(1); + expect(params.messages[0].parts.length).toEqual(1); + expect(params.systemInstruction).toEqual(systemInstruction); + }); + it('Handles 2 messages in a row from the same role', async () => { + const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs); + + await actionsClientChatVertexAI._generate( + [...callMessages, new HumanMessage('Oh boy, another')], + callOptions, + callRunManager + ); + const { messages } = actionsClient.execute.mock.calls[0][0].params + .subActionParams as unknown as { messages: GeminiContent[] }; + expect(messages.length).toEqual(1); + expect(messages[0].parts.length).toEqual(2); + }); + }); }); diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts index 0340d71b438db..dd3c1e1abdda0 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts @@ -7,6 +7,7 @@ import { ChatConnection, + GeminiContent, GoogleAbstractedClient, GoogleAIBaseLLMInput, GoogleLLMResponse, @@ -39,6 +40,22 @@ export class ActionsClientChatConnection extends ChatConnection { this.caller = caller; this.#model = fields.model; this.temperature = fields.temperature ?? 0; + const nativeFormatData = this.formatData.bind(this); + this.formatData = async (data, options) => { + const result = await nativeFormatData(data, options); + if (result?.contents != null && result?.contents.length) { + // ensure there are not 2 messages in a row from the same role, + // if there are combine them + result.contents = result.contents.reduce((acc: GeminiContent[], currentEntry) => { + if (currentEntry.role === acc[acc.length - 1]?.role) { + acc[acc.length - 1].parts = acc[acc.length - 1].parts.concat(currentEntry.parts); + return acc; + } + return [...acc, currentEntry]; + }, []); + } + return result; + }; } async _request(