Skip to content

Commit

Permalink
✨ feat: support other OpenAI APIs (#165)
Browse files Browse the repository at this point in the history
* ✨ feat: support other OpenAI APIs

* 🔖 chore: Update English translation
  • Loading branch information
MartialBE authored Apr 23, 2024
1 parent f91b985 commit 628df97
Show file tree
Hide file tree
Showing 9 changed files with 153 additions and 57 deletions.
4 changes: 3 additions & 1 deletion i18n/en.json
Original file line number Diff line number Diff line change
Expand Up @@ -1122,5 +1122,7 @@
"模型名称映射, 你可以取一个容易记忆的名字来代替coze-{bot_id},例如:": "Model name mapping, you can take an easy-to-remember name to replace coze-{bot_id}, for example: ",
",注意:如果使用了模型映射,那么上面的模型名称必须使用映射前的名称,上述例子中,你应该在模型中填入coze-translate(如果已经使用了coze-*,可以忽略)。": ", Note: If a model mapping is used, then the model name above must use the name before the mapping. In the example above, you should fill in coze-translate in the model (if coze-* has been used, it can be ignored).",
"位置/区域": "Location/Region",
"请输入你 Speech Studio 的位置/区域,例如:eastasia": "Please enter the location/region of your Speech Studio, for example: eastasia"
"请输入你 Speech Studio 的位置/区域,例如:eastasia": "Please enter the location/region of your Speech Studio, for example: eastasia",
"必须指定渠道": "Channel must be specified",
"中继": "Relay"
}
17 changes: 17 additions & 0 deletions middleware/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ func tokenAuth(c *gin.Context, key string) {
return
}
c.Set("specific_channel_id", channelId)
if len(parts) == 3 && parts[2] == "ignore" {
c.Set("specific_channel_id_ignore", true)
}

} else {
abortWithMessage(c, http.StatusForbidden, "普通用户不支持指定渠道")
return
Expand All @@ -135,3 +139,16 @@ func MjAuth() func(c *gin.Context) {
tokenAuth(c, key)
}
}

func SpecifiedChannel() func(c *gin.Context) {
return func(c *gin.Context) {
channelId := c.GetInt("specific_channel_id")
c.Set("specific_channel_id_ignore", false)

if channelId <= 0 {
abortWithMessage(c, http.StatusForbidden, "必须指定渠道")
return
}
c.Next()
}
}
4 changes: 4 additions & 0 deletions providers/base/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,7 @@ func (p *BaseProvider) GetSupportedAPIUri(relayMode int) (url string, err *types

return
}

func (p *BaseProvider) GetRequester() *requester.HTTPRequester {
return p.Requester
}
3 changes: 2 additions & 1 deletion providers/base/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ type ProviderInterface interface {
// 获取完整请求URL
// GetFullRequestURL(requestURL string, modelName string) string
// 获取请求头
// GetRequestHeaders() (headers map[string]string)
GetRequestHeaders() map[string]string
// 获取用量
GetUsage() *types.Usage
// 设置用量
Expand All @@ -35,6 +35,7 @@ type ProviderInterface interface {
// SupportAPI(relayMode int) bool
GetChannel() *model.Channel
ModelMappingHandler(modelName string) (string, error)
GetRequester() *requester.HTTPRequester
}

// 完成接口
Expand Down
24 changes: 15 additions & 9 deletions providers/openai/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,22 +85,28 @@ func (p *OpenAIProvider) GetFullRequestURL(requestURL string, modelName string)

if p.IsAzure {
apiVersion := p.Channel.Other
// 检测模型是是否包含 . 如果有则直接去掉
modelName = strings.Replace(modelName, ".", "", -1)

if modelName == "dall-e-2" {
// 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本
// 已经没有dall-e-2了,所以暂时写死
requestURL = fmt.Sprintf("/openai/%s:submit?api-version=2023-09-01-preview", requestURL)
if modelName != "" {
// 检测模型是是否包含 . 如果有则直接去掉
modelName = strings.Replace(modelName, ".", "", -1)

if modelName == "dall-e-2" {
// 因为dall-e-3需要api-version=2023-12-01-preview,但是该版本
// 已经没有dall-e-2了,所以暂时写死
requestURL = fmt.Sprintf("/openai/%s:submit?api-version=2023-09-01-preview", requestURL)
} else {
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
}
} else {
requestURL = fmt.Sprintf("/openai/deployments/%s%s?api-version=%s", modelName, requestURL, apiVersion)
requestURL = strings.TrimPrefix(requestURL, "/v1")
requestURL = fmt.Sprintf("/openai%s?api-version=%s", requestURL, apiVersion)
}

}

if strings.HasPrefix(baseURL, "https://gateway.ai.cloudflare.com") {
if p.IsAzure {
requestURL = strings.TrimPrefix(requestURL, "/openai/deployments")
requestURL = strings.TrimPrefix(requestURL, "/openai")
requestURL = strings.TrimPrefix(requestURL, "/deployments")
} else {
requestURL = strings.TrimPrefix(requestURL, "/v1")
}
Expand Down
14 changes: 12 additions & 2 deletions relay/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ func GetProvider(c *gin.Context, modeName string) (provider providersBase.Provid

func fetchChannel(c *gin.Context, modelName string) (channel *model.Channel, fail error) {
channelId := c.GetInt("specific_channel_id")
if channelId > 0 {
ignore := c.GetBool("specific_channel_id_ignore")
if channelId > 0 && !ignore {
return fetchChannelById(channelId)
}

Expand Down Expand Up @@ -206,7 +207,8 @@ func responseCache(c *gin.Context, response string) {

func shouldRetry(c *gin.Context, statusCode int) bool {
channelId := c.GetInt("specific_channel_id")
if channelId > 0 {
ignore := c.GetBool("specific_channel_id_ignore")
if channelId > 0 && !ignore {
return false
}
if statusCode == http.StatusTooManyRequests {
Expand All @@ -230,3 +232,11 @@ func processChannelRelayError(ctx context.Context, channelId int, channelName st
controller.DisableChannel(channelId, channelName, err.Message, true)
}
}

func relayResponseWithErr(c *gin.Context, err *types.OpenAIErrorWithStatusCode) {
requestId := c.GetString(common.RequestIdKey)
err.OpenAIError.Message = common.MessageWithRequestId(err.OpenAIError.Message, requestId)
c.JSON(err.StatusCode, gin.H{
"error": err.OpenAIError,
})
}
7 changes: 1 addition & 6 deletions relay/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,10 @@ func Relay(c *gin.Context) {
}

if apiErr != nil {
requestId := c.GetString(common.RequestIdKey)
if apiErr.StatusCode == http.StatusTooManyRequests {
apiErr.OpenAIError.Message = "当前分组上游负载已饱和,请稍后再试"
}
apiErr.OpenAIError.Message = common.MessageWithRequestId(apiErr.OpenAIError.Message, requestId)
c.JSON(apiErr.StatusCode, gin.H{
"error": apiErr.OpenAIError,
})

relayResponseWithErr(c, apiErr)
}
}

Expand Down
85 changes: 85 additions & 0 deletions relay/relay.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
package relay

import (
"net/http"
"one-api/common"
"one-api/model"
"one-api/providers/azure"
"one-api/providers/openai"
"strings"
"time"

"github.com/gin-gonic/gin"
)

func RelayOnly(c *gin.Context) {
provider, _, fail := GetProvider(c, "")
if fail != nil {
common.AbortWithMessage(c, http.StatusServiceUnavailable, fail.Error())
return
}

channel := provider.GetChannel()
if channel.Type != common.ChannelTypeOpenAI && channel.Type != common.ChannelTypeAzure {
common.AbortWithMessage(c, http.StatusServiceUnavailable, "provider must be of type azureopenai or openai")
return
}

// 获取请求的path
url := ""
path := c.Request.URL.Path
openAIProvider, ok := provider.(*openai.OpenAIProvider)
if !ok {
azureProvider, ok := provider.(*azure.AzureProvider)
if !ok {
common.AbortWithMessage(c, http.StatusServiceUnavailable, "provider must be of type openai")
return
}
url = azureProvider.GetFullRequestURL(path, "")
} else {
url = openAIProvider.GetFullRequestURL(path, "")
}

headers := c.Request.Header
mapHeaders := provider.GetRequestHeaders()
// 设置请求头
for k, v := range headers {
if _, ok := mapHeaders[k]; ok {
continue
}
mapHeaders[k] = strings.Join(v, ", ")
}

requester := provider.GetRequester()
req, err := requester.NewRequest(c.Request.Method, url, requester.WithBody(c.Request.Body), requester.WithHeader(mapHeaders))
if err != nil {
common.AbortWithMessage(c, http.StatusBadRequest, err.Error())
return
}

defer req.Body.Close()

response, errWithCode := requester.SendRequestRaw(req)
if errWithCode != nil {
relayResponseWithErr(c, errWithCode)
return
}

errWithCode = responseMultipart(c, response)

if errWithCode != nil {
relayResponseWithErr(c, errWithCode)
return
}

requestTime := 0
requestStartTimeValue := c.Request.Context().Value("requestStartTime")
if requestStartTimeValue != nil {
requestStartTime, ok := requestStartTimeValue.(time.Time)
if ok {
requestTime = int(time.Since(requestStartTime).Milliseconds())
}
}
model.RecordConsumeLog(c.Request.Context(), c.GetInt("id"), c.GetInt("channel_id"), 0, 0, "", c.GetString("token_name"), 0, "中继:"+path, requestTime)

}
52 changes: 14 additions & 38 deletions router/relay-router.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package router

import (
"one-api/controller"
"one-api/middleware"
"one-api/relay"
"one-api/relay/midjourney"
Expand Down Expand Up @@ -38,43 +37,20 @@ func setOpenAIRouter(router *gin.Engine) {
relayV1Router.POST("/audio/translations", relay.Relay)
relayV1Router.POST("/audio/speech", relay.Relay)
relayV1Router.POST("/moderations", relay.Relay)
relayV1Router.GET("/files", controller.RelayNotImplemented)
relayV1Router.POST("/files", controller.RelayNotImplemented)
relayV1Router.DELETE("/files/:id", controller.RelayNotImplemented)
relayV1Router.GET("/files/:id", controller.RelayNotImplemented)
relayV1Router.GET("/files/:id/content", controller.RelayNotImplemented)
relayV1Router.POST("/fine_tuning/jobs", controller.RelayNotImplemented)
relayV1Router.GET("/fine_tuning/jobs", controller.RelayNotImplemented)
relayV1Router.GET("/fine_tuning/jobs/:id", controller.RelayNotImplemented)
relayV1Router.POST("/fine_tuning/jobs/:id/cancel", controller.RelayNotImplemented)
relayV1Router.GET("/fine_tuning/jobs/:id/events", controller.RelayNotImplemented)
relayV1Router.DELETE("/models/:model", controller.RelayNotImplemented)
relayV1Router.POST("/assistants", controller.RelayNotImplemented)
relayV1Router.GET("/assistants/:id", controller.RelayNotImplemented)
relayV1Router.POST("/assistants/:id", controller.RelayNotImplemented)
relayV1Router.DELETE("/assistants/:id", controller.RelayNotImplemented)
relayV1Router.GET("/assistants", controller.RelayNotImplemented)
relayV1Router.POST("/assistants/:id/files", controller.RelayNotImplemented)
relayV1Router.GET("/assistants/:id/files/:fileId", controller.RelayNotImplemented)
relayV1Router.DELETE("/assistants/:id/files/:fileId", controller.RelayNotImplemented)
relayV1Router.GET("/assistants/:id/files", controller.RelayNotImplemented)
relayV1Router.POST("/threads", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id", controller.RelayNotImplemented)
relayV1Router.POST("/threads/:id", controller.RelayNotImplemented)
relayV1Router.DELETE("/threads/:id", controller.RelayNotImplemented)
relayV1Router.POST("/threads/:id/messages", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id/messages/:messageId", controller.RelayNotImplemented)
relayV1Router.POST("/threads/:id/messages/:messageId", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id/messages/:messageId/files/:filesId", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id/messages/:messageId/files", controller.RelayNotImplemented)
relayV1Router.POST("/threads/:id/runs", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id/runs/:runsId", controller.RelayNotImplemented)
relayV1Router.POST("/threads/:id/runs/:runsId", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id/runs", controller.RelayNotImplemented)
relayV1Router.POST("/threads/:id/runs/:runsId/submit_tool_outputs", controller.RelayNotImplemented)
relayV1Router.POST("/threads/:id/runs/:runsId/cancel", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id/runs/:runsId/steps/:stepId", controller.RelayNotImplemented)
relayV1Router.GET("/threads/:id/runs/:runsId/steps", controller.RelayNotImplemented)

relayV1Router.Use(middleware.SpecifiedChannel())
{
relayV1Router.Any("/files", relay.RelayOnly)
relayV1Router.Any("/files/*any", relay.RelayOnly)
relayV1Router.Any("/fine_tuning/*any", relay.RelayOnly)
relayV1Router.Any("/assistants", relay.RelayOnly)
relayV1Router.Any("/assistants/*any", relay.RelayOnly)
relayV1Router.Any("/threads", relay.RelayOnly)
relayV1Router.Any("/threads/*any", relay.RelayOnly)
relayV1Router.Any("/batches/*any", relay.RelayOnly)
relayV1Router.Any("/vector_stores/*any", relay.RelayOnly)
relayV1Router.DELETE("/models/:model", relay.RelayOnly)
}
}
}

Expand Down

0 comments on commit 628df97

Please sign in to comment.