Skip to content

Commit

Permalink
feat(adapters): add embedding support for Bedrock
Browse files Browse the repository at this point in the history
Ref: #176
Signed-off-by: Tomas Dvorak <[email protected]>
  • Loading branch information
Tomas2D committed Dec 13, 2024
1 parent 4673b5e commit 0f60ec0
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 4 deletions.
27 changes: 23 additions & 4 deletions src/adapters/bedrock/chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import { Emitter } from "@/emitter/emitter.js";
import type { AwsCredentialIdentity, Provider } from "@aws-sdk/types";
import {
BedrockRuntimeClient as Client,
InvokeModelCommand,
ConverseCommand,
ConverseCommandOutput,
ConverseStreamCommand,
Expand All @@ -42,10 +43,13 @@ import {
} from "@aws-sdk/client-bedrock-runtime";
import { GetRunContext } from "@/context.js";
import { Serializer } from "@/serializer/serializer.js";
import { NotImplementedError } from "@/errors.js";

type Response = ContentBlockDeltaEvent | ConverseCommandOutput;

export interface BedrockEmbeddingOptions extends EmbeddingOptions {
body?: Record<string, any>;
}

export class ChatBedrockOutput extends ChatLLMOutput {
public readonly responses: Response[];

Expand Down Expand Up @@ -204,9 +208,24 @@ export class BedrockChatLLM extends ChatLLM<ChatBedrockOutput> {
};
}

// eslint-disable-next-line unused-imports/no-unused-vars
async embed(input: BaseMessage[][], options?: EmbeddingOptions): Promise<EmbeddingOutput> {
throw new NotImplementedError();
async embed(
input: BaseMessage[][],
options: BedrockEmbeddingOptions = {},
): Promise<EmbeddingOutput> {
const command = new InvokeModelCommand({
modelId: this.modelId,
contentType: "application/json",
accept: "application/json",
body: JSON.stringify({
texts: input.map((msgs) => msgs.map((msg) => msg.text)),
input_type: "search_document",
...options?.body,
}),
});

const response = await this.client.send(command, { abortSignal: options?.signal });
const jsonString = new TextDecoder().decode(response.body);
return JSON.parse(jsonString);
}

async tokenize(input: BaseMessage[]): Promise<BaseLLMTokenizeOutput> {
Expand Down
19 changes: 19 additions & 0 deletions tests/e2e/adapters/bedrock/chat.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import { BedrockChatLLM } from "@/adapters/bedrock/chat.js";
import { BaseMessage } from "@/llms/primitives/message.js";

describe.runIf([process.env.AWS_REGION].every((env) => Boolean(env)))("Bedrock Chat LLM", () => {
it("Embeds", async () => {
const llm = new BedrockChatLLM({
region: process.env.AWS_REGION,
modelId: "amazon.titan-embed-text-v1",
});

const response = await llm.embed([
[BaseMessage.of({ role: "user", text: `Hello world!` })],
[BaseMessage.of({ role: "user", text: `Hello family!` })],
]);
expect(response.embeddings.length).toBe(2);
expect(response.embeddings[0].length).toBe(512);
expect(response.embeddings[1].length).toBe(512);
});
});

0 comments on commit 0f60ec0

Please sign in to comment.