From 4e0d168cb9e90189650f3913e7f9161e53932798 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 7 Jun 2024 15:34:58 +0200 Subject: [PATCH 1/5] chore: add cohere tools --- assistant/assistant.go | 2 +- go.mod | 2 +- go.sum | 4 +- llm/cohere/cohere.go | 184 +++++++++++++++++++++++++++++++++++--- llm/cohere/formatter.go | 103 +++++++++++++++++++--- llm/cohere/function.go | 191 ++++++++++++++++++++++++++++++++++++++++ 6 files changed, 460 insertions(+), 26 deletions(-) create mode 100644 llm/cohere/function.go diff --git a/assistant/assistant.go b/assistant/assistant.go index e591ebd1..acb1e25b 100644 --- a/assistant/assistant.go +++ b/assistant/assistant.go @@ -24,7 +24,7 @@ type observer interface { } const ( - DefaultMaxIterations = 3 + DefaultMaxIterations = 1 ) type Assistant struct { diff --git a/go.mod b/go.mod index 4131f84e..b3f09e66 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require github.com/mitchellh/mapstructure v1.5.0 require ( github.com/RediSearch/redisearch-go/v2 v2.1.1 github.com/google/uuid v1.6.0 - github.com/henomis/cohere-go v1.1.2 + github.com/henomis/cohere-go v1.2.0-beta.6 github.com/henomis/langfuse-go v0.0.3 github.com/henomis/milvus-go v0.0.4 github.com/henomis/pinecone-go/v2 v2.0.0 diff --git a/go.sum b/go.sum index a79137ce..62d2a4bd 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,8 @@ github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws github.com/gomodule/redigo v1.8.9/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/henomis/cohere-go v1.1.2 h1:rzEA1JRm26RnaQValoVeVQ1y1nTofuhr/z/fPyhUW/4= -github.com/henomis/cohere-go v1.1.2/go.mod h1:z+UIgBbNCnLH5M47FYqg4h3CDWxjUv2VDfEwdTb95MU= +github.com/henomis/cohere-go v1.2.0-beta.6 h1:zhehuI2mMS7u8Pcy6dpHgnTcO6lIfKLZmuCTO9996fs= +github.com/henomis/cohere-go v1.2.0-beta.6/go.mod h1:z+UIgBbNCnLH5M47FYqg4h3CDWxjUv2VDfEwdTb95MU= github.com/henomis/langfuse-go v0.0.3 h1:Z5Mqlnj1fsok7eAT7jt+N4tegjbhY5HnZQ7wu1iVKss= github.com/henomis/langfuse-go v0.0.3/go.mod h1:gSRuO3nvjAvk/mgmb7b+9BcoN9s64GvXeaSN7PfVEKQ= github.com/henomis/milvus-go v0.0.4 h1:ArddXRJx/EGdQ75gB7TyEzD4Z4BqXzW16p8jRcjJOhs= diff --git a/llm/cohere/cohere.go b/llm/cohere/cohere.go index 1c53a04e..ff0f0180 100644 --- a/llm/cohere/cohere.go +++ b/llm/cohere/cohere.go @@ -2,11 +2,13 @@ package cohere import ( "context" + "encoding/json" "errors" "fmt" "os" "strings" + "github.com/google/uuid" coherego "github.com/henomis/cohere-go" "github.com/henomis/cohere-go/model" "github.com/henomis/cohere-go/request" @@ -31,6 +33,8 @@ const ( ModelCommandNightly Model = model.ModelCommandNightly ModelCommandLight Model = model.ModelCommandLight ModelCommandLightNightly Model = model.ModelCommandLightNightly + ModelCommandR Model = model.ModelCommandR + ModelCommandRPlus Model = model.ModelCommandRPlus ) const ( @@ -51,8 +55,7 @@ type Cohere struct { cache *cache.Cache streamCallbackFn StreamCallbackFn name string - observer llmobserver.LLMObserver - observerTraceID string + functions map[string]Function } func (c *Cohere) WithCache(cache *cache.Cache) *Cohere { @@ -72,6 +75,7 @@ func New() *Cohere { temperature: DefaultTemperature, maxTokens: DefaultMaxTokens, name: "cohere", + functions: make(map[string]Function), } } @@ -116,12 +120,6 @@ func (c *Cohere) WithStream(callbackFn StreamCallbackFn) *Cohere { return c } -func (c *Cohere) WithObserver(observer llmobserver.LLMObserver, traceID string) *Cohere { - c.observer = observer - c.observerTraceID = traceID - return c -} - // Completion returns the completion for the given prompt func (c *Cohere) Completion(ctx context.Context, prompt string) (string, error) { resp := &response.Generate{} @@ -219,6 +217,10 @@ func (c *Cohere) Generate(ctx context.Context, t *thread.Thread) error { chatRequest := c.buildChatCompletionRequest(t) + if len(c.functions) > 0 { + chatRequest.Tools = c.getChatCompletionRequestTools() + } + generation, err := c.startObserveGeneration(ctx, t) if err != nil { return fmt.Errorf("%w: %w", ErrCohereChat, err) @@ -260,9 +262,19 @@ func (c *Cohere) generate(ctx context.Context, t *thread.Thread, chatRequest *re return fmt.Errorf("%w: %w", ErrCohereChat, err) } - t.AddMessage(thread.NewAssistantMessage().AddContent( - thread.NewTextContent(response.Text), - )) + var messages []*thread.Message + if len(response.ToolCalls) > 0 { + messages = append(messages, toolCallsToToolCallMessage(response.ToolCalls)) + messages = append(messages, c.callTools(response.ToolCalls)...) + } else { + messages = []*thread.Message{ + thread.NewAssistantMessage().AddContent( + thread.NewTextContent(response.Text), + ), + } + } + + t.Messages = append(t.Messages, messages...) return nil } @@ -317,3 +329,153 @@ func (c *Cohere) stopObserveGeneration( messages, ) } + +func (o *Cohere) getChatCompletionRequestTools() []model.Tool { + tools := []model.Tool{} + + for _, function := range o.functions { + tool := model.Tool{ + Name: function.Name, + Description: function.Description, + ParameterDefinitions: make(map[string]model.ToolParameterDefinition), + } + + functionProperties, ok := function.Parameters["properties"] + if !ok { + continue + } + + functionPropertiesAsMap, isMap := functionProperties.(map[string]interface{}) + if !isMap { + continue + } + + for k, v := range functionPropertiesAsMap { + valueAsMap, isValueMap := v.(map[string]interface{}) + if !isValueMap { + continue + } + + description := "" + descriptionValue, isDescriptionValue := valueAsMap["description"] + if isDescriptionValue { + descriptionAsString, isDescriptionString := descriptionValue.(string) + if isDescriptionString { + description = descriptionAsString + } + } + + argType := "" + argTypeValue, isArgTypeValue := valueAsMap["type"] + if isArgTypeValue { + argTypeAsString, isArgTypeString := argTypeValue.(string) + if isArgTypeString { + argType = argTypeAsString + } + } + + required := false + requiredValue, isRequiredValue := valueAsMap["required"] + if isRequiredValue { + requiredAsBool, isRequiredBool := requiredValue.(bool) + if isRequiredBool { + required = requiredAsBool + } + } + + if description == "" || argType == "" { + continue + } + + tool.ParameterDefinitions[k] = model.ToolParameterDefinition{ + Description: description, + Type: argType, + Required: required, + } + } + + tools = append(tools, tool) + } + + return tools +} + +func toolCallsToToolCallMessage(toolCalls []model.ToolCall) *thread.Message { + if len(toolCalls) == 0 { + return nil + } + + var toolCallData []thread.ToolCallData + for i, toolCall := range toolCalls { + parametersAsString, err := json.Marshal(toolCall.Parameters) + if err != nil { + continue + } + + if toolCalls[i].ID == "" { + toolCalls[i].ID = uuid.New().String() + } + + toolCallData = append(toolCallData, thread.ToolCallData{ + ID: toolCalls[i].ID, + Name: toolCall.Name, + Arguments: string(parametersAsString), + }) + } + + return thread.NewAssistantMessage().AddContent( + thread.NewToolCallContent( + toolCallData, + ), + ) +} + +func toolCallResultToThreadMessage(toolCall model.ToolCall, result string) *thread.Message { + return thread.NewToolMessage().AddContent( + thread.NewToolResponseContent( + thread.ToolResponseData{ + ID: toolCall.ID, + Name: toolCall.Name, + Result: result, + }, + ), + ) +} + +func (o *Cohere) callTools(toolCalls []model.ToolCall) []*thread.Message { + if len(o.functions) == 0 || len(toolCalls) == 0 { + return nil + } + + var messages []*thread.Message + for _, toolCall := range toolCalls { + result, err := o.callTool(toolCall) + if err != nil { + result = fmt.Sprintf("error: %s", err) + } + + messages = append(messages, toolCallResultToThreadMessage(toolCall, result)) + } + + return messages +} + +func (o *Cohere) callTool(toolCall model.ToolCall) (string, error) { + fn, ok := o.functions[toolCall.Name] + if !ok { + return "", fmt.Errorf("unknown function %s", toolCall.Name) + } + + parameters, err := json.Marshal(toolCall.Parameters) + if err != nil { + return "", err + } + + resultAsJSON, err := callFnWithArgumentAsJSON(fn.Fn, string(parameters)) + //resultAsJSON, err := callFnWithArgumentAsJSON(fn.Fn, "{}") + if err != nil { + return "", err + } + + return resultAsJSON, nil +} diff --git a/llm/cohere/formatter.go b/llm/cohere/formatter.go index f39bd16f..750faab9 100644 --- a/llm/cohere/formatter.go +++ b/llm/cohere/formatter.go @@ -1,6 +1,8 @@ package cohere import ( + "encoding/json" + "github.com/henomis/cohere-go/model" "github.com/henomis/cohere-go/request" "github.com/henomis/lingoose/thread" @@ -10,39 +12,61 @@ var threadRoleToCohereRole = map[thread.Role]model.ChatMessageRole{ thread.RoleSystem: model.ChatMessageRoleChatbot, thread.RoleUser: model.ChatMessageRoleUser, thread.RoleAssistant: model.ChatMessageRoleChatbot, + thread.RoleTool: model.ChatMessageRoleTool, } func (c *Cohere) buildChatCompletionRequest(t *thread.Thread) *request.Chat { - message, history := threadToChatMessages(t) + message, history, toolResults := threadToChatMessages(t) return &request.Chat{ Model: c.model, ChatHistory: history, Message: message, + ToolResults: toolResults, } } -func threadToChatMessages(t *thread.Thread) (string, []model.ChatMessage) { +func threadToChatMessages(t *thread.Thread) (string, []model.ChatMessage, []model.ToolResult) { var history []model.ChatMessage + var toolResults []model.ToolResult var message string - for _, m := range t.Messages { - chatMessage := model.ChatMessage{ - Role: threadRoleToCohereRole[m.Role], - } + for i, m := range t.Messages { switch m.Role { - case thread.RoleUser, thread.RoleSystem, thread.RoleAssistant: + case thread.RoleUser, thread.RoleSystem: + chatMessage := model.ChatMessage{ + Role: threadRoleToCohereRole[m.Role], + } + + for _, content := range m.Contents { + if content.Type == thread.ContentTypeText { + chatMessage.Message += content.Data.(string) + "\n" + } + } + + history = append(history, chatMessage) + + case thread.RoleAssistant: + // check if the message is a tool call + if data, isTollCallData := m.Contents[0].Data.([]thread.ToolCallData); isTollCallData { + toolResults = append(toolResults, threadToChatMessagesTool(data, t, i)...) + continue + } + + chatMessage := model.ChatMessage{ + Role: threadRoleToCohereRole[m.Role], + } + for _, content := range m.Contents { if content.Type == thread.ContentTypeText { chatMessage.Message += content.Data.(string) + "\n" } } - case thread.RoleTool: - continue + + history = append(history, chatMessage) } - history = append(history, chatMessage) } lastMessage := t.LastMessage() @@ -56,5 +80,62 @@ func threadToChatMessages(t *thread.Thread) (string, []model.ChatMessage) { history = history[:len(history)-1] } - return message, history + return message, history, toolResults +} + +func threadToChatMessagesTool(data []thread.ToolCallData, t *thread.Thread, index int) []model.ToolResult { + var toolResults []model.ToolResult + + for nToolCall, toolCallData := range data { + toolResult := extractToolResultFromThread(t, toolCallData, index, nToolCall) + if toolResult != nil { + toolResults = append(toolResults, *toolResult) + } + } + + return toolResults +} + +func extractToolResultFromThread(t *thread.Thread, toolCallData thread.ToolCallData, index, nToolCall int) *model.ToolResult { + messageIndex := index + 1 + nToolCall + + // check if the message index is within the bounds of the thread + if messageIndex >= len(t.Messages) { + return nil + } + + message := t.Messages[messageIndex] + + // check if the message role is a tool + if message.Role != thread.RoleTool { + return nil + } + + if data, isTollResponseData := message.Contents[0].Data.(thread.ToolResponseData); isTollResponseData { + + argumentsAsMap := make(map[string]interface{}) + err := json.Unmarshal([]byte(toolCallData.Arguments), &argumentsAsMap) + if err != nil { + return nil + } + + resultsAsMap := make(map[string]interface{}) + err = json.Unmarshal([]byte(data.Result), &resultsAsMap) + if err != nil { + return nil + } + + return &model.ToolResult{ + Call: model.ToolCall{ + Name: toolCallData.Name, + Parameters: argumentsAsMap, + }, + Outputs: []interface{}{ + resultsAsMap, + }, + } + + } + + return nil } diff --git a/llm/cohere/function.go b/llm/cohere/function.go new file mode 100644 index 00000000..b562dc14 --- /dev/null +++ b/llm/cohere/function.go @@ -0,0 +1,191 @@ +package cohere + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + + "github.com/invopop/jsonschema" +) + +type Function struct { + Name string + Description string + Parameters map[string]interface{} + Fn interface{} +} + +type FunctionParameterOption func(map[string]interface{}) error + +func bindFunction( + fn interface{}, + name string, + description string, + functionParameterOptions ...FunctionParameterOption, +) (*Function, error) { + parameter, err := extractFunctionParameter(fn) + if err != nil { + return nil, err + } + + for _, option := range functionParameterOptions { + err = option(parameter) + if err != nil { + return nil, err + } + } + + return &Function{ + Name: name, + Description: description, + Parameters: parameter, + Fn: fn, + }, nil +} + +func (o *Cohere) BindFunction( + fn interface{}, + name string, + description string, + functionParameterOptions ...FunctionParameterOption, +) error { + function, err := bindFunction(fn, name, description, functionParameterOptions...) + if err != nil { + return err + } + + o.functions[name] = *function + + return nil +} + +type Tool interface { + Description() string + Name() string + Fn() any +} + +func (o *Cohere) WithTools(tools ...Tool) *Cohere { + for _, tool := range tools { + function, err := bindFunction(tool.Fn(), tool.Name(), tool.Description()) + if err != nil { + fmt.Println(err) + } + + o.functions[tool.Name()] = *function + } + + return o +} + +func extractFunctionParameter(f interface{}) (map[string]interface{}, error) { + // Get the type of the input function + fnType := reflect.TypeOf(f) + + if fnType.Kind() != reflect.Func { + return nil, errors.New("input must be a function") + } + + // Check that the function only has one argument + if fnType.NumIn() != 1 { + return nil, errors.New("function must have exactly one argument") + } + + // Check that the argument is of type struct + argType := fnType.In(0) + if argType.Kind() != reflect.Struct { + return nil, errors.New("argument must be of type struct") + } + + // Create a new instance of the argument type + argValue := reflect.New(argType).Elem().Interface() + + parameter, err := structAsJSONSchema(argValue) + if err != nil { + return nil, err + } + + return parameter, nil +} + +func structAsJSONSchema(v interface{}) (map[string]interface{}, error) { + r := new(jsonschema.Reflector) + r.DoNotReference = true + schema := r.Reflect(v) + + b, err := json.Marshal(schema) + if err != nil { + return nil, err + } + + var jsonSchema map[string]interface{} + err = json.Unmarshal(b, &jsonSchema) + if err != nil { + return nil, err + } + + delete(jsonSchema, "$schema") + + return jsonSchema, nil +} + +func callFnWithArgumentAsJSON(fn interface{}, argumentAsJSON string) (string, error) { + // Get the type of the input function + fnType := reflect.TypeOf(fn) + + // Check that the function has one argument + if fnType.NumIn() != 1 { + return "", fmt.Errorf("function must have one argument") + } + + // Check that the argument is a struct + argType := fnType.In(0) + if argType.Kind() != reflect.Struct { + return "", fmt.Errorf("argument must be a struct") + } + + // Create a slice to hold the function argument + args := make([]reflect.Value, 1) + + // Unmarshal the JSON string into an interface{} value + var argValue interface{} + err := json.Unmarshal([]byte(argumentAsJSON), &argValue) + if err != nil { + return "", fmt.Errorf("error unmarshaling argument: %w", err) + } + + // Convert the argument value to the correct type + argValueReflect := reflect.New(argType).Elem() + jsonData, err := json.Marshal(argValue) + if err != nil { + return "", fmt.Errorf("error marshaling argument: %w", err) + } + err = json.Unmarshal(jsonData, argValueReflect.Addr().Interface()) + if err != nil { + return "", fmt.Errorf("error unmarshaling argument: %w", err) + } + + // Add the argument value to the slice + args[0] = argValueReflect + + // Call the function with the argument + fnValue := reflect.ValueOf(fn) + result := fnValue.Call(args) + + // Marshal the function result to JSON + if len(result) > 0 { + var resultBytes bytes.Buffer + enc := json.NewEncoder(&resultBytes) + enc.SetEscapeHTML(false) + err = enc.Encode(result[0].Interface()) + if err != nil { + return "", fmt.Errorf("error marshaling result: %w", err) + } + return strings.TrimSpace(resultBytes.String()), nil + } + + return "", nil +} From 97be46168095e3b4d5e20643a8f231191840b92f Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 7 Jun 2024 15:51:22 +0200 Subject: [PATCH 2/5] chore: refactor agent to proper behaviour --- assistant/assistant.go | 20 ++++++++++---------- examples/assistant/{agent => tools}/main.go | 2 +- llm/cohere/cohere.go | 4 +++- 3 files changed, 14 insertions(+), 12 deletions(-) rename examples/assistant/{agent => tools}/main.go (98%) diff --git a/assistant/assistant.go b/assistant/assistant.go index acb1e25b..660afc66 100644 --- a/assistant/assistant.go +++ b/assistant/assistant.go @@ -24,15 +24,15 @@ type observer interface { } const ( - DefaultMaxIterations = 1 + DefaultToolsMaxIterations = 1 ) type Assistant struct { - llm LLM - rag RAG - thread *thread.Thread - parameters Parameters - maxIterations uint + llm LLM + rag RAG + thread *thread.Thread + parameters Parameters + toolsMaxIterations uint } type LLM interface { @@ -54,7 +54,7 @@ func New(llm LLM) *Assistant { CompanyName: defaultCompanyName, CompanyDescription: defaultCompanyDescription, }, - maxIterations: DefaultMaxIterations, + toolsMaxIterations: DefaultToolsMaxIterations, } return assistant @@ -94,7 +94,7 @@ func (a *Assistant) Run(ctx context.Context) error { a.injectSystemMessage() } - for i := 0; i < int(a.maxIterations); i++ { + for i := 0; i < int(a.toolsMaxIterations); i++ { err = a.runIteration(ctx, i) if err != nil { return err @@ -181,8 +181,8 @@ func (a *Assistant) generateRAGMessage(ctx context.Context) error { return nil } -func (a *Assistant) WithMaxIterations(maxIterations uint) *Assistant { - a.maxIterations = maxIterations +func (a *Assistant) WithToolsMaxIterations(maxIterations uint) *Assistant { + a.toolsMaxIterations = maxIterations return a } diff --git a/examples/assistant/agent/main.go b/examples/assistant/tools/main.go similarity index 98% rename from examples/assistant/agent/main.go rename to examples/assistant/tools/main.go index 96441d56..c6b3e59c 100644 --- a/examples/assistant/agent/main.go +++ b/examples/assistant/tools/main.go @@ -46,7 +46,7 @@ func main() { thread.NewTextContent("search the top 3 italian dishes and then their costs, then ask the user's budget in euros and calculate how many guests can be invited for each dish"), ), ), - ).WithMaxIterations(10) + ).WithToolsMaxIterations(10) err = myAssistant.Run(ctx) if err != nil { diff --git a/llm/cohere/cohere.go b/llm/cohere/cohere.go index ff0f0180..3ea8f0b9 100644 --- a/llm/cohere/cohere.go +++ b/llm/cohere/cohere.go @@ -226,6 +226,8 @@ func (c *Cohere) Generate(ctx context.Context, t *thread.Thread) error { return fmt.Errorf("%w: %w", ErrCohereChat, err) } + nMessageBeforeGeneration := len(t.Messages) + if c.streamCallbackFn != nil { err = c.stream(ctx, t, chatRequest) } else { @@ -235,7 +237,7 @@ func (c *Cohere) Generate(ctx context.Context, t *thread.Thread) error { return err } - err = c.stopObserveGeneration(ctx, generation, []*thread.Message{t.LastMessage()}) + err = c.stopObserveGeneration(ctx, generation, t.Messages[nMessageBeforeGeneration:]) if err != nil { return fmt.Errorf("%w: %w", ErrCohereChat, err) } From 2074c1a8828692eec4bce7df3493a7ea49ef4215 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 7 Jun 2024 16:27:46 +0200 Subject: [PATCH 3/5] chore: fix --- .golangci.yml | 14 -------------- examples/assistant/tools/main.go | 2 +- llm/cohere/formatter.go | 11 +++++++---- 3 files changed, 8 insertions(+), 19 deletions(-) diff --git a/.golangci.yml b/.golangci.yml index abd0e282..5249134b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,17 +1,3 @@ -# Copyright 2013-2023 The Cobra Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - run: deadline: 5m diff --git a/examples/assistant/tools/main.go b/examples/assistant/tools/main.go index c6b3e59c..e6904c1f 100644 --- a/examples/assistant/tools/main.go +++ b/examples/assistant/tools/main.go @@ -43,7 +43,7 @@ func main() { ).WithThread( thread.New().AddMessages( thread.NewUserMessage().AddContent( - thread.NewTextContent("search the top 3 italian dishes and then their costs, then ask the user's budget in euros and calculate how many guests can be invited for each dish"), + thread.NewTextContent("search the top 3 italian meals and then their costs, then ask the user's budget in euros and calculate how many guests can be invited for each meal"), ), ), ).WithToolsMaxIterations(10) diff --git a/llm/cohere/formatter.go b/llm/cohere/formatter.go index 750faab9..a1297292 100644 --- a/llm/cohere/formatter.go +++ b/llm/cohere/formatter.go @@ -19,10 +19,13 @@ func (c *Cohere) buildChatCompletionRequest(t *thread.Thread) *request.Chat { message, history, toolResults := threadToChatMessages(t) return &request.Chat{ - Model: c.model, - ChatHistory: history, - Message: message, - ToolResults: toolResults, + Model: c.model, + ChatHistory: history, + Message: message, + ToolResults: toolResults, + Temperature: &c.temperature, + MaxTokens: &c.maxTokens, + StopSequences: c.stop, } } From 4b729e8f06b54fc9bd33fb81ddfc97de54d7db62 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Fri, 7 Jun 2024 16:45:35 +0200 Subject: [PATCH 4/5] chore: update cohere-go version --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index b3f09e66..56e33b5f 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require github.com/mitchellh/mapstructure v1.5.0 require ( github.com/RediSearch/redisearch-go/v2 v2.1.1 github.com/google/uuid v1.6.0 - github.com/henomis/cohere-go v1.2.0-beta.6 + github.com/henomis/cohere-go v1.2.0 github.com/henomis/langfuse-go v0.0.3 github.com/henomis/milvus-go v0.0.4 github.com/henomis/pinecone-go/v2 v2.0.0 diff --git a/go.sum b/go.sum index 62d2a4bd..34b87c26 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,8 @@ github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws github.com/gomodule/redigo v1.8.9/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/henomis/cohere-go v1.2.0-beta.6 h1:zhehuI2mMS7u8Pcy6dpHgnTcO6lIfKLZmuCTO9996fs= -github.com/henomis/cohere-go v1.2.0-beta.6/go.mod h1:z+UIgBbNCnLH5M47FYqg4h3CDWxjUv2VDfEwdTb95MU= +github.com/henomis/cohere-go v1.2.0 h1:rHi0JDdR/LYQPtQJrP0wKD6cLg8rFdJTW7KF5KjMG0s= +github.com/henomis/cohere-go v1.2.0/go.mod h1:z+UIgBbNCnLH5M47FYqg4h3CDWxjUv2VDfEwdTb95MU= github.com/henomis/langfuse-go v0.0.3 h1:Z5Mqlnj1fsok7eAT7jt+N4tegjbhY5HnZQ7wu1iVKss= github.com/henomis/langfuse-go v0.0.3/go.mod h1:gSRuO3nvjAvk/mgmb7b+9BcoN9s64GvXeaSN7PfVEKQ= github.com/henomis/milvus-go v0.0.4 h1:ArddXRJx/EGdQ75gB7TyEzD4Z4BqXzW16p8jRcjJOhs= From e0c63cbda11aaff5046d148abddab9c9f6b7b026 Mon Sep 17 00:00:00 2001 From: Simone Vellei Date: Wed, 19 Jun 2024 11:56:17 +0200 Subject: [PATCH 5/5] fix linting --- llm/cohere/cohere.go | 131 ++++++---------------------------------- llm/cohere/formatter.go | 130 +++++++++++++++++++++++++++++++++++++-- llm/cohere/function.go | 10 +-- tool/python/python.go | 1 + 4 files changed, 151 insertions(+), 121 deletions(-) diff --git a/llm/cohere/cohere.go b/llm/cohere/cohere.go index 3ea8f0b9..712e9bfa 100644 --- a/llm/cohere/cohere.go +++ b/llm/cohere/cohere.go @@ -8,7 +8,6 @@ import ( "os" "strings" - "github.com/google/uuid" coherego "github.com/henomis/cohere-go" "github.com/henomis/cohere-go/model" "github.com/henomis/cohere-go/request" @@ -121,6 +120,7 @@ func (c *Cohere) WithStream(callbackFn StreamCallbackFn) *Cohere { } // Completion returns the completion for the given prompt +// nolint:staticcheck func (c *Cohere) Completion(ctx context.Context, prompt string) (string, error) { resp := &response.Generate{} err := c.client.Generate( @@ -218,7 +218,7 @@ func (c *Cohere) Generate(ctx context.Context, t *thread.Thread) error { chatRequest := c.buildChatCompletionRequest(t) if len(c.functions) > 0 { - chatRequest.Tools = c.getChatCompletionRequestTools() + chatRequest.Tools = c.buildChatCompletionRequestTools() } generation, err := c.startObserveGeneration(ctx, t) @@ -255,6 +255,9 @@ func (c *Cohere) Generate(ctx context.Context, t *thread.Thread) error { func (c *Cohere) generate(ctx context.Context, t *thread.Thread, chatRequest *request.Chat) error { var response response.Chat + b, _ := json.MarshalIndent(chatRequest, "", " ") + fmt.Println(string(b)) + err := c.client.Chat( ctx, chatRequest, @@ -262,6 +265,11 @@ func (c *Cohere) generate(ctx context.Context, t *thread.Thread, chatRequest *re ) if err != nil { return fmt.Errorf("%w: %w", ErrCohereChat, err) + } else if !response.IsSuccess() { + if response.RawBody == nil { + return fmt.Errorf("%w: %d", ErrCohereChat, response.Code) + } + return fmt.Errorf("%w: %s", ErrCohereChat, *response.RawBody) } var messages []*thread.Message @@ -332,126 +340,27 @@ func (c *Cohere) stopObserveGeneration( ) } -func (o *Cohere) getChatCompletionRequestTools() []model.Tool { +func (c *Cohere) buildChatCompletionRequestTools() []model.Tool { tools := []model.Tool{} - for _, function := range o.functions { - tool := model.Tool{ - Name: function.Name, - Description: function.Description, - ParameterDefinitions: make(map[string]model.ToolParameterDefinition), - } - - functionProperties, ok := function.Parameters["properties"] - if !ok { - continue - } - - functionPropertiesAsMap, isMap := functionProperties.(map[string]interface{}) - if !isMap { - continue - } - - for k, v := range functionPropertiesAsMap { - valueAsMap, isValueMap := v.(map[string]interface{}) - if !isValueMap { - continue - } - - description := "" - descriptionValue, isDescriptionValue := valueAsMap["description"] - if isDescriptionValue { - descriptionAsString, isDescriptionString := descriptionValue.(string) - if isDescriptionString { - description = descriptionAsString - } - } - - argType := "" - argTypeValue, isArgTypeValue := valueAsMap["type"] - if isArgTypeValue { - argTypeAsString, isArgTypeString := argTypeValue.(string) - if isArgTypeString { - argType = argTypeAsString - } - } - - required := false - requiredValue, isRequiredValue := valueAsMap["required"] - if isRequiredValue { - requiredAsBool, isRequiredBool := requiredValue.(bool) - if isRequiredBool { - required = requiredAsBool - } - } - - if description == "" || argType == "" { - continue - } - - tool.ParameterDefinitions[k] = model.ToolParameterDefinition{ - Description: description, - Type: argType, - Required: required, - } + for _, function := range c.functions { + tool := buildChatCompletionRequestTool(function) + if tool != nil { + tools = append(tools, *tool) } - - tools = append(tools, tool) } return tools } -func toolCallsToToolCallMessage(toolCalls []model.ToolCall) *thread.Message { - if len(toolCalls) == 0 { - return nil - } - - var toolCallData []thread.ToolCallData - for i, toolCall := range toolCalls { - parametersAsString, err := json.Marshal(toolCall.Parameters) - if err != nil { - continue - } - - if toolCalls[i].ID == "" { - toolCalls[i].ID = uuid.New().String() - } - - toolCallData = append(toolCallData, thread.ToolCallData{ - ID: toolCalls[i].ID, - Name: toolCall.Name, - Arguments: string(parametersAsString), - }) - } - - return thread.NewAssistantMessage().AddContent( - thread.NewToolCallContent( - toolCallData, - ), - ) -} - -func toolCallResultToThreadMessage(toolCall model.ToolCall, result string) *thread.Message { - return thread.NewToolMessage().AddContent( - thread.NewToolResponseContent( - thread.ToolResponseData{ - ID: toolCall.ID, - Name: toolCall.Name, - Result: result, - }, - ), - ) -} - -func (o *Cohere) callTools(toolCalls []model.ToolCall) []*thread.Message { - if len(o.functions) == 0 || len(toolCalls) == 0 { +func (c *Cohere) callTools(toolCalls []model.ToolCall) []*thread.Message { + if len(c.functions) == 0 || len(toolCalls) == 0 { return nil } var messages []*thread.Message for _, toolCall := range toolCalls { - result, err := o.callTool(toolCall) + result, err := c.callTool(toolCall) if err != nil { result = fmt.Sprintf("error: %s", err) } @@ -462,8 +371,8 @@ func (o *Cohere) callTools(toolCalls []model.ToolCall) []*thread.Message { return messages } -func (o *Cohere) callTool(toolCall model.ToolCall) (string, error) { - fn, ok := o.functions[toolCall.Name] +func (c *Cohere) callTool(toolCall model.ToolCall) (string, error) { + fn, ok := c.functions[toolCall.Name] if !ok { return "", fmt.Errorf("unknown function %s", toolCall.Name) } diff --git a/llm/cohere/formatter.go b/llm/cohere/formatter.go index a1297292..5239ecb7 100644 --- a/llm/cohere/formatter.go +++ b/llm/cohere/formatter.go @@ -3,6 +3,7 @@ package cohere import ( "encoding/json" + "github.com/google/uuid" "github.com/henomis/cohere-go/model" "github.com/henomis/cohere-go/request" "github.com/henomis/lingoose/thread" @@ -29,14 +30,17 @@ func (c *Cohere) buildChatCompletionRequest(t *thread.Thread) *request.Chat { } } +// nolint:gocognit func threadToChatMessages(t *thread.Thread) (string, []model.ChatMessage, []model.ToolResult) { var history []model.ChatMessage var toolResults []model.ToolResult var message string for i, m := range t.Messages { - switch m.Role { + case thread.RoleTool: + //TODO: INSERT HERE TOOL RESULTS + continue case thread.RoleUser, thread.RoleSystem: chatMessage := model.ChatMessage{ Role: threadRoleToCohereRole[m.Role], @@ -51,6 +55,7 @@ func threadToChatMessages(t *thread.Thread) (string, []model.ChatMessage, []mode history = append(history, chatMessage) case thread.RoleAssistant: + // INSERT HERE TOOL CALLS // check if the message is a tool call if data, isTollCallData := m.Contents[0].Data.([]thread.ToolCallData); isTollCallData { toolResults = append(toolResults, threadToChatMessagesTool(data, t, i)...) @@ -69,7 +74,6 @@ func threadToChatMessages(t *thread.Thread) (string, []model.ChatMessage, []mode history = append(history, chatMessage) } - } lastMessage := t.LastMessage() @@ -99,7 +103,12 @@ func threadToChatMessagesTool(data []thread.ToolCallData, t *thread.Thread, inde return toolResults } -func extractToolResultFromThread(t *thread.Thread, toolCallData thread.ToolCallData, index, nToolCall int) *model.ToolResult { +func extractToolResultFromThread( + t *thread.Thread, + toolCallData thread.ToolCallData, + index, + nToolCall int, +) *model.ToolResult { messageIndex := index + 1 + nToolCall // check if the message index is within the bounds of the thread @@ -115,7 +124,6 @@ func extractToolResultFromThread(t *thread.Thread, toolCallData thread.ToolCallD } if data, isTollResponseData := message.Contents[0].Data.(thread.ToolResponseData); isTollResponseData { - argumentsAsMap := make(map[string]interface{}) err := json.Unmarshal([]byte(toolCallData.Arguments), &argumentsAsMap) if err != nil { @@ -137,8 +145,120 @@ func extractToolResultFromThread(t *thread.Thread, toolCallData thread.ToolCallD resultsAsMap, }, } - } return nil } + +func buildChatCompletionRequestTool(function Function) *model.Tool { + tool := model.Tool{ + Name: function.Name, + Description: function.Description, + ParameterDefinitions: make(map[string]model.ToolParameterDefinition), + } + + functionProperties, ok := function.Parameters["properties"] + if !ok { + return nil + } + + functionPropertiesAsMap, isMap := functionProperties.(map[string]interface{}) + if !isMap { + return nil + } + + for k, v := range functionPropertiesAsMap { + toolParameterDefinitions := buildFunctionParameterDefinitions(v) + if toolParameterDefinitions != nil { + tool.ParameterDefinitions[k] = *toolParameterDefinitions + } + } + + return &tool +} + +func buildFunctionParameterDefinitions(v any) *model.ToolParameterDefinition { + valueAsMap, isValueMap := v.(map[string]interface{}) + if !isValueMap { + return nil + } + + description := "" + descriptionValue, isDescriptionValue := valueAsMap["description"] + if isDescriptionValue { + descriptionAsString, isDescriptionString := descriptionValue.(string) + if isDescriptionString { + description = descriptionAsString + } + } + + argType := "" + argTypeValue, isArgTypeValue := valueAsMap["type"] + if isArgTypeValue { + argTypeAsString, isArgTypeString := argTypeValue.(string) + if isArgTypeString { + argType = argTypeAsString + } + } + + required := false + requiredValue, isRequiredValue := valueAsMap["required"] + if isRequiredValue { + requiredAsBool, isRequiredBool := requiredValue.(bool) + if isRequiredBool { + required = requiredAsBool + } + } + + if description == "" || argType == "" { + return nil + } + + return &model.ToolParameterDefinition{ + Description: description, + Type: argType, + Required: required, + } +} + +func toolCallsToToolCallMessage(toolCalls []model.ToolCall) *thread.Message { + if len(toolCalls) == 0 { + return nil + } + + var toolCallData []thread.ToolCallData + for i, toolCall := range toolCalls { + parametersAsString, err := json.Marshal(toolCall.Parameters) + if err != nil { + continue + } + + if toolCalls[i].ID == "" { + toolCalls[i].ID = uuid.New().String() + } + + toolCallData = append(toolCallData, thread.ToolCallData{ + ID: toolCalls[i].ID, + Name: toolCall.Name, + Arguments: string(parametersAsString), + }) + } + + return thread.NewAssistantMessage().AddContent( + thread.NewToolCallContent( + toolCallData, + ), + ) +} + +func toolCallResultToThreadMessage(toolCall model.ToolCall, result string) *thread.Message { + return thread.NewToolMessage().AddContent( + thread.NewToolResponseContent( + thread.ToolResponseData{ + ID: toolCall.ID, + Name: toolCall.Name, + Result: result, + }, + ), + ) +} diff --git a/llm/cohere/function.go b/llm/cohere/function.go index b562dc14..5cd8f0a9 100644 --- a/llm/cohere/function.go +++ b/llm/cohere/function.go @@ -46,7 +46,7 @@ func bindFunction( }, nil } -func (o *Cohere) BindFunction( +func (c *Cohere) BindFunction( fn interface{}, name string, description string, @@ -57,7 +57,7 @@ func (o *Cohere) BindFunction( return err } - o.functions[name] = *function + c.functions[name] = *function return nil } @@ -68,17 +68,17 @@ type Tool interface { Fn() any } -func (o *Cohere) WithTools(tools ...Tool) *Cohere { +func (c *Cohere) WithTools(tools ...Tool) *Cohere { for _, tool := range tools { function, err := bindFunction(tool.Fn(), tool.Name(), tool.Description()) if err != nil { fmt.Println(err) } - o.functions[tool.Name()] = *function + c.functions[tool.Name()] = *function } - return o + return c } func extractFunctionParameter(f interface{}) (map[string]interface{}, error) { diff --git a/tool/python/python.go b/tool/python/python.go index 3fd6b7ad..9c6153c4 100644 --- a/tool/python/python.go +++ b/tool/python/python.go @@ -22,6 +22,7 @@ func (t *Tool) WithPythonPath(pythonPath string) *Tool { } type Input struct { + // nolint:lll PythonCode string `json:"python_code" jsonschema:"description=python code that uses print() to print the final result to stdout."` }