diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index a290eb892c..0a2c8f0b9c 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -117,6 +117,13 @@ class SequenceGenerationResponse: token_id: int +@dataclass +class EvalQueryRequest: + request_id: int + num_past_tokens: int + query_token_ids: List[int] + + def sample(logits): logits = torch.from_dlpack(logits) return torch.argmax(logits, -1).cpu().numpy() @@ -241,6 +248,72 @@ def _pad_to_max(x: List[int], max_len: int) -> List[int]: ) +def _prepare_eval_queries( + requests: List[EvalQueryRequest], + all_slot_mappings, + sliding_window, + dev, +): + seq_lens = [] + query_lens = [] + input_ids = [] + slot_mapping = [] + past_slot_mapping = [] + positions = [] + permute_map = [] + + query_offset = sum([request.num_past_tokens for request in requests]) + past_offset = 0 + + for request in requests: + num_past_tokens = request.num_past_tokens + num_queries = len(request.query_token_ids) + query_lens.append(num_queries) + request_id = request.request_id + input_ids += request.query_token_ids + + positions += [num_past_tokens + i for i in range(num_queries)] + + if sliding_window and num_past_tokens + num_queries >= sliding_window: + seq_lens.append(sliding_window) + past_slot_mapping += all_slot_mappings[request_id][ + num_past_tokens - (sliding_window - num_queries) : num_past_tokens + ] + else: + seq_lens.append(num_past_tokens + num_queries) + past_slot_mapping += all_slot_mappings[request_id][:num_past_tokens] + + slot_mapping += all_slot_mappings[request_id][ + num_past_tokens : num_past_tokens + num_queries + ] + + permute_map += list(range(past_offset, past_offset + num_past_tokens)) + list( + range(query_offset, query_offset + num_queries) + ) + + query_offset += num_queries + past_offset += num_past_tokens + + input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), dev) + positions = tvm.nd.array(np.array(positions, dtype="int32"), dev) + seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), dev) + slot_mapping = tvm.nd.array(np.array(slot_mapping, dtype="int32"), dev) + + query_lens = tvm.nd.array(np.array(query_lens, dtype="int32"), dev) + past_slot_mapping = tvm.nd.array(np.array(past_slot_mapping, dtype="int32"), dev) + permute_map = tvm.nd.array(np.array(permute_map, dtype="int32"), dev) + + return ( + input_ids, + positions, + seq_lens, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + ) + + class Model: def __init__( self, artifact_path, model_name, quant, vocab_size, num_shards, dev, sliding_window @@ -443,6 +516,59 @@ def run(args): for p, g in zip(prompts, generated): print("Prompt = '{}', generated text = '{}'".format(p, g)) + query_token_lens = [4, 3, 5, 2] + + eval_query_requests = [] + + for request_id, query_token_len in zip(request_ids, query_token_lens): + queries_to_eval = requests[request_id].token_ids[-query_token_len:] + num_past = len(requests[request_id].token_ids) - query_token_len + eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval)) + + ( + input_ids, + positions, + seq_lens, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + ) = _prepare_eval_queries( + eval_query_requests, + cache.slot_mappings, + None, + model.dev, + ) + + logits = model.mod["evaluate_multi_query"]( + input_ids, + positions, + seq_lens, + cache.cache, + slot_mapping, + query_lens, + past_slot_mapping, + permute_map, + model.params, + )[0].numpy() + + assert logits.shape[0] == sum(query_token_lens) + + logits_offset = 0 + + for request_id, query_token_len in zip(request_ids, query_token_lens): + for i in range(query_token_len - 1): + # requests[request_id].token_ids[-query_token_len:] are the "ground truth" tokens. + # Doing argmax over multi-timestep logits computed in parallel should yield the same + # tokens at the corresponding positions. + past_tokens = requests[request_id].token_ids[:-query_token_len] + assert ( + np.argmax(logits[logits_offset + i]) + == requests[request_id].token_ids[len(past_tokens) + i + 1] + ) + + logits_offset += query_token_len + if __name__ == "__main__": run(parse_args()) diff --git a/mlc_llm/core.py b/mlc_llm/core.py index 7a69562696..f7afbbb693 100644 --- a/mlc_llm/core.py +++ b/mlc_llm/core.py @@ -593,6 +593,7 @@ def mod_transform_before_build( # This is equivalent to prefill but without KV cache. It is used for # determining the number of paged cache blocks that can be allocated. model_names.append("evaluate") + model_names.append("evaluate_multi_query") if args.sep_embed: model_names = ["embed", "prefill_with_embed"] + model_names[1:] diff --git a/mlc_llm/relax_model/llama_batched_vllm.py b/mlc_llm/relax_model/llama_batched_vllm.py index 67e0e12f90..0d1ad13bcd 100644 --- a/mlc_llm/relax_model/llama_batched_vllm.py +++ b/mlc_llm/relax_model/llama_batched_vllm.py @@ -1,9 +1,11 @@ from typing import Optional, Tuple +from dataclasses import dataclass + import numpy as np import tvm from tvm import relax, te -from tvm.relax.op import ccl, reshape, expand_dims, concat, zeros, repeat, take +from tvm.relax.op import ccl, reshape, expand_dims, concat, zeros, take, concat from tvm.relax.op.nn import attention_var_len from tvm.relax.testing import nn from tvm.ir import VDevice @@ -45,6 +47,17 @@ def rotary_compute(*idx): return q_embed, k_embed +@dataclass +class EvaluateMultiQueryInput: + query_start: relax.Expr # (num_query_token + 1,) + max_query_len: relax.Expr # (), must be on CPU + # The followings are only needed for our naive implementation of multi-query eval + # with paged KV cache. They can be replaced with block_tables when a proper attention + # kernel becomes available. + past_slot_mapping: relax.Expr # (num_past_token,) + permute_indices_after_concat: relax.Expr # (num_past_token + num_query_token,) + + class LlamaAttentionBatched(LlamaAttentionBase): def __init__(self, config: LlamaConfig): super().__init__(config) @@ -58,24 +71,25 @@ def __init__(self, config: LlamaConfig): def forward( self, - hidden_states: relax.Expr, # (num_token, hidden_size) - positions: relax.Expr, # (num_token,), for batched RoPE + hidden_states: relax.Expr, # (num_query_token, hidden_size) + positions: relax.Expr, # (num_query_token,), for batched RoPE seq_lens: relax.Expr, # (num_seq,) kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], - slot_mapping: Optional[relax.Expr], # (num_token,) + slot_mapping: Optional[relax.Expr], # (num_query_token,) max_seqlen: Optional[relax.Expr], # (), must be on CPU - seqstart: Optional[relax.Expr], # (num_seq + 1,), for prefill + seq_start: Optional[relax.Expr], # (num_seq + 1,), for prefill block_tables: Optional[relax.Expr], # (num_seq, max_num_blocks_per_seq), for decode indices_within_window: Optional[ relax.Expr - ], # (num_cached_total,), for prefill with sliding-window attention + ], # (num_cached_total,), for prefill with sliding-window attention, + eval_multi_input: Optional[EvaluateMultiQueryInput], ): - num_tokens, _ = hidden_states.struct_info.shape + num_query_tokens, _ = hidden_states.struct_info.shape queries, keys, values = self.project_qkv( hidden_states, - (num_tokens, self.num_query_heads, self.head_dim), - (num_tokens, self.num_key_value_heads, self.head_dim), + (num_query_tokens, self.num_query_heads, self.head_dim), + (num_query_tokens, self.num_key_value_heads, self.head_dim), ) queries, keys = apply_rotary_pos_emb(queries, keys, positions, self.position_embedding_base) @@ -84,15 +98,15 @@ def forward( # Paged KV cache update k_cache, v_cache = kv_cache - if self.sliding_window is None or block_tables: - # For decode or prefill without sliding window, cache all keys / values. - keys_to_cache = keys - values_to_cache = values - else: + if indices_within_window: # Cache only the most recent keys and values within the window. keys_to_cache = nn.emit(take(keys, indices_within_window, axis=0)) values_to_cache = nn.emit(take(values, indices_within_window, axis=0)) slot_mapping = nn.emit(take(slot_mapping, indices_within_window, axis=0)) + else: + # For decode or prefill without sliding window, cache all keys / values. + keys_to_cache = keys + values_to_cache = values # kv caches are updated inplace, but make it look like a pure operation kv = nn.emit( @@ -111,15 +125,65 @@ def forward( else: k_cache = v_cache = None - if seqstart: - # Prefill, batched attention over variable sequence lengths + if eval_multi_input: + assert k_cache and v_cache + num_kv_head = v_cache.struct_info.shape[1] + head_size = v_cache.struct_info.shape[2] + num_past_token = eval_multi_input.past_slot_mapping.struct_info.shape[0] + kv_shape = (num_past_token, num_kv_head, head_size) + kv_sinfo = relax.TensorStructInfo(kv_shape, k_cache.struct_info.dtype) + + kv_tensors = nn.emit( + relax.op.call_pure_packed( + "tvm.contrib.vllm.reconstruct_from_cache", + k_cache, + v_cache, + eval_multi_input.past_slot_mapping, + sinfo_args=[kv_sinfo, kv_sinfo], + ) + ) + keys_past, values_past = kv_tensors[0], kv_tensors[1] + # Say we have past tokens [P1, P2, P3] and the current ones [C1, C2, C3]. + # Each of P1, C1 etc is a sequence of tokens. + # After concat, we have [P1, P2, P3, C1, C2, C3], but batched sequences need to + # be in the format [P1, C1, P2, C2, P3, C3]. This permutation is done by the take + # op and the provided permutation indices. + keys = nn.emit( + take( + concat([keys_past, keys]), eval_multi_input.permute_indices_after_concat, axis=0 + ) + ) + values = nn.emit( + take( + concat([values_past, values]), + eval_multi_input.permute_indices_after_concat, + axis=0, + ) + ) + seq_start_q = eval_multi_input.query_start + max_seqlen_q = eval_multi_input.max_query_len + seq_start_k = seq_start + max_seqlen_k = max_seqlen + elif seq_start: + # prefill + seq_start_q = seq_start_k = seq_start + max_seqlen_q = max_seqlen_k = max_seqlen + else: + # decode + seq_start_q = seq_start_k = None + max_seqlen_q = max_seqlen_k = None + + if seq_start_q: + # Prefill or multi-query evaluation, batched attention over variable sequence lengths attn_output = nn.emit( attention_var_len( nn.emit(expand_dims(queries, axis=0)), nn.emit(expand_dims(keys, axis=0)), nn.emit(expand_dims(values, axis=0)), - seqstart_q=seqstart, - max_seqlen_q=max_seqlen, + seq_start_q, + max_seqlen_q, + seq_start_k, + max_seqlen_k, causal_mask="BottomRight", window_size=self.sliding_window, ) @@ -128,14 +192,14 @@ def forward( # Decode, using vLLM kernel exp_sums = nn.emit( relax.op.builtin.alloc_tensor( - relax.ShapeExpr((num_tokens, self.num_query_heads, self.max_num_partitions)), + relax.ShapeExpr((num_query_tokens, self.num_query_heads, self.max_num_partitions)), dtype="float32", runtime_device_index=0, ) ) max_logits = nn.emit( relax.op.builtin.alloc_tensor( - relax.ShapeExpr((num_tokens, self.num_query_heads, self.max_num_partitions)), + relax.ShapeExpr((num_query_tokens, self.num_query_heads, self.max_num_partitions)), dtype="float32", runtime_device_index=0, ) @@ -143,7 +207,7 @@ def forward( tmp_out = nn.emit( relax.op.builtin.alloc_tensor( relax.ShapeExpr( - (num_tokens, self.num_query_heads, self.max_num_partitions, self.head_dim) + (num_query_tokens, self.num_query_heads, self.max_num_partitions, self.head_dim) ), dtype=queries.struct_info.dtype, runtime_device_index=0, @@ -169,7 +233,7 @@ def forward( ) attn_output = nn.emit( - reshape(attn_output, (num_tokens, self.num_query_heads * self.head_dim)) + reshape(attn_output, (num_query_tokens, self.num_query_heads * self.head_dim)) ) attn_output = self.o_proj(attn_output) @@ -189,9 +253,10 @@ def forward( kv_cache: Optional[Tuple[relax.Expr, relax.Expr]], slot_mapping: Optional[relax.Expr], max_seqlen: Optional[relax.Expr], - seqstart: Optional[relax.Expr], + seq_start: Optional[relax.Expr], block_tables: Optional[relax.Expr], indices_within_window: Optional[relax.Expr], + eval_multi_input: Optional[EvaluateMultiQueryInput], ) -> Tuple[relax.Expr, Optional[Tuple[relax.Expr, relax.Expr]]]: residual = hidden_states @@ -205,9 +270,10 @@ def forward( kv_cache=kv_cache, slot_mapping=slot_mapping, max_seqlen=max_seqlen, - seqstart=seqstart, + seq_start=seq_start, block_tables=block_tables, indices_within_window=indices_within_window, + eval_multi_input=eval_multi_input, ) hidden_states = self.post_self_attn(hidden_states, residual) @@ -215,12 +281,22 @@ def forward( return hidden_states, new_kv +def create_seq_start(seq_lens): + # https://github.com/apache/tvm/issues/15851 for why we need to use Thrust + cumsum = nn.emit( + relax.op.call_dps_packed( + "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info + ) + ) + return nn.emit(concat([zeros((1,), "int32"), cumsum])) + + class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, cpu_device: VDevice, - vocab_size_var: tvm.tir.SizeVar, + vocab_size_var: tvm.tir.Var, sep_embed: bool = False, ): self.padding_idx = config.pad_token_id @@ -247,9 +323,12 @@ def forward( seq_lens: relax.Expr, kv_caches: Optional[relax.Expr], slot_mapping: Optional[relax.Expr], - seqstart: Optional[relax.Expr], + seq_start: Optional[relax.Expr], block_tables: Optional[relax.Expr], indices_within_window: Optional[relax.Expr], + query_lens: Optional[relax.Expr], + past_slot_mapping: Optional[relax.Expr], + permute_indices_after_concat: Optional[relax.Expr], ): if self.embed_tokens: inputs_embeds = self.embed_tokens(inputs) @@ -265,6 +344,15 @@ def forward( new_kvs = () + if query_lens: + max_query_len = R.to_vdevice(R.max(query_lens), self.cpu_device) + query_start = create_seq_start(query_lens) + eval_multi_input = EvaluateMultiQueryInput( + query_start, max_query_len, past_slot_mapping, permute_indices_after_concat + ) + else: + eval_multi_input = None + for idx, decoder_layer in enumerate(self.layers): if kv_caches: cache = (kv_caches[2 * idx], kv_caches[2 * idx + 1]) @@ -278,9 +366,10 @@ def forward( cache, slot_mapping, max_seqlen, - seqstart, + seq_start, block_tables, indices_within_window, + eval_multi_input, ) new_kvs += new_kv @@ -312,17 +401,18 @@ def __init__( def forward( self, - input_ids: relax.Expr, # (num_token,) - positions: relax.Expr, # (num_token,), for batched RoPE + input_ids: relax.Expr, # (num_query_token,) + positions: relax.Expr, # (num_query_token,), for batched RoPE seq_lens: relax.Expr, # (num_seq,) kv_caches: Optional[relax.Expr], # For prefill and decode, not needed for evaluate - slot_mapping: Optional[ - relax.Expr - ], # (num_token,), for prefill and decode, not needed for evaluate + slot_mapping: Optional[relax.Expr], # (num_query_token,), Not needed for evaluate block_tables: Optional[relax.Expr], # (num_seq, max_num_blocks_per_seq), for decode indices_within_window: Optional[ relax.Expr ], # (num_cached_total,), for prefill with sliding-window attention + query_lens: Optional[relax.Expr], + past_slot_mapping: Optional[relax.Expr], + permute_indices_after_concat: Optional[relax.Expr], ): """ In vLLM, the paged KV cache is simply a pair of tensors, one for keys and the other @@ -338,7 +428,7 @@ def forward( So the length of a block table for each sequence is at most ceil(window_size / block_size). With sliding window, not all past K / V values need to be cached during prefill. - The last input, indices_within_window, tells which tokens among (num_token,) need to have + The last input, indices_within_window, tells which tokens among (num_query_token,) need to have their K / V values cached. """ if self.num_shards > 1: @@ -355,18 +445,21 @@ def forward( if indices_within_window: indices_within_window = nn.emit(ccl.broadcast_from_worker0(indices_within_window)) - is_prompt = block_tables is None - - if is_prompt: # prefill and evaluate - # https://github.com/apache/tvm/issues/15851 for why we need to use Thrust - cumsum = nn.emit( - relax.op.call_dps_packed( - "tvm.contrib.thrust.sum_scan", seq_lens, out_sinfo=seq_lens.struct_info + if query_lens: + query_lens = nn.emit(ccl.broadcast_from_worker0(query_lens)) + past_slot_mapping = nn.emit(ccl.broadcast_from_worker0(past_slot_mapping)) + permute_indices_after_concat = nn.emit( + ccl.broadcast_from_worker0(permute_indices_after_concat) ) - ) - seqstart = nn.emit(concat([zeros((1,), "int32"), cumsum])) + + # TODO: Update this condition for evaluate multi + is_prompt = block_tables is None and query_lens is None + is_eval_multi = query_lens is not None + + if is_prompt or is_eval_multi: # prefill and evaluate + seq_start = create_seq_start(seq_lens) else: - seqstart = None + seq_start = None hidden_states, new_kvs = self.model( input_ids, @@ -374,18 +467,21 @@ def forward( seq_lens, kv_caches, slot_mapping, - seqstart, + seq_start, block_tables, indices_within_window, + query_lens, + past_slot_mapping, + permute_indices_after_concat, ) if is_prompt: # Extract logits for the last token in each sequence - def get_logits_last_tokens(x, seq_len_tensor, seqstart): + def get_logits_last_tokens(x, seq_len_tensor, seq_start): return te.compute( shape=(seq_len_tensor.shape[0], x.shape[-1]), - fcompute=lambda i, j: x[seqstart[i] + seq_len_tensor[i] - 1, j], + fcompute=lambda i, j: x[seq_start[i] + seq_len_tensor[i] - 1, j], name="get_logits_last_tokens", ) @@ -394,7 +490,7 @@ def get_logits_last_tokens(x, seq_len_tensor, seqstart): get_logits_last_tokens, hidden_states, seq_lens, - seqstart, + seq_start, primfunc_name_hint="get_logits_last_tokens", ) ) @@ -408,21 +504,21 @@ def get_logits_last_tokens(x, seq_len_tensor, seqstart): def get_inputs( - num_token, num_seq, config, max_num_blocks_per_seq=None, sep_embed=False, need_cache=True + num_query_token, num_seq, config, max_num_blocks_per_seq=None, sep_embed=False, need_cache=True ): hidden_size = config.hidden_size inputs = ( - nn.Placeholder((num_token, hidden_size), dtype=config.dtype, name="inputs_embeds") + nn.Placeholder((num_query_token, hidden_size), dtype=config.dtype, name="inputs_embeds") if sep_embed - else nn.Placeholder((num_token,), dtype="int32", name="input_ids") + else nn.Placeholder((num_query_token,), dtype="int32", name="input_ids") ) seq_lens = nn.Placeholder((num_seq,), dtype="int32", name="seq_lens") - positions = nn.Placeholder((num_token,), dtype="int32", name="positions") + positions = nn.Placeholder((num_query_token,), dtype="int32", name="positions") if need_cache: - num_blocks = tvm.tir.SizeVar("num_blocks", "int64") + num_blocks = tvm.tir.Var("num_blocks", "int64") block_size = 16 vec_size = 8 # 128 bit, fp16 x 8 @@ -448,7 +544,7 @@ def get_inputs( [get_cache_sinfo(i) for i in range(config.num_hidden_layers * 2)] ), ) - slot_mapping = nn.Placeholder((num_token,), dtype="int32", name="slot_mapping") + slot_mapping = nn.Placeholder((num_query_token,), dtype="int32", name="slot_mapping") else: past_key_values = None slot_mapping = None @@ -475,15 +571,15 @@ def create_evaluate_func( """Evaluate logits for the last token in each sequence. Same as prefill but without KV cache.""" func_name = "evaluate" - num_token = tvm.tir.SizeVar("num_token", "int64") + num_query_token = tvm.tir.SizeVar("num_query_token", "int64") num_seq = tvm.tir.SizeVar("num_seq", "int64") with bb.function(func_name): - model = LlamaForCausalLM(config, cpu_dev, tvm.tir.SizeVar("vocab_size", "int64"), sep_embed) + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), sep_embed) param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) inputs, positions, seq_lens, _, _, _ = get_inputs( - num_token, num_seq, config, sep_embed=sep_embed + num_query_token, num_seq, config, sep_embed=sep_embed ) with bb.dataflow(): @@ -495,6 +591,9 @@ def create_evaluate_func( slot_mapping=None, block_tables=None, indices_within_window=None, + query_lens=None, + past_slot_mapping=None, + permute_indices_after_concat=None, ) params = [ inputs, @@ -524,7 +623,7 @@ def create_encoding_func( """ func_name = "prefill_with_embed" if sep_embed else "prefill" - num_token = tvm.tir.SizeVar("num_token", "int64") + num_query_token = tvm.tir.SizeVar("num_query_token", "int64") num_seq = tvm.tir.SizeVar("num_seq", "int64") num_inputs = 5 @@ -534,7 +633,7 @@ def create_encoding_func( param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( - num_token, num_seq, config, sep_embed=sep_embed + num_query_token, num_seq, config, sep_embed=sep_embed ) with bb.dataflow(): @@ -558,9 +657,9 @@ def create_encoding_func( if config.sliding_window: num_inputs += 1 # The value of num_cached_total is between - # num_token (if seq_len < sliding_window for all seq) and + # num_query_token (if seq_len < sliding_window for all seq) and # num_seq * config.sliding_window (if seq_len > sliding_window for all seq) - num_cached_total = tvm.tir.SizeVar("num_cached_total", "int64") + num_cached_total = tvm.tir.Var("num_cached_total", "int64") indices_within_window = nn.Placeholder( (num_cached_total,), dtype="int32", name="indices_within_window" ) @@ -569,6 +668,8 @@ def create_encoding_func( else: inputs.append(None) + inputs += [None, None, None] + logits, new_kvs = model(*inputs) gv = bb.emit_output((logits, relax.Tuple(new_kvs))) @@ -602,7 +703,16 @@ def create_decoding_func( param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) logits, new_kvs = model( - inputs, positions, seq_lens, past_key_values, slot_mapping, block_tables, None + inputs, + positions, + seq_lens, + past_key_values, + slot_mapping, + block_tables, + None, + None, + None, + None, ) params = [ inputs, @@ -620,6 +730,72 @@ def create_decoding_func( bb.update_func(gv, mod[gv].with_attr("num_input", 6)) +def create_evaluate_multi_query_func( + bb: relax.BlockBuilder, + param_manager: ParamManager, + config: LlamaConfig, + cpu_dev: VDevice, + quant_scheme: QuantizationScheme, +) -> None: + func_name = "evaluate_multi_query" + + num_query_token = tvm.tir.SizeVar("num_query_token", "int64") + num_past_token = tvm.tir.SizeVar("num_past_token", "int64") + num_seq = tvm.tir.SizeVar("num_seq", "int64") + seq_lens_sum = tvm.tir.SizeVar("seq_lens_sum", "int64") + + num_inputs = 8 + + with bb.function(func_name): + model = LlamaForCausalLM(config, cpu_dev, tvm.tir.Var("vocab_size", "int64"), False) + param_manager.register_params(model, func_name, quant_scheme, get_param_quant_kind) + + input_ids, positions, seq_lens, past_key_values, slot_mapping, _ = get_inputs( + num_query_token, num_seq, config, sep_embed=False + ) + + query_lens = nn.Placeholder((num_seq,), dtype="int32", name="query_lens") + + # Replace them with block_tables when a proper attention kernel becomes available. + past_slot_mapping = nn.Placeholder( + (num_past_token,), dtype="int32", name="past_slot_mapping" + ) + permute_indices_after_concat = nn.Placeholder( + (seq_lens_sum,), dtype="int32", name="permute_indices_after_concat" + ) + + with bb.dataflow(): + params = [ + input_ids, + positions, + seq_lens, + past_key_values, + slot_mapping, + ] + + inputs = [ + input_ids, + positions, + seq_lens, + past_key_values, + slot_mapping, + None, # block_tables + None, # indices_within_window + ] + + inputs += [query_lens, past_slot_mapping, permute_indices_after_concat] + params += [query_lens, past_slot_mapping, permute_indices_after_concat] + + logits, new_kvs = model(*inputs) + gv = bb.emit_output((logits, relax.Tuple(new_kvs))) + + bb.emit_func_output(gv, params + model.parameters()) + + mod = bb.get() + gv = mod.get_global_var(func_name) + bb.update_func(gv, mod[gv].with_attr("num_input", num_inputs)) + + def get_model(args, hf_config): dtype = args.quantization.model_dtype sep_embed = False @@ -685,6 +861,7 @@ def get_model(args, hf_config): create_evaluate_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_encoding_func(bb, param_manager, config, cpu_dev, args.quantization, sep_embed) create_decoding_func(bb, param_manager, config, cpu_dev, args.quantization) + create_evaluate_multi_query_func(bb, param_manager, config, cpu_dev, args.quantization) mod = bb.get()