diff --git a/examples/offline_inference_semi_structured_sparse.py b/examples/offline_inference_semi_structured_sparse.py new file mode 100644 index 0000000000000..118725b4448d4 --- /dev/null +++ b/examples/offline_inference_semi_structured_sparse.py @@ -0,0 +1,12 @@ +from vllm import LLM, SamplingParams + +model = LLM("nm-testing/zephyr-50sparse-24", + sparsity="semi_structured_sparse_w16a16", + enforce_eager=True, + dtype="float16", + tensor_parallel_size=1, + max_model_len=1024) + +sampling_params = SamplingParams(max_tokens=100, temperature=0) +outputs = model.generate("Hello my name is", sampling_params=sampling_params) +print(outputs[0].outputs[0].text) diff --git a/vllm/config.py b/vllm/config.py index d735819c0c2b1..a86fbc3cfde84 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -148,7 +148,7 @@ def _verify_tokenizer_mode(self) -> None: self.tokenizer_mode = tokenizer_mode def _verify_sparsity(self) -> None: - supported_sparsity = ["sparse_w16a16"] + supported_sparsity = ["sparse_w16a16", "semi_structured_sparse_w16a16"] if self.quantization is not None: raise ValueError("Both sparsity and quantization detected. Only " diff --git a/vllm/model_executor/layers/parameters/sparsity.py b/vllm/model_executor/layers/parameters/sparsity.py index 37ddd05d89636..017fb6b825965 100644 --- a/vllm/model_executor/layers/parameters/sparsity.py +++ b/vllm/model_executor/layers/parameters/sparsity.py @@ -1,29 +1,35 @@ import torch -from magic_wand import SparseTensor, SparseBitmaskStorageFormat +from typing import Type +from magic_wand import (SparseTensor, CompressedStorageFormat, + SparseBitmaskStorageFormat) class SparseParameter(SparseTensor): @staticmethod - def __new__( - cls, - shape: torch.Size, - dtype: torch.dtype, - ): + def __new__(cls, + shape: torch.Size, + dtype: torch.dtype, + storage_format_cls: Type[ + CompressedStorageFormat] = SparseBitmaskStorageFormat): assert torch.__version__ > (1, 10), "SparseTensor requires PyTorch 1.11+" + self = torch.Tensor._make_wrapper_subclass(cls, size=shape, dtype=dtype, requires_grad=False) - self.storage_format_cls = SparseBitmaskStorageFormat + self.storage_format_cls = storage_format_cls self.compressed_data = None self.dense_data = None self._is_param = True return self + def has_compressed_data(self) -> bool: + return (self.compressed_data is not None) + def get_dense_data(self) -> torch.Tensor: if self.dense_data is not None: raise ValueError( @@ -39,6 +45,20 @@ def _unpack(self) -> torch.Tensor: dtype=self.dtype, device="cuda") + @classmethod + def _copy(cls, arg0, arg1): + assert arg0.shape == arg1.shape + + if arg0.has_compressed_data(): + arg0.compressed_data.copy_(arg1) + else: + arg0.compressed_data = arg0.storage_format_cls.compress(arg1) + + return arg0 + + def copy_(self, src, non_blocking=False): + return SparseParameter._copy(self, src) + def pack(self) -> None: if self.dense_data is None: raise ValueError("Called pack() but dense_data does not exist.") diff --git a/vllm/model_executor/layers/sparsity/__init__.py b/vllm/model_executor/layers/sparsity/__init__.py index 411d1ff642266..82893916fde80 100644 --- a/vllm/model_executor/layers/sparsity/__init__.py +++ b/vllm/model_executor/layers/sparsity/__init__.py @@ -2,9 +2,11 @@ from vllm.model_executor.layers.sparsity.base_config import SparsityConfig from vllm.model_executor.layers.sparsity.sparse_w16a16 import SparseW16A16Config +from vllm.model_executor.layers.sparsity.semi_structured_sparse_w16a16 import SemiStructuredSparseW16A16Config _SPARSITY_CONFIG_REGISTRY = { "sparse_w16a16": SparseW16A16Config, + "semi_structured_sparse_w16a16": SemiStructuredSparseW16A16Config, } diff --git a/vllm/model_executor/layers/sparsity/base_config.py b/vllm/model_executor/layers/sparsity/base_config.py index aa09fb623bc00..fe46b55cbf39f 100644 --- a/vllm/model_executor/layers/sparsity/base_config.py +++ b/vllm/model_executor/layers/sparsity/base_config.py @@ -2,13 +2,20 @@ from typing import Any, Dict, List import torch +from typing import Type from vllm.model_executor.layers.linear import LinearMethodBase +from magic_wand import CompressedStorageFormat class SparsityConfig(ABC): """Base class for sparsity configs.""" + @abstractmethod + def get_storage_format_cls(self) -> Type[CompressedStorageFormat]: + """Sparse representation format""" + raise NotImplementedError + @abstractmethod def get_name(self) -> str: """Name of the sparse method.""" diff --git a/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py new file mode 100644 index 0000000000000..2cdd34fd0ff1c --- /dev/null +++ b/vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py @@ -0,0 +1,46 @@ +import torch + +from typing import Any, Dict, List, Type +from vllm.model_executor.layers.sparsity.base_config import SparsityConfig +from .sparse_w16a16_linear_method import SparseW16A16LinearMethod +from magic_wand import (CompressedStorageFormat, + SparseSemiStructuredStorageFormat) + + +class SemiStructuredSparseW16A16Config(SparsityConfig): + """Config class for SemiStructuredSparseW16A16.""" + + def __init__(self) -> None: + pass + + def __repr__(self) -> str: + return "SemiStructuredSparseW16A16Config()" + + @classmethod + def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]: + return SparseSemiStructuredStorageFormat + + @classmethod + def get_name(cls) -> str: + return "semi_structured_sparse_w16a16" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.float16, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + # TODO: Update after checks on more GPUs + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["sparsity_config.json"] + + @classmethod + def from_config( + cls, config: Dict[str, Any]) -> "SemiStructuredSparseW16A16Config": + return cls() + + def get_linear_method(self) -> "SparseW16A16LinearMethod": + return SparseW16A16LinearMethod(self, self.get_storage_format_cls()) diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16.py b/vllm/model_executor/layers/sparsity/sparse_w16a16.py index 771fae9b8ff45..69905eab0c0af 100644 --- a/vllm/model_executor/layers/sparsity/sparse_w16a16.py +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16.py @@ -1,11 +1,11 @@ -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Type import torch -import torch.nn.functional as F -from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs from vllm.model_executor.layers.sparsity.base_config import SparsityConfig -from vllm.model_executor.layers.parameters import SparseParameter + +from .sparse_w16a16_linear_method import SparseW16A16LinearMethod +from magic_wand import (CompressedStorageFormat, SparseBitmaskStorageFormat) class SparseW16A16Config(SparsityConfig): @@ -21,6 +21,10 @@ def __init__(self) -> None: def __repr__(self) -> str: return "SparseW16A16Config()" + @classmethod + def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]: + return SparseBitmaskStorageFormat + @classmethod def get_name(cls) -> str: return "sparse_w16a16" @@ -43,57 +47,4 @@ def from_config(cls, config: Dict[str, Any]) -> "SparseW16A16Config": return cls() def get_linear_method(self) -> "SparseW16A16LinearMethod": - return SparseW16A16LinearMethod(self) - - -class SparseW16A16LinearMethod(LinearMethodBase): - """Linear method for Sparse W16A16. - - Args: - sparsity_config: The sparse config. - """ - - def __init__(self, sparsity_config: SparseW16A16Config): - self.sparsity_config = sparsity_config - - def create_weights( - self, - input_size_per_partition: int, - output_size_per_partition: int, - input_size: int, - output_size: int, - params_dtype: torch.dtype, - ) -> Dict[str, Any]: - weight = SparseParameter( - shape=torch.Size( - (output_size_per_partition, input_size_per_partition)), - dtype=params_dtype, - ) - - set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) - - return {"weight": weight} - - def apply_weights( - self, - weights: Dict[str, Any], - x: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - sparse_weight = weights["weight"] - - # Uncompress to dense - dense_weight = sparse_weight.to_dense() - - # # Uncomment to verify sparsity - # density = torch.count_nonzero( - # dense_weight).item() / dense_weight.numel() - # print(f"sparsity = {1.0 - density}") - - # Standard matrix multiply - if bias is not None: - output = F.linear(x, dense_weight, bias) - else: - output = F.linear(x, dense_weight) - - return output + return SparseW16A16LinearMethod(self, self.get_storage_format_cls()) diff --git a/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py new file mode 100644 index 0000000000000..e2fecda663b60 --- /dev/null +++ b/vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py @@ -0,0 +1,55 @@ +from typing import Any, Dict, Optional, Type + +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs +from vllm.model_executor.layers.sparsity.base_config import SparsityConfig +from vllm.model_executor.layers.parameters import SparseParameter +from magic_wand import (CompressedStorageFormat, + SparseSemiStructuredStorageFormat) + + +class SparseW16A16LinearMethod(LinearMethodBase): + """Linear method for Sparse W16A16. + + Args: + sparsity_config: The sparse config. + """ + storage_format_cls: Type[CompressedStorageFormat] = None + + def __init__(self, sparsity_config: SparsityConfig, + storage_format_cls: Type[CompressedStorageFormat]): + self.sparsity_config = sparsity_config + self.storage_format_cls = storage_format_cls + + def create_weights(self, input_size_per_partition: int, + output_size_per_partition: int, input_size: int, + output_size: int, + params_dtype: torch.dtype) -> Dict[str, Any]: + weight = SparseParameter(shape=torch.Size( + (output_size_per_partition, input_size_per_partition)), + dtype=params_dtype, + storage_format_cls=self.storage_format_cls) + + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + + return {"weight": weight} + + def apply_weights( + self, + weights: Dict[str, Any], + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + sparse_weight = weights["weight"] + + if self.storage_format_cls == SparseSemiStructuredStorageFormat: + output = F.linear(x, sparse_weight, bias) + return output + else: + + # Standard matrix multiply + # Uncompress to dense + output = F.linear(x, sparse_weight.to_dense(), bias) + return output