Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove legacy external quantizer storage names #2163

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 5 additions & 38 deletions nncf/torch/checkpoint_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ class NormalizedKeys:
def __init__(self, keys: List[str], keys_to_ignore: List[str]):
self._unique_normalized_key_vs_orig_key_map = {}
self.is_unified_group_detected = False
self.has_legacy_storage_keys = False
unique_clipped_key_vs_orig_key_map, ignored_keys = self._clip_keys_without_collisions(keys, keys_to_ignore)
self.ignored_orig_keys = ignored_keys
ignored_keys = self._normalize_keys_without_collisions(unique_clipped_key_vs_orig_key_map, keys_to_ignore)
Expand Down Expand Up @@ -228,9 +227,8 @@ def _clip_keys_without_collisions(keys: List[str], keys_to_ignore: List[str]) ->
@staticmethod
def _key_clipper(key: str) -> str:
new_key = key
from nncf.torch.nncf_network import LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME # pylint: disable=cyclic-import

clip_patterns = [LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME + ".", "module.", "|OUTPUT", "|INPUT", "_nncf."]
clip_patterns = ["module.", "|OUTPUT", "|INPUT"]
for pattern in clip_patterns:
new_key = new_key.replace(pattern, "")
return new_key
Expand All @@ -240,11 +238,6 @@ def _key_replacer(self, key: str) -> List[str]:

match = re.search("(pre_ops|post_ops)\\.(\\d+?)\\.op", key)
new_key = new_key if not match else new_key.replace(match.group(), "operation")

new_key, did_replace = self._replace_legacy_act_quantizer_storage_name(new_key)
if did_replace:
self.has_legacy_storage_keys = True

result = self._split_unified_parameters(new_key)
if len(result) > 1:
self.is_unified_group_detected = True
Expand All @@ -263,30 +256,17 @@ def _split_unified_parameters(new_key: str) -> List[str]:
Returns original key if there's no ';' and operation doesn't start with EXTERNAL_QUANTIZERS_STORAGE_NAME
"""
result = [new_key]
from nncf.torch.nncf_network import CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX # pylint: disable=cyclic-import
from nncf.torch.nncf_network import EXTERNAL_QUANTIZERS_STORAGE_PREFIX # pylint: disable=cyclic-import

if ";" in new_key and new_key.startswith(CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX):
if ";" in new_key and new_key.startswith(EXTERNAL_QUANTIZERS_STORAGE_PREFIX):
group_of_keys = new_key.split(";")
last_key = group_of_keys[-1]
common_op = last_key.split(".")[-1]
result = [group_of_keys[0] + "." + common_op, CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX + "." + last_key]
result = [group_of_keys[0] + "." + common_op, EXTERNAL_QUANTIZERS_STORAGE_PREFIX + "." + last_key]
for key in group_of_keys[1:-1]:
result.append(CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX + "." + key + "." + common_op)
result.append(EXTERNAL_QUANTIZERS_STORAGE_PREFIX + "." + key + "." + common_op)
return result

@staticmethod
def _replace_legacy_act_quantizer_storage_name(checkpoint_key: str) -> Tuple[str, bool]:
did_replace = False
splits = checkpoint_key.split(".")
from nncf.torch.nncf_network import CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX # pylint: disable=cyclic-import
from nncf.torch.nncf_network import LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX # pylint: disable=cyclic-import

if splits[0] == LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX:
did_replace = True
splits[0] = CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX
reconstructed_key = ".".join(splits)
return reconstructed_key, did_replace


class KeyMatcher:
"""
Expand Down Expand Up @@ -341,19 +321,6 @@ def run(self) -> Dict[str, torch.Tensor]:
"names. The newly exported checkpoints will be adjusted to the new format."
)

if normalized_keys_to_load.has_legacy_storage_keys:
from nncf.torch.nncf_network import CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX
from nncf.torch.nncf_network import LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX

warning_deprecated(
f"Legacy NNCF-enabled .pth checkpoint has been loaded! "
f"The {LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX} storage key is replaced with "
f"{CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX} in newer versions of NNCF, and support "
f"for the legacy storage key will be dropped in a future release. "
f"This checkpoint will be loaded; update your checkpoint file by saving this model's"
f"checkpoint file again."
)

if normalized_model_keys.is_unified_group_detected and not normalized_keys_to_load.is_unified_group_detected:
nncf_logger.warning(
"Unified parameters are detected in the compressed model, but all parameters are independent "
Expand Down
5 changes: 1 addition & 4 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,8 @@
from nncf.torch.utils import get_model_device
from nncf.torch.utils import training_mode_switcher

LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME = "nncf_module"
LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "external_quantizers"

EXTERNAL_QUANTIZERS_STORAGE_NAME = "external_quantizers"
CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "_nncf." + EXTERNAL_QUANTIZERS_STORAGE_NAME
EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "_nncf." + EXTERNAL_QUANTIZERS_STORAGE_NAME

Module = TypeVar("Module", bound=nn.Module)

Expand Down
107 changes: 35 additions & 72 deletions tests/torch/test_backward_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,11 @@
from examples.torch.common.execution import prepare_model_for_execution
from examples.torch.common.model_loader import load_model
from nncf.api.compression import CompressionStage
from nncf.common.graph.definitions import MODEL_INPUT_OP_NAME
from nncf.common.logging.logger import NNCFDeprecationWarning
from nncf.config import NNCFConfig
from nncf.torch import register_default_init_args
from nncf.torch.checkpoint_loading import load_state
from nncf.torch.nncf_network import LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX
from nncf.torch.nncf_network import LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME
from nncf.torch.nncf_network import EXTERNAL_QUANTIZERS_STORAGE_PREFIX
from nncf.torch.quantization.algo import QUANTIZER_BUILDER_STATE_VERSION_SAVE_NAME
from nncf.torch.quantization.algo import QuantizerBuilderStateVersion
from tests.shared.helpers import get_cli_dict_args
Expand All @@ -37,7 +35,6 @@
from tests.torch.helpers import create_compressed_model_and_algo_for_test
from tests.torch.helpers import create_ones_mock_dataloader
from tests.torch.helpers import register_bn_adaptation_init_args
from tests.torch.quantization.test_range_init import SingleConv2dIdentityModel
from tests.torch.test_compressed_graph import get_basic_quantization_config
from tests.torch.test_sanity_sample import create_command_line

Expand Down Expand Up @@ -164,42 +161,6 @@ def test_loaded_model_evals_according_to_saved_acc(_params, tmp_path, dataset_di
assert torch.load(checkpoint_path)["best_acc1"] == pytest.approx(metrics["Accuracy"], abs=1e-2)


old_style_sd = {
f"{LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME}.conv2d.weight": torch.ones([3, 3, 1, 1]),
f"{LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME}.conv2d.bias": torch.ones([3]),
f"{LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME}.conv2d.pre_ops.0.op._num_bits": 8 * torch.ones([1], dtype=torch.int32),
f"{LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME}.conv2d.pre_ops.0.op.signed_tensor": torch.ones([1], dtype=torch.int32),
f"{LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME}.conv2d.pre_ops.0.op.enabled": torch.ones([1], dtype=torch.int32),
f"{LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME}.conv2d.pre_ops.0.op.scale": torch.ones([3, 1, 1, 1]),
f"{LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./{MODEL_INPUT_OP_NAME}_0|OUTPUT._num_bits": 8
* torch.ones([1], dtype=torch.int32),
f"{LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./{MODEL_INPUT_OP_NAME}_0|OUTPUT.signed_tensor": torch.zeros(
[1], dtype=torch.int32
),
f"{LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./{MODEL_INPUT_OP_NAME}_0|OUTPUT.enabled": torch.ones(
[1], dtype=torch.int32
),
f"{LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./{MODEL_INPUT_OP_NAME}_0|OUTPUT.scale": torch.ones([1]),
}


def test_renamed_activation_quantizer_storage_in_state_dict():
model = SingleConv2dIdentityModel()
config = get_basic_quantization_config(input_info={"sample_size": [1, 3, 100, 100]})
register_bn_adaptation_init_args(config)
compressed_model, _ = create_compressed_model_and_algo_for_test(model, config)

with pytest.warns(NNCFDeprecationWarning):
_ = load_state(compressed_model, old_style_sd, is_resume=True)


def test_can_compress_with_config_and_resume_of_old_checkpoint():
model = SingleConv2dIdentityModel()
config = get_basic_quantization_config(input_info={"sample_size": [1, 3, 100, 100]})
register_bn_adaptation_init_args(config)
create_compressed_model_and_algo_for_test(model, config, compression_state=old_style_sd)


# BN Wrapping backward compatibility test


Expand All @@ -217,39 +178,41 @@ def forward(self, x):


sd_without_nncf_bn_wrapping = {
"nncf_module.conv.weight": torch.ones([9, 3, 3, 3]),
"nncf_module.conv.bias": torch.ones([9]),
"nncf_module.conv.nncf_padding_value": torch.ones([1]),
"nncf_module.conv.pre_ops.0.op._num_bits": torch.ones([1]),
"nncf_module.conv.pre_ops.0.op.signed_tensor": torch.ones([1]),
"nncf_module.conv.pre_ops.0.op.enabled": torch.ones([1]),
"nncf_module.conv.pre_ops.0.op.scale": torch.ones([9, 1, 1, 1]),
"nncf_module.bn.weight": torch.ones([9]),
"nncf_module.bn.bias": torch.ones([9]),
"nncf_module.bn.running_mean": torch.ones([9]),
"nncf_module.bn.running_var": torch.ones([9]),
"nncf_module.bn.num_batches_tracked": torch.ones([]),
"nncf_module.conv1.weight": torch.ones([3, 9, 3, 3]),
"nncf_module.conv1.bias": torch.ones([3]),
"nncf_module.conv1.nncf_padding_value": torch.ones([1]),
"nncf_module.conv1.pre_ops.0.op._num_bits": torch.ones([1]),
"nncf_module.conv1.pre_ops.0.op.signed_tensor": torch.ones([1]),
"nncf_module.conv1.pre_ops.0.op.enabled": torch.ones([1]),
"nncf_module.conv1.pre_ops.0.op.scale": torch.ones([3, 1, 1, 1]),
"nncf_module.bn1.weight": torch.ones([3]),
"nncf_module.bn1.bias": torch.ones([3]),
"nncf_module.bn1.running_mean": torch.ones([3]),
"nncf_module.bn1.running_var": torch.ones([3]),
"nncf_module.bn1.num_batches_tracked": torch.ones([]),
"external_quantizers./nncf_model_input_0|OUTPUT._num_bits": torch.ones([1]),
"external_quantizers./nncf_model_input_0|OUTPUT.signed_tensor": torch.ones([1]),
"external_quantizers./nncf_model_input_0|OUTPUT.enabled": torch.ones([1]),
"external_quantizers./nncf_model_input_0|OUTPUT.scale": torch.ones([1]),
"conv.weight": torch.ones([9, 3, 3, 3]),
"conv.bias": torch.ones([9]),
"conv.nncf_padding_value": torch.ones([1]),
"conv.pre_ops.0.op._num_bits": torch.ones([1]),
"conv.pre_ops.0.op.signed_tensor": torch.ones([1]),
"conv.pre_ops.0.op.enabled": torch.ones([1]),
"conv.pre_ops.0.op.scale": torch.ones([9, 1, 1, 1]),
"bn.weight": torch.ones([9]),
"bn.bias": torch.ones([9]),
"bn.running_mean": torch.ones([9]),
"bn.running_var": torch.ones([9]),
"bn.num_batches_tracked": torch.ones([]),
"conv1.weight": torch.ones([3, 9, 3, 3]),
"conv1.bias": torch.ones([3]),
"conv1.nncf_padding_value": torch.ones([1]),
"conv1.pre_ops.0.op._num_bits": torch.ones([1]),
"conv1.pre_ops.0.op.signed_tensor": torch.ones([1]),
"conv1.pre_ops.0.op.enabled": torch.ones([1]),
"conv1.pre_ops.0.op.scale": torch.ones([3, 1, 1, 1]),
"bn1.weight": torch.ones([3]),
"bn1.bias": torch.ones([3]),
"bn1.running_mean": torch.ones([3]),
"bn1.running_var": torch.ones([3]),
"bn1.num_batches_tracked": torch.ones([]),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./nncf_model_input_0|OUTPUT._num_bits": torch.ones([1]),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./nncf_model_input_0|OUTPUT.signed_tensor": torch.ones([1]),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./nncf_model_input_0|OUTPUT.enabled": torch.ones([1]),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./nncf_model_input_0|OUTPUT.scale": torch.ones([1]),
# Old bn layer names: |||||||||||
"external_quantizers.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT._num_bits": torch.ones([1]),
"external_quantizers.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT.signed_tensor": torch.ones([1]),
"external_quantizers.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT.enabled": torch.ones([1]),
"external_quantizers.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT.scale": torch.ones([1]),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT._num_bits": torch.ones([1]),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT.signed_tensor": torch.ones(
[1]
),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT.enabled": torch.ones([1]),
f"{EXTERNAL_QUANTIZERS_STORAGE_PREFIX}.ConvBNLayer/BatchNorm2d[bn]/batch_norm_0|OUTPUT.scale": torch.ones([1]),
}

compression_state_without_bn_wrapping = {
Expand Down
Loading