diff --git a/conversation/anthropic/anthropic.go b/conversation/anthropic/anthropic.go index 344c458d03..4273e57f05 100644 --- a/conversation/anthropic/anthropic.go +++ b/conversation/anthropic/anthropic.go @@ -92,7 +92,13 @@ func (a *Anthropic) Converse(ctx context.Context, r *conversation.ConversationRe }) } - resp, err := a.llm.GenerateContent(ctx, messages) + opts := []llms.CallOption{} + + if r.Temperature > 0 { + opts = append(opts, conversation.LangchainTemperature(r.Temperature)) + } + + resp, err := a.llm.GenerateContent(ctx, messages, opts...) if err != nil { return nil, err } diff --git a/conversation/aws/bedrock/bedrock.go b/conversation/aws/bedrock/bedrock.go index 709e74da00..2da96d5383 100644 --- a/conversation/aws/bedrock/bedrock.go +++ b/conversation/aws/bedrock/bedrock.go @@ -104,7 +104,13 @@ func (b *AWSBedrock) Converse(ctx context.Context, r *conversation.ConversationR }) } - resp, err := b.llm.GenerateContent(ctx, messages) + opts := []llms.CallOption{} + + if r.Temperature > 0 { + opts = append(opts, conversation.LangchainTemperature(r.Temperature)) + } + + resp, err := b.llm.GenerateContent(ctx, messages, opts...) if err != nil { return nil, err } diff --git a/conversation/converse.go b/conversation/converse.go index 16fce9f6d6..6d3183bc4c 100644 --- a/conversation/converse.go +++ b/conversation/converse.go @@ -42,6 +42,7 @@ type ConversationRequest struct { Inputs []ConversationInput `json:"inputs"` Parameters map[string]*anypb.Any `json:"parameters"` ConversationContext string `json:"conversationContext"` + Temperature float64 `json:"temperature"` // from metadata Key string `json:"key"` diff --git a/conversation/openai/openai.go b/conversation/openai/openai.go index 02606aca28..beb164db2e 100644 --- a/conversation/openai/openai.go +++ b/conversation/openai/openai.go @@ -111,8 +111,9 @@ func (o *OpenAI) Converse(ctx context.Context, r *conversation.ConversationReque } req := openai.ChatCompletionRequest{ - Model: o.model, - Messages: messages, + Model: o.model, + Messages: messages, + Temperature: float32(r.Temperature), } // TODO: support ConversationContext diff --git a/conversation/temperature.go b/conversation/temperature.go new file mode 100644 index 0000000000..8e4bfa96e2 --- /dev/null +++ b/conversation/temperature.go @@ -0,0 +1,22 @@ +/* +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +package conversation + +import "github.com/tmc/langchaingo/llms" + +// LangchainTemperature returns a langchain compliant LLM temperature +func LangchainTemperature(temperature float64) llms.CallOption { + return llms.WithTemperature(temperature) +}