Skip to content

Commit

Permalink
API code moved to a separate PR
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Apr 2, 2024
1 parent f89019b commit 765775a
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 147 deletions.
2 changes: 0 additions & 2 deletions nncf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,9 @@
from nncf.parameters import SensitivityMetric as SensitivityMetric
from nncf.parameters import TargetDevice as TargetDevice
from nncf.quantization import QuantizationPreset as QuantizationPreset
from nncf.quantization import apply_serialized_transformations as apply_serialized_transformations
from nncf.quantization import compress_weights as compress_weights
from nncf.quantization import quantize as quantize
from nncf.quantization import quantize_with_accuracy_control as quantize_with_accuracy_control
from nncf.quantization import serialize_transformations as serialize_transformations
from nncf.quantization.advanced_parameters import (
AdvancedAccuracyRestorerParameters as AdvancedAccuracyRestorerParameters,
)
Expand Down
2 changes: 0 additions & 2 deletions nncf/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
# limitations under the License.
"""Post-training quantization APIs."""
from nncf.common.quantization.structs import QuantizationPreset as QuantizationPreset
from nncf.quantization.quantize_model import apply_serialized_transformations as apply_serialized_transformations
from nncf.quantization.quantize_model import compress_weights as compress_weights
from nncf.quantization.quantize_model import quantize as quantize
from nncf.quantization.quantize_model import quantize_with_accuracy_control as quantize_with_accuracy_control
from nncf.quantization.quantize_model import serialize_transformations as serialize_transformations
33 changes: 1 addition & 32 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Iterable, List, Optional, Tuple, TypeVar, Union

import nncf
from nncf.api.compression import TModel
Expand Down Expand Up @@ -540,34 +540,3 @@ def quantize_with_tune_hyperparams(
quantized_model = hyperparameter_tuner.apply(model, validation_dataset)

return quantized_model


@api(canonical_alias="nncf.apply_serialized_transformations")
def apply_serialized_transformations(
model: TModel,
serialized_transformations,
) -> TModel:
"""
Applies transformation layout to the model.
"""
backend = get_backend(model)
if backend == BackendType.TORCH:
from nncf.torch.quantization.quantize_model import apply_serialized_transformations_impl

return apply_serialized_transformations_impl(model, serialized_transformations)
raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}")


@api(canonical_alias="nncf.serialize_transformations")
def serialize_transformations(
model: TModel,
) -> Dict[str, Any]:
"""
Applies transformation layout to the model.
"""
backend = get_backend(model)
if backend == BackendType.TORCH:
from nncf.torch.quantization.quantize_model import serialize_transformations_impl

return serialize_transformations_impl(model)
raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}")
19 changes: 5 additions & 14 deletions nncf/torch/graph/transformations/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@
# limitations under the License.

from enum import Enum
from typing import Any, Dict, Tuple, Union

import torch
from typing import Any, Dict, Union

from nncf.common.graph.transformations.commands import TransformationPriority
from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo
from nncf.torch.graph.transformations.commands import ExtraCompressionModuleType
from nncf.torch.graph.transformations.commands import PTInsertionCommand
from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand
Expand All @@ -25,36 +22,30 @@
from nncf.torch.layer_utils import COMPRESSION_MODULES

COMPRESSION_STATE_ATTR = "compression_state"
INPUT_INFO_ATTR = "example_input"


class CompressionKeys(Enum):
SHARED_INSERTION_COMMAND = "SHARED_INSERTION_COMMAND"
INSERTION_COMMAND = "INSERTION_COMMAND"


def serialize_transformations(model: torch.nn.Module, transformations_layout: TransformationLayout) -> Dict[str, Any]:
input_info = model.nncf._input_info
if not isinstance(input_info, FillerInputInfo):
raise RuntimeError("Could not serialize model inputs input: {input_info}")

def serialize_transformations(transformations_layout: TransformationLayout) -> Dict[str, Any]:
transformation_commands = []
for command in transformations_layout.transformations:
serialized_command = serialize_command(command)
if serialized_command:
transformation_commands.append(serialized_command)

return {COMPRESSION_STATE_ATTR: transformation_commands, INPUT_INFO_ATTR: input_info.get_state()}
return {COMPRESSION_STATE_ATTR: transformation_commands}


def load_transformations(transformations_state: Dict[str, Any]) -> Tuple[TransformationLayout, FillerInputInfo]:
def load_transformations(transformations_state: Dict[str, Any]) -> TransformationLayout:
transformation_layout = TransformationLayout()
for serialized_command in transformations_state[COMPRESSION_STATE_ATTR]:
command = load_command(serialized_command)
transformation_layout.register(command)

input_info = FillerInputInfo.from_state(transformations_state[INPUT_INFO_ATTR])
return transformation_layout, input_info
return transformation_layout


def serialize_command(command: PTTransformationCommand) -> Dict[str, Any]:
Expand Down
22 changes: 0 additions & 22 deletions nncf/torch/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import torch

import nncf
from nncf.common.factory import ModelTransformerFactory
from nncf.common.factory import NNCFGraphFactory
from nncf.common.quantization.structs import QuantizationPreset
from nncf.data import Dataset
Expand All @@ -30,10 +29,7 @@
from nncf.quantization.quantize_model import warning_model_no_batchwise_support
from nncf.scopes import IgnoredScope
from nncf.torch.graph.operator_metatypes import OPERATIONS_OUTPUT_HAS_NO_BATCH_AXIS
from nncf.torch.graph.transformations.serialization import load_transformations
from nncf.torch.graph.transformations.serialization import serialize_transformations
from nncf.torch.model_creation import wrap_model
from nncf.torch.nncf_network import NNCFNetwork

DEFAULT_RANGE_TYPE = "mean_min_max"

Expand Down Expand Up @@ -104,21 +100,3 @@ def compress_weights_impl(
)
graph = NNCFGraphFactory.create(model)
return compression_algorithm.apply(model, graph, dataset=dataset)


def apply_serialized_transformations_impl(model: torch.nn.Module, serialized_transformations):
transformations_layout, input_info = load_transformations(serialized_transformations)

nncf_network = NNCFNetwork(deepcopy(model), input_info=input_info)
model_transformer = ModelTransformerFactory.create(nncf_network)
transformed_model = model_transformer.transform(transformations_layout)

transformed_model.nncf.disable_dynamic_graph_building()
return transformed_model


def serialize_transformations_impl(
model: NNCFNetwork,
):
layout = model.nncf.get_applied_transformation_layout()
return serialize_transformations(model, layout)
64 changes: 0 additions & 64 deletions tests/torch/qat/test_qat_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,67 +293,3 @@ def test_compression_training(quantization_config_path: Path, sota_data_dir):
del sample_config["compression"]["initializer"]["range"]

start_worker_clean_memory(main_worker, sample_config)


def save_load_main_worker(current_gpu: int, config: SampleConfig):
configure_device(current_gpu, config)
if is_main_process():
configure_logging(logger, config)
else:
config.tb = None

pretrained = is_pretrained_model_requested(config)
model_name = config["model"]
# create model
logger.info(f"\nCreating model from config: {config.config}")
model = load_model(
model_name,
pretrained=pretrained,
num_classes=config.get("num_classes", 1000),
model_params=config.get("model_params"),
weights_path=config.get("weights"),
)
model.to(config.device)

datasets = get_datasets(config)
criterion = nn.CrossEntropyLoss()
criterion = criterion.to(config.device)

logger.info("Original model validation:")
# original_accuracy, *_ = validate(datasets.val_data_loader, model, criterion, config)
original_accuracy = 100.0

logger.info("Apply quantization to the model:")
config_quantization_params = config["compression"]

preset = get_quantization_preset(config_quantization_params)
advanced_parameters = get_advanced_ptq_parameters(config_quantization_params)
# subset_size = get_num_samples(config_quantization_params)

quantized_model = nncf.quantize(
model,
datasets.calibration_dataset,
preset=preset,
advanced_parameters=advanced_parameters,
subset_size=1,
)

transformations_state = nncf.serialize_transformations(quantized_model)
state_dict = quantized_model.state_dict()
del quantized_model
quantized_model = nncf.apply_serialized_transformations(model, transformations_state)
quantized_model.load_state_dict(state_dict)

train_criterion_fn = inception_criterion_fn if "inception" in model_name else default_criterion_fn
acc_drop = train(
quantized_model,
config,
criterion,
train_criterion_fn,
datasets,
original_accuracy,
get_mocked_compression_ctrl(),
)
assert accuracy_drop_is_acceptable(acc_drop)
check_training_correctness(config, model, datasets, criterion, train_criterion_fn)
logger.info("Done!")
39 changes: 28 additions & 11 deletions tests/torch/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
import pytest
import torch

import nncf
from nncf.common.factory import ModelTransformerFactory
from nncf.common.quantization.structs import QuantizationScheme
from nncf.quantization.algorithms.smooth_quant.torch_backend import SQMultiply
from nncf.torch.dynamic_graph.io_handling import FillerInputElement
from nncf.torch.dynamic_graph.io_handling import FillerInputInfo
from nncf.torch.dynamic_graph.io_handling import ModelInputInfo
from nncf.torch.graph.transformations.serialization import load_command
from nncf.torch.graph.transformations.serialization import load_transformations
from nncf.torch.graph.transformations.serialization import serialize_command
Expand Down Expand Up @@ -56,39 +56,56 @@ def test_serialize_load_command(target_type, command_builder, priority):
_check_commands_after_serialization(command, recovered_command, dummy_op_state)


def test_serialize_transformations(mocker):
def test_serialize_transformations():
layout = TwoConvTestModel.get_all_available_commands()
model = mocker.MagicMock()
input_info_ref = FillerInputInfo([FillerInputElement([1, 1, 4, 4])])
model.nncf._input_info = input_info_ref

serialized_transformations = serialize_transformations(model, layout)
serialized_transformations = serialize_transformations(layout)

# Check serialized transformation are json compatible
j_str = json.dumps(serialized_transformations)
serialized_transformations = json.loads(j_str)

recovered_layout, input_info = load_transformations(serialized_transformations)
assert input_info == input_info_ref
recovered_layout = load_transformations(serialized_transformations)
assert len(layout.transformations) == len(recovered_layout.transformations)
# Can zip layouts because the order should not be altered
for command, recovered_command in zip(layout.transformations, recovered_layout.transformations):
_check_commands_after_serialization(command, recovered_command)


def apply_serialized_transformations_impl(
model: torch.nn.Module, input_info: ModelInputInfo, serialized_transformations
):
transformations_layout = load_transformations(serialized_transformations)

nncf_network = NNCFNetwork(deepcopy(model), input_info=input_info)
model_transformer = ModelTransformerFactory.create(nncf_network)
transformed_model = model_transformer.transform(transformations_layout)

transformed_model.nncf.disable_dynamic_graph_building()
return transformed_model


def serialize_transformations_impl(
model: NNCFNetwork,
):
layout = model.nncf.get_applied_transformation_layout()
return serialize_transformations(layout)


def test_get_apply_serialization_from_a_model():
layout = TwoConvTestModel.get_all_available_commands(skip_model_transformer_unsupported=True)
model = TwoConvTestModel()
nncf_model = NNCFNetwork(deepcopy(model), input_info=FillerInputInfo([FillerInputElement([1, 1, 4, 4])]))
input_info = FillerInputInfo([FillerInputElement([1, 1, 4, 4])])
nncf_model = NNCFNetwork(deepcopy(model), input_info=input_info)
modified_model = ModelTransformerFactory.create(nncf_model).transform(layout)

serialized_transformations = nncf.serialize_transformations(modified_model)
serialized_transformations = serialize_transformations_impl(modified_model)

# Check serialized transformation are json compatible
j_str = json.dumps(serialized_transformations)
serialized_transformations = json.loads(j_str)

recovered_model = nncf.apply_serialized_transformations(model, serialized_transformations)
recovered_model = apply_serialized_transformations_impl(model, input_info, serialized_transformations)
for conv, recovered_conv in zip(modified_model.features, recovered_model.features):
for hooks_attr in ["pre_ops", "post_ops"]:
hooks = getattr(conv[0], hooks_attr)
Expand Down

0 comments on commit 765775a

Please sign in to comment.