Skip to content

Commit

Permalink
Remove legacy external quantizer storage names
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Sep 26, 2023
1 parent 36f45a2 commit ad03d21
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 207 deletions.
42 changes: 5 additions & 37 deletions nncf/torch/checkpoint_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ class NormalizedKeys:
def __init__(self, keys: List[str], keys_to_ignore: List[str]):
self._unique_normalized_key_vs_orig_key_map = {}
self.is_unified_group_detected = False
self.has_legacy_storage_keys = False
unique_clipped_key_vs_orig_key_map, ignored_keys = self._clip_keys_without_collisions(keys, keys_to_ignore)
self.ignored_orig_keys = ignored_keys
ignored_keys = self._normalize_keys_without_collisions(unique_clipped_key_vs_orig_key_map, keys_to_ignore)
Expand Down Expand Up @@ -228,9 +227,8 @@ def _clip_keys_without_collisions(keys: List[str], keys_to_ignore: List[str]) ->
@staticmethod
def _key_clipper(key: str) -> str:
new_key = key
from nncf.torch.nncf_network import LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME # pylint: disable=cyclic-import

clip_patterns = [LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME + ".", "module.", "|OUTPUT", "|INPUT", "_nncf."]
clip_patterns = ["module.", "|OUTPUT", "|INPUT", "_nncf."]
for pattern in clip_patterns:
new_key = new_key.replace(pattern, "")
return new_key
Expand All @@ -240,11 +238,6 @@ def _key_replacer(self, key: str) -> List[str]:

match = re.search("(pre_ops|post_ops)\\.(\\d+?)\\.op", key)
new_key = new_key if not match else new_key.replace(match.group(), "operation")

new_key, did_replace = self._replace_legacy_act_quantizer_storage_name(new_key)
if did_replace:
self.has_legacy_storage_keys = True

result = self._split_unified_parameters(new_key)
if len(result) > 1:
self.is_unified_group_detected = True
Expand All @@ -263,29 +256,17 @@ def _split_unified_parameters(new_key: str) -> List[str]:
Returns original key if there's no ';' and operation doesn't start with EXTERNAL_QUANTIZERS_STORAGE_NAME
"""
result = [new_key]
from nncf.torch.nncf_network import CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX # pylint: disable=cyclic-import
from nncf.torch.nncf_network import EXTERNAL_QUANTIZERS_STORAGE_PREFIX # pylint: disable=cyclic-import

if ";" in new_key and new_key.startswith(CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX):
if ";" in new_key and new_key.startswith(EXTERNAL_QUANTIZERS_STORAGE_PREFIX):
group_of_keys = new_key.split(";")
last_key = group_of_keys[-1]
common_op = last_key.split(".")[-1]
result = [group_of_keys[0] + "." + common_op, CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX + "." + last_key]
result = [group_of_keys[0] + "." + common_op, EXTERNAL_QUANTIZERS_STORAGE_PREFIX + "." + last_key]
for key in group_of_keys[1:-1]:
result.append(CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX + "." + key + "." + common_op)
result.append(EXTERNAL_QUANTIZERS_STORAGE_PREFIX + "." + key + "." + common_op)
return result

@staticmethod
def _replace_legacy_act_quantizer_storage_name(checkpoint_key: str) -> Tuple[str, bool]:
did_replace = False
splits = checkpoint_key.split(".")
from nncf.torch.nncf_network import CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX # pylint: disable=cyclic-import
from nncf.torch.nncf_network import LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX # pylint: disable=cyclic-import

if splits[0] == LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX:
did_replace = True
splits[0] = CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX
reconstructed_key = ".".join(splits)
return reconstructed_key, did_replace


class KeyMatcher:
Expand Down Expand Up @@ -341,19 +322,6 @@ def run(self) -> Dict[str, torch.Tensor]:
"names. The newly exported checkpoints will be adjusted to the new format."
)

if normalized_keys_to_load.has_legacy_storage_keys:
from nncf.torch.nncf_network import CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX
from nncf.torch.nncf_network import LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX

warning_deprecated(
f"Legacy NNCF-enabled .pth checkpoint has been loaded! "
f"The {LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX} storage key is replaced with "
f"{CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX} in newer versions of NNCF, and support "
f"for the legacy storage key will be dropped in a future release. "
f"This checkpoint will be loaded; update your checkpoint file by saving this model's"
f"checkpoint file again."
)

if normalized_model_keys.is_unified_group_detected and not normalized_keys_to_load.is_unified_group_detected:
nncf_logger.warning(
"Unified parameters are detected in the compressed model, but all parameters are independent "
Expand Down
5 changes: 1 addition & 4 deletions nncf/torch/nncf_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,8 @@
from nncf.torch.utils import get_model_device
from nncf.torch.utils import training_mode_switcher

LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME = "nncf_module"
LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "external_quantizers"

EXTERNAL_QUANTIZERS_STORAGE_NAME = "external_quantizers"
CURRENT_EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "_nncf." + EXTERNAL_QUANTIZERS_STORAGE_NAME
EXTERNAL_QUANTIZERS_STORAGE_PREFIX = "_nncf." + EXTERNAL_QUANTIZERS_STORAGE_NAME

Module = TypeVar("Module", bound=nn.Module)

Expand Down
40 changes: 0 additions & 40 deletions tests/torch/test_backward_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,10 @@
from examples.torch.common.execution import prepare_model_for_execution
from examples.torch.common.model_loader import load_model
from nncf.api.compression import CompressionStage
from nncf.common.graph.definitions import MODEL_INPUT_OP_NAME
from nncf.common.logging.logger import NNCFDeprecationWarning
from nncf.config import NNCFConfig
from nncf.torch import register_default_init_args
from nncf.torch.checkpoint_loading import load_state
from nncf.torch.nncf_network import LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX
from nncf.torch.nncf_network import LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME
from nncf.torch.quantization.algo import QUANTIZER_BUILDER_STATE_VERSION_SAVE_NAME
from nncf.torch.quantization.algo import QuantizerBuilderStateVersion
from tests.shared.helpers import get_cli_dict_args
Expand All @@ -37,7 +34,6 @@
from tests.torch.helpers import create_compressed_model_and_algo_for_test
from tests.torch.helpers import create_ones_mock_dataloader
from tests.torch.helpers import register_bn_adaptation_init_args
from tests.torch.quantization.test_range_init import SingleConv2dIdentityModel
from tests.torch.test_compressed_graph import get_basic_quantization_config
from tests.torch.test_sanity_sample import create_command_line

Expand Down Expand Up @@ -164,42 +160,6 @@ def test_loaded_model_evals_according_to_saved_acc(_params, tmp_path, dataset_di
assert torch.load(checkpoint_path)["best_acc1"] == pytest.approx(metrics["Accuracy"], abs=1e-2)


old_style_sd = {
f"{LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME}.conv2d.weight": torch.ones([3, 3, 1, 1]),
f"{LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME}.conv2d.bias": torch.ones([3]),
f"{LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME}.conv2d.pre_ops.0.op._num_bits": 8 * torch.ones([1], dtype=torch.int32),
f"{LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME}.conv2d.pre_ops.0.op.signed_tensor": torch.ones([1], dtype=torch.int32),
f"{LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME}.conv2d.pre_ops.0.op.enabled": torch.ones([1], dtype=torch.int32),
f"{LEGACY_MODEL_WRAPPED_BY_NNCF_ATTR_NAME}.conv2d.pre_ops.0.op.scale": torch.ones([3, 1, 1, 1]),
f"{LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./{MODEL_INPUT_OP_NAME}_0|OUTPUT._num_bits": 8
* torch.ones([1], dtype=torch.int32),
f"{LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./{MODEL_INPUT_OP_NAME}_0|OUTPUT.signed_tensor": torch.zeros(
[1], dtype=torch.int32
),
f"{LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./{MODEL_INPUT_OP_NAME}_0|OUTPUT.enabled": torch.ones(
[1], dtype=torch.int32
),
f"{LEGACY_EXTERNAL_QUANTIZERS_STORAGE_PREFIX}./{MODEL_INPUT_OP_NAME}_0|OUTPUT.scale": torch.ones([1]),
}


def test_renamed_activation_quantizer_storage_in_state_dict():
model = SingleConv2dIdentityModel()
config = get_basic_quantization_config(input_info={"sample_size": [1, 3, 100, 100]})
register_bn_adaptation_init_args(config)
compressed_model, _ = create_compressed_model_and_algo_for_test(model, config)

with pytest.warns(NNCFDeprecationWarning):
_ = load_state(compressed_model, old_style_sd, is_resume=True)


def test_can_compress_with_config_and_resume_of_old_checkpoint():
model = SingleConv2dIdentityModel()
config = get_basic_quantization_config(input_info={"sample_size": [1, 3, 100, 100]})
register_bn_adaptation_init_args(config)
create_compressed_model_and_algo_for_test(model, config, compression_state=old_style_sd)


# BN Wrapping backward compatibility test


Expand Down
Loading

0 comments on commit ad03d21

Please sign in to comment.