Skip to content

Commit

Permalink
update wc_reference_data.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
l-bat committed Jun 10, 2024
1 parent a651234 commit fceef1b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 9 deletions.
7 changes: 4 additions & 3 deletions nncf/torch/quantization/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@
from nncf.torch.quantization.quantize_functions import ExportQuantizeToONNXQuantDequant
from nncf.torch.quantization.quantize_functions import TuneRange
from nncf.torch.quantization.quantize_functions import asymmetric_quantize
from nncf.torch.quantization.quantize_functions import decompress
from nncf.torch.quantization.quantize_functions import decompress_asymmetric
from nncf.torch.quantization.quantize_functions import decompress_symmetric
from nncf.torch.quantization.quantize_functions import get_scale_zp_from_input_low_input_high
from nncf.torch.quantization.quantize_functions import symmetric_quantize
from nncf.torch.return_types import maybe_get_values_from_torch_return_type
Expand Down Expand Up @@ -1061,7 +1062,7 @@ def __init__(self, scale: torch.Tensor, zero_point: torch.Tensor, result_dtype:
self.result_dtype = result_dtype

def forward(self, x):
result = decompress(x, self._scale, self._zero_point)
result = decompress_asymmetric(x, self._scale, self._zero_point)
result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result
return result

Expand All @@ -1081,6 +1082,6 @@ def __init__(self, scale: torch.Tensor, result_dtype: torch.dtype = None):
self.result_dtype = result_dtype

def forward(self, x):
result = decompress(x, self._scale)
result = decompress_symmetric(x, self._scale)
result = result.type(dtype=self.result_dtype) if self.result_dtype is not None else result
return result
22 changes: 17 additions & 5 deletions nncf/torch/quantization/quantize_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Optional
from typing import Any

import torch

Expand Down Expand Up @@ -249,17 +249,29 @@ def backward(ctx: Any, *grad_outputs: Any) -> Any:


@register_operator()
def decompress(input: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor] = None) -> torch.Tensor:
def decompress_asymmetric(input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor) -> torch.Tensor:
"""
Decompress the input tensor.
Decompress the asymmetrically quantized input tensor.
:param input: An input tensor
:param scale: A scale tensor
:param zero_point: A zero point tensor
:return: The decompressed tensor
"""
input = input.type(dtype=scale.dtype)
if zero_point is not None:
input -= zero_point
decompressed_input = (input - zero_point) * scale
return decompressed_input


@register_operator()
def decompress_symmetric(input: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
"""
Decompress the symmetrically quantized input tensor.
:param input: An input tensor
:param scale: A scale tensor
:return: The decompressed tensor
"""
input = input.type(dtype=scale.dtype)
decompressed_input = input * scale
return decompressed_input
2 changes: 1 addition & 1 deletion tests/post_training/data/wc_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,5 +24,5 @@ tinyllama_int8_data_free_backend_TORCH:
num_int8: 312
tinyllama_data_aware_gptq_backend_OV:
metric_value: 0.83387
num_int4: 188
num_int4: 94
num_int8: 124

0 comments on commit fceef1b

Please sign in to comment.