diff --git a/config.yaml b/config.yaml index a35be3d5..4ed9afb0 100644 --- a/config.yaml +++ b/config.yaml @@ -28,6 +28,12 @@ google: # - category: HARM_CATEGORY_HARASSMENT # threshold: BLOCK_NONE +claude: + pad: 0 + cookies: + - 'xxx1' + - 'xxx2' + you: helper: 8082 cookies: @@ -58,7 +64,7 @@ serverless: interpreter: baseUrl: http://127.0.0.1:8000 echoCode: false -# reverseUrl: ws://127.0.0.1:8000/ws TODO - +# reverse: true TODO - custom-llm: baseUrl: http://127.0.0.1:8080 diff --git a/internal/plugin/llm/claude/adapter.go b/internal/plugin/llm/claude/adapter.go index 009c8fc9..b0c209cd 100644 --- a/internal/plugin/llm/claude/adapter.go +++ b/internal/plugin/llm/claude/adapter.go @@ -5,21 +5,36 @@ import ( "github.com/bincooo/chatgpt-adapter/internal/gin.handler/response" "github.com/bincooo/chatgpt-adapter/internal/plugin" "github.com/bincooo/chatgpt-adapter/logger" + "github.com/bincooo/chatgpt-adapter/pkg" claude3 "github.com/bincooo/claude-api" "github.com/gin-gonic/gin" "strings" + "time" ) var ( Adapter = API{} Model = "claude" padMaxCount = 25000 + + claudeRollContainer *common.PollContainer[string] ) type API struct { plugin.BaseAdapter } +func init() { + common.AddInitialized(func() { + cookies := pkg.Config.GetStringSlice("claude.cookies") + if len(cookies) == 0 { + return + } + claudeRollContainer = common.NewPollContainer[string](cookies, 30*time.Minute) + claudeRollContainer.Condition = Condition + }) +} + func (API) Match(_ *gin.Context, model string) bool { switch model { case "claude", @@ -66,7 +81,6 @@ func (API) Models() []plugin.Model { func (API) Completion(ctx *gin.Context) { var ( - cookie = ctx.GetString("token") completion = common.GetGinCompletion(ctx) matchers = common.GetGinMatchers(ctx) model = "" @@ -78,6 +92,14 @@ func (API) Completion(ctx *gin.Context) { } } + cookie, err := claudeRollContainer.Poll() + if err != nil { + logger.Error(err) + response.Error(ctx, -1, err) + return + } + + defer resetMarker(cookie) options, err := claude3.NewDefaultOptions(cookie, model) if err != nil { logger.Error(err) @@ -114,3 +136,30 @@ func (API) Completion(ctx *gin.Context) { response.Error(ctx, -1, "EMPTY RESPONSE") } } + +func Condition(cookie string) bool { + marker, err := claudeRollContainer.GetMarker(cookie) + if err != nil { + logger.Error(err) + return false + } + + return marker == 0 +} + +func resetMarker(cookie string) { + marker, e := claudeRollContainer.GetMarker(cookie) + if e != nil { + logger.Error(e) + return + } + + if marker == 0 || marker > 1 { + return + } + + e = claudeRollContainer.SetMarker(cookie, 0) + if e != nil { + logger.Error(e) + } +} diff --git a/internal/plugin/llm/claude/message.go b/internal/plugin/llm/claude/message.go index 1eabed1e..ecdb0046 100644 --- a/internal/plugin/llm/claude/message.go +++ b/internal/plugin/llm/claude/message.go @@ -160,7 +160,11 @@ func mergeMessages(ctx *gin.Context, messages []pkg.Keyv[interface{}]) (attachme nMessages, _ := common.TextMessageCombiner(messages, iterator) join := strings.Join(nMessages, "\n\n") if ctx.GetBool("pad") { - join = common.PadJunkMessage(padMaxCount-len(join), join) + count := ctx.GetInt("claude.pad") + if count == 0 { + count = padMaxCount + } + join = common.PadJunkMessage(count-len(join), join) } tokens = common.CalcTokens(join) diff --git a/internal/plugin/llm/claude/toolcall.go b/internal/plugin/llm/claude/toolcall.go index 094c0242..6ce682e9 100644 --- a/internal/plugin/llm/claude/toolcall.go +++ b/internal/plugin/llm/claude/toolcall.go @@ -32,7 +32,11 @@ func completeToolCalls(ctx *gin.Context, cookie string, completion pkg.ChatCompl } if ctx.GetBool("pad") { - message = common.PadJunkMessage(padMaxCount-len(message), message) + count := ctx.GetInt("claude.pad") + if count == 0 { + count = padMaxCount + } + message = common.PadJunkMessage(count-len(message), message) } chatResponse, err := chat.Reply(common.GetGinContext(ctx), "", []claude3.Attachment{