From b93e18f7ab80cbe63dc1cc1f032af3078137d36d Mon Sep 17 00:00:00 2001 From: Travis Johnson Date: Wed, 20 Mar 2024 14:01:39 -0600 Subject: [PATCH] lint: formatting changes from linter Signed-off-by: Travis Johnson --- tests/samplers/test_sampler.py | 17 +++++++++++------ vllm/engine/llm_engine.py | 3 ++- vllm/model_executor/layers/sampler.py | 2 +- vllm/sampling_params.py | 9 ++++----- 4 files changed, 18 insertions(+), 13 deletions(-) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 20ca14274e59e..0e67e974e194f 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -396,15 +396,20 @@ def run_test_case(*, if should_penalize: for token_id in tokens_to_check: - assert logits[logits_idx, token_id] == -float('inf'), \ - f"Expected token {token_id} for sequence {logits_idx} to be penalized" + assert logits[logits_idx, token_id] == -float( + 'inf' + ), f"Expected token {token_id} for logits row {logits_idx}" + " to be penalized" # no other tokens should be set to -inf - assert torch.count_nonzero(logits[logits_idx, :] == -float('inf')) == len(tokens_to_check), \ - f"Expected only {len(tokens_to_check)} to be penalized" + assert torch.count_nonzero( + logits[logits_idx, :] == -float('inf')) == len( + tokens_to_check + ), f"Expected only {len(tokens_to_check)} to be penalized" else: # no tokens should be set to -inf - assert torch.count_nonzero(logits[logits_idx, :] == -float('inf')) == 0, \ - "No tokens should have been penalized" + assert torch.count_nonzero( + logits[logits_idx, :] == + -float('inf')) == 0, "No tokens should have been penalized" del model_runner diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 1c5984d5aa3b1..796a4895dc971 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -269,7 +269,8 @@ def add_request( # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects sampling_params = sampling_params.clone() - # inject the eos token id into the sampling_params to support min_tokens processing + # inject the eos token id into the sampling_params to support min_tokens + # processing sampling_params.eos_token_id = self.get_tokenizer_for_seq( seq).eos_token_id diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 6c6ee5c61cf08..431c4f0e7db52 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -168,7 +168,7 @@ def _apply_min_tokens_penalty( seqs_to_penalize.append(i) if seqs_to_penalize: - # convert from the index for this seq_group to the index into logits + # convert to the index into logits seqs_to_penalize = [start_idx + i for i in seqs_to_penalize] # use set() to remove any duplicates token_ids_to_penalize = set(sampling_params.stop_token_ids + diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index e6ec1316adac7..6f81ee31f84dd 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -198,13 +198,12 @@ def _verify_args(self) -> None: raise ValueError( f"max_tokens must be at least 1, got {self.max_tokens}.") if self.min_tokens < 0: - raise ValueError( - f"min_tokens must be greater than or equal to 0, got {self.min_tokens}." - ) + raise ValueError(f"min_tokens must be greater than or equal to 0, " + f"got {self.min_tokens}.") if self.max_tokens is not None and self.min_tokens > self.max_tokens: raise ValueError( - f"min_tokens must be less than or equal to max_tokens={self.max_tokens}, got {self.min_tokens}." - ) + f"min_tokens must be less than or equal to " + f"max_tokens={self.max_tokens}, got {self.min_tokens}.") if self.logprobs is not None and self.logprobs < 0: raise ValueError( f"logprobs must be non-negative, got {self.logprobs}.")