Skip to content

Commit

Permalink
feat: add gemini model
Browse files Browse the repository at this point in the history
  • Loading branch information
Mrxyy committed Dec 30, 2023
1 parent 1bf0ccb commit b3b996f
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 17 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
"test:e2e": "jest --config ./test/jest-e2e.json --watch --debug"
},
"dependencies": {
"@google/generative-ai": "^0.1.3",
"@nestjs/common": "^9.0.0",
"@nestjs/core": "^9.0.0",
"@nestjs/platform-express": "^9.0.0",
Expand All @@ -33,7 +34,6 @@
"langchain": "^0.0.149",
"lodash": "^4.17.21",
"mysql2": "^3.3.3",
"openai": "^4.24.0",
"reflect-metadata": "^0.1.13",
"rxjs": "^7.2.0",
"sequelize": "^6.32.0",
Expand Down
2 changes: 2 additions & 0 deletions src/models/events/generateCode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ export interface IGenerateCodeParams {
isTermOfServiceAccepted?: boolean;
accessCode?: boolean;
resultImage?: string;
llm?: 'Gemini' | 'OpenAi';
}

export async function streamGenerateCode(
Expand Down Expand Up @@ -94,6 +95,7 @@ export async function streamGenerateCode(
{
openAiApiKey: params.openAiApiKey,
openAiBaseURL: params.openAiBaseURL,
llm: params.llm,
},
);
} catch (e) {
Expand Down
163 changes: 147 additions & 16 deletions src/models/events/llm.ts
Original file line number Diff line number Diff line change
@@ -1,30 +1,161 @@
import OpenAI from 'openai';
import { GoogleGenerativeAI } from '@google/generative-ai';

export async function streamingOpenAIResponses(messages, callback, params) {
// 功能函数,用于从base64数据URL中提取MIME类型和纯base64数据部分
function extractMimeAndBase64(dataUrl: string) {
const matches = dataUrl.match(/^data:(.+);base64,(.*)$/);
if (!matches || matches.length !== 3) {
throw new Error('Invalid base64 data URL');
}
return { mimeType: matches[1], base64Data: matches[2] };
}

// 转换函数
function transformData(data: Record<any, any>[]) {
const parts = [];

// 遍历原始数据,合并文本内容
for (const item of data) {
if (item.content) {
if (typeof item.content === 'string') {
// 对于系统角色的文本内容
parts.push({ text: item.content });
} else if (Array.isArray(item.content)) {
// 对于用户角色的内容数组
for (const part of item.content) {
if (part.type === 'text') {
parts.push({ text: part.text });
} else if (part.type === 'image_url') {
// 提取MIME类型和base64数据
const { mimeType, base64Data } = extractMimeAndBase64(
part.image_url.url,
);
parts.push({
inlineData: {
data: base64Data,
mimeType: mimeType,
},
});
}
}
}
}
}

// 返回新的数据结构,所有文本和图像都合并到一个用户角色中
return [
{
role: 'user',
parts: parts,
},
];
}

async function useGeminiResponse([messages, callback, params]: Parameters<
typeof streamingOpenAIResponses
>) {
const genAI = new GoogleGenerativeAI(
params.openAiApiKey || process.env['OPENAI_API_KEY'],
);
const model = genAI.getGenerativeModel({ model: 'gemini-pro-vision' });
const generationConfig = {
temperature: 0,
topK: 32,
topP: 1,
maxOutputTokens: 30000,
};

const contents = transformData(messages);
const result = await model.generateContentStream({
contents: contents,
generationConfig,
});

let text = '';
let perText = '';
for await (const chunk of result.stream) {
if (perText) {
callback(perText);
text += perText;
}
const chunkText = text
? chunk.text()
: chunk.text().replace(/^\s*```html/g, '');
perText = chunkText;
}
perText = perText.replace(/```\s*$/g, '');
callback(perText);
text += perText;
return text;
}

export async function streamingOpenAIResponses(
messages: any[],
callback: {
(content: string, event?: string | undefined): void;
(arg0: string, arg1: string | undefined): void;
},
params: { openAiApiKey: any; openAiBaseURL: any; llm: string },
) {
if (params.llm === 'Gemini') {
const full_response = await useGeminiResponse([messages, callback, params]);
return full_response;
}
if (!params.openAiApiKey) {
callback('No openai key', 'error');
return '';
}
const openai = new OpenAI({
apiKey: params.openAiApiKey || process.env['OPEN_AI_API_KEY'], // defaults to process.env["OPENAI_API_KEY"]
baseURL:
params.openAiBaseURL ||
process.env['BASE_URL'] ||
'https://api.openai.com/v1',
});

const stream = await openai.chat.completions.create({
model: 'gpt-4-vision-preview',
temperature: 0,
max_tokens: 4096,
const openAi = [
'gpt-4-vision-preview',
params.openAiApiKey || process.env['OPENAI_API_KEY'],
process.env['OPENAI_BASE_URL'] ||
'https://api.openai.com/v1/chat/completions',
];

const [model, authorization, url] = openAi;

const body = JSON.stringify({
messages,
stream: true,
model: model,
temperature: 0,
max_tokens: 4096,
});

const res = await fetch(url, {
headers: {
accept: 'text/event-stream',
authorization: `Bearer ${authorization}`,
'content-type': 'application/json',
},
referrer: '',
referrerPolicy: 'strict-origin-when-cross-origin',
body,
method: 'POST',
mode: 'cors',
credentials: 'include',
});
let full_response = '';
const stream: any = res.body;
const decoder = new TextDecoder();
let perText = '';
for await (const chunk of stream) {
const content = chunk.choices[0]?.delta?.content || '';
full_response += content;
callback(content);
const string = decoder.decode(chunk);
const resArr = string
.trim()
.split(/\n\n/)
.map((v) => v.replace(/^data:/, '').trim());
resArr.forEach((item) => {
try {
const chunk = JSON.parse(perText + item);
const content = chunk.choices[0]?.delta?.content || '';
full_response += content;
callback(content);
perText = '';
} catch (e) {
perText += item;
}
});
}
return full_response;
}

0 comments on commit b3b996f

Please sign in to comment.