Skip to content

Commit

Permalink
Merge pull request #253 from mercury7720/main
Browse files Browse the repository at this point in the history
feat: support spark desk new model
  • Loading branch information
zmh-program authored Oct 16, 2024
2 parents 93b4cef + 1ba6faa commit 3f6026d
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 33 deletions.
42 changes: 31 additions & 11 deletions adapter/sparkdesk/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import (
)

var FunctionCallingModels = []string{
globals.SparkDeskV3,
globals.SparkDeskV35,
globals.SparkDeskMax,
globals.SparkDeskV4Ultra,
}

func GetToken(props *adaptercommon.ChatProps) *int {
Expand All @@ -19,19 +19,31 @@ func GetToken(props *adaptercommon.ChatProps) *int {
}

switch props.Model {
case globals.SparkDeskV2, globals.SparkDeskV3, globals.SparkDeskV35:
if *props.MaxTokens > 8192 {
return utils.ToPtr(8192)
}
case globals.SparkDesk:
case globals.SparkDeskLite, globals.SparkDeskPro128K:
if *props.MaxTokens > 4096 {
return utils.ToPtr(4096)
}
case globals.SparkDeskPro, globals.SparkDeskMax, globals.SparkDeskMax32K, globals.SparkDeskV4Ultra:
if *props.MaxTokens > 8192 {
return utils.ToPtr(8192)
}
}

return props.MaxTokens
}

func GetTopK(props *adaptercommon.ChatProps) *int {
if props.TopK == nil {
return nil
}
// topk max value is 6
if *props.TopK > 6 {
return utils.ToPtr(6)
}

return props.TopK
}

func (c *ChatInstance) GetMessages(props *adaptercommon.ChatProps) []Message {
var messages []Message
for _, message := range props.Message {
Expand Down Expand Up @@ -103,8 +115,13 @@ func getChoice(form *ChatResponse) *globals.Chunk {
}

func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, hook globals.Hook) error {
endpoint := fmt.Sprintf("%s/%s/chat", c.Endpoint, TransformAddr(props.Model))

var endpoint string
switch props.Model {
case globals.SparkDeskPro128K, globals.SparkDeskMax32K:
endpoint = fmt.Sprintf("%s/chat/%s", c.Endpoint, TransformModel(props.Model))
default:
endpoint = fmt.Sprintf("%s/%s/chat", c.Endpoint, TransformAddr(props.Model))
}
var conn *utils.WebSocket
if conn = utils.NewWebsocketClient(c.GenerateUrl(endpoint)); conn == nil {
return fmt.Errorf("sparkdesk error: websocket connection failed")
Expand All @@ -121,10 +138,13 @@ func (c *ChatInstance) CreateStreamChatRequest(props *adaptercommon.ChatProps, h
},
Functions: c.GetFunctionCalling(props),
},

Parameter: RequestParameter{
Chat: ChatParameter{
Domain: TransformModel(props.Model),
MaxToken: GetToken(props),
Domain: TransformModel(props.Model),
MaxToken: GetToken(props),
Temperature: props.Temperature,
TopK: GetTopK(props),
},
},
}); err != nil {
Expand Down
24 changes: 14 additions & 10 deletions adapter/sparkdesk/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,33 @@ type ChatInstance struct {

func TransformAddr(model string) string {
switch model {
case globals.SparkDesk:
case globals.SparkDeskLite:
return "v1.1"
case globals.SparkDeskV2:
return "v2.1"
case globals.SparkDeskV3:
case globals.SparkDeskPro:
return "v3.1"
case globals.SparkDeskV35:
case globals.SparkDeskMax:
return "v3.5"
case globals.SparkDeskV4Ultra:
return "v4.0"
default:
return "v1.1"
}
}

func TransformModel(model string) string {
switch model {
case globals.SparkDesk:
case globals.SparkDeskLite:
return "general"
case globals.SparkDeskV2:
return "generalv2"
case globals.SparkDeskV3:
case globals.SparkDeskPro:
return "generalv3"
case globals.SparkDeskV35:
case globals.SparkDeskPro128K:
return "pro-128k"
case globals.SparkDeskMax:
return "generalv3.5"
case globals.SparkDeskMax32K:
return "max-32k"
case globals.SparkDeskV4Ultra:
return "4.0Ultra"
default:
return "general"
}
Expand Down
10 changes: 6 additions & 4 deletions app/src/admin/channel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,12 @@ export const ChannelInfos: Record<string, ChannelInfo> = {
endpoint: "wss://spark-api.xf-yun.com",
format: "<app-id>|<api-secret>|<api-key>",
models: [
"spark-desk-v1.5",
"spark-desk-v2",
"spark-desk-v3",
"spark-desk-v3.5",
"spark-desk-lite",
"spark-desk-pro",
"spark-desk-pro-128k",
"spark-desk-max",
"spark-desk-max-32k",
"spark-desk-4.0-ultra",
],
},
chatglm: {
Expand Down
20 changes: 16 additions & 4 deletions app/src/admin/datasets/charge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -155,17 +155,29 @@ export const pricing: PricingDataset = [
billing_type: timesBilling,
},
{
models: ["spark-desk-v1.5"],
input: 0.015,
output: 0.015,
models: ["spark-desk-lite"], // free
input: 0.001,
output: 0.001,
currency: Currency.CNY,
},
{
models: ["spark-desk-v2", "spark-desk-v3", "spark-desk-v3.5"],
models: ["spark-desk-pro", "spark-desk-pro-128k","spark-desk-max"],
input: 0.03,
output: 0.03,
currency: Currency.CNY,
},
{
models: ["spark-desk-max-32k"],
input: 0.032,
output: 0.032,
currency: Currency.CNY,
},
{
models: ["spark-desk-4.0-ultra"],
input: 0.1,
output: 0.1,
currency: Currency.CNY,
},
{
models: ["moonshot-v1-8k"],
input: 0.012,
Expand Down
10 changes: 6 additions & 4 deletions globals/variables.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,12 @@ const (
Claude2200k = "claude-2.1"
Claude3 = "claude-3"
ClaudeSlack = "claude-slack"
SparkDesk = "spark-desk-v1.5"
SparkDeskV2 = "spark-desk-v2"
SparkDeskV3 = "spark-desk-v3"
SparkDeskV35 = "spark-desk-v3.5"
SparkDeskLite = "spark-desk-lite"
SparkDeskPro = "spark-desk-pro"
SparkDeskPro128K = "spark-desk-pro-128k"
SparkDeskMax = "spark-desk-max"
SparkDeskMax32K = "spark-desk-max-32k"
SparkDeskV4Ultra = "spark-desk-4.0-ultra"
ChatBison001 = "chat-bison-001"
GeminiPro = "gemini-pro"
GeminiProVision = "gemini-pro-vision"
Expand Down

0 comments on commit 3f6026d

Please sign in to comment.