From 82df51f3405d991a82f75d5c4fe4311eae529c3a Mon Sep 17 00:00:00 2001 From: saumya-saran Date: Fri, 20 Sep 2024 12:46:02 -0700 Subject: [PATCH] [Bugfix] Validate SamplingParam n is an int (#8548) --- vllm/sampling_params.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 5edbc8e424e81..86e80ae5e224d 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -273,9 +273,14 @@ def __post_init__(self) -> None: self._all_stop_token_ids = set(self.stop_token_ids) def _verify_args(self) -> None: + if not isinstance(self.n, int): + raise ValueError(f"n must be an int, but is of " + f"type {type(self.n)}") if self.n < 1: raise ValueError(f"n must be at least 1, got {self.n}.") - assert isinstance(self.best_of, int) + if not isinstance(self.best_of, int): + raise ValueError(f'best_of must be an int, but is of ' + f'type {type(self.best_of)}') if self.best_of < self.n: raise ValueError(f"best_of must be greater than or equal to n, " f"got n={self.n} and best_of={self.best_of}.")