Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Apr 25, 2024
1 parent 9e5197d commit f14aa92
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 18 deletions.
1 change: 0 additions & 1 deletion nncf/torch/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(self, model: NNCFNetwork):
]

def transform(self, transformation_layout: PTTransformationLayout) -> NNCFNetwork:
# self._model.nncf.record_commands(transformation_layout.transformations)
transformations = transformation_layout.transformations
aggregated_transformations = defaultdict(list)
requires_graph_rebuild = False
Expand Down
40 changes: 23 additions & 17 deletions tests/torch/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,29 @@
from tests.torch.nncf_network.helpers import InsertionCommandBuilder


def load_from_config_impl(model: torch.nn.Module, serialized_transformations, example_input, trace_parameters):
"""
Test implementation of nncf.torch.load_from_config(). Should be replaced by the implementation
"""
transformations_layout = load_transformations(serialized_transformations)

nncf_network = wrap_model(deepcopy(model), example_input=example_input, trace_parameters=trace_parameters)
transformed_model = ModelTransformerFactory.create(nncf_network).transform(transformations_layout)

transformed_model.nncf.disable_dynamic_graph_building()
return transformed_model


def nncf_get_config_impl(
model: NNCFNetwork,
):
"""
Test implementation of model.nncf.get_config(). Should be replaced by the implementation
"""
layout = model.nncf.transformation_layout()
return serialize_transformations(layout)


@pytest.mark.parametrize("target_type", AVAILABLE_TARGET_TYPES)
@pytest.mark.parametrize("command_builder", InsertionCommandBuilder(TwoConvTestModel).get_command_builders())
@pytest.mark.parametrize("priority", InsertionCommandBuilder.PRIORITIES)
Expand Down Expand Up @@ -80,23 +103,6 @@ def test_serialize_transformations():
_check_commands_after_serialization(command, recovered_command, dummy_op_state)


def load_from_config_impl(model: torch.nn.Module, serialized_transformations, example_input, trace_parameters):
transformations_layout = load_transformations(serialized_transformations)

nncf_network = wrap_model(deepcopy(model), example_input=example_input, trace_parameters=trace_parameters)
transformed_model = ModelTransformerFactory.create(nncf_network).transform(transformations_layout)

transformed_model.nncf.disable_dynamic_graph_building()
return transformed_model


def nncf_get_config_impl(
model: NNCFNetwork,
):
layout = model.nncf.transformation_layout()
return serialize_transformations(layout)


@pytest.mark.parametrize("model_cls", InsertionCommandBuilder.AVAILABLE_MODELS)
@pytest.mark.parametrize("trace_parameters", (False, True))
def test_get_apply_serialization_from_a_model(model_cls, trace_parameters):
Expand Down

0 comments on commit f14aa92

Please sign in to comment.