Skip to content

Commit

Permalink
E8M0 scale for E2M1 weights. (#2767)
Browse files Browse the repository at this point in the history
### Changes

Changed default scale fot E2M1 weights from fp16 to fp8_e8m0

### Reason for changes

According to the "OCP Microscaling Formats (MX) Specification"

### Related tickets

CVS-140944

### Tests

Will extend E2M1 tests.
  • Loading branch information
andreyanufr authored Jul 3, 2024
1 parent a644a1e commit fd0e33c
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,12 @@ def transform_model(
) -> ov.Model:
for wc_params in weight_compression_parameters:
compression_config = wc_params.compression_config
scale_dtype = ov.Type.f16
if compression_config.mode == CompressWeightsMode.NF4:
compression_dtype = ov.Type.nf4
elif compression_config.mode == CompressWeightsMode.E2M1:
compression_dtype = ov.Type.f4e2m1
scale_dtype = ov.Type.f8e8m0
elif compression_config.mode == CompressWeightsMode.INT4_SYM:
compression_dtype = ov.Type.i4
elif compression_config.mode == CompressWeightsMode.INT4_ASYM:
Expand Down Expand Up @@ -190,8 +192,11 @@ def transform_model(
)

scale_const = opset.constant(
compressed_weight.scale.data, dtype=ov.Type.f16, name=f"{const_node_name}/scale"
compressed_weight.scale.data, dtype=scale_dtype, name=f"{const_node_name}/scale"
)
if scale_dtype != ov.Type.f16:
scale_const = opset.convert(scale_const, ov.Type.f16)

mul = opset.multiply(
converted_const,
scale_const,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def calculate_e2m1_scale(weight: Tensor, reduction_axes: ReductionAxes, max_val=
"""
scale = calculate_nf4_scale(weight, reduction_axes) / max_val

scale = fns.log2(scale)
scale = fns.ceil(scale)
scale = fns.clip(scale, -127, 127)
scale = 2**scale

return scale


Expand Down
2 changes: 2 additions & 0 deletions nncf/tensor/functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from nncf.tensor.functions.numeric import argsort as argsort
from nncf.tensor.functions.numeric import as_tensor_like as as_tensor_like
from nncf.tensor.functions.numeric import astype as astype
from nncf.tensor.functions.numeric import ceil as ceil
from nncf.tensor.functions.numeric import clip as clip
from nncf.tensor.functions.numeric import concatenate as concatenate
from nncf.tensor.functions.numeric import count_nonzero as count_nonzero
Expand All @@ -32,6 +33,7 @@
from nncf.tensor.functions.numeric import isclose as isclose
from nncf.tensor.functions.numeric import isempty as isempty
from nncf.tensor.functions.numeric import item as item
from nncf.tensor.functions.numeric import log2 as log2
from nncf.tensor.functions.numeric import logical_or as logical_or
from nncf.tensor.functions.numeric import masked_mean as masked_mean
from nncf.tensor.functions.numeric import masked_median as masked_median
Expand Down
24 changes: 24 additions & 0 deletions nncf/tensor/functions/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,3 +881,27 @@ def from_numpy(ndarray: np.ndarray, *, backend: TensorBackend) -> Tensor:
if backend == TensorBackend.numpy:
return Tensor(ndarray)
return Tensor(get_numeric_backend_fn("from_numpy", backend)(ndarray))


@functools.singledispatch
@tensor_guard
def log2(a: Tensor) -> Tensor:
"""
Base-2 logarithm of a.
:param a: The input tensor.
:return: A tensor containing the base-2 logarithm of each element in a.
"""
return Tensor(log2(a.data))


@functools.singledispatch
@tensor_guard
def ceil(a: Tensor) -> Tensor:
"""
Return the ceiling of the input, element-wise.
:param a: Input data.
:return: An array of the same type as a, containing the ceiling values.
"""
return Tensor(ceil(a.data))
10 changes: 10 additions & 0 deletions nncf/tensor/functions/numpy_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,13 @@ def arange(
if dtype is not None:
dtype = DTYPE_MAP[dtype]
return np.arange(start, end, step, dtype=dtype)


@register_numpy_types(numeric.log2)
def _(a: Union[np.ndarray, np.generic]) -> Union[np.ndarray, np.generic]:
return np.log2(a)


@register_numpy_types(numeric.ceil)
def _(a: Union[np.ndarray, np.generic]) -> np.ndarray:
return np.ceil(a)
10 changes: 10 additions & 0 deletions nncf/tensor/functions/torch_numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,3 +455,13 @@ def arange(

def from_numpy(ndarray: np.ndarray) -> torch.Tensor:
return torch.from_numpy(ndarray)


@numeric.log2.register(torch.Tensor)
def _(a: torch.Tensor) -> torch.Tensor:
return torch.log2(a)


@numeric.ceil.register(torch.Tensor)
def _(a: torch.Tensor) -> torch.Tensor:
return torch.ceil(a)
12 changes: 9 additions & 3 deletions tests/openvino/native/quantization/test_weights_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,8 +901,14 @@ def test_mixed_precision_e2m1(mode, all_layers, ratio, ref_ids):
sensitivity_metric=mode,
dataset=dataset,
)
names = {
names_e2m1 = {
op.get_friendly_name() for op in compressed_model.get_ordered_ops() if op.get_element_type() == ov.Type.f4e2m1
}
ref_nf4_nodes = {f"weights_{i}" for i in ref_ids}
assert ref_nf4_nodes == names
ref_e2m1_nodes = {f"weights_{i}" for i in ref_ids}
assert ref_e2m1_nodes == names_e2m1

names_e8m0 = {
op.get_friendly_name() for op in compressed_model.get_ordered_ops() if op.get_element_type() == ov.Type.f8e8m0
}
ref_e8m0_nodes = {f"weights_{i}/scale" for i in ref_ids}
assert ref_e8m0_nodes == names_e8m0
39 changes: 39 additions & 0 deletions tests/shared/test_templates/template_test_nncf_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import operator
from abc import abstractmethod
from math import log2
from math import sqrt
from typing import TypeVar

Expand Down Expand Up @@ -1574,6 +1575,44 @@ def test_searchsorted_2d_error(self):
with pytest.raises(ValueError):
fns.searchsorted(tensor_a, tensor_v)

@pytest.mark.parametrize(
"val,ref",
(
(1.1, 2.0),
([1.1, 0.9], [2.0, 1.0]),
([1.11, 0.91], [2.0, 1.0]),
),
)
def test_fn_ceil(self, val, ref):
tensor = Tensor(self.to_tensor(val))
ref_tensor = self.to_tensor(ref)

res = fns.ceil(tensor)

assert isinstance(res, Tensor)
assert fns.allclose(res.data, ref_tensor)
assert res.device == tensor.device

@pytest.mark.parametrize(
"x,ref",
[
(list(map(float, range(1, 10))), [log2(x) for x in map(float, range(1, 10))]),
],
)
def test_fn_log2(self, x, ref):
if isinstance(x, list):
x = self.to_tensor(x)
tensor = Tensor(x)

ref_tensor = self.to_tensor(ref)

res = fns.log2(tensor)

assert isinstance(res, Tensor)
assert fns.allclose(res.data, ref_tensor)
assert res.device == tensor.device
assert res.shape == tuple(ref_tensor.shape)

@pytest.mark.parametrize(
"x, y, a_ref, b_ref",
(
Expand Down

0 comments on commit fd0e33c

Please sign in to comment.