Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(js): adding Anthropic tool use to vertexai plugin #573

Merged
merged 6 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
267 changes: 188 additions & 79 deletions js/plugins/vertexai/src/anthropic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,18 @@
*/

import {
ContentBlock as AnthropicContent,
ImageBlockParam,
Message,
MessageCreateParamsBase,
MessageParam,
TextBlock,
TextBlockParam,
TextDelta,
Tool,
ToolResultBlockParam,
ToolUseBlock,
ToolUseBlockParam,
} from '@anthropic-ai/sdk/resources/messages';
import { AnthropicVertex } from '@anthropic-ai/vertex-sdk';
import {
Expand All @@ -31,6 +36,7 @@ import {
GenerationCommonConfigSchema,
Part as GenkitPart,
ModelReference,
Part,
defineModel,
getBasicUsageStats,
modelRef,
Expand All @@ -40,12 +46,12 @@ import { GENKIT_CLIENT_HEADER } from '@genkit-ai/core';
export const claude35Sonnet = modelRef({
name: 'vertexai/claude-3-5-sonnet',
info: {
label: 'Vertex AI Model Garden - Claude 35 Sonnet',
label: 'Vertex AI Model Garden - Claude 3.5 Sonnet',
versions: ['claude-3-5-sonnet@20240620'],
supports: {
multiturn: true,
media: true,
tools: false,
tools: true,
systemRole: true,
output: ['text'],
},
Expand All @@ -61,7 +67,7 @@ export const claude3Sonnet = modelRef({
supports: {
multiturn: true,
media: true,
tools: false,
tools: true,
systemRole: true,
output: ['text'],
},
Expand All @@ -77,7 +83,7 @@ export const claude3Haiku = modelRef({
supports: {
multiturn: true,
media: true,
tools: false,
tools: true,
systemRole: true,
output: ['text'],
},
Expand All @@ -93,7 +99,7 @@ export const claude3Opus = modelRef({
supports: {
multiturn: true,
media: true,
tools: false,
tools: true,
systemRole: true,
output: ['text'],
},
Expand All @@ -111,60 +117,6 @@ export const SUPPORTED_ANTHROPIC_MODELS: Record<
'claude-3-haiku': claude3Haiku,
};

export function anthropicModel(
modelName: string,
projectId: string,
region: string
) {
const client = new AnthropicVertex({
region,
projectId,
defaultHeaders: {
'X-Goog-Api-Client': GENKIT_CLIENT_HEADER,
},
});
const model = SUPPORTED_ANTHROPIC_MODELS[modelName];
if (!model) {
throw new Error(`unsupported Anthropic model name ${modelName}`);
}

return defineModel(
{
name: model.name,
label: model.info?.label,
configSchema: GenerationCommonConfigSchema,
supports: model.info?.supports,
versions: model.info?.versions,
},
async (input, streamingCallback) => {
if (!streamingCallback) {
const response = await client.messages.create({
...toAnthropicRequest(input.config?.version ?? modelName, input),
stream: false,
});
return fromAnthropicResponse(input, response);
} else {
const stream = await client.messages.stream(
toAnthropicRequest(input.config?.version ?? modelName, input)
);
for await (const event of stream) {
if (event.type === 'content_block_delta') {
streamingCallback({
index: 0,
content: [
{
text: (event.delta as TextDelta).text,
},
],
});
}
}
return fromAnthropicResponse(input, await stream.finalMessage());
}
}
);
}

export function toAnthropicRequest(
model: string,
input: GenerateRequest<typeof GenerationCommonConfigSchema>
Expand All @@ -183,6 +135,14 @@ export function toAnthropicRequest(
return c.text;
})
.join();
}
// If the last message is a tool response, we need to add a user message.
// https://docs.anthropic.com/en/docs/build-with-claude/tool-use#handling-tool-use-and-tool-result-content-blocks
else if (msg.content[msg.content.length - 1].toolResponse) {
messages.push({
role: 'user',
content: toAnthropicContent(msg.content),
});
} else {
messages.push({
role: toAnthropicRole(msg.role),
Expand All @@ -199,6 +159,15 @@ export function toAnthropicRequest(
if (system) {
request['system'] = system;
}
if (input.tools) {
request.tools = input.tools?.map((tool) => {
return {
name: tool.name,
description: tool.description,
input_schema: tool.inputSchema,
};
}) as Array<Tool>;
}
if (input.config?.stopSequences) {
request.stop_sequences = input.config?.stopSequences;
}
Expand All @@ -216,7 +185,9 @@ export function toAnthropicRequest(

function toAnthropicContent(
content: GenkitPart[]
): Array<TextBlockParam | ImageBlockParam> {
): Array<
TextBlockParam | ImageBlockParam | ToolUseBlockParam | ToolResultBlockParam
> {
return content.map((p) => {
if (p.text) {
return {
Expand All @@ -243,7 +214,13 @@ function toAnthropicContent(
},
};
}
throw new Error(`Unsupported content type: ${p}`);
if (p.toolRequest) {
return toAnthropicToolRequest(p.toolRequest);
}
if (p.toolResponse) {
return toAnthropicToolResponse(p);
}
throw new Error(`Unsupported content type: ${JSON.stringify(p)}`);
});
}

Expand All @@ -254,30 +231,68 @@ function toAnthropicRole(role): 'user' | 'assistant' {
if (role === 'user') {
return 'user';
}
if (role === 'tool') {
return 'assistant';
}
throw new Error(`Unsupported role type ${role}`);
}

function fromAnthropicTextPart(part: TextBlock): Part {
return {
text: part.text,
};
}

function fromAnthropicToolCallPart(part: ToolUseBlock): Part {
return {
toolRequest: {
name: part.name,
input: part.input,
ref: part.id,
},
};
}

// Converts an Anthropic part to a Genkit part.
function fromAnthropicPart(part: AnthropicContent): Part {
if (part.type === 'text') return fromAnthropicTextPart(part);
if (part.type === 'tool_use') return fromAnthropicToolCallPart(part);
throw new Error(
'Part type is unsupported/corrupted. Either data is missing or type cannot be inferred from type.'
);
}

// Converts an Anthropic candidate to a Genkit candidate.
function fromAnthropicCandidate(candidate: Message): CandidateData {
const parts = candidate.content as AnthropicContent[];
const genkitCandidate: CandidateData = {
index: 0,
message: {
role: 'model',
content: parts.map(fromAnthropicPart),
},
finishReason: toGenkitFinishReason(
candidate.stop_reason as
| 'end_turn'
| 'max_tokens'
| 'stop_sequence'
| 'tool_use'
| null
),
custom: {
id: candidate.id,
model: candidate.model,
type: candidate.type,
},
};
return genkitCandidate;
}

export function fromAnthropicResponse(
input: GenerateRequest<typeof GenerationCommonConfigSchema>,
response: Message
): GenerateResponseData {
const candidates: CandidateData[] = [
{
index: 0,
finishReason: toGenkitFinishReason(
response.stop_reason as 'end_turn' | 'max_tokens' | 'stop_sequence'
),
custom: {
id: response.id,
model: response.model,
type: response.type,
},
message: {
role: 'model',
content: response.content.map((c) => ({ text: (c as TextBlock).text })),
},
},
];
const candidates: CandidateData[] = [fromAnthropicCandidate(response)];
return {
candidates,
usage: {
Expand All @@ -289,7 +304,7 @@ export function fromAnthropicResponse(
}

function toGenkitFinishReason(
reason: 'end_turn' | 'max_tokens' | 'stop_sequence' | null
reason: 'end_turn' | 'max_tokens' | 'stop_sequence' | 'tool_use' | null
): CandidateData['finishReason'] {
switch (reason) {
case 'end_turn':
Expand All @@ -298,9 +313,103 @@ function toGenkitFinishReason(
return 'length';
case 'stop_sequence':
return 'stop';
case 'tool_use':
return 'stop';
case null:
return 'unknown';
default:
return 'other';
}
}

function toAnthropicToolRequest(tool: Record<string, any>): ToolUseBlock {
if (!tool.name) {
throw new Error('Tool name is required');
}
// Validate the tool name, Anthropic only supports letters, numbers, and underscores.
// https://docs.anthropic.com/en/docs/build-with-claude/tool-use#specifying-tools
if (!/^[a-zA-Z0-9_-]{1,64}$/.test(tool.name)) {
throw new Error(
`Tool name ${tool.name} contains invalid characters.
Only letters, numbers, and underscores are allowed,
and the name must be between 1 and 64 characters long.`
);
}
const declaration: ToolUseBlock = {
type: 'tool_use',
id: tool.ref,
name: tool.name,
input: tool.input,
};
return declaration;
}

function toAnthropicToolResponse(part: Part): ToolResultBlockParam {
if (!part.toolResponse?.ref) {
throw new Error('Tool response reference is required');
}

if (!part.toolResponse.output) {
throw new Error('Tool response output is required');
}

return {
type: 'tool_result',
tool_use_id: part.toolResponse.ref,
content: JSON.stringify(part.toolResponse.output),
};
}

export function anthropicModel(
modelName: string,
projectId: string,
region: string
) {
const client = new AnthropicVertex({
region,
projectId,
defaultHeaders: {
'X-Goog-Api-Client': GENKIT_CLIENT_HEADER,
},
});
const model = SUPPORTED_ANTHROPIC_MODELS[modelName];
if (!model) {
throw new Error(`unsupported Anthropic model name ${modelName}`);
}

return defineModel(
{
name: model.name,
label: model.info?.label,
configSchema: GenerationCommonConfigSchema,
supports: model.info?.supports,
versions: model.info?.versions,
},
async (input, streamingCallback) => {
if (!streamingCallback) {
const response = await client.messages.create({
...toAnthropicRequest(input.config?.version ?? modelName, input),
stream: false,
});
return fromAnthropicResponse(input, response);
} else {
const stream = await client.messages.stream(
toAnthropicRequest(input.config?.version ?? modelName, input)
);
for await (const event of stream) {
if (event.type === 'content_block_delta') {
streamingCallback({
index: 0,
content: [
{
text: (event.delta as TextDelta).text,
},
],
});
}
}
return fromAnthropicResponse(input, await stream.finalMessage());
}
}
);
}
Loading
Loading