Skip to content

Commit

Permalink
chore: use openai tools (#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis authored Dec 17, 2023
1 parent cfbe53f commit 0ee4e34
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 42 deletions.
2 changes: 1 addition & 1 deletion examples/pipeline/summarize/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
func main() {

summarize := summarizepipeline.New(
openai.NewCompletion().WithMaxTokens(1000).WithVerbose(true).WithModel(openai.GPT3TextDavinci002),
openai.NewCompletion().WithMaxTokens(1000).WithVerbose(true).WithModel(openai.GPT3Dot5TurboInstruct),
loader.NewTextLoader("state_of_the_union.txt", nil).
WithTextSplitter(textsplitter.NewRecursiveCharacterTextSplitter(2000, 0)),
)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ require (
github.com/henomis/qdrant-go v1.1.0
github.com/invopop/jsonschema v0.7.0
github.com/pkoukk/tiktoken-go v0.1.1
github.com/sashabaranov/go-openai v1.12.0
github.com/sashabaranov/go-openai v1.17.9
)

require (
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ github.com/pkoukk/tiktoken-go v0.1.1 h1:jtkYlIECjyM9OW1w4rjPmTohK4arORP9V25y6TM6
github.com/pkoukk/tiktoken-go v0.1.1/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/sashabaranov/go-openai v1.12.0 h1:aRNHH0gtVfrpIaEolD0sWrLLRnYQNK4cH/bIAHwL8Rk=
github.com/sashabaranov/go-openai v1.12.0/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/sashabaranov/go-openai v1.17.9 h1:QEoBiGKWW68W79YIfXWEFZ7l5cEgZBV4/Ow3uy+5hNY=
github.com/sashabaranov/go-openai v1.17.9/go.mod h1:lj5b/K+zjTSFxVLijLSTDZuP7adOgerWeFyZLUhAKRg=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/testify v1.3.1-0.20190311161405-34c6fa2dc709/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
Expand Down
27 changes: 17 additions & 10 deletions llm/openai/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,18 +49,21 @@ func (o *OpenAI) BindFunction(
return nil
}

func (o *OpenAI) getFunctions() []openai.FunctionDefinition {
functions := []openai.FunctionDefinition{}
func (o *OpenAI) getFunctions() []openai.Tool {
tools := []openai.Tool{}

for _, function := range o.functions {
functions = append(functions, openai.FunctionDefinition{
Name: function.Name,
Description: function.Description,
Parameters: function.Parameters,
tools = append(tools, openai.Tool{
Type: "function",
Function: openai.FunctionDefinition{
Name: function.Name,
Description: function.Description,
Parameters: function.Parameters,
},
})
}

return functions
return tools
}

func extractFunctionParameter(f interface{}) (map[string]interface{}, error) {
Expand Down Expand Up @@ -170,12 +173,16 @@ func callFnWithArgumentAsJSON(fn interface{}, argumentAsJSON string) (string, er
}

func (o *OpenAI) functionCall(response openai.ChatCompletionResponse) (string, error) {
fn, ok := o.functions[response.Choices[0].Message.FunctionCall.Name]
fn, ok := o.functions[response.Choices[0].Message.ToolCalls[0].Function.Name]
if !ok {
return "", fmt.Errorf("%w: unknown function %s", ErrOpenAIChat, response.Choices[0].Message.FunctionCall.Name)
return "", fmt.Errorf(
"%w: unknown function %s",
ErrOpenAIChat,
response.Choices[0].Message.ToolCalls[0].Function.Name,
)
}

resultAsJSON, err := callFnWithArgumentAsJSON(fn.Fn, response.Choices[0].Message.FunctionCall.Arguments)
resultAsJSON, err := callFnWithArgumentAsJSON(fn.Fn, response.Choices[0].Message.ToolCalls[0].Function.Arguments)
if err != nil {
return "", fmt.Errorf("%w: %w", ErrOpenAIChat, err)
}
Expand Down
72 changes: 44 additions & 28 deletions llm/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,29 +32,29 @@ const (
type Model string

const (
GPT432K0613 Model = openai.GPT432K0613
GPT432K0314 Model = openai.GPT432K0314
GPT432K Model = openai.GPT432K
GPT40613 Model = openai.GPT40613
GPT40314 Model = openai.GPT40314
GPT4 Model = openai.GPT4
GPT3Dot5Turbo0613 Model = openai.GPT3Dot5Turbo0613
GPT3Dot5Turbo0301 Model = openai.GPT3Dot5Turbo0301
GPT3Dot5Turbo16K Model = openai.GPT3Dot5Turbo16K
GPT3Dot5Turbo16K0613 Model = openai.GPT3Dot5Turbo16K0613
GPT3Dot5Turbo Model = openai.GPT3Dot5Turbo
GPT3TextDavinci003 Model = openai.GPT3TextDavinci003
GPT3TextDavinci002 Model = openai.GPT3TextDavinci002
GPT3TextCurie001 Model = openai.GPT3TextCurie001
GPT3TextBabbage001 Model = openai.GPT3TextBabbage001
GPT3TextAda001 Model = openai.GPT3TextAda001
GPT3TextDavinci001 Model = openai.GPT3TextDavinci001
GPT3DavinciInstructBeta Model = openai.GPT3DavinciInstructBeta
GPT3Davinci Model = openai.GPT3Davinci
GPT3CurieInstructBeta Model = openai.GPT3CurieInstructBeta
GPT3Curie Model = openai.GPT3Curie
GPT3Ada Model = openai.GPT3Ada
GPT3Babbage Model = openai.GPT3Babbage
GPT432K0613 Model = openai.GPT432K0613
GPT432K0314 Model = openai.GPT432K0314
GPT432K Model = openai.GPT432K
GPT40613 Model = openai.GPT40613
GPT40314 Model = openai.GPT40314
GPT4TurboPreview Model = openai.GPT4TurboPreview
GPT4VisionPreview Model = openai.GPT4VisionPreview
GPT4 Model = openai.GPT4
GPT3Dot5Turbo1106 Model = openai.GPT3Dot5Turbo1106
GPT3Dot5Turbo0613 Model = openai.GPT3Dot5Turbo0613
GPT3Dot5Turbo0301 Model = openai.GPT3Dot5Turbo0301
GPT3Dot5Turbo16K Model = openai.GPT3Dot5Turbo16K
GPT3Dot5Turbo16K0613 Model = openai.GPT3Dot5Turbo16K0613
GPT3Dot5Turbo Model = openai.GPT3Dot5Turbo
GPT3Dot5TurboInstruct Model = openai.GPT3Dot5TurboInstruct
GPT3Davinci Model = openai.GPT3Davinci
GPT3Davinci002 Model = openai.GPT3Davinci002
GPT3Curie Model = openai.GPT3Curie
GPT3Curie002 Model = openai.GPT3Curie002
GPT3Ada Model = openai.GPT3Ada
GPT3Ada002 Model = openai.GPT3Ada002
GPT3Babbage Model = openai.GPT3Babbage
GPT3Babbage002 Model = openai.GPT3Babbage002
)

type UsageCallback func(types.Meta)
Expand All @@ -70,6 +70,7 @@ type OpenAI struct {
usageCallback UsageCallback
functions map[string]Function
functionsMaxIterations uint
toolChoice *string
calledFunctionName *string
finishReason string
cache *cache.Cache
Expand Down Expand Up @@ -137,6 +138,11 @@ func (o *OpenAI) WithCompletionCache(cache *cache.Cache) *OpenAI {
return o
}

func (o *OpenAI) WithToolChoice(toolChoice string) *OpenAI {
o.toolChoice = &toolChoice
return o
}

// CalledFunctionName returns the name of the function that was called.
func (o *OpenAI) CalledFunctionName() *string {
return o.calledFunctionName
Expand All @@ -149,7 +155,7 @@ func (o *OpenAI) FinishReason() string {

func NewCompletion() *OpenAI {
return New(
GPT3TextDavinci003,
GPT3Dot5TurboInstruct,
DefaultOpenAITemperature,
DefaultOpenAIMaxTokens,
false,
Expand Down Expand Up @@ -308,7 +314,17 @@ func (o *OpenAI) Chat(ctx context.Context, prompt *chat.Chat) (string, error) {
}

if len(o.functions) > 0 {
chatCompletionRequest.Functions = o.getFunctions()
chatCompletionRequest.Tools = o.getFunctions()
if o.toolChoice != nil {
chatCompletionRequest.ToolChoice = openai.ToolChoice{
Type: openai.ToolTypeFunction,
Function: openai.ToolFunction{
Name: *o.toolChoice,
},
}
} else {
chatCompletionRequest.ToolChoice = "auto"
}
}

response, err := o.openAIClient.CreateChatCompletion(
Expand All @@ -332,10 +348,10 @@ func (o *OpenAI) Chat(ctx context.Context, prompt *chat.Chat) (string, error) {

o.finishReason = string(response.Choices[0].FinishReason)
o.calledFunctionName = nil
if response.Choices[0].FinishReason == "function_call" && len(o.functions) > 0 {
if len(response.Choices[0].Message.ToolCalls) > 0 && len(o.functions) > 0 {
if o.verbose {
fmt.Printf("Calling function %s\n", response.Choices[0].Message.FunctionCall.Name)
fmt.Printf("Function call arguments: %s\n", response.Choices[0].Message.FunctionCall.Arguments)
fmt.Printf("Calling function %s\n", response.Choices[0].Message.ToolCalls[0].Function.Name)
fmt.Printf("Function call arguments: %s\n", response.Choices[0].Message.ToolCalls[0].Function.Arguments)
}

content, err = o.functionCall(response)
Expand Down

0 comments on commit 0ee4e34

Please sign in to comment.