diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts index 37506922ff69b..dbec91e872547 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/chat_vertex.test.ts @@ -54,8 +54,10 @@ const mockStreamExecute = jest.fn().mockImplementation(() => { }; }); +const systemInstruction = 'Answer the following questions truthfully and as best you can.'; + const callMessages = [ - new SystemMessage('Answer the following questions truthfully and as best you can.'), + new SystemMessage(systemInstruction), new HumanMessage('Question: Do you know my name?\n\n'), ] as unknown as BaseMessage[]; @@ -196,4 +198,29 @@ describe('ActionsClientChatVertexAI', () => { expect(handleLLMNewToken).toHaveBeenCalledWith('token3'); }); }); + + describe('message formatting', () => { + it('Properly sorts out the system role', async () => { + const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs); + + await actionsClientChatVertexAI._generate(callMessages, callOptions, callRunManager); + const params = actionsClient.execute.mock.calls[0][0].params.subActionParams; + console.log(actionsClient.execute.mock.calls[0][0].params.subActionParams); + expect(params.messages.length).toEqual(1); + expect(params.messages[0].parts.length).toEqual(1); + expect(params.systemInstruction).toEqual(systemInstruction); + }); + it('Handles 2 messages in a row from the same role', async () => { + const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs); + + await actionsClientChatVertexAI._generate( + [...callMessages, new HumanMessage('Oh boy, another')], + callOptions, + callRunManager + ); + const messages = actionsClient.execute.mock.calls[0][0].params.subActionParams.messages; + expect(messages.length).toEqual(1); + expect(messages[0].parts.length).toEqual(2); + }); + }); }); diff --git a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts index 0340d71b438db..bec3662a5ee48 100644 --- a/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts +++ b/x-pack/packages/kbn-langchain/server/language_models/chat_vertex/connection.ts @@ -39,6 +39,22 @@ export class ActionsClientChatConnection extends ChatConnection { this.caller = caller; this.#model = fields.model; this.temperature = fields.temperature ?? 0; + const nativeFormatData = this.formatData.bind(this); + this.formatData = async (data, options) => { + const result = await nativeFormatData(data, options); + if (result?.contents.length) { + // ensure there are not 2 messages in a row from the same role, + // if there are combine them + result.contents = result.contents.reduce((acc, currentEntry) => { + if (currentEntry.role === acc[acc.length - 1]?.role) { + acc[acc.length - 1].parts = acc[acc.length - 1].parts.concat(currentEntry.parts); + return acc; + } + return [...acc, currentEntry]; + }, []); + } + return result; + }; } async _request(