Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
henomis committed Dec 21, 2023
1 parent 5d52a23 commit 4724622
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 26 deletions.
3 changes: 3 additions & 0 deletions llm/openai/formatters.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"github.com/sashabaranov/go-openai"
)

//nolint:gocognit
func threadToChatCompletionMessages(t *thread.Thread) []openai.ChatCompletionMessage {
chatCompletionMessages := make([]openai.ChatCompletionMessage, len(t.Messages))
for i, message := range t.Messages {
Expand Down Expand Up @@ -77,6 +78,8 @@ func threadContentsToChatMessageParts(m *thread.Message) []openai.ChatMessagePar
Detail: openai.ImageURLDetailAuto,
},
}
case thread.ContentTypeToolCall, thread.ContentTypeToolResponse:
continue
default:
continue
}
Expand Down
6 changes: 3 additions & 3 deletions llm/openai/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func bindFunction(
}, nil
}

func (o *OpenAILegacy) BindFunction(
func (o *Legacy) BindFunction(
fn interface{},
name string,
description string,
Expand Down Expand Up @@ -77,7 +77,7 @@ func (o *OpenAI) BindFunction(
return nil
}

func (o *OpenAILegacy) getFunctions() []openai.FunctionDefinition {
func (o *Legacy) getFunctions() []openai.FunctionDefinition {
functions := []openai.FunctionDefinition{}

for _, function := range o.functions {
Expand Down Expand Up @@ -197,7 +197,7 @@ func callFnWithArgumentAsJSON(fn interface{}, argumentAsJSON string) (string, er
return "", nil
}

func (o *OpenAILegacy) functionCall(response openai.ChatCompletionResponse) (string, error) {
func (o *Legacy) functionCall(response openai.ChatCompletionResponse) (string, error) {
fn, ok := o.functions[response.Choices[0].Message.FunctionCall.Name]
if !ok {
return "", fmt.Errorf("%w: unknown function %s", ErrOpenAIChat, response.Choices[0].Message.FunctionCall.Name)
Expand Down
46 changes: 23 additions & 23 deletions llm/openai/legacy.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import (
"github.com/sashabaranov/go-openai"
)

type OpenAILegacy struct {
type Legacy struct {
openAIClient *openai.Client
model Model
temperature float32
Expand All @@ -32,59 +32,59 @@ type OpenAILegacy struct {
}

// WithModel sets the model to use for the OpenAI instance.
func (o *OpenAILegacy) WithModel(model Model) *OpenAILegacy {
func (o *Legacy) WithModel(model Model) *Legacy {
o.model = model
return o
}

// WithTemperature sets the temperature to use for the OpenAI instance.
func (o *OpenAILegacy) WithTemperature(temperature float32) *OpenAILegacy {
func (o *Legacy) WithTemperature(temperature float32) *Legacy {
o.temperature = temperature
return o
}

// WithMaxTokens sets the max tokens to use for the OpenAI instance.
func (o *OpenAILegacy) WithMaxTokens(maxTokens int) *OpenAILegacy {
func (o *Legacy) WithMaxTokens(maxTokens int) *Legacy {
o.maxTokens = maxTokens
return o
}

// WithUsageCallback sets the usage callback to use for the OpenAI instance.
func (o *OpenAILegacy) WithCallback(callback UsageCallback) *OpenAILegacy {
func (o *Legacy) WithCallback(callback UsageCallback) *Legacy {
o.usageCallback = callback
return o
}

// WithStop sets the stop sequences to use for the OpenAI instance.
func (o *OpenAILegacy) WithStop(stop []string) *OpenAILegacy {
func (o *Legacy) WithStop(stop []string) *Legacy {
o.stop = stop
return o
}

// WithClient sets the client to use for the OpenAI instance.
func (o *OpenAILegacy) WithClient(client *openai.Client) *OpenAILegacy {
func (o *Legacy) WithClient(client *openai.Client) *Legacy {
o.openAIClient = client
return o
}

// WithVerbose sets the verbose flag to use for the OpenAI instance.
func (o *OpenAILegacy) WithVerbose(verbose bool) *OpenAILegacy {
func (o *Legacy) WithVerbose(verbose bool) *Legacy {
o.verbose = verbose
return o
}

// WithCache sets the cache to use for the OpenAI instance.
func (o *OpenAILegacy) WithCompletionCache(cache *cache.Cache) *OpenAILegacy {
func (o *Legacy) WithCompletionCache(cache *cache.Cache) *Legacy {
o.cache = cache
return o
}

// SetStop sets the stop sequences for the completion.
func (o *OpenAILegacy) SetStop(stop []string) {
func (o *Legacy) SetStop(stop []string) {
o.stop = stop
}

func (o *OpenAILegacy) setUsageMetadata(usage openai.Usage) {
func (o *Legacy) setUsageMetadata(usage openai.Usage) {
callbackMetadata := make(types.Meta)

err := mapstructure.Decode(usage, &callbackMetadata)
Expand All @@ -95,10 +95,10 @@ func (o *OpenAILegacy) setUsageMetadata(usage openai.Usage) {
o.usageCallback(callbackMetadata)
}

func NewLegacy(model Model, temperature float32, maxTokens int, verbose bool) *OpenAILegacy {
func NewLegacy(model Model, temperature float32, maxTokens int, verbose bool) *Legacy {
openAIKey := os.Getenv("OPENAI_API_KEY")

return &OpenAILegacy{
return &Legacy{
openAIClient: openai.NewClient(openAIKey),
model: model,
temperature: temperature,
Expand All @@ -110,16 +110,16 @@ func NewLegacy(model Model, temperature float32, maxTokens int, verbose bool) *O
}

// CalledFunctionName returns the name of the function that was called.
func (o *OpenAILegacy) CalledFunctionName() *string {
func (o *Legacy) CalledFunctionName() *string {
return o.calledFunctionName
}

// FinishReason returns the LLM finish reason.
func (o *OpenAILegacy) FinishReason() string {
func (o *Legacy) FinishReason() string {
return o.finishReason
}

func NewCompletion() *OpenAILegacy {
func NewCompletion() *Legacy {
return NewLegacy(
GPT3Dot5TurboInstruct,
DefaultOpenAITemperature,
Expand All @@ -128,7 +128,7 @@ func NewCompletion() *OpenAILegacy {
)
}

func NewChat() *OpenAILegacy {
func NewChat() *Legacy {
return NewLegacy(
GPT3Dot5Turbo,
DefaultOpenAITemperature,
Expand All @@ -138,7 +138,7 @@ func NewChat() *OpenAILegacy {
}

// Completion returns a single completion for the given prompt.
func (o *OpenAILegacy) Completion(ctx context.Context, prompt string) (string, error) {
func (o *Legacy) Completion(ctx context.Context, prompt string) (string, error) {
var cacheResult *cache.Result
var err error

Expand Down Expand Up @@ -167,7 +167,7 @@ func (o *OpenAILegacy) Completion(ctx context.Context, prompt string) (string, e
}

// BatchCompletion returns multiple completions for the given prompts.
func (o *OpenAILegacy) BatchCompletion(ctx context.Context, prompts []string) ([]string, error) {
func (o *Legacy) BatchCompletion(ctx context.Context, prompts []string) ([]string, error) {
response, err := o.openAIClient.CreateCompletion(
ctx,
openai.CompletionRequest{
Expand Down Expand Up @@ -206,12 +206,12 @@ func (o *OpenAILegacy) BatchCompletion(ctx context.Context, prompts []string) ([
}

// CompletionStream returns a single completion stream for the given prompt.
func (o *OpenAILegacy) CompletionStream(ctx context.Context, callbackFn StreamCallback, prompt string) error {
func (o *Legacy) CompletionStream(ctx context.Context, callbackFn StreamCallback, prompt string) error {
return o.BatchCompletionStream(ctx, []StreamCallback{callbackFn}, []string{prompt})
}

// BatchCompletionStream returns multiple completion streams for the given prompts.
func (o *OpenAILegacy) BatchCompletionStream(ctx context.Context, callbackFn []StreamCallback, prompts []string) error {
func (o *Legacy) BatchCompletionStream(ctx context.Context, callbackFn []StreamCallback, prompts []string) error {
stream, err := o.openAIClient.CreateCompletionStream(
ctx,
openai.CompletionRequest{
Expand Down Expand Up @@ -263,7 +263,7 @@ func (o *OpenAILegacy) BatchCompletionStream(ctx context.Context, callbackFn []S
}

// Chat returns a single chat completion for the given prompt.
func (o *OpenAILegacy) Chat(ctx context.Context, prompt *chat.Chat) (string, error) {
func (o *Legacy) Chat(ctx context.Context, prompt *chat.Chat) (string, error) {
messages, err := buildMessages(prompt)
if err != nil {
return "", fmt.Errorf("%w: %w", ErrOpenAIChat, err)
Expand Down Expand Up @@ -324,7 +324,7 @@ func (o *OpenAILegacy) Chat(ctx context.Context, prompt *chat.Chat) (string, err
}

// ChatStream returns a single chat stream for the given prompt.
func (o *OpenAILegacy) ChatStream(ctx context.Context, callbackFn StreamCallback, prompt *chat.Chat) error {
func (o *Legacy) ChatStream(ctx context.Context, callbackFn StreamCallback, prompt *chat.Chat) error {
messages, err := buildMessages(prompt)
if err != nil {
return fmt.Errorf("%w: %w", ErrOpenAIChat, err)
Expand Down

0 comments on commit 4724622

Please sign in to comment.