From f52d7c8801ce23299096bb7626b3ef8e226a887f Mon Sep 17 00:00:00 2001 From: Ella Charlaix <80481427+echarlaix@users.noreply.github.com> Date: Fri, 13 Oct 2023 17:13:59 +0200 Subject: [PATCH] Fix loading of INC quantized model (#452) --- .../intel/neural_compressor/modeling_base.py | 47 +++---------------- 1 file changed, 7 insertions(+), 40 deletions(-) diff --git a/optimum/intel/neural_compressor/modeling_base.py b/optimum/intel/neural_compressor/modeling_base.py index 768164c05c..19c06c8c4c 100644 --- a/optimum/intel/neural_compressor/modeling_base.py +++ b/optimum/intel/neural_compressor/modeling_base.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import logging import os from pathlib import Path @@ -138,7 +137,6 @@ def _from_pretrained( model_save_dir = Path(model_cache_path).parent inc_config = None - q_config = None msg = None try: inc_config = INCConfig.from_pretrained(model_id) @@ -153,54 +151,23 @@ def _from_pretrained( # load(model_cache_path) model = torch.jit.load(model_cache_path) model = torch.jit.freeze(model.eval()) - return cls(model, config=config, model_save_dir=model_save_dir, **kwargs) + return cls(model, config=config, model_save_dir=model_save_dir, inc_config=inc_config, **kwargs) model_class = _get_model_class(config, cls.auto_model_class._model_mapping) - keys_to_ignore_on_load_unexpected = copy.deepcopy( - getattr(model_class, "_keys_to_ignore_on_load_unexpected", None) - ) - keys_to_ignore_on_load_missing = copy.deepcopy(getattr(model_class, "_keys_to_ignore_on_load_missing", None)) - # Avoid unnecessary warnings resulting from quantized model initialization - quantized_keys_to_ignore_on_load = [ - r"zero_point", - r"scale", - r"packed_params", - r"constant", - r"module", - r"best_configure", - r"max_val", - r"min_val", - r"eps", - r"fake_quant_enabled", - r"observer_enabled", - ] - if keys_to_ignore_on_load_unexpected is None: - model_class._keys_to_ignore_on_load_unexpected = quantized_keys_to_ignore_on_load - else: - model_class._keys_to_ignore_on_load_unexpected.extend(quantized_keys_to_ignore_on_load) - missing_keys_to_ignore_on_load = [r"weight", r"bias"] - if keys_to_ignore_on_load_missing is None: - model_class._keys_to_ignore_on_load_missing = missing_keys_to_ignore_on_load - else: - model_class._keys_to_ignore_on_load_missing.extend(missing_keys_to_ignore_on_load) + # Load the state dictionary of the model to verify whether the model to get the quantization config + state_dict = torch.load(model_cache_path, map_location="cpu") + q_config = state_dict.get("best_configure", None) - try: + if q_config is None: model = model_class.from_pretrained(model_save_dir) - except AttributeError: + else: init_contexts = [no_init_weights(_enable=True)] with ContextManagers(init_contexts): model = model_class(config) - - model_class._keys_to_ignore_on_load_unexpected = keys_to_ignore_on_load_unexpected - model_class._keys_to_ignore_on_load_missing = keys_to_ignore_on_load_missing - - # Load the state dictionary of the model to verify whether the model is quantized or not - state_dict = torch.load(model_cache_path, map_location="cpu") - if "best_configure" in state_dict and state_dict["best_configure"] is not None: - q_config = state_dict["best_configure"] try: model = load(model_cache_path, model) except Exception as e: + # For incompatible torch version check if msg is not None: e.args += (msg,) raise