Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
lint/format
Browse files Browse the repository at this point in the history
  • Loading branch information
afeldman-nm committed Feb 13, 2024
1 parent 95303b3 commit 51ebca3
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 56 deletions.
16 changes: 7 additions & 9 deletions examples/offline_inference_semi_structured_sparse.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from vllm import LLM, SamplingParams

model = LLM(
"nm-testing/zephyr-50sparse-24",
sparsity="semi_structured_sparse_w16a16",
enforce_eager=True,
dtype="float16",
tensor_parallel_size=1,
max_model_len=1024
)
model = LLM("nm-testing/zephyr-50sparse-24",
sparsity="semi_structured_sparse_w16a16",
enforce_eager=True,
dtype="float16",
tensor_parallel_size=1,
max_model_len=1024)

sampling_params = SamplingParams(max_tokens=100, temperature=0)
outputs = model.generate("Hello my name is", sampling_params=sampling_params)
print(outputs[0].outputs[0].text)
print(outputs[0].outputs[0].text)
22 changes: 8 additions & 14 deletions vllm/model_executor/layers/parameters/sparsity.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
import torch
from torch.sparse import SparseSemiStructuredTensor

from typing import Type
from magic_wand import (
SparseTensor,
CompressedStorageFormat,
SparseBitmaskStorageFormat,
SparseSemiStructuredStorageFormat
)
from magic_wand import (SparseTensor, CompressedStorageFormat,
SparseBitmaskStorageFormat)


class SparseParameter(SparseTensor):

@staticmethod
def __new__(
cls,
shape: torch.Size,
dtype: torch.dtype,
storage_format_cls: Type[CompressedStorageFormat] = SparseBitmaskStorageFormat
):
def __new__(cls,
shape: torch.Size,
dtype: torch.dtype,
storage_format_cls: Type[
CompressedStorageFormat] = SparseBitmaskStorageFormat):
assert torch.__version__ > (1,
10), "SparseTensor requires PyTorch 1.11+"

Expand Down Expand Up @@ -69,4 +63,4 @@ def pack(self) -> None:
if self.dense_data is None:
raise ValueError("Called pack() but dense_data does not exist.")
self.copy_(self.dense_data)
self.dense_data = None
self.dense_data = None
1 change: 1 addition & 0 deletions vllm/model_executor/layers/sparsity/base_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from vllm.model_executor.layers.linear import LinearMethodBase
from magic_wand import CompressedStorageFormat


class SparsityConfig(ABC):
"""Base class for sparsity configs."""

Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import torch

from typing import Any, Dict, List, Type
from magic_wand import CompressedStorageFormat
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
from .sparse_w16a16_linear_method import SparseW16A16LinearMethod
from magic_wand import (
CompressedStorageFormat,
SparseSemiStructuredStorageFormat
)
from magic_wand import (CompressedStorageFormat,
SparseSemiStructuredStorageFormat)


class SemiStructuredSparseW16A16Config(SparsityConfig):
"""Config class for SemiStructuredSparseW16A16."""
Expand Down Expand Up @@ -40,8 +38,9 @@ def get_config_filenames(cls) -> List[str]:
return ["sparsity_config.json"]

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "SemiStructuredSparseW16A16Config":
def from_config(
cls, config: Dict[str, Any]) -> "SemiStructuredSparseW16A16Config":
return cls()

def get_linear_method(self) -> "SparseW16A16LinearMethod":
return SparseW16A16LinearMethod(self,self.get_storage_format_cls())
return SparseW16A16LinearMethod(self, self.get_storage_format_cls())
9 changes: 3 additions & 6 deletions vllm/model_executor/layers/sparsity/sparse_w16a16.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from typing import Any, Dict, List, Type

import torch
import torch.nn.functional as F

from vllm.model_executor.layers.sparsity.base_config import SparsityConfig

from .sparse_w16a16_linear_method import SparseW16A16LinearMethod
from magic_wand import (
CompressedStorageFormat,
SparseBitmaskStorageFormat
)
from magic_wand import (CompressedStorageFormat, SparseBitmaskStorageFormat)


class SparseW16A16Config(SparsityConfig):
"""Config class for SparseW16A16.
Expand Down Expand Up @@ -50,4 +47,4 @@ def from_config(cls, config: Dict[str, Any]) -> "SparseW16A16Config":
return cls()

def get_linear_method(self) -> "SparseW16A16LinearMethod":
return SparseW16A16LinearMethod(self,self.get_storage_format_cls())
return SparseW16A16LinearMethod(self, self.get_storage_format_cls())
34 changes: 14 additions & 20 deletions vllm/model_executor/layers/sparsity/sparse_w16a16_linear_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from vllm.model_executor.layers.linear import LinearMethodBase, set_weight_attrs
from vllm.model_executor.layers.sparsity.base_config import SparsityConfig
from vllm.model_executor.layers.parameters import SparseParameter
from magic_wand import (
CompressedStorageFormat,
SparseSemiStructuredStorageFormat
)
from magic_wand import (CompressedStorageFormat,
SparseSemiStructuredStorageFormat)


class SparseW16A16LinearMethod(LinearMethodBase):
"""Linear method for Sparse W16A16.
Expand All @@ -19,24 +18,19 @@ class SparseW16A16LinearMethod(LinearMethodBase):
"""
storage_format_cls: Type[CompressedStorageFormat] = None

def __init__(self, sparsity_config: SparsityConfig, storage_format_cls: Type[CompressedStorageFormat]):
def __init__(self, sparsity_config: SparsityConfig,
storage_format_cls: Type[CompressedStorageFormat]):
self.sparsity_config = sparsity_config
self.storage_format_cls = storage_format_cls

def create_weights(
self,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype
) -> Dict[str, Any]:
weight = SparseParameter(
shape=torch.Size(
(output_size_per_partition, input_size_per_partition)),
dtype=params_dtype,
storage_format_cls=self.storage_format_cls
)
def create_weights(self, input_size_per_partition: int,
output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
weight = SparseParameter(shape=torch.Size(
(output_size_per_partition, input_size_per_partition)),
dtype=params_dtype,
storage_format_cls=self.storage_format_cls)

set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})

Expand All @@ -58,4 +52,4 @@ def apply_weights(
# Standard matrix multiply
# Uncompress to dense
output = F.linear(x, sparse_weight.to_dense(), bias)
return output
return output

0 comments on commit 51ebca3

Please sign in to comment.