Skip to content

Commit

Permalink
Modifying the sampler to allow FORCED type of sampling. (#265)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexei-V-Ivanov-AMD authored Nov 5, 2024
1 parent c091eaf commit 1c740db
Showing 1 changed file with 4 additions and 5 deletions.
9 changes: 4 additions & 5 deletions vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,11 +753,11 @@ def get_pythonized_sample_results(
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(seq_groups,
multinomial_samples[sampling_type])
elif sampling_type == SamplingType.FORCED:
sample_results = _forced_sample(seq_groups, forced_samples)
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
elif sampling_type == SamplingType.FORCED:
sample_results = _forced_sample(seq_groups, forced_samples)
sample_results_dict.update(zip(seq_group_id, sample_results))

return [
Expand Down Expand Up @@ -869,9 +869,6 @@ def _sample_with_torch(
# Store sampled tokens in output tensor.
sampled_token_ids_tensor[long_sample_indices] = \
multinomial_samples[sampling_type].to(torch.long)

elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
elif sampling_type == SamplingType.FORCED:
if (seq_groups[0].sampling_params.future_context is not None):
forced_samples = torch.tensor([
Expand All @@ -884,6 +881,8 @@ def _sample_with_torch(
else:
forced_samples = torch.argmax(logprobs[long_sample_indices],
dim=-1)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")

Expand Down

0 comments on commit 1c740db

Please sign in to comment.