diff --git a/src/adapters/bedrock/chat.ts b/src/adapters/bedrock/chat.ts index d70082a5..e7338d19 100644 --- a/src/adapters/bedrock/chat.ts +++ b/src/adapters/bedrock/chat.ts @@ -32,6 +32,7 @@ import { Emitter } from "@/emitter/emitter.js"; import type { AwsCredentialIdentity, Provider } from "@aws-sdk/types"; import { BedrockRuntimeClient as Client, + InvokeModelCommand, ConverseCommand, ConverseCommandOutput, ConverseStreamCommand, @@ -42,10 +43,13 @@ import { } from "@aws-sdk/client-bedrock-runtime"; import { GetRunContext } from "@/context.js"; import { Serializer } from "@/serializer/serializer.js"; -import { NotImplementedError } from "@/errors.js"; type Response = ContentBlockDeltaEvent | ConverseCommandOutput; +export interface BedrockEmbeddingOptions extends EmbeddingOptions { + body?: Record; +} + export class ChatBedrockOutput extends ChatLLMOutput { public readonly responses: Response[]; @@ -204,9 +208,24 @@ export class BedrockChatLLM extends ChatLLM { }; } - // eslint-disable-next-line unused-imports/no-unused-vars - async embed(input: BaseMessage[][], options?: EmbeddingOptions): Promise { - throw new NotImplementedError(); + async embed( + input: BaseMessage[][], + options: BedrockEmbeddingOptions = {}, + ): Promise { + const command = new InvokeModelCommand({ + modelId: this.modelId, + contentType: "application/json", + accept: "application/json", + body: JSON.stringify({ + texts: input.map((msgs) => msgs.map((msg) => msg.text)), + input_type: "search_document", + ...options?.body, + }), + }); + + const response = await this.client.send(command, { abortSignal: options?.signal }); + const jsonString = new TextDecoder().decode(response.body); + return JSON.parse(jsonString); } async tokenize(input: BaseMessage[]): Promise { diff --git a/tests/e2e/adapters/bedrock/chat.test.ts b/tests/e2e/adapters/bedrock/chat.test.ts new file mode 100644 index 00000000..fdf5f7f9 --- /dev/null +++ b/tests/e2e/adapters/bedrock/chat.test.ts @@ -0,0 +1,19 @@ +import { BedrockChatLLM } from "@/adapters/bedrock/chat.js"; +import { BaseMessage } from "@/llms/primitives/message.js"; + +describe.runIf([process.env.AWS_REGION].every((env) => Boolean(env)))("Bedrock Chat LLM", () => { + it("Embeds", async () => { + const llm = new BedrockChatLLM({ + region: process.env.AWS_REGION, + modelId: "amazon.titan-embed-text-v1", + }); + + const response = await llm.embed([ + [BaseMessage.of({ role: "user", text: `Hello world!` })], + [BaseMessage.of({ role: "user", text: `Hello family!` })], + ]); + expect(response.embeddings.length).toBe(2); + expect(response.embeddings[0].length).toBe(512); + expect(response.embeddings[1].length).toBe(512); + }); +});