Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Semi-structured 2:4 sparsity via SparseSemiStructuredTensor #4

Merged
merged 17 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions examples/offline_inference_semi_structured_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from vllm import LLM, SamplingParams

model = LLM(
"nm-testing/zephyr-50sparse-24",
sparsity="semi_structured_sparse_w16a16",
enforce_eager=True,
dtype="float16",
tensor_parallel_size=1,
max_model_len=1024
)

sampling_params = SamplingParams(max_tokens=100, temperature=0)
outputs = model.generate("Hello my name is", sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
2 changes: 1 addition & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _verify_tokenizer_mode(self) -> None:
self.tokenizer_mode = tokenizer_mode

def _verify_sparsity(self) -> None:
supported_sparsity = ["sparse_w16a16"]
supported_sparsity = ["sparse_w16a16", "semi_structured_sparse_w16a16"]

if self.quantization is not None:
raise ValueError("Both sparsity and quantization detected. Only "
Expand Down
32 changes: 29 additions & 3 deletions vllm/model_executor/layers/parameters/sparsity.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
import torch
from torch.sparse import SparseSemiStructuredTensor

from magic_wand import SparseTensor, SparseBitmaskStorageFormat
from typing import Type
from magic_wand import (
SparseTensor,
CompressedStorageFormat,
SparseBitmaskStorageFormat,
SparseSemiStructuredStorageFormat
)


class SparseParameter(SparseTensor):
Expand All @@ -10,20 +17,25 @@ def __new__(
cls,
shape: torch.Size,
dtype: torch.dtype,
storage_format_cls: Type[CompressedStorageFormat] = SparseBitmaskStorageFormat
):
assert torch.__version__ > (1,
10), "SparseTensor requires PyTorch 1.11+"

self = torch.Tensor._make_wrapper_subclass(cls,
size=shape,
dtype=dtype,
requires_grad=False)
self.storage_format_cls = SparseBitmaskStorageFormat
self.storage_format_cls = storage_format_cls
self.compressed_data = None
self.dense_data = None
self._is_param = True

return self

def has_compressed_data(self) -> bool:
return (self.compressed_data is not None)

def get_dense_data(self) -> torch.Tensor:
if self.dense_data is not None:
raise ValueError(
Expand All @@ -39,8 +51,22 @@ def _unpack(self) -> torch.Tensor:
dtype=self.dtype,
device="cuda")

@classmethod
def _copy(cls, arg0, arg1):
assert arg0.shape == arg1.shape

if arg0.has_compressed_data():
arg0.compressed_data.copy_(arg1)
else:
arg0.compressed_data = arg0.storage_format_cls.compress(arg1)

return arg0

def copy_(self, src, non_blocking=False):
return SparseParameter._copy(self, src)

def pack(self) -> None:
if self.dense_data is None:
raise ValueError("Called pack() but dense_data does not exist.")
self.copy_(self.dense_data)
self.dense_data = None
self.dense_data = None
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/sparsity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
from vllm.model_executor.layers.sparsity.sparse_w16a16 import SparseW16A16Config
from vllm.model_executor.layers.sparsity.semi_structured_sparse_w16a16 import SemiStructuredSparseW16A16Config

_SPARSITY_CONFIG_REGISTRY = {
"sparse_w16a16": SparseW16A16Config,
"semi_structured_sparse_w16a16": SemiStructuredSparseW16A16Config,
}


Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/layers/sparsity/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,19 @@
from typing import Any, Dict, List

import torch
from typing import Type

from vllm.model_executor.layers.linear import LinearMethodBase

from magic_wand import CompressedStorageFormat

class SparsityConfig(ABC):
"""Base class for sparsity configs."""

@abstractmethod
def get_storage_format_cls(self) -> Type[CompressedStorageFormat]:
"""Sparse representation format"""
raise NotImplementedError

@abstractmethod
def get_name(self) -> str:
"""Name of the sparse method."""
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch

from typing import Any, Dict, List, Type
from magic_wand import CompressedStorageFormat
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
from .sparse_w16a16_linear_method import SparseW16A16LinearMethod
from magic_wand import (
CompressedStorageFormat,
SparseSemiStructuredStorageFormat
)

class SemiStructuredSparseW16A16Config(SparsityConfig):
"""Config class for SemiStructuredSparseW16A16."""

def __init__(self) -> None:
pass

def __repr__(self) -> str:
return "SemiStructuredSparseW16A16Config()"

@classmethod
def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]:
return SparseSemiStructuredStorageFormat

@classmethod
def get_name(cls) -> str:
return "semi_structured_sparse_w16a16"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]

@classmethod
def get_min_capability(cls) -> int:
# TODO: Update after checks on more GPUs
return 80

@classmethod
def get_config_filenames(cls) -> List[str]:
return ["sparsity_config.json"]
afeldman-nm marked this conversation as resolved.
Show resolved Hide resolved

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SemiStructuredSparseW16A16Config":
return cls()

def get_linear_method(self) -> "SparseW16A16LinearMethod":
return SparseW16A16LinearMethod(self,self.get_storage_format_cls())
68 changes: 11 additions & 57 deletions vllm/model_executor/layers/sparsity/sparse_w16a16.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Type

import torch
import torch.nn.functional as F

from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
from vllm.model_executor.layers.parameters import SparseParameter

from .sparse_w16a16_linear_method import SparseW16A16LinearMethod
from magic_wand import (
CompressedStorageFormat,
SparseBitmaskStorageFormat
)

class SparseW16A16Config(SparsityConfig):
"""Config class for SparseW16A16.
Expand All @@ -21,6 +24,10 @@ def __init__(self) -> None:
def __repr__(self) -> str:
return "SparseW16A16Config()"

@classmethod
def get_storage_format_cls(cls) -> Type[CompressedStorageFormat]:
return SparseBitmaskStorageFormat

@classmethod
def get_name(cls) -> str:
return "sparse_w16a16"
Expand All @@ -43,57 +50,4 @@ def from_config(cls, config: Dict[str, Any]) -> "SparseW16A16Config":
return cls()

def get_linear_method(self) -> "SparseW16A16LinearMethod":
return SparseW16A16LinearMethod(self)


class SparseW16A16LinearMethod(LinearMethodBase):
"""Linear method for Sparse W16A16.

Args:
sparsity_config: The sparse config.
"""

def __init__(self, sparsity_config: SparseW16A16Config):
self.sparsity_config = sparsity_config

def create_weights(
self,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
weight = SparseParameter(
shape=torch.Size(
(output_size_per_partition, input_size_per_partition)),
dtype=params_dtype,
)

set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})

return {"weight": weight}

def apply_weights(
self,
weights: Dict[str, Any],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
sparse_weight = weights["weight"]

# Uncompress to dense
dense_weight = sparse_weight.to_dense()

# # Uncomment to verify sparsity
# density = torch.count_nonzero(
# dense_weight).item() / dense_weight.numel()
# print(f"sparsity = {1.0 - density}")

# Standard matrix multiply
if bias is not None:
output = F.linear(x, dense_weight, bias)
else:
output = F.linear(x, dense_weight)

return output
return SparseW16A16LinearMethod(self,self.get_storage_format_cls())
61 changes: 61 additions & 0 deletions vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Any, Dict, Optional, Type

import torch
import torch.nn.functional as F

from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
from vllm.model_executor.layers.parameters import SparseParameter
from magic_wand import (
CompressedStorageFormat,
SparseSemiStructuredStorageFormat
)

class SparseW16A16LinearMethod(LinearMethodBase):
"""Linear method for Sparse W16A16.

Args:
sparsity_config: The sparse config.
"""
storage_format_cls: Type[CompressedStorageFormat] = None

def __init__(self, sparsity_config: SparsityConfig, storage_format_cls: Type[CompressedStorageFormat]):
self.sparsity_config = sparsity_config
self.storage_format_cls = storage_format_cls

def create_weights(
self,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype
) -> Dict[str, Any]:
weight = SparseParameter(
shape=torch.Size(
(output_size_per_partition, input_size_per_partition)),
dtype=params_dtype,
storage_format_cls=self.storage_format_cls
)

set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})

return {"weight": weight}

def apply_weights(
self,
weights: Dict[str, Any],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
sparse_weight = weights["weight"]

if self.storage_format_cls == SparseSemiStructuredStorageFormat:
output = F.linear(x, sparse_weight, bias)
return output
else:

# Standard matrix multiply
# Uncompress to dense
output = F.linear(x, sparse_weight.to_dense(), bias)
return output