diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index dc30a0cfa3..5cb7f52ae6 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -120,7 +120,7 @@ class SequenceGenerationResponse: @dataclass class EvalQueryRequest: request_id: int - past_token_ids: List[int] + num_past_tokens: int query_token_ids: List[int] @@ -262,11 +262,11 @@ def _prepare_eval_queries( positions = [] permute_map = [] - query_offset = sum([len(request.past_token_ids) for request in requests]) + query_offset = sum([request.num_past_tokens for request in requests]) past_offset = 0 for request in requests: - num_past_tokens = len(request.past_token_ids) + num_past_tokens = request.num_past_tokens num_queries = len(request.query_token_ids) query_lens.append(num_queries) request_id = request.request_id @@ -521,8 +521,8 @@ def run(args): for request_id, query_token_len in zip(request_ids, query_token_lens): queries_to_eval = requests[request_id].token_ids[-query_token_len:] - past_tokens = requests[request_id].token_ids[:-query_token_len] - eval_query_requests.append(EvalQueryRequest(request_id, past_tokens, queries_to_eval)) + num_past = len(requests[request_id].token_ids) - query_token_len + eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval)) ( input_ids,