Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

E8M0 scale for E2M1 weights. #2767

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,12 @@ def transform_model(
) -> ov.Model:
for wc_params in weight_compression_parameters:
compression_config = wc_params.compression_config
scale_dtype = ov.Type.f16
if compression_config.mode == CompressWeightsMode.NF4:
compression_dtype = ov.Type.nf4
elif compression_config.mode == CompressWeightsMode.E2M1:
compression_dtype = ov.Type.f4e2m1
scale_dtype = ov.Type.f8e8m0
elif compression_config.mode == CompressWeightsMode.INT4_SYM:
compression_dtype = ov.Type.i4
elif compression_config.mode == CompressWeightsMode.INT4_ASYM:
Expand Down Expand Up @@ -190,8 +192,11 @@ def transform_model(
)

scale_const = opset.constant(
compressed_weight.scale.data, dtype=ov.Type.f16, name=f"{const_node_name}/scale"
compressed_weight.scale.data, dtype=scale_dtype, name=f"{const_node_name}/scale"
)
if scale_dtype != ov.Type.f16:
scale_const = opset.convert(scale_const, ov.Type.f16)

mul = opset.multiply(
converted_const,
scale_const,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def calculate_e2m1_scale(weight: Tensor, reduction_axes: ReductionAxes, max_val=
"""
scale = calculate_nf4_scale(weight, reduction_axes) / max_val

scale = fns.log2(scale)
scale = fns.ceil(scale)
scale = fns.clip(scale, -127, 127)
scale = 2**scale

return scale


Expand Down
2 changes: 2 additions & 0 deletions nncf/tensor/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nncf.tensor.functions.numeric import argsort as argsort
from nncf.tensor.functions.numeric import as_tensor_like as as_tensor_like
from nncf.tensor.functions.numeric import astype as astype
from nncf.tensor.functions.numeric import ceil as ceil
from nncf.tensor.functions.numeric import clip as clip
from nncf.tensor.functions.numeric import concatenate as concatenate
from nncf.tensor.functions.numeric import count_nonzero as count_nonzero
Expand All @@ -32,6 +33,7 @@
from nncf.tensor.functions.numeric import isclose as isclose
from nncf.tensor.functions.numeric import isempty as isempty
from nncf.tensor.functions.numeric import item as item
from nncf.tensor.functions.numeric import log2 as log2
from nncf.tensor.functions.numeric import logical_or as logical_or
from nncf.tensor.functions.numeric import masked_mean as masked_mean
from nncf.tensor.functions.numeric import masked_median as masked_median
Expand Down
24 changes: 24 additions & 0 deletions nncf/tensor/functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,3 +881,27 @@ def from_numpy(ndarray: np.ndarray, *, backend: TensorBackend) -> Tensor:
if backend == TensorBackend.numpy:
return Tensor(ndarray)
return Tensor(get_numeric_backend_fn("from_numpy", backend)(ndarray))


@functools.singledispatch
@tensor_guard
def log2(a: Tensor) -> Tensor:
"""
Base-2 logarithm of a.

:param a: The input tensor.
:return: A tensor containing the base-2 logarithm of each element in a.
"""
return Tensor(log2(a.data))


@functools.singledispatch
@tensor_guard
def ceil(a: Tensor) -> Tensor:
ljaljushkin marked this conversation as resolved.
Show resolved Hide resolved
"""
Return the ceiling of the input, element-wise.

:param a: Input data.
:return: An array of the same type as a, containing the ceiling values.
"""
return Tensor(ceil(a.data))
10 changes: 10 additions & 0 deletions nncf/tensor/functions/numpy_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,13 @@ def arange(
if dtype is not None:
dtype = DTYPE_MAP[dtype]
return np.arange(start, end, step, dtype=dtype)


@register_numpy_types(numeric.log2)
def _(a: Union[np.ndarray, np.generic]) -> Union[np.ndarray, np.generic]:
return np.log2(a)


@register_numpy_types(numeric.ceil)
def _(a: Union[np.ndarray, np.generic]) -> np.ndarray:
return np.ceil(a)
10 changes: 10 additions & 0 deletions nncf/tensor/functions/torch_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,13 @@ def arange(

def from_numpy(ndarray: np.ndarray) -> torch.Tensor:
return torch.from_numpy(ndarray)


@numeric.log2.register(torch.Tensor)
def _(a: torch.Tensor) -> torch.Tensor:
return torch.log2(a)


@numeric.ceil.register(torch.Tensor)
def _(a: torch.Tensor) -> torch.Tensor:
return torch.ceil(a)
12 changes: 9 additions & 3 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,8 +901,14 @@ def test_mixed_precision_e2m1(mode, all_layers, ratio, ref_ids):
sensitivity_metric=mode,
dataset=dataset,
)
names = {
names_e2m1 = {
op.get_friendly_name() for op in compressed_model.get_ordered_ops() if op.get_element_type() == ov.Type.f4e2m1
}
ref_nf4_nodes = {f"weights_{i}" for i in ref_ids}
assert ref_nf4_nodes == names
ref_e2m1_nodes = {f"weights_{i}" for i in ref_ids}
assert ref_e2m1_nodes == names_e2m1

names_e8m0 = {
op.get_friendly_name() for op in compressed_model.get_ordered_ops() if op.get_element_type() == ov.Type.f8e8m0
}
ref_e8m0_nodes = {f"weights_{i}/scale" for i in ref_ids}
assert ref_e8m0_nodes == names_e8m0
39 changes: 39 additions & 0 deletions tests/shared/test_templates/template_test_nncf_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import operator
from abc import abstractmethod
from math import log2
from math import sqrt
from typing import TypeVar

Expand Down Expand Up @@ -1574,6 +1575,44 @@ def test_searchsorted_2d_error(self):
with pytest.raises(ValueError):
fns.searchsorted(tensor_a, tensor_v)

@pytest.mark.parametrize(
"val,ref",
(
(1.1, 2.0),
([1.1, 0.9], [2.0, 1.0]),
([1.11, 0.91], [2.0, 1.0]),
),
)
def test_fn_ceil(self, val, ref):
tensor = Tensor(self.to_tensor(val))
ref_tensor = self.to_tensor(ref)

res = fns.ceil(tensor)

assert isinstance(res, Tensor)
assert fns.allclose(res.data, ref_tensor)
assert res.device == tensor.device

@pytest.mark.parametrize(
"x,ref",
[
(list(map(float, range(1, 10))), [log2(x) for x in map(float, range(1, 10))]),
],
)
def test_fn_log2(self, x, ref):
if isinstance(x, list):
x = self.to_tensor(x)
tensor = Tensor(x)

ref_tensor = self.to_tensor(ref)

res = fns.log2(tensor)

assert isinstance(res, Tensor)
assert fns.allclose(res.data, ref_tensor)
assert res.device == tensor.device
assert res.shape == tuple(ref_tensor.shape)

@pytest.mark.parametrize(
"x, y, a_ref, b_ref",
(
Expand Down
Loading