From 6541618635d233f96cd7e7fac8c88c9694115f54 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 3 Nov 2023 06:43:20 +0000 Subject: [PATCH 01/51] Create linear method --- vllm/model_executor/layers/linear.py | 290 ++++++++++++++++++ .../layers/quantized_linear/awq.py | 111 +++---- .../layers/quantized_linear/squeezellm.py | 84 ++--- 3 files changed, 356 insertions(+), 129 deletions(-) create mode 100644 vllm/model_executor/layers/linear.py diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py new file mode 100644 index 0000000000000..f3200e782c9ba --- /dev/null +++ b/vllm/model_executor/layers/linear.py @@ -0,0 +1,290 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather) +from vllm.model_executor.parallel_utils.utils import ( + divide, + split_tensor_along_last_dim, +) + + +class LinearMethodBase(ABC): + + @abstractmethod + def create_weights(self, module: torch.nn.Module, input_size: int, + output_size: int, params_dtype: torch.dtype) -> None: + del module + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + module: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + del module, x + raise NotImplementedError + + +class FullPrecisionLinearMethod(LinearMethodBase): + + def __init__(self, separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + + def create_weights(self, module: torch.nn.Module, input_size: int, + output_size: int, params_dtype: torch.dtype) -> None: + weight = Parameter(torch.empty(input_size, + output_size, + device=torch.cuda.current_device(), + dtype=params_dtype), + requires_grad=False) + module.register_parameter("weight", weight) + + def apply_weights(self, + module: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + if self.separate_bias_add: + if bias: + return F.linear(x, module.weight) + bias + return F.linear(x, module.weight) + return F.linear(x, module.weight, bias) + + +class ReplicatedLinear(torch.nn.Module): + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + if linear_method is None: + linear_method = FullPrecisionLinearMethod() + self.linear_method = linear_method + self.linear_method.create_weights(self, self.input_size, + self.output_size, self.params_dtype) + if bias: + self.bias = Parameter( + torch.empty(self.output_size, + device=torch.cuda.current_device(), + dtype=self.params_dtype)) + else: + self.register_parameter('bias', None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bias = self.bias if not self.skip_bias_add else None + output = self.linear_method.apply_weights(self, x, bias) + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class ColumnParallelLinear(torch.nn.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + + Keyword Arguments + bias: If true, add bias + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configuration. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + self.tp_size = get_tensor_model_parallel_world_size() + self.output_size_per_partition = divide(output_size, self.tp_size) + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + if linear_method is None: + linear_method = FullPrecisionLinearMethod() + self.linear_method = linear_method + self.linear_method.create_weights(self, self.input_size, + self.output_size_per_partition, + self.params_dtype) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, + device=torch.cuda.current_device(), + dtype=params_dtype)) + else: + self.register_parameter('bias', None) + + def forward(self, input_): + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + output_parallel = self.linear_method.apply_weights(self, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class RowParallelLinear(torch.nn.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + + Keyword Arguments: + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + skip_bias_add: This was added to enable performance optimization where + bias can be fused with other element-wise operations. + We skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configuration. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + # Divide the weight matrix along the last dimension. + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, self.tp_size) + self.skip_bias_add = skip_bias_add + if linear_method is None: + linear_method = FullPrecisionLinearMethod() + self.linear_method = linear_method + + self.linear_method.create_weights(self.input_size_per_partition, + self.output_size, self.params_dtype) + + if not reduce_results and (bias and not skip_bias_add): + raise ValueError('When not reduce the results, adding bias to the ' + 'results can lead to incorrect results') + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, + device=torch.cuda.current_device(), + dtype=params_dtype)) + else: + self.register_parameter('bias', None) + + def forward(self, input_): + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.linear_method.apply_weights(input_parallel) + if self.reduce_results and self.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.skip_bias_add: + output = output_ + self.bias if self.bias is not None else output_ + output_bias = None + else: + output = output_ + output_bias = self.bias + return output, output_bias diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py index 31e341318d400..43e51e8bdaab1 100644 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ b/vllm/model_executor/layers/quantized_linear/awq.py @@ -4,103 +4,68 @@ from torch.nn.parameter import Parameter from vllm import quantization_ops -from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import LinearMethodBase -class AWQColumnParallelLinear(ColumnParallelLinear): +class AWQLinearMethod(LinearMethodBase): - def create_weights(self, dtype: torch.dtype) -> None: - assert self.input_size % self.quant_config.group_size == 0 - if self.output_size_per_partition % self.quant_config.pack_factor != 0: + def __init__(self, quant_config): + self.quant_config = quant_config + + def create_weights(self, module: torch.nn.Module, input_size: int, + output_size: int, params_dtype: torch.dtype) -> None: + if input_size % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + if output_size % self.quant_config.pack_factor != 0: raise ValueError( - "The tensor parallel size is not aligned with the quantized " - "weight shape. Please use a different tensor parallel size.") - self.qweight = Parameter( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + qweight = Parameter( torch.empty( - self.input_size, - self.output_size_per_partition // - self.quant_config.pack_factor, + input_size, + output_size // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), requires_grad=False, ) - self.qzeros = Parameter( + qzeros = Parameter( torch.empty( - self.input_size // self.quant_config.group_size, - self.output_size_per_partition // - self.quant_config.pack_factor, + input_size // self.quant_config.group_size, + output_size // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), requires_grad=False, ) - self.scales = Parameter( + scales = Parameter( torch.empty( - self.input_size // self.quant_config.group_size, - self.output_size_per_partition, + input_size // self.quant_config.group_size, + output_size, device="cuda", - dtype=dtype, + dtype=params_dtype, ), requires_grad=False, ) + module.register_parameter("qweight", qweight) + module.register_parameter("qzeros", qzeros) + module.register_parameter("scales", scales) - def apply_weights( - self, - x: torch.Tensor, - bias: Optional[torch.Tensor], - ) -> torch.Tensor: + def apply_weights(self, + module: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: pack_factor = self.quant_config.pack_factor - out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, )) + out_shape = (x.shape[:-1] + (module.qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) - out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, - self.qzeros, pack_factor) + out = quantization_ops.awq_gemm(reshaped_x, module.qweight, + module.scales, module.qzeros, + pack_factor) if bias is not None: out = out + bias return out.reshape(out_shape) - - -class AWQRowParallelLinear(RowParallelLinear): - - def create_weights(self, dtype: torch.dtype) -> None: - assert self.output_size % self.quant_config.pack_factor == 0 - if self.input_size_per_partition % self.quant_config.group_size != 0: - raise ValueError( - "The tensor parallel size is not aligned with the quantized " - "weight shape. Please use a different tensor parallel size.") - self.qweight = Parameter( - torch.empty( - self.input_size_per_partition, - self.output_size // self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - requires_grad=False, - ) - self.qzeros = Parameter( - torch.empty( - self.input_size_per_partition // self.quant_config.group_size, - self.output_size // self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - requires_grad=False, - ) - self.scales = Parameter( - torch.empty( - self.input_size_per_partition // self.quant_config.group_size, - self.output_size, - device="cuda", - dtype=dtype, - ), - requires_grad=False, - ) - - def apply_weights(self, x: torch.Tensor) -> torch.Tensor: - pack_factor = self.quant_config.pack_factor - out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, )) - reshaped_x = x.reshape(-1, x.shape[-1]) - out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, - self.qzeros, pack_factor) - return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantized_linear/squeezellm.py b/vllm/model_executor/layers/quantized_linear/squeezellm.py index 3ccbc4e579dc6..eeccdc241e963 100644 --- a/vllm/model_executor/layers/quantized_linear/squeezellm.py +++ b/vllm/model_executor/layers/quantized_linear/squeezellm.py @@ -4,81 +4,53 @@ from torch.nn.parameter import Parameter from vllm import quantization_ops -from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import LinearMethodBase -class SqueezeLLMColumnParallelLinear(ColumnParallelLinear): +class SqueezeLLMLinearMethod(LinearMethodBase): - def create_weights(self, dtype: torch.dtype) -> None: - assert self.input_size % self.quant_config.pack_factor == 0 - self.qweight = Parameter( + def __init__(self, quant_config): + self.quant_config = quant_config + + def create_weights(self, module: torch.nn.Module, input_size: int, + output_size: int, params_dtype: torch.dtype) -> None: + if input_size % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + qweight = Parameter( torch.empty( - self.input_size // self.quant_config.pack_factor, - self.output_size_per_partition, + input_size // self.quant_config.pack_factor, + output_size, device="cuda", dtype=torch.int32, ), requires_grad=False, ) - self.lookup_table = Parameter( + lookup_table = Parameter( torch.empty( - self.output_size_per_partition, + output_size, self.quant_config.weight_bits**2, device="cuda", - dtype=dtype, + dtype=params_dtype, ), requires_grad=False, ) - - def apply_weights( - self, - x: torch.Tensor, - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - out_shape = x.shape[:-1] + (self.qweight.shape[-1], ) + module.register_parameter("qweight", qweight) + module.register_parameter("lookup_table", lookup_table) + + def apply_weights(self, + module: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + out_shape = x.shape[:-1] + (module.qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) # NOTE: The output tensor should be zero-initialized. out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) - quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out, - self.lookup_table) + quantization_ops.squeezellm_gemm(reshaped_x, module.qweight, out, + module.lookup_table) if bias is not None: out = out + bias return out.reshape(out_shape) - - -class SqueezeLLMRowParallelLinear(RowParallelLinear): - - def create_weights(self, dtype: torch.dtype) -> None: - if self.input_size_per_partition % self.quant_config.pack_factor != 0: - raise ValueError( - "The tensor parallel size is not aligned with the quantized " - "weight shape. Please use a different tensor parallel size.") - self.qweight = Parameter( - torch.empty( - self.input_size_per_partition // self.quant_config.pack_factor, - self.output_size, - device="cuda", - dtype=torch.int32, - ), - requires_grad=False, - ) - self.lookup_table = Parameter( - torch.empty( - self.output_size, - self.quant_config.weight_bits**2, - device="cuda", - dtype=dtype, - ), - requires_grad=False, - ) - - def apply_weights(self, x: torch.Tensor) -> torch.Tensor: - reshaped_x = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (self.qweight.shape[-1], ) - # NOTE: The output tensor should be zero-initialized. - out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) - quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out, - self.lookup_table) - return out.reshape(out_shape) From a97ede840f3677c6d28e05d7853bc5d5d3283a69 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 3 Nov 2023 22:21:48 +0000 Subject: [PATCH 02/51] Support llama with the new quantization scheme --- vllm/model_executor/layers/linear.py | 11 +-- .../layers/quantized_linear/awq.py | 85 ++++++++++++++++++- .../layers/quantized_linear/squeezellm.py | 8 ++ .../layers/vocab_parallel_embedding.py | 73 ++++++++++++++++ vllm/model_executor/models/llama.py | 66 +++++++------- .../model_executor/quantization_utils/base.py | 5 ++ 6 files changed, 212 insertions(+), 36 deletions(-) create mode 100644 vllm/model_executor/layers/vocab_parallel_embedding.py diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f3200e782c9ba..d0be87f51d8b9 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -42,8 +42,8 @@ def __init__(self, separate_bias_add: bool = False): def create_weights(self, module: torch.nn.Module, input_size: int, output_size: int, params_dtype: torch.dtype) -> None: - weight = Parameter(torch.empty(input_size, - output_size, + weight = Parameter(torch.empty(output_size, + input_size, device=torch.cuda.current_device(), dtype=params_dtype), requires_grad=False) @@ -229,6 +229,7 @@ def __init__( self.reduce_results = reduce_results if params_dtype is None: params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype # Divide the weight matrix along the last dimension. self.tp_size = get_tensor_model_parallel_world_size() @@ -237,8 +238,7 @@ def __init__( if linear_method is None: linear_method = FullPrecisionLinearMethod() self.linear_method = linear_method - - self.linear_method.create_weights(self.input_size_per_partition, + self.linear_method.create_weights(self, self.input_size_per_partition, self.output_size, self.params_dtype) if not reduce_results and (bias and not skip_bias_add): @@ -275,7 +275,8 @@ def forward(self, input_): input_parallel = splitted_input[tp_rank].contiguous() # Matrix multiply. - output_parallel = self.linear_method.apply_weights(input_parallel) + output_parallel = self.linear_method.apply_weights( + self, input_parallel) if self.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py index 43e51e8bdaab1..bdf7a963ae6c6 100644 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ b/vllm/model_executor/layers/quantized_linear/awq.py @@ -1,10 +1,85 @@ -from typing import Optional +from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter from vllm import quantization_ops from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.quantization_utils.base import QuantizationConfig + + +class AWQConfig(QuantizationConfig): + """Config class for AWQ. + + Reference: https://arxiv.org/abs/2306.00978 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"AWQ, but got {self.weight_bits} bits.") + self.pack_factor = 32 // self.weight_bits + + def __repr__(self) -> str: + return (f"AWQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point})") + + @classmethod + def get_name(cls) -> str: + return "awq" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + # The AWQ kernel only supports Turing or newer GPUs. + return 75 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return [ + "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq + "quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + return cls(weight_bits, group_size, zero_point) + + @classmethod + def get_packed_tensors(cls) -> Dict[str, int]: + return {"qweight": 1, "qzeros": 1} + + @classmethod + def get_transposed_tensor_names(cls) -> List[str]: + return ["qweight", "qzeros", "scales"] + + @classmethod + def get_col_parallel_tensor_names(cls) -> List[str]: + return ["qweight", "qzeros", "scales"] + + @classmethod + def get_row_parallel_tensor_names(cls) -> List[str]: + return ["qweight", "qzeros", "scales"] + + def get_linear_method(self) -> "AWQLinearMethod": + return AWQLinearMethod(self) class AWQLinearMethod(LinearMethodBase): @@ -69,3 +144,11 @@ def apply_weights(self, if bias is not None: out = out + bias return out.reshape(out_shape) + + +class AWQColumnParallelLinear: + pass + + +class AWQRowParallelLinear: + pass diff --git a/vllm/model_executor/layers/quantized_linear/squeezellm.py b/vllm/model_executor/layers/quantized_linear/squeezellm.py index eeccdc241e963..7a50bb978c4ba 100644 --- a/vllm/model_executor/layers/quantized_linear/squeezellm.py +++ b/vllm/model_executor/layers/quantized_linear/squeezellm.py @@ -54,3 +54,11 @@ def apply_weights(self, if bias is not None: out = out + bias return out.reshape(out_shape) + + +class SqueezeLLMColumnParallelLinear: + pass + + +class SqueezeLLMRowParallelLinear: + pass diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py new file mode 100644 index 0000000000000..f7cf8bd52dafa --- /dev/null +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -0,0 +1,73 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) + +from vllm.model_executor.parallel_utils.utils import VocabUtility + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + This is mainly adapted from torch.nn.Embedding and all the default + values are kept. + Arguments: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + params_dtype: type of the parameters. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None): + super().__init__() + + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + if params_dtype is None: + params_dtype = torch.get_default_dtype() + + self.tp_size = get_tensor_model_parallel_world_size() + # TODO: Handle vocab padding here. + # Divide the weight matrix along the vocaburaly dimension. + self.vocab_start_index, self.vocab_end_index = ( + VocabUtility.vocab_range_from_global_vocab_size( + self.num_embeddings, get_tensor_model_parallel_rank(), + self.tp_size)) + self.num_embeddings_per_partition = (self.vocab_end_index - + self.vocab_start_index) + + self.weight = Parameter( + torch.empty(self.num_embeddings_per_partition, + self.embedding_dim, + device=torch.cuda.current_device(), + dtype=params_dtype)) + + def forward(self, input_): + if self.tp_size > 1: + # Build the mask. + input_mask = ((input_ < self.vocab_start_index) | + (input_ >= self.vocab_end_index)) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + output_parallel = F.embedding(masked_input, self.weight) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce(output_parallel) + return output diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 735e4ad172182..c87ef6f4518b0 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -34,12 +34,15 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + RowParallelLinear) from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.quantized_linear import ParallelLinear +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import ( convert_pyslice_to_tensor, hf_model_weights_iterator, @@ -56,19 +59,19 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() - self.gate_up_proj = ParallelLinear.column(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - quant_config=quant_config) - self.down_proj = ParallelLinear.row(intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config) + self.gate_up_proj = ColumnParallelLinear(hidden_size, + 2 * intermediate_size, + bias=False, + gather_output=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -91,7 +94,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -117,21 +120,21 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.qkv_proj = ParallelLinear.column( + self.qkv_proj = ColumnParallelLinear( hidden_size, (self.total_num_heads + 2 * self.total_num_kv_heads * num_kv_heads_replicas) * self.head_dim, bias=False, gather_output=False, - quant_config=quant_config, + linear_method=linear_method, ) - self.o_proj = ParallelLinear.row( + self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, input_is_parallel=True, - quant_config=quant_config, + linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( self.num_heads, @@ -165,7 +168,7 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -181,13 +184,13 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - quant_config=quant_config, + linear_method=linear_method, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - quant_config=quant_config, + linear_method=linear_method, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -227,7 +230,7 @@ class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.config = config @@ -240,7 +243,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, quant_config) + LlamaDecoderLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -281,14 +284,18 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = LlamaModel(config, quant_config) + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None + self.model = LlamaModel(config, linear_method) vocab_size = ((config.vocab_size + 63) // 64) * 64 # NOTE: The LM head is not quantized. - self.lm_head = ParallelLinear.column(config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - quant_config=None) + self.lm_head = ColumnParallelLinear(config.hidden_size, + vocab_size, + bias=False, + gather_output=False, + linear_method=None) self.sampler = Sampler(config.vocab_size) def forward( @@ -422,7 +429,6 @@ def load_weights(self, load_padded_tensor_parallel_vocab(param, loaded_weight, tp_rank) continue - load_tensor_parallel_weights(param, loaded_weight, name, column_parallel_weights, row_parallel_weights, tp_rank) diff --git a/vllm/model_executor/quantization_utils/base.py b/vllm/model_executor/quantization_utils/base.py index a70a7a8631e41..97da60cd699cb 100644 --- a/vllm/model_executor/quantization_utils/base.py +++ b/vllm/model_executor/quantization_utils/base.py @@ -2,6 +2,8 @@ import torch +from vllm.model_executor.layers.linear import LinearMethodBase + class QuantizationConfig: @@ -83,3 +85,6 @@ def get_col_parallel_tensor_names(cls) -> List[str]: @classmethod def get_row_parallel_tensor_names(cls) -> List[str]: raise NotImplementedError + + def get_linear_method(self) -> LinearMethodBase: + raise NotImplementedError From 4671286d6430bbf772ea6533d3757795e2cc776a Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 3 Nov 2023 22:38:26 +0000 Subject: [PATCH 03/51] make awq work --- .../layers/quantized_linear/__init__.py | 29 +------------------ .../quantization_utils/__init__.py | 2 +- 2 files changed, 2 insertions(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantized_linear/__init__.py index b09358261d5d1..61eb8a7fe9c4e 100644 --- a/vllm/model_executor/layers/quantized_linear/__init__.py +++ b/vllm/model_executor/layers/quantized_linear/__init__.py @@ -2,8 +2,6 @@ AWQColumnParallelLinear, AWQRowParallelLinear) from vllm.model_executor.layers.quantized_linear.squeezellm import ( SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear) -from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, - RowParallelLinear) _QUANTIZED_LINEAR_REGISTRY = { "awq": (AWQColumnParallelLinear, AWQRowParallelLinear), @@ -13,29 +11,4 @@ class ParallelLinear: - - @classmethod - def column(cls, *args, **kwargs) -> ColumnParallelLinear: - quant_config = kwargs.get("quant_config", None) - if quant_config is None: - return ColumnParallelLinear(*args, **kwargs) - - name = quant_config.get_name() - if name not in _QUANTIZED_LINEAR_REGISTRY: - raise ValueError(f"No quantized linear is found for {name}") - - quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0] - return quant_linear_cls(*args, **kwargs) - - @classmethod - def row(cls, *args, **kwargs) -> RowParallelLinear: - quant_config = kwargs.get("quant_config", None) - if quant_config is None: - return RowParallelLinear(*args, **kwargs) - - name = quant_config.get_name() - if name not in _QUANTIZED_LINEAR_REGISTRY: - raise ValueError(f"No quantized linear is found for {name}") - - quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1] - return quant_linear_cls(*args, **kwargs) + pass \ No newline at end of file diff --git a/vllm/model_executor/quantization_utils/__init__.py b/vllm/model_executor/quantization_utils/__init__.py index 345f6494bf836..eb39585ddc9ce 100644 --- a/vllm/model_executor/quantization_utils/__init__.py +++ b/vllm/model_executor/quantization_utils/__init__.py @@ -1,6 +1,6 @@ from typing import Type -from vllm.model_executor.quantization_utils.awq import AWQConfig +from vllm.model_executor.layers.quantized_linear.awq import AWQConfig from vllm.model_executor.quantization_utils.base import QuantizationConfig from vllm.model_executor.quantization_utils.squeezellm import SqueezeLLMConfig From 4579d6745be85841a005a765daf71e872f1187bc Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 3 Nov 2023 22:59:23 +0000 Subject: [PATCH 04/51] Fix squeezellm --- pyproject.toml | 2 +- requirements.txt | 4 +- .../layers/quantized_linear/awq.py | 2 +- .../layers/quantized_linear/squeezellm.py | 70 ++++++++++++++++- .../quantization_utils/__init__.py | 2 +- vllm/model_executor/quantization_utils/awq.py | 76 ------------------- .../quantization_utils/squeezellm.py | 65 ---------------- 7 files changed, 72 insertions(+), 149 deletions(-) delete mode 100644 vllm/model_executor/quantization_utils/awq.py delete mode 100644 vllm/model_executor/quantization_utils/squeezellm.py diff --git a/pyproject.toml b/pyproject.toml index 360e023a05a02..27285bb683565 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = [ "ninja", "packaging", "setuptools", - "torch == 2.0.1", + "torch >= 2.1.0", "wheel", ] build-backend = "setuptools.build_meta" diff --git a/requirements.txt b/requirements.txt index d8597b3ec5543..9ef5725b3792d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,9 +5,9 @@ pandas # Required for Ray data. pyarrow # Required for Ray data. sentencepiece # Required for LLaMA tokenizer. numpy -torch == 2.0.1 +torch >= 2.1.0 transformers >= 4.34.0 # Required for Mistral. -xformers == 0.0.22 # Required for Mistral. +xformers >= 0.0.22 # Required for Mistral. fastapi uvicorn[standard] pydantic == 1.10.13 # Required for OpenAI server. diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py index bdf7a963ae6c6..e6874bac3589c 100644 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ b/vllm/model_executor/layers/quantized_linear/awq.py @@ -84,7 +84,7 @@ def get_linear_method(self) -> "AWQLinearMethod": class AWQLinearMethod(LinearMethodBase): - def __init__(self, quant_config): + def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config def create_weights(self, module: torch.nn.Module, input_size: int, diff --git a/vllm/model_executor/layers/quantized_linear/squeezellm.py b/vllm/model_executor/layers/quantized_linear/squeezellm.py index 7a50bb978c4ba..a910b7d747fdb 100644 --- a/vllm/model_executor/layers/quantized_linear/squeezellm.py +++ b/vllm/model_executor/layers/quantized_linear/squeezellm.py @@ -1,20 +1,84 @@ -from typing import Optional +from typing import Any, Dict, List, Optional import torch from torch.nn.parameter import Parameter from vllm import quantization_ops from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.quantization_utils.base import QuantizationConfig + + +class SqueezeLLMConfig(QuantizationConfig): + """Config class for SqueezeLLM. + + Reference: https://arxiv.org/pdf/2306.07629 + """ + + def __init__( + self, + weight_bits: int, + ) -> None: + self.weight_bits = weight_bits + + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"SqueezeLLM, but got {self.weight_bits} bits.") + + self.pack_factor = 32 // self.weight_bits + + def __repr__(self) -> str: + return f"SqueezeLLMConfig(weight_bits={self.weight_bits})" + + @classmethod + def get_name(cls) -> str: + return "squeezellm" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 70 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": + weight_bits = cls.get_from_keys(config, ["wbits"]) + return cls(weight_bits) + + @classmethod + def get_packed_tensors(cls) -> Dict[str, int]: + return {"qweight": 0} + + @classmethod + def get_transposed_tensor_names(cls) -> List[str]: + return ["qweight"] + + @classmethod + def get_col_parallel_tensor_names(cls) -> List[str]: + return ["qweight", "lookup_table"] + + @classmethod + def get_row_parallel_tensor_names(cls) -> List[str]: + return ["qweight"] + + def get_linear_method(self) -> "SqueezeLLMLinearMethod": + return SqueezeLLMLinearMethod(self) class SqueezeLLMLinearMethod(LinearMethodBase): - def __init__(self, quant_config): + def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config def create_weights(self, module: torch.nn.Module, input_size: int, output_size: int, params_dtype: torch.dtype) -> None: - if input_size % self.quant_config.group_size != 0: + if input_size % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " "weight shape. This can be caused by too large " diff --git a/vllm/model_executor/quantization_utils/__init__.py b/vllm/model_executor/quantization_utils/__init__.py index eb39585ddc9ce..2b4df01400e95 100644 --- a/vllm/model_executor/quantization_utils/__init__.py +++ b/vllm/model_executor/quantization_utils/__init__.py @@ -1,8 +1,8 @@ from typing import Type from vllm.model_executor.layers.quantized_linear.awq import AWQConfig +from vllm.model_executor.layers.quantized_linear.squeezellm import SqueezeLLMConfig from vllm.model_executor.quantization_utils.base import QuantizationConfig -from vllm.model_executor.quantization_utils.squeezellm import SqueezeLLMConfig _QUANTIZATION_REGISTRY = { "awq": AWQConfig, diff --git a/vllm/model_executor/quantization_utils/awq.py b/vllm/model_executor/quantization_utils/awq.py deleted file mode 100644 index ebc89560a4477..0000000000000 --- a/vllm/model_executor/quantization_utils/awq.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Any, Dict, List - -import torch - -from vllm.model_executor.quantization_utils.base import QuantizationConfig - - -class AWQConfig(QuantizationConfig): - """Config class for AWQ. - - Reference: https://arxiv.org/abs/2306.00978 - """ - - def __init__( - self, - weight_bits: int, - group_size: int, - zero_point: bool, - ) -> None: - self.weight_bits = weight_bits - self.group_size = group_size - self.zero_point = zero_point - - if self.weight_bits != 4: - raise ValueError( - "Currently, only 4-bit weight quantization is supported for " - f"AWQ, but got {self.weight_bits} bits.") - self.pack_factor = 32 // self.weight_bits - - def __repr__(self) -> str: - return (f"AWQConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"zero_point={self.zero_point})") - - @classmethod - def get_name(cls) -> str: - return "awq" - - @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.half] - - @classmethod - def get_min_capability(cls) -> int: - # The AWQ kernel only supports Turing or newer GPUs. - return 75 - - @classmethod - def get_config_filenames(cls) -> List[str]: - return [ - "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq - "quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long - ] - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": - weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) - group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) - zero_point = cls.get_from_keys(config, ["zero_point"]) - return cls(weight_bits, group_size, zero_point) - - @classmethod - def get_packed_tensors(cls) -> Dict[str, int]: - return {"qweight": 1, "qzeros": 1} - - @classmethod - def get_transposed_tensor_names(cls) -> List[str]: - return ["qweight", "qzeros", "scales"] - - @classmethod - def get_col_parallel_tensor_names(cls) -> List[str]: - return ["qweight", "qzeros", "scales"] - - @classmethod - def get_row_parallel_tensor_names(cls) -> List[str]: - return ["qweight", "qzeros", "scales"] diff --git a/vllm/model_executor/quantization_utils/squeezellm.py b/vllm/model_executor/quantization_utils/squeezellm.py deleted file mode 100644 index 8a1db3e233217..0000000000000 --- a/vllm/model_executor/quantization_utils/squeezellm.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Any, Dict, List - -import torch - -from vllm.model_executor.quantization_utils.base import QuantizationConfig - - -class SqueezeLLMConfig(QuantizationConfig): - """Config class for SqueezeLLM. - - Reference: https://arxiv.org/pdf/2306.07629 - """ - - def __init__( - self, - weight_bits: int, - ) -> None: - self.weight_bits = weight_bits - - if self.weight_bits != 4: - raise ValueError( - "Currently, only 4-bit weight quantization is supported for " - f"SqueezeLLM, but got {self.weight_bits} bits.") - - self.pack_factor = 32 // self.weight_bits - - def __repr__(self) -> str: - return f"SqueezeLLMConfig(weight_bits={self.weight_bits})" - - @classmethod - def get_name(cls) -> str: - return "squeezellm" - - @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.half] - - @classmethod - def get_min_capability(cls) -> int: - return 70 - - @classmethod - def get_config_filenames(cls) -> List[str]: - return ["quant_config.json"] - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": - weight_bits = cls.get_from_keys(config, ["wbits"]) - return cls(weight_bits) - - @classmethod - def get_packed_tensors(cls) -> Dict[str, int]: - return {"qweight": 0} - - @classmethod - def get_transposed_tensor_names(cls) -> List[str]: - return ["qweight"] - - @classmethod - def get_col_parallel_tensor_names(cls) -> List[str]: - return ["qweight", "lookup_table"] - - @classmethod - def get_row_parallel_tensor_names(cls) -> List[str]: - return ["qweight"] From 4406447cfce33a798a284d8d161f5afc17436eb5 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 3 Nov 2023 23:10:17 +0000 Subject: [PATCH 05/51] Remove unused codes --- .../layers/quantized_linear/__init__.py | 14 +------------- vllm/model_executor/layers/quantized_linear/awq.py | 8 -------- .../layers/quantized_linear/squeezellm.py | 8 -------- 3 files changed, 1 insertion(+), 29 deletions(-) diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantized_linear/__init__.py index 61eb8a7fe9c4e..21b7a4bdfcbc6 100644 --- a/vllm/model_executor/layers/quantized_linear/__init__.py +++ b/vllm/model_executor/layers/quantized_linear/__init__.py @@ -1,14 +1,2 @@ -from vllm.model_executor.layers.quantized_linear.awq import ( - AWQColumnParallelLinear, AWQRowParallelLinear) -from vllm.model_executor.layers.quantized_linear.squeezellm import ( - SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear) - -_QUANTIZED_LINEAR_REGISTRY = { - "awq": (AWQColumnParallelLinear, AWQRowParallelLinear), - "squeezellm": - (SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear), -} - - class ParallelLinear: - pass \ No newline at end of file + pass diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py index e6874bac3589c..685a1d40200b8 100644 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ b/vllm/model_executor/layers/quantized_linear/awq.py @@ -144,11 +144,3 @@ def apply_weights(self, if bias is not None: out = out + bias return out.reshape(out_shape) - - -class AWQColumnParallelLinear: - pass - - -class AWQRowParallelLinear: - pass diff --git a/vllm/model_executor/layers/quantized_linear/squeezellm.py b/vllm/model_executor/layers/quantized_linear/squeezellm.py index a910b7d747fdb..582e502fb2267 100644 --- a/vllm/model_executor/layers/quantized_linear/squeezellm.py +++ b/vllm/model_executor/layers/quantized_linear/squeezellm.py @@ -118,11 +118,3 @@ def apply_weights(self, if bias is not None: out = out + bias return out.reshape(out_shape) - - -class SqueezeLLMColumnParallelLinear: - pass - - -class SqueezeLLMRowParallelLinear: - pass From 5a535e39c42952eb8ab06f03447995ba6f3d241b Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 3 Nov 2023 23:36:18 +0000 Subject: [PATCH 06/51] Fix mistral --- vllm/model_executor/models/mistral.py | 48 +++++++++++++++------------ 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 8b09276e6f91d..bd03c53052923 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -34,12 +34,15 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + RowParallelLinear) from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.quantized_linear import ParallelLinear +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import ( convert_pyslice_to_tensor, hf_model_weights_iterator, @@ -56,19 +59,19 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() - self.gate_up_proj = ParallelLinear.column(hidden_size, + self.gate_up_proj = ColumnParallelLinear(hidden_size, 2 * intermediate_size, bias=False, gather_output=False, - quant_config=quant_config) - self.down_proj = ParallelLinear.row(intermediate_size, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, input_is_parallel=True, - quant_config=quant_config) + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -89,7 +92,7 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, sliding_window: Optional[int] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -107,20 +110,20 @@ def __init__(self, self.rope_theta = rope_theta self.sliding_window = sliding_window - self.qkv_proj = ParallelLinear.column( + self.qkv_proj = ColumnParallelLinear( hidden_size, (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, bias=False, gather_output=False, - quant_config=quant_config, + linear_method=linear_method, ) - self.o_proj = ParallelLinear.row( + self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, input_is_parallel=True, - quant_config=quant_config, + linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, @@ -153,7 +156,7 @@ class MistralDecoderLayer(nn.Module): def __init__( self, config: MistralConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -165,13 +168,13 @@ def __init__( max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - quant_config=quant_config, + linear_method=linear_method, sliding_window=config.sliding_window) self.mlp = MistralMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - quant_config=quant_config, + linear_method=linear_method, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -211,7 +214,7 @@ class MistralModel(nn.Module): def __init__( self, config: MistralConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.config = config @@ -224,7 +227,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - MistralDecoderLayer(config, quant_config) + MistralDecoderLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -265,14 +268,17 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = MistralModel(config, quant_config) + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None + self.model = MistralModel(config, linear_method) vocab_size = ((config.vocab_size + 63) // 64) * 64 # NOTE: The LM head is not quantized. - self.lm_head = ParallelLinear.column(config.hidden_size, + self.lm_head = ColumnParallelLinear(config.hidden_size, vocab_size, bias=False, - gather_output=False, - quant_config=None) + gather_output=False) self.sampler = Sampler(config.vocab_size) def forward( From 14e66f8735a1d19096ac7ef64711c234885467c4 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 3 Nov 2023 23:56:18 +0000 Subject: [PATCH 07/51] Fix format --- vllm/model_executor/models/mistral.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index bd03c53052923..d8e440123cc4d 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -63,15 +63,15 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = ColumnParallelLinear(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - linear_method=linear_method) + 2 * intermediate_size, + bias=False, + gather_output=False, + linear_method=linear_method) self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - linear_method=linear_method) + hidden_size, + bias=False, + input_is_parallel=True, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -276,9 +276,9 @@ def __init__( vocab_size = ((config.vocab_size + 63) // 64) * 64 # NOTE: The LM head is not quantized. self.lm_head = ColumnParallelLinear(config.hidden_size, - vocab_size, - bias=False, - gather_output=False) + vocab_size, + bias=False, + gather_output=False) self.sampler = Sampler(config.vocab_size) def forward( From f4643758e3775184902bc2d1e448578b0d8a4ae8 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 8 Nov 2023 00:59:06 +0000 Subject: [PATCH 08/51] New weight loading method, working for llama --- vllm/model_executor/layers/linear.py | 49 ++++++- .../layers/quantized_linear/awq.py | 29 ++++- vllm/model_executor/models/llama.py | 123 +++++------------- vllm/model_executor/weight_utils.py | 67 ++++++---- 4 files changed, 136 insertions(+), 132 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index d0be87f51d8b9..44173c0643a7e 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Optional +from typing import Any, Optional import torch import torch.nn.functional as F @@ -18,11 +18,25 @@ ) +def set_weight_attrs(weight: torch.Tensor, weight_attrs: Optional[dict[str, + Any]]): + if weight_attrs is None: + return + for key, value in weight_attrs.items(): + assert not hasattr( + weight, key), (f"Overwriting existing tensor attribute: {key}") + setattr(weight, key, value) + + class LinearMethodBase(ABC): @abstractmethod - def create_weights(self, module: torch.nn.Module, input_size: int, - output_size: int, params_dtype: torch.dtype) -> None: + def create_weights(self, + module: torch.nn.Module, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_attrs: dict[str, Any] = None) -> None: del module raise NotImplementedError @@ -40,14 +54,24 @@ class FullPrecisionLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, module: torch.nn.Module, input_size: int, - output_size: int, params_dtype: torch.dtype) -> None: + def create_weights(self, + module: torch.nn.Module, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_attrs: dict[str, Any] = None) -> None: weight = Parameter(torch.empty(output_size, input_size, device=torch.cuda.current_device(), dtype=params_dtype), requires_grad=False) + set_weight_attrs(weight, { + "input_dim": 1, + "output_dim": 0, + **weight_attrs + }) module.register_parameter("weight", weight) + #print("dir(module.weight):", dir(module.weight)) def apply_weights(self, module: torch.nn.Module, @@ -90,6 +114,7 @@ def __init__( torch.empty(self.output_size, device=torch.cuda.current_device(), dtype=self.params_dtype)) + set_weight_attrs(self.bias, {"output_dim": 0}) else: self.register_parameter('bias', None) @@ -150,12 +175,17 @@ def __init__( self.linear_method = linear_method self.linear_method.create_weights(self, self.input_size, self.output_size_per_partition, - self.params_dtype) + self.params_dtype, + {"output_dim_parallel": True}) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, device=torch.cuda.current_device(), dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "output_dim_parallel": True + }) else: self.register_parameter('bias', None) @@ -239,7 +269,8 @@ def __init__( linear_method = FullPrecisionLinearMethod() self.linear_method = linear_method self.linear_method.create_weights(self, self.input_size_per_partition, - self.output_size, self.params_dtype) + self.output_size, self.params_dtype, + {"input_dim_parallel": True}) if not reduce_results and (bias and not skip_bias_add): raise ValueError('When not reduce the results, adding bias to the ' @@ -250,6 +281,10 @@ def __init__( torch.empty(self.output_size, device=torch.cuda.current_device(), dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "input_dim_parallel": True + }) else: self.register_parameter('bias', None) diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py index 685a1d40200b8..cfd96c879c0a0 100644 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ b/vllm/model_executor/layers/quantized_linear/awq.py @@ -4,7 +4,8 @@ from torch.nn.parameter import Parameter from vllm import quantization_ops -from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) from vllm.model_executor.quantization_utils.base import QuantizationConfig @@ -87,8 +88,12 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, module: torch.nn.Module, input_size: int, - output_size: int, params_dtype: torch.dtype) -> None: + def create_weights(self, + module: torch.nn.Module, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_attrs: dict[str, Any] = None) -> None: if input_size % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -109,6 +114,13 @@ def create_weights(self, module: torch.nn.Module, input_size: int, ), requires_grad=False, ) + set_weight_attrs(qweight, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + **weight_attrs, + }) + set_weight_attrs(qweight, weight_attrs) qzeros = Parameter( torch.empty( input_size // self.quant_config.group_size, @@ -118,6 +130,12 @@ def create_weights(self, module: torch.nn.Module, input_size: int, ), requires_grad=False, ) + set_weight_attrs(qzeros, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + **weight_attrs, + }) scales = Parameter( torch.empty( input_size // self.quant_config.group_size, @@ -127,6 +145,11 @@ def create_weights(self, module: torch.nn.Module, input_size: int, ), requires_grad=False, ) + set_weight_attrs(qzeros, { + "input_dim": 0, + "output_dim": 1, + **weight_attrs, + }) module.register_parameter("qweight", qweight) module.register_parameter("qzeros", qzeros) module.register_parameter("scales", scales) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c87ef6f4518b0..76926ebfe5ce3 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -312,123 +312,58 @@ def forward( input_metadata) return next_tokens - _column_parallel_layers = [] - _row_parallel_layers = ["o_proj", "down_proj"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - if self.quant_config is None: - col_weight_suffixes = ["weight"] - row_weight_suffixes = ["weight"] - else: - col_weight_suffixes = ( - self.quant_config.get_col_parallel_tensor_names()) - row_weight_suffixes = ( - self.quant_config.get_row_parallel_tensor_names()) - - column_parallel_weights: List[str] = [] - for layer in self._column_parallel_layers: - for suffix in col_weight_suffixes: - column_parallel_weights.append(f"{layer}.{suffix}") - row_parallel_weights: List[str] = [] - for layer in self._row_parallel_layers: - for suffix in row_weight_suffixes: - row_parallel_weights.append(f"{layer}.{suffix}") - tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() - q_proj_shard_size = (self.config.hidden_size // tp_size) num_kv_heads_replicas = max(1, tp_size // self.config.num_key_value_heads) num_kv_heads_per_gpu = max(1, self.config.num_key_value_heads // tp_size) - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - num_kv_heads_per_gpu) - attention_weight_specs = [ - # (weight_name, shard_size, offset) - ("q_proj", q_proj_shard_size, 0), - ("k_proj", kv_proj_shard_size, q_proj_shard_size), - ("v_proj", kv_proj_shard_size, - q_proj_shard_size + kv_proj_shard_size), + + q_proj_size = self.config.hidden_size + kv_proj_size = (self.config.hidden_size // + self.config.num_attention_heads * + self.config.num_key_value_heads) + ffn_intermediate_size = self.config.intermediate_size + stacked_params_mapping = [ + # (param_name, shard_name, slice_size, offset) + ("qkv_proj", "q_proj", q_proj_size, 0), + ("qkv_proj", "k_proj", kv_proj_size, q_proj_size), + ("qkv_proj", "v_proj", kv_proj_size, q_proj_size + kv_proj_size), + ("gate_up_proj", "gate_proj", ffn_intermediate_size, 0), + ("gate_up_proj", "up_proj", ffn_intermediate_size, + ffn_intermediate_size), ] - state_dict = self.state_dict() + + state_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - packed_dim = None - is_transposed = False - if self.quant_config is not None: - packed_dim = self.quant_config.get_packed_dim(name) - is_transposed = self.quant_config.is_transposed(name) - if is_transposed: - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - loaded_weight = loaded_weight.T - - is_attention_weight = False - for weight_name, shard_size, offset in attention_weight_specs: + loaded = False + for (param_name, weight_name, slice_size, + slice_offset) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "qkv_proj")] - if is_transposed: - param = param.T - - if packed_dim is not None: - shard_dim = 0 if not is_transposed else 1 - if packed_dim == shard_dim: - shard_size //= self.quant_config.pack_factor - offset //= self.quant_config.pack_factor - - if weight_name in ["k_proj", "v_proj"]: - shard_id = tp_rank // num_kv_heads_replicas - else: - shard_id = tp_rank - loaded_weight = loaded_weight[shard_size * - shard_id:shard_size * - (shard_id + 1)] - param_slice = param.data[offset:offset + shard_size] - assert param_slice.shape == loaded_weight.shape - - param_slice.copy_(loaded_weight) - is_attention_weight = True + param = state_dict[name.replace(weight_name, param_name)] + # TODO: fix the case when num kv heads < tp size + load_tensor_parallel_weights(param, + loaded_weight, + output_slice_offset=slice_offset, + output_slice_size=slice_size) + loaded = True break - if is_attention_weight: - continue - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - if is_transposed: - param = param.T - - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * - (tp_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True - break - if is_gate_up_weight: + if loaded: continue param = state_dict[name] - if is_transposed: - param = param.T - if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tp_rank) + load_padded_tensor_parallel_vocab(param, loaded_weight) continue - load_tensor_parallel_weights(param, loaded_weight, name, - column_parallel_weights, - row_parallel_weights, tp_rank) + load_tensor_parallel_weights(param, loaded_weight) diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 027e741c5f6bf..81ccadf063ce7 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -13,6 +13,8 @@ from tqdm.auto import tqdm from vllm.logger import init_logger +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.quantization_utils import get_quant_class from vllm.model_executor.quantization_utils.base import QuantizationConfig @@ -263,46 +265,55 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: def load_padded_tensor_parallel_vocab( - param: torch.Tensor, - loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` - tensor_model_parallel_rank: int, + param: torch.Tensor, + loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` ) -> None: + tp_rank = get_tensor_model_parallel_rank() shard_size = param.shape[0] - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size + start_idx = tp_rank * shard_size + end_idx = (tp_rank + 1) * shard_size loaded_weight = loaded_weight[start_idx:end_idx] loaded_weight = convert_pyslice_to_tensor(loaded_weight) - param[:loaded_weight.shape[0]].copy_(loaded_weight) + param[:loaded_weight.shape[0]].data.copy_(loaded_weight) def load_tensor_parallel_weights( param: torch.Tensor, loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` - param_name: str, - column_parallel_weight_names: List[str], - row_parallel_weight_names: List[str], - tensor_model_parallel_rank: int, + output_slice_offset: Optional[int] = None, + output_slice_size: Optional[int] = None, ) -> None: - for p in column_parallel_weight_names: - if p in param_name: - shard_size = param.shape[0] - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - loaded_weight = loaded_weight[start_idx:end_idx] - break - for p in row_parallel_weight_names: - if p in param_name: - shard_size = param.shape[1] - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - loaded_weight = loaded_weight[:, start_idx:end_idx] - break + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + param_data = param.data + if output_slice_offset is not None and output_slice_size is not None: + output_dim = getattr(param, "output_dim", None) + assert output_dim is not None + output_slice_offset = output_slice_offset // tp_size + output_slice_size = output_slice_size // tp_size + param_data = param_data.narrow(output_dim, output_slice_offset, + output_slice_size) loaded_weight = convert_pyslice_to_tensor(loaded_weight) - assert param.shape == loaded_weight.shape, ( - f"{param_name} shape mismatch between model and checkpoint: " - f"{param.shape} != {loaded_weight.shape}") - param.data.copy_(loaded_weight) + if getattr(param, "output_dim_parallel", False): + output_dim = getattr(param, "output_dim", None) + if output_dim is not None: + shard_size = param_data.shape[output_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + if getattr(param, "input_dim_parallel", False): + input_dim = getattr(param, "input_dim", None) + if input_dim is not None: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, + shard_size) + + assert param_data.shape == loaded_weight.shape, ( + f"shape mismatch between model and checkpoint: " + f"{param_data.shape} != {loaded_weight.shape}") + param_data.copy_(loaded_weight) def initialize_dummy_weights( From a5852eff71868f7acbcbbf6befee7ac57db4b9b5 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 8 Nov 2023 01:13:20 +0000 Subject: [PATCH 09/51] Fix awq loading --- .../model_executor/layers/quantized_linear/awq.py | 5 +++-- vllm/model_executor/weight_utils.py | 15 +++++++++++---- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py index cfd96c879c0a0..55d8f1e50f468 100644 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ b/vllm/model_executor/layers/quantized_linear/awq.py @@ -118,9 +118,9 @@ def create_weights(self, "input_dim": 0, "output_dim": 1, "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, **weight_attrs, }) - set_weight_attrs(qweight, weight_attrs) qzeros = Parameter( torch.empty( input_size // self.quant_config.group_size, @@ -134,6 +134,7 @@ def create_weights(self, "input_dim": 0, "output_dim": 1, "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, **weight_attrs, }) scales = Parameter( @@ -145,7 +146,7 @@ def create_weights(self, ), requires_grad=False, ) - set_weight_attrs(qzeros, { + set_weight_attrs(scales, { "input_dim": 0, "output_dim": 1, **weight_attrs, diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 81ccadf063ce7..71ec29a313008 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -287,10 +287,17 @@ def load_tensor_parallel_weights( tp_size = get_tensor_model_parallel_world_size() param_data = param.data if output_slice_offset is not None and output_slice_size is not None: - output_dim = getattr(param, "output_dim", None) - assert output_dim is not None - output_slice_offset = output_slice_offset // tp_size - output_slice_size = output_slice_size // tp_size + output_dim = param.output_dim + # Adjust the offset and size to account for tensor parallelism. + output_slice_offset //= tp_size + output_slice_size //= tp_size + # If quantized, we need to adjust the offset and size to account for + # the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + pack_factor = param.pack_factor + output_slice_offset //= pack_factor + output_slice_size //= pack_factor param_data = param_data.narrow(output_dim, output_slice_offset, output_slice_size) From 7bf933f836d28cc213d7cdd04657439c5204d419 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 8 Nov 2023 01:39:59 +0000 Subject: [PATCH 10/51] Fix squeeze llm --- .../layers/quantized_linear/awq.py | 46 +++++++------------ .../layers/quantized_linear/squeezellm.py | 24 ++++++++-- 2 files changed, 37 insertions(+), 33 deletions(-) diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py index 55d8f1e50f468..4c1707222e435 100644 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ b/vllm/model_executor/layers/quantized_linear/awq.py @@ -63,22 +63,6 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": zero_point = cls.get_from_keys(config, ["zero_point"]) return cls(weight_bits, group_size, zero_point) - @classmethod - def get_packed_tensors(cls) -> Dict[str, int]: - return {"qweight": 1, "qzeros": 1} - - @classmethod - def get_transposed_tensor_names(cls) -> List[str]: - return ["qweight", "qzeros", "scales"] - - @classmethod - def get_col_parallel_tensor_names(cls) -> List[str]: - return ["qweight", "qzeros", "scales"] - - @classmethod - def get_row_parallel_tensor_names(cls) -> List[str]: - return ["qweight", "qzeros", "scales"] - def get_linear_method(self) -> "AWQLinearMethod": return AWQLinearMethod(self) @@ -114,13 +98,14 @@ def create_weights(self, ), requires_grad=False, ) - set_weight_attrs(qweight, { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - **weight_attrs, - }) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + **weight_attrs, + }) qzeros = Parameter( torch.empty( input_size // self.quant_config.group_size, @@ -130,13 +115,14 @@ def create_weights(self, ), requires_grad=False, ) - set_weight_attrs(qzeros, { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - **weight_attrs, - }) + set_weight_attrs( + qzeros, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + **weight_attrs, + }) scales = Parameter( torch.empty( input_size // self.quant_config.group_size, diff --git a/vllm/model_executor/layers/quantized_linear/squeezellm.py b/vllm/model_executor/layers/quantized_linear/squeezellm.py index 582e502fb2267..7e8a646dc1294 100644 --- a/vllm/model_executor/layers/quantized_linear/squeezellm.py +++ b/vllm/model_executor/layers/quantized_linear/squeezellm.py @@ -4,7 +4,8 @@ from torch.nn.parameter import Parameter from vllm import quantization_ops -from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) from vllm.model_executor.quantization_utils.base import QuantizationConfig @@ -76,8 +77,12 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, module: torch.nn.Module, input_size: int, - output_size: int, params_dtype: torch.dtype) -> None: + def create_weights(self, + module: torch.nn.Module, + input_size: int, + output_size: int, + params_dtype: torch.dtype, + weight_attrs: dict[str, Any] = None) -> None: if input_size % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -92,6 +97,14 @@ def create_weights(self, module: torch.nn.Module, input_size: int, ), requires_grad=False, ) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 0, + "pack_factor": self.quant_config.pack_factor, + **weight_attrs, + }) lookup_table = Parameter( torch.empty( output_size, @@ -101,6 +114,11 @@ def create_weights(self, module: torch.nn.Module, input_size: int, ), requires_grad=False, ) + set_weight_attrs(lookup_table, { + "output_dim": 0, + **weight_attrs, + }) + module.register_parameter("qweight", qweight) module.register_parameter("lookup_table", lookup_table) From 8af8b609878c7cfd23551dcd69c3ff7e3c45058d Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 8 Nov 2023 01:49:23 +0000 Subject: [PATCH 11/51] fix quantization --- .../layers/quantized_linear/squeezellm.py | 16 -------- .../model_executor/quantization_utils/base.py | 40 ------------------- 2 files changed, 56 deletions(-) diff --git a/vllm/model_executor/layers/quantized_linear/squeezellm.py b/vllm/model_executor/layers/quantized_linear/squeezellm.py index 7e8a646dc1294..a881281354943 100644 --- a/vllm/model_executor/layers/quantized_linear/squeezellm.py +++ b/vllm/model_executor/layers/quantized_linear/squeezellm.py @@ -52,22 +52,6 @@ def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": weight_bits = cls.get_from_keys(config, ["wbits"]) return cls(weight_bits) - @classmethod - def get_packed_tensors(cls) -> Dict[str, int]: - return {"qweight": 0} - - @classmethod - def get_transposed_tensor_names(cls) -> List[str]: - return ["qweight"] - - @classmethod - def get_col_parallel_tensor_names(cls) -> List[str]: - return ["qweight", "lookup_table"] - - @classmethod - def get_row_parallel_tensor_names(cls) -> List[str]: - return ["qweight"] - def get_linear_method(self) -> "SqueezeLLMLinearMethod": return SqueezeLLMLinearMethod(self) diff --git a/vllm/model_executor/quantization_utils/base.py b/vllm/model_executor/quantization_utils/base.py index 97da60cd699cb..65e36d24264b6 100644 --- a/vllm/model_executor/quantization_utils/base.py +++ b/vllm/model_executor/quantization_utils/base.py @@ -46,45 +46,5 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: raise ValueError(f"Cannot find any of {keys} in the model's " "quantization config.") - @classmethod - def get_packed_tensors(cls) -> Dict[str, int]: - """Returns a dictionary of packed tensor names and their pack dims.""" - raise NotImplementedError - - @classmethod - def get_packed_dim(cls, tensor_name: str) -> Optional[int]: - """Returns the pack dim of a tensor if it is packed. - - A tensor is considered packed if each element in the tensor is a - packed representation of multiple elements in the original tensor. - For example, an INT32 element in the tensor may represent 8 INT4 - elements in the original tensor. - If the tensor is not packed, returns None. - """ - packed_tensors = cls.get_packed_tensors() - for packed_tensor_name, pack_dim in packed_tensors.items(): - if packed_tensor_name in tensor_name: - return pack_dim - return None - - @classmethod - def get_transposed_tensor_names(cls) -> List[str]: - raise NotImplementedError - - @classmethod - def is_transposed(cls, tensor_name: str) -> bool: - """Returns True if a tensor is transposed relative to nn.Linear.weight. - """ - return any(tag in tensor_name - for tag in cls.get_transposed_tensor_names()) - - @classmethod - def get_col_parallel_tensor_names(cls) -> List[str]: - raise NotImplementedError - - @classmethod - def get_row_parallel_tensor_names(cls) -> List[str]: - raise NotImplementedError - def get_linear_method(self) -> LinearMethodBase: raise NotImplementedError From 686dafbe71c65fa18f33c3427efa4636fcd95475 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 9 Nov 2023 04:33:59 +0000 Subject: [PATCH 12/51] new weight loader --- vllm/model_executor/layers/linear.py | 228 +++++++++++++++++++++------ vllm/model_executor/models/llama.py | 61 +++---- vllm/model_executor/weight_utils.py | 43 +---- 3 files changed, 212 insertions(+), 120 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 44173c0643a7e..b3d5fb79a3a56 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Any, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn.functional as F @@ -16,6 +16,9 @@ divide, split_tensor_along_last_dim, ) +from vllm.logger import init_logger + +logger = init_logger(__name__) def set_weight_attrs(weight: torch.Tensor, weight_attrs: Optional[dict[str, @@ -31,21 +34,15 @@ def set_weight_attrs(weight: torch.Tensor, weight_attrs: Optional[dict[str, class LinearMethodBase(ABC): @abstractmethod - def create_weights(self, - module: torch.nn.Module, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - weight_attrs: dict[str, Any] = None) -> None: - del module + def create_weights(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: raise NotImplementedError @abstractmethod def apply_weights(self, - module: torch.nn.Module, + weights: Dict[str, torch.Tensor], x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - del module, x raise NotImplementedError @@ -54,34 +51,26 @@ class FullPrecisionLinearMethod(LinearMethodBase): def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add - def create_weights(self, - module: torch.nn.Module, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - weight_attrs: dict[str, Any] = None) -> None: + def create_weights(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: weight = Parameter(torch.empty(output_size, input_size, device=torch.cuda.current_device(), dtype=params_dtype), requires_grad=False) - set_weight_attrs(weight, { - "input_dim": 1, - "output_dim": 0, - **weight_attrs - }) - module.register_parameter("weight", weight) - #print("dir(module.weight):", dir(module.weight)) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + return {"weight": weight} def apply_weights(self, - module: torch.nn.Module, + weights: Dict[str, torch.Tensor], x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = weights["weight"] if self.separate_bias_add: if bias: - return F.linear(x, module.weight) + bias - return F.linear(x, module.weight) - return F.linear(x, module.weight, bias) + return F.linear(x, weight) + bias + return F.linear(x, weight) + return F.linear(x, weight, bias) class ReplicatedLinear(torch.nn.Module): @@ -107,8 +96,10 @@ def __init__( if linear_method is None: linear_method = FullPrecisionLinearMethod() self.linear_method = linear_method - self.linear_method.create_weights(self, self.input_size, - self.output_size, self.params_dtype) + self.linear_weights = self.linear_method.create_weights( + self.input_size, self.output_size, self.params_dtype) + for name, weight in self.linear_weights.items(): + self.register_parameter(name, weight) if bias: self.bias = Parameter( torch.empty(self.output_size, @@ -120,7 +111,7 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None - output = self.linear_method.apply_weights(self, x, bias) + output = self.linear_method.apply_weights(self.linear_weights, x, bias) output_bias = self.bias if self.skip_bias_add else None return output, output_bias @@ -152,7 +143,7 @@ def __init__( input_size: int, output_size: int, bias: bool = True, - gather_output: bool = True, + gather_output: bool = False, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, linear_method: Optional[LinearMethodBase] = None, @@ -164,8 +155,8 @@ def __init__( self.output_size = output_size self.gather_output = gather_output # Divide the weight matrix along the last dimension. - self.tp_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, self.tp_size) + tp_size = get_tensor_model_parallel_world_size() + self.output_size_per_partition = divide(output_size, tp_size) self.skip_bias_add = skip_bias_add if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -173,10 +164,11 @@ def __init__( if linear_method is None: linear_method = FullPrecisionLinearMethod() self.linear_method = linear_method - self.linear_method.create_weights(self, self.input_size, - self.output_size_per_partition, - self.params_dtype, - {"output_dim_parallel": True}) + self.linear_weights = self.linear_method.create_weights( + self.input_size, self.output_size_per_partition, self.params_dtype) + for name, weight in self.linear_weights.items(): + self.register_parameter(name, weight) + set_weight_attrs(weight, {"weight_loader": self.weight_loader}) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -184,11 +176,23 @@ def __init__( dtype=params_dtype)) set_weight_attrs(self.bias, { "output_dim": 0, - "output_dim_parallel": True + "weight_loader": self.weight_loader, }) else: self.register_parameter('bias', None) + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + output_dim = getattr(param, "output_dim", None) + param_data = param.data + if output_dim is not None: + shard_size = param_data.shape[output_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + def forward(self, input_): """Forward of ColumnParallelLinear @@ -202,7 +206,8 @@ def forward(self, input_): bias = self.bias if not self.skip_bias_add else None # Matrix multiply. - output_parallel = self.linear_method.apply_weights(self, input_, bias) + output_parallel = self.linear_method.apply_weights( + self.linear_weights, input_, bias) if self.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) @@ -212,6 +217,125 @@ def forward(self, input_): return output, output_bias +class PackedColumnParallelLinear(ColumnParallelLinear): + + def __init__( + self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + linear_method: Optional[LinearMethodBase] = None, + ): + self.output_sizes = output_sizes + tp_size = get_tensor_model_parallel_world_size() + assert all(output_size % tp_size == 0 for output_size in output_sizes) + super().__init__(input_size, sum(output_sizes), bias, gather_output, + skip_bias_add, params_dtype, linear_method) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, + loaded_shard_id: int): + assert loaded_shard_id < len(self.output_sizes) + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + param_data = param.data + output_dim = getattr(param, "output_dim", None) + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim is not None: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + else: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "PackedColumnParallelLinear, assume the weight is " + "the same for all partitions.") + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class QKVParallelLinear(ColumnParallelLinear): + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + linear_method: Optional[LinearMethodBase] = None, + ): + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(self.total_num_kv_heads, + tp_size) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = (self.num_heads + + 2 * self.num_kv_heads) * tp_size * self.head_size + super().__init__(input_size, output_size, bias, False, skip_bias_add, + params_dtype, linear_method) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, + loaded_shard_name: str): + assert loaded_shard_name in ["q", "k", "v"] + tp_rank = get_tensor_model_parallel_rank() + param_data = param.data + output_dim = getattr(param, "output_dim", None) + if output_dim is not None: + if loaded_shard_name == "q": + shard_offset = 0 + shard_size = self.num_heads * self.head_size + elif loaded_shard_name == "k": + shard_offset = self.num_heads * self.head_size + shard_size = self.num_kv_heads * self.head_size + elif loaded_shard_name == "v": + shard_offset = (self.num_heads + + self.num_kv_heads) * self.head_size + shard_size = self.num_kv_heads * self.head_size + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim is not None: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + shard_id = tp_rank // self.num_kv_head_replicas + start_idx = shard_id * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + else: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions.") + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + class RowParallelLinear(torch.nn.Module): """Linear layer with row parallelism. @@ -245,7 +369,7 @@ def __init__( input_size: int, output_size: int, bias: bool = True, - input_is_parallel: bool = False, + input_is_parallel: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, @@ -268,9 +392,11 @@ def __init__( if linear_method is None: linear_method = FullPrecisionLinearMethod() self.linear_method = linear_method - self.linear_method.create_weights(self, self.input_size_per_partition, - self.output_size, self.params_dtype, - {"input_dim_parallel": True}) + self.linear_weights = self.linear_method.create_weights( + self.input_size_per_partition, self.output_size, self.params_dtype) + for name, weight in self.linear_weights.items(): + self.register_parameter(name, weight) + set_weight_attrs(weight, {"weight_loader": self.weight_loader}) if not reduce_results and (bias and not skip_bias_add): raise ValueError('When not reduce the results, adding bias to the ' @@ -283,11 +409,23 @@ def __init__( dtype=params_dtype)) set_weight_attrs(self.bias, { "output_dim": 0, - "input_dim_parallel": True + "weight_loader": self.weight_loader, }) else: self.register_parameter('bias', None) + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + input_dim = getattr(param, "input_dim", None) + param_data = param.data + if input_dim is not None: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, + shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + def forward(self, input_): """Forward of RowParallelLinear @@ -311,7 +449,7 @@ def forward(self, input_): # Matrix multiply. output_parallel = self.linear_method.apply_weights( - self, input_parallel) + self.linear_weights, input_parallel) if self.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 76926ebfe5ce3..f437b6732bff8 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -36,6 +36,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearMethodBase, + QKVParallelLinear, + PackedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler @@ -62,11 +64,11 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() - self.gate_up_proj = ColumnParallelLinear(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - linear_method=linear_method) + self.gate_up_proj = PackedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + gather_output=False, + linear_method=linear_method) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, @@ -112,7 +114,6 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -120,13 +121,12 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.qkv_proj = ColumnParallelLinear( + self.qkv_proj = QKVParallelLinear( hidden_size, - (self.total_num_heads + - 2 * self.total_num_kv_heads * num_kv_heads_replicas) * self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, bias=False, - gather_output=False, linear_method=linear_method, ) self.o_proj = RowParallelLinear( @@ -317,29 +317,16 @@ def load_weights(self, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - num_kv_heads_replicas = max(1, - tp_size // self.config.num_key_value_heads) - num_kv_heads_per_gpu = max(1, - self.config.num_key_value_heads // tp_size) - - q_proj_size = self.config.hidden_size - kv_proj_size = (self.config.hidden_size // - self.config.num_attention_heads * - self.config.num_key_value_heads) - ffn_intermediate_size = self.config.intermediate_size stacked_params_mapping = [ - # (param_name, shard_name, slice_size, offset) - ("qkv_proj", "q_proj", q_proj_size, 0), - ("qkv_proj", "k_proj", kv_proj_size, q_proj_size), - ("qkv_proj", "v_proj", kv_proj_size, q_proj_size + kv_proj_size), - ("gate_up_proj", "gate_proj", ffn_intermediate_size, 0), - ("gate_up_proj", "up_proj", ffn_intermediate_size, - ffn_intermediate_size), + # (param_name, shard_name, shard_idx) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] - state_dict = dict(self.named_parameters()) + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): @@ -347,22 +334,18 @@ def load_weights(self, continue loaded = False - for (param_name, weight_name, slice_size, - slice_offset) in stacked_params_mapping: + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, param_name)] - # TODO: fix the case when num kv heads < tp size - load_tensor_parallel_weights(param, - loaded_weight, - output_slice_offset=slice_offset, - output_slice_size=slice_size) + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) loaded = True break if loaded: continue - param = state_dict[name] + param = params_dict[name] if "embed_tokens" in name or "lm_head" in name: load_padded_tensor_parallel_vocab(param, loaded_weight) continue diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 71ec29a313008..3d1bd16ca59e6 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -239,7 +239,7 @@ def hf_model_weights_iterator( with safe_open(st_file, framework="pt") as f: for name in f.keys(): param = f.get_slice(name) - yield name, param + yield name, convert_pyslice_to_tensor(param) else: for bin_file in hf_weights_files: state = torch.load(bin_file, map_location="cpu") @@ -278,44 +278,15 @@ def load_padded_tensor_parallel_vocab( def load_tensor_parallel_weights( - param: torch.Tensor, - loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` - output_slice_offset: Optional[int] = None, - output_slice_size: Optional[int] = None, + param: torch.Tensor, + loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` ) -> None: - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() param_data = param.data - if output_slice_offset is not None and output_slice_size is not None: - output_dim = param.output_dim - # Adjust the offset and size to account for tensor parallelism. - output_slice_offset //= tp_size - output_slice_size //= tp_size - # If quantized, we need to adjust the offset and size to account for - # the packing. - packed_dim = getattr(param, "packed_dim", None) - if packed_dim == output_dim: - pack_factor = param.pack_factor - output_slice_offset //= pack_factor - output_slice_size //= pack_factor - param_data = param_data.narrow(output_dim, output_slice_offset, - output_slice_size) - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - if getattr(param, "output_dim_parallel", False): - output_dim = getattr(param, "output_dim", None) - if output_dim is not None: - shard_size = param_data.shape[output_dim] - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(output_dim, start_idx, - shard_size) - if getattr(param, "input_dim_parallel", False): - input_dim = getattr(param, "input_dim", None) - if input_dim is not None: - shard_size = param_data.shape[input_dim] - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(input_dim, start_idx, - shard_size) + weight_loader = getattr(param, "weight_loader", None) + if weight_loader is not None: + weight_loader(param_data, loaded_weight) + return assert param_data.shape == loaded_weight.shape, ( f"shape mismatch between model and checkpoint: " From e4740204fe24f843abf3e74832a1564202f5af05 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 9 Nov 2023 05:04:35 +0000 Subject: [PATCH 13/51] Fix vocab loading --- vllm/model_executor/layers/linear.py | 19 ++------ .../layers/vocab_parallel_embedding.py | 46 ++++++++++++++++--- vllm/model_executor/models/llama.py | 22 ++++----- vllm/model_executor/utils.py | 13 ++++++ 4 files changed, 65 insertions(+), 35 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index b3d5fb79a3a56..36af744f49fb2 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -7,30 +7,17 @@ from torch.nn.parameter import Parameter from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather) from vllm.model_executor.parallel_utils.utils import ( - divide, - split_tensor_along_last_dim, -) + divide, split_tensor_along_last_dim) +from vllm.model_executor.utils import set_weight_attrs from vllm.logger import init_logger logger = init_logger(__name__) -def set_weight_attrs(weight: torch.Tensor, weight_attrs: Optional[dict[str, - Any]]): - if weight_attrs is None: - return - for key, value in weight_attrs.items(): - assert not hasattr( - weight, key), (f"Overwriting existing tensor attribute: {key}") - setattr(weight, key, value) - - class LinearMethodBase(ABC): @abstractmethod diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index f7cf8bd52dafa..3c7c4752df448 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Sequence import torch import torch.nn.functional as F @@ -8,10 +8,29 @@ get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) +from vllm.model_executor.parallel_utils.utils import divide from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce) +from vllm.model_executor.utils import set_weight_attrs -from vllm.model_executor.parallel_utils.utils import VocabUtility + +def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int: + """Pad the vocab size to the given value.""" + return ((vocab_size + pad_to - 1) // pad_to) * pad_to + + +def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int, + rank: int) -> Sequence[int]: + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + +def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, + world_size: int) -> Sequence[int]: + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, + rank) class VocabParallelEmbedding(torch.nn.Module): @@ -33,25 +52,31 @@ def __init__(self, # Keep the input dimensions. self.num_embeddings = num_embeddings + self.num_embeddings_padded = pad_vocab_size(num_embeddings) self.embedding_dim = embedding_dim if params_dtype is None: params_dtype = torch.get_default_dtype() - self.tp_size = get_tensor_model_parallel_world_size() - # TODO: Handle vocab padding here. # Divide the weight matrix along the vocaburaly dimension. self.vocab_start_index, self.vocab_end_index = ( - VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, get_tensor_model_parallel_rank(), + vocab_range_from_global_vocab_size( + self.num_embeddings_padded, get_tensor_model_parallel_rank(), self.tp_size)) self.num_embeddings_per_partition = (self.vocab_end_index - self.vocab_start_index) - self.weight = Parameter( torch.empty(self.num_embeddings_per_partition, self.embedding_dim, device=torch.cuda.current_device(), dtype=params_dtype)) + set_weight_attrs(self.weight, {"parallel_dim": 0}) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + assert loaded_weight.shape[0] == self.num_embeddings + assert loaded_weight.shape[1] == self.embedding_dim + loaded_weight = loaded_weight[self.vocab_start_index:self. + vocab_end_index] + param[:self.num_embeddings].data.copy_(loaded_weight) def forward(self, input_): if self.tp_size > 1: @@ -71,3 +96,10 @@ def forward(self, input_): # Reduce across all the model parallel GPUs. output = tensor_model_parallel_all_reduce(output_parallel) return output + + +class ParallelLMHead(VocabParallelEmbedding): + # TODO: Add docstring + def forward(self, input_): + del input_ + raise RuntimeError("LMHead's weight should be used in the sampler.") diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f437b6732bff8..9e2c8a76f5ca9 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -42,7 +42,7 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.quantization_utils import QuantizationConfig @@ -289,13 +289,7 @@ def __init__( else: linear_method = None self.model = LlamaModel(config, linear_method) - vocab_size = ((config.vocab_size + 63) // 64) * 64 - # NOTE: The LM head is not quantized. - self.lm_head = ColumnParallelLinear(config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - linear_method=None) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) def forward( @@ -346,7 +340,11 @@ def load_weights(self, continue param = params_dict[name] - if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight) - continue - load_tensor_parallel_weights(param, loaded_weight) + weight_loader = getattr(param, "weight_loader", None) + if weight_loader is not None: + weight_loader(param, loaded_weight) + else: + assert param.shape == loaded_weight.shape, ( + f"shape mismatch between model and checkpoint: " + f"{param.shape} != {loaded_weight.shape}") + param.data.copy_(loaded_weight) diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index bd74ae96aa19e..501906416b56b 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -1,5 +1,6 @@ """Utils for model executor.""" import random +from typing import Any, Optional import numpy as np import torch @@ -11,3 +12,15 @@ def set_random_seed(seed: int) -> None: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + + +def set_weight_attrs( + weight: torch.Tensor, + weight_attrs: Optional[dict[str, Any]], +): + if weight_attrs is None: + return + for key, value in weight_attrs.items(): + assert not hasattr( + weight, key), (f"Overwriting existing tensor attribute: {key}") + setattr(weight, key, value) From d10761360cc2562177bff311879ff58fbbf6a6d1 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 9 Nov 2023 05:22:18 +0000 Subject: [PATCH 14/51] clean up llama loader --- vllm/model_executor/models/llama.py | 32 ++++++++--------------------- vllm/model_executor/weight_utils.py | 27 ++++++++---------------- 2 files changed, 18 insertions(+), 41 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 9e2c8a76f5ca9..2bf6c9d8680bd 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -34,8 +34,7 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, +from vllm.model_executor.layers.linear import (LinearMethodBase, QKVParallelLinear, PackedColumnParallelLinear, RowParallelLinear) @@ -44,11 +43,10 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size) from vllm.model_executor.quantization_utils import QuantizationConfig -from vllm.model_executor.weight_utils import ( - convert_pyslice_to_tensor, hf_model_weights_iterator, - load_tensor_parallel_weights, load_padded_tensor_parallel_vocab) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -312,39 +310,27 @@ def load_weights(self, load_format: str = "auto", revision: Optional[str] = None): stacked_params_mapping = [ - # (param_name, shard_name, shard_idx) + # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - - loaded = False for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue param = params_dict[name.replace(weight_name, param_name)] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) - loaded = True break - if loaded: - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", None) - if weight_loader is not None: - weight_loader(param, loaded_weight) else: - assert param.shape == loaded_weight.shape, ( - f"shape mismatch between model and checkpoint: " - f"{param.shape} != {loaded_weight.shape}") - param.data.copy_(loaded_weight) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 3d1bd16ca59e6..cc75cf58ebc1e 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -264,34 +264,25 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: return x +def default_weight_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + assert param.size() == loaded_weight.size() + param.data.copy_(loaded_weight) + + def load_padded_tensor_parallel_vocab( param: torch.Tensor, loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` ) -> None: - tp_rank = get_tensor_model_parallel_rank() - shard_size = param.shape[0] - start_idx = tp_rank * shard_size - end_idx = (tp_rank + 1) * shard_size - loaded_weight = loaded_weight[start_idx:end_idx] - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - param[:loaded_weight.shape[0]].data.copy_(loaded_weight) + raise NotImplementedError() def load_tensor_parallel_weights( param: torch.Tensor, loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` ) -> None: - param_data = param.data - - weight_loader = getattr(param, "weight_loader", None) - if weight_loader is not None: - weight_loader(param_data, loaded_weight) - return - - assert param_data.shape == loaded_weight.shape, ( - f"shape mismatch between model and checkpoint: " - f"{param_data.shape} != {loaded_weight.shape}") - param_data.copy_(loaded_weight) + raise NotImplementedError() def initialize_dummy_weights( From d4aa8c97738815e4da95ae5d53118617a18f6a9b Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 9 Nov 2023 05:32:19 +0000 Subject: [PATCH 15/51] fix awq --- .../layers/quantized_linear/awq.py | 29 +++++++++---------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py index 4c1707222e435..256bd84fd89bc 100644 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ b/vllm/model_executor/layers/quantized_linear/awq.py @@ -72,12 +72,8 @@ class AWQLinearMethod(LinearMethodBase): def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config - def create_weights(self, - module: torch.nn.Module, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - weight_attrs: dict[str, Any] = None) -> None: + def create_weights(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: if input_size % self.quant_config.group_size != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -104,7 +100,6 @@ def create_weights(self, "output_dim": 1, "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - **weight_attrs, }) qzeros = Parameter( torch.empty( @@ -121,7 +116,6 @@ def create_weights(self, "output_dim": 1, "packed_dim": 1, "pack_factor": self.quant_config.pack_factor, - **weight_attrs, }) scales = Parameter( torch.empty( @@ -135,21 +129,24 @@ def create_weights(self, set_weight_attrs(scales, { "input_dim": 0, "output_dim": 1, - **weight_attrs, }) - module.register_parameter("qweight", qweight) - module.register_parameter("qzeros", qzeros) - module.register_parameter("scales", scales) + return { + "qweight": qweight, + "qzeros": qzeros, + "scales": scales, + } def apply_weights(self, - module: torch.nn.Module, + weights: Dict[str, torch.Tensor], x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = weights["qweight"] + qzeros = weights["qzeros"] + scales = weights["scales"] pack_factor = self.quant_config.pack_factor - out_shape = (x.shape[:-1] + (module.qweight.shape[-1] * pack_factor, )) + out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) reshaped_x = x.reshape(-1, x.shape[-1]) - out = quantization_ops.awq_gemm(reshaped_x, module.qweight, - module.scales, module.qzeros, + out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros, pack_factor) if bias is not None: out = out + bias From f48381bce9e9a8af1ab1544f9e39210134d078c6 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 9 Nov 2023 05:42:45 +0000 Subject: [PATCH 16/51] wip fix squeezellm --- .../layers/quantized_linear/squeezellm.py | 27 +++++++++---------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/layers/quantized_linear/squeezellm.py b/vllm/model_executor/layers/quantized_linear/squeezellm.py index a881281354943..8bf0c232eac83 100644 --- a/vllm/model_executor/layers/quantized_linear/squeezellm.py +++ b/vllm/model_executor/layers/quantized_linear/squeezellm.py @@ -61,12 +61,8 @@ class SqueezeLLMLinearMethod(LinearMethodBase): def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config - def create_weights(self, - module: torch.nn.Module, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - weight_attrs: dict[str, Any] = None) -> None: + def create_weights(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: if input_size % self.quant_config.pack_factor != 0: raise ValueError( "The input size is not aligned with the quantized " @@ -87,7 +83,6 @@ def create_weights(self, "output_dim": 1, "packed_dim": 0, "pack_factor": self.quant_config.pack_factor, - **weight_attrs, }) lookup_table = Parameter( torch.empty( @@ -100,22 +95,24 @@ def create_weights(self, ) set_weight_attrs(lookup_table, { "output_dim": 0, - **weight_attrs, }) - - module.register_parameter("qweight", qweight) - module.register_parameter("lookup_table", lookup_table) + return { + "qweight": qweight, + "lookup_table": lookup_table, + } def apply_weights(self, - module: torch.nn.Module, + weights: Dict[str, torch.Tensor], x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - out_shape = x.shape[:-1] + (module.qweight.shape[-1], ) + qweight = weights["qweight"] + lookup_table = weights["lookup_table"] + out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) # NOTE: The output tensor should be zero-initialized. out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) - quantization_ops.squeezellm_gemm(reshaped_x, module.qweight, out, - module.lookup_table) + quantization_ops.squeezellm_gemm(reshaped_x, qweight, out, + lookup_table) if bias is not None: out = out + bias From c5a9f9c42bac1727e887d1c4b2bba05c56f9c10f Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 9 Nov 2023 05:47:48 +0000 Subject: [PATCH 17/51] fix squeeze llm --- vllm/model_executor/layers/linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 36af744f49fb2..f7965905e4acb 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -235,7 +235,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, # If quantized, we need to adjust the offset and size to account # for the packing. packed_dim = getattr(param, "packed_dim", None) - if packed_dim is not None: + if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor param_data = param_data.narrow(output_dim, shard_offset, @@ -305,7 +305,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, # If quantized, we need to adjust the offset and size to account # for the packing. packed_dim = getattr(param, "packed_dim", None) - if packed_dim is not None: + if packed_dim == output_dim: shard_size = shard_size // param.pack_factor shard_offset = shard_offset // param.pack_factor param_data = param_data.narrow(output_dim, shard_offset, From 92155dad87bbeed7b8fb8d90992fbae9a9808165 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 9 Nov 2023 05:53:43 +0000 Subject: [PATCH 18/51] fix weight loader for embedding --- vllm/model_executor/layers/vocab_parallel_embedding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 3c7c4752df448..697a40d9ae9c8 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -69,7 +69,8 @@ def __init__(self, self.embedding_dim, device=torch.cuda.current_device(), dtype=params_dtype)) - set_weight_attrs(self.weight, {"parallel_dim": 0}) + set_weight_attrs(self.weight, {"parallel_dim": 0, + "weight_loader": self.weight_loader}) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert loaded_weight.shape[0] == self.num_embeddings From e528dbcc24c19f944dd259094861b9962bf33481 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 9 Nov 2023 05:56:40 +0000 Subject: [PATCH 19/51] fix --- vllm/model_executor/parallel_utils/utils.py | 22 --------------------- 1 file changed, 22 deletions(-) diff --git a/vllm/model_executor/parallel_utils/utils.py b/vllm/model_executor/parallel_utils/utils.py index 004a81f130a76..6982c621f113d 100644 --- a/vllm/model_executor/parallel_utils/utils.py +++ b/vllm/model_executor/parallel_utils/utils.py @@ -46,25 +46,3 @@ def split_tensor_along_last_dim( return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list - - -class VocabUtility: - """ Split the vocabulary into `world_size` chunks and return the first - and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [fist, last) - - """ - - @staticmethod - def vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size: int, rank: int) -> Sequence[int]: - index_f = rank * per_partition_vocab_size - index_l = index_f + per_partition_vocab_size - return index_f, index_l - - @staticmethod - def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, - world_size: int) -> Sequence[int]: - per_partition_vocab_size = divide(global_vocab_size, world_size) - return VocabUtility.vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size, rank) From 772ab72965c8049941d1af5df22518b23057677e Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 9 Nov 2023 06:04:18 +0000 Subject: [PATCH 20/51] support mistral --- vllm/model_executor/models/llama.py | 2 +- vllm/model_executor/models/mistral.py | 167 +++++--------------- vllm/model_executor/parallel_utils/utils.py | 3 + 3 files changed, 46 insertions(+), 126 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2bf6c9d8680bd..2e6fc7ef517fe 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -35,8 +35,8 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, - QKVParallelLinear, PackedColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index d8e440123cc4d..6d1e2baec77c0 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -34,19 +34,19 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, +from vllm.model_executor.layers.linear import (LinearMethodBase, + PackedColumnParallelLinear, + QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding) + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size) from vllm.model_executor.quantization_utils import QuantizationConfig -from vllm.model_executor.weight_utils import ( - convert_pyslice_to_tensor, hf_model_weights_iterator, - load_tensor_parallel_weights, load_padded_tensor_parallel_vocab) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -62,11 +62,11 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() - self.gate_up_proj = ColumnParallelLinear(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - linear_method=linear_method) + self.gate_up_proj = PackedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + gather_output=False, + linear_method=linear_method) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, @@ -101,8 +101,15 @@ def __init__(self, assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads - assert self.total_num_kv_heads % tp_size == 0 - self.num_kv_heads = self.total_num_kv_heads // tp_size + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -110,12 +117,12 @@ def __init__(self, self.rope_theta = rope_theta self.sliding_window = sliding_window - self.qkv_proj = ColumnParallelLinear( + self.qkv_proj = QKVParallelLinear( hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, bias=False, - gather_output=False, linear_method=linear_method, ) self.o_proj = RowParallelLinear( @@ -273,12 +280,7 @@ def __init__( else: linear_method = None self.model = MistralModel(config, linear_method) - vocab_size = ((config.vocab_size + 63) // 64) * 64 - # NOTE: The LM head is not quantized. - self.lm_head = ColumnParallelLinear(config.hidden_size, - vocab_size, - bias=False, - gather_output=False) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) def forward( @@ -295,118 +297,33 @@ def forward( input_metadata) return next_tokens - _column_parallel_layers = [] - _row_parallel_layers = ["o_proj", "down_proj"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - if self.quant_config is None: - col_weight_suffixes = ["weight"] - row_weight_suffixes = ["weight"] - else: - col_weight_suffixes = ( - self.quant_config.get_col_parallel_tensor_names()) - row_weight_suffixes = ( - self.quant_config.get_row_parallel_tensor_names()) - - column_parallel_weights: List[str] = [] - for layer in self._column_parallel_layers: - for suffix in col_weight_suffixes: - column_parallel_weights.append(f"{layer}.{suffix}") - row_parallel_weights: List[str] = [] - for layer in self._row_parallel_layers: - for suffix in row_weight_suffixes: - row_parallel_weights.append(f"{layer}.{suffix}") - - tp_size = get_tensor_model_parallel_world_size() - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - q_proj_shard_size = (self.config.hidden_size // tp_size) - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - self.config.num_key_value_heads // tp_size) - attention_weight_specs = [ - # (weight_name, shard_size, offset) - ("q_proj", q_proj_shard_size, 0), - ("k_proj", kv_proj_shard_size, q_proj_shard_size), - ("v_proj", kv_proj_shard_size, - q_proj_shard_size + kv_proj_shard_size), + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] - state_dict = self.state_dict() - + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - - packed_dim = None - is_transposed = False - if self.quant_config is not None: - packed_dim = self.quant_config.get_packed_dim(name) - is_transposed = self.quant_config.is_transposed(name) - if is_transposed: - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - loaded_weight = loaded_weight.T - - is_attention_weight = False - for weight_name, shard_size, offset in attention_weight_specs: - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "qkv_proj")] - if is_transposed: - param = param.T - - if packed_dim is not None: - shard_dim = 0 if not is_transposed else 1 - if packed_dim == shard_dim: - shard_size //= self.quant_config.pack_factor - offset //= self.quant_config.pack_factor - - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[offset:offset + shard_size] - assert param_slice.shape == loaded_weight.shape - - param_slice.copy_(loaded_weight) - is_attention_weight = True - break - if is_attention_weight: - continue - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - if is_transposed: - param = param.T - - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_gate_up_weight: - continue - - param = state_dict[name] - if is_transposed: - param = param.T - - if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) - continue - - load_tensor_parallel_weights(param, loaded_weight, name, - column_parallel_weights, - row_parallel_weights, - tensor_model_parallel_rank) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/parallel_utils/utils.py b/vllm/model_executor/parallel_utils/utils.py index 6982c621f113d..2e8a15ee3f53f 100644 --- a/vllm/model_executor/parallel_utils/utils.py +++ b/vllm/model_executor/parallel_utils/utils.py @@ -46,3 +46,6 @@ def split_tensor_along_last_dim( return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list + +class VocabUtility: + pass From 0a08e66c763e555dac404b791d4158b1aed245a0 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 9 Nov 2023 06:04:18 +0000 Subject: [PATCH 21/51] fix --- vllm/model_executor/models/mistral.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 6d1e2baec77c0..541f89b63fed7 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -228,9 +228,8 @@ def __init__( self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( - vocab_size, + config.vocab_size, config.hidden_size, ) self.layers = nn.ModuleList([ From 7d7aa4b699e483ab8475442acb3ced52629d8c86 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 9 Nov 2023 08:15:00 +0000 Subject: [PATCH 22/51] Fix aqulia --- .../layers/vocab_parallel_embedding.py | 6 +- vllm/model_executor/models/aquila.py | 169 +++++++----------- vllm/model_executor/models/llama.py | 6 +- vllm/model_executor/models/mistral.py | 2 +- vllm/model_executor/parallel_utils/utils.py | 1 + 5 files changed, 77 insertions(+), 107 deletions(-) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 697a40d9ae9c8..b75b76237ca62 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -69,8 +69,10 @@ def __init__(self, self.embedding_dim, device=torch.cuda.current_device(), dtype=params_dtype)) - set_weight_attrs(self.weight, {"parallel_dim": 0, - "weight_loader": self.weight_loader}) + set_weight_attrs(self.weight, { + "parallel_dim": 0, + "weight_loader": self.weight_loader + }) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert loaded_weight.shape[0] == self.num_embeddings diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 4dae9e46dad7d..7c4ed7f6da513 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -33,15 +33,18 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.linear import (LinearMethodBase, + PackedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import ( - hf_model_weights_iterator, load_padded_tensor_parallel_vocab, - load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) + get_tensor_model_parallel_world_size) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.aquila import AquilaConfig @@ -55,20 +58,19 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - self.gate_up_proj = ColumnParallelLinear( - hidden_size, - 2 * intermediate_size, + self.gate_up_proj = PackedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=False, gather_output=False, - ) - self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - ) + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -111,6 +113,7 @@ def __init__( rope_theta: float = 10000, max_position_embeddings: int = 8192, rope_scaling: Optional[Dict[str, Any]] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.hidden_size = hidden_size @@ -128,29 +131,30 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.qkv_proj = ColumnParallelLinear( + self.qkv_proj = QKVParallelLinear( hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, bias=False, - gather_output=False, + linear_method=linear_method, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, input_is_parallel=True, + linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( self.num_heads, self.head_dim, self.scaling, - rotary_dim=self.head_dim, base=self.rope_theta, max_position=self.max_position_embeddings, + rotary_dim=self.head_dim, num_kv_heads=self.num_kv_heads, - rope_scaling=rope_scaling, - ) + rope_scaling=rope_scaling) def forward( self, @@ -171,7 +175,11 @@ def forward( class AquilaDecoderLayer(nn.Module): - def __init__(self, config: AquilaConfig): + def __init__( + self, + config: AquilaConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -185,11 +193,13 @@ def __init__(self, config: AquilaConfig): rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, rope_scaling=rope_scaling, + linear_method=linear_method, ) self.mlp = AquilaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, + linear_method=linear_method, ) self.input_layernorm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -226,19 +236,22 @@ def forward( class AquilaModel(nn.Module): - def __init__(self, config: AquilaConfig): + def __init__( + self, + config: AquilaConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - #vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) self.layers = nn.ModuleList([ - AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers) + AquilaDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) ]) self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -271,17 +284,19 @@ def forward( class AquilaForCausalLM(nn.Module): - def __init__(self, config): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config - self.model = AquilaModel(config) - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.lm_head = ColumnParallelLinear( - config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - ) + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None + self.model = AquilaModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) def forward( @@ -298,79 +313,33 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = [ - "qkv_proj.weight", "gate_proj.weight", "up_proj.weight" - ] - _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tp_size = get_tensor_model_parallel_world_size() - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - q_proj_shard_size = (self.config.hidden_size // tp_size) - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - self.config.num_key_value_heads // tp_size) - attention_weight_specs = [ - # (weight_name, shard_size, offset) - ("q_proj", q_proj_shard_size, 0), - ("k_proj", kv_proj_shard_size, q_proj_shard_size), - ("v_proj", kv_proj_shard_size, - q_proj_shard_size + kv_proj_shard_size), + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] - state_dict = self.state_dict() - + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - - is_attention_weight = False - for weight_name, shard_size, offset in attention_weight_specs: - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "qkv_proj")] - - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[offset:offset + shard_size] - assert param_slice.shape == loaded_weight.shape - - param_slice.copy_(loaded_weight) - is_attention_weight = True - break - if is_attention_weight: - continue - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_gate_up_weight: - continue - - param = state_dict[name] - if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) - continue - - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2e6fc7ef517fe..b72379d9237ff 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -33,12 +33,12 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, PackedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) @@ -234,10 +234,8 @@ def __init__( self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( - vocab_size, + config.vocab_size, config.hidden_size, ) self.layers = nn.ModuleList([ diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 541f89b63fed7..4587432c3ac3d 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -33,12 +33,12 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, PackedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding, ParallelLMHead) diff --git a/vllm/model_executor/parallel_utils/utils.py b/vllm/model_executor/parallel_utils/utils.py index 2e8a15ee3f53f..8244bcb02535f 100644 --- a/vllm/model_executor/parallel_utils/utils.py +++ b/vllm/model_executor/parallel_utils/utils.py @@ -47,5 +47,6 @@ def split_tensor_along_last_dim( return tensor_list + class VocabUtility: pass From 1df5d6bab807f21eddce9375d6333c1277cb5f4e Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 9 Nov 2023 21:31:17 +0000 Subject: [PATCH 23/51] fix vocab loader --- vllm/model_executor/layers/vocab_parallel_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index b75b76237ca62..7dc2d1f824316 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -79,7 +79,7 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): assert loaded_weight.shape[1] == self.embedding_dim loaded_weight = loaded_weight[self.vocab_start_index:self. vocab_end_index] - param[:self.num_embeddings].data.copy_(loaded_weight) + param[:loaded_weight.shape[0]].data.copy_(loaded_weight) def forward(self, input_): if self.tp_size > 1: From 93685f47e2b685e9d7e0aa7f7b54f53d85d0f437 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 9 Nov 2023 22:16:58 +0000 Subject: [PATCH 24/51] fix baichuan --- vllm/model_executor/layers/linear.py | 46 ++++++-- vllm/model_executor/models/baichuan.py | 157 +++++++++++-------------- 2 files changed, 108 insertions(+), 95 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f7965905e4acb..26efa0e20cfb4 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -259,7 +259,7 @@ def __init__( hidden_size: int, head_size: int, total_num_heads: int, - total_num_kv_heads: int, + total_num_kv_heads: Optional[int] = None, bias: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, @@ -268,6 +268,8 @@ def __init__( self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads self.total_num_kv_heads = total_num_kv_heads # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() @@ -285,20 +287,48 @@ def __init__( super().__init__(input_size, output_size, bias, False, skip_bias_add, params_dtype, linear_method) - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_name: str): - assert loaded_shard_name in ["q", "k", "v"] - tp_rank = get_tensor_model_parallel_rank() + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): param_data = param.data output_dim = getattr(param, "output_dim", None) + if loaded_shard_id is None: + # Loaded weight is already packed. + if output_dim is None: + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ("k", self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + ("v", (self.total_num_heads + self.total_num_kv_heads) * + self.head_size, self.total_num_kv_heads * self.head_size), + ] + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + tp_rank = get_tensor_model_parallel_rank() + assert loaded_shard_id in ["q", "k", "v"] if output_dim is not None: - if loaded_shard_name == "q": + if loaded_shard_id == "q": shard_offset = 0 shard_size = self.num_heads * self.head_size - elif loaded_shard_name == "k": + elif loaded_shard_id == "k": shard_offset = self.num_heads * self.head_size shard_size = self.num_kv_heads * self.head_size - elif loaded_shard_name == "v": + elif loaded_shard_id == "v": shard_offset = (self.num_heads + self.num_kv_heads) * self.head_size shard_size = self.num_kv_heads * self.head_size diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 7d0454271a799..27bf4c0ace18d 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -30,18 +30,21 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE, PagedAttentionWithALiBi) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + PackedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import ( - convert_pyslice_to_tensor, hf_model_weights_iterator, - load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.baichuan import BaiChuanConfig @@ -80,20 +83,19 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - self.gate_up_proj = ColumnParallelLinear( - hidden_size, - 2 * intermediate_size, + self.gate_up_proj = PackedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=False, gather_output=False, - ) - self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - ) + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -116,6 +118,7 @@ def __init__( position_embedding: str, rope_theta: float = 10000, max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.hidden_size = hidden_size @@ -131,11 +134,13 @@ def __init__( self.max_position_embeddings = max_position_embeddings # pylint: disable=invalid-name - self.W_pack = ColumnParallelLinear( + self.W_pack = QKVParallelLinear( hidden_size, - 3 * hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_heads, bias=False, - gather_output=False, + linear_method=linear_method, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, @@ -188,7 +193,10 @@ def forward( class BaiChuanDecoderLayer(nn.Module): - def __init__(self, config: BaiChuanConfig, position_embedding: str): + def __init__(self, + config: BaiChuanConfig, + position_embedding: str, + linear_method: Optional[LinearMethodBase] = None): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -200,11 +208,13 @@ def __init__(self, config: BaiChuanConfig, position_embedding: str): position_embedding=position_embedding, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, + linear_method=linear_method, ) self.mlp = BaiChuanMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, + linear_method=linear_method, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -241,7 +251,10 @@ def forward( class BaiChuanModel(nn.Module): - def __init__(self, config: BaiChuanConfig, position_embedding: str): + def __init__(self, + config: BaiChuanConfig, + position_embedding: str, + linear_method: Optional[LinearMethodBase] = None): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -252,7 +265,7 @@ def __init__(self, config: BaiChuanConfig, position_embedding: str): config.hidden_size, ) self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config, position_embedding) + BaiChuanDecoderLayer(config, position_embedding, linear_method) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -285,16 +298,19 @@ def forward( class BaiChuanBaseForCausalLM(nn.Module): - def __init__(self, config, position_embedding: str): + def __init__(self, + config, + position_embedding: str, + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.model = BaiChuanModel(config, position_embedding) - self.lm_head = ColumnParallelLinear( - config.hidden_size, - config.vocab_size, - bias=False, - gather_output=False, - ) + self.quant_config = quant_config + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None + self.model = BaiChuanModel(config, position_embedding, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) def forward( @@ -311,79 +327,46 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = [] - _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tp_world_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() - + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - - if "W_pack" in name: - total_num_heads = self.config.num_attention_heads - hidden_size = self.config.hidden_size - head_size = hidden_size // total_num_heads - num_heads = total_num_heads // tp_world_size - head_start = tp_rank * num_heads - head_end = (tp_rank + 1) * num_heads - - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size, hidden_size) - loaded_weight = loaded_weight[:, head_start:head_end, :, :] - loaded_weight = loaded_weight.reshape(-1, hidden_size) - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * - (tp_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_gate_up_weight: - continue - - param = state_dict[name] - - if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tp_rank) - continue - - load_tensor_parallel_weights( - param, - loaded_weight, - name, - self._column_parallel_weights, - self._row_parallel_weights, - tp_rank, - ) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b - def __init__(self, config): - super().__init__(config, "ALIBI") + def __init__(self, + config, + quant_config: Optional[QuantizationConfig] = None): + super().__init__(config, "ALIBI", quant_config) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b - def __init__(self, config): - super().__init__(config, "ROPE") + def __init__(self, + config, + quant_config: Optional[QuantizationConfig] = None): + super().__init__(config, "ROPE", quant_config) From 5f5ea902839abcabd13c89c934da72ae098e9382 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 9 Nov 2023 23:33:22 +0000 Subject: [PATCH 25/51] fix bloom --- vllm/model_executor/models/bloom.py | 134 +++++++++++++++------------- 1 file changed, 74 insertions(+), 60 deletions(-) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index f3bb17655c5b3..e993871082809 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -30,14 +30,18 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttentionWithALiBi +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -70,7 +74,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: class BloomAttention(nn.Module): - def __init__(self, config: BloomConfig): + def __init__( + self, + config: BloomConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size self.total_num_heads = config.n_head @@ -81,17 +89,19 @@ def __init__(self, config: BloomConfig): assert self.total_num_heads % tp_world_size == 0 self.num_heads = self.total_num_heads // tp_world_size - self.query_key_value = ColumnParallelLinear( + self.query_key_value = QKVParallelLinear( self.hidden_size, - 3 * self.hidden_size, + self.head_dim, + self.total_num_heads, bias=True, - gather_output=False, + linear_method=linear_method, ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, input_is_parallel=True, + linear_method=linear_method, ) # Create the alibi slopes and slice them. @@ -125,19 +135,25 @@ def forward( class BloomMLP(nn.Module): - def __init__(self, config: BloomConfig): + def __init__( + self, + config: BloomConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.hidden_size self.dense_h_to_4h = ColumnParallelLinear( hidden_size, 4 * hidden_size, gather_output=False, + linear_method=linear_method, ) self.act = get_act_fn("gelu") self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, hidden_size, input_is_parallel=True, + linear_method=linear_method, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -149,16 +165,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class BloomBlock(nn.Module): - def __init__(self, config: BloomConfig): + def __init__( + self, + config: BloomConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.hidden_size self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config) + self.self_attention = BloomAttention(config, linear_method) self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) - self.mlp = BloomMLP(config) + self.mlp = BloomMLP(config, linear_method) self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm) @@ -203,7 +223,11 @@ def forward( class BloomModel(nn.Module): - def __init__(self, config: BloomConfig): + def __init__( + self, + config: BloomConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.embed_dim = config.hidden_size @@ -216,8 +240,10 @@ def __init__(self, config: BloomConfig): self.embed_dim, eps=config.layer_norm_epsilon) # Transformer blocks - self.h = nn.ModuleList( - [BloomBlock(config) for _ in range(config.num_hidden_layers)]) + self.h = nn.ModuleList([ + BloomBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) # Final Layer Norm self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -251,12 +277,19 @@ def forward( class BloomForCausalLM(nn.Module): - def __init__(self, config: BloomConfig): + def __init__( + self, + config: BloomConfig, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config - self.transformer = BloomModel(config) - # TODO(zhuohan): create a new weight after implementing pipeline - # parallelism + self.quant_config = quant_config + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None + self.transformer = BloomModel(config, linear_method) self.lm_head_weight = self.transformer.word_embeddings.weight self.sampler = Sampler(config.vocab_size) @@ -274,55 +307,36 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = [ - "word_embeddings.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias" - ] - _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tp_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if name == "lm_head.weight": - # Since hidden_states are parallelized, we need to - # load lm_head.weight in parallel. - self._column_parallel_weights.append(name) - # If lm_head is provided, use it instead. - param = self.lm_head_weight - else: - if not name.startswith("transformer."): - name = "transformer." + name - param = state_dict[name] + continue + if not name.startswith("transformer."): + name = "transformer." + name + param = params_dict[name] if "query_key_value" in name: - # NOTE(woosuk): BLOOM's fused QKV has the shape of - # [num_heads * 3 * head_size, hidden_size], while the - # required shape is [3 * num_heads * head_size, hidden_size]. + # NOTE: BLOOM's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). # Thus, we need weight conversion. - shard_size = param.shape[0] - start = shard_size * tp_rank - end = shard_size * (tp_rank + 1) - loaded_weight = loaded_weight[start:end] - + output_dim = getattr(param, "output_dim", None) num_heads = self.config.num_attention_heads - hidden_size = self.config.hidden_size - head_size = hidden_size // num_heads - if "query_key_value.weight" in name: - loaded_weight = loaded_weight.view(-1, 3, head_size, - hidden_size) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1, hidden_size) - elif "query_key_value.bias" in name: - loaded_weight = loaded_weight.view(-1, 3, head_size) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1) - else: - raise ValueError(f"Unexpected weight name: {name}") - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, tp_rank) + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 31af3ea3f266a1af98ed4a8c56d6fd621f0be6b0 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 10 Nov 2023 00:08:55 +0000 Subject: [PATCH 26/51] fix qwen --- vllm/model_executor/models/baichuan.py | 1 + vllm/model_executor/models/qwen.py | 203 +++++++++++-------------- 2 files changed, 87 insertions(+), 117 deletions(-) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 27bf4c0ace18d..4571ab30431aa 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -147,6 +147,7 @@ def __init__( hidden_size, bias=False, input_is_parallel=True, + linear_method=linear_method, ) # Create the alibi slopes and slice them. if self.postion_embedding == "ALIBI": diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index bd5280b35cc34..014e954969c07 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -15,24 +15,20 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + PackedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import ( - convert_pyslice_to_tensor, - hf_model_weights_iterator, - load_padded_tensor_parallel_vocab, - load_tensor_parallel_weights, -) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from vllm.model_executor.parallel_utils.layers import ( - VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear, -) + get_tensor_model_parallel_world_size) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.qwen import QWenConfig @@ -46,20 +42,19 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str = "silu", + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - self.gate_up_proj = ColumnParallelLinear( - hidden_size, - 2 * intermediate_size, + self.gate_up_proj = PackedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=False, gather_output=False, - ) - self.c_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - ) + linear_method=linear_method) + self.c_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -74,12 +69,15 @@ def forward(self, x): class QWenAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - max_position_embeddings: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None): + def __init__( + self, + hidden_size: int, + num_heads: int, + max_position_embeddings: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = hidden_size tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( @@ -90,18 +88,19 @@ def __init__(self, tensor_model_parallel_world_size) self.head_dim = hidden_size // self.total_num_heads - # pylint: disable=invalid-name - self.c_attn = ColumnParallelLinear( + self.c_attn = QKVParallelLinear( hidden_size, - 3 * hidden_size, + self.head_dim, + self.total_num_heads, bias=True, - gather_output=False, + linear_method=linear_method, ) self.c_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, input_is_parallel=True, + linear_method=linear_method, ) self.scaling = self.head_dim**-0.5 self.attn = PagedAttentionWithRoPE( @@ -134,7 +133,11 @@ def forward( class QWenBlock(nn.Module): - def __init__(self, config: QWenConfig): + def __init__( + self, + config: QWenConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -144,11 +147,14 @@ def __init__(self, config: QWenConfig): config.num_attention_heads, config.max_position_embeddings, rope_theta=rope_theta, - rope_scaling=rope_scaling) + rope_scaling=rope_scaling, + linear_method=linear_method) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2) + self.mlp = QWenMLP(config.hidden_size, + config.intermediate_size // 2, + linear_method=linear_method) def forward( self, @@ -180,18 +186,23 @@ def forward( class QWenModel(nn.Module): - def __init__(self, config: QWenConfig): + def __init__( + self, + config: QWenConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.vocab_size = config.vocab_size - vocab_size = ((config.vocab_size + 63) // 64) * 64 self.wte = VocabParallelEmbedding( - vocab_size, + config.vocab_size, config.hidden_size, ) - self.h = nn.ModuleList( - [QWenBlock(config) for _ in range(config.num_hidden_layers)]) + self.h = nn.ModuleList([ + QWenBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( @@ -222,17 +233,20 @@ def forward( class QWenLMHeadModel(nn.Module): - def __init__(self, config: QWenConfig): + def __init__( + self, + config: QWenConfig, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config - self.transformer = QWenModel(config) - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.lm_head = ColumnParallelLinear( - config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - ) + self.quant_config = quant_config + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None + self.transformer = QWenModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) def forward( @@ -249,75 +263,30 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = [] - _row_parallel_weights = ["c_proj.weight"] - - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): - tp_world_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() - + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "w2", 0), + ("gate_up_proj", "w1", 1), + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - - if "c_attn" in name: - total_num_heads = self.config.num_attention_heads - hidden_size = self.config.hidden_size - head_size = hidden_size // total_num_heads - num_heads = total_num_heads // tp_world_size - head_start = tp_rank * num_heads - head_end = (tp_rank + 1) * num_heads - - if "weight" in name: - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size, hidden_size) - loaded_weight = loaded_weight[:, head_start:head_end, :, :] - loaded_weight = loaded_weight.reshape(-1, hidden_size) - elif "bias" in name: - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size) - loaded_weight = loaded_weight[:, head_start:head_end, :] - loaded_weight = loaded_weight.reshape(-1) - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["w2", "w1"]): + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * - (tp_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_gate_up_weight: - continue - - param = state_dict[name] - - if "wte" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tp_rank) - continue - - load_tensor_parallel_weights( - param, - loaded_weight, - name, - self._column_parallel_weights, - self._row_parallel_weights, - tp_rank, - ) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 68f5a3f6445339f023ac8c8a0aa197ee54a11c10 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 10 Nov 2023 00:09:17 +0000 Subject: [PATCH 27/51] fix qwen --- test_class.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) create mode 100644 test_class.py diff --git a/test_class.py b/test_class.py new file mode 100644 index 0000000000000..80d4f23975160 --- /dev/null +++ b/test_class.py @@ -0,0 +1,17 @@ +class A: + def __init__(self) -> None: + self.hello() + + def hello(self): + print("Hello, I'm A") + + +class B(A): + def __init__(self) -> None: + super().__init__() + self.hello() + + def hello(self): + print("Hello, I'm B") + +b = B() \ No newline at end of file From 4f68d07d9d42e5cc3b7b2ac7080a163dbb3f5de0 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 10 Nov 2023 00:25:27 +0000 Subject: [PATCH 28/51] fix opt --- vllm/model_executor/models/opt.py | 135 +++++++++++++++++------------- 1 file changed, 76 insertions(+), 59 deletions(-) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 5295c73981856..7aeead7226e2f 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -30,14 +30,19 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (LinearMethodBase, + ColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) + get_tensor_model_parallel_world_size) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -63,6 +68,7 @@ def __init__( embed_dim: int, num_heads: int, bias: bool = True, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.embed_dim = embed_dim @@ -74,17 +80,19 @@ def __init__( self.head_dim = embed_dim // total_num_heads self.scaling = self.head_dim**-0.5 - self.qkv_proj = ColumnParallelLinear( + self.qkv_proj = QKVParallelLinear( embed_dim, - 3 * embed_dim, + self.head_dim, + total_num_heads, bias=bias, - gather_output=False, + linear_method=linear_method, ) self.out_proj = RowParallelLinear( embed_dim, embed_dim, bias=bias, input_is_parallel=True, + linear_method=linear_method, ) self.attn = PagedAttention(self.num_heads, self.head_dim, @@ -108,7 +116,11 @@ def forward( class OPTDecoderLayer(nn.Module): - def __init__(self, config: OPTConfig): + def __init__( + self, + config: OPTConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -116,6 +128,7 @@ def __init__(self, config: OPTConfig): embed_dim=self.embed_dim, num_heads=config.num_attention_heads, bias=config.enable_bias, + linear_method=linear_method, ) self.do_layer_norm_before = config.do_layer_norm_before self.activation_fn = get_act_fn(config.activation_function) @@ -128,12 +141,14 @@ def __init__(self, config: OPTConfig): config.ffn_dim, bias=config.enable_bias, gather_output=False, + linear_method=linear_method, ) self.fc2 = RowParallelLinear( config.ffn_dim, self.embed_dim, bias=config.enable_bias, input_is_parallel=True, + linear_method=linear_method, ) self.final_layer_norm = nn.LayerNorm( self.embed_dim, @@ -177,7 +192,11 @@ def forward( class OPTDecoder(nn.Module): - def __init__(self, config: OPTConfig): + def __init__( + self, + config: OPTConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -194,16 +213,18 @@ def __init__(self, config: OPTConfig): # Project out & in will be replicated if they exist. if config.word_embed_proj_dim != config.hidden_size: - self.project_out = nn.Linear(config.hidden_size, - config.word_embed_proj_dim, - bias=False) + self.project_out = ReplicatedLinear(config.hidden_size, + config.word_embed_proj_dim, + bias=False, + linear_method=linear_method) else: self.project_out = None if config.word_embed_proj_dim != config.hidden_size: - self.project_in = nn.Linear(config.word_embed_proj_dim, - config.hidden_size, - bias=False) + self.project_in = ReplicatedLinear(config.word_embed_proj_dim, + config.hidden_size, + bias=False, + linear_method=linear_method) else: self.project_in = None @@ -218,8 +239,10 @@ def __init__(self, config: OPTConfig): else: self.final_layer_norm = None - self.layers = nn.ModuleList( - [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([ + OPTDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) def forward( self, @@ -253,9 +276,13 @@ def forward( class OPTModel(nn.Module): - def __init__(self, config: OPTConfig): + def __init__( + self, + config: OPTConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() - self.decoder = OPTDecoder(config) + self.decoder = OPTDecoder(config, linear_method) def forward( self, @@ -271,12 +298,19 @@ def forward( class OPTForCausalLM(nn.Module): - def __init__(self, config): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config - self.model = OPTModel(config) - # TODO(zhuohan): create a new weight after implementing pipeline - # parallelism + self.quant_config = quant_config + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None + self.model = OPTModel(config, linear_method) self.lm_head_weight = self.model.decoder.embed_tokens.weight self.sampler = Sampler(config.vocab_size) @@ -294,48 +328,31 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = [ - "embed_tokens.weight", "fc1.weight", "fc1.bias" - ] - _row_parallel_weights = ["out_proj.weight", "fc2.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() - + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "lm_head.weight" in name: continue - - if name.startswith("decoder."): - name = "model." + name - - is_attention_weight = False - for stride_id, att_weight_name in enumerate( - ["q_proj", "k_proj", "v_proj"]): - if att_weight_name not in name: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: continue - param = state_dict[name.replace(att_weight_name, "qkv_proj")] - shard_size = param.shape[0] // 3 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_attention_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_attention_weight: - continue - - param = state_dict[name] - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 23099e2c08bea5013d86cd52fa93a02dfd18a213 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 10 Nov 2023 00:55:28 +0000 Subject: [PATCH 29/51] fix mpt --- vllm/model_executor/models/mpt.py | 113 +++++++++++++++--------------- vllm/model_executor/models/opt.py | 4 +- 2 files changed, 59 insertions(+), 58 deletions(-) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 4a66c5b5dec6c..8c6237714ef30 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -10,15 +10,18 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttentionWithALiBi +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor, - hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -39,7 +42,11 @@ def _get_alibi_slopes( class MptAttention(nn.Module): - def __init__(self, config: MptConfig): + def __init__( + self, + config: MptConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.d_model = config.d_model self.total_num_heads = config.n_heads @@ -49,11 +56,13 @@ def __init__(self, config: MptConfig): assert not config.attn_config.prefix_lm assert config.attn_config.alibi - self.qkv_proj = ColumnParallelLinear( + # pylint: disable=invalid-name + self.Wqkv = QKVParallelLinear( self.d_model, - 3 * self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, bias=not config.no_bias, - gather_output=False, + linear_method=linear_method, ) if self.qk_ln: self.q_ln = nn.LayerNorm(self.d_model) @@ -63,6 +72,7 @@ def __init__(self, config: MptConfig): self.d_model, bias=not config.no_bias, input_is_parallel=True, + linear_method=linear_method, ) tp_world_size = get_tensor_model_parallel_world_size() @@ -91,7 +101,7 @@ def forward( cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: del position_ids # unused. - qkv, _ = self.qkv_proj(hidden_states) + qkv, _ = self.Wqkv(hidden_states) if self.clip_qkv is not None: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) q, k, v = qkv.chunk(chunks=3, dim=-1) @@ -107,7 +117,11 @@ def forward( class MptMLP(nn.Module): - def __init__(self, config: MptConfig): + def __init__( + self, + config: MptConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.d_model expansion_ratio = config.expansion_ratio @@ -117,6 +131,7 @@ def __init__(self, config: MptConfig): intermediate_size, bias=not config.no_bias, gather_output=False, + linear_method=linear_method, ) self.act = get_act_fn("gelu") self.down_proj = RowParallelLinear( @@ -124,6 +139,7 @@ def __init__(self, config: MptConfig): hidden_size, bias=not config.no_bias, input_is_parallel=True, + linear_method=linear_method, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -135,13 +151,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MptBlock(nn.Module): - def __init__(self, config: MptConfig): + def __init__( + self, + config: MptConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MptAttention(config) + self.attn = MptAttention(config, linear_method) self.norm_2 = nn.LayerNorm(hidden_size) - self.ffn = MptMLP(config) + self.ffn = MptMLP(config, linear_method) def forward( self, @@ -168,7 +188,11 @@ def forward( class MptModel(nn.Module): - def __init__(self, config: MptConfig): + def __init__( + self, + config: MptConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() assert config.embedding_fraction == 1.0 assert config.norm_type == "low_precision_layernorm" @@ -178,7 +202,7 @@ def __init__(self, config: MptConfig): config.d_model, ) self.blocks = nn.ModuleList( - [MptBlock(config) for _ in range(config.n_layers)]) + [MptBlock(config, linear_method) for _ in range(config.n_layers)]) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -215,14 +239,21 @@ def forward( class MptForCausalLM(nn.Module): - def __init__(self, config: MptConfig): + def __init__( + self, + config: MptConfig, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config assert config.tie_word_embeddings + self.quant_config = quant_config + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None - self.transformer = MptModel(config) - # TODO(zhuohan): create a new weight after implementing pipeline - # parallelism + self.transformer = MptModel(config, linear_method) self.lm_head_weight = self.transformer.wte.weight self.sampler = Sampler(config.vocab_size) @@ -240,45 +271,15 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"] - _row_parallel_weights = ["out_proj.weight", "down_proj.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tp_world_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): - if "Wqkv" in name: - # NOTE(woosuk): MPT's fused QKV has the shape of - # [3 * num_heads * head_size, hidden_size]. - # When tensor model parallelism is used, we need to shard - # the weight along the hidden dimension. - total_num_heads = self.config.num_attention_heads - hidden_size = self.config.hidden_size - head_size = hidden_size // total_num_heads - num_heads = total_num_heads // tp_world_size - head_start = tp_rank * num_heads - head_end = (tp_rank + 1) * num_heads - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - if name.endswith(".weight"): - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size, hidden_size) - loaded_weight = loaded_weight[:, head_start:head_end, :, :] - loaded_weight = loaded_weight.reshape(-1, hidden_size) - elif name.endswith(".bias"): - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size) - loaded_weight = loaded_weight[:, head_start:head_end, :] - loaded_weight = loaded_weight.reshape(-1) - else: - raise ValueError(f"Unexpected parameter name {name}") - name = name.replace("Wqkv", "qkv_proj") - param = state_dict[name] - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, tp_rank) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 7aeead7226e2f..482310a3b4c0a 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -30,8 +30,8 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention -from vllm.model_executor.layers.linear import (LinearMethodBase, - ColumnParallelLinear, +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) From d7d108d391c0fc643b2616b1e58525749b7b17ad Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 10 Nov 2023 01:11:34 +0000 Subject: [PATCH 30/51] fix internlm --- test_class.py | 5 +- vllm/model_executor/models/internlm.py | 154 +++++++++++-------------- 2 files changed, 72 insertions(+), 87 deletions(-) diff --git a/test_class.py b/test_class.py index 80d4f23975160..5a62560a7d45b 100644 --- a/test_class.py +++ b/test_class.py @@ -1,4 +1,5 @@ class A: + def __init__(self) -> None: self.hello() @@ -7,6 +8,7 @@ def hello(self): class B(A): + def __init__(self) -> None: super().__init__() self.hello() @@ -14,4 +16,5 @@ def __init__(self) -> None: def hello(self): print("Hello, I'm B") -b = B() \ No newline at end of file + +b = B() diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 4a595a37730da..046f2ff6faeef 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -9,15 +9,18 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + PackedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, - RowParallelLinear, - VocabParallelEmbedding) -from vllm.model_executor.weight_utils import ( - hf_model_weights_iterator, load_padded_tensor_parallel_vocab, - load_tensor_parallel_weights) + get_tensor_model_parallel_world_size) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -30,20 +33,19 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - self.gate_up_proj = ColumnParallelLinear( - hidden_size, - 2 * intermediate_size, + self.gate_up_proj = PackedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=False, gather_output=False, - ) - self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - ) + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -65,6 +67,7 @@ def __init__( bias: bool, rope_theta: float = 10000, max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.hidden_size = hidden_size @@ -79,17 +82,19 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.qkv_proj = ColumnParallelLinear( + self.qkv_proj = QKVParallelLinear( hidden_size, - 3 * self.total_num_heads * self.head_dim, + self.head_dim, + self.total_num_heads, bias=bias, - gather_output=False, + linear_method=linear_method, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=bias, input_is_parallel=True, + linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( self.num_heads, @@ -118,7 +123,11 @@ def forward( class InternLMDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -130,11 +139,13 @@ def __init__(self, config: LlamaConfig): bias=config.bias, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, + linear_method=linear_method, ) self.mlp = InternLMMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, + linear_method=linear_method, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -171,7 +182,11 @@ def forward( class InternLMModel(nn.Module): - def __init__(self, config: LlamaConfig): + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -183,7 +198,7 @@ def __init__(self, config: LlamaConfig): config.hidden_size, ) self.layers = nn.ModuleList([ - InternLMDecoderLayer(config) + InternLMDecoderLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -216,17 +231,20 @@ def forward( class InternLMForCausalLM(nn.Module): - def __init__(self, config): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config - self.model = InternLMModel(config) - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.lm_head = ColumnParallelLinear( - config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - ) + self.quant_config = quant_config + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None + self.model = InternLMModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) def forward( @@ -243,69 +261,33 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = [ - "qkv_proj.weight", "gate_proj.weight", "up_proj.weight" - ] - _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() - + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - - if "embed_tokens" in name or "lm_head" in name: - param = state_dict[name] - load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) - continue - - is_attention_weight = False - for stride_id, att_weight_name in enumerate( - ["q_proj", "k_proj", "v_proj"]): - if att_weight_name not in name: - continue - param = state_dict[name.replace(att_weight_name, "qkv_proj")] - shard_size = param.shape[0] // 3 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_attention_weight = True - break - if is_attention_weight: - continue - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_gate_up_weight: - continue - - param = state_dict[name] - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From ed4415677b2baecff8a78e7db6fcf38783542d04 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 10 Nov 2023 01:29:38 +0000 Subject: [PATCH 31/51] fix gpt2 --- vllm/model_executor/models/gpt2.py | 133 +++++++++++++---------------- 1 file changed, 57 insertions(+), 76 deletions(-) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index b9309eb956544..5af4b9f8f628d 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -30,15 +30,18 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import ( - convert_pyslice_to_tensor, hf_model_weights_iterator, - load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) + get_tensor_model_parallel_world_size) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -46,7 +49,11 @@ class GPT2Attention(nn.Module): - def __init__(self, config: GPT2Config): + def __init__( + self, + config: GPT2Config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads @@ -57,17 +64,19 @@ def __init__(self, config: GPT2Config): self.head_dim = self.hidden_size // total_num_heads self.scale = self.head_dim**-0.5 - self.c_attn = ColumnParallelLinear( + self.c_attn = QKVParallelLinear( self.hidden_size, - 3 * self.hidden_size, + self.head_dim, + total_num_heads, bias=True, - gather_output=False, + linear_method=linear_method, ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, input_is_parallel=True, + linear_method=linear_method, ) self.attn = PagedAttention(self.num_heads, self.head_dim, @@ -95,6 +104,7 @@ def __init__( self, intermediate_size: int, config: GPT2Config, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() hidden_size = config.hidden_size @@ -103,12 +113,14 @@ def __init__( intermediate_size, bias=True, gather_output=False, + linear_method=linear_method, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, input_is_parallel=True, + linear_method=linear_method, ) self.act = get_act_fn(config.activation_function) @@ -121,16 +133,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GPT2Block(nn.Module): - def __init__(self, config: GPT2Config): + def __init__( + self, + config: GPT2Config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.hidden_size inner_dim = (config.n_inner if config.n_inner is not None else 4 * hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config) + self.attn = GPT2Attention(config, linear_method) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPT2MLP(inner_dim, config) + self.mlp = GPT2MLP(inner_dim, config, linear_method) def forward( self, @@ -160,24 +176,23 @@ def forward( class GPT2Model(nn.Module): - def __init__(self, config: GPT2Config): + def __init__( + self, + config: GPT2Config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config assert not config.add_cross_attention assert not config.scale_attn_by_inverse_layer_idx assert not config.reorder_and_upcast_attn self.embed_dim = config.hidden_size - - # Optimization: While the vocab size of GPT-2 is 50257, we extend it - # to 50304 in order to make it divisible by 64. - # This improves performance since GPUs are faster if the dimension - # is divisible by 64. In addition, it allows us to shard the embedding - # layer across 2, 4, 8, or more GPUs. - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim) + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - self.h = nn.ModuleList( - [GPT2Block(config) for _ in range(config.num_hidden_layers)]) + self.h = nn.ModuleList([ + GPT2Block(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -207,12 +222,19 @@ def forward( class GPT2LMHeadModel(nn.Module): - def __init__(self, config: GPT2Config): + def __init__( + self, + config: GPT2Config, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config - self.transformer = GPT2Model(config) - # TODO(zhuohan): create a new weight after implementing pipeline - # parallelism + self.quant_config = quant_config + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None + self.transformer = GPT2Model(config, linear_method) self.lm_head_weight = self.transformer.wte.weight self.sampler = Sampler(config.vocab_size) @@ -230,19 +252,12 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = ["c_fc.weight", "c_fc.bias"] - _row_parallel_weights = ["c_proj.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() - + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "lm_head.weight" in name: @@ -253,53 +268,19 @@ def load_weights(self, # Skip attention mask. # NOTE: "c_attn.bias" should not be skipped. continue - if not name.startswith("transformer."): name = "transformer." + name - - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - + param = params_dict[name] # The HF's GPT-2 implementation uses Conv1D instead of Linear. # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: if conv1d_weight_name not in name: continue if not name.endswith(".weight"): continue loaded_weight = loaded_weight.t() - param = state_dict[name] - - if name == "transformer.wte.weight": - load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) - continue - - # For the fused QKV linear layer, manually shard the weights. - if "c_attn" in name: - # GPT-2's fused QKV has the shape of - # [3 * num_heads * head_size, hidden_size]. - # When tensor parallelism is used, we shard the weights along - # the head dimension. - total_num_heads = self.config.num_attention_heads - hidden_size = self.config.hidden_size - head_size = hidden_size // total_num_heads - num_heads = total_num_heads // tensor_model_parallel_world_size - head_start = tensor_model_parallel_rank * num_heads - head_end = (tensor_model_parallel_rank + 1) * num_heads - if name.endswith(".weight"): - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size, hidden_size) - loaded_weight = loaded_weight[:, head_start:head_end, :, :] - loaded_weight = loaded_weight.reshape(-1, hidden_size) - elif name.endswith(".bias"): - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size) - loaded_weight = loaded_weight[:, head_start:head_end, :] - loaded_weight = loaded_weight.reshape(-1) - else: - raise ValueError(f"Unexpected parameter name {name}") - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From a75dea106c4ff0ba9c5ce34264bfecb6b9ac52ac Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 10 Nov 2023 01:47:32 +0000 Subject: [PATCH 32/51] fix gpt neox --- vllm/model_executor/models/gpt_neox.py | 132 ++++++++++++++----------- 1 file changed, 75 insertions(+), 57 deletions(-) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index d0187c93c541e..6e3996dee665b 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -29,14 +29,18 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) + get_tensor_model_parallel_world_size) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -44,7 +48,11 @@ class GPTNeoXAttention(nn.Module): - def __init__(self, config: GPTNeoXConfig): + def __init__( + self, + config: GPTNeoXConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size @@ -56,15 +64,17 @@ def __init__(self, config: GPTNeoXConfig): self.num_heads = (self.total_num_heads // tensor_model_parallel_world_size) - self.query_key_value = ColumnParallelLinear( + self.query_key_value = QKVParallelLinear( config.hidden_size, - 3 * config.hidden_size, - gather_output=False, + self.head_size, + self.total_num_heads, + linear_method=linear_method, ) self.dense = RowParallelLinear( config.hidden_size, config.hidden_size, input_is_parallel=True, + linear_method=linear_method, ) scaling = self.head_size**-0.5 @@ -100,17 +110,23 @@ def forward( class GPTNeoXMLP(nn.Module): - def __init__(self, config: GPTNeoXConfig): + def __init__( + self, + config: GPTNeoXConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.dense_h_to_4h = ColumnParallelLinear( config.hidden_size, config.intermediate_size, gather_output=False, + linear_method=linear_method, ) self.dense_4h_to_h = RowParallelLinear( config.intermediate_size, config.hidden_size, input_is_parallel=True, + linear_method=linear_method, ) self.act = get_act_fn(config.hidden_act) @@ -123,15 +139,19 @@ def forward(self, hidden_states): class GPTNeoXLayer(nn.Module): - def __init__(self, config: GPTNeoXConfig): + def __init__( + self, + config: GPTNeoXConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.use_parallel_residual = config.use_parallel_residual self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config) - self.mlp = GPTNeoXMLP(config) + self.attention = GPTNeoXAttention(config, linear_method) + self.mlp = GPTNeoXMLP(config, linear_method) def forward( self, @@ -169,7 +189,11 @@ def forward( class GPTNeoXModel(nn.Module): - def __init__(self, config: GPTNeoXConfig): + def __init__( + self, + config: GPTNeoXConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config @@ -177,8 +201,10 @@ def __init__(self, config: GPTNeoXConfig): config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList( - [GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([ + GPTNeoXLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -210,15 +236,22 @@ def forward( class GPTNeoXForCausalLM(nn.Module): - def __init__(self, config): + def __init__( + self, + config, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config - self.gpt_neox = GPTNeoXModel(config) - self.embed_out = ColumnParallelLinear( - config.hidden_size, + self.quant_config = quant_config + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None + self.gpt_neox = GPTNeoXModel(config, linear_method) + self.embed_out = ParallelLMHead( config.vocab_size, - bias=False, - gather_output=False, + config.hidden_size, ) self.sampler = Sampler(config.vocab_size) @@ -236,50 +269,35 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = [ - "embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", - "dense_h_to_4h.bias" - ] - _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if ("attention.bias" in name or "attention.masked_bias" in name or "rotary_emb.inv_freq" in name): continue - param = state_dict[name] + param = params_dict[name] + if "query_key_value" in name: - # NOTE(woosuk): GPT-NeoX's fused QKV has the shape of - # [num_heads * 3 * head_size, hidden_size], while the - # required shape is [3 * num_heads * head_size, hidden_size]. + # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). # Thus, we need weight conversion. - shard_size = param.shape[0] - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - + output_dim = getattr(param, "output_dim", None) num_heads = self.config.num_attention_heads - hidden_size = self.config.hidden_size - head_size = hidden_size // num_heads - if "query_key_value.weight" in name: - loaded_weight = loaded_weight.view(-1, 3, head_size, - hidden_size) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1, hidden_size) - elif "query_key_value.bias" in name: - loaded_weight = loaded_weight.view(-1, 3, head_size) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1) - else: - raise ValueError(f"Unexpected weight name: {name}") - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From 1f6ca338ca970909fd3b81d057e001431439a386 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 10 Nov 2023 02:11:46 +0000 Subject: [PATCH 33/51] fix gptj --- .../layers/vocab_parallel_embedding.py | 26 +++- vllm/model_executor/models/gpt_j.py | 124 +++++++++++------- 2 files changed, 96 insertions(+), 54 deletions(-) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 7dc2d1f824316..42926ed63efdc 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -75,8 +75,8 @@ def __init__(self, }) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - assert loaded_weight.shape[0] == self.num_embeddings - assert loaded_weight.shape[1] == self.embedding_dim + parallel_dim = param.parallel_dim + assert loaded_weight.shape[parallel_dim] == self.num_embeddings loaded_weight = loaded_weight[self.vocab_start_index:self. vocab_end_index] param[:loaded_weight.shape[0]].data.copy_(loaded_weight) @@ -102,7 +102,25 @@ def forward(self, input_): class ParallelLMHead(VocabParallelEmbedding): - # TODO: Add docstring + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None): + super().__init__(num_embeddings, embedding_dim, params_dtype) + if bias: + self.bias = Parameter( + torch.empty(self.num_embeddings_per_partition, + device=torch.cuda.current_device(), + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "parallel_dim": 0, + "weight_loader": self.weight_loader + }) + else: + self.register_parameter('bias', None) + def forward(self, input_): del input_ - raise RuntimeError("LMHead's weight should be used in the sampler.") + raise RuntimeError("LMHead's weights should be used in the sampler.") diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 3606fdc76fb15..b23e4c3d4d16d 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -29,14 +29,18 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) + get_tensor_model_parallel_world_size) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -44,23 +48,29 @@ class GPTJAttention(nn.Module): - def __init__(self, config: GPTJConfig): + def __init__( + self, + config: GPTJConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.total_num_heads - self.qkv_proj = ColumnParallelLinear( + self.qkv_proj = QKVParallelLinear( config.hidden_size, - 3 * config.hidden_size, + self.head_size, + self.total_num_heads, bias=False, - gather_output=False, + linear_method=linear_method, ) self.out_proj = RowParallelLinear( config.hidden_size, config.hidden_size, bias=False, input_is_parallel=True, + linear_method=linear_method, ) tp_world_size = get_tensor_model_parallel_world_size() @@ -102,18 +112,25 @@ def forward( class GPTJMLP(nn.Module): - def __init__(self, intermediate_size: int, config: GPTJConfig): + def __init__( + self, + intermediate_size: int, + config: GPTJConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.n_embd self.fc_in = ColumnParallelLinear( hidden_size, intermediate_size, gather_output=False, + linear_method=linear_method, ) self.fc_out = RowParallelLinear( intermediate_size, hidden_size, input_is_parallel=True, + linear_method=linear_method, ) self.act = get_act_fn(config.activation_function) @@ -126,15 +143,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GPTJBlock(nn.Module): - def __init__(self, config: GPTJConfig): + def __init__( + self, + config: GPTJConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() if config.n_inner is None: inner_dim = 4 * config.n_embd else: inner_dim = config.n_inner self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GPTJAttention(config) - self.mlp = GPTJMLP(inner_dim, config) + self.attn = GPTJAttention(config, linear_method) + self.mlp = GPTJMLP(inner_dim, config, linear_method) def forward( self, @@ -160,7 +181,11 @@ def forward( class GPTJModel(nn.Module): - def __init__(self, config: GPTJConfig): + def __init__( + self, + config: GPTJConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.embed_dim = config.n_embd @@ -169,7 +194,7 @@ def __init__(self, config: GPTJConfig): self.embed_dim, ) self.h = nn.ModuleList( - [GPTJBlock(config) for _ in range(config.n_layer)]) + [GPTJBlock(config, linear_method) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -200,15 +225,24 @@ def forward( class GPTJForCausalLM(nn.Module): - def __init__(self, config: GPTJConfig): + def __init__( + self, + config: GPTJConfig, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config + self.quant_config = quant_config + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None assert not config.tie_word_embeddings - self.transformer = GPTJModel(config) - self.lm_head = ColumnParallelLinear( - config.n_embd, + self.transformer = GPTJModel(config, linear_method) + self.lm_head = ParallelLMHead( config.vocab_size, - gather_output=False, + config.n_embd, + bias=True, ) self.sampler = Sampler(config.vocab_size) @@ -226,43 +260,33 @@ def forward( input_metadata, self.lm_head.bias) return next_tokens - _column_parallel_weights = [ - "wte.weight", "fc_in.weight", "fc_in.bias", "lm_head.weight", - "lm_head.bias" - ] - _row_parallel_weights = ["out_proj.weight", "fc_out.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tp_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "attn.bias" in name or "attn.masked_bias" in name: continue - - is_attention_weight = False - for stride_id, att_weight_name in enumerate( - ["q_proj", "k_proj", "v_proj"]): - if att_weight_name not in name: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: continue - param = state_dict[name.replace(att_weight_name, "qkv_proj")] - shard_size = param.shape[1] - loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * - (tp_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_attention_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_attention_weight: - continue - - param = state_dict[name] - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, tp_rank) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From b118a2fcceeb9b87cbf4aa76ecea19ea35a46979 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 10 Nov 2023 08:05:56 +0000 Subject: [PATCH 34/51] fix falcon --- vllm/model_executor/models/falcon.py | 245 +++++++++++---------------- 1 file changed, 98 insertions(+), 147 deletions(-) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 6c249f6c98fec..b34fb3ad6de0d 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -30,17 +30,20 @@ from vllm.model_executor.layers.attention import (PagedAttention, PagedAttentionWithALiBi, PagedAttentionWithRoPE) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor, - hf_model_weights_iterator, - load_tensor_parallel_weights) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import RWConfig @@ -48,19 +51,6 @@ FalconConfig = Union[HF_FalconConfig, RWConfig] -# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during -# training, this means that there's one additional quantization to bfloat16 -# between the operations. In order not to degrade the quality of our HF-port, -# we keep these characteristics in the final model. -class FalconLinear(nn.Linear): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - hidden_states = x @ self.weight.T - if self.bias is None: - return hidden_states - return hidden_states + self.bias - - def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), @@ -86,7 +76,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: class FalconAttention(nn.Module): - def __init__(self, config: FalconConfig): + def __init__( + self, + config: FalconConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size @@ -103,41 +97,29 @@ def __init__(self, config: FalconConfig): if self.new_decoder_architecture: self.total_num_kv_heads = config.num_kv_heads - assert self.total_num_heads % tp_size == 0 - self.num_kv_heads = self.total_num_kv_heads // tp_size - self.query_key_value = ColumnParallelLinear( - self.hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * - self.head_dim, - bias=config.bias, - gather_output=False, - skip_bias_add=True, - ) elif self.multi_query: self.total_num_kv_heads = 1 - self.num_kv_heads = 1 - self.query = ColumnParallelLinear( - self.hidden_size, - self.total_num_heads * self.head_dim, - bias=config.bias, - gather_output=False, - skip_bias_add=True, - ) - self.key_value = FalconLinear(self.hidden_size, - 2 * self.head_dim, - bias=config.bias) else: self.total_num_kv_heads = self.total_num_heads - self.num_kv_heads = self.num_heads - self.query_key_value = ColumnParallelLinear( - self.hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * - self.head_dim, - bias=config.bias, - gather_output=False, - skip_bias_add=True, - ) + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.bias, + skip_bias_add=True, + linear_method=linear_method, + ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -196,18 +178,10 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - if not self.new_decoder_architecture and self.multi_query: - q, bias = self.query(hidden_states) - if bias is not None: - q += bias - kv = self.key_value(hidden_states) - k, v = kv.split([self.kv_size, self.kv_size], dim=-1) - else: - qkv, bias = self.query_key_value(hidden_states) - if bias is not None: - qkv += bias - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) + qkv, bias = self.query_key_value(hidden_states) + if bias is not None: + qkv += bias + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) k_cache, v_cache = kv_cache if self.use_rotary: attn_output = self.attn(positions, q, k, v, k_cache, v_cache, @@ -221,7 +195,11 @@ def forward( class FalconMLP(nn.Module): - def __init__(self, config: FalconConfig): + def __init__( + self, + config: FalconConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.hidden_size @@ -229,7 +207,8 @@ def __init__(self, config: FalconConfig): 4 * hidden_size, bias=config.bias, gather_output=False, - skip_bias_add=True) + skip_bias_add=True, + linear_method=linear_method) self.act = nn.GELU() self.reduce_row_parallel_results = not (config.new_decoder_architecture or config.parallel_attn) @@ -239,7 +218,8 @@ def __init__(self, config: FalconConfig): bias=config.bias, input_is_parallel=True, skip_bias_add=True, - reduce_results=self.reduce_row_parallel_results) + reduce_results=self.reduce_row_parallel_results, + linear_method=linear_method) def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE(zhuohan): Following huggingface, we do not fuse bias add here. @@ -253,12 +233,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FalconDecoderLayer(nn.Module): - def __init__(self, config: FalconConfig): + def __init__( + self, + config: FalconConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = FalconAttention(config) - self.mlp = FalconMLP(config) + self.self_attention = FalconAttention(config, linear_method) + self.mlp = FalconMLP(config, linear_method) self.config = config if config.new_decoder_architecture: @@ -334,7 +318,11 @@ def forward( class FalconModel(nn.Module): - def __init__(self, config: FalconConfig): + def __init__( + self, + config: FalconConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -349,7 +337,8 @@ def __init__(self, config: FalconConfig): # Transformer blocks self.h = nn.ModuleList([ - FalconDecoderLayer(config) for _ in range(config.num_hidden_layers) + FalconDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) ]) # Final Layer Norm @@ -383,15 +372,22 @@ def forward( class FalconForCausalLM(nn.Module): - def __init__(self, config: FalconConfig): + def __init__( + self, + config: FalconConfig, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config - self.transformer = FalconModel(config) - self.lm_head = ColumnParallelLinear( - config.hidden_size, + self.quant_config = quant_config + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None + self.transformer = FalconModel(config, linear_method) + self.lm_head = ParallelLMHead( config.vocab_size, - bias=False, - gather_output=False, + config.hidden_size, ) self.sampler = Sampler(config.vocab_size) @@ -415,89 +411,44 @@ def forward( return next_tokens - _column_parallel_weights = [ - "word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight", - "dense_h_to_4h.bias" - ] - _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tp_size = (get_tensor_model_parallel_world_size()) - tp_rank = get_tensor_model_parallel_rank() - - hidden_size = self.config.hidden_size total_num_heads = self.config.num_attention_heads - num_heads = total_num_heads // tp_size - head_size = hidden_size // total_num_heads - head_start = tp_rank * num_heads - head_end = (tp_rank + 1) * num_heads if self.config.new_decoder_architecture: total_num_kv_heads = self.config.num_kv_heads - num_kv_heads = total_num_kv_heads // tp_size - separated_q_kv = False - kv_head_start = tp_rank * num_kv_heads - kv_head_end = (tp_rank + 1) * num_kv_heads elif self.config.multi_query: total_num_kv_heads = 1 - num_kv_heads = 1 - separated_q_kv = True - kv_head_start = 0 - kv_head_end = 1 else: total_num_kv_heads = total_num_heads - num_kv_heads = total_num_kv_heads // tp_size - separated_q_kv = False - kv_head_start = tp_rank * num_kv_heads - kv_head_end = (tp_rank + 1) * num_kv_heads num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads - state_dict = self.state_dict() - + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + param = params_dict[name] if "query_key_value" in name: - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - loaded_weight_size = loaded_weight.size() + output_dim = getattr(param, "output_dim", None) + loaded_weight_shape = loaded_weight.shape loaded_weight = loaded_weight.view( - total_num_kv_heads, num_query_heads_per_kv_head + 2, - head_size, *loaded_weight_size[1:]) - - wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:]) - wk = loaded_weight[:, [-2]].reshape(-1, - *loaded_weight_size[1:]) - wv = loaded_weight[:, [-1]].reshape(-1, - *loaded_weight_size[1:]) - - wq = wq[head_size * head_start:head_size * head_end] - wk = wk[head_size * kv_head_start:head_size * kv_head_end] - wv = wv[head_size * kv_head_start:head_size * kv_head_end] - - if separated_q_kv: - loaded_weight_q = wq - loaded_weight_kv = torch.cat([wk, wv], dim=0) - q_weight_name = name.replace("query_key_value", "query") - kv_weight_name = name.replace("query_key_value", - "key_value") - load_tensor_parallel_weights(state_dict[q_weight_name], - loaded_weight_q, - q_weight_name, - self._column_parallel_weights, - self._row_parallel_weights, - tp_rank) - load_tensor_parallel_weights(state_dict[kv_weight_name], - loaded_weight_kv, - kv_weight_name, - self._column_parallel_weights, - self._row_parallel_weights, - tp_rank) - continue - else: - loaded_weight = torch.cat([wq, wk, wv], dim=0) - - param = state_dict[name] - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, tp_rank) + loaded_weight_shape[:output_dim] + + (total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) + + loaded_weight_shape[output_dim + 1:]) + wq = loaded_weight.narrow( + output_dim + 1, 0, num_query_heads_per_kv_head).reshape( + *loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wk = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wv = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head + 1, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From fb595c730b0a9ab7e5b0e46a410a62bd6da0ffd7 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Fri, 10 Nov 2023 08:10:33 +0000 Subject: [PATCH 35/51] clean up --- test_class.py | 20 -------------------- 1 file changed, 20 deletions(-) delete mode 100644 test_class.py diff --git a/test_class.py b/test_class.py deleted file mode 100644 index 5a62560a7d45b..0000000000000 --- a/test_class.py +++ /dev/null @@ -1,20 +0,0 @@ -class A: - - def __init__(self) -> None: - self.hello() - - def hello(self): - print("Hello, I'm A") - - -class B(A): - - def __init__(self) -> None: - super().__init__() - self.hello() - - def hello(self): - print("Hello, I'm B") - - -b = B() From d5ffe8865b90125f45f60306dc3af7767570cd2b Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 11 Nov 2023 02:39:18 +0000 Subject: [PATCH 36/51] Fix GPT Bigcode --- vllm/model_executor/models/gpt_bigcode.py | 189 ++++-------- vllm/model_executor/parallel_utils/layers.py | 303 ------------------- 2 files changed, 67 insertions(+), 425 deletions(-) delete mode 100644 vllm/model_executor/parallel_utils/layers.py diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 41f72c8cb7086..f3e875de08a39 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -31,15 +31,18 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import ( - convert_pyslice_to_tensor, hf_model_weights_iterator, - load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -47,7 +50,11 @@ class GPTBigCodeAttention(nn.Module): - def __init__(self, config: GPTBigCodeConfig): + def __init__( + self, + config: GPTBigCodeConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads @@ -61,32 +68,27 @@ def __init__(self, config: GPTBigCodeConfig): self.multi_query = config.multi_query if self.multi_query: + total_num_kv_heads = 1 self.num_kv_heads = 1 - self.kv_dim = self.head_dim - self.c_attn_q = ColumnParallelLinear( - self.hidden_size, - self.hidden_size, - bias=True, - gather_output=False, - ) - self.c_attn_kv = nn.Linear(self.hidden_size, - 2 * self.kv_dim, - bias=True) else: + total_num_kv_heads = total_num_heads self.num_kv_heads = self.num_heads - self.kv_dim = self.num_kv_heads * self.head_dim - self.c_attn = ColumnParallelLinear( - self.hidden_size, - self.hidden_size + 2 * self.kv_dim, - bias=True, - gather_output=False, - ) + self.kv_dim = self.head_dim * self.num_kv_heads + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + total_num_kv_heads, + bias=True, + linear_method=linear_method, + ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, input_is_parallel=True, + linear_method=linear_method, ) self.attn = PagedAttention(self.num_heads, self.head_dim, @@ -100,17 +102,14 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - if self.multi_query: - q, _ = self.c_attn_q(hidden_states) - kv = self.c_attn_kv(hidden_states) - k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1) - else: - qkv, _ = self.c_attn(hidden_states) - q, k, v = qkv.split([ + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.split( + [ self.hidden_size // self.tensor_model_parallel_world_size, self.kv_dim, self.kv_dim ], - dim=-1) + dim=-1, + ) key_cache, value_cache = kv_cache attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata, cache_event) @@ -124,6 +123,7 @@ def __init__( self, intermediate_size: int, config: GPTBigCodeConfig, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() hidden_size = config.hidden_size @@ -132,12 +132,14 @@ def __init__( intermediate_size, bias=True, gather_output=False, + linear_method=linear_method, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, input_is_parallel=True, + linear_method=linear_method, ) self.act = get_act_fn(config.activation_function) @@ -150,16 +152,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GPTBigCodeBlock(nn.Module): - def __init__(self, config: GPTBigCodeConfig): + def __init__( + self, + config: GPTBigCodeConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.hidden_size inner_dim = (config.n_inner if config.n_inner is not None else 4 * hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config) + self.attn = GPTBigCodeAttention(config, linear_method) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPTBigMLP(inner_dim, config) + self.mlp = GPTBigMLP(inner_dim, config, linear_method) def forward( self, @@ -189,23 +195,23 @@ def forward( class GPTBigCodeModel(nn.Module): - def __init__(self, config: GPTBigCodeConfig): + def __init__( + self, + config: GPTBigCodeConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config assert not config.add_cross_attention self.embed_dim = config.hidden_size - # Optimization: While the vocab size of GPT-2 is 50257, we extend it - # to 50304 in order to make it divisible by 64. - # This improves performance since GPUs are faster if the dimension - # is divisible by 64. In addition, it allows us to shard the embedding - # layer across 2, 4, 8, or more GPUs. - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim) + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - self.h = nn.ModuleList( - [GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)]) + self.h = nn.ModuleList([ + GPTBigCodeBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -235,12 +241,19 @@ def forward( class GPTBigCodeForCausalLM(nn.Module): - def __init__(self, config: GPTBigCodeConfig): + def __init__( + self, + config: GPTBigCodeConfig, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config = config - self.transformer = GPTBigCodeModel(config) - # TODO(zhuohan): create a new weight after implementing pipeline - # parallelism + self.quant_config = quant_config + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None + self.transformer = GPTBigCodeModel(config, linear_method) self.lm_head_weight = self.transformer.wte.weight self.sampler = Sampler(config.vocab_size) @@ -258,89 +271,21 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = ["c_fc.weight", "c_fc.bias"] - _row_parallel_weights = ["c_proj.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() - + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "lm_head.weight" in name: - # GPT-2 ties the weights of the embedding layer and the final - # linear layer. continue if ".attn.bias" in name: # Skip attention mask. # NOTE: "c_attn.bias" should not be skipped. continue - - if not name.startswith("transformer."): - name = "transformer." + name - - # For the fused QKV linear layer, manually shard the weights. - if "c_attn" in name: - # GPT-2's fused QKV has the shape of - # [3 * num_heads * head_size, hidden_size]. - # When tensor parallelism is used, we shard the weights along - # the head dimension. - total_num_heads = self.config.num_attention_heads - total_num_kv_heads = (1 if self.config.multi_query else - total_num_heads) - hidden_size = self.config.hidden_size - head_size = hidden_size // total_num_heads - total_kv_size = head_size * total_num_kv_heads - num_heads = total_num_heads // tensor_model_parallel_world_size - head_start = tensor_model_parallel_rank * num_heads - head_end = (tensor_model_parallel_rank + 1) * num_heads - - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - wq, wk, wv = torch.split( - loaded_weight, [hidden_size, total_kv_size, total_kv_size], - dim=0) - - wq = wq[head_size * head_start:head_size * head_end] - if not self.config.multi_query: - # Split the heads when using normal multi-head attention - wk = wk[head_size * head_start:head_size * head_end] - wv = wv[head_size * head_start:head_size * head_end] - loaded_weight = torch.cat([wq, wk, wv], dim=0) - else: - # For multi-query attention, we split the query - # but replicate the key and value. - loaded_weight_q = wq - loaded_weight_kv = torch.cat([wk, wv], dim=0) - q_weight_name = name.replace("c_attn", "c_attn_q") - kv_weight_name = name.replace("c_attn", "c_attn_kv") - load_tensor_parallel_weights(state_dict[q_weight_name], - loaded_weight_q, - q_weight_name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) - load_tensor_parallel_weights(state_dict[kv_weight_name], - loaded_weight_kv, - kv_weight_name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) - continue - - param = state_dict[name] - - if name == "transformer.wte.weight": - load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) - continue - - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/parallel_utils/layers.py b/vllm/model_executor/parallel_utils/layers.py deleted file mode 100644 index c1aea2c1d5543..0000000000000 --- a/vllm/model_executor/parallel_utils/layers.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright 2023 The vLLM team. -# Adapted from -# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -# Parts of the code here are adapted from PyTorch -# repo: https://github.com/pytorch/pytorch -from typing import Optional - -import torch -import torch.nn.functional as F -from torch.nn.parameter import Parameter - -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from vllm.model_executor.quantization_utils import QuantizationConfig -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather) - -from vllm.model_executor.parallel_utils.utils import ( - divide, - VocabUtility, - split_tensor_along_last_dim, -) - - -class VocabParallelEmbedding(torch.nn.Module): - """Embedding parallelized in the vocabulary dimension. - - This is mainly adapted from torch.nn.Embedding and all the default - values are kept. - Arguments: - num_embeddings: vocabulary size. - embedding_dim: size of hidden state. - params_dtype: type of the parameters. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - params_dtype: Optional[torch.dtype] = None): - super().__init__() - - # Keep the input dimensions. - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - self.tp_size = get_tensor_model_parallel_world_size() - # TODO: Handle vocab padding here. - # Divide the weight matrix along the vocaburaly dimension. - self.vocab_start_index, self.vocab_end_index = ( - VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, get_tensor_model_parallel_rank(), - self.tp_size)) - self.num_embeddings_per_partition = (self.vocab_end_index - - self.vocab_start_index) - - self.weight = Parameter( - torch.empty(self.num_embeddings_per_partition, - self.embedding_dim, - device=torch.cuda.current_device(), - dtype=params_dtype)) - - def forward(self, input_): - if self.tp_size > 1: - # Build the mask. - input_mask = ((input_ < self.vocab_start_index) | - (input_ >= self.vocab_end_index)) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - else: - masked_input = input_ - # Get the embeddings. - output_parallel = F.embedding(masked_input, self.weight) - # Mask the output embedding. - if self.tp_size > 1: - output_parallel[input_mask, :] = 0.0 - # Reduce across all the model parallel GPUs. - output = tensor_model_parallel_all_reduce(output_parallel) - return output - - -class ColumnParallelLinear(torch.nn.Module): - """Linear layer with column parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its second dimension as A = [A_1, ..., A_p]. - - Arguments: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - - Keyword Arguments - bias: If true, add bias - gather_output: If true, call all-gather on output and make Y available - to all GPUs, otherwise, every GPU will have its output - which is Y_i = XA_i - skip_bias_add: This was added to enable performance optimizations where - bias can be fused with other element-wise operations. we - skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configuration. - """ - - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - gather_output: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.gather_output = gather_output - # Divide the weight matrix along the last dimension. - self.tp_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, self.tp_size) - self.skip_bias_add = skip_bias_add - self.quant_config = quant_config - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - # Parameters. - # NOTE: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - self.create_weights(params_dtype) - - if bias: - self.bias = Parameter( - torch.empty(self.output_size_per_partition, - device=torch.cuda.current_device(), - dtype=params_dtype)) - else: - self.register_parameter('bias', None) - - def create_weights(self, dtype: torch.dtype) -> None: - self.weight = Parameter( - torch.empty(self.output_size_per_partition, - self.input_size, - device=torch.cuda.current_device(), - dtype=dtype)) - - def apply_weights( - self, - x: torch.Tensor, - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - return F.linear(x, self.weight, bias) - - def forward(self, input_): - """Forward of ColumnParallelLinear - - Args: - input_: Tensor whose last dimension is `input_size`. - - Returns: - - output - - bias - """ - bias = self.bias if not self.skip_bias_add else None - - input_parallel = input_ - # Matrix multiply. - output_parallel = self.apply_weights(input_parallel, bias) - if self.gather_output: - # All-gather across the partitions. - output = tensor_model_parallel_all_gather(output_parallel) - else: - output = output_parallel - output_bias = self.bias if self.skip_bias_add else None - return output, output_bias - - -class RowParallelLinear(torch.nn.Module): - """Linear layer with row parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its first dimension and X along its second dimension as: - - - - | A_1 | - | . | - A = | . | X = [X_1, ..., X_p] - | . | - | A_p | - - - - Arguments: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - - Keyword Arguments: - bias: If true, add bias. Note that bias is not parallelized. - input_is_parallel: If true, we assume that the input is already - split across the GPUs and we do not split - again. - skip_bias_add: This was added to enable performance optimization where - bias can be fused with other element-wise operations. - We skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configuration. - """ - - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.input_is_parallel = input_is_parallel - self.reduce_results = reduce_results - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - # Divide the weight matrix along the last dimension. - self.tp_size = get_tensor_model_parallel_world_size() - self.input_size_per_partition = divide(input_size, self.tp_size) - self.skip_bias_add = skip_bias_add - self.quant_config = quant_config - - self.create_weights(params_dtype) - - if not reduce_results and (bias and not skip_bias_add): - raise ValueError('When not reduce the results, adding bias to the ' - 'results can lead to incorrect results') - - if bias: - self.bias = Parameter( - torch.empty(self.output_size, - device=torch.cuda.current_device(), - dtype=params_dtype)) - - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter('bias', None) - - def create_weights(self, dtype: torch.dtype) -> None: - self.weight = Parameter( - torch.empty(self.output_size, - self.input_size_per_partition, - device=torch.cuda.current_device(), - dtype=dtype)) - - def apply_weights(self, x: torch.Tensor) -> torch.Tensor: - return F.linear(x, self.weight) - - def forward(self, input_): - """Forward of RowParallelLinear - - Args: - input_: tensor whose last dimension is `input_size`. If - `input_is_parallel` is set, then the last dimension - is `input_size // tp_size`. - - Returns: - - output - - bias - """ - # Set up backprop all-reduce. - if self.input_is_parallel: - input_parallel = input_ - else: - # TODO: simplify code below - tp_rank = get_tensor_model_parallel_rank() - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() - - # Matrix multiply. - output_parallel = self.apply_weights(input_parallel) - if self.reduce_results and self.tp_size > 1: - output_ = tensor_model_parallel_all_reduce(output_parallel) - else: - output_ = output_parallel - - if not self.skip_bias_add: - output = output_ + self.bias if self.bias is not None else output_ - output_bias = None - else: - output = output_ - output_bias = self.bias - return output, output_bias From 7acf443ba394234155841be700eff42a4b254577 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 11 Nov 2023 05:36:46 +0000 Subject: [PATCH 37/51] Fix chatglm and yi models --- vllm/model_executor/layers/linear.py | 33 +++- vllm/model_executor/models/chatglm.py | 192 +++++++++++------------ vllm/model_executor/models/llama.py | 1 - vllm/model_executor/models/yi.py | 209 +++++++------------------- 4 files changed, 170 insertions(+), 265 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 26efa0e20cfb4..c7190e0002780 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -222,13 +222,38 @@ def __init__( super().__init__(input_size, sum(output_sizes), bias, gather_output, skip_bias_add, params_dtype, linear_method) - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor, - loaded_shard_id: int): + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + param_data = param.data + output_dim = getattr(param, "output_dim", None) + if loaded_shard_id is None: + # Loaded weight is already packed. + if output_dim is None: + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + current_shard_offset = 0 + shard_offsets = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + assert loaded_shard_id < len(self.output_sizes) tp_rank = get_tensor_model_parallel_rank() tp_size = get_tensor_model_parallel_world_size() - param_data = param.data - output_dim = getattr(param, "output_dim", None) if output_dim is not None: shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size shard_size = self.output_sizes[loaded_shard_id] // tp_size diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 8acc8e468b652..edf0d3579f679 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -6,32 +6,29 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn from torch.nn import LayerNorm from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import (LinearMethodBase, + PackedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import ( - hf_model_weights_iterator, - load_tensor_parallel_weights, -) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding -from vllm.model_executor.parallel_utils.layers import ( - ColumnParallelLinear, - RowParallelLinear, -) -from vllm.sequence import SequenceOutputs - + get_tensor_model_parallel_world_size) +from vllm.model_executor.quantization_utils import QuantizationConfig +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -39,7 +36,11 @@ class GLMAttention(nn.Module): - def __init__(self, config): + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -50,25 +51,34 @@ def __init__(self, config): self.total_num_kv_heads = (config.multi_query_group_num if config.multi_query_attention else config.num_attention_heads) - assert self.total_num_kv_heads % tp_size == 0 - self.num_kv_heads = self.total_num_kv_heads // tp_size + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = config.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.query_key_value = ColumnParallelLinear( - config.hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * + self.query_key_value = QKVParallelLinear( + self.hidden_size, self.head_dim, - bias=config.add_qkv_bias, - gather_output=False, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.add_bias_linear or config.add_qkv_bias, + linear_method=linear_method, ) self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=config.add_bias_linear, input_is_parallel=True, + linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( @@ -78,7 +88,6 @@ def __init__(self, config): rotary_dim=self.head_dim // 2, num_kv_heads=self.num_kv_heads, is_neox_style=False, - # is_glm_style=True ) def forward( @@ -117,17 +126,22 @@ class GLMMLP(nn.Module): state back into h hidden dimension. """ - def __init__(self, config): + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.add_bias = config.add_bias_linear # Project to 4h. - self.dense_h_to_4h = ColumnParallelLinear( + self.dense_h_to_4h = PackedColumnParallelLinear( config.hidden_size, - config.ffn_hidden_size * 2, + [config.ffn_hidden_size] * 2, bias=config.add_bias_linear, gather_output=False, + linear_method=linear_method, ) self.activation_func = SiluAndMul() @@ -138,6 +152,7 @@ def __init__(self, config): config.hidden_size, bias=config.add_bias_linear, input_is_parallel=True, + linear_method=linear_method, ) def forward(self, hidden_states): @@ -159,6 +174,7 @@ class GLMBlock(nn.Module): def __init__( self, config, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.apply_residual_connection_post_layernorm = ( @@ -172,7 +188,7 @@ def __init__( eps=config.layernorm_epsilon) # Self attention. - self.self_attention = GLMAttention(config) + self.self_attention = GLMAttention(config, linear_method) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output @@ -180,7 +196,7 @@ def __init__( config.hidden_size, eps=config.layernorm_epsilon) # MLP - self.mlp = GLMMLP(config) + self.mlp = GLMMLP(config, linear_method) def forward( self, @@ -227,7 +243,11 @@ def forward( class GLMTransformer(nn.Module): """Transformer class.""" - def __init__(self, config): + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.post_layer_norm = config.post_layer_norm @@ -236,7 +256,7 @@ def __init__(self, config): # Transformer layers. self.layers = nn.ModuleList( - [GLMBlock(config) for i in range(self.num_layers)]) + [GLMBlock(config, linear_method) for i in range(self.num_layers)]) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm @@ -274,7 +294,11 @@ def forward( class ChatGLMModel(nn.Module): - def __init__(self, config): + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.embedding = VocabParallelEmbedding(config.padded_vocab_size, @@ -283,15 +307,10 @@ def __init__(self, config): self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config) + self.encoder = GLMTransformer(config, linear_method) - self.output_layer = ColumnParallelLinear( - config.hidden_size, - config.padded_vocab_size, - bias=False, - gather_output=False, - params_dtype=config.torch_dtype, - ) + self.output_layer = ParallelLMHead(config.padded_vocab_size, + config.hidden_size) def forward( self, @@ -317,10 +336,19 @@ def forward( class ChatGLMForCausalLM(nn.Module): - def __init__(self, config: ChatGLMConfig): + def __init__( + self, + config: ChatGLMConfig, + quant_config: Optional[QuantizationConfig] = None, + ): super().__init__() self.config: ChatGLMConfig = config - self.transformer = ChatGLMModel(config) + self.quant_config = quant_config + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None + self.transformer = ChatGLMModel(config, linear_method) self.lm_head_weight = self.transformer.output_layer.weight self.sampler = Sampler(config.padded_vocab_size) @@ -331,78 +359,26 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head_weight, hidden_states, input_metadata) return next_tokens - _column_parallel_weights = [ - "output_layer.weight", - "embedding.weight", - ] - _row_parallel_weights = ["dense_4h_to_h", "self_attention.dense"] - - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - q_proj_shard_size = self.config.hidden_size // tp_size - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - self.config.multi_query_group_num // tp_size) - - mlp_hidden_shard_size = self.config.ffn_hidden_size // tp_size - - state_dict = self.state_dict() + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + if "rotary_pos_emb.inv_freq" in name: + continue if "word_embeddings" in name: name = name.replace(".word_embeddings", "") - - if name in state_dict: - param = state_dict[name] - if "query_key_value" in name: - q_offset = q_proj_shard_size * tp_rank - k_offset = (q_proj_shard_size * tp_size + - kv_proj_shard_size * tp_rank) - v_offset = (q_proj_shard_size * tp_size + - kv_proj_shard_size * (tp_size + tp_rank)) - wq = loaded_weight[q_offset:q_offset + q_proj_shard_size] - wk = loaded_weight[k_offset:k_offset + kv_proj_shard_size] - wv = loaded_weight[v_offset:v_offset + kv_proj_shard_size] - loaded_weight = torch.cat([wq, wk, wv], dim=0) - param.data.copy_(loaded_weight) - continue - - if "dense_h_to_4h" in name: - w_gate = loaded_weight[mlp_hidden_shard_size * - tp_rank:mlp_hidden_shard_size * - (tp_rank + 1)] - w_proj = loaded_weight[mlp_hidden_shard_size * - (tp_size + - tp_rank):mlp_hidden_shard_size * - (tp_size + tp_rank + 1)] - loaded_weight = torch.cat([w_gate, w_proj], dim=0) - param.data.copy_(loaded_weight) - continue - - load_tensor_parallel_weights( - param, - loaded_weight, - name, - self._column_parallel_weights, - self._row_parallel_weights, - tp_rank, - ) - elif name == "transformer.rotary_pos_emb.inv_freq": - continue - else: - print("Warning never found tensor's name:", name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b72379d9237ff..8c885153c4d0d 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -170,7 +170,6 @@ def __init__( ) -> None: super().__init__() self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py index a0958f6164e49..1799df3c40687 100644 --- a/vllm/model_executor/models/yi.py +++ b/vllm/model_executor/models/yi.py @@ -33,17 +33,20 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + PackedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.quantized_linear import ParallelLinear +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding + get_tensor_model_parallel_world_size) from vllm.model_executor.quantization_utils import QuantizationConfig -from vllm.model_executor.weight_utils import ( - convert_pyslice_to_tensor, hf_model_weights_iterator, - load_tensor_parallel_weights, load_padded_tensor_parallel_vocab) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -56,19 +59,19 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() - self.gate_up_proj = ParallelLinear.column(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - quant_config=quant_config) - self.down_proj = ParallelLinear.row(intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config) + self.gate_up_proj = PackedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + gather_output=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + input_is_parallel=True, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -91,7 +94,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -109,7 +112,6 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -117,21 +119,20 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.qkv_proj = ParallelLinear.column( + self.qkv_proj = QKVParallelLinear( hidden_size, - (self.total_num_heads + - 2 * self.total_num_kv_heads * num_kv_heads_replicas) * self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, bias=False, - gather_output=False, - quant_config=quant_config, + linear_method=linear_method, ) - self.o_proj = ParallelLinear.row( + self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, input_is_parallel=True, - quant_config=quant_config, + linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( self.num_heads, @@ -165,11 +166,10 @@ class YiDecoderLayer(nn.Module): def __init__( self, config: YiConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", @@ -181,13 +181,13 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - quant_config=quant_config, + linear_method=linear_method, ) self.mlp = YiMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - quant_config=quant_config, + linear_method=linear_method, ) self.ln1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.ln2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -225,20 +225,18 @@ class YiModel(nn.Module): def __init__( self, config: YiConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( - vocab_size, + config.vocab_size, config.hidden_size, ) self.layers = nn.ModuleList([ - YiDecoderLayer(config, quant_config) + YiDecoderLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -279,14 +277,12 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config - self.model = YiModel(config, quant_config) - vocab_size = ((config.vocab_size + 63) // 64) * 64 - # NOTE: The LM head is not quantized. - self.lm_head = ParallelLinear.column(config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - quant_config=None) + if quant_config is not None: + linear_method = quant_config.get_linear_method() + else: + linear_method = None + self.model = YiModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) def forward( @@ -303,124 +299,33 @@ def forward( input_metadata) return next_tokens - _column_parallel_layers = [] - _row_parallel_layers = ["o_proj", "down_proj"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - if self.quant_config is None: - col_weight_suffixes = ["weight"] - row_weight_suffixes = ["weight"] - else: - col_weight_suffixes = ( - self.quant_config.get_col_parallel_tensor_names()) - row_weight_suffixes = ( - self.quant_config.get_row_parallel_tensor_names()) - - column_parallel_weights: List[str] = [] - for layer in self._column_parallel_layers: - for suffix in col_weight_suffixes: - column_parallel_weights.append(f"{layer}.{suffix}") - row_parallel_weights: List[str] = [] - for layer in self._row_parallel_layers: - for suffix in row_weight_suffixes: - row_parallel_weights.append(f"{layer}.{suffix}") - - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - q_proj_shard_size = (self.config.hidden_size // tp_size) - num_kv_heads_replicas = max(1, - tp_size // self.config.num_key_value_heads) - num_kv_heads_per_gpu = max(1, - self.config.num_key_value_heads // tp_size) - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - num_kv_heads_per_gpu) - attention_weight_specs = [ - # (weight_name, shard_size, offset) - ("q_proj", q_proj_shard_size, 0), - ("k_proj", kv_proj_shard_size, q_proj_shard_size), - ("v_proj", kv_proj_shard_size, - q_proj_shard_size + kv_proj_shard_size), + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] - state_dict = self.state_dict() - + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - - packed_dim = None - is_transposed = False - if self.quant_config is not None: - packed_dim = self.quant_config.get_packed_dim(name) - is_transposed = self.quant_config.is_transposed(name) - if is_transposed: - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - loaded_weight = loaded_weight.T - - is_attention_weight = False - for weight_name, shard_size, offset in attention_weight_specs: + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "qkv_proj")] - if is_transposed: - param = param.T - - if packed_dim is not None: - shard_dim = 0 if not is_transposed else 1 - if packed_dim == shard_dim: - shard_size //= self.quant_config.pack_factor - offset //= self.quant_config.pack_factor - - if weight_name in ["k_proj", "v_proj"]: - shard_id = tp_rank // num_kv_heads_replicas - else: - shard_id = tp_rank - loaded_weight = loaded_weight[shard_size * - shard_id:shard_size * - (shard_id + 1)] - param_slice = param.data[offset:offset + shard_size] - assert param_slice.shape == loaded_weight.shape - - param_slice.copy_(loaded_weight) - is_attention_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_attention_weight: - continue - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - if is_transposed: - param = param.T - - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * - (tp_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True - break - if is_gate_up_weight: - continue - - param = state_dict[name] - if is_transposed: - param = param.T - - if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tp_rank) - continue - - load_tensor_parallel_weights(param, loaded_weight, name, - column_parallel_weights, - row_parallel_weights, tp_rank) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) From c33e0f081ef5557c028dda64ac289a6349a2e80c Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 11 Nov 2023 05:41:54 +0000 Subject: [PATCH 38/51] format --- vllm/model_executor/layers/linear.py | 13 ++++++------- .../layers/quantized_linear/__init__.py | 2 -- .../layers/vocab_parallel_embedding.py | 2 +- vllm/model_executor/models/gpt_bigcode.py | 2 +- vllm/model_executor/parallel_utils/utils.py | 8 ++------ vllm/model_executor/quantization_utils/base.py | 2 +- vllm/model_executor/weight_utils.py | 16 ---------------- 7 files changed, 11 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index c7190e0002780..341103fad7350 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass -from typing import Any, Dict, List, Optional +from typing import Dict, List, Optional import torch import torch.nn.functional as F @@ -94,7 +93,7 @@ def __init__( dtype=self.params_dtype)) set_weight_attrs(self.bias, {"output_dim": 0}) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None @@ -166,7 +165,7 @@ def __init__( "weight_loader": self.weight_loader, }) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() @@ -441,8 +440,8 @@ def __init__( set_weight_attrs(weight, {"weight_loader": self.weight_loader}) if not reduce_results and (bias and not skip_bias_add): - raise ValueError('When not reduce the results, adding bias to the ' - 'results can lead to incorrect results') + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") if bias: self.bias = Parameter( @@ -454,7 +453,7 @@ def __init__( "weight_loader": self.weight_loader, }) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): tp_rank = get_tensor_model_parallel_rank() diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantized_linear/__init__.py index 21b7a4bdfcbc6..e69de29bb2d1d 100644 --- a/vllm/model_executor/layers/quantized_linear/__init__.py +++ b/vllm/model_executor/layers/quantized_linear/__init__.py @@ -1,2 +0,0 @@ -class ParallelLinear: - pass diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 42926ed63efdc..85d0370fe8fc7 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -119,7 +119,7 @@ def __init__(self, "weight_loader": self.weight_loader }) else: - self.register_parameter('bias', None) + self.register_parameter("bias", None) def forward(self, input_): del input_ diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index f3e875de08a39..216b29e9a196c 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -39,7 +39,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size) from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) diff --git a/vllm/model_executor/parallel_utils/utils.py b/vllm/model_executor/parallel_utils/utils.py index 8244bcb02535f..0cd420c8e11b5 100644 --- a/vllm/model_executor/parallel_utils/utils.py +++ b/vllm/model_executor/parallel_utils/utils.py @@ -2,7 +2,7 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -from typing import List, Sequence +from typing import Sequence import torch @@ -24,7 +24,7 @@ def split_tensor_along_last_dim( tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: +) -> Sequence[torch.Tensor]: """ Split a tensor along its last dimension. Arguments: @@ -46,7 +46,3 @@ def split_tensor_along_last_dim( return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list - - -class VocabUtility: - pass diff --git a/vllm/model_executor/quantization_utils/base.py b/vllm/model_executor/quantization_utils/base.py index 65e36d24264b6..8af16b7e145f3 100644 --- a/vllm/model_executor/quantization_utils/base.py +++ b/vllm/model_executor/quantization_utils/base.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List import torch diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index cc75cf58ebc1e..daffbf4082568 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -13,8 +13,6 @@ from tqdm.auto import tqdm from vllm.logger import init_logger -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.quantization_utils import get_quant_class from vllm.model_executor.quantization_utils.base import QuantizationConfig @@ -271,20 +269,6 @@ def default_weight_loader(param: torch.Tensor, param.data.copy_(loaded_weight) -def load_padded_tensor_parallel_vocab( - param: torch.Tensor, - loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` -) -> None: - raise NotImplementedError() - - -def load_tensor_parallel_weights( - param: torch.Tensor, - loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` -) -> None: - raise NotImplementedError() - - def initialize_dummy_weights( model: torch.nn.Module, low: float = -1e-3, From 63af93cecf2654afbca4a90cf337ebfd7a00da8c Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 11 Nov 2023 06:17:50 +0000 Subject: [PATCH 39/51] Simplify code logic --- vllm/model_executor/model_loader.py | 7 ++++--- vllm/model_executor/models/aquila.py | 7 ++----- vllm/model_executor/models/baichuan.py | 16 ++++++---------- vllm/model_executor/models/bloom.py | 8 ++------ vllm/model_executor/models/chatglm.py | 8 ++------ vllm/model_executor/models/falcon.py | 8 ++------ vllm/model_executor/models/gpt2.py | 8 ++------ vllm/model_executor/models/gpt_bigcode.py | 8 ++------ vllm/model_executor/models/gpt_j.py | 8 ++------ vllm/model_executor/models/gpt_neox.py | 8 ++------ vllm/model_executor/models/internlm.py | 8 ++------ vllm/model_executor/models/llama.py | 8 ++------ vllm/model_executor/models/mistral.py | 8 ++------ vllm/model_executor/models/mpt.py | 8 ++------ vllm/model_executor/models/opt.py | 8 ++------ vllm/model_executor/models/qwen.py | 8 ++------ vllm/model_executor/models/yi.py | 8 ++------ 17 files changed, 40 insertions(+), 102 deletions(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index b18f99223f10a..ea55661dc3129 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -67,8 +67,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def get_model(model_config: ModelConfig) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) - # Get the quantization config. - quant_config = None + # Get the (maybe quantized) linear method. + linear_method = None if model_config.quantization is not None: if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION: raise ValueError( @@ -90,12 +90,13 @@ def get_model(model_config: ModelConfig) -> nn.Module: f"{model_config.dtype} is not supported for quantization " f"method {model_config.quantization}. Supported dtypes: " f"{supported_dtypes}") + linear_method = quant_config.get_linear_method() with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION: - model = model_class(model_config.hf_config, quant_config) + model = model_class(model_config.hf_config, linear_method) else: model = model_class(model_config.hf_config) if model_config.load_format == "dummy": diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 7c4ed7f6da513..46d616fb54575 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -287,14 +287,11 @@ class AquilaForCausalLM(nn.Module): def __init__( self, config, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.model = AquilaModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 4571ab30431aa..6399b480d9d69 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -302,14 +302,10 @@ class BaiChuanBaseForCausalLM(nn.Module): def __init__(self, config, position_embedding: str, - quant_config: Optional[QuantizationConfig] = None): + linear_method: Optional[LinearMethodBase] = None): super().__init__() self.config = config - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.model = BaiChuanModel(config, position_embedding, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) @@ -361,13 +357,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b def __init__(self, config, - quant_config: Optional[QuantizationConfig] = None): - super().__init__(config, "ALIBI", quant_config) + linear_method: Optional[LinearMethodBase] = None): + super().__init__(config, "ALIBI", linear_method) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b def __init__(self, config, - quant_config: Optional[QuantizationConfig] = None): - super().__init__(config, "ROPE", quant_config) + linear_method: Optional[LinearMethodBase] = None): + super().__init__(config, "ROPE", linear_method) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index e993871082809..b522af7a83bcd 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -280,15 +280,11 @@ class BloomForCausalLM(nn.Module): def __init__( self, config: BloomConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.transformer = BloomModel(config, linear_method) self.lm_head_weight = self.transformer.word_embeddings.weight self.sampler = Sampler(config.vocab_size) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index edf0d3579f679..e76f34aad1664 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -339,15 +339,11 @@ class ChatGLMForCausalLM(nn.Module): def __init__( self, config: ChatGLMConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config: ChatGLMConfig = config - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.transformer = ChatGLMModel(config, linear_method) self.lm_head_weight = self.transformer.output_layer.weight self.sampler = Sampler(config.padded_vocab_size) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index b34fb3ad6de0d..70b18cf9462f1 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -375,15 +375,11 @@ class FalconForCausalLM(nn.Module): def __init__( self, config: FalconConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.transformer = FalconModel(config, linear_method) self.lm_head = ParallelLMHead( config.vocab_size, diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 5af4b9f8f628d..8f900109c9f25 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -225,15 +225,11 @@ class GPT2LMHeadModel(nn.Module): def __init__( self, config: GPT2Config, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.transformer = GPT2Model(config, linear_method) self.lm_head_weight = self.transformer.wte.weight self.sampler = Sampler(config.vocab_size) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 216b29e9a196c..30905e858ead3 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -244,15 +244,11 @@ class GPTBigCodeForCausalLM(nn.Module): def __init__( self, config: GPTBigCodeConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.transformer = GPTBigCodeModel(config, linear_method) self.lm_head_weight = self.transformer.wte.weight self.sampler = Sampler(config.vocab_size) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index b23e4c3d4d16d..dc05a1a3a7f82 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -228,15 +228,11 @@ class GPTJForCausalLM(nn.Module): def __init__( self, config: GPTJConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method assert not config.tie_word_embeddings self.transformer = GPTJModel(config, linear_method) self.lm_head = ParallelLMHead( diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 6e3996dee665b..1dcd65b983df9 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -239,15 +239,11 @@ class GPTNeoXForCausalLM(nn.Module): def __init__( self, config, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.gpt_neox = GPTNeoXModel(config, linear_method) self.embed_out = ParallelLMHead( config.vocab_size, diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 046f2ff6faeef..6f546f8b487c4 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -234,15 +234,11 @@ class InternLMForCausalLM(nn.Module): def __init__( self, config, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.model = InternLMModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8c885153c4d0d..f01878becafae 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -274,15 +274,11 @@ class LlamaForCausalLM(nn.Module): def __init__( self, config: LlamaConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.config = config - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.model = LlamaModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 4587432c3ac3d..e9bed98fbecf6 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -269,15 +269,11 @@ class MistralForCausalLM(nn.Module): def __init__( self, config: MistralConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.config = config - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.model = MistralModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 8c6237714ef30..fd6e6aac6ac06 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -242,16 +242,12 @@ class MptForCausalLM(nn.Module): def __init__( self, config: MptConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config assert config.tie_word_embeddings - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.transformer = MptModel(config, linear_method) self.lm_head_weight = self.transformer.wte.weight diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 482310a3b4c0a..a1102df039ebd 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -301,15 +301,11 @@ class OPTForCausalLM(nn.Module): def __init__( self, config, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.model = OPTModel(config, linear_method) self.lm_head_weight = self.model.decoder.embed_tokens.weight self.sampler = Sampler(config.vocab_size) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 014e954969c07..0cb14e31c60ea 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -236,15 +236,11 @@ class QWenLMHeadModel(nn.Module): def __init__( self, config: QWenConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.config = config - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.transformer = QWenModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py index 1799df3c40687..f9e8f5e101a78 100644 --- a/vllm/model_executor/models/yi.py +++ b/vllm/model_executor/models/yi.py @@ -272,15 +272,11 @@ class YiForCausalLM(nn.Module): def __init__( self, config: YiConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.config = config - self.quant_config = quant_config - if quant_config is not None: - linear_method = quant_config.get_linear_method() - else: - linear_method = None + self.linear_method = linear_method self.model = YiModel(config, linear_method) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) From 82c76b10112aece6fadff65eb351adf0582945cc Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 11 Nov 2023 06:23:22 +0000 Subject: [PATCH 40/51] Simplify code --- vllm/model_executor/models/aquila.py | 4 ---- vllm/model_executor/models/baichuan.py | 4 ---- vllm/model_executor/models/bloom.py | 4 ---- vllm/model_executor/models/chatglm.py | 4 ---- vllm/model_executor/models/falcon.py | 4 ---- vllm/model_executor/models/gpt2.py | 4 ---- vllm/model_executor/models/gpt_bigcode.py | 4 ---- vllm/model_executor/models/gpt_j.py | 4 ---- vllm/model_executor/models/gpt_neox.py | 4 ---- vllm/model_executor/models/internlm.py | 4 ---- vllm/model_executor/models/llama.py | 4 ---- vllm/model_executor/models/mistral.py | 4 ---- vllm/model_executor/models/mpt.py | 4 ---- vllm/model_executor/models/opt.py | 4 ---- vllm/model_executor/models/qwen.py | 4 ---- vllm/model_executor/models/yi.py | 4 ---- 16 files changed, 64 deletions(-) diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 46d616fb54575..0fc004dc9caee 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -42,7 +42,6 @@ VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -64,12 +63,10 @@ def __init__( self.gate_up_proj = PackedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - gather_output=False, linear_method=linear_method) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - input_is_parallel=True, linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -143,7 +140,6 @@ def __init__( self.total_num_heads * self.head_dim, hidden_size, bias=False, - input_is_parallel=True, linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 6399b480d9d69..712fe61ae5910 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -42,7 +42,6 @@ VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -89,12 +88,10 @@ def __init__( self.gate_up_proj = PackedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - gather_output=False, linear_method=linear_method) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - input_is_parallel=True, linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -146,7 +143,6 @@ def __init__( self.total_num_heads * self.head_dim, hidden_size, bias=False, - input_is_parallel=True, linear_method=linear_method, ) # Create the alibi slopes and slice them. diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index b522af7a83bcd..1d379a623c76d 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -39,7 +39,6 @@ VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -100,7 +99,6 @@ def __init__( self.hidden_size, self.hidden_size, bias=True, - input_is_parallel=True, linear_method=linear_method, ) @@ -145,14 +143,12 @@ def __init__( self.dense_h_to_4h = ColumnParallelLinear( hidden_size, 4 * hidden_size, - gather_output=False, linear_method=linear_method, ) self.act = get_act_fn("gelu") self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, hidden_size, - input_is_parallel=True, linear_method=linear_method, ) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index e76f34aad1664..1dbc9a0847baa 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -25,7 +25,6 @@ VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -77,7 +76,6 @@ def __init__( self.total_num_heads * self.head_dim, config.hidden_size, bias=config.add_bias_linear, - input_is_parallel=True, linear_method=linear_method, ) @@ -140,7 +138,6 @@ def __init__( config.hidden_size, [config.ffn_hidden_size] * 2, bias=config.add_bias_linear, - gather_output=False, linear_method=linear_method, ) @@ -151,7 +148,6 @@ def __init__( config.ffn_hidden_size, config.hidden_size, bias=config.add_bias_linear, - input_is_parallel=True, linear_method=linear_method, ) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 70b18cf9462f1..3307d05494429 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -41,7 +41,6 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -131,7 +130,6 @@ def __init__( self.hidden_size, self.hidden_size, bias=config.bias, - input_is_parallel=True, skip_bias_add=True, reduce_results=self.reduce_row_parallel_results) @@ -206,7 +204,6 @@ def __init__( self.dense_h_to_4h = ColumnParallelLinear(hidden_size, 4 * hidden_size, bias=config.bias, - gather_output=False, skip_bias_add=True, linear_method=linear_method) self.act = nn.GELU() @@ -216,7 +213,6 @@ def __init__( 4 * hidden_size, hidden_size, bias=config.bias, - input_is_parallel=True, skip_bias_add=True, reduce_results=self.reduce_row_parallel_results, linear_method=linear_method) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 8f900109c9f25..d540f74724202 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -39,7 +39,6 @@ VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -75,7 +74,6 @@ def __init__( self.hidden_size, self.hidden_size, bias=True, - input_is_parallel=True, linear_method=linear_method, ) self.attn = PagedAttention(self.num_heads, @@ -112,14 +110,12 @@ def __init__( hidden_size, intermediate_size, bias=True, - gather_output=False, linear_method=linear_method, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, - input_is_parallel=True, linear_method=linear_method, ) self.act = get_act_fn(config.activation_function) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 30905e858ead3..1e489e97052a7 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -40,7 +40,6 @@ VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -87,7 +86,6 @@ def __init__( self.hidden_size, self.hidden_size, bias=True, - input_is_parallel=True, linear_method=linear_method, ) self.attn = PagedAttention(self.num_heads, @@ -131,14 +129,12 @@ def __init__( hidden_size, intermediate_size, bias=True, - gather_output=False, linear_method=linear_method, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, - input_is_parallel=True, linear_method=linear_method, ) self.act = get_act_fn(config.activation_function) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index dc05a1a3a7f82..a5b77138bd17f 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -38,7 +38,6 @@ VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -69,7 +68,6 @@ def __init__( config.hidden_size, config.hidden_size, bias=False, - input_is_parallel=True, linear_method=linear_method, ) @@ -123,13 +121,11 @@ def __init__( self.fc_in = ColumnParallelLinear( hidden_size, intermediate_size, - gather_output=False, linear_method=linear_method, ) self.fc_out = RowParallelLinear( intermediate_size, hidden_size, - input_is_parallel=True, linear_method=linear_method, ) self.act = get_act_fn(config.activation_function) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 1dcd65b983df9..5c40783262ce7 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -38,7 +38,6 @@ VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -73,7 +72,6 @@ def __init__( self.dense = RowParallelLinear( config.hidden_size, config.hidden_size, - input_is_parallel=True, linear_method=linear_method, ) @@ -119,13 +117,11 @@ def __init__( self.dense_h_to_4h = ColumnParallelLinear( config.hidden_size, config.intermediate_size, - gather_output=False, linear_method=linear_method, ) self.dense_4h_to_h = RowParallelLinear( config.intermediate_size, config.hidden_size, - input_is_parallel=True, linear_method=linear_method, ) self.act = get_act_fn(config.hidden_act) diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 6f546f8b487c4..3a030447eddb4 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -18,7 +18,6 @@ VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -39,12 +38,10 @@ def __init__( self.gate_up_proj = PackedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - gather_output=False, linear_method=linear_method) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - input_is_parallel=True, linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -93,7 +90,6 @@ def __init__( self.total_num_heads * self.head_dim, hidden_size, bias=bias, - input_is_parallel=True, linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index f01878becafae..d786e1cda30b6 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -44,7 +44,6 @@ VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -65,12 +64,10 @@ def __init__( self.gate_up_proj = PackedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - gather_output=False, linear_method=linear_method) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - input_is_parallel=True, linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -131,7 +128,6 @@ def __init__( self.total_num_heads * self.head_dim, hidden_size, bias=False, - input_is_parallel=True, linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index e9bed98fbecf6..3032a86ad7505 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -44,7 +44,6 @@ VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -65,12 +64,10 @@ def __init__( self.gate_up_proj = PackedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - gather_output=False, linear_method=linear_method) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - input_is_parallel=True, linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -129,7 +126,6 @@ def __init__(self, self.total_num_heads * self.head_dim, hidden_size, bias=False, - input_is_parallel=True, linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE(self.num_heads, diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index fd6e6aac6ac06..30ccb9a4295c9 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -19,7 +19,6 @@ VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -71,7 +70,6 @@ def __init__( self.d_model, self.d_model, bias=not config.no_bias, - input_is_parallel=True, linear_method=linear_method, ) @@ -130,7 +128,6 @@ def __init__( hidden_size, intermediate_size, bias=not config.no_bias, - gather_output=False, linear_method=linear_method, ) self.act = get_act_fn("gelu") @@ -138,7 +135,6 @@ def __init__( intermediate_size, hidden_size, bias=not config.no_bias, - input_is_parallel=True, linear_method=linear_method, ) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index a1102df039ebd..2dde92577bff6 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -40,7 +40,6 @@ VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -91,7 +90,6 @@ def __init__( embed_dim, embed_dim, bias=bias, - input_is_parallel=True, linear_method=linear_method, ) self.attn = PagedAttention(self.num_heads, @@ -140,14 +138,12 @@ def __init__( self.embed_dim, config.ffn_dim, bias=config.enable_bias, - gather_output=False, linear_method=linear_method, ) self.fc2 = RowParallelLinear( config.ffn_dim, self.embed_dim, bias=config.enable_bias, - input_is_parallel=True, linear_method=linear_method, ) self.final_layer_norm = nn.LayerNorm( diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 0cb14e31c60ea..9953291ad986f 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -26,7 +26,6 @@ VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -48,12 +47,10 @@ def __init__( self.gate_up_proj = PackedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - gather_output=False, linear_method=linear_method) self.c_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - input_is_parallel=True, linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -99,7 +96,6 @@ def __init__( self.total_num_heads * self.head_dim, hidden_size, bias=False, - input_is_parallel=True, linear_method=linear_method, ) self.scaling = self.head_dim**-0.5 diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py index f9e8f5e101a78..71ec672ba1a3f 100644 --- a/vllm/model_executor/models/yi.py +++ b/vllm/model_executor/models/yi.py @@ -44,7 +44,6 @@ VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) -from vllm.model_executor.quantization_utils import QuantizationConfig from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput @@ -65,12 +64,10 @@ def __init__( self.gate_up_proj = PackedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - gather_output=False, linear_method=linear_method) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - input_is_parallel=True, linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -131,7 +128,6 @@ def __init__( self.total_num_heads * self.head_dim, hidden_size, bias=False, - input_is_parallel=True, linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( From f53469bbf5a92482a8bf5d19d9bbef10bf138c1c Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 11 Nov 2023 06:26:21 +0000 Subject: [PATCH 41/51] fix --- vllm/model_executor/layers/linear.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 341103fad7350..6baa13a41d0ab 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -300,8 +300,8 @@ def __init__( self.num_heads = divide(self.total_num_heads, tp_size) if tp_size >= self.total_num_kv_heads: self.num_kv_heads = 1 - self.num_kv_head_replicas = divide(self.total_num_kv_heads, - tp_size) + self.num_kv_head_replicas = divide(tp_size, + self.total_num_kv_heads) else: self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) self.num_kv_head_replicas = 1 From d4c07985c76552576f19e11a0718d336d395bd7b Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 11 Nov 2023 08:19:12 +0000 Subject: [PATCH 42/51] Add comment for linear.py --- vllm/model_executor/layers/linear.py | 93 ++++++++++++++++++---------- 1 file changed, 62 insertions(+), 31 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 6baa13a41d0ab..fa0904ffad5b6 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -18,10 +18,11 @@ class LinearMethodBase(ABC): - + """Base class for different quantized linear methods.""" @abstractmethod def create_weights(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: + """Create weights for a linear layer.""" raise NotImplementedError @abstractmethod @@ -29,11 +30,17 @@ def apply_weights(self, weights: Dict[str, torch.Tensor], x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """Apply the weights to the input tensor.""" raise NotImplementedError class FullPrecisionLinearMethod(LinearMethodBase): + """Full precision linear method without quantization. + Args: + separate_bias_add: If true, add bias separately after matrix + multiplication. + """ def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add @@ -60,7 +67,16 @@ def apply_weights(self, class ReplicatedLinear(torch.nn.Module): + """Replicated linear layer. + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + linear_method: (Maybe quantized) linear method. + """ def __init__( self, input_size: int, @@ -108,12 +124,10 @@ class ColumnParallelLinear(torch.nn.Module): The linear layer is defined as Y = XA + b. A is parallelized along its second dimension as A = [A_1, ..., A_p]. - Arguments: + Args: input_size: first dimension of matrix A. output_size: second dimension of matrix A. - - Keyword Arguments - bias: If true, add bias + bias: If true, add bias. gather_output: If true, call all-gather on output and make Y available to all GPUs, otherwise, every GPU will have its output which is Y_i = XA_i @@ -121,7 +135,7 @@ class ColumnParallelLinear(torch.nn.Module): bias can be fused with other element-wise operations. we skip adding bias but instead return it. params_dtype: Data type for the parameters. - quant_config: Quantization configuration. + linear_method: (Maybe quantized) linear method. """ def __init__( @@ -180,15 +194,6 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param_data.copy_(loaded_weight) def forward(self, input_): - """Forward of ColumnParallelLinear - - Args: - input_: Tensor whose last dimension is `input_size`. - - Returns: - - output - - bias - """ bias = self.bias if not self.skip_bias_add else None # Matrix multiply. @@ -204,7 +209,25 @@ def forward(self, input_): class PackedColumnParallelLinear(ColumnParallelLinear): - + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Args: + input_size: input dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make the output + available to all GPUs, otherwise, every GPU will have + its own output. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + linear_method: (Maybe quantized) linear method. + """ def __init__( self, input_size: int, @@ -277,7 +300,28 @@ def weight_loader(self, class QKVParallelLinear(ColumnParallelLinear): - + """Linear layers for the attention's QKV transformation. + + Linear layers for the linear transformation of the query, key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + linear_method: (Maybe quantized) linear method. + """ def __init__( self, hidden_size: int, @@ -392,8 +436,6 @@ class RowParallelLinear(torch.nn.Module): Arguments: input_size: first dimension of matrix A. output_size: second dimension of matrix A. - - Keyword Arguments: bias: If true, add bias. Note that bias is not parallelized. input_is_parallel: If true, we assume that the input is already split across the GPUs and we do not split @@ -402,7 +444,7 @@ class RowParallelLinear(torch.nn.Module): bias can be fused with other element-wise operations. We skip adding bias but instead return it. params_dtype: Data type for the parameters. - quant_config: Quantization configuration. + linear_method: (Maybe quantized) linear method. """ def __init__( @@ -468,17 +510,6 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): param_data.copy_(loaded_weight) def forward(self, input_): - """Forward of RowParallelLinear - - Args: - input_: tensor whose last dimension is `input_size`. If - `input_is_parallel` is set, then the last dimension - is `input_size // tp_size`. - - Returns: - - output - - bias - """ # Set up backprop all-reduce. if self.input_is_parallel: input_parallel = input_ From f0e7f440bc251e5c67b992b9780a609c3243969d Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 11 Nov 2023 20:30:01 +0000 Subject: [PATCH 43/51] Add comments --- .../layers/quantized_linear/awq.py | 5 +++++ .../layers/quantized_linear/squeezellm.py | 6 ++++++ .../layers/vocab_parallel_embedding.py | 18 +++++++++++++++--- vllm/model_executor/utils.py | 9 +++++++++ 4 files changed, 35 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py index 256bd84fd89bc..f52b5b590895f 100644 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ b/vllm/model_executor/layers/quantized_linear/awq.py @@ -68,6 +68,11 @@ def get_linear_method(self) -> "AWQLinearMethod": class AWQLinearMethod(LinearMethodBase): + """Linear method for AWQ. + + Args: + quant_config: The AWQ quantization config. + """ def __init__(self, quant_config: AWQConfig): self.quant_config = quant_config diff --git a/vllm/model_executor/layers/quantized_linear/squeezellm.py b/vllm/model_executor/layers/quantized_linear/squeezellm.py index 8bf0c232eac83..0e8a11fa0c50c 100644 --- a/vllm/model_executor/layers/quantized_linear/squeezellm.py +++ b/vllm/model_executor/layers/quantized_linear/squeezellm.py @@ -57,6 +57,12 @@ def get_linear_method(self) -> "SqueezeLLMLinearMethod": class SqueezeLLMLinearMethod(LinearMethodBase): + """Linear method for SqueezeLLM. + + Args: + quant_config: The SqueezeLLM quantization config. + """ + def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 85d0370fe8fc7..7139cd06ec54e 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -36,9 +36,10 @@ def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, class VocabParallelEmbedding(torch.nn.Module): """Embedding parallelized in the vocabulary dimension. - This is mainly adapted from torch.nn.Embedding and all the default - values are kept. - Arguments: + Adapted from torch.nn.Embedding, note that we pad the vocabulary size to + make sure it is divisible by the number of model parallel GPUs. + + Args: num_embeddings: vocabulary size. embedding_dim: size of hidden state. params_dtype: type of the parameters. @@ -102,7 +103,18 @@ def forward(self, input_): class ParallelLMHead(VocabParallelEmbedding): + """Parallelized LM head. + + Output logits weight matrices used in the Sampler. The weight and bias + tensors are padded to make sure they are divisible by the number of + model parallel GPUs. + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + bias: whether to use bias. + params_dtype: type of the parameters. + """ def __init__(self, num_embeddings: int, embedding_dim: int, diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index 501906416b56b..ed23ee24f75c6 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -18,6 +18,15 @@ def set_weight_attrs( weight: torch.Tensor, weight_attrs: Optional[dict[str, Any]], ): + """Set attributes on a weight tensor. + + This method is used to set attributes on a weight tensor. This method + will not overwrite existing attributes. + + Args: + weight: The weight tensor. + weight_attrs: A dictionary of attributes to set on the weight tensor. + """ if weight_attrs is None: return for key, value in weight_attrs.items(): From 247252c1fb920b76b9d06c742b422485a19aea0b Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 11 Nov 2023 20:50:18 +0000 Subject: [PATCH 44/51] code cleanup --- vllm/model_executor/layers/linear.py | 5 +++++ .../layers/quantized_linear/__init__.py | 22 +++++++++++++++++++ .../layers/quantized_linear/awq.py | 2 +- .../quantized_linear/base_config.py} | 0 .../layers/quantized_linear/squeezellm.py | 3 +-- .../layers/vocab_parallel_embedding.py | 1 + .../quantization_utils/__init__.py | 22 ------------------- vllm/model_executor/weight_utils.py | 6 ++--- 8 files changed, 33 insertions(+), 28 deletions(-) rename vllm/model_executor/{quantization_utils/base.py => layers/quantized_linear/base_config.py} (100%) delete mode 100644 vllm/model_executor/quantization_utils/__init__.py diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index fa0904ffad5b6..f0f5fbca9a16c 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -19,6 +19,7 @@ class LinearMethodBase(ABC): """Base class for different quantized linear methods.""" + @abstractmethod def create_weights(self, input_size: int, output_size: int, params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: @@ -41,6 +42,7 @@ class FullPrecisionLinearMethod(LinearMethodBase): separate_bias_add: If true, add bias separately after matrix multiplication. """ + def __init__(self, separate_bias_add: bool = False): self.separate_bias_add = separate_bias_add @@ -77,6 +79,7 @@ class ReplicatedLinear(torch.nn.Module): params_dtype: Data type for the parameters. linear_method: (Maybe quantized) linear method. """ + def __init__( self, input_size: int, @@ -228,6 +231,7 @@ class PackedColumnParallelLinear(ColumnParallelLinear): params_dtype: Data type for the parameters. linear_method: (Maybe quantized) linear method. """ + def __init__( self, input_size: int, @@ -322,6 +326,7 @@ class QKVParallelLinear(ColumnParallelLinear): params_dtype: Data type for the parameters. linear_method: (Maybe quantized) linear method. """ + def __init__( self, hidden_size: int, diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantized_linear/__init__.py index e69de29bb2d1d..e3755fbc6cae9 100644 --- a/vllm/model_executor/layers/quantized_linear/__init__.py +++ b/vllm/model_executor/layers/quantized_linear/__init__.py @@ -0,0 +1,22 @@ +from typing import Type + +from vllm.model_executor.layers.quantized_linear.awq import AWQConfig +from vllm.model_executor.layers.quantized_linear.squeezellm import SqueezeLLMConfig +from vllm.model_executor.layers.quantized_linear.base_config import QuantizationConfig + +_QUANTIZATION_CONFIG_REGISTRY = { + "awq": AWQConfig, + "squeezellm": SqueezeLLMConfig, +} + + +def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: + if quantization not in _QUANTIZATION_CONFIG_REGISTRY: + raise ValueError(f"Invalid quantization method: {quantization}") + return _QUANTIZATION_CONFIG_REGISTRY[quantization] + + +__all__ = [ + "QuantizationConfig", + "get_quantization_config", +] diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py index f52b5b590895f..9212bff9903b4 100644 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ b/vllm/model_executor/layers/quantized_linear/awq.py @@ -6,7 +6,7 @@ from vllm import quantization_ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) -from vllm.model_executor.quantization_utils.base import QuantizationConfig +from vllm.model_executor.layers.quantized_linear.base_config import QuantizationConfig class AWQConfig(QuantizationConfig): diff --git a/vllm/model_executor/quantization_utils/base.py b/vllm/model_executor/layers/quantized_linear/base_config.py similarity index 100% rename from vllm/model_executor/quantization_utils/base.py rename to vllm/model_executor/layers/quantized_linear/base_config.py diff --git a/vllm/model_executor/layers/quantized_linear/squeezellm.py b/vllm/model_executor/layers/quantized_linear/squeezellm.py index 0e8a11fa0c50c..7a50f037ccf8a 100644 --- a/vllm/model_executor/layers/quantized_linear/squeezellm.py +++ b/vllm/model_executor/layers/quantized_linear/squeezellm.py @@ -6,7 +6,7 @@ from vllm import quantization_ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) -from vllm.model_executor.quantization_utils.base import QuantizationConfig +from vllm.model_executor.layers.quantized_linear.base_config import QuantizationConfig class SqueezeLLMConfig(QuantizationConfig): @@ -63,7 +63,6 @@ class SqueezeLLMLinearMethod(LinearMethodBase): quant_config: The SqueezeLLM quantization config. """ - def __init__(self, quant_config: SqueezeLLMConfig): self.quant_config = quant_config diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 7139cd06ec54e..b08d5555b0faa 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -115,6 +115,7 @@ class ParallelLMHead(VocabParallelEmbedding): bias: whether to use bias. params_dtype: type of the parameters. """ + def __init__(self, num_embeddings: int, embedding_dim: int, diff --git a/vllm/model_executor/quantization_utils/__init__.py b/vllm/model_executor/quantization_utils/__init__.py deleted file mode 100644 index 2b4df01400e95..0000000000000 --- a/vllm/model_executor/quantization_utils/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Type - -from vllm.model_executor.layers.quantized_linear.awq import AWQConfig -from vllm.model_executor.layers.quantized_linear.squeezellm import SqueezeLLMConfig -from vllm.model_executor.quantization_utils.base import QuantizationConfig - -_QUANTIZATION_REGISTRY = { - "awq": AWQConfig, - "squeezellm": SqueezeLLMConfig, -} - - -def get_quant_class(quantization: str) -> Type[QuantizationConfig]: - if quantization not in _QUANTIZATION_REGISTRY: - raise ValueError(f"Invalid quantization method: {quantization}") - return _QUANTIZATION_REGISTRY[quantization] - - -__all__ = [ - "QuantizationConfig", - "get_quant_class", -] diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index daffbf4082568..2d609944e5ef2 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -13,8 +13,8 @@ from tqdm.auto import tqdm from vllm.logger import init_logger -from vllm.model_executor.quantization_utils import get_quant_class -from vllm.model_executor.quantization_utils.base import QuantizationConfig +from vllm.model_executor.layers.quantized_linear import ( + get_quantization_config, QuantizationConfig) logger = init_logger(__name__) @@ -98,7 +98,7 @@ def get_quant_config( hf_folder = model_name_or_path config_files = glob.glob(os.path.join(hf_folder, "*.json")) - quant_cls = get_quant_class(quantization) + quant_cls = get_quantization_config(quantization) quant_config_files = [ f for f in config_files if any( f.endswith(x) for x in quant_cls.get_config_filenames()) From dfb4a810540bfc8f8a9fc1fc702f24be749cd53e Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 11 Nov 2023 20:53:20 +0000 Subject: [PATCH 45/51] Add comment --- vllm/model_executor/layers/quantized_linear/base_config.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/layers/quantized_linear/base_config.py b/vllm/model_executor/layers/quantized_linear/base_config.py index 8af16b7e145f3..ea37be3337903 100644 --- a/vllm/model_executor/layers/quantized_linear/base_config.py +++ b/vllm/model_executor/layers/quantized_linear/base_config.py @@ -6,6 +6,7 @@ class QuantizationConfig: + """Base class for quantization configs.""" @classmethod def get_name(cls) -> str: @@ -47,4 +48,5 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: "quantization config.") def get_linear_method(self) -> LinearMethodBase: + """Get the linear method to use for the quantized linear layer.""" raise NotImplementedError From f7501662a252066b6ef8071cb96f0d4dfafe3002 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 16 Nov 2023 00:13:02 +0000 Subject: [PATCH 46/51] Fix review comments --- vllm/engine/async_llm_engine.py | 4 ++-- vllm/model_executor/model_loader.py | 16 +--------------- vllm/model_executor/utils.py | 4 ++-- 3 files changed, 5 insertions(+), 19 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index bc36b64f7df0f..7dcd2eb632c4c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -142,10 +142,10 @@ def abort_request(self, request_id: str, *, verbose: bool = False) -> None: self._request_streams[request_id].finish() - def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]: + def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]: """Get the new requests and finished requests to be sent to the engine.""" - new_requests: List[dict] = [] + new_requests: List[Dict] = [] finished_requests: Set[str] = set() while not self._finished_requests.empty(): diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index ea55661dc3129..976ead9c787ff 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -37,14 +37,6 @@ "YiForCausalLM": YiForCausalLM, } -# FIXME(woosuk): Remove this once all models support quantization. -_MODEL_CLASSES_SUPPORT_QUANTIZATION = [ - LlamaForCausalLM, - MistralForCausalLM, - YiForCausalLM, -] - - @contextlib.contextmanager def _set_default_torch_dtype(dtype: torch.dtype): """Sets the default torch dtype to the given dtype.""" @@ -70,9 +62,6 @@ def get_model(model_config: ModelConfig) -> nn.Module: # Get the (maybe quantized) linear method. linear_method = None if model_config.quantization is not None: - if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION: - raise ValueError( - f"Quantization is not supported for {model_class}.") quant_config = get_quant_config(model_config.quantization, model_config.model, model_config.download_dir) @@ -95,10 +84,7 @@ def get_model(model_config: ModelConfig) -> nn.Module: with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. - if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION: - model = model_class(model_config.hf_config, linear_method) - else: - model = model_class(model_config.hf_config) + model = model_class(model_config.hf_config, linear_method) if model_config.load_format == "dummy": model = model.cuda() # NOTE(woosuk): For accurate performance evaluation, we assign diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index ed23ee24f75c6..336bc1cd005cf 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -1,6 +1,6 @@ """Utils for model executor.""" import random -from typing import Any, Optional +from typing import Any, Dict, Optional import numpy as np import torch @@ -16,7 +16,7 @@ def set_random_seed(seed: int) -> None: def set_weight_attrs( weight: torch.Tensor, - weight_attrs: Optional[dict[str, Any]], + weight_attrs: Optional[Dict[str, Any]], ): """Set attributes on a weight tensor. From fd4f4d5474ed06e35ae4a0b9de578500c6c5d74d Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 16 Nov 2023 00:30:51 +0000 Subject: [PATCH 47/51] fix naming --- vllm/model_executor/layers/linear.py | 14 +++++++------- vllm/model_executor/model_loader.py | 1 + vllm/model_executor/models/aquila.py | 4 ++-- vllm/model_executor/models/baichuan.py | 4 ++-- vllm/model_executor/models/chatglm.py | 4 ++-- vllm/model_executor/models/internlm.py | 4 ++-- vllm/model_executor/models/llama.py | 4 ++-- vllm/model_executor/models/mistral.py | 4 ++-- vllm/model_executor/models/qwen.py | 4 ++-- vllm/model_executor/models/yi.py | 4 ++-- 10 files changed, 24 insertions(+), 23 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index f0f5fbca9a16c..34b1a8915a528 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -35,8 +35,8 @@ def apply_weights(self, raise NotImplementedError -class FullPrecisionLinearMethod(LinearMethodBase): - """Full precision linear method without quantization. +class UnquantizedLinearMethod(LinearMethodBase): + """Linear method without quantization. Args: separate_bias_add: If true, add bias separately after matrix @@ -99,7 +99,7 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype if linear_method is None: - linear_method = FullPrecisionLinearMethod() + linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( self.input_size, self.output_size, self.params_dtype) @@ -165,7 +165,7 @@ def __init__( params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype if linear_method is None: - linear_method = FullPrecisionLinearMethod() + linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( self.input_size, self.output_size_per_partition, self.params_dtype) @@ -211,7 +211,7 @@ def forward(self, input_): return output, output_bias -class PackedColumnParallelLinear(ColumnParallelLinear): +class MergedColumnParallelLinear(ColumnParallelLinear): """Packed linear layers with column parallelism. Similar to ColumnParallelLinear, but the weight matrix is concatenated @@ -297,7 +297,7 @@ def weight_loader(self, else: logger.warning( "Loading a weight without `output_dim` attribute in " - "PackedColumnParallelLinear, assume the weight is " + "MergedColumnParallelLinear, assume the weight is " "the same for all partitions.") assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) @@ -478,7 +478,7 @@ def __init__( self.input_size_per_partition = divide(input_size, self.tp_size) self.skip_bias_add = skip_bias_add if linear_method is None: - linear_method = FullPrecisionLinearMethod() + linear_method = UnquantizedLinearMethod() self.linear_method = linear_method self.linear_weights = self.linear_method.create_weights( self.input_size_per_partition, self.output_size, self.params_dtype) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 976ead9c787ff..fdd860775c47c 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -37,6 +37,7 @@ "YiForCausalLM": YiForCausalLM, } + @contextlib.contextmanager def _set_default_torch_dtype(dtype: torch.dtype): """Sets the default torch dtype to the given dtype.""" diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 0fc004dc9caee..a1604bbba33b2 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -34,7 +34,7 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.linear import (LinearMethodBase, - PackedColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler @@ -60,7 +60,7 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - self.gate_up_proj = PackedColumnParallelLinear( + self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, linear_method=linear_method) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 712fe61ae5910..64bbd5988fe37 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -34,7 +34,7 @@ PagedAttentionWithALiBi) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, - PackedColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler @@ -85,7 +85,7 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - self.gate_up_proj = PackedColumnParallelLinear( + self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, linear_method=linear_method) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 1dbc9a0847baa..673ca2092146a 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -17,7 +17,7 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, - PackedColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler @@ -134,7 +134,7 @@ def __init__( self.add_bias = config.add_bias_linear # Project to 4h. - self.dense_h_to_4h = PackedColumnParallelLinear( + self.dense_h_to_4h = MergedColumnParallelLinear( config.hidden_size, [config.ffn_hidden_size] * 2, bias=config.add_bias_linear, diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 3a030447eddb4..d90f8aaed624c 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -10,7 +10,7 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, - PackedColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler @@ -35,7 +35,7 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - self.gate_up_proj = PackedColumnParallelLinear( + self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, linear_method=linear_method) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index d786e1cda30b6..9381a2390c712 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -36,7 +36,7 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, - PackedColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler @@ -61,7 +61,7 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() - self.gate_up_proj = PackedColumnParallelLinear( + self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, linear_method=linear_method) diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 3032a86ad7505..f9b9120aff80d 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -36,7 +36,7 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, - PackedColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler @@ -61,7 +61,7 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() - self.gate_up_proj = PackedColumnParallelLinear( + self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, linear_method=linear_method) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 9953291ad986f..45710edcc0bb4 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -18,7 +18,7 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, - PackedColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler @@ -44,7 +44,7 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - self.gate_up_proj = PackedColumnParallelLinear( + self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, linear_method=linear_method) diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py index 71ec672ba1a3f..204c33ed42825 100644 --- a/vllm/model_executor/models/yi.py +++ b/vllm/model_executor/models/yi.py @@ -36,7 +36,7 @@ from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (LinearMethodBase, - PackedColumnParallelLinear, + MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler @@ -61,7 +61,7 @@ def __init__( linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() - self.gate_up_proj = PackedColumnParallelLinear( + self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, linear_method=linear_method) From a7dd7f43abef46088d2f7eadee855c9f7092f25d Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 16 Nov 2023 00:41:52 +0000 Subject: [PATCH 48/51] fix comment --- vllm/model_executor/layers/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 34b1a8915a528..810efb67df8d5 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -18,7 +18,7 @@ class LinearMethodBase(ABC): - """Base class for different quantized linear methods.""" + """Base class for different (maybe quantized) linear methods.""" @abstractmethod def create_weights(self, input_size: int, output_size: int, From 18898f73b5b5ee3a105a857b802ed1623e949e5c Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 16 Nov 2023 00:44:19 +0000 Subject: [PATCH 49/51] rename --- .../layers/{quantized_linear => quantization}/__init__.py | 6 +++--- .../layers/{quantized_linear => quantization}/awq.py | 2 +- .../{quantized_linear => quantization}/base_config.py | 0 .../layers/{quantized_linear => quantization}/squeezellm.py | 2 +- vllm/model_executor/weight_utils.py | 4 ++-- 5 files changed, 7 insertions(+), 7 deletions(-) rename vllm/model_executor/layers/{quantized_linear => quantization}/__init__.py (65%) rename vllm/model_executor/layers/{quantized_linear => quantization}/awq.py (98%) rename vllm/model_executor/layers/{quantized_linear => quantization}/base_config.py (100%) rename vllm/model_executor/layers/{quantized_linear => quantization}/squeezellm.py (97%) diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantization/__init__.py similarity index 65% rename from vllm/model_executor/layers/quantized_linear/__init__.py rename to vllm/model_executor/layers/quantization/__init__.py index e3755fbc6cae9..3d937ba64f9fa 100644 --- a/vllm/model_executor/layers/quantized_linear/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,8 +1,8 @@ from typing import Type -from vllm.model_executor.layers.quantized_linear.awq import AWQConfig -from vllm.model_executor.layers.quantized_linear.squeezellm import SqueezeLLMConfig -from vllm.model_executor.layers.quantized_linear.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig _QUANTIZATION_CONFIG_REGISTRY = { "awq": AWQConfig, diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantization/awq.py similarity index 98% rename from vllm/model_executor/layers/quantized_linear/awq.py rename to vllm/model_executor/layers/quantization/awq.py index 9212bff9903b4..917c08782e5f5 100644 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -6,7 +6,7 @@ from vllm import quantization_ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) -from vllm.model_executor.layers.quantized_linear.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig class AWQConfig(QuantizationConfig): diff --git a/vllm/model_executor/layers/quantized_linear/base_config.py b/vllm/model_executor/layers/quantization/base_config.py similarity index 100% rename from vllm/model_executor/layers/quantized_linear/base_config.py rename to vllm/model_executor/layers/quantization/base_config.py diff --git a/vllm/model_executor/layers/quantized_linear/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py similarity index 97% rename from vllm/model_executor/layers/quantized_linear/squeezellm.py rename to vllm/model_executor/layers/quantization/squeezellm.py index 7a50f037ccf8a..9a3936dbc7029 100644 --- a/vllm/model_executor/layers/quantized_linear/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -6,7 +6,7 @@ from vllm import quantization_ops from vllm.model_executor.layers.linear import (LinearMethodBase, set_weight_attrs) -from vllm.model_executor.layers.quantized_linear.base_config import QuantizationConfig +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig class SqueezeLLMConfig(QuantizationConfig): diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 2d609944e5ef2..0a05fe7e340b4 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -13,8 +13,8 @@ from tqdm.auto import tqdm from vllm.logger import init_logger -from vllm.model_executor.layers.quantized_linear import ( - get_quantization_config, QuantizationConfig) +from vllm.model_executor.layers.quantization import (get_quantization_config, + QuantizationConfig) logger = init_logger(__name__) From 2d01ce0bac7775748a5bf513b7725663c72c9945 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 16 Nov 2023 00:54:30 +0000 Subject: [PATCH 50/51] Fix issues in PR #1640 --- vllm/config.py | 49 ++++++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 19 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 116deee8fa353..d91e18c2ce436 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -140,8 +140,8 @@ def get_head_size(self) -> int: # FIXME(woosuk): This may not be true for all models. return self.hf_config.hidden_size // self.hf_config.num_attention_heads - def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: - """Returns the number of KV heads per GPU worker.""" + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" # For GPTBigCode & Falcon: # NOTE: for falcon, when new_decoder_architecture is True, the # multi_query flag is ignored and we use n_head_kv for the number of @@ -155,23 +155,34 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: # Multi-query attention, only one KV head. # Currently, tensor parallelism is not supported in this case. return 1 - # For Falcon: - if getattr(self.hf_config, "n_head_kv", None) is not None: - return (self.hf_config.n_head_kv // - parallel_config.tensor_parallel_size) - if getattr(self.hf_config, "num_kv_heads", None) is not None: - return (self.hf_config.num_kv_heads // - parallel_config.tensor_parallel_size) - # For LLaMA-2: - if getattr(self.hf_config, "num_key_value_heads", None) is not None: - return (self.hf_config.num_key_value_heads // - parallel_config.tensor_parallel_size) - # For ChatGLM-2: - if getattr(self.hf_config, "multi_query_group_num", None) is not None: - return (self.hf_config.multi_query_group_num // - parallel_config.tensor_parallel_size) - total_num_attention_heads = self.hf_config.num_attention_heads - return total_num_attention_heads // parallel_config.tensor_parallel_size + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_config.num_attention_heads + + def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: + """Returns the number of KV heads per GPU.""" + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, + total_num_kv_heads // parallel_config.tensor_parallel_size) def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_config.num_hidden_layers From 241bfa87458bceb1a8d4a58d25dd0da0c7a494c3 Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Thu, 16 Nov 2023 01:05:36 +0000 Subject: [PATCH 51/51] Fix config --- .../model_executor/layers/quantization/awq.py | 13 +++++------ .../layers/quantization/base_config.py | 22 +++++++++++-------- .../layers/quantization/squeezellm.py | 13 +++++------ 3 files changed, 23 insertions(+), 25 deletions(-) diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 917c08782e5f5..2a077b439e49d 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -36,21 +36,18 @@ def __repr__(self) -> str: f"group_size={self.group_size}, " f"zero_point={self.zero_point})") - @classmethod - def get_name(cls) -> str: + def get_name(self) -> str: return "awq" - @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(self) -> List[torch.dtype]: return [torch.half] - @classmethod - def get_min_capability(cls) -> int: + def get_min_capability(self) -> int: # The AWQ kernel only supports Turing or newer GPUs. return 75 - @classmethod - def get_config_filenames(cls) -> List[str]: + @staticmethod + def get_config_filenames() -> List[str]: return [ "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq "quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index ea37be3337903..116ff903c2290 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -1,3 +1,4 @@ +from abc import ABC, abstractmethod from typing import Any, Dict, List import torch @@ -5,21 +6,21 @@ from vllm.model_executor.layers.linear import LinearMethodBase -class QuantizationConfig: +class QuantizationConfig(ABC): """Base class for quantization configs.""" - @classmethod - def get_name(cls) -> str: + @abstractmethod + def get_name(self) -> str: """Name of the quantization method.""" raise NotImplementedError - @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + @abstractmethod + def get_supported_act_dtypes(self) -> List[torch.dtype]: """List of supported activation dtypes.""" raise NotImplementedError - @classmethod - def get_min_capability(cls) -> int: + @abstractmethod + def get_min_capability(self) -> int: """Minimum GPU capability to support the quantization method. E.g., 70 for Volta, 75 for Turing, 80 for Ampere. @@ -28,12 +29,14 @@ def get_min_capability(cls) -> int: """ raise NotImplementedError - @classmethod - def get_config_filenames(cls) -> List[str]: + @staticmethod + @abstractmethod + def get_config_filenames() -> List[str]: """List of filenames to search for in the model directory.""" raise NotImplementedError @classmethod + @abstractmethod def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": """Create a config class from the model's quantization config.""" raise NotImplementedError @@ -47,6 +50,7 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: raise ValueError(f"Cannot find any of {keys} in the model's " "quantization config.") + @abstractmethod def get_linear_method(self) -> LinearMethodBase: """Get the linear method to use for the quantized linear layer.""" raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index 9a3936dbc7029..a85dd91be7dbd 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -31,20 +31,17 @@ def __init__( def __repr__(self) -> str: return f"SqueezeLLMConfig(weight_bits={self.weight_bits})" - @classmethod - def get_name(cls) -> str: + def get_name(self) -> str: return "squeezellm" - @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: + def get_supported_act_dtypes(self) -> List[torch.dtype]: return [torch.half] - @classmethod - def get_min_capability(cls) -> int: + def get_min_capability(self) -> int: return 70 - @classmethod - def get_config_filenames(cls) -> List[str]: + @staticmethod + def get_config_filenames() -> List[str]: return ["quant_config.json"] @classmethod