Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: changeoptional field to pointer type #1907

Merged
merged 2 commits into from
Nov 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions common/helper/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -137,3 +137,23 @@ func String2Int(str string) int {
}
return num
}

func Float64PtrMax(p *float64, maxValue float64) *float64 {
if p == nil {
return nil
}
if *p > maxValue {
return &maxValue
}
return p
}

func Float64PtrMin(p *float64, minValue float64) *float64 {
if p == nil {
return nil
}
if *p < minValue {
return &minValue
}
return p
}
4 changes: 1 addition & 3 deletions relay/adaptor/ali/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,7 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
enableSearch = true
aliModel = strings.TrimSuffix(aliModel, EnableSearchModelSuffix)
}
if request.TopP >= 1 {
request.TopP = 0.9999
}
request.TopP = helper.Float64PtrMax(request.TopP, 0.9999)
return &ChatRequest{
Model: aliModel,
Input: Input{
Expand Down
4 changes: 2 additions & 2 deletions relay/adaptor/ali/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ type Input struct {
}

type Parameters struct {
TopP float64 `json:"top_p,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Seed uint64 `json:"seed,omitempty"`
EnableSearch bool `json:"enable_search,omitempty"`
IncrementalOutput bool `json:"incremental_output,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
ResultFormat string `json:"result_format,omitempty"`
Tools []model.Tool `json:"tools,omitempty"`
}
Expand Down
4 changes: 2 additions & 2 deletions relay/adaptor/anthropic/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ type Request struct {
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Expand Down
4 changes: 2 additions & 2 deletions relay/adaptor/aws/claude/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ type Request struct {
Messages []anthropic.Message `json:"messages"`
System string `json:"system,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"`
Expand Down
8 changes: 4 additions & 4 deletions relay/adaptor/aws/llama3/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ package aws
//
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-meta.html
type Request struct {
Prompt string `json:"prompt"`
MaxGenLen int `json:"max_gen_len,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Prompt string `json:"prompt"`
MaxGenLen int `json:"max_gen_len,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
}

// Response is the response from AWS Llama3
Expand Down
6 changes: 3 additions & 3 deletions relay/adaptor/baidu/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ type Message struct {

type ChatRequest struct {
Messages []Message `json:"messages"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
PenaltyScore float64 `json:"penalty_score,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
PenaltyScore *float64 `json:"penalty_score,omitempty"`
Stream bool `json:"stream,omitempty"`
System string `json:"system,omitempty"`
DisableSearch bool `json:"disable_search,omitempty"`
Expand Down
2 changes: 1 addition & 1 deletion relay/adaptor/cloudflare/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ type Request struct {
Prompt string `json:"prompt,omitempty"`
Raw bool `json:"raw,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
}
2 changes: 1 addition & 1 deletion relay/adaptor/cohere/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func ConvertRequest(textRequest model.GeneralOpenAIRequest) *Request {
K: textRequest.TopK,
Stream: textRequest.Stream,
FrequencyPenalty: textRequest.FrequencyPenalty,
PresencePenalty: textRequest.FrequencyPenalty,
PresencePenalty: textRequest.PresencePenalty,
Seed: int(textRequest.Seed),
}
if cohereRequest.Model == "" {
Expand Down
8 changes: 4 additions & 4 deletions relay/adaptor/cohere/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@ type Request struct {
PromptTruncation string `json:"prompt_truncation,omitempty"` // 默认值为"AUTO"
Connectors []Connector `json:"connectors,omitempty"`
Documents []Document `json:"documents,omitempty"`
Temperature float64 `json:"temperature,omitempty"` // 默认值为0.3
Temperature *float64 `json:"temperature,omitempty"` // 默认值为0.3
MaxTokens int `json:"max_tokens,omitempty"`
MaxInputTokens int `json:"max_input_tokens,omitempty"`
K int `json:"k,omitempty"` // 默认值为0
P float64 `json:"p,omitempty"` // 默认值为0.75
P *float64 `json:"p,omitempty"` // 默认值为0.75
Seed int `json:"seed,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0
PresencePenalty float64 `json:"presence_penalty,omitempty"` // 默认值为0.0
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"` // 默认值为0.0
PresencePenalty *float64 `json:"presence_penalty,omitempty"` // 默认值为0.0
Tools []Tool `json:"tools,omitempty"`
ToolResults []ToolResult `json:"tool_results,omitempty"`
}
Expand Down
4 changes: 2 additions & 2 deletions relay/adaptor/gemini/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ type ChatTools struct {
type ChatGenerationConfig struct {
ResponseMimeType string `json:"responseMimeType,omitempty"`
ResponseSchema any `json:"responseSchema,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
MaxOutputTokens int `json:"maxOutputTokens,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
Expand Down
16 changes: 8 additions & 8 deletions relay/adaptor/ollama/model.go
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package ollama

type Options struct {
Seed int `json:"seed,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP float64 `json:"top_p,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
Seed int `json:"seed,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
NumPredict int `json:"num_predict,omitempty"`
NumCtx int `json:"num_ctx,omitempty"`
}

type Message struct {
Expand Down
10 changes: 5 additions & 5 deletions relay/adaptor/palm/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ type Prompt struct {
}

type ChatRequest struct {
Prompt Prompt `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
Prompt Prompt `json:"prompt"`
Temperature *float64 `json:"temperature,omitempty"`
CandidateCount int `json:"candidateCount,omitempty"`
TopP *float64 `json:"topP,omitempty"`
TopK int `json:"topK,omitempty"`
}

type Error struct {
Expand Down
4 changes: 2 additions & 2 deletions relay/adaptor/tencent/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ func ConvertRequest(request model.GeneralOpenAIRequest) *ChatRequest {
Model: &request.Model,
Stream: &request.Stream,
Messages: messages,
TopP: &request.TopP,
Temperature: &request.Temperature,
TopP: request.TopP,
Temperature: request.Temperature,
}
}

Expand Down
4 changes: 2 additions & 2 deletions relay/adaptor/vertexai/claude/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ type Request struct {
MaxTokens int `json:"max_tokens,omitempty"`
StopSequences []string `json:"stop_sequences,omitempty"`
Stream bool `json:"stream,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []anthropic.Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Expand Down
10 changes: 5 additions & 5 deletions relay/adaptor/xunfei/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ type ChatRequest struct {
} `json:"header"`
Parameter struct {
Chat struct {
Domain string `json:"domain,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
Domain string `json:"domain,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopK int `json:"top_k,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
Auditing bool `json:"auditing,omitempty"`
} `json:"chat"`
} `json:"parameter"`
Payload struct {
Expand Down
14 changes: 7 additions & 7 deletions relay/adaptor/zhipu/adaptor.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common/helper"
"github.com/songquanpeng/one-api/relay/adaptor"
"github.com/songquanpeng/one-api/relay/adaptor/openai"
"github.com/songquanpeng/one-api/relay/meta"
"github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/relaymode"
"io"
"math"
"net/http"
"strings"
)
Expand Down Expand Up @@ -65,13 +65,13 @@ func (a *Adaptor) ConvertRequest(c *gin.Context, relayMode int, request *model.G
baiduEmbeddingRequest, err := ConvertEmbeddingRequest(*request)
return baiduEmbeddingRequest, err
default:
// TopP (0.0, 1.0)
request.TopP = math.Min(0.99, request.TopP)
request.TopP = math.Max(0.01, request.TopP)
// TopP [0.0, 1.0]
request.TopP = helper.Float64PtrMax(request.TopP, 1)
request.TopP = helper.Float64PtrMin(request.TopP, 0)

// Temperature (0.0, 1.0)
request.Temperature = math.Min(0.99, request.Temperature)
request.Temperature = math.Max(0.01, request.Temperature)
// Temperature [0.0, 1.0]
request.Temperature = helper.Float64PtrMax(request.Temperature, 1)
request.Temperature = helper.Float64PtrMin(request.Temperature, 0)
a.SetVersionByModeName(request.Model)
if a.APIVersion == "v4" {
return request, nil
Expand Down
4 changes: 2 additions & 2 deletions relay/adaptor/zhipu/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ type Message struct {

type Request struct {
Prompt []Message `json:"prompt"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
RequestId string `json:"request_id,omitempty"`
Incremental bool `json:"incremental,omitempty"`
}
Expand Down
8 changes: 4 additions & 4 deletions relay/model/general.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ type GeneralOpenAIRequest struct {
Model string `json:"model,omitempty"`
Modalities []string `json:"modalities,omitempty"`
Audio *Audio `json:"audio,omitempty"`
FrequencyPenalty float64 `json:"frequency_penalty,omitempty"`
FrequencyPenalty *float64 `json:"frequency_penalty,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
N int `json:"n,omitempty"`
PresencePenalty float64 `json:"presence_penalty,omitempty"`
PresencePenalty *float64 `json:"presence_penalty,omitempty"`
ResponseFormat *ResponseFormat `json:"response_format,omitempty"`
Seed float64 `json:"seed,omitempty"`
Stop any `json:"stop,omitempty"`
Stream bool `json:"stream,omitempty"`
StreamOptions *StreamOptions `json:"stream_options,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"top_p,omitempty"`
Temperature *float64 `json:"temperature,omitempty"`
TopP *float64 `json:"top_p,omitempty"`
TopK int `json:"top_k,omitempty"`
Tools []Tool `json:"tools,omitempty"`
ToolChoice any `json:"tool_choice,omitempty"`
Expand Down