diff --git a/.golangci.yml b/.golangci.yml index abd0e282..5249134b 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,17 +1,3 @@ -# Copyright 2013-2023 The Cobra Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - run: deadline: 5m diff --git a/assistant/assistant.go b/assistant/assistant.go index e591ebd1..660afc66 100644 --- a/assistant/assistant.go +++ b/assistant/assistant.go @@ -24,15 +24,15 @@ type observer interface { } const ( - DefaultMaxIterations = 3 + DefaultToolsMaxIterations = 1 ) type Assistant struct { - llm LLM - rag RAG - thread *thread.Thread - parameters Parameters - maxIterations uint + llm LLM + rag RAG + thread *thread.Thread + parameters Parameters + toolsMaxIterations uint } type LLM interface { @@ -54,7 +54,7 @@ func New(llm LLM) *Assistant { CompanyName: defaultCompanyName, CompanyDescription: defaultCompanyDescription, }, - maxIterations: DefaultMaxIterations, + toolsMaxIterations: DefaultToolsMaxIterations, } return assistant @@ -94,7 +94,7 @@ func (a *Assistant) Run(ctx context.Context) error { a.injectSystemMessage() } - for i := 0; i < int(a.maxIterations); i++ { + for i := 0; i < int(a.toolsMaxIterations); i++ { err = a.runIteration(ctx, i) if err != nil { return err @@ -181,8 +181,8 @@ func (a *Assistant) generateRAGMessage(ctx context.Context) error { return nil } -func (a *Assistant) WithMaxIterations(maxIterations uint) *Assistant { - a.maxIterations = maxIterations +func (a *Assistant) WithToolsMaxIterations(maxIterations uint) *Assistant { + a.toolsMaxIterations = maxIterations return a } diff --git a/examples/assistant/agent/main.go b/examples/assistant/tools/main.go similarity index 86% rename from examples/assistant/agent/main.go rename to examples/assistant/tools/main.go index 96441d56..e6904c1f 100644 --- a/examples/assistant/agent/main.go +++ b/examples/assistant/tools/main.go @@ -43,10 +43,10 @@ func main() { ).WithThread( thread.New().AddMessages( thread.NewUserMessage().AddContent( - thread.NewTextContent("search the top 3 italian dishes and then their costs, then ask the user's budget in euros and calculate how many guests can be invited for each dish"), + thread.NewTextContent("search the top 3 italian meals and then their costs, then ask the user's budget in euros and calculate how many guests can be invited for each meal"), ), ), - ).WithMaxIterations(10) + ).WithToolsMaxIterations(10) err = myAssistant.Run(ctx) if err != nil { diff --git a/go.mod b/go.mod index 4131f84e..56e33b5f 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require github.com/mitchellh/mapstructure v1.5.0 require ( github.com/RediSearch/redisearch-go/v2 v2.1.1 github.com/google/uuid v1.6.0 - github.com/henomis/cohere-go v1.1.2 + github.com/henomis/cohere-go v1.2.0 github.com/henomis/langfuse-go v0.0.3 github.com/henomis/milvus-go v0.0.4 github.com/henomis/pinecone-go/v2 v2.0.0 diff --git a/go.sum b/go.sum index a79137ce..34b87c26 100644 --- a/go.sum +++ b/go.sum @@ -7,8 +7,8 @@ github.com/gomodule/redigo v1.8.9 h1:Sl3u+2BI/kk+VEatbj0scLdrFhjPmbxOc1myhDP41ws github.com/gomodule/redigo v1.8.9/go.mod h1:7ArFNvsTjH8GMMzB4uy1snslv2BwmginuMs06a1uzZE= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= -github.com/henomis/cohere-go v1.1.2 h1:rzEA1JRm26RnaQValoVeVQ1y1nTofuhr/z/fPyhUW/4= -github.com/henomis/cohere-go v1.1.2/go.mod h1:z+UIgBbNCnLH5M47FYqg4h3CDWxjUv2VDfEwdTb95MU= +github.com/henomis/cohere-go v1.2.0 h1:rHi0JDdR/LYQPtQJrP0wKD6cLg8rFdJTW7KF5KjMG0s= +github.com/henomis/cohere-go v1.2.0/go.mod h1:z+UIgBbNCnLH5M47FYqg4h3CDWxjUv2VDfEwdTb95MU= github.com/henomis/langfuse-go v0.0.3 h1:Z5Mqlnj1fsok7eAT7jt+N4tegjbhY5HnZQ7wu1iVKss= github.com/henomis/langfuse-go v0.0.3/go.mod h1:gSRuO3nvjAvk/mgmb7b+9BcoN9s64GvXeaSN7PfVEKQ= github.com/henomis/milvus-go v0.0.4 h1:ArddXRJx/EGdQ75gB7TyEzD4Z4BqXzW16p8jRcjJOhs= diff --git a/llm/cohere/cohere.go b/llm/cohere/cohere.go index 1c53a04e..712e9bfa 100644 --- a/llm/cohere/cohere.go +++ b/llm/cohere/cohere.go @@ -2,6 +2,7 @@ package cohere import ( "context" + "encoding/json" "errors" "fmt" "os" @@ -31,6 +32,8 @@ const ( ModelCommandNightly Model = model.ModelCommandNightly ModelCommandLight Model = model.ModelCommandLight ModelCommandLightNightly Model = model.ModelCommandLightNightly + ModelCommandR Model = model.ModelCommandR + ModelCommandRPlus Model = model.ModelCommandRPlus ) const ( @@ -51,8 +54,7 @@ type Cohere struct { cache *cache.Cache streamCallbackFn StreamCallbackFn name string - observer llmobserver.LLMObserver - observerTraceID string + functions map[string]Function } func (c *Cohere) WithCache(cache *cache.Cache) *Cohere { @@ -72,6 +74,7 @@ func New() *Cohere { temperature: DefaultTemperature, maxTokens: DefaultMaxTokens, name: "cohere", + functions: make(map[string]Function), } } @@ -116,13 +119,8 @@ func (c *Cohere) WithStream(callbackFn StreamCallbackFn) *Cohere { return c } -func (c *Cohere) WithObserver(observer llmobserver.LLMObserver, traceID string) *Cohere { - c.observer = observer - c.observerTraceID = traceID - return c -} - // Completion returns the completion for the given prompt +// nolint:staticcheck func (c *Cohere) Completion(ctx context.Context, prompt string) (string, error) { resp := &response.Generate{} err := c.client.Generate( @@ -219,11 +217,17 @@ func (c *Cohere) Generate(ctx context.Context, t *thread.Thread) error { chatRequest := c.buildChatCompletionRequest(t) + if len(c.functions) > 0 { + chatRequest.Tools = c.buildChatCompletionRequestTools() + } + generation, err := c.startObserveGeneration(ctx, t) if err != nil { return fmt.Errorf("%w: %w", ErrCohereChat, err) } + nMessageBeforeGeneration := len(t.Messages) + if c.streamCallbackFn != nil { err = c.stream(ctx, t, chatRequest) } else { @@ -233,7 +237,7 @@ func (c *Cohere) Generate(ctx context.Context, t *thread.Thread) error { return err } - err = c.stopObserveGeneration(ctx, generation, []*thread.Message{t.LastMessage()}) + err = c.stopObserveGeneration(ctx, generation, t.Messages[nMessageBeforeGeneration:]) if err != nil { return fmt.Errorf("%w: %w", ErrCohereChat, err) } @@ -251,6 +255,9 @@ func (c *Cohere) Generate(ctx context.Context, t *thread.Thread) error { func (c *Cohere) generate(ctx context.Context, t *thread.Thread, chatRequest *request.Chat) error { var response response.Chat + b, _ := json.MarshalIndent(chatRequest, "", " ") + fmt.Println(string(b)) + err := c.client.Chat( ctx, chatRequest, @@ -258,11 +265,26 @@ func (c *Cohere) generate(ctx context.Context, t *thread.Thread, chatRequest *re ) if err != nil { return fmt.Errorf("%w: %w", ErrCohereChat, err) + } else if !response.IsSuccess() { + if response.RawBody == nil { + return fmt.Errorf("%w: %d", ErrCohereChat, response.Code) + } + return fmt.Errorf("%w: %s", ErrCohereChat, *response.RawBody) } - t.AddMessage(thread.NewAssistantMessage().AddContent( - thread.NewTextContent(response.Text), - )) + var messages []*thread.Message + if len(response.ToolCalls) > 0 { + messages = append(messages, toolCallsToToolCallMessage(response.ToolCalls)) + messages = append(messages, c.callTools(response.ToolCalls)...) + } else { + messages = []*thread.Message{ + thread.NewAssistantMessage().AddContent( + thread.NewTextContent(response.Text), + ), + } + } + + t.Messages = append(t.Messages, messages...) return nil } @@ -317,3 +339,54 @@ func (c *Cohere) stopObserveGeneration( messages, ) } + +func (c *Cohere) buildChatCompletionRequestTools() []model.Tool { + tools := []model.Tool{} + + for _, function := range c.functions { + tool := buildChatCompletionRequestTool(function) + if tool != nil { + tools = append(tools, *tool) + } + } + + return tools +} + +func (c *Cohere) callTools(toolCalls []model.ToolCall) []*thread.Message { + if len(c.functions) == 0 || len(toolCalls) == 0 { + return nil + } + + var messages []*thread.Message + for _, toolCall := range toolCalls { + result, err := c.callTool(toolCall) + if err != nil { + result = fmt.Sprintf("error: %s", err) + } + + messages = append(messages, toolCallResultToThreadMessage(toolCall, result)) + } + + return messages +} + +func (c *Cohere) callTool(toolCall model.ToolCall) (string, error) { + fn, ok := c.functions[toolCall.Name] + if !ok { + return "", fmt.Errorf("unknown function %s", toolCall.Name) + } + + parameters, err := json.Marshal(toolCall.Parameters) + if err != nil { + return "", err + } + + resultAsJSON, err := callFnWithArgumentAsJSON(fn.Fn, string(parameters)) + //resultAsJSON, err := callFnWithArgumentAsJSON(fn.Fn, "{}") + if err != nil { + return "", err + } + + return resultAsJSON, nil +} diff --git a/llm/cohere/formatter.go b/llm/cohere/formatter.go index f39bd16f..5239ecb7 100644 --- a/llm/cohere/formatter.go +++ b/llm/cohere/formatter.go @@ -1,6 +1,9 @@ package cohere import ( + "encoding/json" + + "github.com/google/uuid" "github.com/henomis/cohere-go/model" "github.com/henomis/cohere-go/request" "github.com/henomis/lingoose/thread" @@ -10,39 +13,67 @@ var threadRoleToCohereRole = map[thread.Role]model.ChatMessageRole{ thread.RoleSystem: model.ChatMessageRoleChatbot, thread.RoleUser: model.ChatMessageRoleUser, thread.RoleAssistant: model.ChatMessageRoleChatbot, + thread.RoleTool: model.ChatMessageRoleTool, } func (c *Cohere) buildChatCompletionRequest(t *thread.Thread) *request.Chat { - message, history := threadToChatMessages(t) + message, history, toolResults := threadToChatMessages(t) return &request.Chat{ - Model: c.model, - ChatHistory: history, - Message: message, + Model: c.model, + ChatHistory: history, + Message: message, + ToolResults: toolResults, + Temperature: &c.temperature, + MaxTokens: &c.maxTokens, + StopSequences: c.stop, } } -func threadToChatMessages(t *thread.Thread) (string, []model.ChatMessage) { +// nolint:gocognit +func threadToChatMessages(t *thread.Thread) (string, []model.ChatMessage, []model.ToolResult) { var history []model.ChatMessage + var toolResults []model.ToolResult var message string - for _, m := range t.Messages { - chatMessage := model.ChatMessage{ - Role: threadRoleToCohereRole[m.Role], - } - + for i, m := range t.Messages { switch m.Role { - case thread.RoleUser, thread.RoleSystem, thread.RoleAssistant: + case thread.RoleTool: + //TODO: INSERT HERE TOOL RESULTS + continue + case thread.RoleUser, thread.RoleSystem: + chatMessage := model.ChatMessage{ + Role: threadRoleToCohereRole[m.Role], + } + + for _, content := range m.Contents { + if content.Type == thread.ContentTypeText { + chatMessage.Message += content.Data.(string) + "\n" + } + } + + history = append(history, chatMessage) + + case thread.RoleAssistant: + // INSERT HERE TOOL CALLS + // check if the message is a tool call + if data, isTollCallData := m.Contents[0].Data.([]thread.ToolCallData); isTollCallData { + toolResults = append(toolResults, threadToChatMessagesTool(data, t, i)...) + continue + } + + chatMessage := model.ChatMessage{ + Role: threadRoleToCohereRole[m.Role], + } + for _, content := range m.Contents { if content.Type == thread.ContentTypeText { chatMessage.Message += content.Data.(string) + "\n" } } - case thread.RoleTool: - continue - } - history = append(history, chatMessage) + history = append(history, chatMessage) + } } lastMessage := t.LastMessage() @@ -56,5 +87,178 @@ func threadToChatMessages(t *thread.Thread) (string, []model.ChatMessage) { history = history[:len(history)-1] } - return message, history + return message, history, toolResults +} + +func threadToChatMessagesTool(data []thread.ToolCallData, t *thread.Thread, index int) []model.ToolResult { + var toolResults []model.ToolResult + + for nToolCall, toolCallData := range data { + toolResult := extractToolResultFromThread(t, toolCallData, index, nToolCall) + if toolResult != nil { + toolResults = append(toolResults, *toolResult) + } + } + + return toolResults +} + +func extractToolResultFromThread( + t *thread.Thread, + toolCallData thread.ToolCallData, + index, + nToolCall int, +) *model.ToolResult { + messageIndex := index + 1 + nToolCall + + // check if the message index is within the bounds of the thread + if messageIndex >= len(t.Messages) { + return nil + } + + message := t.Messages[messageIndex] + + // check if the message role is a tool + if message.Role != thread.RoleTool { + return nil + } + + if data, isTollResponseData := message.Contents[0].Data.(thread.ToolResponseData); isTollResponseData { + argumentsAsMap := make(map[string]interface{}) + err := json.Unmarshal([]byte(toolCallData.Arguments), &argumentsAsMap) + if err != nil { + return nil + } + + resultsAsMap := make(map[string]interface{}) + err = json.Unmarshal([]byte(data.Result), &resultsAsMap) + if err != nil { + return nil + } + + return &model.ToolResult{ + Call: model.ToolCall{ + Name: toolCallData.Name, + Parameters: argumentsAsMap, + }, + Outputs: []interface{}{ + resultsAsMap, + }, + } + } + + return nil +} + +func buildChatCompletionRequestTool(function Function) *model.Tool { + tool := model.Tool{ + Name: function.Name, + Description: function.Description, + ParameterDefinitions: make(map[string]model.ToolParameterDefinition), + } + + functionProperties, ok := function.Parameters["properties"] + if !ok { + return nil + } + + functionPropertiesAsMap, isMap := functionProperties.(map[string]interface{}) + if !isMap { + return nil + } + + for k, v := range functionPropertiesAsMap { + toolParameterDefinitions := buildFunctionParameterDefinitions(v) + if toolParameterDefinitions != nil { + tool.ParameterDefinitions[k] = *toolParameterDefinitions + } + } + + return &tool +} + +func buildFunctionParameterDefinitions(v any) *model.ToolParameterDefinition { + valueAsMap, isValueMap := v.(map[string]interface{}) + if !isValueMap { + return nil + } + + description := "" + descriptionValue, isDescriptionValue := valueAsMap["description"] + if isDescriptionValue { + descriptionAsString, isDescriptionString := descriptionValue.(string) + if isDescriptionString { + description = descriptionAsString + } + } + + argType := "" + argTypeValue, isArgTypeValue := valueAsMap["type"] + if isArgTypeValue { + argTypeAsString, isArgTypeString := argTypeValue.(string) + if isArgTypeString { + argType = argTypeAsString + } + } + + required := false + requiredValue, isRequiredValue := valueAsMap["required"] + if isRequiredValue { + requiredAsBool, isRequiredBool := requiredValue.(bool) + if isRequiredBool { + required = requiredAsBool + } + } + + if description == "" || argType == "" { + return nil + } + + return &model.ToolParameterDefinition{ + Description: description, + Type: argType, + Required: required, + } +} + +func toolCallsToToolCallMessage(toolCalls []model.ToolCall) *thread.Message { + if len(toolCalls) == 0 { + return nil + } + + var toolCallData []thread.ToolCallData + for i, toolCall := range toolCalls { + parametersAsString, err := json.Marshal(toolCall.Parameters) + if err != nil { + continue + } + + if toolCalls[i].ID == "" { + toolCalls[i].ID = uuid.New().String() + } + + toolCallData = append(toolCallData, thread.ToolCallData{ + ID: toolCalls[i].ID, + Name: toolCall.Name, + Arguments: string(parametersAsString), + }) + } + + return thread.NewAssistantMessage().AddContent( + thread.NewToolCallContent( + toolCallData, + ), + ) +} + +func toolCallResultToThreadMessage(toolCall model.ToolCall, result string) *thread.Message { + return thread.NewToolMessage().AddContent( + thread.NewToolResponseContent( + thread.ToolResponseData{ + ID: toolCall.ID, + Name: toolCall.Name, + Result: result, + }, + ), + ) } diff --git a/llm/cohere/function.go b/llm/cohere/function.go new file mode 100644 index 00000000..5cd8f0a9 --- /dev/null +++ b/llm/cohere/function.go @@ -0,0 +1,191 @@ +package cohere + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "reflect" + "strings" + + "github.com/invopop/jsonschema" +) + +type Function struct { + Name string + Description string + Parameters map[string]interface{} + Fn interface{} +} + +type FunctionParameterOption func(map[string]interface{}) error + +func bindFunction( + fn interface{}, + name string, + description string, + functionParameterOptions ...FunctionParameterOption, +) (*Function, error) { + parameter, err := extractFunctionParameter(fn) + if err != nil { + return nil, err + } + + for _, option := range functionParameterOptions { + err = option(parameter) + if err != nil { + return nil, err + } + } + + return &Function{ + Name: name, + Description: description, + Parameters: parameter, + Fn: fn, + }, nil +} + +func (c *Cohere) BindFunction( + fn interface{}, + name string, + description string, + functionParameterOptions ...FunctionParameterOption, +) error { + function, err := bindFunction(fn, name, description, functionParameterOptions...) + if err != nil { + return err + } + + c.functions[name] = *function + + return nil +} + +type Tool interface { + Description() string + Name() string + Fn() any +} + +func (c *Cohere) WithTools(tools ...Tool) *Cohere { + for _, tool := range tools { + function, err := bindFunction(tool.Fn(), tool.Name(), tool.Description()) + if err != nil { + fmt.Println(err) + } + + c.functions[tool.Name()] = *function + } + + return c +} + +func extractFunctionParameter(f interface{}) (map[string]interface{}, error) { + // Get the type of the input function + fnType := reflect.TypeOf(f) + + if fnType.Kind() != reflect.Func { + return nil, errors.New("input must be a function") + } + + // Check that the function only has one argument + if fnType.NumIn() != 1 { + return nil, errors.New("function must have exactly one argument") + } + + // Check that the argument is of type struct + argType := fnType.In(0) + if argType.Kind() != reflect.Struct { + return nil, errors.New("argument must be of type struct") + } + + // Create a new instance of the argument type + argValue := reflect.New(argType).Elem().Interface() + + parameter, err := structAsJSONSchema(argValue) + if err != nil { + return nil, err + } + + return parameter, nil +} + +func structAsJSONSchema(v interface{}) (map[string]interface{}, error) { + r := new(jsonschema.Reflector) + r.DoNotReference = true + schema := r.Reflect(v) + + b, err := json.Marshal(schema) + if err != nil { + return nil, err + } + + var jsonSchema map[string]interface{} + err = json.Unmarshal(b, &jsonSchema) + if err != nil { + return nil, err + } + + delete(jsonSchema, "$schema") + + return jsonSchema, nil +} + +func callFnWithArgumentAsJSON(fn interface{}, argumentAsJSON string) (string, error) { + // Get the type of the input function + fnType := reflect.TypeOf(fn) + + // Check that the function has one argument + if fnType.NumIn() != 1 { + return "", fmt.Errorf("function must have one argument") + } + + // Check that the argument is a struct + argType := fnType.In(0) + if argType.Kind() != reflect.Struct { + return "", fmt.Errorf("argument must be a struct") + } + + // Create a slice to hold the function argument + args := make([]reflect.Value, 1) + + // Unmarshal the JSON string into an interface{} value + var argValue interface{} + err := json.Unmarshal([]byte(argumentAsJSON), &argValue) + if err != nil { + return "", fmt.Errorf("error unmarshaling argument: %w", err) + } + + // Convert the argument value to the correct type + argValueReflect := reflect.New(argType).Elem() + jsonData, err := json.Marshal(argValue) + if err != nil { + return "", fmt.Errorf("error marshaling argument: %w", err) + } + err = json.Unmarshal(jsonData, argValueReflect.Addr().Interface()) + if err != nil { + return "", fmt.Errorf("error unmarshaling argument: %w", err) + } + + // Add the argument value to the slice + args[0] = argValueReflect + + // Call the function with the argument + fnValue := reflect.ValueOf(fn) + result := fnValue.Call(args) + + // Marshal the function result to JSON + if len(result) > 0 { + var resultBytes bytes.Buffer + enc := json.NewEncoder(&resultBytes) + enc.SetEscapeHTML(false) + err = enc.Encode(result[0].Interface()) + if err != nil { + return "", fmt.Errorf("error marshaling result: %w", err) + } + return strings.TrimSpace(resultBytes.String()), nil + } + + return "", nil +} diff --git a/tool/python/python.go b/tool/python/python.go index 3fd6b7ad..9c6153c4 100644 --- a/tool/python/python.go +++ b/tool/python/python.go @@ -22,6 +22,7 @@ func (t *Tool) WithPythonPath(pythonPath string) *Tool { } type Input struct { + // nolint:lll PythonCode string `json:"python_code" jsonschema:"description=python code that uses print() to print the final result to stdout."` }