Skip to content

Commit

Permalink
[LLMChat] Make llm_chat compatible with PagedKVCache (#293)
Browse files Browse the repository at this point in the history
PagedKVCache is introduced in MLC-LLM a while back to unite the
interface for KVCache. This PR makes WebLLM compatible with the new
PagedKVCache interface, encapsulating it with the goal that WebLLM users
will not notice any difference.

This PR is equivalent to the changes to `llm_chat.cc` in
mlc-ai/mlc-llm#1651, and should address issues
like mlc-ai/mlc-llm#1628.

There are still existing model compilation issues regarding
`workgroup_size` (since WebGPU, unlike most other backends, can only
support 256 number of threads). We will address this issue more
elegantly soon; for now, compiling llama-based models require manually
changing kernel sizes as shown in [this
branch](https://github.com/CharlieFRuan/mlc-llm/tree/local-workgroupSize-webLLM-kvCache).

This PR is also largely dependent on
apache/tvm#16554.
  • Loading branch information
CharlieFRuan authored Feb 13, 2024
1 parent 3319d1c commit ec2662f
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 32 deletions.
1 change: 0 additions & 1 deletion src/chat_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ export class ChatModule implements ChatInterface {

this.pipeline = new LLMChatPipeline(tvm, tokenizer, config, this.logitProcessor);
await this.pipeline?.asyncLoadWebGPUPipelines();

const tend = performance.now();

if (this.initProgressCallback !== undefined) {
Expand Down
147 changes: 116 additions & 31 deletions src/llm_chat.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
/* eslint-disable @typescript-eslint/no-non-null-assertion */
/* eslint-disable no-prototype-builtins */
import * as tvmjs from "tvmjs";
import { Tokenizer } from "@mlc-ai/web-tokenizers";
import { ChatConfig } from "./config";
Expand All @@ -16,6 +18,12 @@ export class LLMChatPipeline {
private prefill: tvmjs.PackedFunc;
private decoding: tvmjs.PackedFunc;
private fclearKVCaches: tvmjs.PackedFunc;
// Functions for PagedKVCache only
private embed?: tvmjs.PackedFunc = undefined;
private fKVCacheAddSequence?: tvmjs.PackedFunc = undefined;
private fKVCacheRemoveSequence?: tvmjs.PackedFunc = undefined;
private fKVCacheBeginForward?: tvmjs.PackedFunc = undefined;
private fKVCacheEndForward?: tvmjs.PackedFunc = undefined;

// parameter states
private params: tvmjs.TVMObject;
Expand All @@ -41,10 +49,11 @@ export class LLMChatPipeline {
private appearedTokens = new Set<number>();
private conversation: Conversation;
// Whether sink is in action
private sinkTriggered: boolean = false;
private sinkTriggered = false;
// sliding window cache offset (Next position to be overridden on the rolling kv cache.)
private slidingWindowCacheOffset: number = 0;
// Total amount of seq len prefilled so far
private slidingWindowCacheOffset = 0;
// Whether we are using PagedKVCache (eventually this will become default)
private usePagedKVCache = false;

// stats
private decodingTotalTime = 0;
Expand All @@ -59,6 +68,7 @@ export class LLMChatPipeline {
private logitProcessor?: LogitProcessor = undefined;

constructor(tvm: tvmjs.Instance, tokenizer: Tokenizer, config: ChatConfig, logitProcessor?: LogitProcessor) {
// 0. Setting attributes
this.tvm = tvm;
this.tokenizer = tokenizer;
this.config = config;
Expand All @@ -69,19 +79,28 @@ export class LLMChatPipeline {
this.stopTokens = this.conversation.getStopTokens();

this.device = this.tvm.webgpu();
tvm.beginScope();

// 1. Create VM and get the core functions
tvm.beginScope();
this.vm = this.tvm.detachFromCurrentScope(
this.tvm.createVirtualMachine(this.device)
);
this.prefill = this.tvm.detachFromCurrentScope(
this.vm.getFunction("prefill")
);
try {
// We expect to find `embed` if we usePagedKVCache
this.embed = this.tvm.detachFromCurrentScope(
this.vm.getFunction("embed")
);
} catch {
// Do nothing
}
this.decoding = this.tvm.detachFromCurrentScope(
this.vm.getFunction("decode")
);

// Get json stored in the vm's metadata function
// 2. Get json stored in the vm's metadata function
let fgetMetadata;
let useSLIM = false; // SLIM is the new workflow
try {
Expand All @@ -94,7 +113,7 @@ export class LLMChatPipeline {
const metadataStr = this.tvm.detachFromCurrentScope(ret_value).toString();
const metadata = JSON.parse(metadataStr);

// Load parameters
// 3. Load parameters
if (useSLIM) {
// Under SLIM workflow, we load parameters by name
const paramNames: string[] = [];
Expand All @@ -109,14 +128,16 @@ export class LLMChatPipeline {
);
}

if (metadata.hasOwnProperty("prefill_chunk_size") && metadata.prefill_chunk_size != -1) {
// 4. Read in compilation configurations from metadata
if (metadata.hasOwnProperty("prefill_chunk_size")) {
this.prefillChunkSize = metadata.prefill_chunk_size;
this.logger("Using prefillChunkSize: ", this.prefillChunkSize);
if (this.prefillChunkSize <= 0) {
throw Error("Prefill chunk size needs to be positive.");
}
} else {
throw Error("Cannot find `prefill_chunk_size` in metadta; make sure the wasm is up to date.");
}

// Only use one of slidingWindowSize and maxWindowLength
if (metadata.hasOwnProperty("sliding_window_size") && metadata.sliding_window_size != -1) {
this.slidingWindowSize = metadata.sliding_window_size;
Expand Down Expand Up @@ -149,20 +170,60 @@ export class LLMChatPipeline {
}
}

// 5. Create cache
// Use `fcreateCache` to determine whether we are using the new KVCache implementation
let fcreateCache;
if (useSLIM) {
fcreateCache = this.vm.getFunction("_initialize_effect");
} else {
fcreateCache = this.vm.getFunction("create_kv_cache");
try {
if (useSLIM) {
fcreateCache = this.vm.getFunction("_initialize_effect");
} else {
fcreateCache = this.vm.getFunction("create_kv_cache");
}
} catch (err) {
// If we cannot find function above, it means that we are using the new PagedKVCache
this.usePagedKVCache = true;
fcreateCache = this.vm.getFunction("create_tir_paged_kv_cache");
console.log("Using Paged KVCache");
if (this.embed === undefined) {
throw Error("If using paged KVCache, method `embed()` needs to be defined.");
}
}

this.fclearKVCaches = this.tvm.detachFromCurrentScope(
this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear")
);
// Load cache functions and instantiate KVCache
if (this.usePagedKVCache) {
this.fclearKVCaches = this.tvm.detachFromCurrentScope(
this.tvm.getGlobalFunc("vm.builtin.paged_attention_kv_cache_clear")
);
this.fKVCacheAddSequence = this.tvm.detachFromCurrentScope(
this.tvm.getGlobalFunc("vm.builtin.paged_attention_kv_cache_add_sequence")
);
this.fKVCacheRemoveSequence = this.tvm.detachFromCurrentScope(
this.tvm.getGlobalFunc("vm.builtin.paged_attention_kv_cache_remove_sequence")
);
this.fKVCacheBeginForward = this.tvm.detachFromCurrentScope(
this.tvm.getGlobalFunc("vm.builtin.paged_attention_kv_cache_begin_forward")
);
this.fKVCacheEndForward = this.tvm.detachFromCurrentScope(
this.tvm.getGlobalFunc("vm.builtin.paged_attention_kv_cache_end_forward")
);

// use extern config for now
this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache());
// Create PagedKVCache; we do not expose KVCache config for now
const defaultPageSize = 16;
const defaultMaxNumSequence = 1;
this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache(
this.tvm.makeShapeTuple([defaultMaxNumSequence]), // max_num_sequence
this.tvm.makeShapeTuple([this.maxWindowLength]), // max_total_sequence_length
this.tvm.makeShapeTuple([this.prefillChunkSize]), // prefill_chunk_size
this.tvm.makeShapeTuple([defaultPageSize]), // page_size, hard coded for now
));
} else {
this.fclearKVCaches = this.tvm.detachFromCurrentScope(
this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear")
);
this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache());
}
this.filledKVCacheLength = 0;
this.resetChat(); // especially needed for PagedKVCache as we need to call fKVCacheAddSequence
tvm.endScope();
}

Expand Down Expand Up @@ -198,18 +259,28 @@ export class LLMChatPipeline {
/**
* Reset the chat history
*/
resetChat(keepStats: boolean = false) {
resetChat(keepStats = false) {
this.conversation.reset();
if (!keepStats) {
this.resetRuntimeStats();
}
this.fclearKVCaches(this.kvCache);
this.resetKVCache();
this.filledKVCacheLength = 0;
this.sinkTriggered = false;
this.slidingWindowCacheOffset = 0;
this.logitProcessor?.resetState();
}

/**
* Reset KV Cache
*/
resetKVCache() {
this.fclearKVCaches(this.kvCache);
if (this.usePagedKVCache) {
this.fKVCacheAddSequence!(this.kvCache, new tvmjs.Scalar(0, "int64"));
}
}

/**
* @returns Whether stop is triggered.
*/
Expand Down Expand Up @@ -397,16 +468,23 @@ export class LLMChatPipeline {
private forward(inputs: tvmjs.NDArray, curPos: number): tvmjs.NDArray {
this.tvm.beginScope();
let retValue;
const seqLen = inputs.shape[1]; // Num input tokens
const seqLenShape = this.tvm.makeShapeTuple([curPos]);
if (inputs.shape[1] > 1) {
if (seqLen > 1) {
// Prefill
if (this.slidingWindowSize == -1) {
retValue = this.prefill(
inputs, seqLenShape, this.kvCache, this.params
);
if (this.usePagedKVCache) {
const seqIdsTuple = this.tvm.makeShapeTuple([0]);
const inputLenShape = this.tvm.makeShapeTuple([seqLen]);
this.fKVCacheBeginForward!(this.kvCache, seqIdsTuple, inputLenShape);
const embed = this.embed!(inputs, this.params);
retValue = this.prefill(embed, this.kvCache, this.params);
this.fKVCacheEndForward!(this.kvCache);
} else {
retValue = this.prefill(inputs, seqLenShape, this.kvCache, this.params);
}
} else {
// Sliding window attention needs extra shape parameters
const seqLen = inputs.shape[1]; // Num input tokens
const cacheLen = Math.min(this.slidingWindowSize, curPos - seqLen); // Num elements in the cache
const cacheLenShape = this.tvm.makeShapeTuple([cacheLen]);
const kvSeqLenShape = this.tvm.makeShapeTuple([cacheLen + seqLen]);
Expand All @@ -419,9 +497,16 @@ export class LLMChatPipeline {
} else {
// Decode
if (this.slidingWindowSize == -1) {
retValue = this.decoding(
inputs, seqLenShape, this.kvCache, this.params
);
if (this.usePagedKVCache) {
const seqIdsTuple = this.tvm.makeShapeTuple([0]);
const appendLength = this.tvm.makeShapeTuple([1]);
this.fKVCacheBeginForward!(this.kvCache, seqIdsTuple, appendLength);
const embed = this.embed!(inputs, this.params);
retValue = this.decoding(embed, this.kvCache, this.params);
this.fKVCacheEndForward!(this.kvCache);
} else {
retValue = this.decoding(inputs, seqLenShape, this.kvCache, this.params);
}
} else {
// Same logic as above; keeping this if-else structure to match mlc-llm's llm_chat.cc
const seqLen = inputs.shape[1]; // Num input tokens
Expand Down Expand Up @@ -463,7 +548,7 @@ export class LLMChatPipeline {
) {
// 1. Move logits to CPU
this.tvm.beginScope();
let logitsOnCPU = this.updateLogitsOnCPU(logitsOnGPU);
const logitsOnCPU = this.updateLogitsOnCPU(logitsOnGPU);
this.tvm.endScope();
await this.device.sync();

Expand Down Expand Up @@ -544,7 +629,7 @@ export class LLMChatPipeline {
// need shift window and re-encode
this.logger("need shift window")
this.filledKVCacheLength = 0;
this.fclearKVCaches(this.kvCache);
this.resetKVCache();

// abandon all tokens we collected
if (this.conversation.config.add_bos) {
Expand Down Expand Up @@ -585,7 +670,7 @@ export class LLMChatPipeline {
inputData.copyFrom(inputIds);

// 2. Forward tokens and get logits
let logitsOnGPU: tvmjs.NDArray = this.forward(inputData, curPos);
const logitsOnGPU: tvmjs.NDArray = this.forward(inputData, curPos);
const nextToken = await this.sampleTokenFromLogits(
logitsOnGPU, this.config.temperature, this.config.top_p);
this.tvm.endScope();
Expand All @@ -605,7 +690,7 @@ export class LLMChatPipeline {

async evaluate() {
// run a canonical evaluation of the flow
this.fclearKVCaches(this.kvCache);
this.resetKVCache();
this.filledKVCacheLength = 0;

const testPrompt = "The capital of Canada is";
Expand Down

0 comments on commit ec2662f

Please sign in to comment.