Skip to content

Commit

Permalink
Fix function calling for streaming (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
akshaylb authored Apr 29, 2024
1 parent 21dd700 commit 2ca67ff
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 17 deletions.
74 changes: 74 additions & 0 deletions examples/llm/openai/thread-stream/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package main

import (
"context"
"fmt"

"github.com/henomis/lingoose/llm/openai"
"github.com/henomis/lingoose/thread"
)

type Answer struct {
Answer string `json:"answer" jsonschema:"description=the pirate answer"`
}

func getAnswer(a Answer) string {
return "🦜 ☠️ " + a.Answer
}

func newStr(str string) *string {
return &str
}

func main() {
openaillm := openai.New()
openaillm.WithToolChoice(newStr("getPirateAnswer"))
err := openaillm.BindFunction(
getAnswer,
"getPirateAnswer",
"use this function to get the pirate answer",
)
if err != nil {
panic(err)
}
openaillm.WithStream(true, func(a string) {
if a == openai.EOS {
fmt.Printf("\n")
return
}
fmt.Printf("%s", a)
})

t := thread.New().AddMessage(
thread.NewUserMessage().AddContent(
thread.NewTextContent("Hello, I'm a user"),
).AddContent(
thread.NewTextContent("Can you greet me?"),
),
).AddMessage(
thread.NewUserMessage().AddContent(
thread.NewTextContent("please greet me as a pirate."),
),
)

fmt.Println(t)

err = openaillm.Generate(context.Background(), t)
if err != nil {
panic(err)
}

t.AddMessage(thread.NewUserMessage().AddContent(
thread.NewTextContent("now translate to italian as a poem"),
))

fmt.Println(t)
// disable functions
openaillm.WithToolChoice(nil)
err = openaillm.Generate(context.Background(), t)
if err != nil {
panic(err)
}

fmt.Println(t)
}
14 changes: 7 additions & 7 deletions llm/openai/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ func bindFunction(
fn interface{},
name string,
description string,
functionParamenterOptions ...FunctionParameterOption,
functionParameterOptions ...FunctionParameterOption,
) (*Function, error) {
parameter, err := extractFunctionParameter(fn)
if err != nil {
return nil, err
}

for _, option := range functionParamenterOptions {
for _, option := range functionParameterOptions {
err = option(parameter)
if err != nil {
return nil, err
Expand All @@ -49,9 +49,9 @@ func (o *Legacy) BindFunction(
fn interface{},
name string,
description string,
functionParamenterOptions ...FunctionParameterOption,
functionParameterOptions ...FunctionParameterOption,
) error {
function, err := bindFunction(fn, name, description, functionParamenterOptions...)
function, err := bindFunction(fn, name, description, functionParameterOptions...)
if err != nil {
return err
}
Expand All @@ -65,9 +65,9 @@ func (o *OpenAI) BindFunction(
fn interface{},
name string,
description string,
functionParamenterOptions ...FunctionParameterOption,
functionParameterOptions ...FunctionParameterOption,
) error {
function, err := bindFunction(fn, name, description, functionParamenterOptions...)
function, err := bindFunction(fn, name, description, functionParameterOptions...)
if err != nil {
return err
}
Expand All @@ -78,7 +78,7 @@ func (o *OpenAI) BindFunction(
}

func (o *Legacy) getFunctions() []openai.FunctionDefinition {
functions := []openai.FunctionDefinition{}
var functions []openai.FunctionDefinition

for _, function := range o.functions {
functions = append(functions, openai.FunctionDefinition{
Expand Down
62 changes: 52 additions & 10 deletions llm/openai/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,46 @@ func (o *OpenAI) Generate(ctx context.Context, t *thread.Thread) error {
return nil
}

func isStreamToolCallResponse(response *openai.ChatCompletionStreamResponse) bool {
return response.Choices[0].FinishReason == openai.FinishReasonToolCalls || len(response.Choices[0].Delta.ToolCalls) > 0
}

func handleStreamToolCallResponse(
response *openai.ChatCompletionStreamResponse,
currentToolCall *openai.ToolCall,
) (updatedToolCall openai.ToolCall, isNewTool bool) {
if len(response.Choices[0].Delta.ToolCalls) == 0 {
return
}
if response.Choices[0].Delta.ToolCalls[0].ID != "" {
isNewTool = true
updatedToolCall = response.Choices[0].Delta.ToolCalls[0]
} else {
currentToolCall.Function.Arguments += response.Choices[0].Delta.ToolCalls[0].Function.Arguments
}
return
}

func (o *OpenAI) handleEndOfStream(
messages []*thread.Message,
content string,
currentToolCall *openai.ToolCall,
allToolCalls []openai.ToolCall,
) []*thread.Message {
o.streamCallbackFn(EOS)
if len(content) > 0 {
messages = append(messages, thread.NewAssistantMessage().AddContent(
thread.NewTextContent(content),
))
}
if currentToolCall.ID != "" {
allToolCalls = append(allToolCalls, *currentToolCall)
messages = append(messages, toolCallsToToolCallMessage(allToolCalls))
messages = append(messages, o.callTools(allToolCalls)...)
}
return messages
}

func (o *OpenAI) stream(
ctx context.Context,
t *thread.Thread,
Expand All @@ -219,27 +259,29 @@ func (o *OpenAI) stream(
return fmt.Errorf("%w: %w", ErrOpenAIChat, err)
}

var messages []*thread.Message
var content string
var messages []*thread.Message
var allToolCalls []openai.ToolCall
var currentToolCall openai.ToolCall
for {
response, errRecv := stream.Recv()
if errors.Is(errRecv, io.EOF) {
o.streamCallbackFn(EOS)

if len(content) > 0 {
messages = append(messages, thread.NewAssistantMessage().AddContent(
thread.NewTextContent(content),
))
}
messages = o.handleEndOfStream(messages, content, &currentToolCall, allToolCalls)
break
}

if len(response.Choices) == 0 {
return fmt.Errorf("%w: no choices returned", ErrOpenAIChat)
}

if response.Choices[0].FinishReason == "tool_calls" || len(response.Choices[0].Delta.ToolCalls) > 0 {
messages = append(messages, o.callTools(response.Choices[0].Delta.ToolCalls)...)
if isStreamToolCallResponse(&response) {
updatedToolCall, isNewTool := handleStreamToolCallResponse(&response, &currentToolCall)
if isNewTool {
if currentToolCall.ID != "" {
allToolCalls = append(allToolCalls, currentToolCall)
}
currentToolCall = updatedToolCall
}
} else {
content += response.Choices[0].Delta.Content
}
Expand Down

0 comments on commit 2ca67ff

Please sign in to comment.