Skip to content

Commit

Permalink
[8.x] [Security GenAI] Fix `VertexChatAI` tool calling (#19…
Browse files Browse the repository at this point in the history
…5689) (#195832)

# Backport

This will backport the following commits from `main` to `8.x`:
- [[Security GenAI] Fix `VertexChatAI` tool calling
(#195689)](#195689)

<!--- Backport version: 9.4.3 -->

### Questions ?
Please refer to the [Backport tool
documentation](https://github.com/sqren/backport)

<!--BACKPORT [{"author":{"name":"Steph
Milovic","email":"[email protected]"},"sourceCommit":{"committedDate":"2024-10-10T21:59:10Z","message":"[Security
GenAI] Fix `VertexChatAI` tool calling
(#195689)","sha":"6ff2d87b5c8ed48ccfaa66f9cc8d712ae161a076","branchLabelMapping":{"^v9.0.0$":"main","^v8.16.0$":"8.x","^v(\\d+).(\\d+).\\d+$":"$1.$2"}},"sourcePullRequest":{"labels":["release_note:skip","v9.0.0","Team:
SecuritySolution","backport:prev-minor","Team:Security Generative
AI","v8.16.0"],"title":"[Security GenAI] Fix `VertexChatAI` tool
calling","number":195689,"url":"https://github.com/elastic/kibana/pull/195689","mergeCommit":{"message":"[Security
GenAI] Fix `VertexChatAI` tool calling
(#195689)","sha":"6ff2d87b5c8ed48ccfaa66f9cc8d712ae161a076"}},"sourceBranch":"main","suggestedTargetBranches":["8.x"],"targetPullRequestStates":[{"branch":"main","label":"v9.0.0","branchLabelMappingKey":"^v9.0.0$","isSourceBranch":true,"state":"MERGED","url":"https://github.com/elastic/kibana/pull/195689","number":195689,"mergeCommit":{"message":"[Security
GenAI] Fix `VertexChatAI` tool calling
(#195689)","sha":"6ff2d87b5c8ed48ccfaa66f9cc8d712ae161a076"}},{"branch":"8.x","label":"v8.16.0","branchLabelMappingKey":"^v8.16.0$","isSourceBranch":false,"state":"NOT_CREATED"}]}]
BACKPORT-->

Co-authored-by: Steph Milovic <[email protected]>
  • Loading branch information
kibanamachine and stephmilovic authored Oct 10, 2024
1 parent 6910f15 commit afebfae
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import { actionsClientMock } from '@kbn/actions-plugin/server/actions_client/act
import { BaseMessage, HumanMessage, SystemMessage } from '@langchain/core/messages';
import { ActionsClientChatVertexAI } from './chat_vertex';
import { CallbackManagerForLLMRun } from '@langchain/core/callbacks/manager';
import { GeminiContent } from '@langchain/google-common';

const connectorId = 'mock-connector-id';

Expand Down Expand Up @@ -54,8 +55,10 @@ const mockStreamExecute = jest.fn().mockImplementation(() => {
};
});

const systemInstruction = 'Answer the following questions truthfully and as best you can.';

const callMessages = [
new SystemMessage('Answer the following questions truthfully and as best you can.'),
new SystemMessage(systemInstruction),
new HumanMessage('Question: Do you know my name?\n\n'),
] as unknown as BaseMessage[];

Expand Down Expand Up @@ -196,4 +199,32 @@ describe('ActionsClientChatVertexAI', () => {
expect(handleLLMNewToken).toHaveBeenCalledWith('token3');
});
});

describe('message formatting', () => {
it('Properly sorts out the system role', async () => {
const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);

await actionsClientChatVertexAI._generate(callMessages, callOptions, callRunManager);
const params = actionsClient.execute.mock.calls[0][0].params.subActionParams as unknown as {
messages: GeminiContent[];
systemInstruction: string;
};
expect(params.messages.length).toEqual(1);
expect(params.messages[0].parts.length).toEqual(1);
expect(params.systemInstruction).toEqual(systemInstruction);
});
it('Handles 2 messages in a row from the same role', async () => {
const actionsClientChatVertexAI = new ActionsClientChatVertexAI(defaultArgs);

await actionsClientChatVertexAI._generate(
[...callMessages, new HumanMessage('Oh boy, another')],
callOptions,
callRunManager
);
const { messages } = actionsClient.execute.mock.calls[0][0].params
.subActionParams as unknown as { messages: GeminiContent[] };
expect(messages.length).toEqual(1);
expect(messages[0].parts.length).toEqual(2);
});
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import {
ChatConnection,
GeminiContent,
GoogleAbstractedClient,
GoogleAIBaseLLMInput,
GoogleLLMResponse,
Expand Down Expand Up @@ -39,6 +40,22 @@ export class ActionsClientChatConnection<Auth> extends ChatConnection<Auth> {
this.caller = caller;
this.#model = fields.model;
this.temperature = fields.temperature ?? 0;
const nativeFormatData = this.formatData.bind(this);
this.formatData = async (data, options) => {
const result = await nativeFormatData(data, options);
if (result?.contents != null && result?.contents.length) {
// ensure there are not 2 messages in a row from the same role,
// if there are combine them
result.contents = result.contents.reduce((acc: GeminiContent[], currentEntry) => {
if (currentEntry.role === acc[acc.length - 1]?.role) {
acc[acc.length - 1].parts = acc[acc.length - 1].parts.concat(currentEntry.parts);
return acc;
}
return [...acc, currentEntry];
}, []);
}
return result;
};
}

async _request(
Expand Down

0 comments on commit afebfae

Please sign in to comment.