diff --git a/vllm/config.py b/vllm/config.py index 116deee8fa353..d91e18c2ce436 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -140,8 +140,8 @@ def get_head_size(self) -> int: # FIXME(woosuk): This may not be true for all models. return self.hf_config.hidden_size // self.hf_config.num_attention_heads - def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: - """Returns the number of KV heads per GPU worker.""" + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" # For GPTBigCode & Falcon: # NOTE: for falcon, when new_decoder_architecture is True, the # multi_query flag is ignored and we use n_head_kv for the number of @@ -155,23 +155,34 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: # Multi-query attention, only one KV head. # Currently, tensor parallelism is not supported in this case. return 1 - # For Falcon: - if getattr(self.hf_config, "n_head_kv", None) is not None: - return (self.hf_config.n_head_kv // - parallel_config.tensor_parallel_size) - if getattr(self.hf_config, "num_kv_heads", None) is not None: - return (self.hf_config.num_kv_heads // - parallel_config.tensor_parallel_size) - # For LLaMA-2: - if getattr(self.hf_config, "num_key_value_heads", None) is not None: - return (self.hf_config.num_key_value_heads // - parallel_config.tensor_parallel_size) - # For ChatGLM-2: - if getattr(self.hf_config, "multi_query_group_num", None) is not None: - return (self.hf_config.multi_query_group_num // - parallel_config.tensor_parallel_size) - total_num_attention_heads = self.hf_config.num_attention_heads - return total_num_attention_heads // parallel_config.tensor_parallel_size + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_config.num_attention_heads + + def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: + """Returns the number of KV heads per GPU.""" + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, + total_num_kv_heads // parallel_config.tensor_parallel_size) def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_config.num_hidden_layers diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index bc36b64f7df0f..7dcd2eb632c4c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -142,10 +142,10 @@ def abort_request(self, request_id: str, *, verbose: bool = False) -> None: self._request_streams[request_id].finish() - def get_new_and_finished_requests(self) -> Tuple[List[dict], Set[str]]: + def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]: """Get the new requests and finished requests to be sent to the engine.""" - new_requests: List[dict] = [] + new_requests: List[Dict] = [] finished_requests: Set[str] = set() while not self._finished_requests.empty(): diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py new file mode 100644 index 0000000000000..810efb67df8d5 --- /dev/null +++ b/vllm/model_executor/layers/linear.py @@ -0,0 +1,541 @@ +from abc import ABC, abstractmethod +from typing import Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather) +from vllm.model_executor.parallel_utils.utils import ( + divide, split_tensor_along_last_dim) +from vllm.model_executor.utils import set_weight_attrs +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +class LinearMethodBase(ABC): + """Base class for different (maybe quantized) linear methods.""" + + @abstractmethod + def create_weights(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: + """Create weights for a linear layer.""" + raise NotImplementedError + + @abstractmethod + def apply_weights(self, + weights: Dict[str, torch.Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """Apply the weights to the input tensor.""" + raise NotImplementedError + + +class UnquantizedLinearMethod(LinearMethodBase): + """Linear method without quantization. + + Args: + separate_bias_add: If true, add bias separately after matrix + multiplication. + """ + + def __init__(self, separate_bias_add: bool = False): + self.separate_bias_add = separate_bias_add + + def create_weights(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: + weight = Parameter(torch.empty(output_size, + input_size, + device=torch.cuda.current_device(), + dtype=params_dtype), + requires_grad=False) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + return {"weight": weight} + + def apply_weights(self, + weights: Dict[str, torch.Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + weight = weights["weight"] + if self.separate_bias_add: + if bias: + return F.linear(x, weight) + bias + return F.linear(x, weight) + return F.linear(x, weight, bias) + + +class ReplicatedLinear(torch.nn.Module): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + linear_method: (Maybe quantized) linear method. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + if linear_method is None: + linear_method = UnquantizedLinearMethod() + self.linear_method = linear_method + self.linear_weights = self.linear_method.create_weights( + self.input_size, self.output_size, self.params_dtype) + for name, weight in self.linear_weights.items(): + self.register_parameter(name, weight) + if bias: + self.bias = Parameter( + torch.empty(self.output_size, + device=torch.cuda.current_device(), + dtype=self.params_dtype)) + set_weight_attrs(self.bias, {"output_dim": 0}) + else: + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + bias = self.bias if not self.skip_bias_add else None + output = self.linear_method.apply_weights(self.linear_weights, x, bias) + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class ColumnParallelLinear(torch.nn.Module): + """Linear layer with column parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its second dimension as A = [A_1, ..., A_p]. + + Args: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make Y available + to all GPUs, otherwise, every GPU will have its output + which is Y_i = XA_i + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + linear_method: (Maybe quantized) linear method. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.gather_output = gather_output + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + self.output_size_per_partition = divide(output_size, tp_size) + self.skip_bias_add = skip_bias_add + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + if linear_method is None: + linear_method = UnquantizedLinearMethod() + self.linear_method = linear_method + self.linear_weights = self.linear_method.create_weights( + self.input_size, self.output_size_per_partition, self.params_dtype) + for name, weight in self.linear_weights.items(): + self.register_parameter(name, weight) + set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + if bias: + self.bias = Parameter( + torch.empty(self.output_size_per_partition, + device=torch.cuda.current_device(), + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + output_dim = getattr(param, "output_dim", None) + param_data = param.data + if output_dim is not None: + shard_size = param_data.shape[output_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward(self, input_): + bias = self.bias if not self.skip_bias_add else None + + # Matrix multiply. + output_parallel = self.linear_method.apply_weights( + self.linear_weights, input_, bias) + if self.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = self.bias if self.skip_bias_add else None + return output, output_bias + + +class MergedColumnParallelLinear(ColumnParallelLinear): + """Packed linear layers with column parallelism. + + Similar to ColumnParallelLinear, but the weight matrix is concatenated + along the output dimension. When the weight matrix is loaded, the + different partitions are sharded separately. + + Args: + input_size: input dimension of the linear layer. + output_sizes: list of output dimensions of the linear layer. + bias: If true, add bias. + gather_output: If true, call all-gather on output and make the output + available to all GPUs, otherwise, every GPU will have + its own output. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + linear_method: (Maybe quantized) linear method. + """ + + def __init__( + self, + input_size: int, + output_sizes: List[int], + bias: bool = True, + gather_output: bool = False, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + linear_method: Optional[LinearMethodBase] = None, + ): + self.output_sizes = output_sizes + tp_size = get_tensor_model_parallel_world_size() + assert all(output_size % tp_size == 0 for output_size in output_sizes) + super().__init__(input_size, sum(output_sizes), bias, gather_output, + skip_bias_add, params_dtype, linear_method) + + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[int] = None): + param_data = param.data + output_dim = getattr(param, "output_dim", None) + if loaded_shard_id is None: + # Loaded weight is already packed. + if output_dim is None: + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + current_shard_offset = 0 + shard_offsets = [] + for i, output_size in enumerate(self.output_sizes): + shard_offsets.append((i, current_shard_offset, output_size)) + current_shard_offset += output_size + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + assert loaded_shard_id < len(self.output_sizes) + tp_rank = get_tensor_model_parallel_rank() + tp_size = get_tensor_model_parallel_world_size() + if output_dim is not None: + shard_offset = sum(self.output_sizes[:loaded_shard_id]) // tp_size + shard_size = self.output_sizes[loaded_shard_id] // tp_size + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + else: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "MergedColumnParallelLinear, assume the weight is " + "the same for all partitions.") + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class QKVParallelLinear(ColumnParallelLinear): + """Linear layers for the attention's QKV transformation. + + Linear layers for the linear transformation of the query, key, and value + vectors in the attention layer. The weight matrix is concatenated along + the output dimension. The layer is parallelized along the head dimension. + When the number of key/value heads is smaller than the number of query + heads (e.g., multi-query/grouped-query attention), the key/value head may + be replicated while the query heads are partitioned. + + Args: + hidden_size: input hidden state size of the transformer. + head_size: size of each attention head. + total_num_heads: total number of attention query heads. + total_num_kv_heads: total number of attention key/value heads. If + None, assume total_num_kv_heads = total_num_heads. + bias: If true, add bias. + skip_bias_add: This was added to enable performance optimizations where + bias can be fused with other element-wise operations. we + skip adding bias but instead return it. + params_dtype: Data type for the parameters. + linear_method: (Maybe quantized) linear method. + """ + + def __init__( + self, + hidden_size: int, + head_size: int, + total_num_heads: int, + total_num_kv_heads: Optional[int] = None, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + linear_method: Optional[LinearMethodBase] = None, + ): + self.hidden_size = hidden_size + self.head_size = head_size + self.total_num_heads = total_num_heads + if total_num_kv_heads is None: + total_num_kv_heads = total_num_heads + self.total_num_kv_heads = total_num_kv_heads + # Divide the weight matrix along the last dimension. + tp_size = get_tensor_model_parallel_world_size() + self.num_heads = divide(self.total_num_heads, tp_size) + if tp_size >= self.total_num_kv_heads: + self.num_kv_heads = 1 + self.num_kv_head_replicas = divide(tp_size, + self.total_num_kv_heads) + else: + self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) + self.num_kv_head_replicas = 1 + input_size = self.hidden_size + output_size = (self.num_heads + + 2 * self.num_kv_heads) * tp_size * self.head_size + super().__init__(input_size, output_size, bias, False, skip_bias_add, + params_dtype, linear_method) + + def weight_loader(self, + param: Parameter, + loaded_weight: torch.Tensor, + loaded_shard_id: Optional[str] = None): + param_data = param.data + output_dim = getattr(param, "output_dim", None) + if loaded_shard_id is None: + # Loaded weight is already packed. + if output_dim is None: + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + return + shard_offsets = [ + # (shard_id, shard_offset, shard_size) + ("q", 0, self.total_num_heads * self.head_size), + ("k", self.total_num_heads * self.head_size, + self.total_num_kv_heads * self.head_size), + ("v", (self.total_num_heads + self.total_num_kv_heads) * + self.head_size, self.total_num_kv_heads * self.head_size), + ] + packed_dim = getattr(param, "packed_dim", None) + for shard_id, shard_offset, shard_size in shard_offsets: + # If quantized, we need to adjust the offset and size to account + # for the packing. + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + loaded_weight_shard = loaded_weight.narrow( + output_dim, shard_offset, shard_size) + self.weight_loader(param, loaded_weight_shard, shard_id) + return + + tp_rank = get_tensor_model_parallel_rank() + assert loaded_shard_id in ["q", "k", "v"] + if output_dim is not None: + if loaded_shard_id == "q": + shard_offset = 0 + shard_size = self.num_heads * self.head_size + elif loaded_shard_id == "k": + shard_offset = self.num_heads * self.head_size + shard_size = self.num_kv_heads * self.head_size + elif loaded_shard_id == "v": + shard_offset = (self.num_heads + + self.num_kv_heads) * self.head_size + shard_size = self.num_kv_heads * self.head_size + # If quantized, we need to adjust the offset and size to account + # for the packing. + packed_dim = getattr(param, "packed_dim", None) + if packed_dim == output_dim: + shard_size = shard_size // param.pack_factor + shard_offset = shard_offset // param.pack_factor + param_data = param_data.narrow(output_dim, shard_offset, + shard_size) + shard_id = tp_rank // self.num_kv_head_replicas + start_idx = shard_id * shard_size + loaded_weight = loaded_weight.narrow(output_dim, start_idx, + shard_size) + else: + logger.warning( + "Loading a weight without `output_dim` attribute in " + "QKVParallelLinear, assume the weight is the same " + "for all partitions.") + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + +class RowParallelLinear(torch.nn.Module): + """Linear layer with row parallelism. + + The linear layer is defined as Y = XA + b. A is parallelized along + its first dimension and X along its second dimension as: + - - + | A_1 | + | . | + A = | . | X = [X_1, ..., X_p] + | . | + | A_p | + - - + Arguments: + input_size: first dimension of matrix A. + output_size: second dimension of matrix A. + bias: If true, add bias. Note that bias is not parallelized. + input_is_parallel: If true, we assume that the input is already + split across the GPUs and we do not split + again. + skip_bias_add: This was added to enable performance optimization where + bias can be fused with other element-wise operations. + We skip adding bias but instead return it. + params_dtype: Data type for the parameters. + linear_method: (Maybe quantized) linear method. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + linear_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + # Keep input parameters + self.input_size = input_size + self.output_size = output_size + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + # Divide the weight matrix along the last dimension. + self.tp_size = get_tensor_model_parallel_world_size() + self.input_size_per_partition = divide(input_size, self.tp_size) + self.skip_bias_add = skip_bias_add + if linear_method is None: + linear_method = UnquantizedLinearMethod() + self.linear_method = linear_method + self.linear_weights = self.linear_method.create_weights( + self.input_size_per_partition, self.output_size, self.params_dtype) + for name, weight in self.linear_weights.items(): + self.register_parameter(name, weight) + set_weight_attrs(weight, {"weight_loader": self.weight_loader}) + + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = Parameter( + torch.empty(self.output_size, + device=torch.cuda.current_device(), + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + tp_rank = get_tensor_model_parallel_rank() + input_dim = getattr(param, "input_dim", None) + param_data = param.data + if input_dim is not None: + shard_size = param_data.shape[input_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(input_dim, start_idx, + shard_size) + assert param_data.shape == loaded_weight.shape + param_data.copy_(loaded_weight) + + def forward(self, input_): + # Set up backprop all-reduce. + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.linear_method.apply_weights( + self.linear_weights, input_parallel) + if self.reduce_results and self.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.skip_bias_add: + output = output_ + self.bias if self.bias is not None else output_ + output_bias = None + else: + output = output_ + output_bias = self.bias + return output, output_bias diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py new file mode 100644 index 0000000000000..3d937ba64f9fa --- /dev/null +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -0,0 +1,22 @@ +from typing import Type + +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + +_QUANTIZATION_CONFIG_REGISTRY = { + "awq": AWQConfig, + "squeezellm": SqueezeLLMConfig, +} + + +def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: + if quantization not in _QUANTIZATION_CONFIG_REGISTRY: + raise ValueError(f"Invalid quantization method: {quantization}") + return _QUANTIZATION_CONFIG_REGISTRY[quantization] + + +__all__ = [ + "QuantizationConfig", + "get_quantization_config", +] diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py new file mode 100644 index 0000000000000..2a077b439e49d --- /dev/null +++ b/vllm/model_executor/layers/quantization/awq.py @@ -0,0 +1,155 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import quantization_ops +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + + +class AWQConfig(QuantizationConfig): + """Config class for AWQ. + + Reference: https://arxiv.org/abs/2306.00978 + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + ) -> None: + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"AWQ, but got {self.weight_bits} bits.") + self.pack_factor = 32 // self.weight_bits + + def __repr__(self) -> str: + return (f"AWQConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point})") + + def get_name(self) -> str: + return "awq" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half] + + def get_min_capability(self) -> int: + # The AWQ kernel only supports Turing or newer GPUs. + return 75 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq + "quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": + weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) + group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) + zero_point = cls.get_from_keys(config, ["zero_point"]) + return cls(weight_bits, group_size, zero_point) + + def get_linear_method(self) -> "AWQLinearMethod": + return AWQLinearMethod(self) + + +class AWQLinearMethod(LinearMethodBase): + """Linear method for AWQ. + + Args: + quant_config: The AWQ quantization config. + """ + + def __init__(self, quant_config: AWQConfig): + self.quant_config = quant_config + + def create_weights(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: + if input_size % self.quant_config.group_size != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + if output_size % self.quant_config.pack_factor != 0: + raise ValueError( + "The output size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + + qweight = Parameter( + torch.empty( + input_size, + output_size // self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }) + qzeros = Parameter( + torch.empty( + input_size // self.quant_config.group_size, + output_size // self.quant_config.pack_factor, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qzeros, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }) + scales = Parameter( + torch.empty( + input_size // self.quant_config.group_size, + output_size, + device="cuda", + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(scales, { + "input_dim": 0, + "output_dim": 1, + }) + return { + "qweight": qweight, + "qzeros": qzeros, + "scales": scales, + } + + def apply_weights(self, + weights: Dict[str, torch.Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = weights["qweight"] + qzeros = weights["qzeros"] + scales = weights["scales"] + pack_factor = self.quant_config.pack_factor + out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, )) + reshaped_x = x.reshape(-1, x.shape[-1]) + out = quantization_ops.awq_gemm(reshaped_x, qweight, scales, qzeros, + pack_factor) + if bias is not None: + out = out + bias + return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py new file mode 100644 index 0000000000000..116ff903c2290 --- /dev/null +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -0,0 +1,56 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List + +import torch + +from vllm.model_executor.layers.linear import LinearMethodBase + + +class QuantizationConfig(ABC): + """Base class for quantization configs.""" + + @abstractmethod + def get_name(self) -> str: + """Name of the quantization method.""" + raise NotImplementedError + + @abstractmethod + def get_supported_act_dtypes(self) -> List[torch.dtype]: + """List of supported activation dtypes.""" + raise NotImplementedError + + @abstractmethod + def get_min_capability(self) -> int: + """Minimum GPU capability to support the quantization method. + + E.g., 70 for Volta, 75 for Turing, 80 for Ampere. + This requirement is due to the custom CUDA kernels used by the + quantization method. + """ + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_config_filenames() -> List[str]: + """List of filenames to search for in the model directory.""" + raise NotImplementedError + + @classmethod + @abstractmethod + def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": + """Create a config class from the model's quantization config.""" + raise NotImplementedError + + @staticmethod + def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: + """Get a value from the model's quantization config.""" + for key in keys: + if key in config: + return config[key] + raise ValueError(f"Cannot find any of {keys} in the model's " + "quantization config.") + + @abstractmethod + def get_linear_method(self) -> LinearMethodBase: + """Get the linear method to use for the quantized linear layer.""" + raise NotImplementedError diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py new file mode 100644 index 0000000000000..a85dd91be7dbd --- /dev/null +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -0,0 +1,121 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from vllm import quantization_ops +from vllm.model_executor.layers.linear import (LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import QuantizationConfig + + +class SqueezeLLMConfig(QuantizationConfig): + """Config class for SqueezeLLM. + + Reference: https://arxiv.org/pdf/2306.07629 + """ + + def __init__( + self, + weight_bits: int, + ) -> None: + self.weight_bits = weight_bits + + if self.weight_bits != 4: + raise ValueError( + "Currently, only 4-bit weight quantization is supported for " + f"SqueezeLLM, but got {self.weight_bits} bits.") + + self.pack_factor = 32 // self.weight_bits + + def __repr__(self) -> str: + return f"SqueezeLLMConfig(weight_bits={self.weight_bits})" + + def get_name(self) -> str: + return "squeezellm" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half] + + def get_min_capability(self) -> int: + return 70 + + @staticmethod + def get_config_filenames() -> List[str]: + return ["quant_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": + weight_bits = cls.get_from_keys(config, ["wbits"]) + return cls(weight_bits) + + def get_linear_method(self) -> "SqueezeLLMLinearMethod": + return SqueezeLLMLinearMethod(self) + + +class SqueezeLLMLinearMethod(LinearMethodBase): + """Linear method for SqueezeLLM. + + Args: + quant_config: The SqueezeLLM quantization config. + """ + + def __init__(self, quant_config: SqueezeLLMConfig): + self.quant_config = quant_config + + def create_weights(self, input_size: int, output_size: int, + params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: + if input_size % self.quant_config.pack_factor != 0: + raise ValueError( + "The input size is not aligned with the quantized " + "weight shape. This can be caused by too large " + "tensor parallel size.") + qweight = Parameter( + torch.empty( + input_size // self.quant_config.pack_factor, + output_size, + device="cuda", + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, { + "input_dim": 0, + "output_dim": 1, + "packed_dim": 0, + "pack_factor": self.quant_config.pack_factor, + }) + lookup_table = Parameter( + torch.empty( + output_size, + self.quant_config.weight_bits**2, + device="cuda", + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(lookup_table, { + "output_dim": 0, + }) + return { + "qweight": qweight, + "lookup_table": lookup_table, + } + + def apply_weights(self, + weights: Dict[str, torch.Tensor], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qweight = weights["qweight"] + lookup_table = weights["lookup_table"] + out_shape = x.shape[:-1] + (qweight.shape[-1], ) + reshaped_x = x.reshape(-1, x.shape[-1]) + # NOTE: The output tensor should be zero-initialized. + out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) + quantization_ops.squeezellm_gemm(reshaped_x, qweight, out, + lookup_table) + + if bias is not None: + out = out + bias + return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantized_linear/__init__.py b/vllm/model_executor/layers/quantized_linear/__init__.py deleted file mode 100644 index b09358261d5d1..0000000000000 --- a/vllm/model_executor/layers/quantized_linear/__init__.py +++ /dev/null @@ -1,41 +0,0 @@ -from vllm.model_executor.layers.quantized_linear.awq import ( - AWQColumnParallelLinear, AWQRowParallelLinear) -from vllm.model_executor.layers.quantized_linear.squeezellm import ( - SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear) -from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, - RowParallelLinear) - -_QUANTIZED_LINEAR_REGISTRY = { - "awq": (AWQColumnParallelLinear, AWQRowParallelLinear), - "squeezellm": - (SqueezeLLMColumnParallelLinear, SqueezeLLMRowParallelLinear), -} - - -class ParallelLinear: - - @classmethod - def column(cls, *args, **kwargs) -> ColumnParallelLinear: - quant_config = kwargs.get("quant_config", None) - if quant_config is None: - return ColumnParallelLinear(*args, **kwargs) - - name = quant_config.get_name() - if name not in _QUANTIZED_LINEAR_REGISTRY: - raise ValueError(f"No quantized linear is found for {name}") - - quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][0] - return quant_linear_cls(*args, **kwargs) - - @classmethod - def row(cls, *args, **kwargs) -> RowParallelLinear: - quant_config = kwargs.get("quant_config", None) - if quant_config is None: - return RowParallelLinear(*args, **kwargs) - - name = quant_config.get_name() - if name not in _QUANTIZED_LINEAR_REGISTRY: - raise ValueError(f"No quantized linear is found for {name}") - - quant_linear_cls = _QUANTIZED_LINEAR_REGISTRY[name][1] - return quant_linear_cls(*args, **kwargs) diff --git a/vllm/model_executor/layers/quantized_linear/awq.py b/vllm/model_executor/layers/quantized_linear/awq.py deleted file mode 100644 index 31e341318d400..0000000000000 --- a/vllm/model_executor/layers/quantized_linear/awq.py +++ /dev/null @@ -1,106 +0,0 @@ -from typing import Optional - -import torch -from torch.nn.parameter import Parameter - -from vllm import quantization_ops -from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, - RowParallelLinear) - - -class AWQColumnParallelLinear(ColumnParallelLinear): - - def create_weights(self, dtype: torch.dtype) -> None: - assert self.input_size % self.quant_config.group_size == 0 - if self.output_size_per_partition % self.quant_config.pack_factor != 0: - raise ValueError( - "The tensor parallel size is not aligned with the quantized " - "weight shape. Please use a different tensor parallel size.") - self.qweight = Parameter( - torch.empty( - self.input_size, - self.output_size_per_partition // - self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - requires_grad=False, - ) - self.qzeros = Parameter( - torch.empty( - self.input_size // self.quant_config.group_size, - self.output_size_per_partition // - self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - requires_grad=False, - ) - self.scales = Parameter( - torch.empty( - self.input_size // self.quant_config.group_size, - self.output_size_per_partition, - device="cuda", - dtype=dtype, - ), - requires_grad=False, - ) - - def apply_weights( - self, - x: torch.Tensor, - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - pack_factor = self.quant_config.pack_factor - out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, )) - reshaped_x = x.reshape(-1, x.shape[-1]) - out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, - self.qzeros, pack_factor) - if bias is not None: - out = out + bias - return out.reshape(out_shape) - - -class AWQRowParallelLinear(RowParallelLinear): - - def create_weights(self, dtype: torch.dtype) -> None: - assert self.output_size % self.quant_config.pack_factor == 0 - if self.input_size_per_partition % self.quant_config.group_size != 0: - raise ValueError( - "The tensor parallel size is not aligned with the quantized " - "weight shape. Please use a different tensor parallel size.") - self.qweight = Parameter( - torch.empty( - self.input_size_per_partition, - self.output_size // self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - requires_grad=False, - ) - self.qzeros = Parameter( - torch.empty( - self.input_size_per_partition // self.quant_config.group_size, - self.output_size // self.quant_config.pack_factor, - device="cuda", - dtype=torch.int32, - ), - requires_grad=False, - ) - self.scales = Parameter( - torch.empty( - self.input_size_per_partition // self.quant_config.group_size, - self.output_size, - device="cuda", - dtype=dtype, - ), - requires_grad=False, - ) - - def apply_weights(self, x: torch.Tensor) -> torch.Tensor: - pack_factor = self.quant_config.pack_factor - out_shape = (x.shape[:-1] + (self.qweight.shape[-1] * pack_factor, )) - reshaped_x = x.reshape(-1, x.shape[-1]) - out = quantization_ops.awq_gemm(reshaped_x, self.qweight, self.scales, - self.qzeros, pack_factor) - return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantized_linear/squeezellm.py b/vllm/model_executor/layers/quantized_linear/squeezellm.py deleted file mode 100644 index 3ccbc4e579dc6..0000000000000 --- a/vllm/model_executor/layers/quantized_linear/squeezellm.py +++ /dev/null @@ -1,84 +0,0 @@ -from typing import Optional - -import torch -from torch.nn.parameter import Parameter - -from vllm import quantization_ops -from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, - RowParallelLinear) - - -class SqueezeLLMColumnParallelLinear(ColumnParallelLinear): - - def create_weights(self, dtype: torch.dtype) -> None: - assert self.input_size % self.quant_config.pack_factor == 0 - self.qweight = Parameter( - torch.empty( - self.input_size // self.quant_config.pack_factor, - self.output_size_per_partition, - device="cuda", - dtype=torch.int32, - ), - requires_grad=False, - ) - self.lookup_table = Parameter( - torch.empty( - self.output_size_per_partition, - self.quant_config.weight_bits**2, - device="cuda", - dtype=dtype, - ), - requires_grad=False, - ) - - def apply_weights( - self, - x: torch.Tensor, - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - out_shape = x.shape[:-1] + (self.qweight.shape[-1], ) - reshaped_x = x.reshape(-1, x.shape[-1]) - # NOTE: The output tensor should be zero-initialized. - out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) - quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out, - self.lookup_table) - - if bias is not None: - out = out + bias - return out.reshape(out_shape) - - -class SqueezeLLMRowParallelLinear(RowParallelLinear): - - def create_weights(self, dtype: torch.dtype) -> None: - if self.input_size_per_partition % self.quant_config.pack_factor != 0: - raise ValueError( - "The tensor parallel size is not aligned with the quantized " - "weight shape. Please use a different tensor parallel size.") - self.qweight = Parameter( - torch.empty( - self.input_size_per_partition // self.quant_config.pack_factor, - self.output_size, - device="cuda", - dtype=torch.int32, - ), - requires_grad=False, - ) - self.lookup_table = Parameter( - torch.empty( - self.output_size, - self.quant_config.weight_bits**2, - device="cuda", - dtype=dtype, - ), - requires_grad=False, - ) - - def apply_weights(self, x: torch.Tensor) -> torch.Tensor: - reshaped_x = x.reshape(-1, x.shape[-1]) - out_shape = x.shape[:-1] + (self.qweight.shape[-1], ) - # NOTE: The output tensor should be zero-initialized. - out = torch.zeros(out_shape, device="cuda", dtype=torch.float16) - quantization_ops.squeezellm_gemm(reshaped_x, self.qweight, out, - self.lookup_table) - return out.reshape(out_shape) diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py new file mode 100644 index 0000000000000..b08d5555b0faa --- /dev/null +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -0,0 +1,139 @@ +from typing import Optional, Sequence + +import torch +import torch.nn.functional as F +from torch.nn.parameter import Parameter + +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) +from vllm.model_executor.parallel_utils.utils import divide +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_reduce) +from vllm.model_executor.utils import set_weight_attrs + + +def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int: + """Pad the vocab size to the given value.""" + return ((vocab_size + pad_to - 1) // pad_to) * pad_to + + +def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size: int, + rank: int) -> Sequence[int]: + index_f = rank * per_partition_vocab_size + index_l = index_f + per_partition_vocab_size + return index_f, index_l + + +def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, + world_size: int) -> Sequence[int]: + per_partition_vocab_size = divide(global_vocab_size, world_size) + return vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, + rank) + + +class VocabParallelEmbedding(torch.nn.Module): + """Embedding parallelized in the vocabulary dimension. + + Adapted from torch.nn.Embedding, note that we pad the vocabulary size to + make sure it is divisible by the number of model parallel GPUs. + + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + params_dtype: type of the parameters. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None): + super().__init__() + + # Keep the input dimensions. + self.num_embeddings = num_embeddings + self.num_embeddings_padded = pad_vocab_size(num_embeddings) + self.embedding_dim = embedding_dim + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.tp_size = get_tensor_model_parallel_world_size() + # Divide the weight matrix along the vocaburaly dimension. + self.vocab_start_index, self.vocab_end_index = ( + vocab_range_from_global_vocab_size( + self.num_embeddings_padded, get_tensor_model_parallel_rank(), + self.tp_size)) + self.num_embeddings_per_partition = (self.vocab_end_index - + self.vocab_start_index) + self.weight = Parameter( + torch.empty(self.num_embeddings_per_partition, + self.embedding_dim, + device=torch.cuda.current_device(), + dtype=params_dtype)) + set_weight_attrs(self.weight, { + "parallel_dim": 0, + "weight_loader": self.weight_loader + }) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + parallel_dim = param.parallel_dim + assert loaded_weight.shape[parallel_dim] == self.num_embeddings + loaded_weight = loaded_weight[self.vocab_start_index:self. + vocab_end_index] + param[:loaded_weight.shape[0]].data.copy_(loaded_weight) + + def forward(self, input_): + if self.tp_size > 1: + # Build the mask. + input_mask = ((input_ < self.vocab_start_index) | + (input_ >= self.vocab_end_index)) + # Mask the input. + masked_input = input_.clone() - self.vocab_start_index + masked_input[input_mask] = 0 + else: + masked_input = input_ + # Get the embeddings. + output_parallel = F.embedding(masked_input, self.weight) + # Mask the output embedding. + if self.tp_size > 1: + output_parallel[input_mask, :] = 0.0 + # Reduce across all the model parallel GPUs. + output = tensor_model_parallel_all_reduce(output_parallel) + return output + + +class ParallelLMHead(VocabParallelEmbedding): + """Parallelized LM head. + + Output logits weight matrices used in the Sampler. The weight and bias + tensors are padded to make sure they are divisible by the number of + model parallel GPUs. + + Args: + num_embeddings: vocabulary size. + embedding_dim: size of hidden state. + bias: whether to use bias. + params_dtype: type of the parameters. + """ + + def __init__(self, + num_embeddings: int, + embedding_dim: int, + bias: bool = False, + params_dtype: Optional[torch.dtype] = None): + super().__init__(num_embeddings, embedding_dim, params_dtype) + if bias: + self.bias = Parameter( + torch.empty(self.num_embeddings_per_partition, + device=torch.cuda.current_device(), + dtype=params_dtype)) + set_weight_attrs(self.bias, { + "parallel_dim": 0, + "weight_loader": self.weight_loader + }) + else: + self.register_parameter("bias", None) + + def forward(self, input_): + del input_ + raise RuntimeError("LMHead's weights should be used in the sampler.") diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index b18f99223f10a..fdd860775c47c 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -37,13 +37,6 @@ "YiForCausalLM": YiForCausalLM, } -# FIXME(woosuk): Remove this once all models support quantization. -_MODEL_CLASSES_SUPPORT_QUANTIZATION = [ - LlamaForCausalLM, - MistralForCausalLM, - YiForCausalLM, -] - @contextlib.contextmanager def _set_default_torch_dtype(dtype: torch.dtype): @@ -67,12 +60,9 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def get_model(model_config: ModelConfig) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) - # Get the quantization config. - quant_config = None + # Get the (maybe quantized) linear method. + linear_method = None if model_config.quantization is not None: - if model_class not in _MODEL_CLASSES_SUPPORT_QUANTIZATION: - raise ValueError( - f"Quantization is not supported for {model_class}.") quant_config = get_quant_config(model_config.quantization, model_config.model, model_config.download_dir) @@ -90,14 +80,12 @@ def get_model(model_config: ModelConfig) -> nn.Module: f"{model_config.dtype} is not supported for quantization " f"method {model_config.quantization}. Supported dtypes: " f"{supported_dtypes}") + linear_method = quant_config.get_linear_method() with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. - if model_class in _MODEL_CLASSES_SUPPORT_QUANTIZATION: - model = model_class(model_config.hf_config, quant_config) - else: - model = model_class(model_config.hf_config) + model = model_class(model_config.hf_config, linear_method) if model_config.load_format == "dummy": model = model.cuda() # NOTE(woosuk): For accurate performance evaluation, we assign diff --git a/vllm/model_executor/models/aquila.py b/vllm/model_executor/models/aquila.py index 4dae9e46dad7d..a1604bbba33b2 100644 --- a/vllm/model_executor/models/aquila.py +++ b/vllm/model_executor/models/aquila.py @@ -33,15 +33,17 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import ( - hf_model_weights_iterator, load_padded_tensor_parallel_vocab, - load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) + get_tensor_model_parallel_world_size) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.aquila import AquilaConfig @@ -55,20 +57,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - self.gate_up_proj = ColumnParallelLinear( - hidden_size, - 2 * intermediate_size, + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=False, - gather_output=False, - ) - self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - ) + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -111,6 +110,7 @@ def __init__( rope_theta: float = 10000, max_position_embeddings: int = 8192, rope_scaling: Optional[Dict[str, Any]] = None, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.hidden_size = hidden_size @@ -128,29 +128,29 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.qkv_proj = ColumnParallelLinear( + self.qkv_proj = QKVParallelLinear( hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, bias=False, - gather_output=False, + linear_method=linear_method, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - input_is_parallel=True, + linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( self.num_heads, self.head_dim, self.scaling, - rotary_dim=self.head_dim, base=self.rope_theta, max_position=self.max_position_embeddings, + rotary_dim=self.head_dim, num_kv_heads=self.num_kv_heads, - rope_scaling=rope_scaling, - ) + rope_scaling=rope_scaling) def forward( self, @@ -171,7 +171,11 @@ def forward( class AquilaDecoderLayer(nn.Module): - def __init__(self, config: AquilaConfig): + def __init__( + self, + config: AquilaConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -185,11 +189,13 @@ def __init__(self, config: AquilaConfig): rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, rope_scaling=rope_scaling, + linear_method=linear_method, ) self.mlp = AquilaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, + linear_method=linear_method, ) self.input_layernorm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -226,19 +232,22 @@ def forward( class AquilaModel(nn.Module): - def __init__(self, config: AquilaConfig): + def __init__( + self, + config: AquilaConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - #vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) self.layers = nn.ModuleList([ - AquilaDecoderLayer(config) for _ in range(config.num_hidden_layers) + AquilaDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) ]) self.norm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -271,17 +280,16 @@ def forward( class AquilaForCausalLM(nn.Module): - def __init__(self, config): + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config - self.model = AquilaModel(config) - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.lm_head = ColumnParallelLinear( - config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - ) + self.linear_method = linear_method + self.model = AquilaModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) def forward( @@ -298,79 +306,33 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = [ - "qkv_proj.weight", "gate_proj.weight", "up_proj.weight" - ] - _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tp_size = get_tensor_model_parallel_world_size() - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - q_proj_shard_size = (self.config.hidden_size // tp_size) - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - self.config.num_key_value_heads // tp_size) - attention_weight_specs = [ - # (weight_name, shard_size, offset) - ("q_proj", q_proj_shard_size, 0), - ("k_proj", kv_proj_shard_size, q_proj_shard_size), - ("v_proj", kv_proj_shard_size, - q_proj_shard_size + kv_proj_shard_size), + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] - state_dict = self.state_dict() - + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - - is_attention_weight = False - for weight_name, shard_size, offset in attention_weight_specs: + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "qkv_proj")] - - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[offset:offset + shard_size] - assert param_slice.shape == loaded_weight.shape - - param_slice.copy_(loaded_weight) - is_attention_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_attention_weight: - continue - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True - break - if is_gate_up_weight: - continue - - param = state_dict[name] - if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) - continue - - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 7d0454271a799..64bbd5988fe37 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -30,18 +30,20 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import (PagedAttentionWithRoPE, PagedAttentionWithALiBi) +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import ( - convert_pyslice_to_tensor, hf_model_weights_iterator, - load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.baichuan import BaiChuanConfig @@ -80,20 +82,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - self.gate_up_proj = ColumnParallelLinear( - hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - ) - self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=False, - input_is_parallel=True, - ) + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -116,6 +115,7 @@ def __init__( position_embedding: str, rope_theta: float = 10000, max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.hidden_size = hidden_size @@ -131,17 +131,19 @@ def __init__( self.max_position_embeddings = max_position_embeddings # pylint: disable=invalid-name - self.W_pack = ColumnParallelLinear( + self.W_pack = QKVParallelLinear( hidden_size, - 3 * hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_heads, bias=False, - gather_output=False, + linear_method=linear_method, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - input_is_parallel=True, + linear_method=linear_method, ) # Create the alibi slopes and slice them. if self.postion_embedding == "ALIBI": @@ -188,7 +190,10 @@ def forward( class BaiChuanDecoderLayer(nn.Module): - def __init__(self, config: BaiChuanConfig, position_embedding: str): + def __init__(self, + config: BaiChuanConfig, + position_embedding: str, + linear_method: Optional[LinearMethodBase] = None): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -200,11 +205,13 @@ def __init__(self, config: BaiChuanConfig, position_embedding: str): position_embedding=position_embedding, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, + linear_method=linear_method, ) self.mlp = BaiChuanMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, + linear_method=linear_method, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -241,7 +248,10 @@ def forward( class BaiChuanModel(nn.Module): - def __init__(self, config: BaiChuanConfig, position_embedding: str): + def __init__(self, + config: BaiChuanConfig, + position_embedding: str, + linear_method: Optional[LinearMethodBase] = None): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -252,7 +262,7 @@ def __init__(self, config: BaiChuanConfig, position_embedding: str): config.hidden_size, ) self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config, position_embedding) + BaiChuanDecoderLayer(config, position_embedding, linear_method) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -285,16 +295,15 @@ def forward( class BaiChuanBaseForCausalLM(nn.Module): - def __init__(self, config, position_embedding: str): + def __init__(self, + config, + position_embedding: str, + linear_method: Optional[LinearMethodBase] = None): super().__init__() self.config = config - self.model = BaiChuanModel(config, position_embedding) - self.lm_head = ColumnParallelLinear( - config.hidden_size, - config.vocab_size, - bias=False, - gather_output=False, - ) + self.linear_method = linear_method + self.model = BaiChuanModel(config, position_embedding, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) def forward( @@ -311,79 +320,46 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = [] - _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tp_world_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() - + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - - if "W_pack" in name: - total_num_heads = self.config.num_attention_heads - hidden_size = self.config.hidden_size - head_size = hidden_size // total_num_heads - num_heads = total_num_heads // tp_world_size - head_start = tp_rank * num_heads - head_end = (tp_rank + 1) * num_heads - - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size, hidden_size) - loaded_weight = loaded_weight[:, head_start:head_end, :, :] - loaded_weight = loaded_weight.reshape(-1, hidden_size) - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * - (tp_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_gate_up_weight: - continue - - param = state_dict[name] - - if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tp_rank) - continue - - load_tensor_parallel_weights( - param, - loaded_weight, - name, - self._column_parallel_weights, - self._row_parallel_weights, - tp_rank, - ) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b - def __init__(self, config): - super().__init__(config, "ALIBI") + def __init__(self, + config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__(config, "ALIBI", linear_method) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b - def __init__(self, config): - super().__init__(config, "ROPE") + def __init__(self, + config, + linear_method: Optional[LinearMethodBase] = None): + super().__init__(config, "ROPE", linear_method) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index f3bb17655c5b3..1d379a623c76d 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -30,14 +30,17 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttentionWithALiBi +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -70,7 +73,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: class BloomAttention(nn.Module): - def __init__(self, config: BloomConfig): + def __init__( + self, + config: BloomConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size self.total_num_heads = config.n_head @@ -81,17 +88,18 @@ def __init__(self, config: BloomConfig): assert self.total_num_heads % tp_world_size == 0 self.num_heads = self.total_num_heads // tp_world_size - self.query_key_value = ColumnParallelLinear( + self.query_key_value = QKVParallelLinear( self.hidden_size, - 3 * self.hidden_size, + self.head_dim, + self.total_num_heads, bias=True, - gather_output=False, + linear_method=linear_method, ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - input_is_parallel=True, + linear_method=linear_method, ) # Create the alibi slopes and slice them. @@ -125,19 +133,23 @@ def forward( class BloomMLP(nn.Module): - def __init__(self, config: BloomConfig): + def __init__( + self, + config: BloomConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.hidden_size self.dense_h_to_4h = ColumnParallelLinear( hidden_size, 4 * hidden_size, - gather_output=False, + linear_method=linear_method, ) self.act = get_act_fn("gelu") self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, hidden_size, - input_is_parallel=True, + linear_method=linear_method, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -149,16 +161,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class BloomBlock(nn.Module): - def __init__(self, config: BloomConfig): + def __init__( + self, + config: BloomConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.hidden_size self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config) + self.self_attention = BloomAttention(config, linear_method) self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) - self.mlp = BloomMLP(config) + self.mlp = BloomMLP(config, linear_method) self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm) @@ -203,7 +219,11 @@ def forward( class BloomModel(nn.Module): - def __init__(self, config: BloomConfig): + def __init__( + self, + config: BloomConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.embed_dim = config.hidden_size @@ -216,8 +236,10 @@ def __init__(self, config: BloomConfig): self.embed_dim, eps=config.layer_norm_epsilon) # Transformer blocks - self.h = nn.ModuleList( - [BloomBlock(config) for _ in range(config.num_hidden_layers)]) + self.h = nn.ModuleList([ + BloomBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) # Final Layer Norm self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -251,12 +273,15 @@ def forward( class BloomForCausalLM(nn.Module): - def __init__(self, config: BloomConfig): + def __init__( + self, + config: BloomConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config - self.transformer = BloomModel(config) - # TODO(zhuohan): create a new weight after implementing pipeline - # parallelism + self.linear_method = linear_method + self.transformer = BloomModel(config, linear_method) self.lm_head_weight = self.transformer.word_embeddings.weight self.sampler = Sampler(config.vocab_size) @@ -274,55 +299,36 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = [ - "word_embeddings.weight", "dense_h_to_4h.weight", "dense_h_to_4h.bias" - ] - _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tp_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if name == "lm_head.weight": - # Since hidden_states are parallelized, we need to - # load lm_head.weight in parallel. - self._column_parallel_weights.append(name) - # If lm_head is provided, use it instead. - param = self.lm_head_weight - else: - if not name.startswith("transformer."): - name = "transformer." + name - param = state_dict[name] + continue + if not name.startswith("transformer."): + name = "transformer." + name + param = params_dict[name] if "query_key_value" in name: - # NOTE(woosuk): BLOOM's fused QKV has the shape of - # [num_heads * 3 * head_size, hidden_size], while the - # required shape is [3 * num_heads * head_size, hidden_size]. + # NOTE: BLOOM's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). # Thus, we need weight conversion. - shard_size = param.shape[0] - start = shard_size * tp_rank - end = shard_size * (tp_rank + 1) - loaded_weight = loaded_weight[start:end] - + output_dim = getattr(param, "output_dim", None) num_heads = self.config.num_attention_heads - hidden_size = self.config.hidden_size - head_size = hidden_size // num_heads - if "query_key_value.weight" in name: - loaded_weight = loaded_weight.view(-1, 3, head_size, - hidden_size) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1, hidden_size) - elif "query_key_value.bias" in name: - loaded_weight = loaded_weight.view(-1, 3, head_size) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1) - else: - raise ValueError(f"Unexpected weight name: {name}") - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, tp_rank) + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 8acc8e468b652..673ca2092146a 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -6,32 +6,28 @@ The input of the model is flattened to a 1D tensor of tokens. The model uses InputMetadata to extract the original 2D shape of the input. """ -from typing import Dict, List, Optional, Tuple +from typing import List, Optional, Tuple import torch from torch import nn from torch.nn import LayerNorm from vllm.model_executor.input_metadata import InputMetadata +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import ( - hf_model_weights_iterator, - load_tensor_parallel_weights, -) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding -from vllm.model_executor.parallel_utils.layers import ( - ColumnParallelLinear, - RowParallelLinear, -) -from vllm.sequence import SequenceOutputs - + get_tensor_model_parallel_world_size) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) +from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -39,7 +35,11 @@ class GLMAttention(nn.Module): - def __init__(self, config): + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -50,25 +50,33 @@ def __init__(self, config): self.total_num_kv_heads = (config.multi_query_group_num if config.multi_query_attention else config.num_attention_heads) - assert self.total_num_kv_heads % tp_size == 0 - self.num_kv_heads = self.total_num_kv_heads // tp_size + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = config.hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 - self.query_key_value = ColumnParallelLinear( - config.hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * + self.query_key_value = QKVParallelLinear( + self.hidden_size, self.head_dim, - bias=config.add_qkv_bias, - gather_output=False, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.add_bias_linear or config.add_qkv_bias, + linear_method=linear_method, ) self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=config.add_bias_linear, - input_is_parallel=True, + linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( @@ -78,7 +86,6 @@ def __init__(self, config): rotary_dim=self.head_dim // 2, num_kv_heads=self.num_kv_heads, is_neox_style=False, - # is_glm_style=True ) def forward( @@ -117,17 +124,21 @@ class GLMMLP(nn.Module): state back into h hidden dimension. """ - def __init__(self, config): + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.add_bias = config.add_bias_linear # Project to 4h. - self.dense_h_to_4h = ColumnParallelLinear( + self.dense_h_to_4h = MergedColumnParallelLinear( config.hidden_size, - config.ffn_hidden_size * 2, + [config.ffn_hidden_size] * 2, bias=config.add_bias_linear, - gather_output=False, + linear_method=linear_method, ) self.activation_func = SiluAndMul() @@ -137,7 +148,7 @@ def __init__(self, config): config.ffn_hidden_size, config.hidden_size, bias=config.add_bias_linear, - input_is_parallel=True, + linear_method=linear_method, ) def forward(self, hidden_states): @@ -159,6 +170,7 @@ class GLMBlock(nn.Module): def __init__( self, config, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.apply_residual_connection_post_layernorm = ( @@ -172,7 +184,7 @@ def __init__( eps=config.layernorm_epsilon) # Self attention. - self.self_attention = GLMAttention(config) + self.self_attention = GLMAttention(config, linear_method) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output @@ -180,7 +192,7 @@ def __init__( config.hidden_size, eps=config.layernorm_epsilon) # MLP - self.mlp = GLMMLP(config) + self.mlp = GLMMLP(config, linear_method) def forward( self, @@ -227,7 +239,11 @@ def forward( class GLMTransformer(nn.Module): """Transformer class.""" - def __init__(self, config): + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.post_layer_norm = config.post_layer_norm @@ -236,7 +252,7 @@ def __init__(self, config): # Transformer layers. self.layers = nn.ModuleList( - [GLMBlock(config) for i in range(self.num_layers)]) + [GLMBlock(config, linear_method) for i in range(self.num_layers)]) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm @@ -274,7 +290,11 @@ def forward( class ChatGLMModel(nn.Module): - def __init__(self, config): + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.embedding = VocabParallelEmbedding(config.padded_vocab_size, @@ -283,15 +303,10 @@ def __init__(self, config): self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config) + self.encoder = GLMTransformer(config, linear_method) - self.output_layer = ColumnParallelLinear( - config.hidden_size, - config.padded_vocab_size, - bias=False, - gather_output=False, - params_dtype=config.torch_dtype, - ) + self.output_layer = ParallelLMHead(config.padded_vocab_size, + config.hidden_size) def forward( self, @@ -317,10 +332,15 @@ def forward( class ChatGLMForCausalLM(nn.Module): - def __init__(self, config: ChatGLMConfig): + def __init__( + self, + config: ChatGLMConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config: ChatGLMConfig = config - self.transformer = ChatGLMModel(config) + self.linear_method = linear_method + self.transformer = ChatGLMModel(config, linear_method) self.lm_head_weight = self.transformer.output_layer.weight self.sampler = Sampler(config.padded_vocab_size) @@ -331,78 +351,26 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, cache_events: Optional[List[torch.cuda.Event]], - ) -> Dict[int, SequenceOutputs]: + ) -> SamplerOutput: hidden_states = self.transformer(input_ids, positions, kv_caches, input_metadata, cache_events) next_tokens = self.sampler(self.lm_head_weight, hidden_states, input_metadata) return next_tokens - _column_parallel_weights = [ - "output_layer.weight", - "embedding.weight", - ] - _row_parallel_weights = ["dense_4h_to_h", "self_attention.dense"] - - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): - tp_rank = get_tensor_model_parallel_rank() - tp_size = get_tensor_model_parallel_world_size() - - q_proj_shard_size = self.config.hidden_size // tp_size - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - self.config.multi_query_group_num // tp_size) - - mlp_hidden_shard_size = self.config.ffn_hidden_size // tp_size - - state_dict = self.state_dict() + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + if "rotary_pos_emb.inv_freq" in name: + continue if "word_embeddings" in name: name = name.replace(".word_embeddings", "") - - if name in state_dict: - param = state_dict[name] - if "query_key_value" in name: - q_offset = q_proj_shard_size * tp_rank - k_offset = (q_proj_shard_size * tp_size + - kv_proj_shard_size * tp_rank) - v_offset = (q_proj_shard_size * tp_size + - kv_proj_shard_size * (tp_size + tp_rank)) - wq = loaded_weight[q_offset:q_offset + q_proj_shard_size] - wk = loaded_weight[k_offset:k_offset + kv_proj_shard_size] - wv = loaded_weight[v_offset:v_offset + kv_proj_shard_size] - loaded_weight = torch.cat([wq, wk, wv], dim=0) - param.data.copy_(loaded_weight) - continue - - if "dense_h_to_4h" in name: - w_gate = loaded_weight[mlp_hidden_shard_size * - tp_rank:mlp_hidden_shard_size * - (tp_rank + 1)] - w_proj = loaded_weight[mlp_hidden_shard_size * - (tp_size + - tp_rank):mlp_hidden_shard_size * - (tp_size + tp_rank + 1)] - loaded_weight = torch.cat([w_gate, w_proj], dim=0) - param.data.copy_(loaded_weight) - continue - - load_tensor_parallel_weights( - param, - loaded_weight, - name, - self._column_parallel_weights, - self._row_parallel_weights, - tp_rank, - ) - elif name == "transformer.rotary_pos_emb.inv_freq": - continue - else: - print("Warning never found tensor's name:", name) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 6c249f6c98fec..3307d05494429 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -30,17 +30,19 @@ from vllm.model_executor.layers.attention import (PagedAttention, PagedAttentionWithALiBi, PagedAttentionWithRoPE) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor, - hf_model_weights_iterator, - load_tensor_parallel_weights) -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.communication_op import ( tensor_model_parallel_all_reduce) +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import RWConfig @@ -48,19 +50,6 @@ FalconConfig = Union[HF_FalconConfig, RWConfig] -# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during -# training, this means that there's one additional quantization to bfloat16 -# between the operations. In order not to degrade the quality of our HF-port, -# we keep these characteristics in the final model. -class FalconLinear(nn.Linear): - - def forward(self, x: torch.Tensor) -> torch.Tensor: - hidden_states = x @ self.weight.T - if self.bias is None: - return hidden_states - return hidden_states + self.bias - - def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))), @@ -86,7 +75,11 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: class FalconAttention(nn.Module): - def __init__(self, config: FalconConfig): + def __init__( + self, + config: FalconConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size @@ -103,41 +96,29 @@ def __init__(self, config: FalconConfig): if self.new_decoder_architecture: self.total_num_kv_heads = config.num_kv_heads - assert self.total_num_heads % tp_size == 0 - self.num_kv_heads = self.total_num_kv_heads // tp_size - self.query_key_value = ColumnParallelLinear( - self.hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * - self.head_dim, - bias=config.bias, - gather_output=False, - skip_bias_add=True, - ) elif self.multi_query: self.total_num_kv_heads = 1 - self.num_kv_heads = 1 - self.query = ColumnParallelLinear( - self.hidden_size, - self.total_num_heads * self.head_dim, - bias=config.bias, - gather_output=False, - skip_bias_add=True, - ) - self.key_value = FalconLinear(self.hidden_size, - 2 * self.head_dim, - bias=config.bias) else: self.total_num_kv_heads = self.total_num_heads - self.num_kv_heads = self.num_heads - self.query_key_value = ColumnParallelLinear( - self.hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * - self.head_dim, - bias=config.bias, - gather_output=False, - skip_bias_add=True, - ) + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.query_key_value = QKVParallelLinear( + self.hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=config.bias, + skip_bias_add=True, + linear_method=linear_method, + ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -149,7 +130,6 @@ def __init__(self, config: FalconConfig): self.hidden_size, self.hidden_size, bias=config.bias, - input_is_parallel=True, skip_bias_add=True, reduce_results=self.reduce_row_parallel_results) @@ -196,18 +176,10 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - if not self.new_decoder_architecture and self.multi_query: - q, bias = self.query(hidden_states) - if bias is not None: - q += bias - kv = self.key_value(hidden_states) - k, v = kv.split([self.kv_size, self.kv_size], dim=-1) - else: - qkv, bias = self.query_key_value(hidden_states) - if bias is not None: - qkv += bias - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) + qkv, bias = self.query_key_value(hidden_states) + if bias is not None: + qkv += bias + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) k_cache, v_cache = kv_cache if self.use_rotary: attn_output = self.attn(positions, q, k, v, k_cache, v_cache, @@ -221,15 +193,19 @@ def forward( class FalconMLP(nn.Module): - def __init__(self, config: FalconConfig): + def __init__( + self, + config: FalconConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.hidden_size self.dense_h_to_4h = ColumnParallelLinear(hidden_size, 4 * hidden_size, bias=config.bias, - gather_output=False, - skip_bias_add=True) + skip_bias_add=True, + linear_method=linear_method) self.act = nn.GELU() self.reduce_row_parallel_results = not (config.new_decoder_architecture or config.parallel_attn) @@ -237,9 +213,9 @@ def __init__(self, config: FalconConfig): 4 * hidden_size, hidden_size, bias=config.bias, - input_is_parallel=True, skip_bias_add=True, - reduce_results=self.reduce_row_parallel_results) + reduce_results=self.reduce_row_parallel_results, + linear_method=linear_method) def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE(zhuohan): Following huggingface, we do not fuse bias add here. @@ -253,12 +229,16 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FalconDecoderLayer(nn.Module): - def __init__(self, config: FalconConfig): + def __init__( + self, + config: FalconConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = FalconAttention(config) - self.mlp = FalconMLP(config) + self.self_attention = FalconAttention(config, linear_method) + self.mlp = FalconMLP(config, linear_method) self.config = config if config.new_decoder_architecture: @@ -334,7 +314,11 @@ def forward( class FalconModel(nn.Module): - def __init__(self, config: FalconConfig): + def __init__( + self, + config: FalconConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -349,7 +333,8 @@ def __init__(self, config: FalconConfig): # Transformer blocks self.h = nn.ModuleList([ - FalconDecoderLayer(config) for _ in range(config.num_hidden_layers) + FalconDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) ]) # Final Layer Norm @@ -383,15 +368,18 @@ def forward( class FalconForCausalLM(nn.Module): - def __init__(self, config: FalconConfig): + def __init__( + self, + config: FalconConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config - self.transformer = FalconModel(config) - self.lm_head = ColumnParallelLinear( - config.hidden_size, + self.linear_method = linear_method + self.transformer = FalconModel(config, linear_method) + self.lm_head = ParallelLMHead( config.vocab_size, - bias=False, - gather_output=False, + config.hidden_size, ) self.sampler = Sampler(config.vocab_size) @@ -415,89 +403,44 @@ def forward( return next_tokens - _column_parallel_weights = [ - "word_embeddings.weight", "lm_head.weight", "dense_h_to_4h.weight", - "dense_h_to_4h.bias" - ] - _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tp_size = (get_tensor_model_parallel_world_size()) - tp_rank = get_tensor_model_parallel_rank() - - hidden_size = self.config.hidden_size total_num_heads = self.config.num_attention_heads - num_heads = total_num_heads // tp_size - head_size = hidden_size // total_num_heads - head_start = tp_rank * num_heads - head_end = (tp_rank + 1) * num_heads if self.config.new_decoder_architecture: total_num_kv_heads = self.config.num_kv_heads - num_kv_heads = total_num_kv_heads // tp_size - separated_q_kv = False - kv_head_start = tp_rank * num_kv_heads - kv_head_end = (tp_rank + 1) * num_kv_heads elif self.config.multi_query: total_num_kv_heads = 1 - num_kv_heads = 1 - separated_q_kv = True - kv_head_start = 0 - kv_head_end = 1 else: total_num_kv_heads = total_num_heads - num_kv_heads = total_num_kv_heads // tp_size - separated_q_kv = False - kv_head_start = tp_rank * num_kv_heads - kv_head_end = (tp_rank + 1) * num_kv_heads num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads - state_dict = self.state_dict() - + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): + param = params_dict[name] if "query_key_value" in name: - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - loaded_weight_size = loaded_weight.size() + output_dim = getattr(param, "output_dim", None) + loaded_weight_shape = loaded_weight.shape loaded_weight = loaded_weight.view( - total_num_kv_heads, num_query_heads_per_kv_head + 2, - head_size, *loaded_weight_size[1:]) - - wq = loaded_weight[:, :-2].reshape(-1, *loaded_weight_size[1:]) - wk = loaded_weight[:, [-2]].reshape(-1, - *loaded_weight_size[1:]) - wv = loaded_weight[:, [-1]].reshape(-1, - *loaded_weight_size[1:]) - - wq = wq[head_size * head_start:head_size * head_end] - wk = wk[head_size * kv_head_start:head_size * kv_head_end] - wv = wv[head_size * kv_head_start:head_size * kv_head_end] - - if separated_q_kv: - loaded_weight_q = wq - loaded_weight_kv = torch.cat([wk, wv], dim=0) - q_weight_name = name.replace("query_key_value", "query") - kv_weight_name = name.replace("query_key_value", - "key_value") - load_tensor_parallel_weights(state_dict[q_weight_name], - loaded_weight_q, - q_weight_name, - self._column_parallel_weights, - self._row_parallel_weights, - tp_rank) - load_tensor_parallel_weights(state_dict[kv_weight_name], - loaded_weight_kv, - kv_weight_name, - self._column_parallel_weights, - self._row_parallel_weights, - tp_rank) - continue - else: - loaded_weight = torch.cat([wq, wk, wv], dim=0) - - param = state_dict[name] - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, tp_rank) + loaded_weight_shape[:output_dim] + + (total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) + + loaded_weight_shape[output_dim + 1:]) + wq = loaded_weight.narrow( + output_dim + 1, 0, num_query_heads_per_kv_head).reshape( + *loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wk = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + wv = loaded_weight.narrow( + output_dim + 1, num_query_heads_per_kv_head + 1, + 1).reshape(*loaded_weight_shape[:output_dim], -1, + *loaded_weight_shape[output_dim + 1:]) + loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index b9309eb956544..d540f74724202 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -30,15 +30,17 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import ( - convert_pyslice_to_tensor, hf_model_weights_iterator, - load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) + get_tensor_model_parallel_world_size) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -46,7 +48,11 @@ class GPT2Attention(nn.Module): - def __init__(self, config: GPT2Config): + def __init__( + self, + config: GPT2Config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads @@ -57,17 +63,18 @@ def __init__(self, config: GPT2Config): self.head_dim = self.hidden_size // total_num_heads self.scale = self.head_dim**-0.5 - self.c_attn = ColumnParallelLinear( + self.c_attn = QKVParallelLinear( self.hidden_size, - 3 * self.hidden_size, + self.head_dim, + total_num_heads, bias=True, - gather_output=False, + linear_method=linear_method, ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - input_is_parallel=True, + linear_method=linear_method, ) self.attn = PagedAttention(self.num_heads, self.head_dim, @@ -95,6 +102,7 @@ def __init__( self, intermediate_size: int, config: GPT2Config, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() hidden_size = config.hidden_size @@ -102,13 +110,13 @@ def __init__( hidden_size, intermediate_size, bias=True, - gather_output=False, + linear_method=linear_method, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, - input_is_parallel=True, + linear_method=linear_method, ) self.act = get_act_fn(config.activation_function) @@ -121,16 +129,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GPT2Block(nn.Module): - def __init__(self, config: GPT2Config): + def __init__( + self, + config: GPT2Config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.hidden_size inner_dim = (config.n_inner if config.n_inner is not None else 4 * hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config) + self.attn = GPT2Attention(config, linear_method) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPT2MLP(inner_dim, config) + self.mlp = GPT2MLP(inner_dim, config, linear_method) def forward( self, @@ -160,24 +172,23 @@ def forward( class GPT2Model(nn.Module): - def __init__(self, config: GPT2Config): + def __init__( + self, + config: GPT2Config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config assert not config.add_cross_attention assert not config.scale_attn_by_inverse_layer_idx assert not config.reorder_and_upcast_attn self.embed_dim = config.hidden_size - - # Optimization: While the vocab size of GPT-2 is 50257, we extend it - # to 50304 in order to make it divisible by 64. - # This improves performance since GPUs are faster if the dimension - # is divisible by 64. In addition, it allows us to shard the embedding - # layer across 2, 4, 8, or more GPUs. - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim) + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - self.h = nn.ModuleList( - [GPT2Block(config) for _ in range(config.num_hidden_layers)]) + self.h = nn.ModuleList([ + GPT2Block(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -207,12 +218,15 @@ def forward( class GPT2LMHeadModel(nn.Module): - def __init__(self, config: GPT2Config): + def __init__( + self, + config: GPT2Config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config - self.transformer = GPT2Model(config) - # TODO(zhuohan): create a new weight after implementing pipeline - # parallelism + self.linear_method = linear_method + self.transformer = GPT2Model(config, linear_method) self.lm_head_weight = self.transformer.wte.weight self.sampler = Sampler(config.vocab_size) @@ -230,19 +244,12 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = ["c_fc.weight", "c_fc.bias"] - _row_parallel_weights = ["c_proj.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() - + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "lm_head.weight" in name: @@ -253,53 +260,19 @@ def load_weights(self, # Skip attention mask. # NOTE: "c_attn.bias" should not be skipped. continue - if not name.startswith("transformer."): name = "transformer." + name - - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - + param = params_dict[name] # The HF's GPT-2 implementation uses Conv1D instead of Linear. # Because of this, we need to transpose the weights. + # Note(zhuohan): the logic below might break quantized models. for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]: if conv1d_weight_name not in name: continue if not name.endswith(".weight"): continue loaded_weight = loaded_weight.t() - param = state_dict[name] - - if name == "transformer.wte.weight": - load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) - continue - - # For the fused QKV linear layer, manually shard the weights. - if "c_attn" in name: - # GPT-2's fused QKV has the shape of - # [3 * num_heads * head_size, hidden_size]. - # When tensor parallelism is used, we shard the weights along - # the head dimension. - total_num_heads = self.config.num_attention_heads - hidden_size = self.config.hidden_size - head_size = hidden_size // total_num_heads - num_heads = total_num_heads // tensor_model_parallel_world_size - head_start = tensor_model_parallel_rank * num_heads - head_end = (tensor_model_parallel_rank + 1) * num_heads - if name.endswith(".weight"): - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size, hidden_size) - loaded_weight = loaded_weight[:, head_start:head_end, :, :] - loaded_weight = loaded_weight.reshape(-1, hidden_size) - elif name.endswith(".bias"): - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size) - loaded_weight = loaded_weight[:, head_start:head_end, :] - loaded_weight = loaded_weight.reshape(-1) - else: - raise ValueError(f"Unexpected parameter name {name}") - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 41f72c8cb7086..1e489e97052a7 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -31,15 +31,17 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import ( - convert_pyslice_to_tensor, hf_model_weights_iterator, - load_padded_tensor_parallel_vocab, load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) + get_tensor_model_parallel_world_size) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -47,7 +49,11 @@ class GPTBigCodeAttention(nn.Module): - def __init__(self, config: GPTBigCodeConfig): + def __init__( + self, + config: GPTBigCodeConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size total_num_heads = config.num_attention_heads @@ -61,32 +67,26 @@ def __init__(self, config: GPTBigCodeConfig): self.multi_query = config.multi_query if self.multi_query: + total_num_kv_heads = 1 self.num_kv_heads = 1 - self.kv_dim = self.head_dim - self.c_attn_q = ColumnParallelLinear( - self.hidden_size, - self.hidden_size, - bias=True, - gather_output=False, - ) - self.c_attn_kv = nn.Linear(self.hidden_size, - 2 * self.kv_dim, - bias=True) else: + total_num_kv_heads = total_num_heads self.num_kv_heads = self.num_heads - self.kv_dim = self.num_kv_heads * self.head_dim - self.c_attn = ColumnParallelLinear( - self.hidden_size, - self.hidden_size + 2 * self.kv_dim, - bias=True, - gather_output=False, - ) + self.kv_dim = self.head_dim * self.num_kv_heads + self.c_attn = QKVParallelLinear( + self.hidden_size, + self.head_dim, + total_num_heads, + total_num_kv_heads, + bias=True, + linear_method=linear_method, + ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - input_is_parallel=True, + linear_method=linear_method, ) self.attn = PagedAttention(self.num_heads, self.head_dim, @@ -100,17 +100,14 @@ def forward( input_metadata: InputMetadata, cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: - if self.multi_query: - q, _ = self.c_attn_q(hidden_states) - kv = self.c_attn_kv(hidden_states) - k, v = kv.split([self.kv_dim, self.kv_dim], dim=-1) - else: - qkv, _ = self.c_attn(hidden_states) - q, k, v = qkv.split([ + qkv, _ = self.c_attn(hidden_states) + q, k, v = qkv.split( + [ self.hidden_size // self.tensor_model_parallel_world_size, self.kv_dim, self.kv_dim ], - dim=-1) + dim=-1, + ) key_cache, value_cache = kv_cache attn_output = self.attn(q, k, v, key_cache, value_cache, input_metadata, cache_event) @@ -124,6 +121,7 @@ def __init__( self, intermediate_size: int, config: GPTBigCodeConfig, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() hidden_size = config.hidden_size @@ -131,13 +129,13 @@ def __init__( hidden_size, intermediate_size, bias=True, - gather_output=False, + linear_method=linear_method, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, - input_is_parallel=True, + linear_method=linear_method, ) self.act = get_act_fn(config.activation_function) @@ -150,16 +148,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GPTBigCodeBlock(nn.Module): - def __init__(self, config: GPTBigCodeConfig): + def __init__( + self, + config: GPTBigCodeConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.hidden_size inner_dim = (config.n_inner if config.n_inner is not None else 4 * hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config) + self.attn = GPTBigCodeAttention(config, linear_method) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPTBigMLP(inner_dim, config) + self.mlp = GPTBigMLP(inner_dim, config, linear_method) def forward( self, @@ -189,23 +191,23 @@ def forward( class GPTBigCodeModel(nn.Module): - def __init__(self, config: GPTBigCodeConfig): + def __init__( + self, + config: GPTBigCodeConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config assert not config.add_cross_attention self.embed_dim = config.hidden_size - # Optimization: While the vocab size of GPT-2 is 50257, we extend it - # to 50304 in order to make it divisible by 64. - # This improves performance since GPUs are faster if the dimension - # is divisible by 64. In addition, it allows us to shard the embedding - # layer across 2, 4, 8, or more GPUs. - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.wte = VocabParallelEmbedding(vocab_size, self.embed_dim) + self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - self.h = nn.ModuleList( - [GPTBigCodeBlock(config) for _ in range(config.num_hidden_layers)]) + self.h = nn.ModuleList([ + GPTBigCodeBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -235,12 +237,15 @@ def forward( class GPTBigCodeForCausalLM(nn.Module): - def __init__(self, config: GPTBigCodeConfig): + def __init__( + self, + config: GPTBigCodeConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config - self.transformer = GPTBigCodeModel(config) - # TODO(zhuohan): create a new weight after implementing pipeline - # parallelism + self.linear_method = linear_method + self.transformer = GPTBigCodeModel(config, linear_method) self.lm_head_weight = self.transformer.wte.weight self.sampler = Sampler(config.vocab_size) @@ -258,89 +263,21 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = ["c_fc.weight", "c_fc.bias"] - _row_parallel_weights = ["c_proj.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tensor_model_parallel_world_size = ( - get_tensor_model_parallel_world_size()) - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() - + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "lm_head.weight" in name: - # GPT-2 ties the weights of the embedding layer and the final - # linear layer. continue if ".attn.bias" in name: # Skip attention mask. # NOTE: "c_attn.bias" should not be skipped. continue - - if not name.startswith("transformer."): - name = "transformer." + name - - # For the fused QKV linear layer, manually shard the weights. - if "c_attn" in name: - # GPT-2's fused QKV has the shape of - # [3 * num_heads * head_size, hidden_size]. - # When tensor parallelism is used, we shard the weights along - # the head dimension. - total_num_heads = self.config.num_attention_heads - total_num_kv_heads = (1 if self.config.multi_query else - total_num_heads) - hidden_size = self.config.hidden_size - head_size = hidden_size // total_num_heads - total_kv_size = head_size * total_num_kv_heads - num_heads = total_num_heads // tensor_model_parallel_world_size - head_start = tensor_model_parallel_rank * num_heads - head_end = (tensor_model_parallel_rank + 1) * num_heads - - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - wq, wk, wv = torch.split( - loaded_weight, [hidden_size, total_kv_size, total_kv_size], - dim=0) - - wq = wq[head_size * head_start:head_size * head_end] - if not self.config.multi_query: - # Split the heads when using normal multi-head attention - wk = wk[head_size * head_start:head_size * head_end] - wv = wv[head_size * head_start:head_size * head_end] - loaded_weight = torch.cat([wq, wk, wv], dim=0) - else: - # For multi-query attention, we split the query - # but replicate the key and value. - loaded_weight_q = wq - loaded_weight_kv = torch.cat([wk, wv], dim=0) - q_weight_name = name.replace("c_attn", "c_attn_q") - kv_weight_name = name.replace("c_attn", "c_attn_kv") - load_tensor_parallel_weights(state_dict[q_weight_name], - loaded_weight_q, - q_weight_name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) - load_tensor_parallel_weights(state_dict[kv_weight_name], - loaded_weight_kv, - kv_weight_name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) - continue - - param = state_dict[name] - - if name == "transformer.wte.weight": - load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) - continue - - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index f61eab73b3a89..a5b77138bd17f 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -29,14 +29,17 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) + get_tensor_model_parallel_world_size) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -44,23 +47,28 @@ class GPTJAttention(nn.Module): - def __init__(self, config: GPTJConfig): + def __init__( + self, + config: GPTJConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.total_num_heads - self.qkv_proj = ColumnParallelLinear( + self.qkv_proj = QKVParallelLinear( config.hidden_size, - 3 * config.hidden_size, + self.head_size, + self.total_num_heads, bias=False, - gather_output=False, + linear_method=linear_method, ) self.out_proj = RowParallelLinear( config.hidden_size, config.hidden_size, bias=False, - input_is_parallel=True, + linear_method=linear_method, ) tp_world_size = get_tensor_model_parallel_world_size() @@ -102,18 +110,23 @@ def forward( class GPTJMLP(nn.Module): - def __init__(self, intermediate_size: int, config: GPTJConfig): + def __init__( + self, + intermediate_size: int, + config: GPTJConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.n_embd self.fc_in = ColumnParallelLinear( hidden_size, intermediate_size, - gather_output=False, + linear_method=linear_method, ) self.fc_out = RowParallelLinear( intermediate_size, hidden_size, - input_is_parallel=True, + linear_method=linear_method, ) self.act = get_act_fn(config.activation_function) @@ -126,15 +139,19 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class GPTJBlock(nn.Module): - def __init__(self, config: GPTJConfig): + def __init__( + self, + config: GPTJConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() if config.n_inner is None: inner_dim = 4 * config.n_embd else: inner_dim = config.n_inner self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GPTJAttention(config) - self.mlp = GPTJMLP(inner_dim, config) + self.attn = GPTJAttention(config, linear_method) + self.mlp = GPTJMLP(inner_dim, config, linear_method) def forward( self, @@ -160,7 +177,11 @@ def forward( class GPTJModel(nn.Module): - def __init__(self, config: GPTJConfig): + def __init__( + self, + config: GPTJConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.embed_dim = config.n_embd @@ -169,7 +190,7 @@ def __init__(self, config: GPTJConfig): self.embed_dim, ) self.h = nn.ModuleList( - [GPTJBlock(config) for _ in range(config.n_layer)]) + [GPTJBlock(config, linear_method) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -200,15 +221,20 @@ def forward( class GPTJForCausalLM(nn.Module): - def __init__(self, config: GPTJConfig): + def __init__( + self, + config: GPTJConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config + self.linear_method = linear_method assert not config.tie_word_embeddings - self.transformer = GPTJModel(config) - self.lm_head = ColumnParallelLinear( - config.n_embd, + self.transformer = GPTJModel(config, linear_method) + self.lm_head = ParallelLMHead( config.vocab_size, - gather_output=False, + config.n_embd, + bias=True, ) self.sampler = Sampler(config.vocab_size) @@ -226,43 +252,33 @@ def forward( input_metadata, self.lm_head.bias) return next_tokens - _column_parallel_weights = [ - "wte.weight", "fc_in.weight", "fc_in.bias", "lm_head.weight", - "lm_head.bias" - ] - _row_parallel_weights = ["out_proj.weight", "fc_out.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tp_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "attn.bias" in name or "attn.masked_bias" in name: continue - - is_attention_weight = False - for stride_id, att_weight_name in enumerate( - ["q_proj", "k_proj", "v_proj"]): - if att_weight_name not in name: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: continue - param = state_dict[name.replace(att_weight_name, "qkv_proj")] - shard_size = param.shape[0] // 3 - loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * - (tp_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_attention_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_attention_weight: - continue - - param = state_dict[name] - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, tp_rank) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index d0187c93c541e..5c40783262ce7 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -29,14 +29,17 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) + get_tensor_model_parallel_world_size) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -44,7 +47,11 @@ class GPTNeoXAttention(nn.Module): - def __init__(self, config: GPTNeoXConfig): + def __init__( + self, + config: GPTNeoXConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size @@ -56,15 +63,16 @@ def __init__(self, config: GPTNeoXConfig): self.num_heads = (self.total_num_heads // tensor_model_parallel_world_size) - self.query_key_value = ColumnParallelLinear( + self.query_key_value = QKVParallelLinear( config.hidden_size, - 3 * config.hidden_size, - gather_output=False, + self.head_size, + self.total_num_heads, + linear_method=linear_method, ) self.dense = RowParallelLinear( config.hidden_size, config.hidden_size, - input_is_parallel=True, + linear_method=linear_method, ) scaling = self.head_size**-0.5 @@ -100,17 +108,21 @@ def forward( class GPTNeoXMLP(nn.Module): - def __init__(self, config: GPTNeoXConfig): + def __init__( + self, + config: GPTNeoXConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.dense_h_to_4h = ColumnParallelLinear( config.hidden_size, config.intermediate_size, - gather_output=False, + linear_method=linear_method, ) self.dense_4h_to_h = RowParallelLinear( config.intermediate_size, config.hidden_size, - input_is_parallel=True, + linear_method=linear_method, ) self.act = get_act_fn(config.hidden_act) @@ -123,15 +135,19 @@ def forward(self, hidden_states): class GPTNeoXLayer(nn.Module): - def __init__(self, config: GPTNeoXConfig): + def __init__( + self, + config: GPTNeoXConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.use_parallel_residual = config.use_parallel_residual self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config) - self.mlp = GPTNeoXMLP(config) + self.attention = GPTNeoXAttention(config, linear_method) + self.mlp = GPTNeoXMLP(config, linear_method) def forward( self, @@ -169,7 +185,11 @@ def forward( class GPTNeoXModel(nn.Module): - def __init__(self, config: GPTNeoXConfig): + def __init__( + self, + config: GPTNeoXConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config @@ -177,8 +197,10 @@ def __init__(self, config: GPTNeoXConfig): config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList( - [GPTNeoXLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([ + GPTNeoXLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -210,15 +232,18 @@ def forward( class GPTNeoXForCausalLM(nn.Module): - def __init__(self, config): + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config - self.gpt_neox = GPTNeoXModel(config) - self.embed_out = ColumnParallelLinear( - config.hidden_size, + self.linear_method = linear_method + self.gpt_neox = GPTNeoXModel(config, linear_method) + self.embed_out = ParallelLMHead( config.vocab_size, - bias=False, - gather_output=False, + config.hidden_size, ) self.sampler = Sampler(config.vocab_size) @@ -236,50 +261,35 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = [ - "embed_in.weight", "embed_out.weight", "dense_h_to_4h.weight", - "dense_h_to_4h.bias" - ] - _row_parallel_weights = ["dense.weight", "dense_4h_to_h.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if ("attention.bias" in name or "attention.masked_bias" in name or "rotary_emb.inv_freq" in name): continue - param = state_dict[name] + param = params_dict[name] + if "query_key_value" in name: - # NOTE(woosuk): GPT-NeoX's fused QKV has the shape of - # [num_heads * 3 * head_size, hidden_size], while the - # required shape is [3 * num_heads * head_size, hidden_size]. + # NOTE: GPT-NeoX's fused QKV's output_dim has the shape of + # (num_heads * 3 * head_size), while the + # required shape is (3 * num_heads * head_size). # Thus, we need weight conversion. - shard_size = param.shape[0] - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - + output_dim = getattr(param, "output_dim", None) num_heads = self.config.num_attention_heads - hidden_size = self.config.hidden_size - head_size = hidden_size // num_heads - if "query_key_value.weight" in name: - loaded_weight = loaded_weight.view(-1, 3, head_size, - hidden_size) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1, hidden_size) - elif "query_key_value.bias" in name: - loaded_weight = loaded_weight.view(-1, 3, head_size) - loaded_weight = loaded_weight.transpose(0, 1) - loaded_weight = loaded_weight.reshape(-1) - else: - raise ValueError(f"Unexpected weight name: {name}") - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) + if output_dim is not None: + loaded_weight_shape = loaded_weight.shape + loaded_weight = loaded_weight.view( + loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + + loaded_weight_shape[output_dim + 1:]) + loaded_weight = loaded_weight.transpose( + output_dim, output_dim + 1) + loaded_weight = loaded_weight.reshape(loaded_weight_shape) + + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/internlm.py b/vllm/model_executor/models/internlm.py index 4a595a37730da..d90f8aaed624c 100644 --- a/vllm/model_executor/models/internlm.py +++ b/vllm/model_executor/models/internlm.py @@ -9,15 +9,17 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttentionWithRoPE from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (ColumnParallelLinear, - RowParallelLinear, - VocabParallelEmbedding) -from vllm.model_executor.weight_utils import ( - hf_model_weights_iterator, load_padded_tensor_parallel_vocab, - load_tensor_parallel_weights) + get_tensor_model_parallel_world_size) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -30,20 +32,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - self.gate_up_proj = ColumnParallelLinear( - hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - ) - self.down_proj = RowParallelLinear( - intermediate_size, - hidden_size, + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=False, - input_is_parallel=True, - ) + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -65,6 +64,7 @@ def __init__( bias: bool, rope_theta: float = 10000, max_position_embeddings: int = 8192, + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() self.hidden_size = hidden_size @@ -79,17 +79,18 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.qkv_proj = ColumnParallelLinear( + self.qkv_proj = QKVParallelLinear( hidden_size, - 3 * self.total_num_heads * self.head_dim, + self.head_dim, + self.total_num_heads, bias=bias, - gather_output=False, + linear_method=linear_method, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=bias, - input_is_parallel=True, + linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( self.num_heads, @@ -118,7 +119,11 @@ def forward( class InternLMDecoderLayer(nn.Module): - def __init__(self, config: LlamaConfig): + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -130,11 +135,13 @@ def __init__(self, config: LlamaConfig): bias=config.bias, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, + linear_method=linear_method, ) self.mlp = InternLMMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, + linear_method=linear_method, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -171,7 +178,11 @@ def forward( class InternLMModel(nn.Module): - def __init__(self, config: LlamaConfig): + def __init__( + self, + config: LlamaConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -183,7 +194,7 @@ def __init__(self, config: LlamaConfig): config.hidden_size, ) self.layers = nn.ModuleList([ - InternLMDecoderLayer(config) + InternLMDecoderLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -216,17 +227,16 @@ def forward( class InternLMForCausalLM(nn.Module): - def __init__(self, config): + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config - self.model = InternLMModel(config) - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.lm_head = ColumnParallelLinear( - config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - ) + self.linear_method = linear_method + self.model = InternLMModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) def forward( @@ -243,69 +253,33 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = [ - "qkv_proj.weight", "gate_proj.weight", "up_proj.weight" - ] - _row_parallel_weights = ["o_proj.weight", "down_proj.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() - + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - - if "embed_tokens" in name or "lm_head" in name: - param = state_dict[name] - load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) - continue - - is_attention_weight = False - for stride_id, att_weight_name in enumerate( - ["q_proj", "k_proj", "v_proj"]): - if att_weight_name not in name: - continue - param = state_dict[name.replace(att_weight_name, "qkv_proj")] - shard_size = param.shape[0] // 3 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_attention_weight = True - break - if is_attention_weight: - continue - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_gate_up_weight: - continue - - param = state_dict[name] - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 735e4ad172182..9381a2390c712 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -33,17 +33,19 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.quantized_linear import ParallelLinear +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding -from vllm.model_executor.quantization_utils import QuantizationConfig -from vllm.model_executor.weight_utils import ( - convert_pyslice_to_tensor, hf_model_weights_iterator, - load_tensor_parallel_weights, load_padded_tensor_parallel_vocab) + get_tensor_model_parallel_world_size) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -56,19 +58,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() - self.gate_up_proj = ParallelLinear.column(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - quant_config=quant_config) - self.down_proj = ParallelLinear.row(intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config) + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -91,7 +91,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -109,7 +109,6 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -117,21 +116,19 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.qkv_proj = ParallelLinear.column( + self.qkv_proj = QKVParallelLinear( hidden_size, - (self.total_num_heads + - 2 * self.total_num_kv_heads * num_kv_heads_replicas) * self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, bias=False, - gather_output=False, - quant_config=quant_config, + linear_method=linear_method, ) - self.o_proj = ParallelLinear.row( + self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - input_is_parallel=True, - quant_config=quant_config, + linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( self.num_heads, @@ -165,11 +162,10 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", @@ -181,13 +177,13 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - quant_config=quant_config, + linear_method=linear_method, ) self.mlp = LlamaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - quant_config=quant_config, + linear_method=linear_method, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -227,20 +223,18 @@ class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( - vocab_size, + config.vocab_size, config.hidden_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, quant_config) + LlamaDecoderLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -276,19 +270,13 @@ class LlamaForCausalLM(nn.Module): def __init__( self, config: LlamaConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.config = config - self.quant_config = quant_config - self.model = LlamaModel(config, quant_config) - vocab_size = ((config.vocab_size + 63) // 64) * 64 - # NOTE: The LM head is not quantized. - self.lm_head = ParallelLinear.column(config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - quant_config=None) + self.linear_method = linear_method + self.model = LlamaModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) def forward( @@ -305,124 +293,33 @@ def forward( input_metadata) return next_tokens - _column_parallel_layers = [] - _row_parallel_layers = ["o_proj", "down_proj"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - if self.quant_config is None: - col_weight_suffixes = ["weight"] - row_weight_suffixes = ["weight"] - else: - col_weight_suffixes = ( - self.quant_config.get_col_parallel_tensor_names()) - row_weight_suffixes = ( - self.quant_config.get_row_parallel_tensor_names()) - - column_parallel_weights: List[str] = [] - for layer in self._column_parallel_layers: - for suffix in col_weight_suffixes: - column_parallel_weights.append(f"{layer}.{suffix}") - row_parallel_weights: List[str] = [] - for layer in self._row_parallel_layers: - for suffix in row_weight_suffixes: - row_parallel_weights.append(f"{layer}.{suffix}") - - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - q_proj_shard_size = (self.config.hidden_size // tp_size) - num_kv_heads_replicas = max(1, - tp_size // self.config.num_key_value_heads) - num_kv_heads_per_gpu = max(1, - self.config.num_key_value_heads // tp_size) - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - num_kv_heads_per_gpu) - attention_weight_specs = [ - # (weight_name, shard_size, offset) - ("q_proj", q_proj_shard_size, 0), - ("k_proj", kv_proj_shard_size, q_proj_shard_size), - ("v_proj", kv_proj_shard_size, - q_proj_shard_size + kv_proj_shard_size), + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] - state_dict = self.state_dict() - + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - - packed_dim = None - is_transposed = False - if self.quant_config is not None: - packed_dim = self.quant_config.get_packed_dim(name) - is_transposed = self.quant_config.is_transposed(name) - if is_transposed: - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - loaded_weight = loaded_weight.T - - is_attention_weight = False - for weight_name, shard_size, offset in attention_weight_specs: + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "qkv_proj")] - if is_transposed: - param = param.T - - if packed_dim is not None: - shard_dim = 0 if not is_transposed else 1 - if packed_dim == shard_dim: - shard_size //= self.quant_config.pack_factor - offset //= self.quant_config.pack_factor - - if weight_name in ["k_proj", "v_proj"]: - shard_id = tp_rank // num_kv_heads_replicas - else: - shard_id = tp_rank - loaded_weight = loaded_weight[shard_size * - shard_id:shard_size * - (shard_id + 1)] - param_slice = param.data[offset:offset + shard_size] - assert param_slice.shape == loaded_weight.shape - - param_slice.copy_(loaded_weight) - is_attention_weight = True - break - if is_attention_weight: - continue - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - if is_transposed: - param = param.T - - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * - (tp_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_gate_up_weight: - continue - - param = state_dict[name] - if is_transposed: - param = param.T - - if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tp_rank) - continue - - load_tensor_parallel_weights(param, loaded_weight, name, - column_parallel_weights, - row_parallel_weights, tp_rank) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 8b09276e6f91d..f9b9120aff80d 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -33,17 +33,19 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.quantized_linear import ParallelLinear +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding -from vllm.model_executor.quantization_utils import QuantizationConfig -from vllm.model_executor.weight_utils import ( - convert_pyslice_to_tensor, hf_model_weights_iterator, - load_tensor_parallel_weights, load_padded_tensor_parallel_vocab) + get_tensor_model_parallel_world_size) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -56,19 +58,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() - self.gate_up_proj = ParallelLinear.column(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - quant_config=quant_config) - self.down_proj = ParallelLinear.row(intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config) + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -89,7 +89,7 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, sliding_window: Optional[int] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -98,8 +98,15 @@ def __init__(self, assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads - assert self.total_num_kv_heads % tp_size == 0 - self.num_kv_heads = self.total_num_kv_heads // tp_size + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -107,20 +114,19 @@ def __init__(self, self.rope_theta = rope_theta self.sliding_window = sliding_window - self.qkv_proj = ParallelLinear.column( + self.qkv_proj = QKVParallelLinear( hidden_size, - (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, bias=False, - gather_output=False, - quant_config=quant_config, + linear_method=linear_method, ) - self.o_proj = ParallelLinear.row( + self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - input_is_parallel=True, - quant_config=quant_config, + linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim, @@ -153,7 +159,7 @@ class MistralDecoderLayer(nn.Module): def __init__( self, config: MistralConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -165,13 +171,13 @@ def __init__( max_position=config.max_position_embeddings, num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, - quant_config=quant_config, + linear_method=linear_method, sliding_window=config.sliding_window) self.mlp = MistralMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - quant_config=quant_config, + linear_method=linear_method, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -211,20 +217,19 @@ class MistralModel(nn.Module): def __init__( self, config: MistralConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( - vocab_size, + config.vocab_size, config.hidden_size, ) self.layers = nn.ModuleList([ - MistralDecoderLayer(config, quant_config) + MistralDecoderLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -260,19 +265,13 @@ class MistralForCausalLM(nn.Module): def __init__( self, config: MistralConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.config = config - self.quant_config = quant_config - self.model = MistralModel(config, quant_config) - vocab_size = ((config.vocab_size + 63) // 64) * 64 - # NOTE: The LM head is not quantized. - self.lm_head = ParallelLinear.column(config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - quant_config=None) + self.linear_method = linear_method + self.model = MistralModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) def forward( @@ -289,118 +288,33 @@ def forward( input_metadata) return next_tokens - _column_parallel_layers = [] - _row_parallel_layers = ["o_proj", "down_proj"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - if self.quant_config is None: - col_weight_suffixes = ["weight"] - row_weight_suffixes = ["weight"] - else: - col_weight_suffixes = ( - self.quant_config.get_col_parallel_tensor_names()) - row_weight_suffixes = ( - self.quant_config.get_row_parallel_tensor_names()) - - column_parallel_weights: List[str] = [] - for layer in self._column_parallel_layers: - for suffix in col_weight_suffixes: - column_parallel_weights.append(f"{layer}.{suffix}") - row_parallel_weights: List[str] = [] - for layer in self._row_parallel_layers: - for suffix in row_weight_suffixes: - row_parallel_weights.append(f"{layer}.{suffix}") - - tp_size = get_tensor_model_parallel_world_size() - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - q_proj_shard_size = (self.config.hidden_size // tp_size) - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - self.config.num_key_value_heads // tp_size) - attention_weight_specs = [ - # (weight_name, shard_size, offset) - ("q_proj", q_proj_shard_size, 0), - ("k_proj", kv_proj_shard_size, q_proj_shard_size), - ("v_proj", kv_proj_shard_size, - q_proj_shard_size + kv_proj_shard_size), + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] - state_dict = self.state_dict() - + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - - packed_dim = None - is_transposed = False - if self.quant_config is not None: - packed_dim = self.quant_config.get_packed_dim(name) - is_transposed = self.quant_config.is_transposed(name) - if is_transposed: - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - loaded_weight = loaded_weight.T - - is_attention_weight = False - for weight_name, shard_size, offset in attention_weight_specs: + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "qkv_proj")] - if is_transposed: - param = param.T - - if packed_dim is not None: - shard_dim = 0 if not is_transposed else 1 - if packed_dim == shard_dim: - shard_size //= self.quant_config.pack_factor - offset //= self.quant_config.pack_factor - - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[offset:offset + shard_size] - assert param_slice.shape == loaded_weight.shape - - param_slice.copy_(loaded_weight) - is_attention_weight = True - break - if is_attention_weight: - continue - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - if is_transposed: - param = param.T - - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_gate_up_weight: - continue - - param = state_dict[name] - if is_transposed: - param = param.T - - if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tensor_model_parallel_rank) - continue - - load_tensor_parallel_weights(param, loaded_weight, name, - column_parallel_weights, - row_parallel_weights, - tensor_model_parallel_rank) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 4a66c5b5dec6c..30ccb9a4295c9 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -10,15 +10,17 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttentionWithALiBi +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (convert_pyslice_to_tensor, - hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -39,7 +41,11 @@ def _get_alibi_slopes( class MptAttention(nn.Module): - def __init__(self, config: MptConfig): + def __init__( + self, + config: MptConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.d_model = config.d_model self.total_num_heads = config.n_heads @@ -49,11 +55,13 @@ def __init__(self, config: MptConfig): assert not config.attn_config.prefix_lm assert config.attn_config.alibi - self.qkv_proj = ColumnParallelLinear( + # pylint: disable=invalid-name + self.Wqkv = QKVParallelLinear( self.d_model, - 3 * self.d_model, + self.d_model // self.total_num_heads, + self.total_num_heads, bias=not config.no_bias, - gather_output=False, + linear_method=linear_method, ) if self.qk_ln: self.q_ln = nn.LayerNorm(self.d_model) @@ -62,7 +70,7 @@ def __init__(self, config: MptConfig): self.d_model, self.d_model, bias=not config.no_bias, - input_is_parallel=True, + linear_method=linear_method, ) tp_world_size = get_tensor_model_parallel_world_size() @@ -91,7 +99,7 @@ def forward( cache_event: Optional[torch.cuda.Event], ) -> torch.Tensor: del position_ids # unused. - qkv, _ = self.qkv_proj(hidden_states) + qkv, _ = self.Wqkv(hidden_states) if self.clip_qkv is not None: qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv) q, k, v = qkv.chunk(chunks=3, dim=-1) @@ -107,7 +115,11 @@ def forward( class MptMLP(nn.Module): - def __init__(self, config: MptConfig): + def __init__( + self, + config: MptConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.d_model expansion_ratio = config.expansion_ratio @@ -116,14 +128,14 @@ def __init__(self, config: MptConfig): hidden_size, intermediate_size, bias=not config.no_bias, - gather_output=False, + linear_method=linear_method, ) self.act = get_act_fn("gelu") self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=not config.no_bias, - input_is_parallel=True, + linear_method=linear_method, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -135,13 +147,17 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MptBlock(nn.Module): - def __init__(self, config: MptConfig): + def __init__( + self, + config: MptConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MptAttention(config) + self.attn = MptAttention(config, linear_method) self.norm_2 = nn.LayerNorm(hidden_size) - self.ffn = MptMLP(config) + self.ffn = MptMLP(config, linear_method) def forward( self, @@ -168,7 +184,11 @@ def forward( class MptModel(nn.Module): - def __init__(self, config: MptConfig): + def __init__( + self, + config: MptConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() assert config.embedding_fraction == 1.0 assert config.norm_type == "low_precision_layernorm" @@ -178,7 +198,7 @@ def __init__(self, config: MptConfig): config.d_model, ) self.blocks = nn.ModuleList( - [MptBlock(config) for _ in range(config.n_layers)]) + [MptBlock(config, linear_method) for _ in range(config.n_layers)]) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -215,14 +235,17 @@ def forward( class MptForCausalLM(nn.Module): - def __init__(self, config: MptConfig): + def __init__( + self, + config: MptConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config assert config.tie_word_embeddings + self.linear_method = linear_method - self.transformer = MptModel(config) - # TODO(zhuohan): create a new weight after implementing pipeline - # parallelism + self.transformer = MptModel(config, linear_method) self.lm_head_weight = self.transformer.wte.weight self.sampler = Sampler(config.vocab_size) @@ -240,45 +263,15 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = ["wte.weight", "up_proj.weight", "up_proj.bias"] - _row_parallel_weights = ["out_proj.weight", "down_proj.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tp_world_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): - if "Wqkv" in name: - # NOTE(woosuk): MPT's fused QKV has the shape of - # [3 * num_heads * head_size, hidden_size]. - # When tensor model parallelism is used, we need to shard - # the weight along the hidden dimension. - total_num_heads = self.config.num_attention_heads - hidden_size = self.config.hidden_size - head_size = hidden_size // total_num_heads - num_heads = total_num_heads // tp_world_size - head_start = tp_rank * num_heads - head_end = (tp_rank + 1) * num_heads - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - if name.endswith(".weight"): - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size, hidden_size) - loaded_weight = loaded_weight[:, head_start:head_end, :, :] - loaded_weight = loaded_weight.reshape(-1, hidden_size) - elif name.endswith(".bias"): - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size) - loaded_weight = loaded_weight[:, head_start:head_end, :] - loaded_weight = loaded_weight.reshape(-1) - else: - raise ValueError(f"Unexpected parameter name {name}") - name = name.replace("Wqkv", "qkv_proj") - param = state_dict[name] - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, tp_rank) + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 5295c73981856..2dde92577bff6 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -30,14 +30,18 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.attention import PagedAttention +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + LinearMethodBase, + QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import (hf_model_weights_iterator, - load_tensor_parallel_weights) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import (VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear) + get_tensor_model_parallel_world_size) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -63,6 +67,7 @@ def __init__( embed_dim: int, num_heads: int, bias: bool = True, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.embed_dim = embed_dim @@ -74,17 +79,18 @@ def __init__( self.head_dim = embed_dim // total_num_heads self.scaling = self.head_dim**-0.5 - self.qkv_proj = ColumnParallelLinear( + self.qkv_proj = QKVParallelLinear( embed_dim, - 3 * embed_dim, + self.head_dim, + total_num_heads, bias=bias, - gather_output=False, + linear_method=linear_method, ) self.out_proj = RowParallelLinear( embed_dim, embed_dim, bias=bias, - input_is_parallel=True, + linear_method=linear_method, ) self.attn = PagedAttention(self.num_heads, self.head_dim, @@ -108,7 +114,11 @@ def forward( class OPTDecoderLayer(nn.Module): - def __init__(self, config: OPTConfig): + def __init__( + self, + config: OPTConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.embed_dim = config.hidden_size @@ -116,6 +126,7 @@ def __init__(self, config: OPTConfig): embed_dim=self.embed_dim, num_heads=config.num_attention_heads, bias=config.enable_bias, + linear_method=linear_method, ) self.do_layer_norm_before = config.do_layer_norm_before self.activation_fn = get_act_fn(config.activation_function) @@ -127,13 +138,13 @@ def __init__(self, config: OPTConfig): self.embed_dim, config.ffn_dim, bias=config.enable_bias, - gather_output=False, + linear_method=linear_method, ) self.fc2 = RowParallelLinear( config.ffn_dim, self.embed_dim, bias=config.enable_bias, - input_is_parallel=True, + linear_method=linear_method, ) self.final_layer_norm = nn.LayerNorm( self.embed_dim, @@ -177,7 +188,11 @@ def forward( class OPTDecoder(nn.Module): - def __init__(self, config: OPTConfig): + def __init__( + self, + config: OPTConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -194,16 +209,18 @@ def __init__(self, config: OPTConfig): # Project out & in will be replicated if they exist. if config.word_embed_proj_dim != config.hidden_size: - self.project_out = nn.Linear(config.hidden_size, - config.word_embed_proj_dim, - bias=False) + self.project_out = ReplicatedLinear(config.hidden_size, + config.word_embed_proj_dim, + bias=False, + linear_method=linear_method) else: self.project_out = None if config.word_embed_proj_dim != config.hidden_size: - self.project_in = nn.Linear(config.word_embed_proj_dim, - config.hidden_size, - bias=False) + self.project_in = ReplicatedLinear(config.word_embed_proj_dim, + config.hidden_size, + bias=False, + linear_method=linear_method) else: self.project_in = None @@ -218,8 +235,10 @@ def __init__(self, config: OPTConfig): else: self.final_layer_norm = None - self.layers = nn.ModuleList( - [OPTDecoderLayer(config) for _ in range(config.num_hidden_layers)]) + self.layers = nn.ModuleList([ + OPTDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) def forward( self, @@ -253,9 +272,13 @@ def forward( class OPTModel(nn.Module): - def __init__(self, config: OPTConfig): + def __init__( + self, + config: OPTConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() - self.decoder = OPTDecoder(config) + self.decoder = OPTDecoder(config, linear_method) def forward( self, @@ -271,12 +294,15 @@ def forward( class OPTForCausalLM(nn.Module): - def __init__(self, config): + def __init__( + self, + config, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config - self.model = OPTModel(config) - # TODO(zhuohan): create a new weight after implementing pipeline - # parallelism + self.linear_method = linear_method + self.model = OPTModel(config, linear_method) self.lm_head_weight = self.model.decoder.embed_tokens.weight self.sampler = Sampler(config.vocab_size) @@ -294,48 +320,31 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = [ - "embed_tokens.weight", "fc1.weight", "fc1.bias" - ] - _row_parallel_weights = ["out_proj.weight", "fc2.weight"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() - + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "lm_head.weight" in name: continue - - if name.startswith("decoder."): - name = "model." + name - - is_attention_weight = False - for stride_id, att_weight_name in enumerate( - ["q_proj", "k_proj", "v_proj"]): - if att_weight_name not in name: + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: continue - param = state_dict[name.replace(att_weight_name, "qkv_proj")] - shard_size = param.shape[0] // 3 - loaded_weight = loaded_weight[ - shard_size * tensor_model_parallel_rank:shard_size * - (tensor_model_parallel_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_attention_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_attention_weight: - continue - - param = state_dict[name] - load_tensor_parallel_weights(param, loaded_weight, name, - self._column_parallel_weights, - self._row_parallel_weights, - tensor_model_parallel_rank) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index bd5280b35cc34..45710edcc0bb4 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -15,24 +15,19 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.weight_utils import ( - convert_pyslice_to_tensor, - hf_model_weights_iterator, - load_padded_tensor_parallel_vocab, - load_tensor_parallel_weights, -) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from vllm.model_executor.parallel_utils.layers import ( - VocabParallelEmbedding, - ColumnParallelLinear, - RowParallelLinear, -) + get_tensor_model_parallel_world_size) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs.qwen import QWenConfig @@ -46,20 +41,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str = "silu", + linear_method: Optional[LinearMethodBase] = None, ): super().__init__() - self.gate_up_proj = ColumnParallelLinear( - hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - ) - self.c_proj = RowParallelLinear( - intermediate_size, - hidden_size, + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, bias=False, - input_is_parallel=True, - ) + linear_method=linear_method) + self.c_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -74,12 +66,15 @@ def forward(self, x): class QWenAttention(nn.Module): - def __init__(self, - hidden_size: int, - num_heads: int, - max_position_embeddings: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None): + def __init__( + self, + hidden_size: int, + num_heads: int, + max_position_embeddings: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.hidden_size = hidden_size tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( @@ -90,18 +85,18 @@ def __init__(self, tensor_model_parallel_world_size) self.head_dim = hidden_size // self.total_num_heads - # pylint: disable=invalid-name - self.c_attn = ColumnParallelLinear( + self.c_attn = QKVParallelLinear( hidden_size, - 3 * hidden_size, + self.head_dim, + self.total_num_heads, bias=True, - gather_output=False, + linear_method=linear_method, ) self.c_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - input_is_parallel=True, + linear_method=linear_method, ) self.scaling = self.head_dim**-0.5 self.attn = PagedAttentionWithRoPE( @@ -134,7 +129,11 @@ def forward( class QWenBlock(nn.Module): - def __init__(self, config: QWenConfig): + def __init__( + self, + config: QWenConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -144,11 +143,14 @@ def __init__(self, config: QWenConfig): config.num_attention_heads, config.max_position_embeddings, rope_theta=rope_theta, - rope_scaling=rope_scaling) + rope_scaling=rope_scaling, + linear_method=linear_method) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2) + self.mlp = QWenMLP(config.hidden_size, + config.intermediate_size // 2, + linear_method=linear_method) def forward( self, @@ -180,18 +182,23 @@ def forward( class QWenModel(nn.Module): - def __init__(self, config: QWenConfig): + def __init__( + self, + config: QWenConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config self.vocab_size = config.vocab_size - vocab_size = ((config.vocab_size + 63) // 64) * 64 self.wte = VocabParallelEmbedding( - vocab_size, + config.vocab_size, config.hidden_size, ) - self.h = nn.ModuleList( - [QWenBlock(config) for _ in range(config.num_hidden_layers)]) + self.h = nn.ModuleList([ + QWenBlock(config, linear_method) + for _ in range(config.num_hidden_layers) + ]) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) def forward( @@ -222,17 +229,16 @@ def forward( class QWenLMHeadModel(nn.Module): - def __init__(self, config: QWenConfig): + def __init__( + self, + config: QWenConfig, + linear_method: Optional[LinearMethodBase] = None, + ): super().__init__() self.config = config - self.transformer = QWenModel(config) - vocab_size = ((config.vocab_size + 63) // 64) * 64 - self.lm_head = ColumnParallelLinear( - config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - ) + self.linear_method = linear_method + self.transformer = QWenModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) def forward( @@ -249,75 +255,30 @@ def forward( input_metadata) return next_tokens - _column_parallel_weights = [] - _row_parallel_weights = ["c_proj.weight"] - - def load_weights( - self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None, - ): - tp_world_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - state_dict = self.state_dict() - + def load_weights(self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "w2", 0), + ("gate_up_proj", "w1", 1), + ] + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - - if "c_attn" in name: - total_num_heads = self.config.num_attention_heads - hidden_size = self.config.hidden_size - head_size = hidden_size // total_num_heads - num_heads = total_num_heads // tp_world_size - head_start = tp_rank * num_heads - head_end = (tp_rank + 1) * num_heads - - if "weight" in name: - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size, hidden_size) - loaded_weight = loaded_weight[:, head_start:head_end, :, :] - loaded_weight = loaded_weight.reshape(-1, hidden_size) - elif "bias" in name: - loaded_weight = loaded_weight.view(3, total_num_heads, - head_size) - loaded_weight = loaded_weight[:, head_start:head_end, :] - loaded_weight = loaded_weight.reshape(-1) - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["w2", "w1"]): + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * - (tp_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_gate_up_weight: - continue - - param = state_dict[name] - - if "wte" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tp_rank) - continue - - load_tensor_parallel_weights( - param, - loaded_weight, - name, - self._column_parallel_weights, - self._row_parallel_weights, - tp_rank, - ) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/models/yi.py b/vllm/model_executor/models/yi.py index a0958f6164e49..204c33ed42825 100644 --- a/vllm/model_executor/models/yi.py +++ b/vllm/model_executor/models/yi.py @@ -33,17 +33,19 @@ from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.attention import PagedAttentionWithRoPE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.sampler import Sampler -from vllm.model_executor.layers.quantized_linear import ParallelLinear +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding, ParallelLMHead) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) -from vllm.model_executor.parallel_utils.layers import VocabParallelEmbedding -from vllm.model_executor.quantization_utils import QuantizationConfig -from vllm.model_executor.weight_utils import ( - convert_pyslice_to_tensor, hf_model_weights_iterator, - load_tensor_parallel_weights, load_padded_tensor_parallel_vocab) + get_tensor_model_parallel_world_size) +from vllm.model_executor.weight_utils import (default_weight_loader, + hf_model_weights_iterator) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -56,19 +58,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() - self.gate_up_proj = ParallelLinear.column(hidden_size, - 2 * intermediate_size, - bias=False, - gather_output=False, - quant_config=quant_config) - self.down_proj = ParallelLinear.row(intermediate_size, - hidden_size, - bias=False, - input_is_parallel=True, - quant_config=quant_config) + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + linear_method=linear_method) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + linear_method=linear_method) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -91,7 +91,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -109,7 +109,6 @@ def __init__( # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) - num_kv_heads_replicas = max(1, tp_size // self.total_num_kv_heads) self.head_dim = hidden_size // self.total_num_heads self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -117,21 +116,19 @@ def __init__( self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - self.qkv_proj = ParallelLinear.column( + self.qkv_proj = QKVParallelLinear( hidden_size, - (self.total_num_heads + - 2 * self.total_num_kv_heads * num_kv_heads_replicas) * self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, bias=False, - gather_output=False, - quant_config=quant_config, + linear_method=linear_method, ) - self.o_proj = ParallelLinear.row( + self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - input_is_parallel=True, - quant_config=quant_config, + linear_method=linear_method, ) self.attn = PagedAttentionWithRoPE( self.num_heads, @@ -165,11 +162,10 @@ class YiDecoderLayer(nn.Module): def __init__( self, config: YiConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size - # Requires transformers > 4.32.0 rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) max_position_embeddings = getattr(config, "max_position_embeddings", @@ -181,13 +177,13 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - quant_config=quant_config, + linear_method=linear_method, ) self.mlp = YiMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - quant_config=quant_config, + linear_method=linear_method, ) self.ln1 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.ln2 = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -225,20 +221,18 @@ class YiModel(nn.Module): def __init__( self, config: YiConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size - - vocab_size = ((config.vocab_size + 63) // 64) * 64 self.embed_tokens = VocabParallelEmbedding( - vocab_size, + config.vocab_size, config.hidden_size, ) self.layers = nn.ModuleList([ - YiDecoderLayer(config, quant_config) + YiDecoderLayer(config, linear_method) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -274,19 +268,13 @@ class YiForCausalLM(nn.Module): def __init__( self, config: YiConfig, - quant_config: Optional[QuantizationConfig] = None, + linear_method: Optional[LinearMethodBase] = None, ) -> None: super().__init__() self.config = config - self.quant_config = quant_config - self.model = YiModel(config, quant_config) - vocab_size = ((config.vocab_size + 63) // 64) * 64 - # NOTE: The LM head is not quantized. - self.lm_head = ParallelLinear.column(config.hidden_size, - vocab_size, - bias=False, - gather_output=False, - quant_config=None) + self.linear_method = linear_method + self.model = YiModel(config, linear_method) + self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.sampler = Sampler(config.vocab_size) def forward( @@ -303,124 +291,33 @@ def forward( input_metadata) return next_tokens - _column_parallel_layers = [] - _row_parallel_layers = ["o_proj", "down_proj"] - def load_weights(self, model_name_or_path: str, cache_dir: Optional[str] = None, load_format: str = "auto", revision: Optional[str] = None): - if self.quant_config is None: - col_weight_suffixes = ["weight"] - row_weight_suffixes = ["weight"] - else: - col_weight_suffixes = ( - self.quant_config.get_col_parallel_tensor_names()) - row_weight_suffixes = ( - self.quant_config.get_row_parallel_tensor_names()) - - column_parallel_weights: List[str] = [] - for layer in self._column_parallel_layers: - for suffix in col_weight_suffixes: - column_parallel_weights.append(f"{layer}.{suffix}") - row_parallel_weights: List[str] = [] - for layer in self._row_parallel_layers: - for suffix in row_weight_suffixes: - row_parallel_weights.append(f"{layer}.{suffix}") - - tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() - q_proj_shard_size = (self.config.hidden_size // tp_size) - num_kv_heads_replicas = max(1, - tp_size // self.config.num_key_value_heads) - num_kv_heads_per_gpu = max(1, - self.config.num_key_value_heads // tp_size) - kv_proj_shard_size = (self.config.hidden_size // - self.config.num_attention_heads * - num_kv_heads_per_gpu) - attention_weight_specs = [ - # (weight_name, shard_size, offset) - ("q_proj", q_proj_shard_size, 0), - ("k_proj", kv_proj_shard_size, q_proj_shard_size), - ("v_proj", kv_proj_shard_size, - q_proj_shard_size + kv_proj_shard_size), + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), ] - state_dict = self.state_dict() - + params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( model_name_or_path, cache_dir, load_format, revision): if "rotary_emb.inv_freq" in name: continue - - packed_dim = None - is_transposed = False - if self.quant_config is not None: - packed_dim = self.quant_config.get_packed_dim(name) - is_transposed = self.quant_config.is_transposed(name) - if is_transposed: - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - loaded_weight = loaded_weight.T - - is_attention_weight = False - for weight_name, shard_size, offset in attention_weight_specs: + for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: continue - param = state_dict[name.replace(weight_name, "qkv_proj")] - if is_transposed: - param = param.T - - if packed_dim is not None: - shard_dim = 0 if not is_transposed else 1 - if packed_dim == shard_dim: - shard_size //= self.quant_config.pack_factor - offset //= self.quant_config.pack_factor - - if weight_name in ["k_proj", "v_proj"]: - shard_id = tp_rank // num_kv_heads_replicas - else: - shard_id = tp_rank - loaded_weight = loaded_weight[shard_size * - shard_id:shard_size * - (shard_id + 1)] - param_slice = param.data[offset:offset + shard_size] - assert param_slice.shape == loaded_weight.shape - - param_slice.copy_(loaded_weight) - is_attention_weight = True - break - if is_attention_weight: - continue - - is_gate_up_weight = False - for stride_id, weight_name in enumerate(["gate_proj", "up_proj"]): - if weight_name not in name: - continue - param = state_dict[name.replace(weight_name, "gate_up_proj")] - if is_transposed: - param = param.T - - shard_size = param.shape[0] // 2 - loaded_weight = loaded_weight[shard_size * tp_rank:shard_size * - (tp_rank + 1)] - param_slice = param.data[shard_size * stride_id:shard_size * - (stride_id + 1)] - assert param_slice.shape == loaded_weight.shape - param_slice.copy_(loaded_weight) - is_gate_up_weight = True + param = params_dict[name.replace(weight_name, param_name)] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) break - if is_gate_up_weight: - continue - - param = state_dict[name] - if is_transposed: - param = param.T - - if "embed_tokens" in name or "lm_head" in name: - load_padded_tensor_parallel_vocab(param, loaded_weight, - tp_rank) - continue - - load_tensor_parallel_weights(param, loaded_weight, name, - column_parallel_weights, - row_parallel_weights, tp_rank) + else: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/model_executor/parallel_utils/layers.py b/vllm/model_executor/parallel_utils/layers.py deleted file mode 100644 index c1aea2c1d5543..0000000000000 --- a/vllm/model_executor/parallel_utils/layers.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright 2023 The vLLM team. -# Adapted from -# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py -# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. - -# Parts of the code here are adapted from PyTorch -# repo: https://github.com/pytorch/pytorch -from typing import Optional - -import torch -import torch.nn.functional as F -from torch.nn.parameter import Parameter - -from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, -) -from vllm.model_executor.quantization_utils import QuantizationConfig -from vllm.model_executor.parallel_utils.communication_op import ( - tensor_model_parallel_all_reduce, tensor_model_parallel_all_gather) - -from vllm.model_executor.parallel_utils.utils import ( - divide, - VocabUtility, - split_tensor_along_last_dim, -) - - -class VocabParallelEmbedding(torch.nn.Module): - """Embedding parallelized in the vocabulary dimension. - - This is mainly adapted from torch.nn.Embedding and all the default - values are kept. - Arguments: - num_embeddings: vocabulary size. - embedding_dim: size of hidden state. - params_dtype: type of the parameters. - """ - - def __init__(self, - num_embeddings: int, - embedding_dim: int, - params_dtype: Optional[torch.dtype] = None): - super().__init__() - - # Keep the input dimensions. - self.num_embeddings = num_embeddings - self.embedding_dim = embedding_dim - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - self.tp_size = get_tensor_model_parallel_world_size() - # TODO: Handle vocab padding here. - # Divide the weight matrix along the vocaburaly dimension. - self.vocab_start_index, self.vocab_end_index = ( - VocabUtility.vocab_range_from_global_vocab_size( - self.num_embeddings, get_tensor_model_parallel_rank(), - self.tp_size)) - self.num_embeddings_per_partition = (self.vocab_end_index - - self.vocab_start_index) - - self.weight = Parameter( - torch.empty(self.num_embeddings_per_partition, - self.embedding_dim, - device=torch.cuda.current_device(), - dtype=params_dtype)) - - def forward(self, input_): - if self.tp_size > 1: - # Build the mask. - input_mask = ((input_ < self.vocab_start_index) | - (input_ >= self.vocab_end_index)) - # Mask the input. - masked_input = input_.clone() - self.vocab_start_index - masked_input[input_mask] = 0 - else: - masked_input = input_ - # Get the embeddings. - output_parallel = F.embedding(masked_input, self.weight) - # Mask the output embedding. - if self.tp_size > 1: - output_parallel[input_mask, :] = 0.0 - # Reduce across all the model parallel GPUs. - output = tensor_model_parallel_all_reduce(output_parallel) - return output - - -class ColumnParallelLinear(torch.nn.Module): - """Linear layer with column parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its second dimension as A = [A_1, ..., A_p]. - - Arguments: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - - Keyword Arguments - bias: If true, add bias - gather_output: If true, call all-gather on output and make Y available - to all GPUs, otherwise, every GPU will have its output - which is Y_i = XA_i - skip_bias_add: This was added to enable performance optimizations where - bias can be fused with other element-wise operations. we - skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configuration. - """ - - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - gather_output: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.gather_output = gather_output - # Divide the weight matrix along the last dimension. - self.tp_size = get_tensor_model_parallel_world_size() - self.output_size_per_partition = divide(output_size, self.tp_size) - self.skip_bias_add = skip_bias_add - self.quant_config = quant_config - - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - # Parameters. - # NOTE: torch.nn.functional.linear performs XA^T + b and as a result - # we allocate the transpose. - self.create_weights(params_dtype) - - if bias: - self.bias = Parameter( - torch.empty(self.output_size_per_partition, - device=torch.cuda.current_device(), - dtype=params_dtype)) - else: - self.register_parameter('bias', None) - - def create_weights(self, dtype: torch.dtype) -> None: - self.weight = Parameter( - torch.empty(self.output_size_per_partition, - self.input_size, - device=torch.cuda.current_device(), - dtype=dtype)) - - def apply_weights( - self, - x: torch.Tensor, - bias: Optional[torch.Tensor], - ) -> torch.Tensor: - return F.linear(x, self.weight, bias) - - def forward(self, input_): - """Forward of ColumnParallelLinear - - Args: - input_: Tensor whose last dimension is `input_size`. - - Returns: - - output - - bias - """ - bias = self.bias if not self.skip_bias_add else None - - input_parallel = input_ - # Matrix multiply. - output_parallel = self.apply_weights(input_parallel, bias) - if self.gather_output: - # All-gather across the partitions. - output = tensor_model_parallel_all_gather(output_parallel) - else: - output = output_parallel - output_bias = self.bias if self.skip_bias_add else None - return output, output_bias - - -class RowParallelLinear(torch.nn.Module): - """Linear layer with row parallelism. - - The linear layer is defined as Y = XA + b. A is parallelized along - its first dimension and X along its second dimension as: - - - - | A_1 | - | . | - A = | . | X = [X_1, ..., X_p] - | . | - | A_p | - - - - Arguments: - input_size: first dimension of matrix A. - output_size: second dimension of matrix A. - - Keyword Arguments: - bias: If true, add bias. Note that bias is not parallelized. - input_is_parallel: If true, we assume that the input is already - split across the GPUs and we do not split - again. - skip_bias_add: This was added to enable performance optimization where - bias can be fused with other element-wise operations. - We skip adding bias but instead return it. - params_dtype: Data type for the parameters. - quant_config: Quantization configuration. - """ - - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = False, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - ): - super().__init__() - # Keep input parameters - self.input_size = input_size - self.output_size = output_size - self.input_is_parallel = input_is_parallel - self.reduce_results = reduce_results - if params_dtype is None: - params_dtype = torch.get_default_dtype() - - # Divide the weight matrix along the last dimension. - self.tp_size = get_tensor_model_parallel_world_size() - self.input_size_per_partition = divide(input_size, self.tp_size) - self.skip_bias_add = skip_bias_add - self.quant_config = quant_config - - self.create_weights(params_dtype) - - if not reduce_results and (bias and not skip_bias_add): - raise ValueError('When not reduce the results, adding bias to the ' - 'results can lead to incorrect results') - - if bias: - self.bias = Parameter( - torch.empty(self.output_size, - device=torch.cuda.current_device(), - dtype=params_dtype)) - - # Always initialize bias to zero. - with torch.no_grad(): - self.bias.zero_() - else: - self.register_parameter('bias', None) - - def create_weights(self, dtype: torch.dtype) -> None: - self.weight = Parameter( - torch.empty(self.output_size, - self.input_size_per_partition, - device=torch.cuda.current_device(), - dtype=dtype)) - - def apply_weights(self, x: torch.Tensor) -> torch.Tensor: - return F.linear(x, self.weight) - - def forward(self, input_): - """Forward of RowParallelLinear - - Args: - input_: tensor whose last dimension is `input_size`. If - `input_is_parallel` is set, then the last dimension - is `input_size // tp_size`. - - Returns: - - output - - bias - """ - # Set up backprop all-reduce. - if self.input_is_parallel: - input_parallel = input_ - else: - # TODO: simplify code below - tp_rank = get_tensor_model_parallel_rank() - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() - - # Matrix multiply. - output_parallel = self.apply_weights(input_parallel) - if self.reduce_results and self.tp_size > 1: - output_ = tensor_model_parallel_all_reduce(output_parallel) - else: - output_ = output_parallel - - if not self.skip_bias_add: - output = output_ + self.bias if self.bias is not None else output_ - output_bias = None - else: - output = output_ - output_bias = self.bias - return output, output_bias diff --git a/vllm/model_executor/parallel_utils/utils.py b/vllm/model_executor/parallel_utils/utils.py index 004a81f130a76..0cd420c8e11b5 100644 --- a/vllm/model_executor/parallel_utils/utils.py +++ b/vllm/model_executor/parallel_utils/utils.py @@ -2,7 +2,7 @@ # Adapted from # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. -from typing import List, Sequence +from typing import Sequence import torch @@ -24,7 +24,7 @@ def split_tensor_along_last_dim( tensor: torch.Tensor, num_partitions: int, contiguous_split_chunks: bool = False, -) -> List[torch.Tensor]: +) -> Sequence[torch.Tensor]: """ Split a tensor along its last dimension. Arguments: @@ -46,25 +46,3 @@ def split_tensor_along_last_dim( return tuple(chunk.contiguous() for chunk in tensor_list) return tensor_list - - -class VocabUtility: - """ Split the vocabulary into `world_size` chunks and return the first - and last index of the vocabulary belonging to the `rank` - partition: Note that indices in [fist, last) - - """ - - @staticmethod - def vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size: int, rank: int) -> Sequence[int]: - index_f = rank * per_partition_vocab_size - index_l = index_f + per_partition_vocab_size - return index_f, index_l - - @staticmethod - def vocab_range_from_global_vocab_size(global_vocab_size: int, rank: int, - world_size: int) -> Sequence[int]: - per_partition_vocab_size = divide(global_vocab_size, world_size) - return VocabUtility.vocab_range_from_per_partition_vocab_size( - per_partition_vocab_size, rank) diff --git a/vllm/model_executor/quantization_utils/__init__.py b/vllm/model_executor/quantization_utils/__init__.py deleted file mode 100644 index 345f6494bf836..0000000000000 --- a/vllm/model_executor/quantization_utils/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from typing import Type - -from vllm.model_executor.quantization_utils.awq import AWQConfig -from vllm.model_executor.quantization_utils.base import QuantizationConfig -from vllm.model_executor.quantization_utils.squeezellm import SqueezeLLMConfig - -_QUANTIZATION_REGISTRY = { - "awq": AWQConfig, - "squeezellm": SqueezeLLMConfig, -} - - -def get_quant_class(quantization: str) -> Type[QuantizationConfig]: - if quantization not in _QUANTIZATION_REGISTRY: - raise ValueError(f"Invalid quantization method: {quantization}") - return _QUANTIZATION_REGISTRY[quantization] - - -__all__ = [ - "QuantizationConfig", - "get_quant_class", -] diff --git a/vllm/model_executor/quantization_utils/awq.py b/vllm/model_executor/quantization_utils/awq.py deleted file mode 100644 index ebc89560a4477..0000000000000 --- a/vllm/model_executor/quantization_utils/awq.py +++ /dev/null @@ -1,76 +0,0 @@ -from typing import Any, Dict, List - -import torch - -from vllm.model_executor.quantization_utils.base import QuantizationConfig - - -class AWQConfig(QuantizationConfig): - """Config class for AWQ. - - Reference: https://arxiv.org/abs/2306.00978 - """ - - def __init__( - self, - weight_bits: int, - group_size: int, - zero_point: bool, - ) -> None: - self.weight_bits = weight_bits - self.group_size = group_size - self.zero_point = zero_point - - if self.weight_bits != 4: - raise ValueError( - "Currently, only 4-bit weight quantization is supported for " - f"AWQ, but got {self.weight_bits} bits.") - self.pack_factor = 32 // self.weight_bits - - def __repr__(self) -> str: - return (f"AWQConfig(weight_bits={self.weight_bits}, " - f"group_size={self.group_size}, " - f"zero_point={self.zero_point})") - - @classmethod - def get_name(cls) -> str: - return "awq" - - @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.half] - - @classmethod - def get_min_capability(cls) -> int: - # The AWQ kernel only supports Turing or newer GPUs. - return 75 - - @classmethod - def get_config_filenames(cls) -> List[str]: - return [ - "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq - "quantize_config.json", # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq # pylint: disable=line-too-long - ] - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": - weight_bits = cls.get_from_keys(config, ["w_bit", "bits"]) - group_size = cls.get_from_keys(config, ["q_group_size", "group_size"]) - zero_point = cls.get_from_keys(config, ["zero_point"]) - return cls(weight_bits, group_size, zero_point) - - @classmethod - def get_packed_tensors(cls) -> Dict[str, int]: - return {"qweight": 1, "qzeros": 1} - - @classmethod - def get_transposed_tensor_names(cls) -> List[str]: - return ["qweight", "qzeros", "scales"] - - @classmethod - def get_col_parallel_tensor_names(cls) -> List[str]: - return ["qweight", "qzeros", "scales"] - - @classmethod - def get_row_parallel_tensor_names(cls) -> List[str]: - return ["qweight", "qzeros", "scales"] diff --git a/vllm/model_executor/quantization_utils/base.py b/vllm/model_executor/quantization_utils/base.py deleted file mode 100644 index a70a7a8631e41..0000000000000 --- a/vllm/model_executor/quantization_utils/base.py +++ /dev/null @@ -1,85 +0,0 @@ -from typing import Any, Dict, List, Optional - -import torch - - -class QuantizationConfig: - - @classmethod - def get_name(cls) -> str: - """Name of the quantization method.""" - raise NotImplementedError - - @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: - """List of supported activation dtypes.""" - raise NotImplementedError - - @classmethod - def get_min_capability(cls) -> int: - """Minimum GPU capability to support the quantization method. - - E.g., 70 for Volta, 75 for Turing, 80 for Ampere. - This requirement is due to the custom CUDA kernels used by the - quantization method. - """ - raise NotImplementedError - - @classmethod - def get_config_filenames(cls) -> List[str]: - """List of filenames to search for in the model directory.""" - raise NotImplementedError - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": - """Create a config class from the model's quantization config.""" - raise NotImplementedError - - @staticmethod - def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: - """Get a value from the model's quantization config.""" - for key in keys: - if key in config: - return config[key] - raise ValueError(f"Cannot find any of {keys} in the model's " - "quantization config.") - - @classmethod - def get_packed_tensors(cls) -> Dict[str, int]: - """Returns a dictionary of packed tensor names and their pack dims.""" - raise NotImplementedError - - @classmethod - def get_packed_dim(cls, tensor_name: str) -> Optional[int]: - """Returns the pack dim of a tensor if it is packed. - - A tensor is considered packed if each element in the tensor is a - packed representation of multiple elements in the original tensor. - For example, an INT32 element in the tensor may represent 8 INT4 - elements in the original tensor. - If the tensor is not packed, returns None. - """ - packed_tensors = cls.get_packed_tensors() - for packed_tensor_name, pack_dim in packed_tensors.items(): - if packed_tensor_name in tensor_name: - return pack_dim - return None - - @classmethod - def get_transposed_tensor_names(cls) -> List[str]: - raise NotImplementedError - - @classmethod - def is_transposed(cls, tensor_name: str) -> bool: - """Returns True if a tensor is transposed relative to nn.Linear.weight. - """ - return any(tag in tensor_name - for tag in cls.get_transposed_tensor_names()) - - @classmethod - def get_col_parallel_tensor_names(cls) -> List[str]: - raise NotImplementedError - - @classmethod - def get_row_parallel_tensor_names(cls) -> List[str]: - raise NotImplementedError diff --git a/vllm/model_executor/quantization_utils/squeezellm.py b/vllm/model_executor/quantization_utils/squeezellm.py deleted file mode 100644 index 8a1db3e233217..0000000000000 --- a/vllm/model_executor/quantization_utils/squeezellm.py +++ /dev/null @@ -1,65 +0,0 @@ -from typing import Any, Dict, List - -import torch - -from vllm.model_executor.quantization_utils.base import QuantizationConfig - - -class SqueezeLLMConfig(QuantizationConfig): - """Config class for SqueezeLLM. - - Reference: https://arxiv.org/pdf/2306.07629 - """ - - def __init__( - self, - weight_bits: int, - ) -> None: - self.weight_bits = weight_bits - - if self.weight_bits != 4: - raise ValueError( - "Currently, only 4-bit weight quantization is supported for " - f"SqueezeLLM, but got {self.weight_bits} bits.") - - self.pack_factor = 32 // self.weight_bits - - def __repr__(self) -> str: - return f"SqueezeLLMConfig(weight_bits={self.weight_bits})" - - @classmethod - def get_name(cls) -> str: - return "squeezellm" - - @classmethod - def get_supported_act_dtypes(cls) -> List[torch.dtype]: - return [torch.half] - - @classmethod - def get_min_capability(cls) -> int: - return 70 - - @classmethod - def get_config_filenames(cls) -> List[str]: - return ["quant_config.json"] - - @classmethod - def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": - weight_bits = cls.get_from_keys(config, ["wbits"]) - return cls(weight_bits) - - @classmethod - def get_packed_tensors(cls) -> Dict[str, int]: - return {"qweight": 0} - - @classmethod - def get_transposed_tensor_names(cls) -> List[str]: - return ["qweight"] - - @classmethod - def get_col_parallel_tensor_names(cls) -> List[str]: - return ["qweight", "lookup_table"] - - @classmethod - def get_row_parallel_tensor_names(cls) -> List[str]: - return ["qweight"] diff --git a/vllm/model_executor/utils.py b/vllm/model_executor/utils.py index bd74ae96aa19e..336bc1cd005cf 100644 --- a/vllm/model_executor/utils.py +++ b/vllm/model_executor/utils.py @@ -1,5 +1,6 @@ """Utils for model executor.""" import random +from typing import Any, Dict, Optional import numpy as np import torch @@ -11,3 +12,24 @@ def set_random_seed(seed: int) -> None: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) + + +def set_weight_attrs( + weight: torch.Tensor, + weight_attrs: Optional[Dict[str, Any]], +): + """Set attributes on a weight tensor. + + This method is used to set attributes on a weight tensor. This method + will not overwrite existing attributes. + + Args: + weight: The weight tensor. + weight_attrs: A dictionary of attributes to set on the weight tensor. + """ + if weight_attrs is None: + return + for key, value in weight_attrs.items(): + assert not hasattr( + weight, key), (f"Overwriting existing tensor attribute: {key}") + setattr(weight, key, value) diff --git a/vllm/model_executor/weight_utils.py b/vllm/model_executor/weight_utils.py index 027e741c5f6bf..0a05fe7e340b4 100644 --- a/vllm/model_executor/weight_utils.py +++ b/vllm/model_executor/weight_utils.py @@ -13,8 +13,8 @@ from tqdm.auto import tqdm from vllm.logger import init_logger -from vllm.model_executor.quantization_utils import get_quant_class -from vllm.model_executor.quantization_utils.base import QuantizationConfig +from vllm.model_executor.layers.quantization import (get_quantization_config, + QuantizationConfig) logger = init_logger(__name__) @@ -98,7 +98,7 @@ def get_quant_config( hf_folder = model_name_or_path config_files = glob.glob(os.path.join(hf_folder, "*.json")) - quant_cls = get_quant_class(quantization) + quant_cls = get_quantization_config(quantization) quant_config_files = [ f for f in config_files if any( f.endswith(x) for x in quant_cls.get_config_filenames()) @@ -237,7 +237,7 @@ def hf_model_weights_iterator( with safe_open(st_file, framework="pt") as f: for name in f.keys(): param = f.get_slice(name) - yield name, param + yield name, convert_pyslice_to_tensor(param) else: for bin_file in hf_weights_files: state = torch.load(bin_file, map_location="cpu") @@ -262,46 +262,10 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: return x -def load_padded_tensor_parallel_vocab( - param: torch.Tensor, - loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` - tensor_model_parallel_rank: int, -) -> None: - shard_size = param.shape[0] - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - loaded_weight = loaded_weight[start_idx:end_idx] - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - param[:loaded_weight.shape[0]].copy_(loaded_weight) - - -def load_tensor_parallel_weights( - param: torch.Tensor, - loaded_weight: Any, # `torch.Tensor` or `PySafeSlice` - param_name: str, - column_parallel_weight_names: List[str], - row_parallel_weight_names: List[str], - tensor_model_parallel_rank: int, -) -> None: - for p in column_parallel_weight_names: - if p in param_name: - shard_size = param.shape[0] - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - loaded_weight = loaded_weight[start_idx:end_idx] - break - for p in row_parallel_weight_names: - if p in param_name: - shard_size = param.shape[1] - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - loaded_weight = loaded_weight[:, start_idx:end_idx] - break - - loaded_weight = convert_pyslice_to_tensor(loaded_weight) - assert param.shape == loaded_weight.shape, ( - f"{param_name} shape mismatch between model and checkpoint: " - f"{param.shape} != {loaded_weight.shape}") +def default_weight_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + """Default weight loader.""" + assert param.size() == loaded_weight.size() param.data.copy_(loaded_weight)