Skip to content

Commit

Permalink
Only the number of past tokens is needed
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jan 11, 2024
1 parent f1314a5 commit 2bee022
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions examples/python/run_llama_batched_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ class SequenceGenerationResponse:
@dataclass
class EvalQueryRequest:
request_id: int
past_token_ids: List[int]
num_past_tokens: int
query_token_ids: List[int]


Expand Down Expand Up @@ -262,11 +262,11 @@ def _prepare_eval_queries(
positions = []
permute_map = []

query_offset = sum([len(request.past_token_ids) for request in requests])
query_offset = sum([request.num_past_tokens for request in requests])
past_offset = 0

for request in requests:
num_past_tokens = len(request.past_token_ids)
num_past_tokens = request.num_past_tokens
num_queries = len(request.query_token_ids)
query_lens.append(num_queries)
request_id = request.request_id
Expand Down Expand Up @@ -521,8 +521,8 @@ def run(args):

for request_id, query_token_len in zip(request_ids, query_token_lens):
queries_to_eval = requests[request_id].token_ids[-query_token_len:]
past_tokens = requests[request_id].token_ids[:-query_token_len]
eval_query_requests.append(EvalQueryRequest(request_id, past_tokens, queries_to_eval))
num_past = len(requests[request_id].token_ids) - query_token_len
eval_query_requests.append(EvalQueryRequest(request_id, num_past, queries_to_eval))

(
input_ids,
Expand Down

0 comments on commit 2bee022

Please sign in to comment.