Skip to content

Commit

Permalink
Enhancement memory usage for PyTorch quantization (#541)
Browse files Browse the repository at this point in the history
Signed-off-by: Cheng, Penghui <[email protected]>
Signed-off-by: Xin He <[email protected]>
Signed-off-by: Lv, Liang1 <[email protected]>
  • Loading branch information
PenghuiCheng authored Feb 28, 2023
1 parent 6e10efd commit c295a7f
Show file tree
Hide file tree
Showing 11 changed files with 384 additions and 187 deletions.
393 changes: 275 additions & 118 deletions neural_compressor/adaptor/pytorch.py

Large diffs are not rendered by default.

14 changes: 12 additions & 2 deletions neural_compressor/conf/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,7 @@ def map_pyconfig_to_cfg(self, pythonic_config):
'model.domain': pythonic_config.quantization.domain,
'quantization.recipes': pythonic_config.quantization.recipes,
'quantization.approach': pythonic_config.quantization.approach,
'quantization.example_inputs': pythonic_config.quantization.example_inputs,
'quantization.calibration.sampling_size':
pythonic_config.quantization.calibration_sampling_size,
'quantization.optype_wise': pythonic_config.quantization.op_type_list,
Expand Down Expand Up @@ -1429,7 +1430,7 @@ def map_pyconfig_to_cfg(self, pythonic_config):
if st_key in st_kwargs:
st_val = st_kwargs[st_key]
mapping.update({'tuning.strategy.' + st_key: st_val})

if pythonic_config.distillation is not None:
mapping.update({
'distillation.train.criterion': pythonic_config.distillation.criterion,
Expand Down Expand Up @@ -1458,7 +1459,6 @@ def map_pyconfig_to_cfg(self, pythonic_config):
if pythonic_config.benchmark.outputs != []:
mapping.update({'model.outputs': pythonic_config.benchmark.outputs})
mapping.update({
'model.backend': pythonic_config.benchmark.backend,
'evaluation.performance.warmup': pythonic_config.benchmark.warmup,
'evaluation.performance.iteration': pythonic_config.benchmark.iteration,
'evaluation.performance.configs.cores_per_instance':
Expand All @@ -1478,6 +1478,16 @@ def map_pyconfig_to_cfg(self, pythonic_config):
'evaluation.accuracy.configs.intra_num_of_threads':
pythonic_config.benchmark.intra_num_of_threads,
})
if "model.backend" not in mapping:
mapping.update({
'model.backend': pythonic_config.benchmark.backend,
})
else:
if mapping['model.backend'] == 'default' and \
pythonic_config.benchmark.backend != 'default':
mapping.update({
'model.backend': pythonic_config.benchmark.backend,
})

if "model.backend" not in mapping:
mapping.update({
Expand Down
13 changes: 13 additions & 0 deletions neural_compressor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,7 @@ def __init__(self,
max_trials=100,
performance_only=False,
reduce_range=None,
example_inputs=None,
excluded_precisions=[],
quant_level=1,
accuracy_criterion=accuracy_criterion,
Expand Down Expand Up @@ -428,6 +429,7 @@ def __init__(self,
max_trials: max tune times. default value is 100. Combine with timeout field to decide when to exit
performance_only: whether do evaluation
reduce_range: whether use 7 bit
example_inputs: used to trace PyTorch model with torch.jit/torch.fx
excluded_precisions: precisions to be excluded, support 'bf16'
quant_level: support 0 and 1, 0 is conservative strategy, 1 is basic(default) or user-specified strategy
accuracy_criterion: accuracy constraint settings
Expand Down Expand Up @@ -455,6 +457,7 @@ def __init__(self,
self.calibration_sampling_size = calibration_sampling_size
self.quant_level = quant_level
self.use_distributed_tuning=use_distributed_tuning
self._example_inputs = example_inputs

@property
def domain(self):
Expand Down Expand Up @@ -766,6 +769,16 @@ def inputs(self, inputs):
if check_value('inputs', inputs, str):
self._inputs = inputs

@property
def example_inputs(self):
"""Get strategy_kwargs."""
return self._example_inputs

@example_inputs.setter
def example_inputs(self, example_inputs):
"""Set example_inputs."""
self._example_inputs = example_inputs


class TuningCriterion:
"""Class for Tuning Criterion.
Expand Down
4 changes: 2 additions & 2 deletions neural_compressor/model/torch_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def save(self, root=None):
logger.info("Save config file of quantized model to {}.".format(root))
except IOError as e:
logger.error("Fail to save configure file and weights due to {}.".format(e))

if isinstance(self.model, torch.jit._script.RecursiveScriptModule):
self.model.save(os.path.join(root, "best_model.pt"))

5 changes: 3 additions & 2 deletions neural_compressor/strategy/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,7 +995,8 @@ def _create_path(self, custom_path, filename):
def _set_framework_info(self, q_dataloader, q_func=None):
framework_specific_info = {'device': self.cfg.device,
'approach': self.cfg.quantization.approach,
'random_seed': self.cfg.tuning.random_seed}
'random_seed': self.cfg.tuning.random_seed,
'performance_only': self.cfg.tuning.exit_policy.performance_only,}
framework = self.cfg.model.framework.lower()
framework_specific_info.update({'backend': self.cfg.model.get('backend', 'default')})
framework_specific_info.update({'format': self.cfg.model.get('quant_format', 'default')})
Expand All @@ -1010,7 +1011,6 @@ def _set_framework_info(self, q_dataloader, q_func=None):
"outputs": self.cfg.model.outputs,
'workspace_path': self.cfg.tuning.workspace.path,
'recipes': self.cfg.quantization.recipes,
'performance_only': self.cfg.tuning.exit_policy.performance_only,
'use_bf16': self.cfg.use_bf16 if self.cfg.use_bf16 is not None else False})
for item in ['scale_propagation_max_pooling', 'scale_propagation_concat']:
if item not in framework_specific_info['recipes']:
Expand Down Expand Up @@ -1054,6 +1054,7 @@ def _set_framework_info(self, q_dataloader, q_func=None):
framework_specific_info.update(
{"default_qconfig": self.cfg['quantization']['op_wise']['default_qconfig']})
framework_specific_info.update({"q_func": q_func})
framework_specific_info.update({"example_inputs": self.cfg.quantization.example_inputs})
return framework, framework_specific_info

def _set_objectives(self):
Expand Down
77 changes: 35 additions & 42 deletions neural_compressor/utils/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
stat_dict['best_configure'] = tune_cfg
else:
logger.error("Unexpected checkpoint type:{}. \
Only file dir/path or state_dict is acceptable")
Only file dir/path or state_dict is acceptable")

if not isinstance(stat_dict, torch.jit._script.RecursiveScriptModule):
assert 'best_configure' in stat_dict, \
Expand All @@ -223,17 +223,10 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
logger.info("Finish load the model quantized by INC IPEX backend.")
return q_model

try:
q_model = copy.deepcopy(model)
except Exception as e: # pragma: no cover
logger.warning("Fail to deep copy the model due to {}, inplace is used now.".
format(repr(e)))
q_model = model

if 'is_oneshot' in tune_cfg and tune_cfg['is_oneshot']:
return _load_int8_orchestration(q_model, tune_cfg, stat_dict, example_inputs, **kwargs)
return _load_int8_orchestration(model, tune_cfg, stat_dict, example_inputs, **kwargs)

q_model.eval()
model.eval()
approach_quant_mode = None
if tune_cfg['approach'] == "post_training_dynamic_quant":
approach_quant_mode = 'dynamic'
Expand Down Expand Up @@ -279,79 +272,79 @@ def load(checkpoint_dir=None, model=None, history_cfg=None, **kwargs):
op_cfgs = _cfg_to_qconfig(tune_cfg, tune_cfg['approach'])
fx_op_cfgs = _cfgs_to_fx_cfgs(op_cfgs, tune_cfg['approach'])
if not tune_cfg['fx_sub_module_list']:
tmp_model = q_model
tmp_model = model
if tune_cfg['approach'] == "quant_aware_training":
q_model.train()
model.train()
if version > Version("1.12.1"): # pragma: no cover
# pylint: disable=E1123
q_model = prepare_qat_fx(q_model,
fx_op_cfgs,
prepare_custom_config=prepare_custom_config_dict,
example_inputs=example_inputs)
model = prepare_qat_fx(model,
fx_op_cfgs,
prepare_custom_config=prepare_custom_config_dict,
example_inputs=example_inputs)
else:
q_model = prepare_qat_fx(q_model,
fx_op_cfgs,
prepare_custom_config_dict=prepare_custom_config_dict)
model = prepare_qat_fx(model,
fx_op_cfgs,
prepare_custom_config_dict=prepare_custom_config_dict)
else:
if version > Version("1.12.1"): # pragma: no cover
# pylint: disable=E1123
q_model = prepare_fx(q_model,
fx_op_cfgs,
prepare_custom_config=prepare_custom_config_dict,
example_inputs=example_inputs)
model = prepare_fx(model,
fx_op_cfgs,
prepare_custom_config=prepare_custom_config_dict,
example_inputs=example_inputs)
else:
q_model = prepare_fx(q_model,
fx_op_cfgs,
prepare_custom_config_dict=prepare_custom_config_dict)
model = prepare_fx(model,
fx_op_cfgs,
prepare_custom_config_dict=prepare_custom_config_dict)
if version > Version("1.12.1"): # pragma: no cover
# pylint: disable=E1123
q_model = convert_fx(q_model,
model = convert_fx(model,
convert_custom_config=convert_custom_config_dict)
else:
q_model = convert_fx(q_model,
model = convert_fx(model,
convert_custom_config_dict=convert_custom_config_dict)
util.append_attr(q_model, tmp_model)
util.append_attr(model, tmp_model)
del tmp_model
else:
sub_module_list = tune_cfg['fx_sub_module_list']
if tune_cfg['approach'] == "quant_aware_training":
q_model.train()
model.train()
PyTorch_FXAdaptor.prepare_sub_graph(sub_module_list,
fx_op_cfgs,
q_model,
model,
prefix='',
is_qat=True,
example_inputs=example_inputs)
else:
PyTorch_FXAdaptor.prepare_sub_graph(sub_module_list,
fx_op_cfgs,
q_model,
model,
prefix='',
example_inputs=example_inputs)
PyTorch_FXAdaptor.convert_sub_graph(sub_module_list, q_model, prefix='')
PyTorch_FXAdaptor.convert_sub_graph(sub_module_list, model, prefix='')
else:
if tune_cfg['approach'] == "post_training_dynamic_quant":
op_cfgs = _cfg_to_qconfig(tune_cfg, tune_cfg['approach'])
else:
op_cfgs = _cfg_to_qconfig(tune_cfg)

_propagate_qconfig(q_model, op_cfgs, approach=tune_cfg['approach'])
_propagate_qconfig(model, op_cfgs, approach=tune_cfg['approach'])
# sanity check common API misusage
if not any(hasattr(m, 'qconfig') and m.qconfig for m in q_model.modules()):
if not any(hasattr(m, 'qconfig') and m.qconfig for m in model.modules()):
logger.warn("None of the submodule got qconfig applied. Make sure you "
"passed correct configuration through `qconfig_dict` or "
"by assigning the `.qconfig` attribute directly on submodules")
if tune_cfg['approach'] != "post_training_dynamic_quant":
add_observer_(q_model)
q_model = convert(q_model, mapping=q_mapping, inplace=True)
add_observer_(model)
model = convert(model, mapping=q_mapping, inplace=True)

bf16_ops_list = tune_cfg['bf16_ops_list'] if 'bf16_ops_list' in tune_cfg.keys() else []
if len(bf16_ops_list) > 0 and (version >= Version("1.11.0-rc1")):
from ..adaptor.torch_utils.bf16_convert import Convert
q_model = Convert(q_model, tune_cfg)
model = Convert(model, tune_cfg)
if checkpoint_dir is None and history_cfg is not None:
_set_activation_scale_zeropoint(q_model, history_cfg)
_set_activation_scale_zeropoint(model, history_cfg)
else:
q_model.load_state_dict(stat_dict)
util.get_embedding_contiguous(q_model)
return q_model
model.load_state_dict(stat_dict)
util.get_embedding_contiguous(model)
return model
20 changes: 10 additions & 10 deletions test/adaptor/pytorch_adaptor/test_adaptor_pytorch_1.x.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,9 @@ def build_pytorch_yaml():

def build_pytorch_fx_yaml():
if PT_VERSION >= Version("1.9.0").release:
fake_fx_ptq_yaml = fake_ptq_yaml_for_fx
fake_fx_ptq_yaml = fake_ptq_yaml_for_fx
else:
fake_fx_ptq_yaml = fake_ptq_yaml.replace('pytorch', 'pytorch_fx')
fake_fx_ptq_yaml = fake_ptq_yaml.replace('pytorch', 'pytorch_fx')
with open('fx_ptq_yaml.yaml', 'w', encoding="utf-8") as f:
f.write(fake_fx_ptq_yaml)

Expand Down Expand Up @@ -712,11 +712,11 @@ def test_tensor_dump_and_set(self):
a = load_array('saved/dump_tensor/activation_iter1.npz')
w = load_array('saved/dump_tensor/weight.npz')
if PT_VERSION >= Version("1.8.0").release:
self.assertTrue(w['conv1.0'].item()['conv1.0.weight'].shape[0] ==
a['conv1.0'].item()['conv1.0.output0'].shape[1])
self.assertTrue(w['conv1.0'].item()['conv1.0.weight'].shape[0] ==
a['conv1.0'].item()['conv1.0.output0'].shape[1])
else:
self.assertTrue(w['conv1.0'].item()['conv1.0.weight'].shape[0] ==
a['conv1.0'].item()['conv1.1.output0'].shape[1])
self.assertTrue(w['conv1.0'].item()['conv1.0.weight'].shape[0] ==
a['conv1.0'].item()['conv1.1.output0'].shape[1])
data = np.random.random(w['conv1.0'].item()['conv1.0.weight'].shape).astype(np.float32)
quantizer.strategy.adaptor.set_tensor(q_model, {'conv1.0.weight': data})
changed_tensor = q_model.get_weight('conv1.weight')
Expand Down Expand Up @@ -789,7 +789,7 @@ def forward(self, x):
q_capability = self.adaptor.query_fw_capability(model)
for k, v in q_capability["opwise"].items():
if k[0] != "quant" and k[0] != "dequant":
fallback_ops.append(k[0])
fallback_ops.append(k[0])
model.model.qconfig = torch.quantization.default_qconfig
model.model.quant.qconfig = torch.quantization.default_qconfig
if PT_VERSION >= Version("1.8.0").release:
Expand Down Expand Up @@ -903,7 +903,7 @@ def test_fx_dynamic_quant(self):
# run fx_quant in neural_compressor and save the quantized GraphModule
model.eval()
quantizer = Quantization('fx_dynamic_yaml.yaml')
quantizer.model = common.Model(model,
quantizer.model = common.Model(copy.deepcopy(model),
**{'prepare_custom_config_dict': \
{'non_traceable_module_name': ['a']},
'convert_custom_config_dict': \
Expand All @@ -913,7 +913,7 @@ def test_fx_dynamic_quant(self):
q_model.save('./saved')

# Load configure and weights by neural_compressor.utils
model_fx = load("./saved", model,
model_fx = load("./saved", copy.deepcopy(model),
**{'prepare_custom_config_dict': \
{'non_traceable_module_name': ['a']},
'convert_custom_config_dict': \
Expand All @@ -929,7 +929,7 @@ def test_fx_dynamic_quant(self):
yaml.dump(tune_cfg, f, default_flow_style=False)
torch.save(state_dict, "./saved/best_model_weights.pt")
os.remove('./saved/best_model.pt')
model_fx = load("./saved", model,
model_fx = load("./saved", copy.deepcopy(model),
**{'prepare_custom_config_dict': \
{'non_traceable_module_name': ['a']},
'convert_custom_config_dict': \
Expand Down
12 changes: 6 additions & 6 deletions test/adaptor/pytorch_adaptor/test_adaptor_pytorch_2.x.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,7 @@ def test_fx_quant(self):
else:
conf = PostTrainingQuantConfig(
op_name_list=ptq_fx_op_name_list)
conf.example_inputs = torch.randn([1, 3, 224, 224])
set_workspace("./saved")
q_model = quantization.fit(model_origin,
conf,
Expand Down Expand Up @@ -381,11 +382,11 @@ def test_fx_dynamic_quant(self):
origin_model.eval()
conf = PostTrainingQuantConfig(approach="dynamic", op_name_list=ptq_fx_op_name_list)
set_workspace("./saved")
q_model = quantization.fit(origin_model, conf)
q_model = quantization.fit(copy.deepcopy(origin_model), conf)
q_model.save("./saved")

# Load configure and weights by neural_compressor.utils
model_fx = load("./saved", origin_model)
model_fx = load("./saved", copy.deepcopy(origin_model))
self.assertTrue(isinstance(model_fx, torch.fx.graph_module.GraphModule))

# Test the functionality of older model saving type
Expand All @@ -396,7 +397,7 @@ def test_fx_dynamic_quant(self):
yaml.dump(tune_cfg, f, default_flow_style=False)
torch.save(state_dict, "./saved/best_model_weights.pt")
os.remove("./saved/best_model.pt")
model_fx = load("./saved", origin_model)
model_fx = load("./saved", copy.deepcopy(origin_model))
self.assertTrue(isinstance(model_fx, torch.fx.graph_module.GraphModule))

# recover int8 model with only tune_cfg
Expand Down Expand Up @@ -472,8 +473,7 @@ def test_fx_sub_module_quant(self):
# recover int8 model with only tune_cfg
history_file = "./saved/history.snapshot"
model_fx_recover = recover(model_origin, history_file, 0,
**{"dataloader": torch.utils.data.DataLoader(dataset)
})
**{"dataloader": torch.utils.data.DataLoader(dataset)})
self.assertEqual(model_fx.sub.code, model_fx_recover.sub.code)
shutil.rmtree("./saved", ignore_errors=True)

Expand All @@ -489,7 +489,7 @@ def test_mix_precision(self):
q_model = quantization.fit(model_origin,
conf,
calib_dataloader=dataloader,
calib_func = eval_func)
calib_func=eval_func)
tune_cfg = q_model.q_config
tune_cfg["op"][("conv.module", "Conv2d")].clear()
tune_cfg["op"][("conv.module", "Conv2d")] = \
Expand Down
Loading

0 comments on commit c295a7f

Please sign in to comment.