This repository has been archived by the owner on Oct 11, 2024. It is now read-only.
forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from neuralmagic/semi_structured
Semi-structured 2:4 sparsity via SparseSemiStructuredTensor
- Loading branch information
Showing
8 changed files
with
159 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
from vllm import LLM, SamplingParams | ||
|
||
model = LLM("nm-testing/zephyr-50sparse-24", | ||
sparsity="semi_structured_sparse_w16a16", | ||
enforce_eager=True, | ||
dtype="float16", | ||
tensor_parallel_size=1, | ||
max_model_len=1024) | ||
|
||
sampling_params = SamplingParams(max_tokens=100, temperature=0) | ||
outputs = model.generate("Hello my name is", sampling_params=sampling_params) | ||
print(outputs[0].outputs[0].text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
46 changes: 46 additions & 0 deletions
46
vllm/model_executor/layers/sparsity/semi_structured_sparse_w16a16.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
import torch | ||
|
||
from typing import Any, Dict, List, Type | ||
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig | ||
from .sparse_w16a16_linear_method import SparseW16A16LinearMethod | ||
from magic_wand import (CompressedStorageFormat, | ||
SparseSemiStructuredStorageFormat) | ||
|
||
|
||
class SemiStructuredSparseW16A16Config(SparsityConfig): | ||
"""Config class for SemiStructuredSparseW16A16.""" | ||
|
||
def __init__(self) -> None: | ||
pass | ||
|
||
def __repr__(self) -> str: | ||
return "SemiStructuredSparseW16A16Config()" | ||
|
||
@classmethod | ||
def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]: | ||
return SparseSemiStructuredStorageFormat | ||
|
||
@classmethod | ||
def get_name(cls) -> str: | ||
return "semi_structured_sparse_w16a16" | ||
|
||
@classmethod | ||
def get_supported_act_dtypes(cls) -> List[torch.dtype]: | ||
return [torch.float16, torch.bfloat16] | ||
|
||
@classmethod | ||
def get_min_capability(cls) -> int: | ||
# TODO: Update after checks on more GPUs | ||
return 80 | ||
|
||
@classmethod | ||
def get_config_filenames(cls) -> List[str]: | ||
return ["sparsity_config.json"] | ||
|
||
@classmethod | ||
def from_config( | ||
cls, config: Dict[str, Any]) -> "SemiStructuredSparseW16A16Config": | ||
return cls() | ||
|
||
def get_linear_method(self) -> "SparseW16A16LinearMethod": | ||
return SparseW16A16LinearMethod(self, self.get_storage_format_cls()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
55 changes: 55 additions & 0 deletions
55
vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from typing import Any, Dict, Optional, Type | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs | ||
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig | ||
from vllm.model_executor.layers.parameters import SparseParameter | ||
from magic_wand import (CompressedStorageFormat, | ||
SparseSemiStructuredStorageFormat) | ||
|
||
|
||
class SparseW16A16LinearMethod(LinearMethodBase): | ||
"""Linear method for Sparse W16A16. | ||
Args: | ||
sparsity_config: The sparse config. | ||
""" | ||
storage_format_cls: Type[CompressedStorageFormat] = None | ||
|
||
def __init__(self, sparsity_config: SparsityConfig, | ||
storage_format_cls: Type[CompressedStorageFormat]): | ||
self.sparsity_config = sparsity_config | ||
self.storage_format_cls = storage_format_cls | ||
|
||
def create_weights(self, input_size_per_partition: int, | ||
output_size_per_partition: int, input_size: int, | ||
output_size: int, | ||
params_dtype: torch.dtype) -> Dict[str, Any]: | ||
weight = SparseParameter(shape=torch.Size( | ||
(output_size_per_partition, input_size_per_partition)), | ||
dtype=params_dtype, | ||
storage_format_cls=self.storage_format_cls) | ||
|
||
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) | ||
|
||
return {"weight": weight} | ||
|
||
def apply_weights( | ||
self, | ||
weights: Dict[str, Any], | ||
x: torch.Tensor, | ||
bias: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
sparse_weight = weights["weight"] | ||
|
||
if self.storage_format_cls == SparseSemiStructuredStorageFormat: | ||
output = F.linear(x, sparse_weight, bias) | ||
return output | ||
else: | ||
|
||
# Standard matrix multiply | ||
# Uncompress to dense | ||
output = F.linear(x, sparse_weight.to_dense(), bias) | ||
return output |