From b3cfaf142059a2349b7179365809496ec77aaca0 Mon Sep 17 00:00:00 2001 From: Dogtiti <499960698@qq.com> Date: Wed, 17 Jul 2024 20:05:34 +0800 Subject: [PATCH] refator: google --- app/api/google/[...path]/route.ts | 116 +++++++++++++++++------------- app/client/platforms/google.ts | 69 +++++++++--------- app/constant.ts | 3 +- 3 files changed, 101 insertions(+), 87 deletions(-) diff --git a/app/api/google/[...path]/route.ts b/app/api/google/[...path]/route.ts index 81e50538a56..83a7ce794c1 100644 --- a/app/api/google/[...path]/route.ts +++ b/app/api/google/[...path]/route.ts @@ -1,7 +1,15 @@ import { NextRequest, NextResponse } from "next/server"; import { auth } from "../../auth"; import { getServerSideConfig } from "@/app/config/server"; -import { GEMINI_BASE_URL, Google, ModelProvider } from "@/app/constant"; +import { + ApiPath, + GEMINI_BASE_URL, + Google, + ModelProvider, +} from "@/app/constant"; +import { prettyObject } from "@/app/utils/format"; + +const serverConfig = getServerSideConfig(); async function handle( req: NextRequest, @@ -13,32 +21,6 @@ async function handle( return NextResponse.json({ body: "OK" }, { status: 200 }); } - const controller = new AbortController(); - - const serverConfig = getServerSideConfig(); - - let baseUrl = serverConfig.googleUrl || GEMINI_BASE_URL; - - if (!baseUrl.startsWith("http")) { - baseUrl = `https://${baseUrl}`; - } - - if (baseUrl.endsWith("/")) { - baseUrl = baseUrl.slice(0, -1); - } - - let path = `${req.nextUrl.pathname}`.replaceAll("/api/google/", ""); - - console.log("[Proxy] ", path); - console.log("[Base Url]", baseUrl); - - const timeoutId = setTimeout( - () => { - controller.abort(); - }, - 10 * 60 * 1000, - ); - const authResult = auth(req, ModelProvider.GeminiPro); if (authResult.error) { return NextResponse.json(authResult, { @@ -49,9 +31,9 @@ async function handle( const bearToken = req.headers.get("Authorization") ?? ""; const token = bearToken.trim().replaceAll("Bearer ", "").trim(); - const key = token ? token : serverConfig.googleApiKey; + const apiKey = token ? token : serverConfig.googleApiKey; - if (!key) { + if (!apiKey) { return NextResponse.json( { error: true, @@ -62,10 +44,63 @@ async function handle( }, ); } + try { + const response = await request(req, apiKey); + return response; + } catch (e) { + console.error("[Google] ", e); + return NextResponse.json(prettyObject(e)); + } +} - const fetchUrl = `${baseUrl}/${path}?key=${key}${ - req?.nextUrl?.searchParams?.get("alt") == "sse" ? "&alt=sse" : "" +export const GET = handle; +export const POST = handle; + +export const runtime = "edge"; +export const preferredRegion = [ + "bom1", + "cle1", + "cpt1", + "gru1", + "hnd1", + "iad1", + "icn1", + "kix1", + "pdx1", + "sfo1", + "sin1", + "syd1", +]; + +async function request(req: NextRequest, apiKey: string) { + const controller = new AbortController(); + + let baseUrl = serverConfig.googleUrl || GEMINI_BASE_URL; + + let path = `${req.nextUrl.pathname}`.replaceAll(ApiPath.Google, ""); + + if (!baseUrl.startsWith("http")) { + baseUrl = `https://${baseUrl}`; + } + + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, -1); + } + + console.log("[Proxy] ", path); + console.log("[Base Url]", baseUrl); + + const timeoutId = setTimeout( + () => { + controller.abort(); + }, + 10 * 60 * 1000, + ); + const fetchUrl = `${baseUrl}${path}?key=${apiKey}${ + req?.nextUrl?.searchParams?.get("alt") === "sse" ? "&alt=sse" : "" }`; + + console.log("[Fetch Url] ", fetchUrl); const fetchOptions: RequestInit = { headers: { "Content-Type": "application/json", @@ -97,22 +132,3 @@ async function handle( clearTimeout(timeoutId); } } - -export const GET = handle; -export const POST = handle; - -export const runtime = "edge"; -export const preferredRegion = [ - "bom1", - "cle1", - "cpt1", - "gru1", - "hnd1", - "iad1", - "icn1", - "kix1", - "pdx1", - "sfo1", - "sin1", - "syd1", -]; diff --git a/app/client/platforms/google.ts b/app/client/platforms/google.ts index 6054c7a476e..8acde1a83f1 100644 --- a/app/client/platforms/google.ts +++ b/app/client/platforms/google.ts @@ -1,4 +1,4 @@ -import { Google, REQUEST_TIMEOUT_MS } from "@/app/constant"; +import { ApiPath, Google, REQUEST_TIMEOUT_MS } from "@/app/constant"; import { ChatOptions, getHeaders, LLMApi, LLMModel, LLMUsage } from "../api"; import { useAccessStore, useAppConfig, useChatStore } from "@/app/store"; import { getClientConfig } from "@/app/config/client"; @@ -16,6 +16,34 @@ import { } from "@/app/utils"; export class GeminiProApi implements LLMApi { + path(path: string): string { + const accessStore = useAccessStore.getState(); + + let baseUrl = ""; + if (accessStore.useCustomConfig) { + baseUrl = accessStore.googleUrl; + } + + if (baseUrl.length === 0) { + const isApp = !!getClientConfig()?.isApp; + baseUrl = isApp + ? DEFAULT_API_HOST + `/api/proxy/google?key=${accessStore.googleApiKey}` + : ApiPath.Google; + } + if (baseUrl.endsWith("/")) { + baseUrl = baseUrl.slice(0, baseUrl.length - 1); + } + if (!baseUrl.startsWith("http") && !baseUrl.startsWith(ApiPath.Google)) { + baseUrl = "https://" + baseUrl; + } + + console.log("[Proxy Endpoint] ", baseUrl, path); + + let chatPath = [baseUrl, path].join("/"); + + chatPath += chatPath.includes("?") ? "&alt=sse" : "?alt=sse"; + return chatPath; + } extractMessage(res: any) { console.log("[Response] gemini-pro response: ", res); @@ -108,30 +136,13 @@ export class GeminiProApi implements LLMApi { ], }; - const accessStore = useAccessStore.getState(); - - let baseUrl = ""; - - if (accessStore.useCustomConfig) { - baseUrl = accessStore.googleUrl; - } - - const isApp = !!getClientConfig()?.isApp; - let shouldStream = !!options.config.stream; const controller = new AbortController(); options.onController?.(controller); try { - if (!baseUrl && isApp) { - baseUrl = DEFAULT_API_HOST + "/api/proxy/google/"; - } - baseUrl = `${baseUrl}/${Google.ChatPath(modelConfig.model)}`.replaceAll( - "//", - "/", - ); - if (isApp) { - baseUrl += `?key=${accessStore.googleApiKey}`; - } + // https://github.com/google-gemini/cookbook/blob/main/quickstarts/rest/Streaming_REST.ipynb + const chatPath = this.path(Google.ChatPath(modelConfig.model)); + const chatPayload = { method: "POST", body: JSON.stringify(requestPayload), @@ -181,10 +192,6 @@ export class GeminiProApi implements LLMApi { controller.signal.onabort = finish; - // https://github.com/google-gemini/cookbook/blob/main/quickstarts/rest/Streaming_REST.ipynb - const chatPath = - baseUrl.replace("generateContent", "streamGenerateContent") + - (baseUrl.indexOf("?") > -1 ? "&alt=sse" : "?alt=sse"); fetchEventSource(chatPath, { ...chatPayload, async onopen(res) { @@ -259,7 +266,7 @@ export class GeminiProApi implements LLMApi { openWhenHidden: true, }); } else { - const res = await fetch(baseUrl, chatPayload); + const res = await fetch(chatPath, chatPayload); clearTimeout(requestTimeoutId); const resJson = await res.json(); if (resJson?.promptFeedback?.blockReason) { @@ -285,14 +292,4 @@ export class GeminiProApi implements LLMApi { async models(): Promise { return []; } - path(path: string): string { - return "/api/google/" + path; - } -} - -function ensureProperEnding(str: string) { - if (str.startsWith("[") && !str.endsWith("]")) { - return str + "]"; - } - return str; } diff --git a/app/constant.ts b/app/constant.ts index a146200d63b..bdc30067b70 100644 --- a/app/constant.ts +++ b/app/constant.ts @@ -118,7 +118,8 @@ export const Azure = { export const Google = { ExampleEndpoint: "https://generativelanguage.googleapis.com/", - ChatPath: (modelName: string) => `v1beta/models/${modelName}:generateContent`, + ChatPath: (modelName: string) => + `v1beta/models/${modelName}:streamGenerateContent`, }; export const Baidu = {