From 444cb64c41cdf8ed28a65c441c41491af6a1b558 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Tue, 23 Jul 2024 12:06:12 +0200 Subject: [PATCH] Comments --- nncf/experimental/torch/fx/quantization/quantize_model.py | 7 ------- nncf/experimental/torch/fx/transformations.py | 6 ++++++ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/nncf/experimental/torch/fx/quantization/quantize_model.py b/nncf/experimental/torch/fx/quantization/quantize_model.py index 8a5348754ef..6dcd2bc154d 100644 --- a/nncf/experimental/torch/fx/quantization/quantize_model.py +++ b/nncf/experimental/torch/fx/quantization/quantize_model.py @@ -18,7 +18,6 @@ from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat from torch.ao.quantization.pt2e.utils import _disallow_eval_train -from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ from torch.fx import GraphModule from torch.fx.passes.infra.pass_manager import PassManager @@ -91,12 +90,6 @@ def quantize_impl( advanced_parameters=advanced_parameters, ) - # BatchNorm operations have 3 output ports, - # to make it easier for alorithms to work - # with the target graph BatchNorm operations - # are being fused - _fuse_conv_bn_(copied_model) - # To make it easier for bias correction algorithms, # biases are being separated by the followng calls. apply_quantization_transformations(copied_model) diff --git a/nncf/experimental/torch/fx/transformations.py b/nncf/experimental/torch/fx/transformations.py index 0cb730ca4d0..985e28f66db 100644 --- a/nncf/experimental/torch/fx/transformations.py +++ b/nncf/experimental/torch/fx/transformations.py @@ -14,6 +14,7 @@ import torch import torch.fx from torch.ao.quantization.fx.utils import create_getattr_from_value +from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node from torch.quantization.fake_quantize import FakeQuantize @@ -239,6 +240,11 @@ def apply_quantization_transformations(model: torch.fx.Graph): Applies quantization transformations to the model. :param model: Model to apply transformations to. """ + # BatchNorm operations have 3 output ports, + # to make it easier for alorithms to work + # with the target graph BatchNorm operations + # are being fused + _fuse_conv_bn_(model) separate_conv_and_bias(model) separate_linear_and_bias(model)