From 628df97f96ae786b82f6aba883a4bbf2a12743ba Mon Sep 17 00:00:00 2001 From: Buer <42402987+MartialBE@users.noreply.github.com> Date: Tue, 23 Apr 2024 19:57:14 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat:=20support=20other=20OpenAI=20?= =?UTF-8?q?APIs=20(#165)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ✨ feat: support other OpenAI APIs * 🔖 chore: Update English translation --- i18n/en.json | 4 +- middleware/auth.go | 17 ++++++++ providers/base/common.go | 4 ++ providers/base/interface.go | 3 +- providers/openai/base.go | 24 +++++++---- relay/common.go | 14 +++++- relay/main.go | 7 +-- relay/relay.go | 85 +++++++++++++++++++++++++++++++++++++ router/relay-router.go | 52 ++++++----------------- 9 files changed, 153 insertions(+), 57 deletions(-) create mode 100644 relay/relay.go diff --git a/i18n/en.json b/i18n/en.json index 97218b96c..e3d4dc27a 100644 --- a/i18n/en.json +++ b/i18n/en.json @@ -1122,5 +1122,7 @@ "模型名称映射, 你可以取一个容易记忆的名字来代替coze-{bot_id},例如:": "Model name mapping, you can take an easy-to-remember name to replace coze-{bot_id}, for example: ", ",注意:如果使用了模型映射,那么上面的模型名称必须使用映射前的名称,上述例子中,你应该在模型中填入coze-translate(如果已经使用了coze-*,可以忽略)。": ", Note: If a model mapping is used, then the model name above must use the name before the mapping. In the example above, you should fill in coze-translate in the model (if coze-* has been used, it can be ignored).", "位置/区域": "Location/Region", - "请输入你 Speech Studio 的位置/区域,例如:eastasia": "Please enter the location/region of your Speech Studio, for example: eastasia" + "请输入你 Speech Studio 的位置/区域,例如:eastasia": "Please enter the location/region of your Speech Studio, for example: eastasia", + "必须指定渠道": "Channel must be specified", + "中继": "Relay" } diff --git a/middleware/auth.go b/middleware/auth.go index 30ccc6ef1..4964a4dd6 100644 --- a/middleware/auth.go +++ b/middleware/auth.go @@ -114,6 +114,10 @@ func tokenAuth(c *gin.Context, key string) { return } c.Set("specific_channel_id", channelId) + if len(parts) == 3 && parts[2] == "ignore" { + c.Set("specific_channel_id_ignore", true) + } + } else { abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道") return @@ -135,3 +139,16 @@ func MjAuth() func(c *gin.Context) { tokenAuth(c, key) } } + +func SpecifiedChannel() func(c *gin.Context) { + return func(c *gin.Context) { + channelId := c.GetInt("specific_channel_id") + c.Set("specific_channel_id_ignore", false) + + if channelId <= 0 { + abortWithMessage(c, http.StatusForbidden, "必须指定渠道") + return + } + c.Next() + } +} diff --git a/providers/base/common.go b/providers/base/common.go index 710b49cf0..1c486ab35 100644 --- a/providers/base/common.go +++ b/providers/base/common.go @@ -143,3 +143,7 @@ func (p *BaseProvider) GetSupportedAPIUri(relayMode int) (url string, err *types return } + +func (p *BaseProvider) GetRequester() *requester.HTTPRequester { + return p.Requester +} diff --git a/providers/base/interface.go b/providers/base/interface.go index 868f57118..5618c154b 100644 --- a/providers/base/interface.go +++ b/providers/base/interface.go @@ -20,7 +20,7 @@ type ProviderInterface interface { // 获取完整请求URL // GetFullRequestURL(requestURL string, modelName string) string // 获取请求头 - // GetRequestHeaders() (headers map[string]string) + GetRequestHeaders() map[string]string // 获取用量 GetUsage() *types.Usage // 设置用量 @@ -35,6 +35,7 @@ type ProviderInterface interface { // SupportAPI(relayMode int) bool GetChannel() *model.Channel ModelMappingHandler(modelName string) (string, error) + GetRequester() *requester.HTTPRequester } // 完成接口 diff --git a/providers/openai/base.go b/providers/openai/base.go index 3f11e977e..9121af47d 100644 --- a/providers/openai/base.go +++ b/providers/openai/base.go @@ -85,22 +85,28 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string) if p.IsAzure { apiVersion := p.Channel.Other - // 检测模型是是否包含 . 如果有则直接去掉 - modelName = strings.Replace(modelName, ".", "", -1) - - if modelName == "dall-e-2" { - // 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本 - // 已经没有dall-e-2了,所以暂时写死 - requestURL = fmt.Sprintf("/openai/%s:submit?api-version=2023-09-01-preview", requestURL) + if modelName != "" { + // 检测模型是是否包含 . 如果有则直接去掉 + modelName = strings.Replace(modelName, ".", "", -1) + + if modelName == "dall-e-2" { + // 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本 + // 已经没有dall-e-2了,所以暂时写死 + requestURL = fmt.Sprintf("/openai/%s:submit?api-version=2023-09-01-preview", requestURL) + } else { + requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion) + } } else { - requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion) + requestURL = strings.TrimPrefix(requestURL, "/v1") + requestURL = fmt.Sprintf("/openai%s?api-version=%s", requestURL, apiVersion) } } if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") { if p.IsAzure { - requestURL = strings.TrimPrefix(requestURL, "/openai/deployments") + requestURL = strings.TrimPrefix(requestURL, "/openai") + requestURL = strings.TrimPrefix(requestURL, "/deployments") } else { requestURL = strings.TrimPrefix(requestURL, "/v1") } diff --git a/relay/common.go b/relay/common.go index 9289da8f7..bbaaa209a 100644 --- a/relay/common.go +++ b/relay/common.go @@ -78,7 +78,8 @@ func GetProvider(c *gin.Context, modeName string) (provider providersBase.Provid func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail error) { channelId := c.GetInt("specific_channel_id") - if channelId > 0 { + ignore := c.GetBool("specific_channel_id_ignore") + if channelId > 0 && !ignore { return fetchChannelById(channelId) } @@ -206,7 +207,8 @@ func responseCache(c *gin.Context, response string) { func shouldRetry(c *gin.Context, statusCode int) bool { channelId := c.GetInt("specific_channel_id") - if channelId > 0 { + ignore := c.GetBool("specific_channel_id_ignore") + if channelId > 0 && !ignore { return false } if statusCode == http.StatusTooManyRequests { @@ -230,3 +232,11 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st controller.DisableChannel(channelId, channelName, err.Message, true) } } + +func relayResponseWithErr(c *gin.Context, err *types.OpenAIErrorWithStatusCode) { + requestId := c.GetString(common.RequestIdKey) + err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId) + c.JSON(err.StatusCode, gin.H{ + "error": err.OpenAIError, + }) +} diff --git a/relay/main.go b/relay/main.go index 3630ba2ba..4c983a42d 100644 --- a/relay/main.go +++ b/relay/main.go @@ -75,15 +75,10 @@ func Relay(c *gin.Context) { } if apiErr != nil { - requestId := c.GetString(common.RequestIdKey) if apiErr.StatusCode == http.StatusTooManyRequests { apiErr.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试" } - apiErr.OpenAIError.Message = common.MessageWithRequestId(apiErr.OpenAIError.Message, requestId) - c.JSON(apiErr.StatusCode, gin.H{ - "error": apiErr.OpenAIError, - }) - + relayResponseWithErr(c, apiErr) } } diff --git a/relay/relay.go b/relay/relay.go new file mode 100644 index 000000000..a0b493b83 --- /dev/null +++ b/relay/relay.go @@ -0,0 +1,85 @@ +package relay + +import ( + "net/http" + "one-api/common" + "one-api/model" + "one-api/providers/azure" + "one-api/providers/openai" + "strings" + "time" + + "github.com/gin-gonic/gin" +) + +func RelayOnly(c *gin.Context) { + provider, _, fail := GetProvider(c, "") + if fail != nil { + common.AbortWithMessage(c, http.StatusServiceUnavailable, fail.Error()) + return + } + + channel := provider.GetChannel() + if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeAzure { + common.AbortWithMessage(c, http.StatusServiceUnavailable, "provider must be of type azureopenai or openai") + return + } + + // 获取请求的path + url := "" + path := c.Request.URL.Path + openAIProvider, ok := provider.(*openai.OpenAIProvider) + if !ok { + azureProvider, ok := provider.(*azure.AzureProvider) + if !ok { + common.AbortWithMessage(c, http.StatusServiceUnavailable, "provider must be of type openai") + return + } + url = azureProvider.GetFullRequestURL(path, "") + } else { + url = openAIProvider.GetFullRequestURL(path, "") + } + + headers := c.Request.Header + mapHeaders := provider.GetRequestHeaders() + // 设置请求头 + for k, v := range headers { + if _, ok := mapHeaders[k]; ok { + continue + } + mapHeaders[k] = strings.Join(v, ", ") + } + + requester := provider.GetRequester() + req, err := requester.NewRequest(c.Request.Method, url, requester.WithBody(c.Request.Body), requester.WithHeader(mapHeaders)) + if err != nil { + common.AbortWithMessage(c, http.StatusBadRequest, err.Error()) + return + } + + defer req.Body.Close() + + response, errWithCode := requester.SendRequestRaw(req) + if errWithCode != nil { + relayResponseWithErr(c, errWithCode) + return + } + + errWithCode = responseMultipart(c, response) + + if errWithCode != nil { + relayResponseWithErr(c, errWithCode) + return + } + + requestTime := 0 + requestStartTimeValue := c.Request.Context().Value("requestStartTime") + if requestStartTimeValue != nil { + requestStartTime, ok := requestStartTimeValue.(time.Time) + if ok { + requestTime = int(time.Since(requestStartTime).Milliseconds()) + } + } + model.RecordConsumeLog(c.Request.Context(), c.GetInt("id"), c.GetInt("channel_id"), 0, 0, "", c.GetString("token_name"), 0, "中继:"+path, requestTime) + +} diff --git a/router/relay-router.go b/router/relay-router.go index 0d4416fa5..38c4f8623 100644 --- a/router/relay-router.go +++ b/router/relay-router.go @@ -1,7 +1,6 @@ package router import ( - "one-api/controller" "one-api/middleware" "one-api/relay" "one-api/relay/midjourney" @@ -38,43 +37,20 @@ func setOpenAIRouter(router *gin.Engine) { relayV1Router.POST("/audio/translations", relay.Relay) relayV1Router.POST("/audio/speech", relay.Relay) relayV1Router.POST("/moderations", relay.Relay) - relayV1Router.GET("/files", controller.RelayNotImplemented) - relayV1Router.POST("/files", controller.RelayNotImplemented) - relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented) - relayV1Router.GET("/files/:id", controller.RelayNotImplemented) - relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented) - relayV1Router.POST("/fine_tuning/jobs", controller.RelayNotImplemented) - relayV1Router.GET("/fine_tuning/jobs", controller.RelayNotImplemented) - relayV1Router.GET("/fine_tuning/jobs/:id", controller.RelayNotImplemented) - relayV1Router.POST("/fine_tuning/jobs/:id/cancel", controller.RelayNotImplemented) - relayV1Router.GET("/fine_tuning/jobs/:id/events", controller.RelayNotImplemented) - relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented) - relayV1Router.POST("/assistants", controller.RelayNotImplemented) - relayV1Router.GET("/assistants/:id", controller.RelayNotImplemented) - relayV1Router.POST("/assistants/:id", controller.RelayNotImplemented) - relayV1Router.DELETE("/assistants/:id", controller.RelayNotImplemented) - relayV1Router.GET("/assistants", controller.RelayNotImplemented) - relayV1Router.POST("/assistants/:id/files", controller.RelayNotImplemented) - relayV1Router.GET("/assistants/:id/files/:fileId", controller.RelayNotImplemented) - relayV1Router.DELETE("/assistants/:id/files/:fileId", controller.RelayNotImplemented) - relayV1Router.GET("/assistants/:id/files", controller.RelayNotImplemented) - relayV1Router.POST("/threads", controller.RelayNotImplemented) - relayV1Router.GET("/threads/:id", controller.RelayNotImplemented) - relayV1Router.POST("/threads/:id", controller.RelayNotImplemented) - relayV1Router.DELETE("/threads/:id", controller.RelayNotImplemented) - relayV1Router.POST("/threads/:id/messages", controller.RelayNotImplemented) - relayV1Router.GET("/threads/:id/messages/:messageId", controller.RelayNotImplemented) - relayV1Router.POST("/threads/:id/messages/:messageId", controller.RelayNotImplemented) - relayV1Router.GET("/threads/:id/messages/:messageId/files/:filesId", controller.RelayNotImplemented) - relayV1Router.GET("/threads/:id/messages/:messageId/files", controller.RelayNotImplemented) - relayV1Router.POST("/threads/:id/runs", controller.RelayNotImplemented) - relayV1Router.GET("/threads/:id/runs/:runsId", controller.RelayNotImplemented) - relayV1Router.POST("/threads/:id/runs/:runsId", controller.RelayNotImplemented) - relayV1Router.GET("/threads/:id/runs", controller.RelayNotImplemented) - relayV1Router.POST("/threads/:id/runs/:runsId/submit_tool_outputs", controller.RelayNotImplemented) - relayV1Router.POST("/threads/:id/runs/:runsId/cancel", controller.RelayNotImplemented) - relayV1Router.GET("/threads/:id/runs/:runsId/steps/:stepId", controller.RelayNotImplemented) - relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented) + + relayV1Router.Use(middleware.SpecifiedChannel()) + { + relayV1Router.Any("/files", relay.RelayOnly) + relayV1Router.Any("/files/*any", relay.RelayOnly) + relayV1Router.Any("/fine_tuning/*any", relay.RelayOnly) + relayV1Router.Any("/assistants", relay.RelayOnly) + relayV1Router.Any("/assistants/*any", relay.RelayOnly) + relayV1Router.Any("/threads", relay.RelayOnly) + relayV1Router.Any("/threads/*any", relay.RelayOnly) + relayV1Router.Any("/batches/*any", relay.RelayOnly) + relayV1Router.Any("/vector_stores/*any", relay.RelayOnly) + relayV1Router.DELETE("/models/:model", relay.RelayOnly) + } } }