Skip to content

Commit

Permalink
fix nncf precision config to exported model
Browse files Browse the repository at this point in the history
wonjuleee committed Jun 2, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent eed9902 commit 0957213
Showing 2 changed files with 11 additions and 6 deletions.
Original file line number Diff line number Diff line change
@@ -100,7 +100,7 @@ def __init__(self, task_environment: TaskEnvironment):

# Set default model attributes.
self._optimization_methods = []
self._precision = [ModelPrecision.FP16] if self._config.get('fp16', None) else [ModelPrecision.FP32]
self._precision = self._precision_from_config
self._optimization_type = ModelOptimizationType.MO

# Create and initialize PyTorch model.
@@ -113,6 +113,10 @@ def __init__(self, task_environment: TaskEnvironment):
self._should_stop = False
logger.info('Task initialization completed')

@property
def _precision_from_config(self):
return [ModelPrecision.FP16] if self._config.get('fp16', None) else [ModelPrecision.FP32]

@property
def _hyperparams(self):
return self._task_environment.get_hyper_parameters(OTEDetectionConfig)
@@ -402,7 +406,8 @@ def export(self,
model = self._model.cpu()
pruning_transformation = OptimizationMethod.FILTER_PRUNING in self._optimization_methods
export_model(model, self._config, tempdir, target='openvino',
pruning_transformation=pruning_transformation, precision=self._precision[0].name)
pruning_transformation=pruning_transformation,
precision=self._precision_from_config[0].name)
bin_file = [f for f in os.listdir(tempdir) if f.endswith('.bin')][0]
xml_file = [f for f in os.listdir(tempdir) if f.endswith('.xml')][0]
with open(os.path.join(tempdir, bin_file), "rb") as f:
Original file line number Diff line number Diff line change
@@ -81,17 +81,17 @@ def _set_attributes_by_hyperparams(self):
if quantization and pruning:
self._nncf_preset = "nncf_quantization_pruning"
self._optimization_methods = [OptimizationMethod.QUANTIZATION, OptimizationMethod.FILTER_PRUNING]
self._nncf_precision = [ModelPrecision.INT8]
self._precision = [ModelPrecision.INT8]
return
if quantization and not pruning:
self._nncf_preset = "nncf_quantization"
self._optimization_methods = [OptimizationMethod.QUANTIZATION]
self._nncf_precision = [ModelPrecision.INT8]
self._precision = [ModelPrecision.INT8]
return
if not quantization and pruning:
self._nncf_preset = "nncf_pruning"
self._optimization_methods = [OptimizationMethod.FILTER_PRUNING]
self._nncf_precision = [ModelPrecision.FP32]
self._precision = [ModelPrecision.INT8]
return
raise RuntimeError('Not selected optimization algorithm')

@@ -249,7 +249,7 @@ def optimize(
output_model.model_format = ModelFormat.BASE_FRAMEWORK
output_model.optimization_type = self._optimization_type
output_model.optimization_methods = self._optimization_methods
output_model.precision = self._nncf_precision
output_model.precision = self._precision

self._is_training = False

0 comments on commit 0957213

Please sign in to comment.