Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new Relax function to the batched model for evaluating query tokens over multiple time steps in parallel #156

Merged
merged 7 commits into from
Jan 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 126 additions & 0 deletions examples/python/run_llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,13 @@ class SequenceGenerationResponse:
token_id: int


@dataclass
class EvalQueryRequest:
request_id: int
num_past_tokens: int
query_token_ids: List[int]


def sample(logits):
logits = torch.from_dlpack(logits)
return torch.argmax(logits, -1).cpu().numpy()
Expand Down Expand Up @@ -241,6 +248,72 @@ def _pad_to_max(x: List[int], max_len: int) -> List[int]:
)


def _prepare_eval_queries(
requests: List[EvalQueryRequest],
all_slot_mappings,
sliding_window,
dev,
):
seq_lens = []
query_lens = []
input_ids = []
slot_mapping = []
past_slot_mapping = []
positions = []
permute_map = []

query_offset = sum([request.num_past_tokens for request in requests])
past_offset = 0

for request in requests:
num_past_tokens = request.num_past_tokens
num_queries = len(request.query_token_ids)
query_lens.append(num_queries)
request_id = request.request_id
input_ids += request.query_token_ids

positions += [num_past_tokens + i for i in range(num_queries)]

if sliding_window and num_past_tokens + num_queries >= sliding_window:
seq_lens.append(sliding_window)
past_slot_mapping += all_slot_mappings[request_id][
num_past_tokens - (sliding_window - num_queries) : num_past_tokens
]
else:
seq_lens.append(num_past_tokens + num_queries)
past_slot_mapping += all_slot_mappings[request_id][:num_past_tokens]

slot_mapping += all_slot_mappings[request_id][
num_past_tokens : num_past_tokens + num_queries
]

permute_map += list(range(past_offset, past_offset + num_past_tokens)) + list(
range(query_offset, query_offset + num_queries)
)

query_offset += num_queries
past_offset += num_past_tokens

input_ids = tvm.nd.array(np.array(input_ids, dtype="int32"), dev)
positions = tvm.nd.array(np.array(positions, dtype="int32"), dev)
seq_lens = tvm.nd.array(np.array(seq_lens, dtype="int32"), dev)
slot_mapping = tvm.nd.array(np.array(slot_mapping, dtype="int32"), dev)

query_lens = tvm.nd.array(np.array(query_lens, dtype="int32"), dev)
past_slot_mapping = tvm.nd.array(np.array(past_slot_mapping, dtype="int32"), dev)
permute_map = tvm.nd.array(np.array(permute_map, dtype="int32"), dev)

return (
input_ids,
positions,
seq_lens,
slot_mapping,
query_lens,
past_slot_mapping,
permute_map,
)


class Model:
def __init__(
self, artifact_path, model_name, quant, vocab_size, num_shards, dev, sliding_window
Expand Down Expand Up @@ -443,6 +516,59 @@ def run(args):
for p, g in zip(prompts, generated):
print("Prompt = '{}', generated text = '{}'".format(p, g))

query_token_lens = [4, 3, 5, 2]

eval_query_requests = []

for request_id, query_token_len in zip(request_ids, query_token_lens):
queries_to_eval = requests[request_id].token_ids[-query_token_len:]
num_past = len(requests[request_id].token_ids) - query_token_len
eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval))

(
input_ids,
positions,
seq_lens,
slot_mapping,
query_lens,
past_slot_mapping,
permute_map,
) = _prepare_eval_queries(
eval_query_requests,
cache.slot_mappings,
None,
model.dev,
)

logits = model.mod["evaluate_multi_query"](
input_ids,
positions,
seq_lens,
cache.cache,
slot_mapping,
query_lens,
past_slot_mapping,
permute_map,
model.params,
)[0].numpy()

assert logits.shape[0] == sum(query_token_lens)

logits_offset = 0

for request_id, query_token_len in zip(request_ids, query_token_lens):
for i in range(query_token_len - 1):
# requests[request_id].token_ids[-query_token_len:] are the "ground truth" tokens.
# Doing argmax over multi-timestep logits computed in parallel should yield the same
# tokens at the corresponding positions.
past_tokens = requests[request_id].token_ids[:-query_token_len]
assert (
np.argmax(logits[logits_offset + i])
== requests[request_id].token_ids[len(past_tokens) + i + 1]
)

logits_offset += query_token_len


if __name__ == "__main__":
run(parse_args())
1 change: 1 addition & 0 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ def mod_transform_before_build(
# This is equivalent to prefill but without KV cache. It is used for
# determining the number of paged cache blocks that can be allocated.
model_names.append("evaluate")
model_names.append("evaluate_multi_query")

if args.sep_embed:
model_names = ["embed", "prefill_with_embed"] + model_names[1:]
Expand Down
Loading