forked from songquanpeng/one-api
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add Google Gemini Pro support (songquanpeng#826)
* fest: Add Google Gemini Pro, fix songquanpeng#810 * fest: Add tooling to Gemini; Add OpenAI-like system prompt to Gemini * refactor: removing unused if statement * fest: Add dummy model message for system message in gemini model * chore: update implementation --------- Co-authored-by: JustSong <[email protected]>
- Loading branch information
1 parent
39086ab
commit 128c44b
Showing
12 changed files
with
353 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,281 @@ | ||
package controller | ||
|
||
import ( | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"one-api/common" | ||
|
||
"github.com/gin-gonic/gin" | ||
) | ||
|
||
type GeminiChatRequest struct { | ||
Contents []GeminiChatContent `json:"contents"` | ||
SafetySettings []GeminiChatSafetySettings `json:"safety_settings,omitempty"` | ||
GenerationConfig GeminiChatGenerationConfig `json:"generation_config,omitempty"` | ||
Tools []GeminiChatTools `json:"tools,omitempty"` | ||
} | ||
|
||
type GeminiInlineData struct { | ||
MimeType string `json:"mimeType"` | ||
Data string `json:"data"` | ||
} | ||
|
||
type GeminiPart struct { | ||
Text string `json:"text,omitempty"` | ||
InlineData *GeminiInlineData `json:"inlineData,omitempty"` | ||
} | ||
|
||
type GeminiChatContent struct { | ||
Role string `json:"role,omitempty"` | ||
Parts []GeminiPart `json:"parts"` | ||
} | ||
|
||
type GeminiChatSafetySettings struct { | ||
Category string `json:"category"` | ||
Threshold string `json:"threshold"` | ||
} | ||
|
||
type GeminiChatTools struct { | ||
FunctionDeclarations any `json:"functionDeclarations,omitempty"` | ||
} | ||
|
||
type GeminiChatGenerationConfig struct { | ||
Temperature float64 `json:"temperature,omitempty"` | ||
TopP float64 `json:"topP,omitempty"` | ||
TopK float64 `json:"topK,omitempty"` | ||
MaxOutputTokens int `json:"maxOutputTokens,omitempty"` | ||
CandidateCount int `json:"candidateCount,omitempty"` | ||
StopSequences []string `json:"stopSequences,omitempty"` | ||
} | ||
|
||
// Setting safety to the lowest possible values since Gemini is already powerless enough | ||
func requestOpenAI2Gemini(textRequest GeneralOpenAIRequest) *GeminiChatRequest { | ||
geminiRequest := GeminiChatRequest{ | ||
Contents: make([]GeminiChatContent, 0, len(textRequest.Messages)), | ||
//SafetySettings: []GeminiChatSafetySettings{ | ||
// { | ||
// Category: "HARM_CATEGORY_HARASSMENT", | ||
// Threshold: "BLOCK_ONLY_HIGH", | ||
// }, | ||
// { | ||
// Category: "HARM_CATEGORY_HATE_SPEECH", | ||
// Threshold: "BLOCK_ONLY_HIGH", | ||
// }, | ||
// { | ||
// Category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", | ||
// Threshold: "BLOCK_ONLY_HIGH", | ||
// }, | ||
// { | ||
// Category: "HARM_CATEGORY_DANGEROUS_CONTENT", | ||
// Threshold: "BLOCK_ONLY_HIGH", | ||
// }, | ||
//}, | ||
GenerationConfig: GeminiChatGenerationConfig{ | ||
Temperature: textRequest.Temperature, | ||
TopP: textRequest.TopP, | ||
MaxOutputTokens: textRequest.MaxTokens, | ||
}, | ||
} | ||
if textRequest.Functions != nil { | ||
geminiRequest.Tools = []GeminiChatTools{ | ||
{ | ||
FunctionDeclarations: textRequest.Functions, | ||
}, | ||
} | ||
} | ||
shouldAddDummyModelMessage := false | ||
for _, message := range textRequest.Messages { | ||
content := GeminiChatContent{ | ||
Role: message.Role, | ||
Parts: []GeminiPart{ | ||
{ | ||
Text: message.StringContent(), | ||
}, | ||
}, | ||
} | ||
// there's no assistant role in gemini and API shall vomit if Role is not user or model | ||
if content.Role == "assistant" { | ||
content.Role = "model" | ||
} | ||
// Converting system prompt to prompt from user for the same reason | ||
if content.Role == "system" { | ||
content.Role = "user" | ||
shouldAddDummyModelMessage = true | ||
} | ||
geminiRequest.Contents = append(geminiRequest.Contents, content) | ||
|
||
// If a system message is the last message, we need to add a dummy model message to make gemini happy | ||
if shouldAddDummyModelMessage { | ||
geminiRequest.Contents = append(geminiRequest.Contents, GeminiChatContent{ | ||
Role: "model", | ||
Parts: []GeminiPart{ | ||
{ | ||
Text: "ok", | ||
}, | ||
}, | ||
}) | ||
shouldAddDummyModelMessage = false | ||
} | ||
} | ||
|
||
return &geminiRequest | ||
} | ||
|
||
type GeminiChatResponse struct { | ||
Candidates []GeminiChatCandidate `json:"candidates"` | ||
PromptFeedback GeminiChatPromptFeedback `json:"promptFeedback"` | ||
} | ||
|
||
type GeminiChatCandidate struct { | ||
Content GeminiChatContent `json:"content"` | ||
FinishReason string `json:"finishReason"` | ||
Index int64 `json:"index"` | ||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` | ||
} | ||
|
||
type GeminiChatSafetyRating struct { | ||
Category string `json:"category"` | ||
Probability string `json:"probability"` | ||
} | ||
|
||
type GeminiChatPromptFeedback struct { | ||
SafetyRatings []GeminiChatSafetyRating `json:"safetyRatings"` | ||
} | ||
|
||
func responseGeminiChat2OpenAI(response *GeminiChatResponse) *OpenAITextResponse { | ||
fullTextResponse := OpenAITextResponse{ | ||
Id: fmt.Sprintf("chatcmpl-%s", common.GetUUID()), | ||
Object: "chat.completion", | ||
Created: common.GetTimestamp(), | ||
Choices: make([]OpenAITextResponseChoice, 0, len(response.Candidates)), | ||
} | ||
for i, candidate := range response.Candidates { | ||
choice := OpenAITextResponseChoice{ | ||
Index: i, | ||
Message: Message{ | ||
Role: "assistant", | ||
Content: candidate.Content.Parts[0].Text, | ||
}, | ||
FinishReason: stopFinishReason, | ||
} | ||
fullTextResponse.Choices = append(fullTextResponse.Choices, choice) | ||
} | ||
return &fullTextResponse | ||
} | ||
|
||
func streamResponseGeminiChat2OpenAI(geminiResponse *GeminiChatResponse) *ChatCompletionsStreamResponse { | ||
var choice ChatCompletionsStreamResponseChoice | ||
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { | ||
choice.Delta.Content = geminiResponse.Candidates[0].Content.Parts[0].Text | ||
} | ||
choice.FinishReason = &stopFinishReason | ||
var response ChatCompletionsStreamResponse | ||
response.Object = "chat.completion.chunk" | ||
response.Model = "gemini" | ||
response.Choices = []ChatCompletionsStreamResponseChoice{choice} | ||
return &response | ||
} | ||
|
||
func geminiChatStreamHandler(c *gin.Context, resp *http.Response) (*OpenAIErrorWithStatusCode, string) { | ||
responseText := "" | ||
responseId := fmt.Sprintf("chatcmpl-%s", common.GetUUID()) | ||
createdTime := common.GetTimestamp() | ||
dataChan := make(chan string) | ||
stopChan := make(chan bool) | ||
go func() { | ||
responseBody, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
common.SysError("error reading stream response: " + err.Error()) | ||
stopChan <- true | ||
return | ||
} | ||
err = resp.Body.Close() | ||
if err != nil { | ||
common.SysError("error closing stream response: " + err.Error()) | ||
stopChan <- true | ||
return | ||
} | ||
var geminiResponse GeminiChatResponse | ||
err = json.Unmarshal(responseBody, &geminiResponse) | ||
if err != nil { | ||
common.SysError("error unmarshalling stream response: " + err.Error()) | ||
stopChan <- true | ||
return | ||
} | ||
fullTextResponse := streamResponseGeminiChat2OpenAI(&geminiResponse) | ||
fullTextResponse.Id = responseId | ||
fullTextResponse.Created = createdTime | ||
if len(geminiResponse.Candidates) > 0 && len(geminiResponse.Candidates[0].Content.Parts) > 0 { | ||
responseText += geminiResponse.Candidates[0].Content.Parts[0].Text | ||
} | ||
jsonResponse, err := json.Marshal(fullTextResponse) | ||
if err != nil { | ||
common.SysError("error marshalling stream response: " + err.Error()) | ||
stopChan <- true | ||
return | ||
} | ||
dataChan <- string(jsonResponse) | ||
stopChan <- true | ||
}() | ||
setEventStreamHeaders(c) | ||
c.Stream(func(w io.Writer) bool { | ||
select { | ||
case data := <-dataChan: | ||
c.Render(-1, common.CustomEvent{Data: "data: " + data}) | ||
return true | ||
case <-stopChan: | ||
c.Render(-1, common.CustomEvent{Data: "data: [DONE]"}) | ||
return false | ||
} | ||
}) | ||
err := resp.Body.Close() | ||
if err != nil { | ||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), "" | ||
} | ||
return nil, responseText | ||
} | ||
|
||
func geminiChatHandler(c *gin.Context, resp *http.Response, promptTokens int, model string) (*OpenAIErrorWithStatusCode, *Usage) { | ||
responseBody, err := io.ReadAll(resp.Body) | ||
if err != nil { | ||
return errorWrapper(err, "read_response_body_failed", http.StatusInternalServerError), nil | ||
} | ||
err = resp.Body.Close() | ||
if err != nil { | ||
return errorWrapper(err, "close_response_body_failed", http.StatusInternalServerError), nil | ||
} | ||
var geminiResponse GeminiChatResponse | ||
err = json.Unmarshal(responseBody, &geminiResponse) | ||
if err != nil { | ||
return errorWrapper(err, "unmarshal_response_body_failed", http.StatusInternalServerError), nil | ||
} | ||
if len(geminiResponse.Candidates) == 0 { | ||
return &OpenAIErrorWithStatusCode{ | ||
OpenAIError: OpenAIError{ | ||
Message: "No candidates returned", | ||
Type: "server_error", | ||
Param: "", | ||
Code: 500, | ||
}, | ||
StatusCode: resp.StatusCode, | ||
}, nil | ||
} | ||
fullTextResponse := responseGeminiChat2OpenAI(&geminiResponse) | ||
completionTokens := countTokenText(geminiResponse.Candidates[0].Content.Parts[0].Text, model) | ||
usage := Usage{ | ||
PromptTokens: promptTokens, | ||
CompletionTokens: completionTokens, | ||
TotalTokens: promptTokens + completionTokens, | ||
} | ||
fullTextResponse.Usage = usage | ||
jsonResponse, err := json.Marshal(fullTextResponse) | ||
if err != nil { | ||
return errorWrapper(err, "marshal_response_body_failed", http.StatusInternalServerError), nil | ||
} | ||
c.Writer.Header().Set("Content-Type", "application/json") | ||
c.Writer.WriteHeader(resp.StatusCode) | ||
_, err = c.Writer.Write(jsonResponse) | ||
return nil, &usage | ||
} |
Oops, something went wrong.