Skip to content

Commit

Permalink
Fix do_copy in strip function (#2296)
Browse files Browse the repository at this point in the history
### Changes

Fix `do_copy` option for `strip` function

### Reason for changes

Does not make a copy of the quantized model by PTQ in `strip` function.
  • Loading branch information
AlexanderDokuchaev authored Nov 29, 2023
1 parent 3706639 commit e17d268
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 89 deletions.
3 changes: 2 additions & 1 deletion nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,7 +737,8 @@ def strip(self, do_copy: bool = True) -> "NNCFNetwork":
# PTQ algorithm does not set compressed controller
from nncf.torch.quantization.strip import strip_quantized_model

return strip_quantized_model(self._model_ref)
model = deepcopy(self._model_ref) if do_copy else self._model_ref
return strip_quantized_model(model)
return self.compression_controller.strip(do_copy)


Expand Down
80 changes: 16 additions & 64 deletions tests/torch/ptq/test_strip.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,78 +10,30 @@
# limitations under the License.

import pytest
import torch
from torch.quantization import FakeQuantize

import nncf
from nncf.data import Dataset
from nncf.parameters import TargetDevice
from nncf.quantization import QuantizationPreset
from nncf.torch.nncf_network import ExtraCompressionModuleType
from nncf.torch.nncf_network import NNCFNetwork
from nncf.torch.quantization.layers import BaseQuantizer
from tests.torch.helpers import LeNet
from tests.torch.helpers import RandomDatasetMock


def check_fq(model: NNCFNetwork, striped: bool):
compression_module_type = ExtraCompressionModuleType.EXTERNAL_QUANTIZER
if model.nncf.is_compression_module_registered(compression_module_type):
external_quantizers = model.nncf.get_compression_modules_by_type(compression_module_type)
for key in list(external_quantizers.keys()):
op = external_quantizers[key]
if striped:
assert isinstance(op, FakeQuantize)
else:
assert isinstance(op, BaseQuantizer)

for node in model.nncf.get_original_graph().get_all_nodes():
if node.node_type in ["nncf_model_input", "nncf_model_output"]:
continue

nncf_module = model.nncf.get_containing_module(node.node_name)

if hasattr(nncf_module, "pre_ops"):
for key in list(nncf_module.pre_ops.keys()):
op = nncf_module.get_pre_op(key)
if striped:
assert isinstance(op.op, FakeQuantize)
else:
assert isinstance(op.op, BaseQuantizer)

if hasattr(nncf_module, "post_ops"):
for key in list(nncf_module.post_ops.keys()):
op = nncf_module.get_post_ops(key)
if striped:
assert isinstance(op.op, FakeQuantize)
else:
assert isinstance(op.op, BaseQuantizer)
from tests.torch.helpers import BasicConvTestModel


@pytest.mark.parametrize("strip_type", ("nncf", "torch", "nncf_interfere"))
def test_nncf_strip_api(strip_type):
model = LeNet()
input_size = [1, 1, 32, 32]

def transform_fn(data_item):
images, _ = data_item
return images

dataset = Dataset(RandomDatasetMock(input_size), transform_fn)

quantized_model = nncf.quantize(
model=model,
calibration_dataset=dataset,
preset=QuantizationPreset.MIXED,
target_device=TargetDevice.CPU,
subset_size=1,
fast_bias_correction=True,
)
@pytest.mark.parametrize("do_copy", (True, False), ids=["copy", "inplace"])
def test_nncf_strip_api(strip_type, do_copy):
model = BasicConvTestModel()
quantized_model = nncf.quantize(model, nncf.Dataset([torch.ones(model.INPUT_SIZE)]), subset_size=1)

if strip_type == "nncf":
strip_model = nncf.strip(quantized_model)
strip_model = nncf.strip(quantized_model, do_copy)
elif strip_type == "torch":
strip_model = nncf.torch.strip(quantized_model)
strip_model = nncf.torch.strip(quantized_model, do_copy)
elif strip_type == "nncf_interfere":
strip_model = quantized_model.nncf.strip()
strip_model = quantized_model.nncf.strip(do_copy)

if do_copy:
assert id(strip_model) != id(quantized_model)
else:
assert id(strip_model) == id(quantized_model)

check_fq(quantized_model, True if strip_model is None else strip_model)
assert isinstance(strip_model.conv.get_pre_op("0").op, FakeQuantize)
assert isinstance(strip_model.nncf.external_quantizers["/nncf_model_input_0|OUTPUT"], FakeQuantize)
39 changes: 15 additions & 24 deletions tests/torch/quantization/test_strip.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,36 +300,27 @@ def test_strip_quantization(mode, overflow_fix, tmp_path):
torch.onnx.export(inference_model, input_tensor, f"{tmp_path}/model.onnx")


@pytest.mark.parametrize("do_copy", (True, False))
def test_do_copy(do_copy):
model = BasicConvTestModel()
config = _get_config_for_algo(model.INPUT_SIZE)
register_bn_adaptation_init_args(config)
compressed_model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config)

inference_model = compression_ctrl.strip(do_copy=do_copy)

if do_copy:
assert id(inference_model) != id(compressed_model)
else:
assert id(inference_model) == id(compressed_model)

assert id(compressed_model) == id(compression_ctrl.model)


@pytest.mark.parametrize("strip_type", ("nncf", "torch", "nncf_interfere"))
def test_nncf_strip_api(strip_type):
@pytest.mark.parametrize("do_copy", (True, False), ids=["copy", "inplace"])
def test_nncf_strip_api(strip_type, do_copy):
model = BasicConvTestModel()
config = _get_config_for_algo(model.INPUT_SIZE)

quantized_model, _ = create_compressed_model_and_algo_for_test(model, config)
quantized_model, compression_ctrl = create_compressed_model_and_algo_for_test(model, config)

if strip_type == "nncf":
strip_model = nncf.strip(quantized_model)
strip_model = nncf.strip(quantized_model, do_copy)
elif strip_type == "torch":
strip_model = nncf.torch.strip(quantized_model)
strip_model = nncf.torch.strip(quantized_model, do_copy)
elif strip_type == "nncf_interfere":
strip_model = quantized_model.nncf.strip()
strip_model = quantized_model.nncf.strip(do_copy)

if do_copy:
assert id(strip_model) != id(quantized_model)
else:
assert id(strip_model) == id(quantized_model)

assert id(quantized_model) == id(compression_ctrl.model)

fq = strip_model.conv.get_pre_op("0").op
assert isinstance(fq, FakeQuantize)
assert isinstance(strip_model.conv.get_pre_op("0").op, FakeQuantize)
assert isinstance(strip_model.nncf.external_quantizers["/nncf_model_input_0|OUTPUT"], FakeQuantize)

0 comments on commit e17d268

Please sign in to comment.