diff --git a/examples/python/run_llama_batched_vllm.py b/examples/python/run_llama_batched_vllm.py index dcb16a878d..0a2c8f0b9c 100644 --- a/examples/python/run_llama_batched_vllm.py +++ b/examples/python/run_llama_batched_vllm.py @@ -274,17 +274,18 @@ def _prepare_eval_queries( positions += [num_past_tokens + i for i in range(num_queries)] - if sliding_window: - seq_lens.append(min(num_past_tokens + num_queries, sliding_window)) - num_past = min(num_past_tokens, sliding_window) - past_slot_mapping += all_slot_mappings[request_id][:num_past] - slot_mapping += all_slot_mappings[request_id][num_past: num_past + num_queries] + if sliding_window and num_past_tokens + num_queries >= sliding_window: + seq_lens.append(sliding_window) + past_slot_mapping += all_slot_mappings[request_id][ + num_past_tokens - (sliding_window - num_queries) : num_past_tokens + ] else: seq_lens.append(num_past_tokens + num_queries) past_slot_mapping += all_slot_mappings[request_id][:num_past_tokens] - slot_mapping += all_slot_mappings[request_id][ - num_past_tokens : num_past_tokens + num_queries - ] + + slot_mapping += all_slot_mappings[request_id][ + num_past_tokens : num_past_tokens + num_queries + ] permute_map += list(range(past_offset, past_offset + num_past_tokens)) + list( range(query_offset, query_offset + num_queries)