Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Sep 27, 2023
1 parent ad03d21 commit 24744d5
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 51 deletions.
3 changes: 1 addition & 2 deletions nncf/torch/checkpoint_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def _clip_keys_without_collisions(keys: List[str], keys_to_ignore: List[str]) ->
def _key_clipper(key: str) -> str:
new_key = key

clip_patterns = ["module.", "|OUTPUT", "|INPUT", "_nncf."]
clip_patterns = ["module.", "|OUTPUT", "|INPUT"]
for pattern in clip_patterns:
new_key = new_key.replace(pattern, "")
return new_key
Expand Down Expand Up @@ -268,7 +268,6 @@ def _split_unified_parameters(new_key: str) -> List[str]:
return result



class KeyMatcher:
"""
Matches state_dict_to_load parameters to the model's state_dict parameters while discarding irrelevant prefixes
Expand Down
67 changes: 35 additions & 32 deletions tests/torch/test_backward_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from nncf.config import NNCFConfig
from nncf.torch import register_default_init_args
from nncf.torch.checkpoint_loading import load_state
from nncf.torch.nncf_network import EXTERNAL_QUANTIZERS_STORAGE_PREFIX
from nncf.torch.quantization.algo import QUANTIZER_BUILDER_STATE_VERSION_SAVE_NAME
from nncf.torch.quantization.algo import QuantizerBuilderStateVersion
from tests.shared.helpers import get_cli_dict_args
Expand Down Expand Up @@ -177,39 +178,41 @@ def forward(self, x):


sd_without_nncf_bn_wrapping = {
"nncf_module.conv.weight": torch.ones([9, 3, 3, 3]),
"nncf_module.conv.bias": torch.ones([9]),
"nncf_module.conv.nncf_padding_value": torch.ones([1]),
"nncf_module.conv.pre_ops.0.op._num_bits": torch.ones([1]),
"nncf_module.conv.pre_ops.0.op.signed_tensor": torch.ones([1]),
"nncf_module.conv.pre_ops.0.op.enabled": torch.ones([1]),
"nncf_module.conv.pre_ops.0.op.scale": torch.ones([9, 1, 1, 1]),
"nncf_module.bn.weight": torch.ones([9]),
"nncf_module.bn.bias": torch.ones([9]),
"nncf_module.bn.running_mean": torch.ones([9]),
"nncf_module.bn.running_var": torch.ones([9]),
"nncf_module.bn.num_batches_tracked": torch.ones([]),
"nncf_module.conv1.weight": torch.ones([3, 9, 3, 3]),
"nncf_module.conv1.bias": torch.ones([3]),
"nncf_module.conv1.nncf_padding_value": torch.ones([1]),
"nncf_module.conv1.pre_ops.0.op._num_bits": torch.ones([1]),
"nncf_module.conv1.pre_ops.0.op.signed_tensor": torch.ones([1]),
"nncf_module.conv1.pre_ops.0.op.enabled": torch.ones([1]),
"nncf_module.conv1.pre_ops.0.op.scale": torch.ones([3, 1, 1, 1]),
"nncf_module.bn1.weight": torch.ones([3]),
"nncf_module.bn1.bias": torch.ones([3]),
"nncf_module.bn1.running_mean": torch.ones([3]),
"nncf_module.bn1.running_var": torch.ones([3]),
"nncf_module.bn1.num_batches_tracked": torch.ones([]),
"external_quantizers./nncf_model_input_0|OUTPUT._num_bits": torch.ones([1]),
"external_quantizers./nncf_model_input_0|OUTPUT.signed_tensor": torch.ones([1]),
"external_quantizers./nncf_model_input_0|OUTPUT.enabled": torch.ones([1]),
"external_quantizers./nncf_model_input_0|OUTPUT.scale": torch.ones([1]),
"conv.weight": torch.ones([9, 3, 3, 3]),
"conv.bias": torch.ones([9]),
"conv.nncf_padding_value": torch.ones([1]),
"conv.pre_ops.0.op._num_bits": torch.ones([1]),
"conv.pre_ops.0.op.signed_tensor": torch.ones([1]),
"conv.pre_ops.0.op.enabled": torch.ones([1]),
"conv.pre_ops.0.op.scale": torch.ones([9, 1, 1, 1]),
"bn.weight": torch.ones([9]),
"bn.bias": torch.ones([9]),
"bn.running_mean": torch.ones([9]),
"bn.running_var": torch.ones([9]),
"bn.num_batches_tracked": torch.ones([]),
"conv1.weight": torch.ones([3, 9, 3, 3]),
"conv1.bias": torch.ones([3]),
"conv1.nncf_padding_value": torch.ones([1]),
"conv1.pre_ops.0.op._num_bits": torch.ones([1]),
"conv1.pre_ops.0.op.signed_tensor": torch.ones([1]),
"conv1.pre_ops.0.op.enabled": torch.ones([1]),
"conv1.pre_ops.0.op.scale": torch.ones([3, 1, 1, 1]),
"bn1.weight": torch.ones([3]),
"bn1.bias": torch.ones([3]),
"bn1.running_mean": torch.ones([3]),
"bn1.running_var": torch.ones([3]),
"bn1.num_batches_tracked": torch.ones([]),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./nncf_model_input_0|OUTPUT._num_bits": torch.ones([1]),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./nncf_model_input_0|OUTPUT.signed_tensor": torch.ones([1]),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./nncf_model_input_0|OUTPUT.enabled": torch.ones([1]),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./nncf_model_input_0|OUTPUT.scale": torch.ones([1]),
# Old bn layer names: |||||||||||
"external_quantizers.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT._num_bits": torch.ones([1]),
"external_quantizers.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT.signed_tensor": torch.ones([1]),
"external_quantizers.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT.enabled": torch.ones([1]),
"external_quantizers.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT.scale": torch.ones([1]),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT._num_bits": torch.ones([1]),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT.signed_tensor": torch.ones(
[1]
),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT.enabled": torch.ones([1]),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT.scale": torch.ones([1]),
}

compression_state_without_bn_wrapping = {
Expand Down
26 changes: 9 additions & 17 deletions tests/torch/test_load_model_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,14 +229,12 @@ def fn() -> Set["str"]:
.missing(['2']).matched(['1']),

# wrapping by NNCFNetwork and DataParallel & DistributedDataParallel
MatchKeyDesc(num_loaded=2).keys_to_load(['module.1', 'nncf_module.2']).model_keys(['1', '2'])
MatchKeyDesc(num_loaded=2).keys_to_load(['1', '2']).model_keys(['module.1', 'module.2'])
.all_matched(),
MatchKeyDesc(num_loaded=2).keys_to_load(['1', '2']).model_keys(['module.1', 'nncf_module.2'])
.all_matched(),
MatchKeyDesc(num_loaded=2).keys_to_load(['module.nncf_module.1', 'module.2']).model_keys(['1', 'nncf_module.2'])
MatchKeyDesc(num_loaded=2).keys_to_load(['module.1', 'module.2']).model_keys(['1', 'module.2'])
.all_matched(),
MatchKeyDesc(num_loaded=0, expects_error=True)
.keys_to_load(['module.nncf_module.1.1', 'module.2']).model_keys(['1', '2.2'])
.keys_to_load(['module.1.1', 'module.2']).model_keys(['1', '2.2'])
.all_not_matched(),

# collisions after normalization of keys
Expand All @@ -252,8 +250,8 @@ def fn() -> Set["str"]:
.model_keys(['pre_ops.0.op.1', 'pre_ops.1.op.1'])
.all_matched(),
MatchKeyDesc(num_loaded=2)
.keys_to_load(['nncf_module.pre_ops.1.op.1', 'nncf_module.pre_ops.0.op.1'])
.model_keys(['module.nncf_module.pre_ops.1.op.1', 'module.nncf_module.pre_ops.0.op.1'])
.keys_to_load(['module.pre_ops.1.op.1', 'module.pre_ops.0.op.1'])
.model_keys(['module.module.pre_ops.1.op.1', 'module.module.pre_ops.0.op.1'])
.all_matched(),
# quantization -> quantization + sparsity: op.1 was first, than
MatchKeyDesc(num_loaded=2)
Expand Down Expand Up @@ -361,9 +359,9 @@ def fn() -> Set["str"]:
.keys_to_ignore(['1'])
.skipped(['1']),
MatchKeyDesc(num_loaded=0, expects_error=True)
.keys_to_load(['module.nncf_module.1.1', '2.2']).model_keys(['module.1', 'module.2'])
.keys_to_load(['module.1.1', '2.2']).model_keys(['module.1', 'module.2'])
.keys_to_ignore(['1', '2.2'])
.skipped(['module.1', '2.2']).missing(['module.2']).unexpected(['module.nncf_module.1.1']),
.skipped(['module.1', '2.2']).missing(['module.2']).unexpected(['module.1.1']),

# optional parameter - not necessary in checkpoint can be initialized by default in the model
# can match unified FQ
Expand Down Expand Up @@ -419,9 +417,9 @@ def fn() -> Set["str"]:
.keys_to_ignore(['1'])
.skipped(['1']),
MatchKeyDesc(num_loaded=0, expects_error=True)
.keys_to_load(['module.nncf_module.1.1', '2.2']).model_keys(['module.1', 'module.2'])
.keys_to_load(['module.1.1', '2.2']).model_keys(['module.1', 'module.2'])
.keys_to_ignore(['1', '2.2'])
.skipped(['module.1', '2.2']).missing(['module.2']).unexpected(['module.nncf_module.1.1']),
.skipped(['module.1', '2.2']).missing(['module.2']).unexpected(['module.1.1']),

# optional parameter - not necessary in checkpoint can be initialized by default in the model
# can match unified FQ
Expand Down Expand Up @@ -466,12 +464,6 @@ def fn() -> Set["str"]:
.matched([EXTERNAL_QUANTIZERS_STORAGE_PREFIX + '.RELU_0|OUTPUT;RELU_2|OUTPUT;RELU_1|OUTPUT.' + OP1])
.unexpected(['module.' + EXTERNAL_QUANTIZERS_STORAGE_PREFIX + '.RELU_3.' + OP1]),

# can match keys under _nncf in the new style and the keys without _nncf
MatchKeyDesc(num_loaded=1)
.keys_to_load(['key.op'])
.model_keys(['_nncf.key.op'])
.all_matched()
.with_deprecation_warning(),

OptionalMatchKeyDesc(num_loaded=0)
.keys_to_load([])
Expand Down

0 comments on commit 24744d5

Please sign in to comment.