diff --git a/src/engine.ts b/src/engine.ts index 82df2127..c2fc3069 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -70,6 +70,7 @@ import { } from "./error"; import { asyncLoadTokenizer } from "./cache_util"; import { EmbeddingPipeline } from "./embedding"; +import { Tokenizer } from "@mlc-ai/web-tokenizers"; /** * Creates `MLCEngine`, and loads `modelId` onto WebGPU. @@ -131,6 +132,7 @@ export class MLCEngine implements MLCEngineInterface { private logitProcessorRegistry?: Map; private initProgressCallback?: InitProgressCallback; private appConfig: AppConfig; + private tokenizer: Tokenizer | null = null; // Signals and flags private interruptSignal = false; @@ -359,7 +361,7 @@ export class MLCEngine implements MLCEngineInterface { }); tvm.initWebGPU(gpuDetectOutput.device); - const tokenizer = await asyncLoadTokenizer( + this.tokenizer = await asyncLoadTokenizer( modelUrl, curModelConfig, this.appConfig, @@ -379,11 +381,11 @@ export class MLCEngine implements MLCEngineInterface { // embedding model, and prompt user to use ModelRecord.model_type let newPipeline: LLMChatPipeline | EmbeddingPipeline; if (modelRecord.model_type === ModelType.embedding) { - newPipeline = new EmbeddingPipeline(tvm, tokenizer, curModelConfig); + newPipeline = new EmbeddingPipeline(tvm, this.tokenizer, curModelConfig); } else { newPipeline = new LLMChatPipeline( tvm, - tokenizer, + this.tokenizer, curModelConfig, logitProcessor, ); @@ -1333,4 +1335,16 @@ export class MLCEngine implements MLCEngineInterface { async decode(pipeline: LLMChatPipeline, genConfig?: GenerationConfig) { return pipeline.decodeStep(genConfig); } + + //----------------------------------------------- + // 8. Expose tokenizer + //----------------------------------------------- + + async tokenize(text: string) { + return this.tokenizer!.encode(text); + } + + async decodeTokens(ids: Int32Array) { + return this.tokenizer!.decode(ids); + } } diff --git a/src/message.ts b/src/message.ts index 8be38bb7..913d64d1 100644 --- a/src/message.ts +++ b/src/message.ts @@ -34,7 +34,9 @@ type RequestKind = | "customRequest" | "keepAlive" | "setLogLevel" - | "setAppConfig"; + | "setAppConfig" + | "tokenize" + | "decodeTokens"; // eslint-disable-next-line @typescript-eslint/no-unused-vars type ResponseKind = "return" | "throw" | "initProgressCallback"; @@ -58,6 +60,12 @@ export interface ForwardTokensAndSampleParams { isPrefill: boolean; modelId?: string; } +export interface TokenizeParams { + text: string; +} +export interface DecodeTokensParams { + inputIds: Int32Array; +} // Notes on the following Params with modelId and chatOpts: // These fields are the model and chatOpts that the frontend engine expects the backend @@ -128,6 +136,7 @@ export type MessageContent = | CreateEmbeddingResponse | Completion | AppConfig + | Int32Array | void; /** * The message used in exchange between worker diff --git a/src/types.ts b/src/types.ts index 4d4522c0..c831b0aa 100644 --- a/src/types.ts +++ b/src/types.ts @@ -172,6 +172,12 @@ export interface MLCEngineInterface { */ embedding(request: EmbeddingCreateParams): Promise; + /** + * Exposes the tokenizer for clients to avoid needing to load it twice + */ + tokenize(input: string): Promise; + decodeTokens(input: Int32Array): Promise; + /** * @returns A text summarizing the runtime stats. * @param modelId Only required when multiple models are loaded. diff --git a/src/web_worker.ts b/src/web_worker.ts index c683c004..9807b6f8 100644 --- a/src/web_worker.ts +++ b/src/web_worker.ts @@ -26,6 +26,8 @@ import { MessageContent, ReloadParams, ForwardTokensAndSampleParams, + TokenizeParams, + DecodeTokensParams, ChatCompletionNonStreamingParams, ChatCompletionStreamInitParams, ResetChatParams, @@ -345,6 +347,24 @@ export class WebWorkerMLCEngineHandler { onComplete?.(null); return; } + case "decodeTokens": { + this.handleTask(msg.uuid, async () => { + const params = msg.content as DecodeTokensParams; + const res = await this.engine.decodeTokens(params.inputIds); + onComplete?.(res); + return res; + }); + return; + } + case "tokenize": { + this.handleTask(msg.uuid, async () => { + const params = msg.content as TokenizeParams; + const res = await this.engine.tokenize(params.text); + onComplete?.(res); + return res; + }); + return; + } default: { if (msg.kind && msg.content) { onError?.(); @@ -633,6 +653,30 @@ export class WebWorkerMLCEngine implements MLCEngineInterface { return await this.getPromise(msg); } + async tokenize(text: string, modelId?: string) { + const msg: WorkerRequest = { + kind: "tokenize", + uuid: crypto.randomUUID(), + content: { + text, + modelId: modelId, + }, + }; + return await this.getPromise(msg); + } + + async decodeTokens(ids: Int32Array, modelId?: string) { + const msg: WorkerRequest = { + kind: "decodeTokens", + uuid: crypto.randomUUID(), + content: { + inputIds: Array.from(ids), + modelId: modelId, + }, + }; + return await this.getPromise(msg); + } + /** * Every time the generator is called, we post a message to the worker asking it to * decode one step, and we expect to receive a message of `ChatCompletionChunk` from