Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Oct 26, 2023
1 parent b31ed6e commit 203e20f
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 18 deletions.
29 changes: 12 additions & 17 deletions nncf/torch/model_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def create_compressed_model(
"""
The main function used to produce a model ready for compression fine-tuning from an original PyTorch
model and a configuration object.
dummy_forward_fn
:param model: The original model. Should have its parameters already loaded from a checkpoint or another
source.
:param config: A configuration object used to determine the exact compression modifications to be applied
Expand Down Expand Up @@ -124,8 +124,7 @@ def create_compressed_model(

should_init = compression_state is None

input_info = get_input_info_from_config(config)
nncf_network = create_nncf_network(model, input_info, dummy_forward_fn, wrap_inputs_fn, wrap_outputs_fn)
nncf_network = create_nncf_network(model, config, dummy_forward_fn, wrap_inputs_fn, wrap_outputs_fn)

if dump_graphs and is_main_process():
nncf_network.nncf.get_graph().visualize_graph(osp.join(config.get("log_dir", "."), "original_graph.dot"))
Expand Down Expand Up @@ -186,27 +185,24 @@ def get_input_info_from_config(config: NNCFConfig) -> ModelInputInfo:


def create_nncf_network_with_inputs_from_config(model: torch.nn.Module, config: NNCFConfig):
return create_nncf_network(model, input_info=get_input_info_from_config(config))
return create_nncf_network(model, config=config)


def create_nncf_network(
model: torch.nn.Module,
input_info: ModelInputInfo,
config: NNCFConfig,
dummy_forward_fn: Callable[[Module], Any] = None,
wrap_inputs_fn: Callable = None,
wrap_outputs_fn: Callable = None,
ignored_scopes: List[str] = None,
target_scopes: List[str] = None,
scopes_without_shape_matching: List[str] = None,
wrap_outputs_fn: Callable = None
) -> NNCFNetwork:
"""
The main function used to produce a model ready for adding compression from an original PyTorch
model and a configuration object.
:param model: The original model. Should have its parameters already loaded from a checkpoint or another
source.
:param input_info: A list of descriptors for each of the model's tensor inputs, used to build the input
parameters for the internal `.forward` calls made on the model object to build its graph representation
:param config: A configuration object used to determine the exact compression modifications to be applied
to the model
:param dummy_forward_fn: if supplied, will be used instead of a *forward* function call to build
the internal graph representation via tracing. Specifying this is useful when the original training pipeline
has special formats of data loader output or has additional *forward* arguments other than input tensors.
Expand All @@ -228,12 +224,6 @@ def create_nncf_network(
the same as were supplied in input, but each tensor in the original input. Must be specified if
dummy_forward_fn is specified.
:param wrap_outputs_fn: Same as `wrap_inputs_fn`, but for marking model outputs with
:param scopes_without_shape_matching: A list of scopes in the model in which the activation tensor shapes will
not be considered for purposes of scope matching - this helps handle RNN-like cases.
:param ignored_scopes: A list of scopes in the model for which NNCF handling should not be applied. Functions as
a "denylist". If left unspecified, nothing will be ignored.
:param target_scopes: A list of scopes in the model for which NNCF handling should be applied. Functions as
an "allowlist". If left unspecified, everything will be targeted.
:return: A model wrapped by NNCFNetwork, which is ready for adding compression."""

Expand All @@ -251,6 +241,11 @@ def create_nncf_network(
# model that are used on training stage only (e.g. AuxLogits of Inception-v3 model) or unused modules with
# weights. As a consequence, no need to care about spoiling BN statistics, as they're disabled in eval mode.

input_info = get_input_info_from_config(config)
scopes_without_shape_matching = config.get("scopes_without_shape_matching", [])
ignored_scopes = config.get("ignored_scopes")
target_scopes = config.get("target_scopes")

nncf_network = NNCFNetwork(
model,
input_info=input_info,
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/sparsity/movement/helpers/run_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ class ClipVisionRunRecipe(BaseMockRunRecipe):
def model_input_info(self) -> FillerInputInfo:
num_channels = self.model_config.num_channels
image_size = self.model_config.image_size
return [FillerInputElement(shape=[1, num_channels, image_size, image_size], type_str="float")]
return FillerInputInfo([FillerInputElement(shape=[1, num_channels, image_size, image_size], type_str="float")])

@staticmethod
def get_nncf_modules_in_transformer_block_order(compressed_model: NNCFNetwork) -> List[DictInTransformerBlockOrder]:
Expand Down

0 comments on commit 203e20f

Please sign in to comment.