Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose the tokenizer to clients #622

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ import {
} from "./error";
import { asyncLoadTokenizer } from "./cache_util";
import { EmbeddingPipeline } from "./embedding";
import { Tokenizer } from "@mlc-ai/web-tokenizers";

/**
* Creates `MLCEngine`, and loads `modelId` onto WebGPU.
Expand Down Expand Up @@ -131,6 +132,7 @@ export class MLCEngine implements MLCEngineInterface {
private logitProcessorRegistry?: Map<string, LogitProcessor>;
private initProgressCallback?: InitProgressCallback;
private appConfig: AppConfig;
private tokenizer: Tokenizer | null = null;

// Signals and flags
private interruptSignal = false;
Expand Down Expand Up @@ -359,7 +361,7 @@ export class MLCEngine implements MLCEngineInterface {
});
tvm.initWebGPU(gpuDetectOutput.device);

const tokenizer = await asyncLoadTokenizer(
this.tokenizer = await asyncLoadTokenizer(
modelUrl,
curModelConfig,
this.appConfig,
Expand All @@ -379,11 +381,11 @@ export class MLCEngine implements MLCEngineInterface {
// embedding model, and prompt user to use ModelRecord.model_type
let newPipeline: LLMChatPipeline | EmbeddingPipeline;
if (modelRecord.model_type === ModelType.embedding) {
newPipeline = new EmbeddingPipeline(tvm, tokenizer, curModelConfig);
newPipeline = new EmbeddingPipeline(tvm, this.tokenizer, curModelConfig);
} else {
newPipeline = new LLMChatPipeline(
tvm,
tokenizer,
this.tokenizer,
curModelConfig,
logitProcessor,
);
Expand Down Expand Up @@ -1333,4 +1335,16 @@ export class MLCEngine implements MLCEngineInterface {
async decode(pipeline: LLMChatPipeline, genConfig?: GenerationConfig) {
return pipeline.decodeStep(genConfig);
}

//-----------------------------------------------
// 8. Expose tokenizer
//-----------------------------------------------

async tokenize(text: string) {
return this.tokenizer!.encode(text);
}

async decodeTokens(ids: Int32Array) {
return this.tokenizer!.decode(ids);
}
}
11 changes: 10 additions & 1 deletion src/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ type RequestKind =
| "customRequest"
| "keepAlive"
| "setLogLevel"
| "setAppConfig";
| "setAppConfig"
| "tokenize"
| "decodeTokens";

// eslint-disable-next-line @typescript-eslint/no-unused-vars
type ResponseKind = "return" | "throw" | "initProgressCallback";
Expand All @@ -58,6 +60,12 @@ export interface ForwardTokensAndSampleParams {
isPrefill: boolean;
modelId?: string;
}
export interface TokenizeParams {
text: string;
}
export interface DecodeTokensParams {
inputIds: Int32Array;
}

// Notes on the following Params with modelId and chatOpts:
// These fields are the model and chatOpts that the frontend engine expects the backend
Expand Down Expand Up @@ -128,6 +136,7 @@ export type MessageContent =
| CreateEmbeddingResponse
| Completion
| AppConfig
| Int32Array
| void;
/**
* The message used in exchange between worker
Expand Down
6 changes: 6 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,12 @@ export interface MLCEngineInterface {
*/
embedding(request: EmbeddingCreateParams): Promise<CreateEmbeddingResponse>;

/**
* Exposes the tokenizer for clients to avoid needing to load it twice
*/
tokenize(input: string): Promise<Int32Array>;
decodeTokens(input: Int32Array): Promise<string>;

/**
* @returns A text summarizing the runtime stats.
* @param modelId Only required when multiple models are loaded.
Expand Down
44 changes: 44 additions & 0 deletions src/web_worker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import {
MessageContent,
ReloadParams,
ForwardTokensAndSampleParams,
TokenizeParams,
DecodeTokensParams,
ChatCompletionNonStreamingParams,
ChatCompletionStreamInitParams,
ResetChatParams,
Expand Down Expand Up @@ -345,6 +347,24 @@ export class WebWorkerMLCEngineHandler {
onComplete?.(null);
return;
}
case "decodeTokens": {
this.handleTask(msg.uuid, async () => {
const params = msg.content as DecodeTokensParams;
const res = await this.engine.decodeTokens(params.inputIds);
onComplete?.(res);
return res;
});
return;
}
case "tokenize": {
this.handleTask(msg.uuid, async () => {
const params = msg.content as TokenizeParams;
const res = await this.engine.tokenize(params.text);
onComplete?.(res);
return res;
});
return;
}
default: {
if (msg.kind && msg.content) {
onError?.();
Expand Down Expand Up @@ -633,6 +653,30 @@ export class WebWorkerMLCEngine implements MLCEngineInterface {
return await this.getPromise<number>(msg);
}

async tokenize(text: string, modelId?: string) {
const msg: WorkerRequest = {
kind: "tokenize",
uuid: crypto.randomUUID(),
content: {
text,
modelId: modelId,
},
};
return await this.getPromise<Int32Array>(msg);
}

async decodeTokens(ids: Int32Array, modelId?: string) {
const msg: WorkerRequest = {
kind: "decodeTokens",
uuid: crypto.randomUUID(),
content: {
inputIds: Array.from(ids),
modelId: modelId,
},
};
return await this.getPromise<string>(msg);
}

/**
* Every time the generator is called, we post a message to the worker asking it to
* decode one step, and we expect to receive a message of `ChatCompletionChunk` from
Expand Down