Skip to content

Commit

Permalink
Fix loading of INC quantized model (#452)
Browse files Browse the repository at this point in the history
  • Loading branch information
echarlaix authored Oct 13, 2023
1 parent b14059f commit f52d7c8
Showing 1 changed file with 7 additions and 40 deletions.
47 changes: 7 additions & 40 deletions optimum/intel/neural_compressor/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
import logging
import os
from pathlib import Path
Expand Down Expand Up @@ -138,7 +137,6 @@ def _from_pretrained(

model_save_dir = Path(model_cache_path).parent
inc_config = None
q_config = None
msg = None
try:
inc_config = INCConfig.from_pretrained(model_id)
Expand All @@ -153,54 +151,23 @@ def _from_pretrained(
# load(model_cache_path)
model = torch.jit.load(model_cache_path)
model = torch.jit.freeze(model.eval())
return cls(model, config=config, model_save_dir=model_save_dir, **kwargs)
return cls(model, config=config, model_save_dir=model_save_dir, inc_config=inc_config, **kwargs)

model_class = _get_model_class(config, cls.auto_model_class._model_mapping)
keys_to_ignore_on_load_unexpected = copy.deepcopy(
getattr(model_class, "_keys_to_ignore_on_load_unexpected", None)
)
keys_to_ignore_on_load_missing = copy.deepcopy(getattr(model_class, "_keys_to_ignore_on_load_missing", None))
# Avoid unnecessary warnings resulting from quantized model initialization
quantized_keys_to_ignore_on_load = [
r"zero_point",
r"scale",
r"packed_params",
r"constant",
r"module",
r"best_configure",
r"max_val",
r"min_val",
r"eps",
r"fake_quant_enabled",
r"observer_enabled",
]
if keys_to_ignore_on_load_unexpected is None:
model_class._keys_to_ignore_on_load_unexpected = quantized_keys_to_ignore_on_load
else:
model_class._keys_to_ignore_on_load_unexpected.extend(quantized_keys_to_ignore_on_load)
missing_keys_to_ignore_on_load = [r"weight", r"bias"]
if keys_to_ignore_on_load_missing is None:
model_class._keys_to_ignore_on_load_missing = missing_keys_to_ignore_on_load
else:
model_class._keys_to_ignore_on_load_missing.extend(missing_keys_to_ignore_on_load)
# Load the state dictionary of the model to verify whether the model to get the quantization config
state_dict = torch.load(model_cache_path, map_location="cpu")
q_config = state_dict.get("best_configure", None)

try:
if q_config is None:
model = model_class.from_pretrained(model_save_dir)
except AttributeError:
else:
init_contexts = [no_init_weights(_enable=True)]
with ContextManagers(init_contexts):
model = model_class(config)

model_class._keys_to_ignore_on_load_unexpected = keys_to_ignore_on_load_unexpected
model_class._keys_to_ignore_on_load_missing = keys_to_ignore_on_load_missing

# Load the state dictionary of the model to verify whether the model is quantized or not
state_dict = torch.load(model_cache_path, map_location="cpu")
if "best_configure" in state_dict and state_dict["best_configure"] is not None:
q_config = state_dict["best_configure"]
try:
model = load(model_cache_path, model)
except Exception as e:
# For incompatible torch version check
if msg is not None:
e.args += (msg,)
raise
Expand Down

0 comments on commit f52d7c8

Please sign in to comment.