From 0fc07ea55897a9d74380da2767b9bfa25e71cbd3 Mon Sep 17 00:00:00 2001 From: Mikey Date: Tue, 2 Jul 2024 00:12:01 +0800 Subject: [PATCH] feat: add support for Claude 3 tool use (function calling) (#1587) * feat: add tool support for AWS & Claude * fix: add {} for openai compatibility in streaming tool_use --- relay/adaptor/anthropic/main.go | 121 +++++++++++++++++++++++++++++-- relay/adaptor/anthropic/model.go | 21 ++++++ relay/adaptor/aws/main.go | 24 +++++- relay/adaptor/aws/model.go | 3 + relay/model/message.go | 9 ++- relay/model/tool.go | 4 +- 6 files changed, 168 insertions(+), 14 deletions(-) diff --git a/relay/adaptor/anthropic/main.go b/relay/adaptor/anthropic/main.go index c817a9d1ed..d3e306c88c 100644 --- a/relay/adaptor/anthropic/main.go +++ b/relay/adaptor/anthropic/main.go @@ -29,12 +29,30 @@ func stopReasonClaude2OpenAI(reason *string) string { return "stop" case "max_tokens": return "length" + case "tool_use": + return "tool_calls" default: return *reason } } func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { + claudeTools := make([]Tool, 0, len(textRequest.Tools)) + + for _, tool := range textRequest.Tools { + if params, ok := tool.Function.Parameters.(map[string]any); ok { + claudeTools = append(claudeTools, Tool{ + Name: tool.Function.Name, + Description: tool.Function.Description, + InputSchema: InputSchema{ + Type: params["type"].(string), + Properties: params["properties"], + Required: params["required"], + }, + }) + } + } + claudeRequest := Request{ Model: textRequest.Model, MaxTokens: textRequest.MaxTokens, @@ -42,6 +60,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { TopP: textRequest.TopP, TopK: textRequest.TopK, Stream: textRequest.Stream, + Tools: claudeTools, + } + if len(claudeTools) > 0 { + claudeToolChoice := struct { + Type string `json:"type"` + Name string `json:"name,omitempty"` + }{Type: "auto"} // default value https://docs.anthropic.com/en/docs/build-with-claude/tool-use#controlling-claudes-output + if choice, ok := textRequest.ToolChoice.(map[string]any); ok { + if function, ok := choice["function"].(map[string]any); ok { + claudeToolChoice.Type = "tool" + claudeToolChoice.Name = function["name"].(string) + } + } else if toolChoiceType, ok := textRequest.ToolChoice.(string); ok { + if toolChoiceType == "any" { + claudeToolChoice.Type = toolChoiceType + } + } + claudeRequest.ToolChoice = claudeToolChoice } if claudeRequest.MaxTokens == 0 { claudeRequest.MaxTokens = 4096 @@ -64,7 +100,24 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request { if message.IsStringContent() { content.Type = "text" content.Text = message.StringContent() + if message.Role == "tool" { + claudeMessage.Role = "user" + content.Type = "tool_result" + content.Content = content.Text + content.Text = "" + content.ToolUseId = message.ToolCallId + } claudeMessage.Content = append(claudeMessage.Content, content) + for i := range message.ToolCalls { + inputParam := make(map[string]any) + _ = json.Unmarshal([]byte(message.ToolCalls[i].Function.Arguments.(string)), &inputParam) + claudeMessage.Content = append(claudeMessage.Content, Content{ + Type: "tool_use", + Id: message.ToolCalls[i].Id, + Name: message.ToolCalls[i].Function.Name, + Input: inputParam, + }) + } claudeRequest.Messages = append(claudeRequest.Messages, claudeMessage) continue } @@ -97,16 +150,35 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo var response *Response var responseText string var stopReason string + tools := make([]model.Tool, 0) + switch claudeResponse.Type { case "message_start": return nil, claudeResponse.Message case "content_block_start": if claudeResponse.ContentBlock != nil { responseText = claudeResponse.ContentBlock.Text + if claudeResponse.ContentBlock.Type == "tool_use" { + tools = append(tools, model.Tool{ + Id: claudeResponse.ContentBlock.Id, + Type: "function", + Function: model.Function{ + Name: claudeResponse.ContentBlock.Name, + Arguments: "", + }, + }) + } } case "content_block_delta": if claudeResponse.Delta != nil { responseText = claudeResponse.Delta.Text + if claudeResponse.Delta.Type == "input_json_delta" { + tools = append(tools, model.Tool{ + Function: model.Function{ + Arguments: claudeResponse.Delta.PartialJson, + }, + }) + } } case "message_delta": if claudeResponse.Usage != nil { @@ -120,6 +192,10 @@ func StreamResponseClaude2OpenAI(claudeResponse *StreamResponse) (*openai.ChatCo } var choice openai.ChatCompletionsStreamResponseChoice choice.Delta.Content = responseText + if len(tools) > 0 { + choice.Delta.Content = nil // compatible with other OpenAI derivative applications, like LobeOpenAICompatibleFactory ... + choice.Delta.ToolCalls = tools + } choice.Delta.Role = "assistant" finishReason := stopReasonClaude2OpenAI(&stopReason) if finishReason != "null" { @@ -136,12 +212,27 @@ func ResponseClaude2OpenAI(claudeResponse *Response) *openai.TextResponse { if len(claudeResponse.Content) > 0 { responseText = claudeResponse.Content[0].Text } + tools := make([]model.Tool, 0) + for _, v := range claudeResponse.Content { + if v.Type == "tool_use" { + args, _ := json.Marshal(v.Input) + tools = append(tools, model.Tool{ + Id: v.Id, + Type: "function", // compatible with other OpenAI derivative applications + Function: model.Function{ + Name: v.Name, + Arguments: string(args), + }, + }) + } + } choice := openai.TextResponseChoice{ Index: 0, Message: model.Message{ - Role: "assistant", - Content: responseText, - Name: nil, + Role: "assistant", + Content: responseText, + Name: nil, + ToolCalls: tools, }, FinishReason: stopReasonClaude2OpenAI(claudeResponse.StopReason), } @@ -176,6 +267,7 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC var usage model.Usage var modelName string var id string + var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice for scanner.Scan() { data := scanner.Text() @@ -196,9 +288,20 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC if meta != nil { usage.PromptTokens += meta.Usage.InputTokens usage.CompletionTokens += meta.Usage.OutputTokens - modelName = meta.Model - id = fmt.Sprintf("chatcmpl-%s", meta.Id) - continue + if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event. + modelName = meta.Model + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + continue + } else { // finish_reason case + if len(lastToolCallChoice.Delta.ToolCalls) > 0 { + lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function + if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments. + lastArgs.Arguments = "{}" + response.Choices[len(response.Choices)-1].Delta.Content = nil + response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls + } + } + } } if response == nil { continue @@ -207,6 +310,12 @@ func StreamHandler(c *gin.Context, resp *http.Response) (*model.ErrorWithStatusC response.Id = id response.Model = modelName response.Created = createdTime + + for _, choice := range response.Choices { + if len(choice.Delta.ToolCalls) > 0 { + lastToolCallChoice = choice + } + } err = render.ObjectData(c, response) if err != nil { logger.SysError(err.Error()) diff --git a/relay/adaptor/anthropic/model.go b/relay/adaptor/anthropic/model.go index 32b187cd12..47f766291d 100644 --- a/relay/adaptor/anthropic/model.go +++ b/relay/adaptor/anthropic/model.go @@ -16,6 +16,12 @@ type Content struct { Type string `json:"type"` Text string `json:"text,omitempty"` Source *ImageSource `json:"source,omitempty"` + // tool_calls + Id string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input any `json:"input,omitempty"` + Content string `json:"content,omitempty"` + ToolUseId string `json:"tool_use_id,omitempty"` } type Message struct { @@ -23,6 +29,18 @@ type Message struct { Content []Content `json:"content"` } +type Tool struct { + Name string `json:"name"` + Description string `json:"description,omitempty"` + InputSchema InputSchema `json:"input_schema"` +} + +type InputSchema struct { + Type string `json:"type"` + Properties any `json:"properties,omitempty"` + Required any `json:"required,omitempty"` +} + type Request struct { Model string `json:"model"` Messages []Message `json:"messages"` @@ -33,6 +51,8 @@ type Request struct { Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` + Tools []Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` //Metadata `json:"metadata,omitempty"` } @@ -61,6 +81,7 @@ type Response struct { type Delta struct { Type string `json:"type"` Text string `json:"text"` + PartialJson string `json:"partial_json,omitempty"` StopReason *string `json:"stop_reason"` StopSequence *string `json:"stop_sequence"` } diff --git a/relay/adaptor/aws/main.go b/relay/adaptor/aws/main.go index 5d29597ce8..72f40ddcdf 100644 --- a/relay/adaptor/aws/main.go +++ b/relay/adaptor/aws/main.go @@ -6,6 +6,7 @@ import ( "encoding/json" "fmt" "github.com/songquanpeng/one-api/common/ctxkey" + "github.com/songquanpeng/one-api/relay/adaptor/openai" "io" "net/http" @@ -143,6 +144,8 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E c.Writer.Header().Set("Content-Type", "text/event-stream") var usage relaymodel.Usage var id string + var lastToolCallChoice openai.ChatCompletionsStreamResponseChoice + c.Stream(func(w io.Writer) bool { event, ok := <-stream.Events() if !ok { @@ -163,8 +166,19 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E if meta != nil { usage.PromptTokens += meta.Usage.InputTokens usage.CompletionTokens += meta.Usage.OutputTokens - id = fmt.Sprintf("chatcmpl-%s", meta.Id) - return true + if len(meta.Id) > 0 { // only message_start has an id, otherwise it's a finish_reason event. + id = fmt.Sprintf("chatcmpl-%s", meta.Id) + return true + } else { // finish_reason case + if len(lastToolCallChoice.Delta.ToolCalls) > 0 { + lastArgs := &lastToolCallChoice.Delta.ToolCalls[len(lastToolCallChoice.Delta.ToolCalls)-1].Function + if len(lastArgs.Arguments.(string)) == 0 { // compatible with OpenAI sending an empty object `{}` when no arguments. + lastArgs.Arguments = "{}" + response.Choices[len(response.Choices)-1].Delta.Content = nil + response.Choices[len(response.Choices)-1].Delta.ToolCalls = lastToolCallChoice.Delta.ToolCalls + } + } + } } if response == nil { return true @@ -172,6 +186,12 @@ func StreamHandler(c *gin.Context, awsCli *bedrockruntime.Client) (*relaymodel.E response.Id = id response.Model = c.GetString(ctxkey.OriginalModel) response.Created = createdTime + + for _, choice := range response.Choices { + if len(choice.Delta.ToolCalls) > 0 { + lastToolCallChoice = choice + } + } jsonStr, err := json.Marshal(response) if err != nil { logger.SysError("error marshalling stream response: " + err.Error()) diff --git a/relay/adaptor/aws/model.go b/relay/adaptor/aws/model.go index bcbfb584c2..6d00b68865 100644 --- a/relay/adaptor/aws/model.go +++ b/relay/adaptor/aws/model.go @@ -9,9 +9,12 @@ type Request struct { // AnthropicVersion should be "bedrock-2023-05-31" AnthropicVersion string `json:"anthropic_version"` Messages []anthropic.Message `json:"messages"` + System string `json:"system,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` Temperature float64 `json:"temperature,omitempty"` TopP float64 `json:"top_p,omitempty"` TopK int `json:"top_k,omitempty"` StopSequences []string `json:"stop_sequences,omitempty"` + Tools []anthropic.Tool `json:"tools,omitempty"` + ToolChoice any `json:"tool_choice,omitempty"` } diff --git a/relay/model/message.go b/relay/model/message.go index 32a1055b92..b908f98912 100644 --- a/relay/model/message.go +++ b/relay/model/message.go @@ -1,10 +1,11 @@ package model type Message struct { - Role string `json:"role,omitempty"` - Content any `json:"content,omitempty"` - Name *string `json:"name,omitempty"` - ToolCalls []Tool `json:"tool_calls,omitempty"` + Role string `json:"role,omitempty"` + Content any `json:"content,omitempty"` + Name *string `json:"name,omitempty"` + ToolCalls []Tool `json:"tool_calls,omitempty"` + ToolCallId string `json:"tool_call_id,omitempty"` } func (m Message) IsStringContent() bool { diff --git a/relay/model/tool.go b/relay/model/tool.go index 253dca3537..75dbb8f723 100644 --- a/relay/model/tool.go +++ b/relay/model/tool.go @@ -2,13 +2,13 @@ package model type Tool struct { Id string `json:"id,omitempty"` - Type string `json:"type"` + Type string `json:"type,omitempty"` // when splicing claude tools stream messages, it is empty Function Function `json:"function"` } type Function struct { Description string `json:"description,omitempty"` - Name string `json:"name"` + Name string `json:"name,omitempty"` // when splicing claude tools stream messages, it is empty Parameters any `json:"parameters,omitempty"` // request Arguments any `json:"arguments,omitempty"` // response }