Skip to content

Commit

Permalink
[Security GenAI] Fix VertexChatAI tool calling (#195689)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic authored Oct 10, 2024
1 parent cd217c0 commit 6ff2d87
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/act
import { BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messages';
import { ActionsClientChatVertexAI } from './chat_vertex';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { GeminiContent } from '@langchain/google-common';

const connectorId = 'mock-connector-id';

Expand Down Expand Up @@ -54,8 +55,10 @@ const mockStreamExecute = jest.fn().mockImplementation(() => {
};
});

const systemInstruction = 'Answer the following questions truthfully and as best you can.';

const callMessages = [
new SystemMessage('Answer the following questions truthfully and as best you can.'),
new SystemMessage(systemInstruction),
new HumanMessage('Question: Do you know my name?\n\n'),
] as unknown as BaseMessage[];

Expand Down Expand Up @@ -196,4 +199,32 @@ describe('ActionsClientChatVertexAI', () => {
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
});
});

describe('message formatting', () => {
it('Properly sorts out the system role', async () => {
const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);

await actionsClientChatVertexAI._generate(callMessages, callOptions, callRunManager);
const params = actionsClient.execute.mock.calls[0][0].params.subActionParams as unknown as {
messages: GeminiContent[];
systemInstruction: string;
};
expect(params.messages.length).toEqual(1);
expect(params.messages[0].parts.length).toEqual(1);
expect(params.systemInstruction).toEqual(systemInstruction);
});
it('Handles 2 messages in a row from the same role', async () => {
const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);

await actionsClientChatVertexAI._generate(
[...callMessages, new HumanMessage('Oh boy, another')],
callOptions,
callRunManager
);
const { messages } = actionsClient.execute.mock.calls[0][0].params
.subActionParams as unknown as { messages: GeminiContent[] };
expect(messages.length).toEqual(1);
expect(messages[0].parts.length).toEqual(2);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import {
ChatConnection,
GeminiContent,
GoogleAbstractedClient,
GoogleAIBaseLLMInput,
GoogleLLMResponse,
Expand Down Expand Up @@ -39,6 +40,22 @@ export class ActionsClientChatConnection<Auth> extends ChatConnection<Auth> {
this.caller = caller;
this.#model = fields.model;
this.temperature = fields.temperature ?? 0;
const nativeFormatData = this.formatData.bind(this);
this.formatData = async (data, options) => {
const result = await nativeFormatData(data, options);
if (result?.contents != null && result?.contents.length) {
// ensure there are not 2 messages in a row from the same role,
// if there are combine them
result.contents = result.contents.reduce((acc: GeminiContent[], currentEntry) => {
if (currentEntry.role === acc[acc.length - 1]?.role) {
acc[acc.length - 1].parts = acc[acc.length - 1].parts.concat(currentEntry.parts);
return acc;
}
return [...acc, currentEntry];
}, []);
}
return result;
};
}

async _request(
Expand Down

0 comments on commit 6ff2d87

Please sign in to comment.