Skip to content

Commit

Permalink
fix: modify polyfill
Browse files Browse the repository at this point in the history
  • Loading branch information
jeasonstudio committed Jul 16, 2024
1 parent dce6c87 commit dc098d9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 28 deletions.
12 changes: 5 additions & 7 deletions src/embedding-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@ import { TextEmbedder, FilesetResolver } from '@mediapipe/tasks-text';
export interface ChromeAIEmbeddingModelSettings {
/**
* An optional base path to specify the directory the Wasm files should be loaded from.
* It's about 6mb before gzip.
* @default 'https://pub-ddcfe353995744e89b8002f16bf98575.r2.dev/text_wasm_internal.js'
*/
wasmLoaderPath?: string;
/**
* It's about 6mb before gzip.
* @default 'https://pub-ddcfe353995744e89b8002f16bf98575.r2.dev/text_wasm_internal.wasm'
*/
wasmBinaryPath?: string;
Expand Down Expand Up @@ -61,19 +61,18 @@ export class ChromeAIEmbeddingModel implements EmbeddingModelV1<string> {
quantize: false,
};
private modelAssetBuffer!: Promise<ReadableStreamDefaultReader>;
private textEmbedder: TextEmbedder | null = null;
private textEmbedder!: Promise<TextEmbedder>;

public constructor(settings: ChromeAIEmbeddingModelSettings = {}) {
this.settings = { ...this.settings, ...settings };
this.modelAssetBuffer = fetch(this.settings.modelAssetPath!).then(
(response) => response.body!.getReader()
)!;
this.getTextEmbedder();
this.textEmbedder = this.getTextEmbedder();
}

protected getTextEmbedder = async (): Promise<TextEmbedder> => {
if (this.textEmbedder !== null) return this.textEmbedder;
this.textEmbedder = await TextEmbedder.createFromOptions(
return TextEmbedder.createFromOptions(
{
wasmBinaryPath: this.settings.wasmBinaryPath!,
wasmLoaderPath: this.settings.wasmLoaderPath!,
Expand All @@ -87,7 +86,6 @@ export class ChromeAIEmbeddingModel implements EmbeddingModelV1<string> {
quantize: this.settings.quantize,
}
);
return this.textEmbedder;
};

public doEmbed = async (options: {
Expand All @@ -98,7 +96,7 @@ export class ChromeAIEmbeddingModel implements EmbeddingModelV1<string> {
rawResponse?: Record<PropertyKey, any>;
}> => {
// if (options.abortSignal) console.warn('abortSignal is not supported');
const embedder = await this.getTextEmbedder();
const embedder = await this.textEmbedder;
const embeddings = options.values.map((text) => {
const embedderResult = embedder.embed(text);
const [embedding] = embedderResult.embeddings;
Expand Down
44 changes: 23 additions & 21 deletions src/polyfill/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ class PolyfillChromeAISession implements ChromeAISession {
debug('PolyfillChromeAISession created', llm);
}

private generateText = async (prompt: string): Promise<string> => {
public prompt = async (prompt: string): Promise<string> => {
const response = await this.llm.generateResponse(prompt);
debug('generateText', prompt, response);
debug('prompt', prompt, response);
return response;
};

private streamText = (prompt: string): ReadableStream<string> => {
debug('streamText', prompt);
public promptStreaming = (prompt: string): ReadableStream<string> => {
debug('promptStreaming', prompt);
const stream = new ReadableStream<string>({
start: (controller) => {
const listener: ProgressListener = (
Expand All @@ -40,15 +40,11 @@ class PolyfillChromeAISession implements ChromeAISession {
console.warn('stream text canceled', reason);
},
});
debug('streamText', prompt);
debug('promptStreaming', prompt);
return stream;
};

public destroy = async () => this.llm.close();
public prompt = this.generateText;
public execute = this.generateText;
public promptStreaming = this.streamText;
public executeStreaming = this.streamText;
}

/**
Expand All @@ -75,22 +71,32 @@ export class PolyfillChromeAI implements ChromePromptAPI {

private modelAssetBuffer: Promise<ReadableStreamDefaultReader>;

private canCreateSession = async (): Promise<ChromeAISessionAvailable> => {
// TODO@jeasonstudio:
// * if browser do not support WebAssembly/WebGPU, return 'no';
// * check if modelAssetBuffer is downloaded, if not, return 'after-download';
return 'readily';
public canCreateTextSession = async (): Promise<ChromeAISessionAvailable> => {
// If browser do not support WebAssembly/WebGPU, return 'no';
if (typeof WebAssembly.instantiate !== 'function') return 'no';
if (!(<any>navigator).gpu) return 'no';

// Check if modelAssetBuffer is downloaded, if not, return 'after-download';
const isModelAssetBufferReady = await Promise.race([
this.modelAssetBuffer,
Promise.resolve('sentinel'),
])
.then((value) => value === 'sentinel')
.catch(() => true);

return isModelAssetBufferReady ? 'readily' : 'after-download';
};
private defaultSessionOptions =

public defaultTextSessionOptions =
async (): Promise<ChromeAISessionOptions> => ({
temperature: 0.8,
topK: 3,
});

private createSession = async (
public createTextSession = async (
options?: ChromeAISessionOptions
): Promise<ChromeAISession> => {
const argv = options ?? (await this.defaultSessionOptions());
const argv = options ?? (await this.defaultTextSessionOptions());
const llm = await LlmInference.createFromOptions(
{
wasmLoaderPath: this.aiOptions.wasmLoaderPath!,
Expand All @@ -108,10 +114,6 @@ export class PolyfillChromeAI implements ChromePromptAPI {
debug('createSession', options, session);
return session;
};

public canCreateTextSession = this.canCreateSession;
public defaultTextSessionOptions = this.defaultSessionOptions;
public createTextSession = this.createSession;
}

export const polyfillChromeAI = (
Expand Down

0 comments on commit dc098d9

Please sign in to comment.