Skip to content

Commit

Permalink
lint: formatting changes from linter
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Johnson <[email protected]>
  • Loading branch information
tjohnson31415 committed Mar 20, 2024
1 parent 6393a50 commit b93e18f
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 13 deletions.
17 changes: 11 additions & 6 deletions tests/samplers/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,15 +396,20 @@ def run_test_case(*,

if should_penalize:
for token_id in tokens_to_check:
assert logits[logits_idx, token_id] == -float('inf'), \
f"Expected token {token_id} for sequence {logits_idx} to be penalized"
assert logits[logits_idx, token_id] == -float(
'inf'
), f"Expected token {token_id} for logits row {logits_idx}"
" to be penalized"
# no other tokens should be set to -inf
assert torch.count_nonzero(logits[logits_idx, :] == -float('inf')) == len(tokens_to_check), \
f"Expected only {len(tokens_to_check)} to be penalized"
assert torch.count_nonzero(
logits[logits_idx, :] == -float('inf')) == len(
tokens_to_check
), f"Expected only {len(tokens_to_check)} to be penalized"
else:
# no tokens should be set to -inf
assert torch.count_nonzero(logits[logits_idx, :] == -float('inf')) == 0, \
"No tokens should have been penalized"
assert torch.count_nonzero(
logits[logits_idx, :] ==
-float('inf')) == 0, "No tokens should have been penalized"

del model_runner

Expand Down
3 changes: 2 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,8 @@ def add_request(
# Defensive copy of SamplingParams, which are used by the sampler,
# this doesn't deep-copy LogitsProcessor objects
sampling_params = sampling_params.clone()
# inject the eos token id into the sampling_params to support min_tokens processing
# inject the eos token id into the sampling_params to support min_tokens
# processing
sampling_params.eos_token_id = self.get_tokenizer_for_seq(
seq).eos_token_id

Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def _apply_min_tokens_penalty(
seqs_to_penalize.append(i)

if seqs_to_penalize:
# convert from the index for this seq_group to the index into logits
# convert to the index into logits
seqs_to_penalize = [start_idx + i for i in seqs_to_penalize]
# use set() to remove any duplicates
token_ids_to_penalize = set(sampling_params.stop_token_ids +
Expand Down
9 changes: 4 additions & 5 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,13 +198,12 @@ def _verify_args(self) -> None:
raise ValueError(
f"max_tokens must be at least 1, got {self.max_tokens}.")
if self.min_tokens < 0:
raise ValueError(
f"min_tokens must be greater than or equal to 0, got {self.min_tokens}."
)
raise ValueError(f"min_tokens must be greater than or equal to 0, "
f"got {self.min_tokens}.")
if self.max_tokens is not None and self.min_tokens > self.max_tokens:
raise ValueError(
f"min_tokens must be less than or equal to max_tokens={self.max_tokens}, got {self.min_tokens}."
)
f"min_tokens must be less than or equal to "
f"max_tokens={self.max_tokens}, got {self.min_tokens}.")
if self.logprobs is not None and self.logprobs < 0:
raise ValueError(
f"logprobs must be non-negative, got {self.logprobs}.")
Expand Down

0 comments on commit b93e18f

Please sign in to comment.