Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Security solution] Bedrock chat fix #192013

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,44 @@ export class ActionsClientBedrockChatModel extends _BedrockChat {
}
}

const prepareMessages = (messages: Array<{ role: string; content: string[] }>) =>
const prepareMessages = (
messages: Array<{ role: string; content: string | unknown[] }>
): Array<{ role: string; content?: string; rawContent?: unknown[] }> =>
messages.reduce((acc, { role, content }) => {
const lastMessage = acc[acc.length - 1];

// If there's no last message or the role is different, create a new entry
if (!lastMessage || lastMessage.role !== role) {
acc.push({ role, content });
acc.push({ role, ...getNewMessageFormat(content) });
return acc;
}

if (lastMessage.role === role) {
acc[acc.length - 1].content = lastMessage.content.concat(content);
return acc;
// If the role is the same, merge the content based on the type
if (typeof content === 'string') {
if (lastMessage.content) {
lastMessage.content += content;
} else {
acc[acc.length - 1] = { role, content };
}
} else if (Array.isArray(content)) {
if (lastMessage.rawContent) {
lastMessage.rawContent = lastMessage.rawContent.concat(content);
} else {
acc[acc.length - 1] = { role, rawContent: content };
}
}
}

return acc;
}, [] as Array<{ role: string; content: string[] }>);
}, [] as Array<{ role: string; content?: string; rawContent?: unknown[] }>);

// Helper function to format the new message
const getNewMessageFormat = (content: string | unknown[]) => {
if (typeof content === 'string') {
return { content };
} else if (Array.isArray(content)) {
return { rawContent: content };
}
return {}; // Return an empty object if content is in an unexpected format
};
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ import { Message, MessageRole } from '../../../../common/chat_complete';
import { createInferenceInternalError } from '../../../../common/errors';
import { ToolChoiceType, type ToolOptions } from '../../../../common/chat_complete/tools';
import { InferenceConnectorAdapter } from '../../types';
import type { BedRockMessage, BedrockToolChoice } from './types';
import type { BedrockMessage, BedrockToolChoice } from './types';
import {
BedrockChunkMember,
serdeEventstreamIntoObservable,
Expand Down Expand Up @@ -97,8 +97,8 @@ const toolsToBedrock = (tools: ToolOptions['tools']) => {
: undefined;
};

const messagesToBedrock = (messages: Message[]): BedRockMessage[] => {
return messages.map<BedRockMessage>((message) => {
const messagesToBedrock = (messages: Message[]): BedrockMessage[] => {
return messages.map<BedrockMessage>((message) => {
switch (message.role) {
case MessageRole.User:
return {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@
*/

/**
* BedRock message as expected by the bedrock connector
* Bedrock message as expected by the bedrock connector
*/
export interface BedRockMessage {
export interface BedrockMessage {
role: 'user' | 'assistant';
content?: string;
rawContent?: BedRockMessagePart[];
rawContent?: BedrockMessagePart[];
}

/**
* Bedrock message parts
*/
export type BedRockMessagePart =
export type BedrockMessagePart =
| { type: 'text'; text: string }
| {
type: 'tool_use';
Expand Down
2 changes: 1 addition & 1 deletion x-pack/plugins/stack_connectors/common/bedrock/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ export type RunActionResponse = TypeOf<typeof RunActionResponseSchema>;
export type StreamingResponse = TypeOf<typeof StreamingResponseSchema>;
export type DashboardActionParams = TypeOf<typeof DashboardActionParamsSchema>;
export type DashboardActionResponse = TypeOf<typeof DashboardActionResponseSchema>;
export type BedRockMessage = TypeOf<typeof BedrockMessageSchema>;
export type BedrockMessage = TypeOf<typeof BedrockMessageSchema>;
export type BedrockToolChoice = TypeOf<typeof BedrockToolChoiceSchema>;
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import type {
InvokeAIRawActionParams,
InvokeAIRawActionResponse,
RunApiLatestResponse,
BedRockMessage,
BedrockMessage,
BedrockToolChoice,
} from '../../../common/bedrock/types';
import {
Expand Down Expand Up @@ -392,16 +392,18 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
): Promise<InvokeAIRawActionResponse> {
const res = await this.runApi(
{
body: JSON.stringify({
messages,
stop_sequences: stopSequences,
system,
temperature,
max_tokens: maxTokens,
tools,
tool_choice: toolChoice,
anthropic_version: anthropicVersion,
}),
body: JSON.stringify(
formatBedrockBody({
messages,
stopSequences,
system,
temperature,
maxTokens,
tools,
toolChoice,
anthropicVersion,
})
),
model,
signal,
timeout,
Expand All @@ -414,6 +416,7 @@ The Kibana Connector in use may need to be reconfigured with an updated Amazon B
}

const formatBedrockBody = ({
anthropicVersion,
messages,
stopSequences,
temperature = 0,
Expand All @@ -422,7 +425,8 @@ const formatBedrockBody = ({
tools,
toolChoice,
}: {
messages: BedRockMessage[];
anthropicVersion?: string;
messages: BedrockMessage[];
stopSequences?: string[];
temperature?: number;
maxTokens?: number;
Expand All @@ -431,7 +435,7 @@ const formatBedrockBody = ({
tools?: Array<{ name: string; description: string }>;
toolChoice?: BedrockToolChoice;
}) => ({
anthropic_version: 'bedrock-2023-05-31',
anthropic_version: anthropicVersion ?? 'bedrock-2023-05-31',
...ensureMessageFormat(messages, system),
max_tokens: maxTokens,
stop_sequences: stopSequences,
Expand All @@ -440,9 +444,9 @@ const formatBedrockBody = ({
tool_choice: toolChoice,
});

interface FormattedBedRockMessage {
interface FormattedBedrockMessage {
role: string;
content: string | BedRockMessage['rawContent'];
content: string | BedrockMessage['rawContent'];
}

/**
Expand All @@ -452,15 +456,15 @@ interface FormattedBedRockMessage {
* @param messages
*/
const ensureMessageFormat = (
messages: BedRockMessage[],
messages: BedrockMessage[],
systemPrompt?: string
): {
messages: FormattedBedRockMessage[];
messages: FormattedBedrockMessage[];
system?: string;
} => {
let system = systemPrompt ? systemPrompt : '';

const newMessages = messages.reduce<FormattedBedRockMessage[]>((acc, m) => {
const newMessages = messages.reduce<FormattedBedrockMessage[]>((acc, m) => {
if (m.role === 'system') {
system = `${system.length ? `${system}\n` : ''}${m.content}`;
return acc;
Expand Down