diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f7c1569696249..11cda053260ec 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -96,6 +96,9 @@ steps: - label: Metrics Test command: pytest -v -s metrics +- label: Quantization Test + command: pytest -v -s quantization + - label: Benchmarks working_dir: "/vllm-workspace/.buildkite" commands: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index e56af9075e2fd..6ad7ae0f63197 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -34,9 +34,19 @@ def create_weights(self, layer: torch.nn.Module, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs): - """Create weights for a linear layer. - - The weights will be set as attributes of the layer.""" + """Create weights for a linear layer. + The weights will be set as attributes of the layer. + + Args: + layer: The layer that is using the LinearMethodBase factory. + input_size_per_partition: Size of the weight input dim on rank X. + output_partition_sizes: Sizes of the output dim of each logical + weight on rank X. E.g., output_partition_sizes for QKVLinear + is a list contains the width of Wq, Wk, Wv on rank X. + input_size: Size of the input dim of the weight across all ranks. + output_size: Size of the output dim of the weight across all ranks. + params_dtype: Datatype of the parameters. + """ raise NotImplementedError @abstractmethod diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 8df82e0e18edd..01e494c870e71 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -64,12 +64,13 @@ def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, - output_size_per_partition: int, + output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): + output_size_per_partition = sum(output_partition_sizes) weight = Parameter(torch.empty(output_size_per_partition, input_size_per_partition, dtype=params_dtype),