Skip to content

Commit

Permalink
streaming for LangChain gemini
Browse files Browse the repository at this point in the history
  • Loading branch information
stephmilovic committed Jun 10, 2024
1 parent 47edd78 commit 5210e07
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 98 deletions.
2 changes: 2 additions & 0 deletions x-pack/packages/kbn-langchain/server/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ import { ActionsClientChatOpenAI } from './language_models/chat_openai';
import { ActionsClientLlm } from './language_models/llm';
import { ActionsClientSimpleChatModel } from './language_models/simple_chat_model';
import { parseBedrockStream } from './utils/bedrock';
import { parseGeminiResponse } from './utils/gemini';
import { getDefaultArguments } from './language_models/constants';

export {
parseBedrockStream,
parseGeminiResponse,
getDefaultArguments,
ActionsClientChatOpenAI,
ActionsClientLlm,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,9 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
let streamingFinished = false;
const finalOutputStopRegex = /(?<!\\)\"/;
const handleLLMNewToken = async (token: string) => {
console.log('stephhh handleLlMNewToken', token);
if (finalOutputIndex === -1) {
// Remove whitespace to simplify parsing
currentOutput += token.replace(/\s/g, '');
console.log('stephhh currentOutput', currentOutput);
if (currentOutput.includes(finalOutputStartToken)) {
finalOutputIndex = currentOutput.indexOf(finalOutputStartToken);
}
Expand All @@ -175,7 +173,6 @@ export class ActionsClientSimpleChatModel extends SimpleChatModel {
if (finalOutputEndIndex !== -1) {
streamingFinished = true;
} else {
console.log('stephhh PUSH THIS TOKEN', token);
await runManager?.handleLLMNewToken(token);
}
}
Expand Down
75 changes: 13 additions & 62 deletions x-pack/packages/kbn-langchain/server/utils/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,34 @@ export const parseGeminiStream: StreamParser = async (
tokenHandler
) => {
let responseBody = '';
let responseBody2 = '';
stream.on('data', (chunk) => {
const decoded = chunk.toString();
const parsed = parseGeminiResponse(decoded);
if (tokenHandler) {
tokenHandler(parsed);
const splitByQuotes = parsed.split(`"`);
splitByQuotes.forEach((chunkk, index) => {
// add quote back on except for last chunk
const splitBySpace = `${chunkk}${index === splitByQuotes.length - 1 ? '' : '"'}`.split(` `);

for (const char of splitBySpace) {
tokenHandler(`${char} `);
}
});
}
responseBody += decoded;
responseBody2 += parsed;
responseBody += parsed;
});
return new Promise((resolve, reject) => {
stream.on('end', () => {
console.log('stephhh END responseBody', responseBody2);
const parsed = parseGeminiResponse(responseBody);
console.log('stephhh END parsed', parsed);
resolve(parseGeminiResponse(responseBody));
resolve(responseBody);
});
stream.on('error', (err) => {
reject(err);
});
if (abortSignal) {
abortSignal.addEventListener('abort', () => {
logger.info('Bedrock stream parsing was aborted.');
stream.destroy();
resolve(parseGeminiResponse(responseBody));
resolve(responseBody);
});
}
});
Expand Down Expand Up @@ -74,56 +78,3 @@ export const parseGeminiResponse = (responseBody: string) => {
return prev;
}, '');
};
//
// export const parseGeminiStream: StreamParser = async (
// responseStream,
// logger,
// abortSignal,
// tokenHandler
// ) => {
// const responseChunks: string[] = [];
// const decoder = new TextDecoder();
// if (abortSignal) {
// abortSignal.addEventListener('abort', () => {
// responseStream.destroy(new Error('Aborted'));
// return parseGeminiChunks(responseChunks, logger);
// });
// }
// responseStream.on('data', (chunk) => {
// const value = decoder.decode(chunk, { stream: true });
// console.log('stephhh value', value);
// const lines = value.split('\r');
// console.log('stephhh lines', lines);
// const parsedLines = parseGeminiChunks(lines, logger);
// console.log('stephhh parsedLines', parsedLines);
// const parsedChunk = parsedLines[0];
// responseChunks.push(parsedChunk);
// if (tokenHandler) {
// tokenHandler(parsedChunk);
// }
// });
//
// await finished(responseStream).catch((err) => {
// if (abortSignal?.aborted) {
// logger.info('Gemini stream parsing was aborted.');
// } else {
// throw err;
// }
// });
//
// return responseChunks.join(); // parseGeminiChunks(responseChunks, logger);
// };
//
// const parseGeminiChunks = (chunks: string[], logger: Logger) => {
// return chunks
// .filter((str) => !!str && str !== '[DONE]')
// .map((line) => {
// try {
// const newLine = line.replaceAll('data: ', '');
// const geminiResponse: GeminiResponseSchema = JSON.parse(newLine);
// return geminiResponse.candidates[0]?.content.parts.map((part) => part.text).join('') ?? '';
// } catch (err) {
// return '';
// }
// });
// };
34 changes: 1 addition & 33 deletions x-pack/plugins/elastic_assistant/server/lib/parse_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import { Readable } from 'stream';
import { Logger } from '@kbn/core/server';
import { parseBedrockStream } from '@kbn/langchain/server';
import { parseBedrockStream, parseGeminiResponse } from '@kbn/langchain/server';

type StreamParser = (
responseStream: Readable,
Expand Down Expand Up @@ -114,35 +114,3 @@ export const parseGeminiStream: StreamParser = async (stream, logger, abortSigna
}
});
};

/** Parse Gemini stream response body */
export const parseGeminiResponse = (responseBody: string) => {
return responseBody
.split('\n')
.filter((line) => line.startsWith('data: ') && !line.endsWith('[DONE]'))
.map((line) => JSON.parse(line.replace('data: ', '')))
.filter(
(
line
): line is {
candidates: Array<{
content: { role: string; parts: Array<{ text: string }> };
finishReason: string;
safetyRatings: Array<{ category: string; probability: string }>;
}>;
usageMetadata: {
promptTokenCount: number;
candidatesTokenCount: number;
totalTokenCount: number;
};
} => 'candidates' in line
)
.reduce((prev, line) => {
if (line.candidates[0] && line.candidates[0].content) {
const parts = line.candidates[0].content.parts;
const text = parts.map((part) => part.text).join('');
return prev + text;
}
return prev;
}, '');
};

0 comments on commit 5210e07

Please sign in to comment.