From ed601df69bc5e710b39898b364dc9a3734f4762e Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Tue, 21 May 2024 20:53:39 +0200 Subject: [PATCH 1/7] TorchFX quantization init --- aa_torch_fx.py | 456 ++++++++++++++++++ .../openvino/tiny_llama/main.py | 2 +- .../tiny_llama_find_hyperparams/main.py | 2 +- nncf/common/factory.py | 16 + nncf/common/graph/patterns/manager.py | 10 + nncf/common/utils/backend.py | 26 +- nncf/experimental/torch_fx/__init__.py | 10 + nncf/experimental/torch_fx/engine.py | 48 ++ .../torch_fx/model_transformer.py | 183 +++++++ .../torch_fx/nncf_graph_builder.py | 167 +++++++ .../torch_fx/quantization/__init__.py | 10 + .../torch_fx/quantization/quantize_model.py | 105 ++++ .../torch_fx/statistics/__init__.py | 10 + .../torch_fx/statistics/aggregator.py | 99 ++++ nncf/experimental/torch_fx/transformations.py | 350 ++++++++++++++ .../fast_bias_correction/algorithm.py | 8 +- .../fast_bias_correction/torch_fx_backend.py | 116 +++++ .../algorithms/min_max/algorithm.py | 6 +- .../algorithms/min_max/torch_fx_backend.py | 351 ++++++++++++++ nncf/quantization/quantize_model.py | 14 + nncf/torch/dynamic_graph/patch_pytorch.py | 2 +- nncf/torch/dynamic_graph/structs.py | 1 + nncf/torch/graph/operator_metatypes.py | 10 +- .../pipelines/lm_weight_compression.py | 2 +- .../sparsity/movement/helpers/run_recipe.py | 2 +- .../sparsity/movement/helpers/trainer.py | 2 +- .../sparsity/movement/test_model_saving.py | 2 +- .../movement/training_scripts/run_glue.py | 3 +- tests/torch_fx/__init__.py | 10 + tests/torch_fx/helpers.py | 105 ++++ tests/torch_fx/requirements.txt | 1 + tests/torch_fx/test_sanity.py | 141 ++++++ torch_compile_ex_release.py | 217 +++++++++ yolo_fx_bad_metrics_repro.py | 86 ++++ 34 files changed, 2558 insertions(+), 15 deletions(-) create mode 100644 aa_torch_fx.py create mode 100644 nncf/experimental/torch_fx/__init__.py create mode 100644 nncf/experimental/torch_fx/engine.py create mode 100644 nncf/experimental/torch_fx/model_transformer.py create mode 100644 nncf/experimental/torch_fx/nncf_graph_builder.py create mode 100644 nncf/experimental/torch_fx/quantization/__init__.py create mode 100644 nncf/experimental/torch_fx/quantization/quantize_model.py create mode 100644 nncf/experimental/torch_fx/statistics/__init__.py create mode 100644 nncf/experimental/torch_fx/statistics/aggregator.py create mode 100644 nncf/experimental/torch_fx/transformations.py create mode 100644 nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py create mode 100644 nncf/quantization/algorithms/min_max/torch_fx_backend.py create mode 100644 tests/torch_fx/__init__.py create mode 100644 tests/torch_fx/helpers.py create mode 100644 tests/torch_fx/requirements.txt create mode 100644 tests/torch_fx/test_sanity.py create mode 100644 torch_compile_ex_release.py create mode 100644 yolo_fx_bad_metrics_repro.py diff --git a/aa_torch_fx.py b/aa_torch_fx.py new file mode 100644 index 00000000000..339d33d1598 --- /dev/null +++ b/aa_torch_fx.py @@ -0,0 +1,456 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import copy +import re +import subprocess +import time +import warnings +from itertools import islice +from pathlib import Path + +import numpy as np +import openvino as ov +import openvino.torch # noqa +import pandas as pd +import torch +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq +import torchvision.models as models +from sklearn.metrics import accuracy_score +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import convert_pt2e +from torch.ao.quantization.quantize_pt2e import prepare_pt2e +from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer +from torch.fx.passes.graph_drawer import FxGraphDrawer +from torch.jit import TracerWarning +from torchao.utils import benchmark_model as ao_benchmark_model +from torchvision import datasets +from transformers import AutoImageProcessor +from transformers import AutoModelForImageClassification + +import nncf +from nncf.common.logging.track_progress import track +from nncf.common.quantization.structs import QuantizationPreset # noqa +from nncf.parameters import ModelType +from nncf.torch.dynamic_graph.patch_pytorch import disable_patching + +warnings.filterwarnings("ignore", category=TracerWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +DATASET_IMAGENET = "/home/dlyakhov/datasets/imagenet/val" + +hf_models = () + + +def hf_model_builder(model_id: str): + def build(weights): + processor = AutoImageProcessor.from_pretrained(model_id) + model = AutoModelForImageClassification.from_pretrained(model_id) + + class ModelWithProcessing(torch.nn.Module): + def __init__(self, processor, model): + super().__init__() + self.processor = processor + self.model = model + + def forward(self, x): + processed_input = processor(x, return_tensors="pt") + return model(processed_input) + + # return ModelWithProcessing(processor, model) + return model + + class DummyWeights: + def transforms(self): + return models.ResNet18_Weights.DEFAULT.transforms() + + @property + def meta(self): + return {} + + return build, DummyWeights() + + +MODELS_DICT = { + "vit_h_14": (models.vit_h_14, models.ViT_H_14_Weights.DEFAULT), + "vit_b_16": (models.vit_b_16, models.ViT_B_16_Weights.DEFAULT), + "swin_v2_t": (models.swin_v2_t, models.Swin_V2_T_Weights.DEFAULT), + "swin_v2_s": (models.swin_v2_s, models.Swin_V2_S_Weights.DEFAULT), + "resnet18": (models.resnet18, models.ResNet18_Weights.DEFAULT), + "resnet50": (models.resnet50, models.ResNet50_Weights.DEFAULT), + "mobilenet_v2": (models.mobilenet_v2, models.MobileNet_V2_Weights.DEFAULT), + "mobilenet_v3_small": (models.mobilenet_v3_small, models.MobileNet_V3_Small_Weights.DEFAULT), + "mobilenet_v3_large": (models.mobilenet_v3_large, models.MobileNet_V3_Large_Weights.DEFAULT), + # "densenet161": (models.densenet161, models.DenseNet161_Weights.DEFAULT), + "vgg16": (models.vgg16, models.VGG16_Weights.DEFAULT), + "efficientnet_b7": (models.efficientnet_b7, models.EfficientNet_B7_Weights.DEFAULT), + "inception_v3": (models.inception_v3, models.Inception_V3_Weights.DEFAULT), + "regnet_x_32gf": (models.regnet_x_32gf, models.RegNet_X_32GF_Weights.DEFAULT), + # "google/vit-base-patch16-224": hf_model_builder("google/vit-base-patch16-224"), + # "convnext_large": (models.convnext_large, models.ConvNeXt_Large_Weights.DEFAULT), + # "convnext_small": (models.convnext_small, models.ConvNeXt_Small_Weights.DEFAULT), +} + + +def measure_time(model, example_inputs, num_iters=1000): + with torch.no_grad(): + model(*example_inputs) + total_time = 0 + for i in range(0, num_iters): + start_time = time.time() + model(*example_inputs) + total_time += time.time() - start_time + average_time = (total_time / num_iters) * 1000 + return average_time + + +def measure_time_ov(model, example_inputs, num_iters=1000): + ie = ov.Core() + compiled_model = ie.compile_model(model, "CPU") + infer_request = compiled_model.create_infer_request() + infer_request.infer(example_inputs) + total_time = 0 + for i in range(0, num_iters): + start_time = time.time() + infer_request.infer(example_inputs) + total_time += time.time() - start_time + average_time = (total_time / num_iters) * 1000 + return average_time + + +def quantize(model, example_inputs, calibration_dataset, subset_size=300): + with torch.no_grad(): + exported_model = capture_pre_autograd_graph(model, example_inputs) + + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + + prepared_model = prepare_pt2e(exported_model, quantizer) + from tqdm import tqdm + + for inp, _ in islice(tqdm(calibration_dataset), subset_size): + prepared_model(inp) + converted_model = convert_pt2e(prepared_model) + return converted_model + + +def validate(model, val_loader, subset_size=None): + dataset_size = len(val_loader) + + predictions = np.zeros((dataset_size)) + references = -1 * np.ones((dataset_size)) + + with track(total=dataset_size, description="Validation") as pbar: + + for i, (images, target) in enumerate(val_loader): + if subset_size is not None and i >= subset_size: + break + + output_data = model(images).detach().numpy() + predicted_label = np.argmax(output_data, axis=1) + predictions[i] = predicted_label.item() + references[i] = target + pbar.progress.update(pbar.task, advance=1) + acc_top1 = accuracy_score(predictions, references) * 100 + print(acc_top1) + return acc_top1 + + +def validate_ov(model, val_loader): + dataset_size = len(val_loader) + + # Initialize result tensors for async inference support. + predictions = np.zeros((dataset_size)) + references = -1 * np.ones((dataset_size)) + + core = ov.Core() + compiled_model = core.compile_model(model) + + infer_queue = ov.AsyncInferQueue(compiled_model, 4) + with track(total=dataset_size, description="Validation") as pbar: + + def process_result(request, userdata): + output_data = request.get_output_tensor().data + predicted_label = np.argmax(output_data, axis=1) + predictions[userdata] = predicted_label.item() + pbar.progress.update(pbar.task, advance=1) + + infer_queue.set_callback(process_result) + + for i, (images, target) in enumerate(val_loader): + # W/A for memory leaks when using torch DataLoader and OpenVINO + image_copies = copy.deepcopy(images.numpy()) + infer_queue.start_async(image_copies, userdata=i) + references[i] = target + + infer_queue.wait_all() + + acc_top1 = accuracy_score(predictions, references) * 100 + print(acc_top1) + return acc_top1 + + +def run_benchmark(model_path: Path, shape) -> float: + command = f"benchmark_app -m {model_path} -d CPU -api async -t 15" + command += f' -shape="[{",".join(str(x) for x in shape)}]"' + cmd_output = subprocess.check_output(command, shell=True) # nosec + match = re.search(r"Throughput\: (.+?) FPS", str(cmd_output)) + return float(match.group(1)) + + +def torch_ao_sq_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input): + import torch + from torchao.quantization.smoothquant import smooth_fq_linear_to_inference + from torchao.quantization.smoothquant import swap_linear_with_smooth_fq_linear + + # Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor + torch._inductor.config.force_fuse_int_mm_with_mul = True + + # plug in your model + # model = torch.compile(pt_model) + model = pt_model + + # convert linear modules to smoothquant + # linear module in calibration mode + swap_linear_with_smooth_fq_linear(model) + + # Create a data loader for calibration + calibration_loader = val_loader + + # Calibrate the model + model.train() + from tqdm import tqdm + + for batch in tqdm(islice(calibration_loader, 300)): + inputs = batch[0] + model(inputs) + + # set it to inference mode + smooth_fq_linear_to_inference(model) + + # compile the model to improve performance + model = torch.compile(model, mode="max-autotune") + acc1_quant_model = validate(model, val_loader) + print(f"torch ao metric acc@1: {acc1_quant_model}") + result["torch_ao_quant_model_acc"] = acc1_quant_model + + latency = ao_benchmark_model(model, 20, example_input) + print(f"torch ao latency: {latency}") + result["torch_ao_quant_model_latency"] = latency + + +def nncf_fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input): + with disable_patching(): + with torch.no_grad(): + exported_model = capture_pre_autograd_graph(pt_model, (example_input,)) + + def transform(x): + return x[0] + + quant_fx_model = nncf.quantize( + exported_model, nncf.Dataset(val_loader, transform_func=transform), model_type=ModelType.TRANSFORMER + ) + quant_compile_model = torch.compile(quant_fx_model, backend="openvino") + + # acc1_quant_model = validate(quant_compile_model, val_loader) + acc1_quant_model = -1.0 + latency_fx = measure_time(quant_compile_model, (example_input,)) + print(f"latency: {latency_fx}") + result["acc1_nncf_fx_quant_model"] = acc1_quant_model + result["torch_compile_ov_latency_nncf_fx_quant_model"] = latency_fx + + g = FxGraphDrawer(quant_compile_model, f"b_nncf_{pt_model.__class__.__name__}_int8") + g.get_dot_graph().write_svg(f"b_nncf_{pt_model.__class__.__name__}_int8.svg") + + # EXPORT TO OV + exported_model = torch.export.export(quant_compile_model, (example_input,)) + ov_quant_model = ov.convert_model(exported_model, example_input=example_input) + quant_file_path = output_dir / "quant.xml" + ov.save_model(ov_quant_model, quant_file_path) + + fps = run_benchmark(quant_file_path, shape_input) + print(f"fps: {fps}") + result["ov_fps_nncf_fx_quant_model"] = fps + + +def fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input): + with disable_patching(): + fp32_pt_model = copy.deepcopy(pt_model) + fp32_compile_model = torch.compile(fp32_pt_model, backend="openvino") + + quant_pt_model = quantize(fp32_compile_model, (example_input,), val_loader) + quant_compile_model = torch.compile(quant_pt_model, backend="openvino") + + g = FxGraphDrawer(quant_pt_model, f"b_pt_{pt_model.__class__.__name__}_int8") + g.get_dot_graph().write_svg(f"b_pt_{pt_model.__class__.__name__}_int8.svg") + + acc1_quant_model = validate(quant_compile_model, val_loader) + result["acc1_quant_model"] = acc1_quant_model + + latency_fx = measure_time(quant_compile_model, (example_input,)) + print(f"latency: {latency_fx}") + result["torch_compile_latency_fps_quant_model"] = latency_fx + + +def nncf_pt_2_ov_quantization(pt_model, val_loader, example_input, output_dir, result, shape_input): + def transform(x): + return x[0] + + nncf_model = nncf.quantize(copy.deepcopy(pt_model), nncf.Dataset(val_loader, transform_func=transform)) + + ov_nncf_model = ov.convert_model(nncf_model, example_input=example_input) + nncf_pt_file_path = output_dir / "nncf_pt.xml" + ov.save_model(ov_nncf_model, nncf_pt_file_path) + acc1_nncf_pt = validate_ov(ov_nncf_model, val_loader) + result["acc1_nncf_pt"] = acc1_nncf_pt + fps = run_benchmark(nncf_pt_file_path, shape_input) + print(f"fps: {fps}") + result["ov_fps_nncf_pt"] = fps + + +def nncf_ov_2_ov_quantization(ov_fp32_model, val_loader, output_dir, result, shape_input): + def transform(x): + return np.array(x[0]) + + from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters + from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters + + advanced_params = AdvancedQuantizationParameters() + # for sq_param in [-1, 0.15, 0.5, 0.75]: + for sq_param in [0.95]: + advanced_params.smooth_quant_alphas = AdvancedSmoothQuantParameters(matmul=sq_param) + + from copy import deepcopy + + fast_bias_correction = True + nncf_ov_int8_model = nncf.quantize( + deepcopy(ov_fp32_model), + nncf.Dataset(val_loader, transform_func=transform), + fast_bias_correction=fast_bias_correction, + model_type=ModelType.TRANSFORMER, + preset=QuantizationPreset.MIXED, + advanced_parameters=advanced_params, + ) + acc1_nncf_ov = validate_ov(nncf_ov_int8_model, val_loader) + result[f"acc1_nncf_ov_{sq_param}"] = acc1_nncf_ov + for precision, model in (("int8", nncf_ov_int8_model), ("fp32", ov_fp32_model)): + nncf_ov_file_path = output_dir / f"nncf_ov_{precision}.xml" + ov.save_model(model, nncf_ov_file_path) + fps = run_benchmark(nncf_ov_file_path, shape_input) + print(f"fps_{precision}: {fps} {sq_param}") + result[f"ov_fps_nncf_ov_{precision}_{sq_param}"] = fps + + latency = measure_time_ov(model, next(iter(val_loader))[0], num_iters=10_000) + print(f"latency_{precision}: {latency}") + result[f"ov_latency_nncf_ov_{precision}_{sq_param}"] = latency + + +def process_model(model_name: str): + + result = {"name": model_name} + model_cls, model_weights = MODELS_DICT[model_name] + output_dir = Path("models") / model_name + output_dir.mkdir(exist_ok=True) + ############################################################## + # Prepare dataset + ############################################################## + + val_dataset = datasets.ImageFolder(root=DATASET_IMAGENET, transform=model_weights.transforms()) + val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=2, shuffle=False) + + ############################################################## + # Prepare original model + ############################################################## + + pt_model = model_cls(weights=model_weights) + pt_model = pt_model.eval() + example_input = next(iter(val_loader))[0] + shape_input = list(example_input.shape) + ############################################################## + # Process FP32 Model + ############################################################## + + fp32_pt_model = copy.deepcopy(pt_model) + + orig_infer_acc1 = model_weights.meta.get("_metrics", {}).get("ImageNet-1K", {}).get("acc@1") + print(f"fp32 model metric: {orig_infer_acc1}") + # orig_infer_acc1 = validate(fp32_pt_model, val_loader) + result["acc1_fp32_openvino"] = orig_infer_acc1 + + fp32_pt_model = torch.export.export(fp32_pt_model, (example_input,)) + ov_fp32_model = ov.convert_model(fp32_pt_model, example_input=example_input) + ov_fp32_file_path = None + ov_fp32_file_path = output_dir / "fp32.xml" + ov.save_model(ov_fp32_model, ov_fp32_file_path) + # result["fps_fp32_openvino"] = run_benchmark(ov_fp32_file_path, shape_input) + # print(f"fps_fp32_openvino {result['fps_fp32_openvino']}") + + del fp32_pt_model + ############################################################## + # Process Torch AO Quantize with SQ + ############################################################## + # torch_ao_sq_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) + + ############################################################## + # with torch.no_grad(): + # exported_model = capture_pre_autograd_graph(pt_model, (example_input,)) + # latency_fx = measure_time(torch.compile(exported_model), (example_input,)) + # print(f"latency: {latency_fx}") + ############################################################# + + ############################################################## + # Process PT Quantize + ############################################################## + fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) + + ############################################################## + # Process NNCF FX Quantize + ############################################################## + # nncf_fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) + + ############################################################## + # Process NNCF Quantize by PT + ############################################################## + # nncf_pt_2_ov_quantization(pt_model, val_loader, example_input, output_dir, result, shape_input) + + ############################################################## + # Process NNCF Quantize by OV + ############################################################## + # nncf_ov_2_ov_quantization(ov_fp32_model, val_loader, output_dir, result, shape_input) + + print(result) + return result + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", help="torchvision model name", type=str, default="all") + parser.add_argument("--file_name", help="output csv file_name", type=str, default="result.csv") + + args = parser.parse_args() + + results_list = [] + if args.model == "all": + for model_name in MODELS_DICT: + print("---------------------------------------------------") + print(f"name: {model_name}") + results_list.append(process_model(model_name)) + else: + results_list.append(process_model(args.model)) + + df = pd.DataFrame(results_list) + print(df) + df.to_csv(args.file_name) + + +if __name__ == "__main__": + main() diff --git a/examples/llm_compression/openvino/tiny_llama/main.py b/examples/llm_compression/openvino/tiny_llama/main.py index dd03a4361c6..436f95f0af4 100644 --- a/examples/llm_compression/openvino/tiny_llama/main.py +++ b/examples/llm_compression/openvino/tiny_llama/main.py @@ -11,12 +11,12 @@ import time from functools import partial -import datasets import numpy as np import openvino as ov from optimum.intel.openvino import OVModelForCausalLM from transformers import AutoTokenizer +import datasets import nncf diff --git a/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py b/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py index 7ab0176eb85..56acd0dd4f2 100644 --- a/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py +++ b/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py @@ -17,12 +17,12 @@ import numpy as np import openvino as ov -from datasets import load_dataset from optimum.intel import OVModelForCausalLM from transformers import AutoTokenizer from whowhatbench import Evaluator import nncf +from datasets import load_dataset from nncf.common.logging import nncf_logger DataItem = TypeVar("DataItem") diff --git a/nncf/common/factory.py b/nncf/common/factory.py index 6616f9dbe3a..d5d13605a07 100644 --- a/nncf/common/factory.py +++ b/nncf/common/factory.py @@ -41,6 +41,10 @@ def create(model: TModel) -> NNCFGraph: if model_backend == BackendType.OPENVINO: from nncf.openvino.graph.nncf_graph_builder import GraphConverter + return GraphConverter.create_nncf_graph(model) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.nncf_graph_builder import GraphConverter + return GraphConverter.create_nncf_graph(model) if model_backend == BackendType.TORCH: return model.nncf.get_graph() @@ -72,6 +76,10 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer: from nncf.torch.model_transformer import PTModelTransformer return PTModelTransformer(model) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.model_transformer import FXModelTransformer + + return FXModelTransformer(model) raise nncf.UnsupportedBackendError( "Cannot create backend-specific model transformer because {} is not supported!".format(model_backend.value) ) @@ -99,6 +107,10 @@ def create(model: TModel) -> Engine: from nncf.torch.engine import PTEngine return PTEngine(model) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.engine import FXEngine + + return FXEngine(model) raise nncf.UnsupportedBackendError( "Cannot create backend-specific engine because {} is not supported!".format(model_backend.value) ) @@ -151,6 +163,10 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator: from nncf.torch.statistics.aggregator import PTStatisticsAggregator return PTStatisticsAggregator(dataset) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.statistics.aggregator import FXStatisticsAggregator + + return FXStatisticsAggregator(dataset) raise nncf.UnsupportedBackendError( "Cannot create backend-specific statistics aggregator because {} is not supported!".format( model_backend.value diff --git a/nncf/common/graph/patterns/manager.py b/nncf/common/graph/patterns/manager.py index 08bae0000af..e784f772628 100644 --- a/nncf/common/graph/patterns/manager.py +++ b/nncf/common/graph/patterns/manager.py @@ -52,6 +52,11 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> Dict[HWFusedPatternNam registry = cast(Dict[HWFusedPatternNames, Callable[[], GraphPattern]], PT_HW_FUSED_PATTERNS.registry_dict) return registry + if backend == BackendType.TORCH_FX: + from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS + + registry = PT_HW_FUSED_PATTERNS.registry_dict + return registry raise ValueError(f"Hardware-fused patterns not implemented for {backend} backend.") @staticmethod @@ -82,6 +87,11 @@ def _get_backend_ignored_patterns_map( registry = cast(Dict[IgnoredPatternNames, Callable[[], GraphPattern]], PT_IGNORED_PATTERNS.registry_dict) return registry + if backend == BackendType.TORCH_FX: + from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS + + registry = PT_IGNORED_PATTERNS.registry_dict + return registry raise ValueError(f"Ignored patterns not implemented for {backend} backend.") @staticmethod diff --git a/nncf/common/utils/backend.py b/nncf/common/utils/backend.py index 7589aabb739..ee3f9b75768 100644 --- a/nncf/common/utils/backend.py +++ b/nncf/common/utils/backend.py @@ -20,6 +20,7 @@ class BackendType(Enum): TORCH = "Torch" + TORCH_FX = "TorchFX" TENSORFLOW = "Tensorflow" ONNX = "ONNX" OPENVINO = "OpenVINO" @@ -33,6 +34,7 @@ def get_available_backends() -> List[BackendType]: """ frameworks = [ ("torch", BackendType.TORCH), + ("torch.fx", BackendType.TORCH_FX), ("tensorflow", BackendType.TENSORFLOW), ("onnx", BackendType.ONNX), ("openvino.runtime", BackendType.OPENVINO), @@ -51,14 +53,27 @@ def get_available_backends() -> List[BackendType]: def is_torch_model(model: TModel) -> bool: """ - Returns True if the model is an instance of torch.nn.Module, otherwise False. + Returns True if the model is an instance of torch.nn.Module and not a torch.fx.GraphModule, otherwise False. :param model: A target model. - :return: True if the model is an instance of torch.nn.Module, otherwise False. + :return: True if the model is an instance of torch.nn.Module and not torch.fx.GraphModule, otherwise False. """ - import torch # type: ignore + import torch # type: ignore + import torch.fx - return isinstance(model, torch.nn.Module) + return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module) + + +def is_torch_fx_model(model: TModel) -> bool: + """ + Returns True if the model is an instance of torch.fx.GraphModule, otherwise False. + + :param model: A target model. + :return: True if the model is an instance of torch.fx.GraphModule, otherwise False. + """ + import torch.fx + + return isinstance(model, torch.fx.GraphModule) def is_tensorflow_model(model: TModel) -> bool: @@ -118,6 +133,9 @@ def get_backend(model: TModel) -> BackendType: """ available_backends = get_available_backends() + if BackendType.TORCH_FX in available_backends and is_torch_fx_model(model): + return BackendType.TORCH_FX + if BackendType.TORCH in available_backends and is_torch_model(model): return BackendType.TORCH diff --git a/nncf/experimental/torch_fx/__init__.py b/nncf/experimental/torch_fx/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/nncf/experimental/torch_fx/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nncf/experimental/torch_fx/engine.py b/nncf/experimental/torch_fx/engine.py new file mode 100644 index 00000000000..5f9dc2ac221 --- /dev/null +++ b/nncf/experimental/torch_fx/engine.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Tuple, Union + +import torch +from torch import nn + +from nncf.common.engine import Engine + + +class FXEngine(Engine): + """ + Engine for the Pytorch FX backend. + """ + + def __init__(self, model: nn.Module): + """ + Constructor. + + :param model: Pytorch module to infer. + """ + + self._model = model + + def infer( + self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]] + ) -> Union[torch.Tensor, Dict[str, Any]]: + """ + Runs Torch model on the provided input. + + :param input_data: Inputs for the model. + :return: Model outputs. + """ + + if isinstance(input_data, dict): + return self._model(**input_data) + if isinstance(input_data, tuple): + return self._model(*input_data) + return self._model(input_data) diff --git a/nncf/experimental/torch_fx/model_transformer.py b/nncf/experimental/torch_fx/model_transformer.py new file mode 100644 index 00000000000..48b3cf0c1f1 --- /dev/null +++ b/nncf/experimental/torch_fx/model_transformer.py @@ -0,0 +1,183 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +# from functools import partial +from typing import Callable, List, Union + +import torch +import torch.fx +from torch.fx.passes.split_utils import split_by_tags + +from nncf.common.graph.model_transformer import ModelTransformer +from nncf.common.graph.transformations.commands import Command +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.commands import TransformationType +from nncf.torch.graph.transformations.commands import PTModelExtractionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.layout import PTTransformationLayout + + +class FXModuleInsertionCommand(Command): + def __init__( + self, + target_points: List[PTTargetPoint], + module_to_insert: torch.nn.Module, + priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, + ): + super().__init__(TransformationType.INSERT) + self.target_points = target_points + self.module_to_insert = module_to_insert + self.priority = priority + + +class FXApplyTransformationCommand(Command): + def __init__( + self, + transformation_fn: Callable[[torch.fx.GraphModule], None], + priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, + ): + super().__init__(TransformationType.INSERT) + self.tranformation_fn = transformation_fn + self.priority = priority + + +class FXModelTransformer(ModelTransformer): + """ + Applies transformations upon Torch FX model. + """ + + # TODO: manage priorities of transformations + + def __init__(self, model: torch.fx.GraphModule): + super().__init__(model) + + self._command_transformation_ordered_pairs = [ + # TODO: Move the module insertion command to a transformation + (FXApplyTransformationCommand, self._apply_transformation), + (FXModuleInsertionCommand, self._apply_module_insertion), + (PTModelExtractionCommand, self._apply_model_extraction), + ] + + def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule: + transformations = transformation_layout.transformations + aggregated_transformations = defaultdict(list) + for transformation in transformations: + aggregated_transformations[transformation.__class__].append(transformation) + + model = self._model + for transformation_cls, transformation_fn in self._command_transformation_ordered_pairs: + transformations = aggregated_transformations[transformation_cls] + if transformations: + model = transformation_fn(model, transformations) + + # Do not eliminate dead code as + # the dead code is coputing statistics :) + # model.graph.eliminate_dead_code() + model.recompile() + return model + + @staticmethod + def _apply_model_extraction( + model: torch.fx.GraphModule, + transformations: List[PTModelExtractionCommand], + ) -> torch.fx.GraphModule: + transformation = transformations[-1] + assert len(transformation.input_node_names) == 1 + assert transformation.input_node_names == transformation.output_node_names + node_name = transformation.input_node_names[0] + + tags = ["before", "extracted", "after"] + i = 0 + for node in model.graph.nodes: + if node.name == node_name: + node.tag = tags[1] + weights = [node.all_input_nodes[1]] + while weights: + w_node = weights.pop() + assert w_node.tag in tags[0:2] + w_node.tag = tags[1] + weights.extend(w_node.all_input_nodes) + i = 2 + continue + node.tag = tags[i] + + splitted_gm = split_by_tags(model, tags) + return splitted_gm.extracted + + @staticmethod + def _apply_module_insertion( + model: torch.fx.GraphModule, + transformations: List[FXModuleInsertionCommand], + ) -> torch.fx.GraphModule: + """ + Applies insertion of PTSharedFnInsertionCommand commands. For each command method inserts + a torch module to the torch.fx.GraphModule and inserts call hooks for each command target points. + + :param model: Model to apply transformations. + :param transformations: List of the bias correction transformations. + :param device: Target device for the insertion functions. Applies only to + functions which are subclassed from torch.nn.Module. Do nothing in case device is None. + :return: A modified torch.fx.GraphModule. + """ + for transformation in transformations: + # Set fn to the model as an attribute + module_to_insert = transformation.module_to_insert + module_name_in_model = ( + ";".join( + "_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) + for tp in transformation.target_points + ) + + "_" + + str(id(module_to_insert)) + ) + assert not hasattr(model, module_name_in_model) + setattr(model, module_name_in_model, module_to_insert) + # Insert call_module nodes to the model + for target_point in transformation.target_points: + FXModelTransformer._create_call_module_node(model.graph, target_point, module_name_in_model) + return model + + @staticmethod + def get_graph_node_by_name(graph, name): + for node in graph.nodes: + if node.name == name: + return node + raise RuntimeError(f"Node with name {name} is not found") + + @staticmethod + def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint): + target_type = target_point.target_type + target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name) + if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]: + target_node = target_node.all_input_nodes[target_point.input_port_id] + elif target_type == TargetType.OPERATOR_POST_HOOK: + pass + else: + raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}") + return target_node + + @staticmethod + def _create_call_module_node(graph: torch.fx.Graph, target_point: PTTargetPoint, module_name: str): + target_node = FXModelTransformer._get_target_node(graph, target_point) + with graph.inserting_after(target_node): + graph.create_node("call_module", module_name, (target_node,), {}, name=module_name + "_graph_node") + + @staticmethod + def _apply_transformation( + model: torch.fx.GraphModule, + transformations: List[FXApplyTransformationCommand], + ) -> torch.fx.GraphModule: + for transformation in transformations: + transformation.tranformation_fn(model) + return model diff --git a/nncf/experimental/torch_fx/nncf_graph_builder.py b/nncf/experimental/torch_fx/nncf_graph_builder.py new file mode 100644 index 00000000000..9990ee3bf2f --- /dev/null +++ b/nncf/experimental/torch_fx/nncf_graph_builder.py @@ -0,0 +1,167 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from itertools import chain +from typing import Tuple + +import torch.fx +from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ + +import nncf.torch.graph.operator_metatypes as om +from nncf.common.graph import NNCFGraph +from nncf.common.graph import NNCFNode +from nncf.common.graph.layer_attributes import Dtype +from nncf.common.graph.operator_metatypes import UnknownMetatype +from nncf.common.logging import nncf_logger +from nncf.experimental.torch_fx.transformations import separate_conv_and_bias +from nncf.experimental.torch_fx.transformations import separate_linear_and_bias +from nncf.experimental.torch_fx.transformations import view_to_reshape +from nncf.torch.graph.graph import PTNNCFGraph +from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES + + +class GraphConverter: + """ + Builds the NNCFGraph from an OpenVINO model. + """ + + @staticmethod + def _get_leaf_node(module: torch.nn.Module, node: torch.fx.Node) -> torch.nn.Module: + py_obj = module + assert isinstance(node.target, str) + atoms = node.target.split(".") + for atom in atoms: + if not hasattr(py_obj, atom): + raise RuntimeError(str(py_obj) + " does not have attribute " + atom + "!") + py_obj = getattr(py_obj, atom) + return py_obj + + @staticmethod + def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMetatype]: + if node.op == "placeholder": + node_type = "input" + node_metatype = om.PTInputNoopMetatype + elif node.op == "output": + node_type = "output" + node_metatype = om.PTOutputNoopMetatype + elif node.op == "get_attr": + node_type = "get_attr" + node_metatype = om.PTConstNoopMetatype + elif node.op in ("call_function",): + if hasattr(node.target, "overloadpacket"): + node_type = str(node.target.overloadpacket).split(".")[1] + elif node.target.__name__ == "getitem": + node_type = "__getitem__" + else: + # TODO: get correct nodes types from this nodes as well + node_type = str(node.target) + node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type) + else: + node_type = node.op + node_metatype = UnknownMetatype + if node_metatype is UnknownMetatype: + nncf_logger.info(f"Unknown metatype for node: {node}") + return node_type, node_metatype + + @staticmethod + def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph: + """ + Creates NNCFGraph from GraphModule. + All nodes from model which have valid metatype are added to NNCFGraph. + Then, corresponding edges are added to the NNCFGraph with shape, type, output and input port ids. + + :param model: torch fx GraphModule. + :return: NNCFGraph. + """ + + _fuse_conv_bn_(model) + # BN fuses to conv bias, conv+bias joined op + # needs to be splited for nncf + separate_linear_and_bias(model) + separate_conv_and_bias(model) + view_to_reshape(model) + + nncf_graph = PTNNCFGraph() + + for source_node in model.graph.nodes: + + node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node) + + nncf_node = nncf_graph.add_nncf_node( + node_name=source_node.name, + node_type=node_type, + node_metatype=node_metatype, # layer_attributes, + ) + + def get_module_params_or_buffers(): + for pname, ptensor in chain(leaf_module.named_parameters(), leaf_module.named_buffers()): + pname1 = source_node.name + "." + pname + nncf_param_node = nncf_graph.add_nncf_node( + pname1, + "parameter" if isinstance(ptensor, torch.nn.Parameter) else "buffer", + om.PTConstNoopMetatype, + ) + # TODO: Use valid tensor_shape, input_port_id, output_port_id + nncf_graph.add_edge_between_nncf_nodes( + nncf_param_node, nncf_node, tensor_shape=[1, 1, 1, 1], input_port_id=0, output_port_id=0 + ) + + if source_node.op == "call_module": + leaf_module = GraphConverter._get_leaf_node(model, source_node) + + if not isinstance(leaf_module, torch.fx.GraphModule): + get_module_params_or_buffers() + + for source_node in model.graph.nodes: + + source_nncf_node = nncf_graph.get_node_by_name(source_node.name) + for idx, dist_node in enumerate(source_node.users): + dist_node_id = nncf_graph.get_node_by_name(dist_node.name).node_id + input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params( + model, source_node, source_nncf_node, dist_node, idx + ) + + nncf_graph.add_edge_between_nncf_nodes( + source_nncf_node.node_id, + dist_node_id, + tensor_shape=tensor_shape, + input_port_id=input_port_id, + output_port_id=output_port_id, + dtype=Dtype.FLOAT, + ) + + return nncf_graph + + @staticmethod + def get_edge_params( + model, source_node: torch.fx.Node, source_nncf_node: NNCFNode, dist_node: torch.fx.Node, output_idx: int + ): + output_port_id = 0 + if source_node.op in ("get_attr",): + tensor_shape = tuple(getattr(model, source_node.target).shape) + elif "val" in source_node.meta: + if source_nncf_node.metatype is om.PTBatchNormMetatype: + tensor = source_node.meta["val"][0] + elif source_nncf_node.metatype is om.PTSplitMetatype: + tensor = source_node.meta["val"][output_idx] + # Assume every split outputs corresponds to an unique output_port_id + output_port_id = output_idx + else: + tensor = source_node.meta["val"] + tensor_shape = tuple(tensor.shape) + else: + nncf_logger.info( + f"Edge shape between {source_node.name} and {dist_node.name} is unknown. Using [1,1,1,1] instead." + ) + tensor_shape = [1, 1, 1, 1] + + input_port_id = dist_node.all_input_nodes.index(source_node) + return input_port_id, output_port_id, tensor_shape diff --git a/nncf/experimental/torch_fx/quantization/__init__.py b/nncf/experimental/torch_fx/quantization/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/nncf/experimental/torch_fx/quantization/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nncf/experimental/torch_fx/quantization/quantize_model.py b/nncf/experimental/torch_fx/quantization/quantize_model.py new file mode 100644 index 00000000000..0f40800fb49 --- /dev/null +++ b/nncf/experimental/torch_fx/quantization/quantize_model.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Optional + +import torch +import torch.fx +from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass +from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ +from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat +from torch.ao.quantization.pt2e.utils import _disallow_eval_train +from torch.fx import GraphModule +from torch.fx.passes.infra.pass_manager import PassManager + +import nncf +from nncf.common.factory import NNCFGraphFactory +from nncf.common.quantization.structs import QuantizationPreset +from nncf.common.quantization.structs import QuantizationScheme +from nncf.data import Dataset +from nncf.experimental.torch_fx.transformations import merge_conv_and_bias +from nncf.parameters import ModelType +from nncf.parameters import QuantizationMode +from nncf.parameters import TargetDevice +from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters +from nncf.quantization.advanced_parameters import QuantizationParameters +from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization +from nncf.scopes import IgnoredScope + +DEFAULT_RANGE_TYPE = "mean_min_max" + + +def quantize_impl( + model: torch.fx.GraphModule, + calibration_dataset: Dataset, + mode: Optional[QuantizationMode] = None, + preset: Optional[QuantizationPreset] = None, + target_device: TargetDevice = TargetDevice.ANY, + subset_size: int = 300, + fast_bias_correction: bool = True, + model_type: Optional[ModelType] = None, + ignored_scope: Optional[IgnoredScope] = None, + advanced_parameters: Optional[AdvancedQuantizationParameters] = None, +) -> torch.nn.Module: + """ + Implementation of the `quantize()` method for the Torch FX backend. + """ + if fast_bias_correction is False: + raise ValueError(f"fast_bias_correction={fast_bias_correction} is not supported") + if target_device == TargetDevice.CPU_SPR: + raise nncf.InternalError("target_device == CPU_SPR is not supported") + if mode is not None: + raise ValueError(f"mode={mode} is not supported") + + original_graph_meta = model.meta + + copied_model = deepcopy(model) + + if advanced_parameters is None: + advanced_parameters = AdvancedQuantizationParameters() + # torch.fx supports only assymetric activations quantization + # force to use only this type of quantization + activations_quantization_params = advanced_parameters.activations_quantization_params + if activations_quantization_params is None: + activations_quantization_params = QuantizationParameters() + + activations_quantization_params.mode = QuantizationScheme.ASYMMETRIC + advanced_parameters.activations_quantization_params = activations_quantization_params + + quantization_algorithm = PostTrainingQuantization( + preset=preset, + target_device=target_device, + subset_size=subset_size, + fast_bias_correction=fast_bias_correction, + model_type=model_type, + ignored_scope=ignored_scope, + advanced_parameters=advanced_parameters, + ) + nncf_graph = NNCFGraphFactory.create(copied_model) + quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset) + merge_conv_and_bias(quantized_model) + + # Magic. Without this call compiled model + # is not preformant + quantized_model = GraphModule(quantized_model, quantized_model.graph) + + quantized_model = _fold_conv_bn_qat(quantized_model) + pm = PassManager([DuplicateDQPass()]) + + quantized_model = pm(quantized_model).graph_module + pm = PassManager([PortNodeMetaForQDQ()]) + quantized_model = pm(quantized_model).graph_module + + quantized_model.meta.update(original_graph_meta) + quantized_model = _disallow_eval_train(quantized_model) + + return quantized_model diff --git a/nncf/experimental/torch_fx/statistics/__init__.py b/nncf/experimental/torch_fx/statistics/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/nncf/experimental/torch_fx/statistics/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nncf/experimental/torch_fx/statistics/aggregator.py b/nncf/experimental/torch_fx/statistics/aggregator.py new file mode 100644 index 00000000000..f1ce6ff05ec --- /dev/null +++ b/nncf/experimental/torch_fx/statistics/aggregator.py @@ -0,0 +1,99 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import numpy as np +import torch + +from nncf.common.factory import TModel +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer +from nncf.common.tensor_statistics.aggregator import StatisticsAggregator +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.torch_fx.model_transformer import FXModuleInsertionCommand +from nncf.tensor import Tensor +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.return_types import maybe_get_values_from_torch_return_type + + +class TensorCollectorModule(torch.nn.Module): + """ + torch.nn.Module which calls given collector in forward + """ + + def __init__(self, collector: TensorCollector): + super().__init__() + self._collector = collector + + def forward(self, x: torch.Tensor): + """ + Register inputs hook function. + + :parameter x: tensor to register in hook. + :return: tensor to register in hook. + """ + x_unwrapped = maybe_get_values_from_torch_return_type(x) + self._collector.register_input_for_all_reducers(Tensor(x_unwrapped)) + return x + + +class FXStatisticsAggregator(StatisticsAggregator): + HOOKS_GROUP_NAME = "statistics_hooks" + + def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None: + with torch.no_grad(): + super().collect_statistics(model, graph) + # All statistics are collected as a dead code, + # so eliminate dead core removed statistcs collector + # from the target model. No additional code required + # for that, horay! + model.graph.eliminate_dead_code() + model.recompile() + + def _register_statistics(self, outputs: Dict[str, Tensor], statistic_points: StatisticPointsContainer) -> None: + return + + def _get_transformation_layout_extra_outputs( + self, statistic_points: StatisticPointsContainer + ) -> TransformationLayout: + transformation_layout = TransformationLayout() + transformation_commands = [] + + for _statistic_points in statistic_points.values(): + for _statistic_point in _statistic_points: + for collectors in _statistic_point.algorithm_to_tensor_collectors.values(): + for collector in collectors: + transformation_commands.append( + FXModuleInsertionCommand( + [_statistic_point.target_point], + TensorCollectorModule(collector), + TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION, + ) + ) + + for transformation_command in transformation_commands: + transformation_layout.register(transformation_command) + + return transformation_layout + + @staticmethod + def _get_merged_statistic_points( + statistic_points: StatisticPointsContainer, model: TModel, graph: NNCFGraph + ) -> StatisticPointsContainer: + # TODO: mirgate to experimental statistic collector and use common merging algorithm + return statistic_points + + @staticmethod + def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, Tensor]: + return outputs diff --git a/nncf/experimental/torch_fx/transformations.py b/nncf/experimental/torch_fx/transformations.py new file mode 100644 index 00000000000..d572c06b120 --- /dev/null +++ b/nncf/experimental/torch_fx/transformations.py @@ -0,0 +1,350 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Callable, List, Optional + +import torch +import torch.fx +from torch.ao.quantization.fx.utils import create_getattr_from_value +from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node +from torch.ao.quantization.pt2e.utils import _is_conv +from torch.quantization.fake_quantize import FakeQuantize + +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.transformations.commands import TargetType +from nncf.experimental.torch_fx.model_transformer import FXModelTransformer +from nncf.torch.graph.transformations.commands import PTTargetPoint + + +def fake_quantize_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]): + def fake_quantize_insertion_transformation(model: torch.fx.GraphModule): + module_attr_name = _set_module_to_the_graph_module(model, quantizer, target_points) + graph = model.graph + for target_point in target_points: + target_node = FXModelTransformer._get_target_node(model.graph, target_point) + with graph.inserting_after(target_node): + fq_node = graph.create_node( + "call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_quantizer" + ) + for user in list(target_node.users): + if user is fq_node: + continue + user.replace_input_with(target_node, fq_node) + + return fake_quantize_insertion_transformation + + +def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor): + def bias_update_transformation(model: torch.fx.GraphModule): + graph = model.graph + target_node_name = node.node_name + graph_node = FXModelTransformer.get_graph_node_by_name(graph, target_node_name) + bias_node = next(iter(graph_node.users)) + with graph.inserting_before(bias_node): + new_constant = create_getattr_from_value(model, graph, target_node_name + "_shifted_bias", value) + args = list(bias_node.args) + args[1] = new_constant + bias_node.args = tuple(args) + graph.eliminate_dead_code() + + return bias_update_transformation + + +def qdq_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]): + def qdq_insertion_tranformation(model: torch.fx.GraphModule): + if any(tp.target_type != TargetType.OPERATION_WITH_WEIGHTS for tp in target_points) and len(target_points) > 1: + raise RuntimeError + for target_point in target_points: + target_node = FXModelTransformer._get_target_node(model.graph, target_point) + insert_one_qdq(model, target_node, quantizer, target_point) + + return qdq_insertion_tranformation + + +def insert_one_qdq( + model: torch.fx.GraphModule, target_node: torch.fx.Node, quantizer: FakeQuantize, target_point: PTTargetPoint +): + # Copied from torch.ao.quantization.quantize_pt2e.convert_pt2e + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op: Optional[Callable] = None + # scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + dtype = torch.int8 if quantizer.quant_min < 0 else torch.uint8 + if quantizer.is_per_channel: + qparams = { + "_scale_": quantizer.scale, + "_zero_point_": quantizer.zero_point, + "_axis_": quantizer.ch_axis, + "_quant_min_": quantizer.quant_min, + "_quant_max_": quantizer.quant_max, + "_dtype_": dtype, + } + quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default + else: + qparams = { + "_scale_": float(quantizer.scale), + "_zero_point_": int(quantizer.zero_point), + "_quant_min_": quantizer.quant_min, + "_quant_max_": quantizer.quant_max, + "_dtype_": dtype, + } + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default + + # 2. replace activation_post_process node with quantize and dequantize + graph = model.graph + # TODO: use metatype to get correct input_port_id + # Do not quantize already quantized nodes + # inserting_before handle only order in the graph generated code. + # so, inserting quantize-dequantize and all constant nodes before the usage of the nodes + with graph.inserting_before(target_node): + quantize_op_inputs = [target_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"] and (not isinstance(value_or_node, (float, int))): + # For scale and zero_point values we register them as buffers in the root module. + # However, note that when the values are not tensors, as in the case of + # per_tensor quantization, they will be treated as literals. + # However, registering them as a node seems to cause issue with dynamo + # tracing where it may consider tensor overload as opposed to default. + # With extra check of scale and zero_point being scalar, it makes + # sure that the default overload can be used. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value(model, graph, target_node.name + key, value_or_node) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store + # them as literals in the graph. + quantize_op_inputs.append(value_or_node) + with graph.inserting_after(target_node): + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + user_dq_nodes = [] + with graph.inserting_after(quantized_node): + for user in target_node.users: + if user is quantized_node: + continue + user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {}))) + + for user, dq_node in user_dq_nodes: + user.replace_input_with(target_node, dq_node) + + +def _set_module_to_the_graph_module( + model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint] +) -> str: + """ + Sets given module to the given torch.fx.GraphModule with unique name. + """ + module_to_insert = module_to_insert + module_name_in_model = ( + ";".join( + "_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) for tp in target_points + ) + + "_" + + str(id(module_to_insert)) + ) + assert not hasattr(model, module_name_in_model) + setattr(model, module_name_in_model, module_to_insert) + return module_name_in_model + + +def _is_linear(n: torch.fx.Node): + return n.op == "call_function" and n.target in [torch.ops.aten.linear.default] + + +def separate_linear_and_bias(model: torch.fx.GraphModule): + """ + Separates one joined linear+bias node to two nodes: conv and bias. + Needed as nncf does not expect joined conv + """ + add_node_target = torch.ops.aten.add_.Tensor + for n in model.graph.nodes: + if not _is_linear(n): + continue + if len(n.args) < 3 or n.args[2] is None: + continue + linear_node = n + linear_bias_node = linear_node.args[2] + conv_bias_value = _get_tensor_constant_from_node(linear_bias_node, model) + args = list(n.args) + args[2] = None + linear_node.args = tuple(args) + with model.graph.inserting_after(linear_node): + new_linear_bias_node = create_getattr_from_value( + model, + model.graph, + linear_bias_node.name + "_", + conv_bias_value, + ) + with model.graph.inserting_after(new_linear_bias_node): + add_node = model.graph.create_node( + "call_function", add_node_target, (linear_node, new_linear_bias_node), {} + ) + for user in list(linear_node.users): + if user is add_node: + continue + user.replace_input_with(linear_node, add_node) + if "val" in linear_node.meta: + add_node.meta["val"] = linear_node.meta["val"] + model.graph.eliminate_dead_code() + model.recompile() + + +def view_to_reshape(model: torch.fx.GraphModule): + for n in model.graph.nodes: + if not (n.op == "call_function" and n.target in [torch.ops.aten.view.default]): + continue + with model.graph.inserting_after(n): + reshape = model.graph.create_node("call_function", torch.ops.aten.reshape.default, tuple(n.args), {}) + reshape.meta = n.meta + + for user in list(n.users): + user.replace_input_with(n, reshape) + + model.graph.eliminate_dead_code() + model.recompile() + + +def separate_conv_and_bias(model: torch.fx.GraphModule): + """ + Separates one joined conv+bias node to two nodes: conv and bias. + Needed as nncf does not expect joined conv + """ + add_node_target = torch.ops.aten.add_.Tensor + for n in model.graph.nodes: + if not _is_conv(n): + continue + if len(n.args) < 3 or n.args[2] is None: + continue + conv_node = n + dims = len(_get_tensor_constant_from_node(conv_node.args[1], model).shape) + conv_bias_node = conv_node.args[2] + conv_bias_value = _get_tensor_constant_from_node(conv_bias_node, model) + args = list(n.args) + args[2] = None + conv_node.args = tuple(args) + with model.graph.inserting_after(conv_node): + new_conv_bias_node = create_getattr_from_value( + model, + model.graph, + conv_bias_node.name + "_", + conv_bias_value.reshape( + ( + 1, + -1, + ) + + (1,) * (dims - 2) + ), + ) + with model.graph.inserting_after(new_conv_bias_node): + add_node = model.graph.create_node("call_function", add_node_target, (conv_node, new_conv_bias_node), {}) + for user in list(conv_node.users): + if user is add_node: + continue + user.replace_input_with(conv_node, add_node) + + if "val" in conv_node.meta: + add_node.meta["val"] = conv_node.meta["val"] + model.graph.eliminate_dead_code() + model.recompile() + + +def merge_conv_and_bias(model: torch.fx.GraphModule): + """ + Separates one joined conv+bias node to two nodes: conv and bias. + Needed as nncf does not expect joined conv + """ + add_node_targets = (torch.ops.aten.add_.Tensor,) + for n in model.graph.nodes: + if not _is_conv(n): + continue + if len(n.args) > 2 and n.args[2] is not None: + continue + bias_node = next(iter(n.users)) + if len(n.users) > 1 or bias_node.target not in add_node_targets: + continue + conv_node = n + const_node = None + for node in bias_node.all_input_nodes: + if node is not conv_node: + const_node = node + break + assert const_node is not None + bias_value = _get_tensor_constant_from_node(const_node, model).squeeze() + with model.graph.inserting_before(conv_node): + new_bias_node = create_getattr_from_value(model, model.graph, const_node.name + "_", bias_value) + args = list(conv_node.args) + args[2] = new_bias_node + conv_node.args = tuple(args) + for user in list(bias_node.users): + user.replace_input_with(bias_node, conv_node) + + model.graph.eliminate_dead_code() + model.recompile() + + +def _is_scaled_dot_product_attention(n: torch.fx.Node): + return n.op == "call_function" and n.target in [torch.ops.aten.scaled_dot_product_attention.default] + + +def _unfold_sdp(model: torch.fx.GraphModule, node: torch.fx.Node): + transpose_target = torch.ops.aten.transpose.int + matmul_target = torch.ops.aten.matmul.default + mul_target = torch.ops.aten.multiply.Scalar + softmax_target = torch.ops.aten.softmax.int + + query, key, value = node.args + q, k, v = (n.meta["val"] for n in node.args) + n = query.meta["val"].shape[-1] + scale_factor = 1 / math.sqrt(n) + + with model.graph.inserting_before(node): + k_transposed = model.graph.create_node("call_function", transpose_target, (key, -2, -1), {}) + k = k.transpose(-2, -1) + k_transposed.meta["val"] = torch.clone(k) + + sa = model.graph.create_node("call_function", matmul_target, (query, k_transposed), {}) + attn_value = q @ k + sa.meta["val"] = torch.clone(attn_value) + + sa_scaled = model.graph.create_node("call_function", mul_target, (sa, float(scale_factor)), {}) + sa_scaled.meta["val"] = torch.clone(attn_value) + + softmax = model.graph.create_node("call_function", softmax_target, (sa_scaled, -1), {}) + softmax.meta["val"] = torch.clone(attn_value) + + result = model.graph.create_node("call_function", matmul_target, (softmax, value), {}) + r = attn_value @ v + result.meta["val"] = torch.clone(r) + + for user in list(node.users): + user.replace_input_with(node, result) + model.graph.eliminate_dead_code() + + +@staticmethod +def unfold_scaled_dot_product_attention(model: torch.fx.GraphModule): + for n in model.graph.nodes: + if not _is_scaled_dot_product_attention(n): + continue + args = n.args + if len(args) > 3: + raise NotImplementedError( + f"Unfolding of scaled dot product attention node {n}" " with more than 3 inputs is not implemented yet" + ) + _unfold_sdp(model, n) + model.graph.eliminate_dead_code() + model.recompile() diff --git a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py index ff7836035c9..3d104cad3c9 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py @@ -93,7 +93,7 @@ def __init__( @property def available_backends(self) -> List[BackendType]: - return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH] + return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX] def _set_backend_entity(self, model: TModel) -> None: """ @@ -116,6 +116,12 @@ def _set_backend_entity(self, model: TModel) -> None: from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend self._backend_entity = PTFastBiasCorrectionAlgoBackend() + elif model_backend == BackendType.TORCH_FX: + from nncf.quantization.algorithms.fast_bias_correction.torch_fx_backend import ( + FXFastBiasCorrectionAlgoBackend, + ) + + self._backend_entity = FXFastBiasCorrectionAlgoBackend() else: raise nncf.UnsupportedBackendError( "Cannot return backend-specific entity because {} is not supported!".format(model_backend.value) diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py new file mode 100644 index 00000000000..089afd4ab11 --- /dev/null +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.fx +from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node + +import nncf.torch.graph.operator_metatypes as om +from nncf.common.graph import NNCFGraph +from nncf.common.graph import NNCFNode +from nncf.common.graph.definitions import NNCFGraphNodeType +from nncf.common.graph.transformations.commands import TargetType +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch_fx.transformations import bias_update_transformation_builder +from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend +from nncf.tensor import Tensor +from nncf.torch.graph.transformations.commands import PTModelExtractionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.tensor_statistics.collectors import get_mean_statistic_collector + + +class FXFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend): + TARGET_TYPE_TO_PT_INS_TYPE_MAP = { + TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK, + TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, + } + + @staticmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: + if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: + port_id = None + if target_type in FXFastBiasCorrectionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP: + target_type = FXFastBiasCorrectionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type] + return PTTargetPoint(target_type, target_node_name, input_port_id=port_id) + + @staticmethod + def create_bias_correction_command( + node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph + ) -> FXApplyTransformationCommand: + return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data)) + + @staticmethod + def model_extraction_command( + input_ids: List[Tuple[str, int]], output_ids: List[Tuple[str, int]] + ) -> PTModelExtractionCommand: + return PTModelExtractionCommand([input_ids[0][0]], [output_ids[0][0]]) + + @staticmethod + def mean_statistic_collector( + channel_axis: int, + inplace: bool, + num_samples: Optional[int] = None, + window_size: Optional[int] = None, + ) -> TensorCollector: + return get_mean_statistic_collector(num_samples, channel_axis, window_size) + + @staticmethod + def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]: + # Pytorch does not have name for extracted node + return None, None + + @staticmethod + def create_input_data(shape: Tuple[int], data: List[Tensor], input_name: str, channel_axis: int) -> torch.Tensor: + blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device) + for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): + index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) + blob[index] = data[j].data + return blob + + @staticmethod + def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor: + # TODO: make a node_name_vs_node map to speed up the process + from nncf.experimental.torch_fx.model_transformer import FXModelTransformer + + bias_node = nncf_graph.get_next_nodes(node)[0] + graph_bias_node = FXModelTransformer.get_graph_node_by_name(model.graph, bias_node.node_name) + return Tensor(_get_tensor_constant_from_node(graph_bias_node.all_input_nodes[1], model)) + + @staticmethod + def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]: + return 0, 0 + + @staticmethod + def process_model_output(raw_data: Dict, output_name: str) -> Tensor: + return Tensor(raw_data) + + @staticmethod + def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + weight_node = nncf_graph.get_previous_nodes(node)[1] + return weight_node.node_type == "dequantize_per_channel" + + @staticmethod + def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + # Assumes that all biases were unfused + if node.metatype in (om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype, om.PTLinearMetatype): + next_nodes = nncf_graph.get_next_nodes(node) + if len(next_nodes) != 1: + return False + return next_nodes[0].metatype in (om.PTAddMetatype,) + + @staticmethod + def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: + return node.node_name, node.node_name diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 2fefb664a18..f8cdd316529 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -328,7 +328,7 @@ def _init_cache(self) -> None: @property def available_backends(self) -> List[BackendType]: - return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH] + return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX] def _get_quantizer_constraints( self, @@ -381,6 +381,10 @@ def _set_backend_entity(self, model: TModel) -> None: from nncf.quantization.algorithms.min_max.openvino_backend import OVMinMaxAlgoBackend self._backend_entity = OVMinMaxAlgoBackend() + elif model_backend == BackendType.TORCH_FX: + from nncf.quantization.algorithms.min_max.torch_fx_backend import FXMinMaxAlgoBackend + + self._backend_entity = FXMinMaxAlgoBackend() elif model_backend == BackendType.TORCH: from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py new file mode 100644 index 00000000000..c5403386441 --- /dev/null +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -0,0 +1,351 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Set, Tuple + +import torch + +import nncf +import nncf.torch.graph.operator_metatypes as om +from nncf.common.graph.definitions import NNCFGraphNodeType +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationCommand +from nncf.common.hardware.config import HWConfig +from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode +from nncf.common.quantization.structs import QuantizerConfig +from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic +from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch_fx.transformations import qdq_insertion_tranformation_builder +from nncf.parameters import ModelType +from nncf.parameters import TargetDevice +from nncf.quantization.advanced_parameters import StatisticsType +from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend +from nncf.quantization.fake_quantize import FakeConvertParameters +from nncf.quantization.fake_quantize import FakeQuantizeParameters +from nncf.quantization.range_estimator import AggregatorType +from nncf.quantization.range_estimator import RangeEstimatorParameters +from nncf.torch.graph.graph import PTNNCFGraph +from nncf.torch.graph.graph import PTTargetPoint +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.hardware.config import PTHWConfig +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT +from nncf.torch.quantization.layers import QUANTIZATION_MODULES +from nncf.torch.quantization.layers import AsymmetricQuantizer +from nncf.torch.quantization.layers import BaseQuantizer +from nncf.torch.quantization.layers import PTQuantizerSpec +from nncf.torch.quantization.layers import get_scale_shape +from nncf.torch.quantization.strip import convert_to_torch_fakequantizer +from nncf.torch.tensor_statistics.collectors import PT_REDUCERS_MAP + + +class FXMinMaxAlgoBackend(MinMaxAlgoBackend): + TARGET_TYPE_TO_PT_INS_TYPE_MAP = { + TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK, + TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, + } + + @property + def mat_mul_metatypes(self) -> List[OperatorMetatype]: + return [om.PTLinearMetatype, om.PTMatMulMetatype] + + @property + def post_processing_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def shapeof_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def dropout_metatypes(self) -> List[OperatorMetatype]: + return [om.PTDropoutMetatype] + + @property + def read_variable_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def conv_metatypes(self) -> List[OperatorMetatype]: + return [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype] + + @property + def overflow_fix_metatypes(self) -> List[OperatorMetatype]: + return [ + om.PTConv1dMetatype, + om.PTConv2dMetatype, + om.PTConv3dMetatype, + om.PTLinearMetatype, + om.PTConvTranspose1dMetatype, + om.PTConvTranspose2dMetatype, + om.PTConvTranspose3dMetatype, + ] + + @property + def add_metatypes(self) -> List[OperatorMetatype]: + return [om.PTAddMetatype] + + @property + def group_conv_metatypes(self) -> List[OperatorMetatype]: + return self.conv_metatypes + + @property + def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]: + return {om.PTCatMetatype: self.overflow_fix_metatypes} + + @property + def hw_config(self) -> HWConfig: + return PTHWConfig + + @property + def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]: + return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT + + @staticmethod + def get_start_nodes_for_activation_path_tracing(nncf_graph: PTNNCFGraph) -> List[NNCFNode]: + return nncf_graph.get_input_nodes() + + @staticmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: + if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: + port_id = None + if target_type in FXMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP: + target_type = FXMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type] + return PTTargetPoint(target_type, target_node_name, input_port_id=port_id) + + @staticmethod + def create_convert_insertion_command( + target_point: PTTargetPoint, + parameters: FakeConvertParameters, + ) -> TransformationCommand: + raise nncf.InternalError("FakeConvert insertion not implemented in PyTorch backend!") + + @staticmethod + def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point: PTTargetPoint) -> Tuple[int, ...]: + return nncf_graph.get_input_shape_for_insertion_point(target_point) + + @staticmethod + def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint) -> Tuple[int]: + # TODO: support transpose conv and other cases + return (0,) + + @staticmethod + def get_statistic_collector( + range_estimator_params: RangeEstimatorParameters, + use_abs_max: bool, + reduction_axes: Optional[Tuple[int, ...]], + aggregation_axes: Optional[Tuple[int, ...]], + inplace: bool, + num_samples: Optional[int] = None, + ) -> TensorCollector: + collector = TensorCollector(MinMaxTensorStatistic) + for params, container_key in zip( + [range_estimator_params.min, range_estimator_params.max], + [MinMaxTensorStatistic.MIN_STAT, MinMaxTensorStatistic.MAX_STAT], + ): + if params.statistics_type not in PT_REDUCERS_MAP: + raise nncf.InternalError( + f"Statistic type: {params.statistics_type} is not supported for Torch PTQ backend yet." + ) + + if params.aggregator_type not in AGGREGATORS_MAP: + raise nncf.InternalError( + f"Aggregator type: {params.aggregator_type} is not supported for Torch PTQ backend yet." + ) + + statistic_type = params.statistics_type + if statistic_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]: + # TODO(dlyakhov): merge two quantile aggregators in one + if container_key == MinMaxTensorStatistic.MIN_STAT: + quantile = params.quantile_outlier_prob + else: + quantile = 1 - params.quantile_outlier_prob + reducer = PT_REDUCERS_MAP[statistic_type](reduction_axes=reduction_axes, quantile=[quantile]) + else: + if use_abs_max and statistic_type == StatisticsType.MAX: + statistic_type = StatisticsType.ABS_MAX + reducer = PT_REDUCERS_MAP[statistic_type](reduction_axes=reduction_axes) + + kwargs = { + "num_samples": num_samples, + "aggregation_axes": aggregation_axes, + } + if params.aggregator_type in [AggregatorType.MEAN_NO_OUTLIERS, AggregatorType.MEDIAN_NO_OUTLIERS]: + kwargs.update({"quantile": params.quantile_outlier_prob}) + aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs) + + collector.register_statistic_branch(container_key, reducer, aggregator) + return collector + + @staticmethod + def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]: + return node.metatype.weight_port_ids + + @staticmethod + def get_weight_name(nncf_graph: NNCFGraph, target_point: PTTargetPoint) -> str: + weighted_node = nncf_graph.get_node_by_name(target_point.target_node_name) + weight = nncf_graph.get_previous_nodes(weighted_node)[target_point.input_port_id] + return weight.node_name + + @staticmethod + def should_quantize_weight(weight_name: str, quantized_weight_names: Set[str]) -> bool: + # If the nodes share one weight tensor, we should have only one quantizer on that + return weight_name not in quantized_weight_names + + @staticmethod + def get_weight_config(config: QuantizerConfig, model: NNCFNetwork) -> QuantizerConfig: + return config + + @staticmethod + def _get_input_scale_shape( + nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: + is_weights = target_point.is_weight_target_point() + if is_weights: + # TODO: support transpose conv/ make channel_idx common + channel_idx = 0 + else: + channel_idx = 1 # channel dim for activations + + input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) + scale_shape = tuple( + get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx) + ) + + return input_shape, scale_shape, channel_idx + + @staticmethod + def _create_quantizer( + quantizer_config: QuantizerConfig, + scale_shape: Tuple, + parameters: FakeQuantizeParameters, + target_type: TargetType, + ) -> BaseQuantizer: + mode = quantizer_config.mode + quantizer_cls = QUANTIZATION_MODULES.get(mode) + narrow_range = target_type == TargetType.OPERATION_WITH_WEIGHTS and mode == QuantizationMode.SYMMETRIC + quantizer_spec = PTQuantizerSpec.from_config( + quantizer_config, + narrow_range=narrow_range, + scale_shape=scale_shape, + half_range=False, + logarithm_scale=False, + is_quantized_on_export=False, + compression_lr_multiplier=None, + ) + quantizer = quantizer_cls(quantizer_spec) + + # Fill it with minmax + FXMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape) + # Convert to the torch fake quantizer + torch_fq = convert_to_torch_fakequantizer(quantizer) + return torch_fq + + @staticmethod + def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None: + if isinstance(quantizer, AsymmetricQuantizer): + quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape)) + input_range = parameters.input_high - parameters.input_low + # Subtract eps from the input_range to make quantizer parameters equal to + # original parameters on the forward call. + quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape)) + else: + quantizer.signed = bool(torch.any(parameters.input_low.data < 0)) + # Subtract eps from the scale to make quantizer parameters equal to + # original parameters on the forward call. + quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape)) + + @staticmethod + def create_quantizer_insertion_command( + nncf_graph: NNCFGraph, + target_point: PTTargetPoint, + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> FXApplyTransformationCommand: + _, scale_shape, _ = FXMinMaxAlgoBackend._get_input_scale_shape( + nncf_graph, target_point, quantizer_config.per_channel + ) + + quantizer = FXMinMaxAlgoBackend._create_quantizer( + quantizer_config, scale_shape, parameters, target_point.target_type + ) + transformation = qdq_insertion_tranformation_builder(quantizer, [target_point]) + return FXApplyTransformationCommand(transformation) + + @staticmethod + def create_unified_scales_quantizers_insertion_commands( + nncf_graph: NNCFGraph, + target_points: List[PTTargetPoint], + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> List[PTSharedFnInsertionCommand]: + _, scale_shape, _ = FXMinMaxAlgoBackend._get_input_scale_shape( + nncf_graph, target_points[0], quantizer_config.per_channel + ) + + quantizer = FXMinMaxAlgoBackend._create_quantizer( + quantizer_config, scale_shape, parameters, target_points[0].target_type + ) + + transformations = [] + for tp in target_points: + transformation = qdq_insertion_tranformation_builder(quantizer, [tp]) + transformations.append(FXApplyTransformationCommand(transformation)) + return transformations + + @staticmethod + def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]: + types = [] + if model_type == ModelType.TRANSFORMER: + types = [ + om.PTAddMetatype, + om.PTPowerMetatype, + om.PTSubMetatype, + om.PTAvgPool2dMetatype, + om.PTAvgPool3dMetatype, + om.PTMeanMetatype, + om.PTSumMetatype, + om.PTReduceL2, + om.PTDivMetatype, + om.PTMaxMetatype, + om.PTSqueezeMetatype, + om.PTLayerNormMetatype, + om.PTModuleLayerNormMetatype, + om.PTGroupNormMetatype, + om.PTModuleGroupNormMetatype, + # Batchnorm + om.PTBatchNormMetatype, + om.PTModuleBatchNormMetatype, + ] + if device != TargetDevice.CPU_SPR: + types.append(om.PTMulMetatype) + return types + + @staticmethod + def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> List[str]: + return [] + + @staticmethod + def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]: + retval = set() + for node in nncf_graph.get_all_nodes(): + if node.metatype in [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype, om.PTLinearMetatype]: + retval.add(node) + return list(retval) diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 65b804f10a8..166bae404ef 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -228,7 +228,21 @@ def quantize( ignored_scope=ignored_scope, advanced_parameters=advanced_parameters, ) + if backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.quantization.quantize_model import quantize_impl + return quantize_impl( + model=model, + calibration_dataset=calibration_dataset, + mode=mode, + preset=preset, + target_device=target_device, + subset_size=subset_size, + fast_bias_correction=fast_bias_correction, + model_type=model_type, + ignored_scope=ignored_scope, + advanced_parameters=advanced_parameters, + ) raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}") diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 5148496fea6..5d20a0d7ba6 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -371,7 +371,7 @@ def patch_torch_operators(): functions_to_patch = {} for namespace in NamespaceTarget: - if namespace == NamespaceTarget.EXTERNAL: + if namespace in [NamespaceTarget.ATEN, NamespaceTarget.EXTERNAL]: continue functions_to_patch[namespace] = get_all_functions_from_namespace(namespace) diff --git a/nncf/torch/dynamic_graph/structs.py b/nncf/torch/dynamic_graph/structs.py index c767790a92c..d8cf563107f 100644 --- a/nncf/torch/dynamic_graph/structs.py +++ b/nncf/torch/dynamic_graph/structs.py @@ -22,6 +22,7 @@ class NamespaceTarget(Enum): TORCH_TENSOR = "torch.tensor" TORCH_NN_PARAMETER = "torch.nn.parameter" TORCH = "torch" + ATEN = "aten" EXTERNAL = "external_function" diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index ca038b77f24..d8d8a1a50c8 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -56,6 +56,7 @@ class PTOperatorMetatype(OperatorMetatype): NamespaceTarget.TORCH_NN_FUNCTIONAL: [], NamespaceTarget.TORCH_TENSOR: [], NamespaceTarget.TORCH: [], + NamespaceTarget.ATEN: [], } subtypes: List[Type["PTOperatorMetatype"]] = [] @@ -528,7 +529,7 @@ class PTGELUMetatype(PTOperatorMetatype): @PT_OPERATOR_METATYPES.register() class PTSILUMetatype(PTOperatorMetatype): name = "SiluOp" - module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["silu"]} + module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["silu"], NamespaceTarget.ATEN: ["silu_"]} @PT_OPERATOR_METATYPES.register() @@ -553,6 +554,7 @@ class PTAddMetatype(PTOperatorMetatype): "__radd__", ], NamespaceTarget.TORCH: ["add"], + NamespaceTarget.ATEN: ["add_"], } hw_config_names = [HWConfigOpName.ADD] num_expected_input_edges = 2 @@ -570,6 +572,7 @@ class PTSubMetatype(PTOperatorMetatype): "__rsub__", ], NamespaceTarget.TORCH: ["sub"], + NamespaceTarget.ATEN: ["sub_"], } hw_config_names = [HWConfigOpName.SUBTRACT] num_expected_input_edges = 2 @@ -706,6 +709,7 @@ class PTModuleBatchNormMetatype(PTModuleOperatorSubtype): name = "BatchNormOp" module_to_function_names = { NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"], + NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training"], } @@ -714,6 +718,7 @@ class PTBatchNormMetatype(PTOperatorMetatype): name = "BatchNormOp" module_to_function_names = { NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"], + NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training"], } subtypes = [PTModuleBatchNormMetatype] weight_port_ids = [3] @@ -844,6 +849,7 @@ class PTGatherMetatype(PTOperatorMetatype): module_to_function_names = { NamespaceTarget.TORCH_TENSOR: ["index_select", "__getitem__"], NamespaceTarget.TORCH: ["gather", "index_select", "select", "where"], + NamespaceTarget.ATEN: ["slice"], } @@ -880,6 +886,7 @@ class PTSplitMetatype(PTOperatorMetatype): NamespaceTarget.TORCH_NN_FUNCTIONAL: [], NamespaceTarget.TORCH_TENSOR: ["split", "chunk", "unbind"], NamespaceTarget.TORCH: ["split", "chunk", "unbind"], + NamespaceTarget.ATEN: ["split_with_sizes"], } hw_config_names = [HWConfigOpName.SPLIT, HWConfigOpName.CHUNK] @@ -1047,6 +1054,7 @@ class PTInterpolateMetatype(PTOperatorMetatype): name = "InterpolateOp" module_to_function_names = { NamespaceTarget.TORCH_NN_FUNCTIONAL: ["interpolate"], + NamespaceTarget.ATEN: ["upsample_nearest2d", "upsample_nearest_exact2d"], } hw_config_names = [HWConfigOpName.INTERPOLATE] num_expected_input_edges = 1 diff --git a/tests/post_training/pipelines/lm_weight_compression.py b/tests/post_training/pipelines/lm_weight_compression.py index 27479fe6a50..b31a7aae5ca 100644 --- a/tests/post_training/pipelines/lm_weight_compression.py +++ b/tests/post_training/pipelines/lm_weight_compression.py @@ -19,7 +19,6 @@ import numpy as np import openvino as ov import torch -from datasets import load_dataset from memory_profiler import memory_usage from optimum.exporters.openvino.convert import export_from_model from optimum.intel.openvino import OVModelForCausalLM @@ -28,6 +27,7 @@ from whowhatbench import Evaluator import nncf +from datasets import load_dataset from tests.post_training.pipelines.base import BackendType from tests.post_training.pipelines.base import BaseTestPipeline from tests.post_training.pipelines.base import StatsFromOutput diff --git a/tests/torch/sparsity/movement/helpers/run_recipe.py b/tests/torch/sparsity/movement/helpers/run_recipe.py index 77b3140a967..383552932d5 100644 --- a/tests/torch/sparsity/movement/helpers/run_recipe.py +++ b/tests/torch/sparsity/movement/helpers/run_recipe.py @@ -20,7 +20,6 @@ import torch.nn import torch.nn.functional as F import torch.utils.data -from datasets import Dataset from transformers import AutoModelForAudioClassification from transformers import AutoModelForImageClassification from transformers import AutoModelForSequenceClassification @@ -34,6 +33,7 @@ from transformers import SwinConfig from transformers import Wav2Vec2Config +from datasets import Dataset from nncf import NNCFConfig from nncf.experimental.torch.sparsity.movement.scheduler import MovementSchedulerParams from nncf.torch.dynamic_graph.io_handling import FillerInputElement diff --git a/tests/torch/sparsity/movement/helpers/trainer.py b/tests/torch/sparsity/movement/helpers/trainer.py index 89ffeb6c865..2af37c5b2f4 100644 --- a/tests/torch/sparsity/movement/helpers/trainer.py +++ b/tests/torch/sparsity/movement/helpers/trainer.py @@ -14,7 +14,6 @@ import numpy as np import torch -from datasets import Dataset # pylint: disable=no-name-in-module from transformers import TrainingArguments from transformers.trainer import Trainer from transformers.trainer_callback import TrainerCallback @@ -22,6 +21,7 @@ from transformers.trainer_callback import TrainerState from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from datasets import Dataset # pylint: disable=no-name-in-module from nncf.api.compression import CompressionAlgorithmController from nncf.common.compression import BaseCompressionAlgorithmController from nncf.common.utils.tensorboard import prepare_for_tensorboard diff --git a/tests/torch/sparsity/movement/test_model_saving.py b/tests/torch/sparsity/movement/test_model_saving.py index 979b86a7b18..9d2401609fc 100644 --- a/tests/torch/sparsity/movement/test_model_saving.py +++ b/tests/torch/sparsity/movement/test_model_saving.py @@ -18,7 +18,6 @@ import pytest import torch from addict import Dict -from datasets import Dataset from onnx import numpy_helper from openvino._offline_transformations import apply_fused_names_cleanup from openvino._offline_transformations import apply_moc_transformations @@ -29,6 +28,7 @@ from scipy.special import softmax from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from datasets import Dataset from nncf.torch import create_compressed_model from nncf.torch.checkpoint_loading import load_state from tests.torch.helpers import PTTensorListComparator diff --git a/tests/torch/sparsity/movement/training_scripts/run_glue.py b/tests/torch/sparsity/movement/training_scripts/run_glue.py index 360832a5bb7..d0f5b14269e 100644 --- a/tests/torch/sparsity/movement/training_scripts/run_glue.py +++ b/tests/torch/sparsity/movement/training_scripts/run_glue.py @@ -12,12 +12,13 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple -import datasets import evaluate import jstyleson import numpy as np from transformers.training_args import ParallelMode +import datasets + # isort: off from nncf import NNCFConfig from nncf.api.compression import CompressionAlgorithmController diff --git a/tests/torch_fx/__init__.py b/tests/torch_fx/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/tests/torch_fx/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/torch_fx/helpers.py b/tests/torch_fx/helpers.py new file mode 100644 index 00000000000..8bbc721e0fa --- /dev/null +++ b/tests/torch_fx/helpers.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import torch +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from fastdownload import FastDownload + + +class TinyImagenetDatasetManager: + DATASET_URL = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" + DATASET_PATH = "~/.cache/nncf/tests/datasets" + + def __init__(self, image_size: int, batch_size: int) -> None: + self.image_size = image_size + self.batch_size = batch_size + + @staticmethod + def download_dataset() -> Path: + downloader = FastDownload(base=TinyImagenetDatasetManager.DATASET_PATH, archive="downloaded", data="extracted") + return downloader.get(TinyImagenetDatasetManager.DATASET_URL) + + @staticmethod + def prepare_tiny_imagenet_200(dataset_dir: Path): + # Format validation set the same way as train set is formatted. + val_data_dir = dataset_dir / "val" + val_images_dir = val_data_dir / "images" + if not val_images_dir.exists(): + return + + val_annotations_file = val_data_dir / "val_annotations.txt" + with open(val_annotations_file, "r") as f: + val_annotation_data = map(lambda line: line.split("\t")[:2], f.readlines()) + for image_filename, image_label in val_annotation_data: + from_image_filepath = val_images_dir / image_filename + to_image_dir = val_data_dir / image_label + if not to_image_dir.exists(): + to_image_dir.mkdir() + to_image_filepath = to_image_dir / image_filename + from_image_filepath.rename(to_image_filepath) + val_annotations_file.unlink() + val_images_dir.rmdir() + + def create_data_loaders(self): + dataset_path = TinyImagenetDatasetManager.download_dataset() + + TinyImagenetDatasetManager.prepare_tiny_imagenet_200(dataset_path) + print(f"Successfully downloaded and prepared dataset at: {dataset_path}") + + train_dir = dataset_path / "train" + val_dir = dataset_path / "val" + + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + train_dir, + transforms.Compose( + [ + transforms.Resize(self.image_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ] + ), + ) + val_dataset = datasets.ImageFolder( + val_dir, + transforms.Compose( + [ + transforms.Resize(self.image_size), + transforms.ToTensor(), + normalize, + ] + ), + ) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True, sampler=None + ) + + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, pin_memory=True + ) + + # Creating separate dataloader with batch size = 1 + # as dataloaders with batches > 1 are not supported yet. + calibration_dataset = torch.utils.data.DataLoader( + val_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True + ) + + return train_loader, val_loader, calibration_dataset diff --git a/tests/torch_fx/requirements.txt b/tests/torch_fx/requirements.txt new file mode 100644 index 00000000000..99ee43ce754 --- /dev/null +++ b/tests/torch_fx/requirements.txt @@ -0,0 +1 @@ +fastdownload==0.0.7 \ No newline at end of file diff --git a/tests/torch_fx/test_sanity.py b/tests/torch_fx/test_sanity.py new file mode 100644 index 00000000000..197c2f95472 --- /dev/null +++ b/tests/torch_fx/test_sanity.py @@ -0,0 +1,141 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +import openvino.torch # noqa +import pytest +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.models as models +from torch._export import capture_pre_autograd_graph + +import nncf +from nncf.common.logging.track_progress import track +from nncf.torch.dynamic_graph.patch_pytorch import disable_patching +from tests.torch_fx.helpers import TinyImagenetDatasetManager + +IMAGE_SIZE = 64 +BATCH_SIZE = 128 + + +@dataclass +class SanitySampleCase: + model_id: str + checkpoint_url: str + top1_int8_ref: float + ref_num_q: int + ref_num_dq: int + + +MODELS = ( + SanitySampleCase( + "resnet18", + "https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/302_resnet18_fp32_v1.pth", + 55.23, + 51, + 58, + ), +) + + +def get_model(model_id: str, checkpoint_url: str, device: torch.device) -> torch.nn.Module: + num_classes = 200 # 200 is for Tiny ImageNet, default is 1000 for ImageNet + model = getattr(models, model_id)(weights=None) + # Update the last FC layer for Tiny ImageNet number of classes. + model.fc = nn.Linear(in_features=512, out_features=num_classes, bias=True) + model.to(device) + checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device("cpu"), progress=False) + model.load_state_dict(checkpoint["state_dict"]) + return model + + +def validate(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module, device: torch.device) -> float: + top1_sum = 0.0 + with torch.no_grad(): + for images, target in track(val_loader, total=len(val_loader), description="Validation:"): + images = images.to(device) + target = target.to(device) + + # Compute output. + output = model(images) + + # Measure accuracy and record loss. + [acc1] = accuracy(output, target, topk=(1,)) + top1_sum += acc1.item() + + num_samples = len(val_loader) + top1_avg = top1_sum / num_samples + return top1_avg + + +def accuracy(output: torch.Tensor, target: torch.tensor, topk: Tuple[int, ...] = (1,)): + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def count_q_dq(model: torch.fx.GraphModule): + q, dq = 0, 0 + for node in model.graph.nodes: + if node.op == "call_function" and hasattr(node.target, "overloadpacket"): + node_type = str(node.target.overloadpacket).split(".")[1] + if node_type in ["quantize_per_tensor", "quantize_per_channel"]: + q += 1 + elif node_type in ["dequantize_per_tensor", "dequantize_per_channel"]: + dq += 1 + return q, dq + + +@pytest.mark.parametrize("test_case", MODELS) +def test_sanity(test_case: SanitySampleCase): + with disable_patching(): + device = torch.device("cpu") + model = get_model(test_case.model_id, test_case.checkpoint_url, device) + _, val_dataloader, calibration_dataset = TinyImagenetDatasetManager( + IMAGE_SIZE, BATCH_SIZE + ).create_data_loaders() + + def transform_fn(data_item): + return data_item[0].to(device) + + calibration_dataset = nncf.Dataset(calibration_dataset, transform_fn) + + with torch.no_grad(): + ex_input = next(iter(calibration_dataset.get_inference_data())) + model.eval() + exported_model = capture_pre_autograd_graph(model, args=(ex_input,)) + quantized_model = nncf.quantize(exported_model, calibration_dataset) + quantized_model = torch.compile(quantized_model, backend="openvino") + + top1_int8 = validate(val_dataloader, quantized_model, device) + assert np.isclose(top1_int8, test_case.top1_int8_ref, atol=1e-2) + + num_q, num_dq = count_q_dq(quantized_model) + assert num_q == test_case.ref_num_q + assert num_dq == test_case.ref_num_dq diff --git a/torch_compile_ex_release.py b/torch_compile_ex_release.py new file mode 100644 index 00000000000..7bd0addf02e --- /dev/null +++ b/torch_compile_ex_release.py @@ -0,0 +1,217 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Enable torch inductor freezing feature first +import os + +os.environ["TORCHINDUCTOR_FREEZING"] = "1" + + +import argparse +import copy +import time +from collections import defaultdict + +import openvino.torch # noqa +import torch + +# Optional: using the C++ wrapper instead of default Python wrapper +import torch._inductor.config as config +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq +import torchvision.models as models +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import convert_pt2e +from torch.ao.quantization.quantize_pt2e import prepare_pt2e +from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer +from torch.fx.passes.graph_drawer import FxGraphDrawer + +from nncf.experimental.torch_fx.model_transformer import QPARAMPerChannel +from nncf.experimental.torch_fx.model_transformer import QPARAMSPerTensor +from nncf.experimental.torch_fx.model_transformer import insert_qdq_to_model +from nncf.experimental.torch_fx.nncf_graph_builder import GraphConverter # noqa + + +def get_exported_model_from_nn_module(module, example_inputs): + with torch.no_grad(): + return capture_pre_autograd_graph(module, example_inputs) + + +NNCF_IMPL = True + + +def get_qsetup(exported_model, example_inputs): + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + + prepared_model = prepare_pt2e(exported_model, quantizer) + prepared_model(*example_inputs) + converted_model = convert_pt2e(prepared_model) + g = FxGraphDrawer(converted_model, "resnet18_int8") + g.get_dot_graph().write_svg("resnet18_int8_compiled.svg") + qsetup = defaultdict(lambda: dict()) + + for node in converted_model.graph.nodes: + if "dequantize" in node.name: + quantize = node.all_input_nodes[0] + # place = "activations" + # if len(quantize.all_input_nodes) > 1: + # place = "weights" + if "per_tensor" in node.name: + params = QPARAMSPerTensor(*node.args[1:]) + else: + params = [] + for i in range(1, 3): + name = node.args[i].target + params.append(getattr(converted_model, name)) + params = QPARAMPerChannel(*(params + list(node.args[3:]))) + + target_node_name = quantize.all_input_nodes[0].name + qsetup[target_node_name] = params + return qsetup + + +def quantize(model, example_inputs): + if NNCF_IMPL: + # Use NNCF here on exported model + # to create a quantized model which is compatible with + # convert_pt2e function + pass + # 1. Convert torch.graph to NNCFGraph. + # # 2. Analize nncf grpah for SQ/CA + # # 3. Collect statistics + # # 4. Update params + # 5. Analize nncf graph for quantization + # 6. Insert observers + # 7. prepared_model(*example_inputs) + # 8. convert_pt2e(prepared_model) + import nncf + + calibration_dataset = nncf.Dataset(example_inputs) + exported_model = get_exported_model_from_nn_module(model, example_inputs) + quantized_model = nncf.quantize(exported_model, calibration_dataset) + g = FxGraphDrawer(quantized_model, "resnet18_quantized_native_nncf") + g.get_dot_graph().write_svg("resnet18_quantized_native_nncf.svg") + return quantized_model + + else: + # g = FxGraphDrawer(exported_model, "resnet18") + # g.get_dot_graph().write_svg("resnet18_compiled.svg") + + # MOCK NNCF QUANTIZATION + exported_model = get_exported_model_from_nn_module(model, example_inputs) + qsetup = get_qsetup(exported_model, example_inputs) + exported_model = get_exported_model_from_nn_module(model, example_inputs) + exported_model = insert_qdq_to_model(exported_model, qsetup) + g = FxGraphDrawer(exported_model, "resnet18_int8") + g.get_dot_graph().write_svg("resnet18_int8_compiled_manually.svg") + return exported_model + + return None # converted_model + + +config.cpp_wrapper = True + + +def measure_time(model, example_inputs, num_iters): + with torch.no_grad(): + model(*example_inputs) + total_time = 0 + for i in range(0, num_iters): + start_time = time.time() + model(*example_inputs) + total_time += time.time() - start_time + average_time = (total_time / num_iters) * 1000 + return average_time + + +def get_dummy_dataset(): + traced_bs = 1 + x = torch.randn(traced_bs, 3, 224, 224).contiguous(memory_format=torch.channels_last) + example_inputs = (x,) + return example_inputs + + +def main_nncf(model_name, num_iters): + model = models.__dict__[model_name](pretrained=True) + model = model.eval() + + example_inputs = get_dummy_dataset() + import nncf + + calibration_dataset = nncf.Dataset(example_inputs) + quantized_model = nncf.quantize(model, calibration_dataset) + + import openvino as ov + + ov_model = ov.convert_model(quantized_model.cpu(), example_input=example_inputs[0]) + ov.serialize(ov_model, "./model_cache_nncf/model.xml") + + +def main(model_name, num_iters): + model = models.__dict__[model_name](pretrained=True) + model = model.eval() + + example_inputs = get_dummy_dataset() + + converted_model = quantize(copy.deepcopy(model), example_inputs) + + print("original model execution time: ", measure_time(model, example_inputs, num_iters)) + + native_optimized_model_fp32 = torch.compile(model) + print( + "Torch Inductor FP32 model execution time: ", + measure_time(native_optimized_model_fp32, example_inputs, num_iters), + ) + + native_optimized_model_int8 = torch.compile(converted_model) + print( + "Torch Inductor INT8 model execution time: ", + measure_time(native_optimized_model_int8, example_inputs, num_iters), + ) + + ov_optimized_model_fp32 = torch.compile(model, backend="openvino") + print( + "Torch.compile OpenVINO FP32 model execution time: ", + measure_time(ov_optimized_model_fp32, example_inputs, num_iters), + ) + + ov_optimized_model_int8 = torch.compile( + converted_model, backend="openvino", options={"model_caching": True, "cache_dir": "./model_cache"} + ) + print( + "Torch.compile OpenVINO INT8 model execution time: ", + measure_time(ov_optimized_model_int8, example_inputs, num_iters), + ) + + import intel_extension_for_pytorch # noqa + + ipex_optimized_model_fp32 = torch.compile(model, backend="ipex") + print( + "Torch.compile IPEX FP32 model execution time: ", + measure_time(ipex_optimized_model_fp32, example_inputs, num_iters), + ) + + ipex_optimized_model_int8 = torch.compile(converted_model, backend="ipex") + print( + "Torch.compile IPEX INT8 model execution time: ", + measure_time(ipex_optimized_model_int8, example_inputs, num_iters), + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num_iters", help="number of inference iterations", type=int, default=100) + parser.add_argument("--model", help="torchvision model name", type=str, default="resnet18") + args = parser.parse_args() + model_name = args.model + num_iters = args.num_iters + main(model_name, num_iters) + # main_nncf(model_name, num_iters) diff --git a/yolo_fx_bad_metrics_repro.py b/yolo_fx_bad_metrics_repro.py new file mode 100644 index 00000000000..b5c05d6bbcb --- /dev/null +++ b/yolo_fx_bad_metrics_repro.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Tuple + +import numpy as np +import torch +from tqdm import tqdm +from ultralytics.data.utils import check_det_dataset +from ultralytics.engine.validator import BaseValidator as Validator +from ultralytics.models.yolo import YOLO +from ultralytics.utils.torch_utils import de_parallel + + +def print_statistics(stats: np.ndarray, total_images: int, total_objects: int) -> None: + mp, mr, map50, mean_ap = ( + stats["metrics/precision(B)"], + stats["metrics/recall(B)"], + stats["metrics/mAP50(B)"], + stats["metrics/mAP50-95(B)"], + ) + s = ("%20s" + "%12s" * 6) % ("Class", "Images", "Labels", "Precision", "Recall", "mAP@.5", "mAP@.5:.95") + print(s) + pf = "%20s" + "%12i" * 2 + "%12.3g" * 4 # print format + print(pf % ("all", total_images, total_objects, mp, mr, map50, mean_ap)) + + +def prepare_validation(model: YOLO, data: str) -> Tuple[Validator, torch.utils.data.DataLoader]: + # custom = {"rect": True, "batch": 1} # method defaults + # rect: false forces to resize all input pictures to one size + custom = {"rect": False, "batch": 1} # method defaults + args = {**model.overrides, **custom, "mode": "val"} # highest priority args on the right + + validator = model._smart_load("validator")(args=args, _callbacks=model.callbacks) + stride = 32 # default stride + validator.stride = stride # used in get_dataloader() for padding + validator.data = check_det_dataset(data) + validator.init_metrics(de_parallel(model)) + + data_loader = validator.get_dataloader(validator.data.get(validator.args.split), validator.args.batch) + return validator, data_loader + + +def validate(model, data_loader: torch.utils.data.DataLoader, validator: Validator) -> Tuple[Dict, int, int]: + with torch.no_grad(): + for batch in data_loader: + batch = validator.preprocess(batch) + preds = model(batch["img"]) + preds = validator.postprocess(preds) + validator.update_metrics(preds, batch) + stats = validator.get_stats() + return stats, validator.seen, validator.nt_per_class.sum() + + +def main(torch_fx): + # ultralytics @ git+https://github.com/THU-MIG/yolov10.git@2c36ab0f108efdd17c7e290564bb845ccb6844d8 + # pip install git+https://github.com/THU-MIG/yolov10.git + # pip install huggingface-hub + # yolo_model = YOLO("yolov10n.pt") + + yolo_model = YOLO("yolov8n") + + model_type = "torch" + model = yolo_model.model + if torch_fx: + model = torch.compile(model) + model_type = "FX" + print(f"FP32 {model_type} model validation results:") + validator, data_loader = prepare_validation(yolo_model, "coco128.yaml") + stats, total_images, total_objects = validate(model, tqdm(data_loader), validator) + print_statistics(stats, total_images, total_objects) + + +if __name__ == "__main__": + print("Torch model:") + main(torch_fx=False) + print("Torch FX model:") + main(torch_fx=True) From 98b5a2539f080ae9f182bd97533a8919643cd280 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Tue, 25 Jun 2024 16:40:41 +0200 Subject: [PATCH 2/7] Test code is removed Comments --- aa_torch_fx.py | 456 ------------------ .../openvino/tiny_llama/main.py | 2 +- .../tiny_llama_find_hyperparams/main.py | 2 +- nncf/common/factory.py | 12 +- nncf/common/graph/patterns/manager.py | 14 +- .../{torch_fx => torch/fx}/__init__.py | 0 .../fx}/model_transformer.py | 81 +--- .../fx}/nncf_graph_builder.py | 86 ++-- .../fx}/quantization/__init__.py | 0 .../fx}/quantization/quantize_model.py | 33 +- .../fx}/statistics/__init__.py | 0 .../fx}/statistics/aggregator.py | 16 +- .../{torch_fx => torch/fx}/transformations.py | 220 ++++++--- nncf/experimental/torch_fx/engine.py | 48 -- .../fast_bias_correction/torch_fx_backend.py | 10 +- .../algorithms/min_max/torch_fx_backend.py | 13 +- nncf/quantization/quantize_model.py | 2 +- nncf/torch/engine.py | 5 +- nncf/torch/graph/operator_metatypes.py | 2 - .../pipelines/lm_weight_compression.py | 2 +- tests/{torch_fx => torch/fx}/__init__.py | 0 tests/{torch_fx => torch/fx}/helpers.py | 0 tests/{torch_fx => torch/fx}/test_sanity.py | 18 +- tests/torch/requirements.txt | 5 + .../sparsity/movement/helpers/run_recipe.py | 2 +- .../sparsity/movement/helpers/trainer.py | 2 +- .../sparsity/movement/test_model_saving.py | 2 +- .../movement/training_scripts/run_glue.py | 3 +- tests/torch_fx/requirements.txt | 1 - torch_compile_ex_release.py | 217 --------- yolo_fx_bad_metrics_repro.py | 86 ---- 31 files changed, 266 insertions(+), 1074 deletions(-) delete mode 100644 aa_torch_fx.py rename nncf/experimental/{torch_fx => torch/fx}/__init__.py (100%) rename nncf/experimental/{torch_fx => torch/fx}/model_transformer.py (53%) rename nncf/experimental/{torch_fx => torch/fx}/nncf_graph_builder.py (62%) rename nncf/experimental/{torch_fx => torch/fx}/quantization/__init__.py (100%) rename nncf/experimental/{torch_fx => torch/fx}/quantization/quantize_model.py (75%) rename nncf/experimental/{torch_fx => torch/fx}/statistics/__init__.py (100%) rename nncf/experimental/{torch_fx => torch/fx}/statistics/aggregator.py (83%) rename nncf/experimental/{torch_fx => torch/fx}/transformations.py (64%) delete mode 100644 nncf/experimental/torch_fx/engine.py rename tests/{torch_fx => torch/fx}/__init__.py (100%) rename tests/{torch_fx => torch/fx}/helpers.py (100%) rename tests/{torch_fx => torch/fx}/test_sanity.py (91%) delete mode 100644 tests/torch_fx/requirements.txt delete mode 100644 torch_compile_ex_release.py delete mode 100644 yolo_fx_bad_metrics_repro.py diff --git a/aa_torch_fx.py b/aa_torch_fx.py deleted file mode 100644 index 339d33d1598..00000000000 --- a/aa_torch_fx.py +++ /dev/null @@ -1,456 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import copy -import re -import subprocess -import time -import warnings -from itertools import islice -from pathlib import Path - -import numpy as np -import openvino as ov -import openvino.torch # noqa -import pandas as pd -import torch -import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq -import torchvision.models as models -from sklearn.metrics import accuracy_score -from torch._export import capture_pre_autograd_graph -from torch.ao.quantization.quantize_pt2e import convert_pt2e -from torch.ao.quantization.quantize_pt2e import prepare_pt2e -from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer -from torch.fx.passes.graph_drawer import FxGraphDrawer -from torch.jit import TracerWarning -from torchao.utils import benchmark_model as ao_benchmark_model -from torchvision import datasets -from transformers import AutoImageProcessor -from transformers import AutoModelForImageClassification - -import nncf -from nncf.common.logging.track_progress import track -from nncf.common.quantization.structs import QuantizationPreset # noqa -from nncf.parameters import ModelType -from nncf.torch.dynamic_graph.patch_pytorch import disable_patching - -warnings.filterwarnings("ignore", category=TracerWarning) -warnings.filterwarnings("ignore", category=UserWarning) - -DATASET_IMAGENET = "/home/dlyakhov/datasets/imagenet/val" - -hf_models = () - - -def hf_model_builder(model_id: str): - def build(weights): - processor = AutoImageProcessor.from_pretrained(model_id) - model = AutoModelForImageClassification.from_pretrained(model_id) - - class ModelWithProcessing(torch.nn.Module): - def __init__(self, processor, model): - super().__init__() - self.processor = processor - self.model = model - - def forward(self, x): - processed_input = processor(x, return_tensors="pt") - return model(processed_input) - - # return ModelWithProcessing(processor, model) - return model - - class DummyWeights: - def transforms(self): - return models.ResNet18_Weights.DEFAULT.transforms() - - @property - def meta(self): - return {} - - return build, DummyWeights() - - -MODELS_DICT = { - "vit_h_14": (models.vit_h_14, models.ViT_H_14_Weights.DEFAULT), - "vit_b_16": (models.vit_b_16, models.ViT_B_16_Weights.DEFAULT), - "swin_v2_t": (models.swin_v2_t, models.Swin_V2_T_Weights.DEFAULT), - "swin_v2_s": (models.swin_v2_s, models.Swin_V2_S_Weights.DEFAULT), - "resnet18": (models.resnet18, models.ResNet18_Weights.DEFAULT), - "resnet50": (models.resnet50, models.ResNet50_Weights.DEFAULT), - "mobilenet_v2": (models.mobilenet_v2, models.MobileNet_V2_Weights.DEFAULT), - "mobilenet_v3_small": (models.mobilenet_v3_small, models.MobileNet_V3_Small_Weights.DEFAULT), - "mobilenet_v3_large": (models.mobilenet_v3_large, models.MobileNet_V3_Large_Weights.DEFAULT), - # "densenet161": (models.densenet161, models.DenseNet161_Weights.DEFAULT), - "vgg16": (models.vgg16, models.VGG16_Weights.DEFAULT), - "efficientnet_b7": (models.efficientnet_b7, models.EfficientNet_B7_Weights.DEFAULT), - "inception_v3": (models.inception_v3, models.Inception_V3_Weights.DEFAULT), - "regnet_x_32gf": (models.regnet_x_32gf, models.RegNet_X_32GF_Weights.DEFAULT), - # "google/vit-base-patch16-224": hf_model_builder("google/vit-base-patch16-224"), - # "convnext_large": (models.convnext_large, models.ConvNeXt_Large_Weights.DEFAULT), - # "convnext_small": (models.convnext_small, models.ConvNeXt_Small_Weights.DEFAULT), -} - - -def measure_time(model, example_inputs, num_iters=1000): - with torch.no_grad(): - model(*example_inputs) - total_time = 0 - for i in range(0, num_iters): - start_time = time.time() - model(*example_inputs) - total_time += time.time() - start_time - average_time = (total_time / num_iters) * 1000 - return average_time - - -def measure_time_ov(model, example_inputs, num_iters=1000): - ie = ov.Core() - compiled_model = ie.compile_model(model, "CPU") - infer_request = compiled_model.create_infer_request() - infer_request.infer(example_inputs) - total_time = 0 - for i in range(0, num_iters): - start_time = time.time() - infer_request.infer(example_inputs) - total_time += time.time() - start_time - average_time = (total_time / num_iters) * 1000 - return average_time - - -def quantize(model, example_inputs, calibration_dataset, subset_size=300): - with torch.no_grad(): - exported_model = capture_pre_autograd_graph(model, example_inputs) - - quantizer = X86InductorQuantizer() - quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) - - prepared_model = prepare_pt2e(exported_model, quantizer) - from tqdm import tqdm - - for inp, _ in islice(tqdm(calibration_dataset), subset_size): - prepared_model(inp) - converted_model = convert_pt2e(prepared_model) - return converted_model - - -def validate(model, val_loader, subset_size=None): - dataset_size = len(val_loader) - - predictions = np.zeros((dataset_size)) - references = -1 * np.ones((dataset_size)) - - with track(total=dataset_size, description="Validation") as pbar: - - for i, (images, target) in enumerate(val_loader): - if subset_size is not None and i >= subset_size: - break - - output_data = model(images).detach().numpy() - predicted_label = np.argmax(output_data, axis=1) - predictions[i] = predicted_label.item() - references[i] = target - pbar.progress.update(pbar.task, advance=1) - acc_top1 = accuracy_score(predictions, references) * 100 - print(acc_top1) - return acc_top1 - - -def validate_ov(model, val_loader): - dataset_size = len(val_loader) - - # Initialize result tensors for async inference support. - predictions = np.zeros((dataset_size)) - references = -1 * np.ones((dataset_size)) - - core = ov.Core() - compiled_model = core.compile_model(model) - - infer_queue = ov.AsyncInferQueue(compiled_model, 4) - with track(total=dataset_size, description="Validation") as pbar: - - def process_result(request, userdata): - output_data = request.get_output_tensor().data - predicted_label = np.argmax(output_data, axis=1) - predictions[userdata] = predicted_label.item() - pbar.progress.update(pbar.task, advance=1) - - infer_queue.set_callback(process_result) - - for i, (images, target) in enumerate(val_loader): - # W/A for memory leaks when using torch DataLoader and OpenVINO - image_copies = copy.deepcopy(images.numpy()) - infer_queue.start_async(image_copies, userdata=i) - references[i] = target - - infer_queue.wait_all() - - acc_top1 = accuracy_score(predictions, references) * 100 - print(acc_top1) - return acc_top1 - - -def run_benchmark(model_path: Path, shape) -> float: - command = f"benchmark_app -m {model_path} -d CPU -api async -t 15" - command += f' -shape="[{",".join(str(x) for x in shape)}]"' - cmd_output = subprocess.check_output(command, shell=True) # nosec - match = re.search(r"Throughput\: (.+?) FPS", str(cmd_output)) - return float(match.group(1)) - - -def torch_ao_sq_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input): - import torch - from torchao.quantization.smoothquant import smooth_fq_linear_to_inference - from torchao.quantization.smoothquant import swap_linear_with_smooth_fq_linear - - # Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor - torch._inductor.config.force_fuse_int_mm_with_mul = True - - # plug in your model - # model = torch.compile(pt_model) - model = pt_model - - # convert linear modules to smoothquant - # linear module in calibration mode - swap_linear_with_smooth_fq_linear(model) - - # Create a data loader for calibration - calibration_loader = val_loader - - # Calibrate the model - model.train() - from tqdm import tqdm - - for batch in tqdm(islice(calibration_loader, 300)): - inputs = batch[0] - model(inputs) - - # set it to inference mode - smooth_fq_linear_to_inference(model) - - # compile the model to improve performance - model = torch.compile(model, mode="max-autotune") - acc1_quant_model = validate(model, val_loader) - print(f"torch ao metric acc@1: {acc1_quant_model}") - result["torch_ao_quant_model_acc"] = acc1_quant_model - - latency = ao_benchmark_model(model, 20, example_input) - print(f"torch ao latency: {latency}") - result["torch_ao_quant_model_latency"] = latency - - -def nncf_fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input): - with disable_patching(): - with torch.no_grad(): - exported_model = capture_pre_autograd_graph(pt_model, (example_input,)) - - def transform(x): - return x[0] - - quant_fx_model = nncf.quantize( - exported_model, nncf.Dataset(val_loader, transform_func=transform), model_type=ModelType.TRANSFORMER - ) - quant_compile_model = torch.compile(quant_fx_model, backend="openvino") - - # acc1_quant_model = validate(quant_compile_model, val_loader) - acc1_quant_model = -1.0 - latency_fx = measure_time(quant_compile_model, (example_input,)) - print(f"latency: {latency_fx}") - result["acc1_nncf_fx_quant_model"] = acc1_quant_model - result["torch_compile_ov_latency_nncf_fx_quant_model"] = latency_fx - - g = FxGraphDrawer(quant_compile_model, f"b_nncf_{pt_model.__class__.__name__}_int8") - g.get_dot_graph().write_svg(f"b_nncf_{pt_model.__class__.__name__}_int8.svg") - - # EXPORT TO OV - exported_model = torch.export.export(quant_compile_model, (example_input,)) - ov_quant_model = ov.convert_model(exported_model, example_input=example_input) - quant_file_path = output_dir / "quant.xml" - ov.save_model(ov_quant_model, quant_file_path) - - fps = run_benchmark(quant_file_path, shape_input) - print(f"fps: {fps}") - result["ov_fps_nncf_fx_quant_model"] = fps - - -def fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input): - with disable_patching(): - fp32_pt_model = copy.deepcopy(pt_model) - fp32_compile_model = torch.compile(fp32_pt_model, backend="openvino") - - quant_pt_model = quantize(fp32_compile_model, (example_input,), val_loader) - quant_compile_model = torch.compile(quant_pt_model, backend="openvino") - - g = FxGraphDrawer(quant_pt_model, f"b_pt_{pt_model.__class__.__name__}_int8") - g.get_dot_graph().write_svg(f"b_pt_{pt_model.__class__.__name__}_int8.svg") - - acc1_quant_model = validate(quant_compile_model, val_loader) - result["acc1_quant_model"] = acc1_quant_model - - latency_fx = measure_time(quant_compile_model, (example_input,)) - print(f"latency: {latency_fx}") - result["torch_compile_latency_fps_quant_model"] = latency_fx - - -def nncf_pt_2_ov_quantization(pt_model, val_loader, example_input, output_dir, result, shape_input): - def transform(x): - return x[0] - - nncf_model = nncf.quantize(copy.deepcopy(pt_model), nncf.Dataset(val_loader, transform_func=transform)) - - ov_nncf_model = ov.convert_model(nncf_model, example_input=example_input) - nncf_pt_file_path = output_dir / "nncf_pt.xml" - ov.save_model(ov_nncf_model, nncf_pt_file_path) - acc1_nncf_pt = validate_ov(ov_nncf_model, val_loader) - result["acc1_nncf_pt"] = acc1_nncf_pt - fps = run_benchmark(nncf_pt_file_path, shape_input) - print(f"fps: {fps}") - result["ov_fps_nncf_pt"] = fps - - -def nncf_ov_2_ov_quantization(ov_fp32_model, val_loader, output_dir, result, shape_input): - def transform(x): - return np.array(x[0]) - - from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters - from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters - - advanced_params = AdvancedQuantizationParameters() - # for sq_param in [-1, 0.15, 0.5, 0.75]: - for sq_param in [0.95]: - advanced_params.smooth_quant_alphas = AdvancedSmoothQuantParameters(matmul=sq_param) - - from copy import deepcopy - - fast_bias_correction = True - nncf_ov_int8_model = nncf.quantize( - deepcopy(ov_fp32_model), - nncf.Dataset(val_loader, transform_func=transform), - fast_bias_correction=fast_bias_correction, - model_type=ModelType.TRANSFORMER, - preset=QuantizationPreset.MIXED, - advanced_parameters=advanced_params, - ) - acc1_nncf_ov = validate_ov(nncf_ov_int8_model, val_loader) - result[f"acc1_nncf_ov_{sq_param}"] = acc1_nncf_ov - for precision, model in (("int8", nncf_ov_int8_model), ("fp32", ov_fp32_model)): - nncf_ov_file_path = output_dir / f"nncf_ov_{precision}.xml" - ov.save_model(model, nncf_ov_file_path) - fps = run_benchmark(nncf_ov_file_path, shape_input) - print(f"fps_{precision}: {fps} {sq_param}") - result[f"ov_fps_nncf_ov_{precision}_{sq_param}"] = fps - - latency = measure_time_ov(model, next(iter(val_loader))[0], num_iters=10_000) - print(f"latency_{precision}: {latency}") - result[f"ov_latency_nncf_ov_{precision}_{sq_param}"] = latency - - -def process_model(model_name: str): - - result = {"name": model_name} - model_cls, model_weights = MODELS_DICT[model_name] - output_dir = Path("models") / model_name - output_dir.mkdir(exist_ok=True) - ############################################################## - # Prepare dataset - ############################################################## - - val_dataset = datasets.ImageFolder(root=DATASET_IMAGENET, transform=model_weights.transforms()) - val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=2, shuffle=False) - - ############################################################## - # Prepare original model - ############################################################## - - pt_model = model_cls(weights=model_weights) - pt_model = pt_model.eval() - example_input = next(iter(val_loader))[0] - shape_input = list(example_input.shape) - ############################################################## - # Process FP32 Model - ############################################################## - - fp32_pt_model = copy.deepcopy(pt_model) - - orig_infer_acc1 = model_weights.meta.get("_metrics", {}).get("ImageNet-1K", {}).get("acc@1") - print(f"fp32 model metric: {orig_infer_acc1}") - # orig_infer_acc1 = validate(fp32_pt_model, val_loader) - result["acc1_fp32_openvino"] = orig_infer_acc1 - - fp32_pt_model = torch.export.export(fp32_pt_model, (example_input,)) - ov_fp32_model = ov.convert_model(fp32_pt_model, example_input=example_input) - ov_fp32_file_path = None - ov_fp32_file_path = output_dir / "fp32.xml" - ov.save_model(ov_fp32_model, ov_fp32_file_path) - # result["fps_fp32_openvino"] = run_benchmark(ov_fp32_file_path, shape_input) - # print(f"fps_fp32_openvino {result['fps_fp32_openvino']}") - - del fp32_pt_model - ############################################################## - # Process Torch AO Quantize with SQ - ############################################################## - # torch_ao_sq_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) - - ############################################################## - # with torch.no_grad(): - # exported_model = capture_pre_autograd_graph(pt_model, (example_input,)) - # latency_fx = measure_time(torch.compile(exported_model), (example_input,)) - # print(f"latency: {latency_fx}") - ############################################################# - - ############################################################## - # Process PT Quantize - ############################################################## - fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) - - ############################################################## - # Process NNCF FX Quantize - ############################################################## - # nncf_fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) - - ############################################################## - # Process NNCF Quantize by PT - ############################################################## - # nncf_pt_2_ov_quantization(pt_model, val_loader, example_input, output_dir, result, shape_input) - - ############################################################## - # Process NNCF Quantize by OV - ############################################################## - # nncf_ov_2_ov_quantization(ov_fp32_model, val_loader, output_dir, result, shape_input) - - print(result) - return result - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--model", help="torchvision model name", type=str, default="all") - parser.add_argument("--file_name", help="output csv file_name", type=str, default="result.csv") - - args = parser.parse_args() - - results_list = [] - if args.model == "all": - for model_name in MODELS_DICT: - print("---------------------------------------------------") - print(f"name: {model_name}") - results_list.append(process_model(model_name)) - else: - results_list.append(process_model(args.model)) - - df = pd.DataFrame(results_list) - print(df) - df.to_csv(args.file_name) - - -if __name__ == "__main__": - main() diff --git a/examples/llm_compression/openvino/tiny_llama/main.py b/examples/llm_compression/openvino/tiny_llama/main.py index 436f95f0af4..dd03a4361c6 100644 --- a/examples/llm_compression/openvino/tiny_llama/main.py +++ b/examples/llm_compression/openvino/tiny_llama/main.py @@ -11,12 +11,12 @@ import time from functools import partial +import datasets import numpy as np import openvino as ov from optimum.intel.openvino import OVModelForCausalLM from transformers import AutoTokenizer -import datasets import nncf diff --git a/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py b/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py index 56acd0dd4f2..7ab0176eb85 100644 --- a/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py +++ b/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py @@ -17,12 +17,12 @@ import numpy as np import openvino as ov +from datasets import load_dataset from optimum.intel import OVModelForCausalLM from transformers import AutoTokenizer from whowhatbench import Evaluator import nncf -from datasets import load_dataset from nncf.common.logging import nncf_logger DataItem = TypeVar("DataItem") diff --git a/nncf/common/factory.py b/nncf/common/factory.py index d5d13605a07..c5a921c8068 100644 --- a/nncf/common/factory.py +++ b/nncf/common/factory.py @@ -43,7 +43,7 @@ def create(model: TModel) -> NNCFGraph: return GraphConverter.create_nncf_graph(model) if model_backend == BackendType.TORCH_FX: - from nncf.experimental.torch_fx.nncf_graph_builder import GraphConverter + from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter return GraphConverter.create_nncf_graph(model) if model_backend == BackendType.TORCH: @@ -77,7 +77,7 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer: return PTModelTransformer(model) if model_backend == BackendType.TORCH_FX: - from nncf.experimental.torch_fx.model_transformer import FXModelTransformer + from nncf.experimental.torch.fx.model_transformer import FXModelTransformer return FXModelTransformer(model) raise nncf.UnsupportedBackendError( @@ -103,14 +103,10 @@ def create(model: TModel) -> Engine: from nncf.openvino.engine import OVNativeEngine return OVNativeEngine(model) - if model_backend == BackendType.TORCH: + if model_backend in (BackendType.TORCH, BackendType.TORCH_FX): from nncf.torch.engine import PTEngine return PTEngine(model) - if model_backend == BackendType.TORCH_FX: - from nncf.experimental.torch_fx.engine import FXEngine - - return FXEngine(model) raise nncf.UnsupportedBackendError( "Cannot create backend-specific engine because {} is not supported!".format(model_backend.value) ) @@ -164,7 +160,7 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator: return PTStatisticsAggregator(dataset) if model_backend == BackendType.TORCH_FX: - from nncf.experimental.torch_fx.statistics.aggregator import FXStatisticsAggregator + from nncf.experimental.torch.fx.statistics.aggregator import FXStatisticsAggregator return FXStatisticsAggregator(dataset) raise nncf.UnsupportedBackendError( diff --git a/nncf/common/graph/patterns/manager.py b/nncf/common/graph/patterns/manager.py index e784f772628..2c32e3abf56 100644 --- a/nncf/common/graph/patterns/manager.py +++ b/nncf/common/graph/patterns/manager.py @@ -47,16 +47,11 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> Dict[HWFusedPatternNam Dict[HWFusedPatternNames, Callable[[], GraphPattern]], OPENVINO_HW_FUSED_PATTERNS.registry_dict ) return registry - if backend == BackendType.TORCH: + if backend in (BackendType.TORCH, BackendType.TORCH_FX): from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS registry = cast(Dict[HWFusedPatternNames, Callable[[], GraphPattern]], PT_HW_FUSED_PATTERNS.registry_dict) return registry - if backend == BackendType.TORCH_FX: - from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS - - registry = PT_HW_FUSED_PATTERNS.registry_dict - return registry raise ValueError(f"Hardware-fused patterns not implemented for {backend} backend.") @staticmethod @@ -82,16 +77,11 @@ def _get_backend_ignored_patterns_map( Dict[IgnoredPatternNames, Callable[[], GraphPattern]], OPENVINO_IGNORED_PATTERNS.registry_dict ) return registry - if backend == BackendType.TORCH: + if backend in (BackendType.TORCH, BackendType.TORCH_FX): from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS registry = cast(Dict[IgnoredPatternNames, Callable[[], GraphPattern]], PT_IGNORED_PATTERNS.registry_dict) return registry - if backend == BackendType.TORCH_FX: - from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS - - registry = PT_IGNORED_PATTERNS.registry_dict - return registry raise ValueError(f"Ignored patterns not implemented for {backend} backend.") @staticmethod diff --git a/nncf/experimental/torch_fx/__init__.py b/nncf/experimental/torch/fx/__init__.py similarity index 100% rename from nncf/experimental/torch_fx/__init__.py rename to nncf/experimental/torch/fx/__init__.py diff --git a/nncf/experimental/torch_fx/model_transformer.py b/nncf/experimental/torch/fx/model_transformer.py similarity index 53% rename from nncf/experimental/torch_fx/model_transformer.py rename to nncf/experimental/torch/fx/model_transformer.py index 48b3cf0c1f1..b4db5ed4fa7 100644 --- a/nncf/experimental/torch_fx/model_transformer.py +++ b/nncf/experimental/torch/fx/model_transformer.py @@ -10,8 +10,6 @@ # limitations under the License. from collections import defaultdict - -# from functools import partial from typing import Callable, List, Union import torch @@ -20,27 +18,12 @@ from nncf.common.graph.model_transformer import ModelTransformer from nncf.common.graph.transformations.commands import Command -from nncf.common.graph.transformations.commands import TargetType from nncf.common.graph.transformations.commands import TransformationPriority from nncf.common.graph.transformations.commands import TransformationType from nncf.torch.graph.transformations.commands import PTModelExtractionCommand -from nncf.torch.graph.transformations.commands import PTTargetPoint from nncf.torch.graph.transformations.layout import PTTransformationLayout -class FXModuleInsertionCommand(Command): - def __init__( - self, - target_points: List[PTTargetPoint], - module_to_insert: torch.nn.Module, - priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, - ): - super().__init__(TransformationType.INSERT) - self.target_points = target_points - self.module_to_insert = module_to_insert - self.priority = priority - - class FXApplyTransformationCommand(Command): def __init__( self, @@ -57,19 +40,16 @@ class FXModelTransformer(ModelTransformer): Applies transformations upon Torch FX model. """ - # TODO: manage priorities of transformations - def __init__(self, model: torch.fx.GraphModule): super().__init__(model) self._command_transformation_ordered_pairs = [ - # TODO: Move the module insertion command to a transformation (FXApplyTransformationCommand, self._apply_transformation), - (FXModuleInsertionCommand, self._apply_module_insertion), (PTModelExtractionCommand, self._apply_model_extraction), ] def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule: + # TODO(dlyakhov): Manage priorities of transformations. transformations = transformation_layout.transformations aggregated_transformations = defaultdict(list) for transformation in transformations: @@ -81,9 +61,9 @@ def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.G if transformations: model = transformation_fn(model, transformations) - # Do not eliminate dead code as - # the dead code is coputing statistics :) - # model.graph.eliminate_dead_code() + # Do not use model.graph.eliminate_dead_code() + # because the computational statistics code + # is interpolated as dead code. model.recompile() return model @@ -116,63 +96,12 @@ def _apply_model_extraction( return splitted_gm.extracted @staticmethod - def _apply_module_insertion( - model: torch.fx.GraphModule, - transformations: List[FXModuleInsertionCommand], - ) -> torch.fx.GraphModule: - """ - Applies insertion of PTSharedFnInsertionCommand commands. For each command method inserts - a torch module to the torch.fx.GraphModule and inserts call hooks for each command target points. - - :param model: Model to apply transformations. - :param transformations: List of the bias correction transformations. - :param device: Target device for the insertion functions. Applies only to - functions which are subclassed from torch.nn.Module. Do nothing in case device is None. - :return: A modified torch.fx.GraphModule. - """ - for transformation in transformations: - # Set fn to the model as an attribute - module_to_insert = transformation.module_to_insert - module_name_in_model = ( - ";".join( - "_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) - for tp in transformation.target_points - ) - + "_" - + str(id(module_to_insert)) - ) - assert not hasattr(model, module_name_in_model) - setattr(model, module_name_in_model, module_to_insert) - # Insert call_module nodes to the model - for target_point in transformation.target_points: - FXModelTransformer._create_call_module_node(model.graph, target_point, module_name_in_model) - return model - - @staticmethod - def get_graph_node_by_name(graph, name): + def get_graph_node_by_name(graph: torch.fx.Graph, name: str) -> torch.fx.Node: for node in graph.nodes: if node.name == name: return node raise RuntimeError(f"Node with name {name} is not found") - @staticmethod - def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint): - target_type = target_point.target_type - target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name) - if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]: - target_node = target_node.all_input_nodes[target_point.input_port_id] - elif target_type == TargetType.OPERATOR_POST_HOOK: - pass - else: - raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}") - return target_node - - @staticmethod - def _create_call_module_node(graph: torch.fx.Graph, target_point: PTTargetPoint, module_name: str): - target_node = FXModelTransformer._get_target_node(graph, target_point) - with graph.inserting_after(target_node): - graph.create_node("call_module", module_name, (target_node,), {}, name=module_name + "_graph_node") - @staticmethod def _apply_transformation( model: torch.fx.GraphModule, diff --git a/nncf/experimental/torch_fx/nncf_graph_builder.py b/nncf/experimental/torch/fx/nncf_graph_builder.py similarity index 62% rename from nncf/experimental/torch_fx/nncf_graph_builder.py rename to nncf/experimental/torch/fx/nncf_graph_builder.py index 9990ee3bf2f..e90d6bf7fa7 100644 --- a/nncf/experimental/torch_fx/nncf_graph_builder.py +++ b/nncf/experimental/torch/fx/nncf_graph_builder.py @@ -9,43 +9,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -from itertools import chain from typing import Tuple import torch.fx -from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ import nncf.torch.graph.operator_metatypes as om -from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode from nncf.common.graph.layer_attributes import Dtype from nncf.common.graph.operator_metatypes import UnknownMetatype from nncf.common.logging import nncf_logger -from nncf.experimental.torch_fx.transformations import separate_conv_and_bias -from nncf.experimental.torch_fx.transformations import separate_linear_and_bias -from nncf.experimental.torch_fx.transformations import view_to_reshape from nncf.torch.graph.graph import PTNNCFGraph from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES class GraphConverter: """ - Builds the NNCFGraph from an OpenVINO model. + Builds the NNCFGraph from an torch.fx.GraphModule instance. """ - @staticmethod - def _get_leaf_node(module: torch.nn.Module, node: torch.fx.Node) -> torch.nn.Module: - py_obj = module - assert isinstance(node.target, str) - atoms = node.target.split(".") - for atom in atoms: - if not hasattr(py_obj, atom): - raise RuntimeError(str(py_obj) + " does not have attribute " + atom + "!") - py_obj = getattr(py_obj, atom) - return py_obj - @staticmethod def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMetatype]: + """ + Retrieves node's type and metatype. + + :param node: Given node. + :return: Node's type and metatype. + """ if node.op == "placeholder": node_type = "input" node_metatype = om.PTInputNoopMetatype @@ -61,7 +50,7 @@ def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMe elif node.target.__name__ == "getitem": node_type = "__getitem__" else: - # TODO: get correct nodes types from this nodes as well + # TODO(dlyakhov): get correct nodes types from this nodes as well node_type = str(node.target) node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type) else: @@ -72,7 +61,7 @@ def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMe return node_type, node_metatype @staticmethod - def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph: + def create_nncf_graph(model: torch.fx.GraphModule) -> PTNNCFGraph: """ Creates NNCFGraph from GraphModule. All nodes from model which have valid metatype are added to NNCFGraph. @@ -82,46 +71,18 @@ def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph: :return: NNCFGraph. """ - _fuse_conv_bn_(model) - # BN fuses to conv bias, conv+bias joined op - # needs to be splited for nncf - separate_linear_and_bias(model) - separate_conv_and_bias(model) - view_to_reshape(model) - nncf_graph = PTNNCFGraph() for source_node in model.graph.nodes: - node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node) - nncf_node = nncf_graph.add_nncf_node( + nncf_graph.add_nncf_node( node_name=source_node.name, node_type=node_type, - node_metatype=node_metatype, # layer_attributes, + node_metatype=node_metatype, ) - def get_module_params_or_buffers(): - for pname, ptensor in chain(leaf_module.named_parameters(), leaf_module.named_buffers()): - pname1 = source_node.name + "." + pname - nncf_param_node = nncf_graph.add_nncf_node( - pname1, - "parameter" if isinstance(ptensor, torch.nn.Parameter) else "buffer", - om.PTConstNoopMetatype, - ) - # TODO: Use valid tensor_shape, input_port_id, output_port_id - nncf_graph.add_edge_between_nncf_nodes( - nncf_param_node, nncf_node, tensor_shape=[1, 1, 1, 1], input_port_id=0, output_port_id=0 - ) - - if source_node.op == "call_module": - leaf_module = GraphConverter._get_leaf_node(model, source_node) - - if not isinstance(leaf_module, torch.fx.GraphModule): - get_module_params_or_buffers() - for source_node in model.graph.nodes: - source_nncf_node = nncf_graph.get_node_by_name(source_node.name) for idx, dist_node in enumerate(source_node.users): dist_node_id = nncf_graph.get_node_by_name(dist_node.name).node_id @@ -142,8 +103,23 @@ def get_module_params_or_buffers(): @staticmethod def get_edge_params( - model, source_node: torch.fx.Node, source_nncf_node: NNCFNode, dist_node: torch.fx.Node, output_idx: int - ): + model: torch.fx.GraphModule, + source_node: torch.fx.Node, + source_nncf_node: NNCFNode, + dist_node: torch.fx.Node, + output_idx: int, + ) -> Tuple[int, int, Tuple[int, ...]]: + """ + Retrieves edge params from the given source_node and dist_node pair. + + :param model: A torch.fx.GraphModule instance. + :param source_node: Source node in format of torch.fx.Node. + :param source_nncf_node: Source node in format of NNCFNode. + :param dist_node: Distance node in format of torch.fx.Node. + :param output_idx: Output indes of the source_node. + :return: Tuple of edge parameters: edge input port id, edge output port id and + edge tensor shape. + """ output_port_id = 0 if source_node.op in ("get_attr",): tensor_shape = tuple(getattr(model, source_node.target).shape) @@ -158,10 +134,8 @@ def get_edge_params( tensor = source_node.meta["val"] tensor_shape = tuple(tensor.shape) else: - nncf_logger.info( - f"Edge shape between {source_node.name} and {dist_node.name} is unknown. Using [1,1,1,1] instead." - ) - tensor_shape = [1, 1, 1, 1] + nncf_logger.info(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.") + tensor_shape = None input_port_id = dist_node.all_input_nodes.index(source_node) return input_port_id, output_port_id, tensor_shape diff --git a/nncf/experimental/torch_fx/quantization/__init__.py b/nncf/experimental/torch/fx/quantization/__init__.py similarity index 100% rename from nncf/experimental/torch_fx/quantization/__init__.py rename to nncf/experimental/torch/fx/quantization/__init__.py diff --git a/nncf/experimental/torch_fx/quantization/quantize_model.py b/nncf/experimental/torch/fx/quantization/quantize_model.py similarity index 75% rename from nncf/experimental/torch_fx/quantization/quantize_model.py rename to nncf/experimental/torch/fx/quantization/quantize_model.py index 0f40800fb49..08bb73ee854 100644 --- a/nncf/experimental/torch_fx/quantization/quantize_model.py +++ b/nncf/experimental/torch/fx/quantization/quantize_model.py @@ -18,15 +18,20 @@ from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat from torch.ao.quantization.pt2e.utils import _disallow_eval_train +from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ from torch.fx import GraphModule from torch.fx.passes.infra.pass_manager import PassManager import nncf from nncf.common.factory import NNCFGraphFactory +from nncf.common.logging import nncf_logger from nncf.common.quantization.structs import QuantizationPreset from nncf.common.quantization.structs import QuantizationScheme from nncf.data import Dataset -from nncf.experimental.torch_fx.transformations import merge_conv_and_bias +from nncf.experimental.torch.fx.transformations import merge_conv_and_bias +from nncf.experimental.torch.fx.transformations import separate_conv_and_bias +from nncf.experimental.torch.fx.transformations import separate_linear_and_bias +from nncf.experimental.torch.fx.transformations import view_to_reshape from nncf.parameters import ModelType from nncf.parameters import QuantizationMode from nncf.parameters import TargetDevice @@ -53,6 +58,11 @@ def quantize_impl( """ Implementation of the `quantize()` method for the Torch FX backend. """ + nncf_logger.warning( + "Experimental Torch FX quantization backend is being used for the given torch.fx.GraphModule model." + " Torch FX PTQ is an experimental feature, consider using Torch or OpenVino PTQ backends" + " in case of errors or a poor model performance." + ) if fast_bias_correction is False: raise ValueError(f"fast_bias_correction={fast_bias_correction} is not supported") if target_device == TargetDevice.CPU_SPR: @@ -66,8 +76,7 @@ def quantize_impl( if advanced_parameters is None: advanced_parameters = AdvancedQuantizationParameters() - # torch.fx supports only assymetric activations quantization - # force to use only this type of quantization + # Default quantization mode is assymmetric activations_quantization_params = advanced_parameters.activations_quantization_params if activations_quantization_params is None: activations_quantization_params = QuantizationParameters() @@ -84,6 +93,24 @@ def quantize_impl( ignored_scope=ignored_scope, advanced_parameters=advanced_parameters, ) + + # BatchNorm operations have 3 output ports, + # to make it easier for alorithms to work + # with the target graph BatchNorm operations + # are being fused + _fuse_conv_bn_(copied_model) + + # To make it easier for bias correction algorithms, + # biases are being separated by the followng calls. + separate_linear_and_bias(copied_model) + separate_conv_and_bias(copied_model) + + # View requires at least one dimension spans + # across two contiguous subspaces and reshape is not. + # To prevent error during statistics collection + # all view operation are translated to reshape. + view_to_reshape(copied_model) + nncf_graph = NNCFGraphFactory.create(copied_model) quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset) merge_conv_and_bias(quantized_model) diff --git a/nncf/experimental/torch_fx/statistics/__init__.py b/nncf/experimental/torch/fx/statistics/__init__.py similarity index 100% rename from nncf/experimental/torch_fx/statistics/__init__.py rename to nncf/experimental/torch/fx/statistics/__init__.py diff --git a/nncf/experimental/torch_fx/statistics/aggregator.py b/nncf/experimental/torch/fx/statistics/aggregator.py similarity index 83% rename from nncf/experimental/torch_fx/statistics/aggregator.py rename to nncf/experimental/torch/fx/statistics/aggregator.py index f1ce6ff05ec..fd9e8e1e386 100644 --- a/nncf/experimental/torch_fx/statistics/aggregator.py +++ b/nncf/experimental/torch/fx/statistics/aggregator.py @@ -21,7 +21,8 @@ from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer from nncf.common.tensor_statistics.aggregator import StatisticsAggregator from nncf.experimental.common.tensor_statistics.collectors import TensorCollector -from nncf.experimental.torch_fx.model_transformer import FXModuleInsertionCommand +from nncf.experimental.torch.fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch.fx.transformations import leaf_module_insertion_transformation_builder from nncf.tensor import Tensor from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.return_types import maybe_get_values_from_torch_return_type @@ -36,7 +37,7 @@ def __init__(self, collector: TensorCollector): super().__init__() self._collector = collector - def forward(self, x: torch.Tensor): + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Register inputs hook function. @@ -74,11 +75,12 @@ def _get_transformation_layout_extra_outputs( for _statistic_point in _statistic_points: for collectors in _statistic_point.algorithm_to_tensor_collectors.values(): for collector in collectors: + transformation = leaf_module_insertion_transformation_builder( + TensorCollectorModule(collector), [_statistic_point.target_point] + ) transformation_commands.append( - FXModuleInsertionCommand( - [_statistic_point.target_point], - TensorCollectorModule(collector), - TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION, + FXApplyTransformationCommand( + transformation, TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION ) ) @@ -91,7 +93,7 @@ def _get_transformation_layout_extra_outputs( def _get_merged_statistic_points( statistic_points: StatisticPointsContainer, model: TModel, graph: NNCFGraph ) -> StatisticPointsContainer: - # TODO: mirgate to experimental statistic collector and use common merging algorithm + # TODO(dlyakhov): mirgate to experimental statistic collector and use common merging algorithm return statistic_points @staticmethod diff --git a/nncf/experimental/torch_fx/transformations.py b/nncf/experimental/torch/fx/transformations.py similarity index 64% rename from nncf/experimental/torch_fx/transformations.py rename to nncf/experimental/torch/fx/transformations.py index d572c06b120..dc8a2122adb 100644 --- a/nncf/experimental/torch_fx/transformations.py +++ b/nncf/experimental/torch/fx/transformations.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import Callable, List, Optional import torch @@ -21,29 +20,74 @@ from nncf.common.graph.graph import NNCFNode from nncf.common.graph.transformations.commands import TargetType -from nncf.experimental.torch_fx.model_transformer import FXModelTransformer +from nncf.experimental.torch.fx.model_transformer import FXModelTransformer from nncf.torch.graph.transformations.commands import PTTargetPoint +TransformationFNType = Callable[[torch.fx.GraphModule], None] -def fake_quantize_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]): - def fake_quantize_insertion_transformation(model: torch.fx.GraphModule): - module_attr_name = _set_module_to_the_graph_module(model, quantizer, target_points) + +def module_insertion_tranformation_builder( + module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint] +) -> TransformationFNType: + """ + Returns transformation which inserts given module to a target model and calls given module + after each target points. For each target node all original ouputs are being replaced + by outputs of corresponded module call. + + :param module_to_insert: Given torch.nn.Module to insert. + :param target_points: Target points to insert the target module. + :returns: Transformation which inserts given module to a target model and calls given module + after each target points. For each target node all original ouputs + are being replaced by outputs of corresponded module call. + """ + + def module_insertion_transformation(model: torch.fx.GraphModule): + module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_points) graph = model.graph for target_point in target_points: - target_node = FXModelTransformer._get_target_node(model.graph, target_point) - with graph.inserting_after(target_node): - fq_node = graph.create_node( - "call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_quantizer" - ) + target_node = _get_target_node(graph, target_point) + new_node = _insert_call_module(graph, target_node, module_attr_name) for user in list(target_node.users): - if user is fq_node: + if user is new_node: continue - user.replace_input_with(target_node, fq_node) + user.replace_input_with(target_node, new_node) + + return module_insertion_transformation - return fake_quantize_insertion_transformation +def leaf_module_insertion_transformation_builder( + module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint] +) -> TransformationFNType: + """ + Returns transformation which inserts given module to a target model + and calls given module after each target points. + + :param module_to_insert: Given torch.nn.Module to insert. + :param target_points: Target points to insert the target module. + :returns: Transformation which which inserts given module to a target model + and calls given module after each target points. + """ + + def leaf_module_insertion_transformation(model: torch.fx.GraphModule): + module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_points) + # Insert call_module nodes to the model + graph = model.graph + for target_point in target_points: + target_node = _get_target_node(graph, target_point) + _insert_call_module(graph, target_node, module_attr_name) + + return leaf_module_insertion_transformation + + +def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType: + """ + Return transformation which updates constant of the given bias node to the given value. + + :param node: Bias node which requires bias constant update. + :param value: New value to use as the bias constant. + :return: Transformation which updates constant of the given bias node to the given value. + """ -def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor): def bias_update_transformation(model: torch.fx.GraphModule): graph = model.graph target_node_name = node.node_name @@ -59,25 +103,44 @@ def bias_update_transformation(model: torch.fx.GraphModule): return bias_update_transformation -def qdq_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]): +def qdq_insertion_tranformation_builder( + quantizer: FakeQuantize, target_points: List[PTTargetPoint] +) -> TransformationFNType: + """ + Returns transformation which inserts quantize-dequantize operations with parameters + inherited from the given quantizer to each given target point. + + :param quantizer: Quantizer module to inherit quantization parameters from. + :param target_points: List of target point used to insert quantize-dequantize pairs. + :return: Transformation which inserts quantize-dequantize operations with parameters + inherited from the given quantizer to each given target point. + """ + def qdq_insertion_tranformation(model: torch.fx.GraphModule): if any(tp.target_type != TargetType.OPERATION_WITH_WEIGHTS for tp in target_points) and len(target_points) > 1: raise RuntimeError for target_point in target_points: - target_node = FXModelTransformer._get_target_node(model.graph, target_point) - insert_one_qdq(model, target_node, quantizer, target_point) + target_node = _get_target_node(model.graph, target_point) + insert_one_qdq_before_node(model, target_node, quantizer) return qdq_insertion_tranformation -def insert_one_qdq( - model: torch.fx.GraphModule, target_node: torch.fx.Node, quantizer: FakeQuantize, target_point: PTTargetPoint -): +def insert_one_qdq_before_node(model: torch.fx.GraphModule, target_node: torch.fx.Node, quantizer: FakeQuantize): + """ + Inserts quantize-dequantize after the target node to the target model. + + :param model: Target model. + :param target_node: Target node, quantizer-dequantizer pair is inserted just after the + target node. + :param quantizer: Quantizer module to inherit quantization parameters from. + """ + # Copied from torch.ao.quantization.quantize_pt2e.convert_pt2e # 1. extract information for inserting q/dq node from activation_post_process node_type = "call_function" quantize_op: Optional[Callable] = None - # scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + dtype = torch.int8 if quantizer.quant_min < 0 else torch.uint8 if quantizer.is_per_channel: qparams = { @@ -142,11 +205,50 @@ def insert_one_qdq( user.replace_input_with(target_node, dq_node) +def _insert_call_module(graph: torch.fx.Graph, target_node: torch.fx.Node, module_attr_name: str): + """ + Inserts module call node to the graph after the target node. + + :param graph: Graph to insert module call node. + :param target_node: Target node, module call node is being iserted just after the target node. + :param module_attr_name: The name of the graph attribute which keeps the target module. + """ + with graph.inserting_after(target_node): + return graph.create_node( + "call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_graph_node" + ) + + +def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint) -> torch.fx.Node: + """ + Returns TorchFX graph node correspondent to the target point. + + :param graph: Target torch.fx.Graph. + :param target_point: A target point to find the target node. + :return: TorchFX graph node correspondent to the target point. + """ + target_type = target_point.target_type + target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name) + if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]: + target_node = target_node.all_input_nodes[target_point.input_port_id] + elif target_type == TargetType.OPERATOR_POST_HOOK: + pass + else: + raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}") + return target_node + + def _set_module_to_the_graph_module( model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint] ) -> str: """ Sets given module to the given torch.fx.GraphModule with unique name. + + :param graph: Target torch.fx.Graph. + :param module_to_insert: Module to insert to the target graph. + :param target_points: Target points which will be used to insert target module + to the graph. + :return: A graph module attribute name which keep given module. """ module_to_insert = module_to_insert module_name_in_model = ( @@ -161,7 +263,13 @@ def _set_module_to_the_graph_module( return module_name_in_model -def _is_linear(n: torch.fx.Node): +def _is_linear(n: torch.fx.Node) -> bool: + """ + Returns true if given node is a linear node, else False. + + :param n: The given node. + :return: True if given node is a linear node, else False. + """ return n.op == "call_function" and n.target in [torch.ops.aten.linear.default] @@ -169,6 +277,8 @@ def separate_linear_and_bias(model: torch.fx.GraphModule): """ Separates one joined linear+bias node to two nodes: conv and bias. Needed as nncf does not expect joined conv + + :param model: Target model. """ add_node_target = torch.ops.aten.add_.Tensor for n in model.graph.nodes: @@ -178,6 +288,9 @@ def separate_linear_and_bias(model: torch.fx.GraphModule): continue linear_node = n linear_bias_node = linear_node.args[2] + while linear_bias_node.op != "get_attr": + # Assume zero argument is on a path to the constant + linear_bias_node = linear_bias_node.args[0] conv_bias_value = _get_tensor_constant_from_node(linear_bias_node, model) args = list(n.args) args[2] = None @@ -204,6 +317,11 @@ def separate_linear_and_bias(model: torch.fx.GraphModule): def view_to_reshape(model: torch.fx.GraphModule): + """ + Replaces all instances of view to a reshape call. + + :param model: Target model. + """ for n in model.graph.nodes: if not (n.op == "call_function" and n.target in [torch.ops.aten.view.default]): continue @@ -222,6 +340,8 @@ def separate_conv_and_bias(model: torch.fx.GraphModule): """ Separates one joined conv+bias node to two nodes: conv and bias. Needed as nncf does not expect joined conv + + :param model: Target model. """ add_node_target = torch.ops.aten.add_.Tensor for n in model.graph.nodes: @@ -266,6 +386,8 @@ def merge_conv_and_bias(model: torch.fx.GraphModule): """ Separates one joined conv+bias node to two nodes: conv and bias. Needed as nncf does not expect joined conv + + :param model: Target model. """ add_node_targets = (torch.ops.aten.add_.Tensor,) for n in model.graph.nodes: @@ -294,57 +416,3 @@ def merge_conv_and_bias(model: torch.fx.GraphModule): model.graph.eliminate_dead_code() model.recompile() - - -def _is_scaled_dot_product_attention(n: torch.fx.Node): - return n.op == "call_function" and n.target in [torch.ops.aten.scaled_dot_product_attention.default] - - -def _unfold_sdp(model: torch.fx.GraphModule, node: torch.fx.Node): - transpose_target = torch.ops.aten.transpose.int - matmul_target = torch.ops.aten.matmul.default - mul_target = torch.ops.aten.multiply.Scalar - softmax_target = torch.ops.aten.softmax.int - - query, key, value = node.args - q, k, v = (n.meta["val"] for n in node.args) - n = query.meta["val"].shape[-1] - scale_factor = 1 / math.sqrt(n) - - with model.graph.inserting_before(node): - k_transposed = model.graph.create_node("call_function", transpose_target, (key, -2, -1), {}) - k = k.transpose(-2, -1) - k_transposed.meta["val"] = torch.clone(k) - - sa = model.graph.create_node("call_function", matmul_target, (query, k_transposed), {}) - attn_value = q @ k - sa.meta["val"] = torch.clone(attn_value) - - sa_scaled = model.graph.create_node("call_function", mul_target, (sa, float(scale_factor)), {}) - sa_scaled.meta["val"] = torch.clone(attn_value) - - softmax = model.graph.create_node("call_function", softmax_target, (sa_scaled, -1), {}) - softmax.meta["val"] = torch.clone(attn_value) - - result = model.graph.create_node("call_function", matmul_target, (softmax, value), {}) - r = attn_value @ v - result.meta["val"] = torch.clone(r) - - for user in list(node.users): - user.replace_input_with(node, result) - model.graph.eliminate_dead_code() - - -@staticmethod -def unfold_scaled_dot_product_attention(model: torch.fx.GraphModule): - for n in model.graph.nodes: - if not _is_scaled_dot_product_attention(n): - continue - args = n.args - if len(args) > 3: - raise NotImplementedError( - f"Unfolding of scaled dot product attention node {n}" " with more than 3 inputs is not implemented yet" - ) - _unfold_sdp(model, n) - model.graph.eliminate_dead_code() - model.recompile() diff --git a/nncf/experimental/torch_fx/engine.py b/nncf/experimental/torch_fx/engine.py deleted file mode 100644 index 5f9dc2ac221..00000000000 --- a/nncf/experimental/torch_fx/engine.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Any, Dict, Tuple, Union - -import torch -from torch import nn - -from nncf.common.engine import Engine - - -class FXEngine(Engine): - """ - Engine for the Pytorch FX backend. - """ - - def __init__(self, model: nn.Module): - """ - Constructor. - - :param model: Pytorch module to infer. - """ - - self._model = model - - def infer( - self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]] - ) -> Union[torch.Tensor, Dict[str, Any]]: - """ - Runs Torch model on the provided input. - - :param input_data: Inputs for the model. - :return: Model outputs. - """ - - if isinstance(input_data, dict): - return self._model(**input_data) - if isinstance(input_data, tuple): - return self._model(*input_data) - return self._model(input_data) diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py index 089afd4ab11..e32d31a9c7d 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py @@ -22,8 +22,8 @@ from nncf.common.graph.definitions import NNCFGraphNodeType from nncf.common.graph.transformations.commands import TargetType from nncf.experimental.common.tensor_statistics.collectors import TensorCollector -from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand -from nncf.experimental.torch_fx.transformations import bias_update_transformation_builder +from nncf.experimental.torch.fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch.fx.transformations import bias_update_transformation_builder from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend from nncf.tensor import Tensor from nncf.torch.graph.transformations.commands import PTModelExtractionCommand @@ -83,7 +83,7 @@ def create_input_data(shape: Tuple[int], data: List[Tensor], input_name: str, ch @staticmethod def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor: # TODO: make a node_name_vs_node map to speed up the process - from nncf.experimental.torch_fx.model_transformer import FXModelTransformer + from nncf.experimental.torch.fx.model_transformer import FXModelTransformer bias_node = nncf_graph.get_next_nodes(node)[0] graph_bias_node = FXModelTransformer.get_graph_node_by_name(model.graph, bias_node.node_name) @@ -114,3 +114,7 @@ def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: @staticmethod def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: return node.node_name, node.node_name + + @staticmethod + def get_activation_channel_axis(node: NNCFNode, pord_id: int, input_shape: Tuple[int]) -> int: + return node.metatype.output_channel_axis diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index c5403386441..3cde7eb552e 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -27,8 +27,8 @@ from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic -from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand -from nncf.experimental.torch_fx.transformations import qdq_insertion_tranformation_builder +from nncf.experimental.torch.fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch.fx.transformations import qdq_insertion_tranformation_builder from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import StatisticsType @@ -41,6 +41,7 @@ from nncf.torch.graph.graph import PTTargetPoint from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand from nncf.torch.hardware.config import PTHWConfig +from nncf.torch.model_graph_manager import get_weight_tensor_port_ids from nncf.torch.nncf_network import NNCFNetwork from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT from nncf.torch.quantization.layers import QUANTIZATION_MODULES @@ -104,7 +105,7 @@ def group_conv_metatypes(self) -> List[OperatorMetatype]: @property def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]: - return [] + return [om.PTScaledDotProductAttentionMetatype] @property def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]: @@ -142,7 +143,7 @@ def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point return nncf_graph.get_input_shape_for_insertion_point(target_point) @staticmethod - def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint) -> Tuple[int]: + def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint, ndims: int) -> Tuple[int]: # TODO: support transpose conv and other cases return (0,) @@ -195,8 +196,8 @@ def get_statistic_collector( return collector @staticmethod - def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]: - return node.metatype.weight_port_ids + def get_weight_tensor_port_ids(node: NNCFNode, graph: NNCFGraph) -> List[Optional[int]]: + return get_weight_tensor_port_ids(node, graph) @staticmethod def get_weight_name(nncf_graph: NNCFGraph, target_point: PTTargetPoint) -> str: diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 166bae404ef..dc56e6daede 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -229,7 +229,7 @@ def quantize( advanced_parameters=advanced_parameters, ) if backend == BackendType.TORCH_FX: - from nncf.experimental.torch_fx.quantization.quantize_model import quantize_impl + from nncf.experimental.torch.fx.quantization.quantize_model import quantize_impl return quantize_impl( model=model, diff --git a/nncf/torch/engine.py b/nncf/torch/engine.py index c2a7c051132..2bc17db0416 100644 --- a/nncf/torch/engine.py +++ b/nncf/torch/engine.py @@ -15,6 +15,8 @@ from torch import nn from nncf.common.engine import Engine +from nncf.common.utils.backend import BackendType +from nncf.common.utils.backend import get_backend class PTEngine(Engine): @@ -30,7 +32,8 @@ def __init__(self, model: nn.Module): """ self._model = model - self._model.eval() + if get_backend(model) == BackendType.TORCH: + self._model.eval() def infer( self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]] diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index d8d8a1a50c8..3b998c40531 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -554,7 +554,6 @@ class PTAddMetatype(PTOperatorMetatype): "__radd__", ], NamespaceTarget.TORCH: ["add"], - NamespaceTarget.ATEN: ["add_"], } hw_config_names = [HWConfigOpName.ADD] num_expected_input_edges = 2 @@ -572,7 +571,6 @@ class PTSubMetatype(PTOperatorMetatype): "__rsub__", ], NamespaceTarget.TORCH: ["sub"], - NamespaceTarget.ATEN: ["sub_"], } hw_config_names = [HWConfigOpName.SUBTRACT] num_expected_input_edges = 2 diff --git a/tests/post_training/pipelines/lm_weight_compression.py b/tests/post_training/pipelines/lm_weight_compression.py index b31a7aae5ca..27479fe6a50 100644 --- a/tests/post_training/pipelines/lm_weight_compression.py +++ b/tests/post_training/pipelines/lm_weight_compression.py @@ -19,6 +19,7 @@ import numpy as np import openvino as ov import torch +from datasets import load_dataset from memory_profiler import memory_usage from optimum.exporters.openvino.convert import export_from_model from optimum.intel.openvino import OVModelForCausalLM @@ -27,7 +28,6 @@ from whowhatbench import Evaluator import nncf -from datasets import load_dataset from tests.post_training.pipelines.base import BackendType from tests.post_training.pipelines.base import BaseTestPipeline from tests.post_training.pipelines.base import StatsFromOutput diff --git a/tests/torch_fx/__init__.py b/tests/torch/fx/__init__.py similarity index 100% rename from tests/torch_fx/__init__.py rename to tests/torch/fx/__init__.py diff --git a/tests/torch_fx/helpers.py b/tests/torch/fx/helpers.py similarity index 100% rename from tests/torch_fx/helpers.py rename to tests/torch/fx/helpers.py diff --git a/tests/torch_fx/test_sanity.py b/tests/torch/fx/test_sanity.py similarity index 91% rename from tests/torch_fx/test_sanity.py rename to tests/torch/fx/test_sanity.py index 197c2f95472..f4dab10b965 100644 --- a/tests/torch_fx/test_sanity.py +++ b/tests/torch/fx/test_sanity.py @@ -28,12 +28,17 @@ import nncf from nncf.common.logging.track_progress import track from nncf.torch.dynamic_graph.patch_pytorch import disable_patching -from tests.torch_fx.helpers import TinyImagenetDatasetManager +from tests.torch.fx.helpers import TinyImagenetDatasetManager IMAGE_SIZE = 64 BATCH_SIZE = 128 +@pytest.fixture(name="tiny_imagenet_dataset", scope="module") +def tiny_imagenet_dataset_fixture(): + return TinyImagenetDatasetManager(IMAGE_SIZE, BATCH_SIZE).create_data_loaders() + + @dataclass class SanitySampleCase: model_id: str @@ -47,7 +52,7 @@ class SanitySampleCase: SanitySampleCase( "resnet18", "https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/302_resnet18_fp32_v1.pth", - 55.23, + 55.2, 51, 58, ), @@ -113,13 +118,12 @@ def count_q_dq(model: torch.fx.GraphModule): @pytest.mark.parametrize("test_case", MODELS) -def test_sanity(test_case: SanitySampleCase): +def test_sanity(test_case: SanitySampleCase, tiny_imagenet_dataset): with disable_patching(): + torch.manual_seed(42) device = torch.device("cpu") model = get_model(test_case.model_id, test_case.checkpoint_url, device) - _, val_dataloader, calibration_dataset = TinyImagenetDatasetManager( - IMAGE_SIZE, BATCH_SIZE - ).create_data_loaders() + _, val_dataloader, calibration_dataset = tiny_imagenet_dataset def transform_fn(data_item): return data_item[0].to(device) @@ -134,7 +138,7 @@ def transform_fn(data_item): quantized_model = torch.compile(quantized_model, backend="openvino") top1_int8 = validate(val_dataloader, quantized_model, device) - assert np.isclose(top1_int8, test_case.top1_int8_ref, atol=1e-2) + assert np.isclose(top1_int8, test_case.top1_int8_ref, atol=0.1) num_q, num_dq = count_q_dq(quantized_model) assert num_q == test_case.ref_num_q diff --git a/tests/torch/requirements.txt b/tests/torch/requirements.txt index bbd3a45c57e..be82652d65f 100644 --- a/tests/torch/requirements.txt +++ b/tests/torch/requirements.txt @@ -19,3 +19,8 @@ datasets==2.14.7 evaluate==0.3.0 openvino timm==0.9.2 + + +# Required for torch/fx tests +torchvision +fastdownload==0.0.7 diff --git a/tests/torch/sparsity/movement/helpers/run_recipe.py b/tests/torch/sparsity/movement/helpers/run_recipe.py index 383552932d5..77b3140a967 100644 --- a/tests/torch/sparsity/movement/helpers/run_recipe.py +++ b/tests/torch/sparsity/movement/helpers/run_recipe.py @@ -20,6 +20,7 @@ import torch.nn import torch.nn.functional as F import torch.utils.data +from datasets import Dataset from transformers import AutoModelForAudioClassification from transformers import AutoModelForImageClassification from transformers import AutoModelForSequenceClassification @@ -33,7 +34,6 @@ from transformers import SwinConfig from transformers import Wav2Vec2Config -from datasets import Dataset from nncf import NNCFConfig from nncf.experimental.torch.sparsity.movement.scheduler import MovementSchedulerParams from nncf.torch.dynamic_graph.io_handling import FillerInputElement diff --git a/tests/torch/sparsity/movement/helpers/trainer.py b/tests/torch/sparsity/movement/helpers/trainer.py index 2af37c5b2f4..89ffeb6c865 100644 --- a/tests/torch/sparsity/movement/helpers/trainer.py +++ b/tests/torch/sparsity/movement/helpers/trainer.py @@ -14,6 +14,7 @@ import numpy as np import torch +from datasets import Dataset # pylint: disable=no-name-in-module from transformers import TrainingArguments from transformers.trainer import Trainer from transformers.trainer_callback import TrainerCallback @@ -21,7 +22,6 @@ from transformers.trainer_callback import TrainerState from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR -from datasets import Dataset # pylint: disable=no-name-in-module from nncf.api.compression import CompressionAlgorithmController from nncf.common.compression import BaseCompressionAlgorithmController from nncf.common.utils.tensorboard import prepare_for_tensorboard diff --git a/tests/torch/sparsity/movement/test_model_saving.py b/tests/torch/sparsity/movement/test_model_saving.py index 9d2401609fc..979b86a7b18 100644 --- a/tests/torch/sparsity/movement/test_model_saving.py +++ b/tests/torch/sparsity/movement/test_model_saving.py @@ -18,6 +18,7 @@ import pytest import torch from addict import Dict +from datasets import Dataset from onnx import numpy_helper from openvino._offline_transformations import apply_fused_names_cleanup from openvino._offline_transformations import apply_moc_transformations @@ -28,7 +29,6 @@ from scipy.special import softmax from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR -from datasets import Dataset from nncf.torch import create_compressed_model from nncf.torch.checkpoint_loading import load_state from tests.torch.helpers import PTTensorListComparator diff --git a/tests/torch/sparsity/movement/training_scripts/run_glue.py b/tests/torch/sparsity/movement/training_scripts/run_glue.py index d0f5b14269e..360832a5bb7 100644 --- a/tests/torch/sparsity/movement/training_scripts/run_glue.py +++ b/tests/torch/sparsity/movement/training_scripts/run_glue.py @@ -12,13 +12,12 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple +import datasets import evaluate import jstyleson import numpy as np from transformers.training_args import ParallelMode -import datasets - # isort: off from nncf import NNCFConfig from nncf.api.compression import CompressionAlgorithmController diff --git a/tests/torch_fx/requirements.txt b/tests/torch_fx/requirements.txt deleted file mode 100644 index 99ee43ce754..00000000000 --- a/tests/torch_fx/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -fastdownload==0.0.7 \ No newline at end of file diff --git a/torch_compile_ex_release.py b/torch_compile_ex_release.py deleted file mode 100644 index 7bd0addf02e..00000000000 --- a/torch_compile_ex_release.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Enable torch inductor freezing feature first -import os - -os.environ["TORCHINDUCTOR_FREEZING"] = "1" - - -import argparse -import copy -import time -from collections import defaultdict - -import openvino.torch # noqa -import torch - -# Optional: using the C++ wrapper instead of default Python wrapper -import torch._inductor.config as config -import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq -import torchvision.models as models -from torch._export import capture_pre_autograd_graph -from torch.ao.quantization.quantize_pt2e import convert_pt2e -from torch.ao.quantization.quantize_pt2e import prepare_pt2e -from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer -from torch.fx.passes.graph_drawer import FxGraphDrawer - -from nncf.experimental.torch_fx.model_transformer import QPARAMPerChannel -from nncf.experimental.torch_fx.model_transformer import QPARAMSPerTensor -from nncf.experimental.torch_fx.model_transformer import insert_qdq_to_model -from nncf.experimental.torch_fx.nncf_graph_builder import GraphConverter # noqa - - -def get_exported_model_from_nn_module(module, example_inputs): - with torch.no_grad(): - return capture_pre_autograd_graph(module, example_inputs) - - -NNCF_IMPL = True - - -def get_qsetup(exported_model, example_inputs): - quantizer = X86InductorQuantizer() - quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) - - prepared_model = prepare_pt2e(exported_model, quantizer) - prepared_model(*example_inputs) - converted_model = convert_pt2e(prepared_model) - g = FxGraphDrawer(converted_model, "resnet18_int8") - g.get_dot_graph().write_svg("resnet18_int8_compiled.svg") - qsetup = defaultdict(lambda: dict()) - - for node in converted_model.graph.nodes: - if "dequantize" in node.name: - quantize = node.all_input_nodes[0] - # place = "activations" - # if len(quantize.all_input_nodes) > 1: - # place = "weights" - if "per_tensor" in node.name: - params = QPARAMSPerTensor(*node.args[1:]) - else: - params = [] - for i in range(1, 3): - name = node.args[i].target - params.append(getattr(converted_model, name)) - params = QPARAMPerChannel(*(params + list(node.args[3:]))) - - target_node_name = quantize.all_input_nodes[0].name - qsetup[target_node_name] = params - return qsetup - - -def quantize(model, example_inputs): - if NNCF_IMPL: - # Use NNCF here on exported model - # to create a quantized model which is compatible with - # convert_pt2e function - pass - # 1. Convert torch.graph to NNCFGraph. - # # 2. Analize nncf grpah for SQ/CA - # # 3. Collect statistics - # # 4. Update params - # 5. Analize nncf graph for quantization - # 6. Insert observers - # 7. prepared_model(*example_inputs) - # 8. convert_pt2e(prepared_model) - import nncf - - calibration_dataset = nncf.Dataset(example_inputs) - exported_model = get_exported_model_from_nn_module(model, example_inputs) - quantized_model = nncf.quantize(exported_model, calibration_dataset) - g = FxGraphDrawer(quantized_model, "resnet18_quantized_native_nncf") - g.get_dot_graph().write_svg("resnet18_quantized_native_nncf.svg") - return quantized_model - - else: - # g = FxGraphDrawer(exported_model, "resnet18") - # g.get_dot_graph().write_svg("resnet18_compiled.svg") - - # MOCK NNCF QUANTIZATION - exported_model = get_exported_model_from_nn_module(model, example_inputs) - qsetup = get_qsetup(exported_model, example_inputs) - exported_model = get_exported_model_from_nn_module(model, example_inputs) - exported_model = insert_qdq_to_model(exported_model, qsetup) - g = FxGraphDrawer(exported_model, "resnet18_int8") - g.get_dot_graph().write_svg("resnet18_int8_compiled_manually.svg") - return exported_model - - return None # converted_model - - -config.cpp_wrapper = True - - -def measure_time(model, example_inputs, num_iters): - with torch.no_grad(): - model(*example_inputs) - total_time = 0 - for i in range(0, num_iters): - start_time = time.time() - model(*example_inputs) - total_time += time.time() - start_time - average_time = (total_time / num_iters) * 1000 - return average_time - - -def get_dummy_dataset(): - traced_bs = 1 - x = torch.randn(traced_bs, 3, 224, 224).contiguous(memory_format=torch.channels_last) - example_inputs = (x,) - return example_inputs - - -def main_nncf(model_name, num_iters): - model = models.__dict__[model_name](pretrained=True) - model = model.eval() - - example_inputs = get_dummy_dataset() - import nncf - - calibration_dataset = nncf.Dataset(example_inputs) - quantized_model = nncf.quantize(model, calibration_dataset) - - import openvino as ov - - ov_model = ov.convert_model(quantized_model.cpu(), example_input=example_inputs[0]) - ov.serialize(ov_model, "./model_cache_nncf/model.xml") - - -def main(model_name, num_iters): - model = models.__dict__[model_name](pretrained=True) - model = model.eval() - - example_inputs = get_dummy_dataset() - - converted_model = quantize(copy.deepcopy(model), example_inputs) - - print("original model execution time: ", measure_time(model, example_inputs, num_iters)) - - native_optimized_model_fp32 = torch.compile(model) - print( - "Torch Inductor FP32 model execution time: ", - measure_time(native_optimized_model_fp32, example_inputs, num_iters), - ) - - native_optimized_model_int8 = torch.compile(converted_model) - print( - "Torch Inductor INT8 model execution time: ", - measure_time(native_optimized_model_int8, example_inputs, num_iters), - ) - - ov_optimized_model_fp32 = torch.compile(model, backend="openvino") - print( - "Torch.compile OpenVINO FP32 model execution time: ", - measure_time(ov_optimized_model_fp32, example_inputs, num_iters), - ) - - ov_optimized_model_int8 = torch.compile( - converted_model, backend="openvino", options={"model_caching": True, "cache_dir": "./model_cache"} - ) - print( - "Torch.compile OpenVINO INT8 model execution time: ", - measure_time(ov_optimized_model_int8, example_inputs, num_iters), - ) - - import intel_extension_for_pytorch # noqa - - ipex_optimized_model_fp32 = torch.compile(model, backend="ipex") - print( - "Torch.compile IPEX FP32 model execution time: ", - measure_time(ipex_optimized_model_fp32, example_inputs, num_iters), - ) - - ipex_optimized_model_int8 = torch.compile(converted_model, backend="ipex") - print( - "Torch.compile IPEX INT8 model execution time: ", - measure_time(ipex_optimized_model_int8, example_inputs, num_iters), - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--num_iters", help="number of inference iterations", type=int, default=100) - parser.add_argument("--model", help="torchvision model name", type=str, default="resnet18") - args = parser.parse_args() - model_name = args.model - num_iters = args.num_iters - main(model_name, num_iters) - # main_nncf(model_name, num_iters) diff --git a/yolo_fx_bad_metrics_repro.py b/yolo_fx_bad_metrics_repro.py deleted file mode 100644 index b5c05d6bbcb..00000000000 --- a/yolo_fx_bad_metrics_repro.py +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright (c) 2024 Intel Corporation -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# http://www.apache.org/licenses/LICENSE-2.0 -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Dict, Tuple - -import numpy as np -import torch -from tqdm import tqdm -from ultralytics.data.utils import check_det_dataset -from ultralytics.engine.validator import BaseValidator as Validator -from ultralytics.models.yolo import YOLO -from ultralytics.utils.torch_utils import de_parallel - - -def print_statistics(stats: np.ndarray, total_images: int, total_objects: int) -> None: - mp, mr, map50, mean_ap = ( - stats["metrics/precision(B)"], - stats["metrics/recall(B)"], - stats["metrics/mAP50(B)"], - stats["metrics/mAP50-95(B)"], - ) - s = ("%20s" + "%12s" * 6) % ("Class", "Images", "Labels", "Precision", "Recall", "mAP@.5", "mAP@.5:.95") - print(s) - pf = "%20s" + "%12i" * 2 + "%12.3g" * 4 # print format - print(pf % ("all", total_images, total_objects, mp, mr, map50, mean_ap)) - - -def prepare_validation(model: YOLO, data: str) -> Tuple[Validator, torch.utils.data.DataLoader]: - # custom = {"rect": True, "batch": 1} # method defaults - # rect: false forces to resize all input pictures to one size - custom = {"rect": False, "batch": 1} # method defaults - args = {**model.overrides, **custom, "mode": "val"} # highest priority args on the right - - validator = model._smart_load("validator")(args=args, _callbacks=model.callbacks) - stride = 32 # default stride - validator.stride = stride # used in get_dataloader() for padding - validator.data = check_det_dataset(data) - validator.init_metrics(de_parallel(model)) - - data_loader = validator.get_dataloader(validator.data.get(validator.args.split), validator.args.batch) - return validator, data_loader - - -def validate(model, data_loader: torch.utils.data.DataLoader, validator: Validator) -> Tuple[Dict, int, int]: - with torch.no_grad(): - for batch in data_loader: - batch = validator.preprocess(batch) - preds = model(batch["img"]) - preds = validator.postprocess(preds) - validator.update_metrics(preds, batch) - stats = validator.get_stats() - return stats, validator.seen, validator.nt_per_class.sum() - - -def main(torch_fx): - # ultralytics @ git+https://github.com/THU-MIG/yolov10.git@2c36ab0f108efdd17c7e290564bb845ccb6844d8 - # pip install git+https://github.com/THU-MIG/yolov10.git - # pip install huggingface-hub - # yolo_model = YOLO("yolov10n.pt") - - yolo_model = YOLO("yolov8n") - - model_type = "torch" - model = yolo_model.model - if torch_fx: - model = torch.compile(model) - model_type = "FX" - print(f"FP32 {model_type} model validation results:") - validator, data_loader = prepare_validation(yolo_model, "coco128.yaml") - stats, total_images, total_objects = validate(model, tqdm(data_loader), validator) - print_statistics(stats, total_images, total_objects) - - -if __name__ == "__main__": - print("Torch model:") - main(torch_fx=False) - print("Torch FX model:") - main(torch_fx=True) From bf0f357ece1d0492349c6c8d3afe359d49ba649a Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Wed, 10 Jul 2024 17:59:09 +0200 Subject: [PATCH 3/7] view_to_reshape is removed Comments --- nncf/common/utils/backend.py | 4 +- .../torch/fx/model_transformer.py | 29 ++++++ .../torch/fx/quantization/quantize_model.py | 27 ++---- nncf/experimental/torch/fx/transformations.py | 96 ++++++++++++------- .../fast_bias_correction/torch_fx_backend.py | 2 +- .../algorithms/min_max/torch_fx_backend.py | 4 +- 6 files changed, 106 insertions(+), 56 deletions(-) diff --git a/nncf/common/utils/backend.py b/nncf/common/utils/backend.py index ee3f9b75768..a76c0b8670d 100644 --- a/nncf/common/utils/backend.py +++ b/nncf/common/utils/backend.py @@ -58,8 +58,8 @@ def is_torch_model(model: TModel) -> bool: :param model: A target model. :return: True if the model is an instance of torch.nn.Module and not torch.fx.GraphModule, otherwise False. """ - import torch # type: ignore - import torch.fx + import torch # type: ignore + import torch.fx # type: ignore return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module) diff --git a/nncf/experimental/torch/fx/model_transformer.py b/nncf/experimental/torch/fx/model_transformer.py index b4db5ed4fa7..33cb703fde8 100644 --- a/nncf/experimental/torch/fx/model_transformer.py +++ b/nncf/experimental/torch/fx/model_transformer.py @@ -49,6 +49,12 @@ def __init__(self, model: torch.fx.GraphModule): ] def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule: + """ + Transforms the target model according to given transformation layout. + + :param transformation_layout: Given transformation layout. + :return: Target model transformered according to the given transformation layout. + """ # TODO(dlyakhov): Manage priorities of transformations. transformations = transformation_layout.transformations aggregated_transformations = defaultdict(list) @@ -72,6 +78,15 @@ def _apply_model_extraction( model: torch.fx.GraphModule, transformations: List[PTModelExtractionCommand], ) -> torch.fx.GraphModule: + """ + Returns a submodel extracted from the given model by the given transformation. + + :param model: Given model. + :param transformations: List of one transformation which specifies + how to retrieve a submodule from the model. In case list contains + more than one element this function raises an assert. + :return: Returns a submodel extracted from the given model by the given transformation. + """ transformation = transformations[-1] assert len(transformation.input_node_names) == 1 assert transformation.input_node_names == transformation.output_node_names @@ -97,6 +112,13 @@ def _apply_model_extraction( @staticmethod def get_graph_node_by_name(graph: torch.fx.Graph, name: str) -> torch.fx.Node: + """ + Retrieves a node with the specified name from the grpah. + + :param graph: Given torch fx graph. + :param name: Target node name. + :return: A graph node with the given name. + """ for node in graph.nodes: if node.name == name: return node @@ -107,6 +129,13 @@ def _apply_transformation( model: torch.fx.GraphModule, transformations: List[FXApplyTransformationCommand], ) -> torch.fx.GraphModule: + """ + Applies transformations to the given model. + + :param model: Target model. + :param transformations: Transformations to apply to the model. + :return: Target model after all transformations were applied. + """ for transformation in transformations: transformation.tranformation_fn(model) return model diff --git a/nncf/experimental/torch/fx/quantization/quantize_model.py b/nncf/experimental/torch/fx/quantization/quantize_model.py index 08bb73ee854..8a5348754ef 100644 --- a/nncf/experimental/torch/fx/quantization/quantize_model.py +++ b/nncf/experimental/torch/fx/quantization/quantize_model.py @@ -28,10 +28,8 @@ from nncf.common.quantization.structs import QuantizationPreset from nncf.common.quantization.structs import QuantizationScheme from nncf.data import Dataset -from nncf.experimental.torch.fx.transformations import merge_conv_and_bias -from nncf.experimental.torch.fx.transformations import separate_conv_and_bias -from nncf.experimental.torch.fx.transformations import separate_linear_and_bias -from nncf.experimental.torch.fx.transformations import view_to_reshape +from nncf.experimental.torch.fx.transformations import apply_quantization_transformations +from nncf.experimental.torch.fx.transformations import revert_quantization_transformations from nncf.parameters import ModelType from nncf.parameters import QuantizationMode from nncf.parameters import TargetDevice @@ -76,13 +74,12 @@ def quantize_impl( if advanced_parameters is None: advanced_parameters = AdvancedQuantizationParameters() - # Default quantization mode is assymmetric + # Default quantization mode is asymmetric activations_quantization_params = advanced_parameters.activations_quantization_params if activations_quantization_params is None: activations_quantization_params = QuantizationParameters() - - activations_quantization_params.mode = QuantizationScheme.ASYMMETRIC - advanced_parameters.activations_quantization_params = activations_quantization_params + activations_quantization_params.mode = QuantizationScheme.ASYMMETRIC + advanced_parameters.activations_quantization_params = activations_quantization_params quantization_algorithm = PostTrainingQuantization( preset=preset, @@ -102,18 +99,14 @@ def quantize_impl( # To make it easier for bias correction algorithms, # biases are being separated by the followng calls. - separate_linear_and_bias(copied_model) - separate_conv_and_bias(copied_model) - - # View requires at least one dimension spans - # across two contiguous subspaces and reshape is not. - # To prevent error during statistics collection - # all view operation are translated to reshape. - view_to_reshape(copied_model) + apply_quantization_transformations(copied_model) nncf_graph = NNCFGraphFactory.create(copied_model) quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset) - merge_conv_and_bias(quantized_model) + + # Revert applied transformation to keep original model + # bias configuration. + revert_quantization_transformations(quantized_model) # Magic. Without this call compiled model # is not preformant diff --git a/nncf/experimental/torch/fx/transformations.py b/nncf/experimental/torch/fx/transformations.py index dc8a2122adb..097338dc928 100644 --- a/nncf/experimental/torch/fx/transformations.py +++ b/nncf/experimental/torch/fx/transformations.py @@ -15,7 +15,6 @@ import torch.fx from torch.ao.quantization.fx.utils import create_getattr_from_value from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node -from torch.ao.quantization.pt2e.utils import _is_conv from torch.quantization.fake_quantize import FakeQuantize from nncf.common.graph.graph import NNCFNode @@ -118,7 +117,10 @@ def qdq_insertion_tranformation_builder( def qdq_insertion_tranformation(model: torch.fx.GraphModule): if any(tp.target_type != TargetType.OPERATION_WITH_WEIGHTS for tp in target_points) and len(target_points) > 1: - raise RuntimeError + raise RuntimeError( + "Insertion of shared qdq pair for the weights is not supported." + " Please use non shared qdq pairs for the weights quantization." + ) for target_point in target_points: target_node = _get_target_node(model.graph, target_point) insert_one_qdq_before_node(model, target_node, quantizer) @@ -166,14 +168,14 @@ def insert_one_qdq_before_node(model: torch.fx.GraphModule, target_node: torch.f # 2. replace activation_post_process node with quantize and dequantize graph = model.graph - # TODO: use metatype to get correct input_port_id + # TODO(dlyakhov): use metatype to get correct input_port_id # Do not quantize already quantized nodes # inserting_before handle only order in the graph generated code. # so, inserting quantize-dequantize and all constant nodes before the usage of the nodes with graph.inserting_before(target_node): quantize_op_inputs = [target_node] for key, value_or_node in qparams.items(): - # TODO: we can add the information of whether a value needs to + # TODO(dlyakhov): we can add the information of whether a value needs to # be registered as an attribute in qparams dict itself if key in ["_scale_", "_zero_point_"] and (not isinstance(value_or_node, (float, int))): # For scale and zero_point values we register them as buffers in the root module. @@ -183,7 +185,7 @@ def insert_one_qdq_before_node(model: torch.fx.GraphModule, target_node: torch.f # tracing where it may consider tensor overload as opposed to default. # With extra check of scale and zero_point being scalar, it makes # sure that the default overload can be used. - # TODO: maybe need more complex attr name here + # TODO(dlaykhov): maybe need more complex attr name here qparam_node = create_getattr_from_value(model, graph, target_node.name + key, value_or_node) quantize_op_inputs.append(qparam_node) else: @@ -231,9 +233,7 @@ def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint) -> torc target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name) if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]: target_node = target_node.all_input_nodes[target_point.input_port_id] - elif target_type == TargetType.OPERATOR_POST_HOOK: - pass - else: + elif target_type != TargetType.OPERATOR_POST_HOOK: raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}") return target_node @@ -263,14 +263,42 @@ def _set_module_to_the_graph_module( return module_name_in_model +def apply_quantization_transformations(model: torch.fx.Graph): + """ + Applies quantization transformations to the model. + :param model: Model to apply transformations to. + """ + separate_conv_and_bias(model) + separate_linear_and_bias(model) + + +def revert_quantization_transformations(model: torch.fx.Graph): + """ + Reverts quantization transformations from the model. + :param model: Model to revert transformations from. + """ + merge_conv_and_bias(model) + merge_linear_and_bias(model) + + def _is_linear(n: torch.fx.Node) -> bool: """ - Returns true if given node is a linear node, else False. + Return whether the node refers to an aten linear op. :param n: The given node. :return: True if given node is a linear node, else False. """ - return n.op == "call_function" and n.target in [torch.ops.aten.linear.default] + return n.op == "call_function" and n.target in (torch.ops.aten.linear.default,) + + +def _is_conv(n: torch.fx.Node): + """ + Return whether the node refers to an aten conv op. + """ + return n.op == "call_function" and n.target in ( + torch.ops.aten.conv1d.default, + torch.ops.aten.conv2d.default, + ) def separate_linear_and_bias(model: torch.fx.GraphModule): @@ -291,7 +319,7 @@ def separate_linear_and_bias(model: torch.fx.GraphModule): while linear_bias_node.op != "get_attr": # Assume zero argument is on a path to the constant linear_bias_node = linear_bias_node.args[0] - conv_bias_value = _get_tensor_constant_from_node(linear_bias_node, model) + linear_bias_value = _get_tensor_constant_from_node(linear_bias_node, model) args = list(n.args) args[2] = None linear_node.args = tuple(args) @@ -300,7 +328,7 @@ def separate_linear_and_bias(model: torch.fx.GraphModule): model, model.graph, linear_bias_node.name + "_", - conv_bias_value, + linear_bias_value, ) with model.graph.inserting_after(new_linear_bias_node): add_node = model.graph.create_node( @@ -316,26 +344,6 @@ def separate_linear_and_bias(model: torch.fx.GraphModule): model.recompile() -def view_to_reshape(model: torch.fx.GraphModule): - """ - Replaces all instances of view to a reshape call. - - :param model: Target model. - """ - for n in model.graph.nodes: - if not (n.op == "call_function" and n.target in [torch.ops.aten.view.default]): - continue - with model.graph.inserting_after(n): - reshape = model.graph.create_node("call_function", torch.ops.aten.reshape.default, tuple(n.args), {}) - reshape.meta = n.meta - - for user in list(n.users): - user.replace_input_with(n, reshape) - - model.graph.eliminate_dead_code() - model.recompile() - - def separate_conv_and_bias(model: torch.fx.GraphModule): """ Separates one joined conv+bias node to two nodes: conv and bias. @@ -384,14 +392,34 @@ def separate_conv_and_bias(model: torch.fx.GraphModule): def merge_conv_and_bias(model: torch.fx.GraphModule): """ - Separates one joined conv+bias node to two nodes: conv and bias. + Merges two separate conv and bias nodes to a one node: conv+bias. Needed as nncf does not expect joined conv :param model: Target model. """ + _merge_node_and_bias(model, _is_conv) + + +def merge_linear_and_bias(model: torch.fx.GraphModule): + """ + Merges two separate linear and bias nodes to a one node: linear+bias. + + :param model: Target model. + """ + _merge_node_and_bias(model, _is_linear) + + +def _merge_node_and_bias(model: torch.fx.GraphModule, is_target_node: Callable[[torch.fx.Node], bool]): + """ + Merges two separate node and bias node to a one node: node+bias. + Check which node should be merged by the given `is_target_node` predicate. + + :param model: Target model. + :param is_target_node: Predicate to specify nodes which shoudld be merged with the bias + """ add_node_targets = (torch.ops.aten.add_.Tensor,) for n in model.graph.nodes: - if not _is_conv(n): + if not is_target_node(n): continue if len(n.args) > 2 and n.args[2] is not None: continue diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py index e32d31a9c7d..6af561ff8c4 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py @@ -82,7 +82,7 @@ def create_input_data(shape: Tuple[int], data: List[Tensor], input_name: str, ch @staticmethod def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor: - # TODO: make a node_name_vs_node map to speed up the process + # TODO(dlyakhov): make a node_name_vs_node map to speed up the process from nncf.experimental.torch.fx.model_transformer import FXModelTransformer bias_node = nncf_graph.get_next_nodes(node)[0] diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index 3cde7eb552e..ed29351c577 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -144,7 +144,7 @@ def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point @staticmethod def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint, ndims: int) -> Tuple[int]: - # TODO: support transpose conv and other cases + # TODO(dlyakhov): support transpose conv and other cases return (0,) @staticmethod @@ -220,7 +220,7 @@ def _get_input_scale_shape( ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: is_weights = target_point.is_weight_target_point() if is_weights: - # TODO: support transpose conv/ make channel_idx common + # TODO(dlyakhov): support transpose conv/ make channel_idx common channel_idx = 0 else: channel_idx = 1 # channel dim for activations From 1dd098a31560ea1ae11823f253caa29cf98af074 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Mon, 22 Jul 2024 15:11:29 +0200 Subject: [PATCH 4/7] module_insertion_tranformation_builder is removed Comments --- .../torch/fx/quantization/quantize_model.py | 18 -------- nncf/experimental/torch/fx/transformations.py | 44 +++++-------------- .../algorithms/min_max/torch_fx_backend.py | 4 +- tests/torch/fx/test_sanity.py | 2 +- 4 files changed, 14 insertions(+), 54 deletions(-) diff --git a/nncf/experimental/torch/fx/quantization/quantize_model.py b/nncf/experimental/torch/fx/quantization/quantize_model.py index 8a5348754ef..01aebf68c1f 100644 --- a/nncf/experimental/torch/fx/quantization/quantize_model.py +++ b/nncf/experimental/torch/fx/quantization/quantize_model.py @@ -18,7 +18,6 @@ from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat from torch.ao.quantization.pt2e.utils import _disallow_eval_train -from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ from torch.fx import GraphModule from torch.fx.passes.infra.pass_manager import PassManager @@ -26,7 +25,6 @@ from nncf.common.factory import NNCFGraphFactory from nncf.common.logging import nncf_logger from nncf.common.quantization.structs import QuantizationPreset -from nncf.common.quantization.structs import QuantizationScheme from nncf.data import Dataset from nncf.experimental.torch.fx.transformations import apply_quantization_transformations from nncf.experimental.torch.fx.transformations import revert_quantization_transformations @@ -34,7 +32,6 @@ from nncf.parameters import QuantizationMode from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters -from nncf.quantization.advanced_parameters import QuantizationParameters from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization from nncf.scopes import IgnoredScope @@ -72,15 +69,6 @@ def quantize_impl( copied_model = deepcopy(model) - if advanced_parameters is None: - advanced_parameters = AdvancedQuantizationParameters() - # Default quantization mode is asymmetric - activations_quantization_params = advanced_parameters.activations_quantization_params - if activations_quantization_params is None: - activations_quantization_params = QuantizationParameters() - activations_quantization_params.mode = QuantizationScheme.ASYMMETRIC - advanced_parameters.activations_quantization_params = activations_quantization_params - quantization_algorithm = PostTrainingQuantization( preset=preset, target_device=target_device, @@ -91,12 +79,6 @@ def quantize_impl( advanced_parameters=advanced_parameters, ) - # BatchNorm operations have 3 output ports, - # to make it easier for alorithms to work - # with the target graph BatchNorm operations - # are being fused - _fuse_conv_bn_(copied_model) - # To make it easier for bias correction algorithms, # biases are being separated by the followng calls. apply_quantization_transformations(copied_model) diff --git a/nncf/experimental/torch/fx/transformations.py b/nncf/experimental/torch/fx/transformations.py index 097338dc928..32876d1fd55 100644 --- a/nncf/experimental/torch/fx/transformations.py +++ b/nncf/experimental/torch/fx/transformations.py @@ -14,6 +14,7 @@ import torch import torch.fx from torch.ao.quantization.fx.utils import create_getattr_from_value +from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node from torch.quantization.fake_quantize import FakeQuantize @@ -25,35 +26,6 @@ TransformationFNType = Callable[[torch.fx.GraphModule], None] -def module_insertion_tranformation_builder( - module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint] -) -> TransformationFNType: - """ - Returns transformation which inserts given module to a target model and calls given module - after each target points. For each target node all original ouputs are being replaced - by outputs of corresponded module call. - - :param module_to_insert: Given torch.nn.Module to insert. - :param target_points: Target points to insert the target module. - :returns: Transformation which inserts given module to a target model and calls given module - after each target points. For each target node all original ouputs - are being replaced by outputs of corresponded module call. - """ - - def module_insertion_transformation(model: torch.fx.GraphModule): - module_attr_name = _set_module_to_the_graph_module(model, module_to_insert, target_points) - graph = model.graph - for target_point in target_points: - target_node = _get_target_node(graph, target_point) - new_node = _insert_call_module(graph, target_node, module_attr_name) - for user in list(target_node.users): - if user is new_node: - continue - user.replace_input_with(target_node, new_node) - - return module_insertion_transformation - - def leaf_module_insertion_transformation_builder( module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint] ) -> TransformationFNType: @@ -123,12 +95,12 @@ def qdq_insertion_tranformation(model: torch.fx.GraphModule): ) for target_point in target_points: target_node = _get_target_node(model.graph, target_point) - insert_one_qdq_before_node(model, target_node, quantizer) + insert_one_qdq_after_node(model, target_node, quantizer) return qdq_insertion_tranformation -def insert_one_qdq_before_node(model: torch.fx.GraphModule, target_node: torch.fx.Node, quantizer: FakeQuantize): +def insert_one_qdq_after_node(model: torch.fx.GraphModule, target_node: torch.fx.Node, quantizer: FakeQuantize): """ Inserts quantize-dequantize after the target node to the target model. @@ -229,6 +201,7 @@ def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint) -> torc :param target_point: A target point to find the target node. :return: TorchFX graph node correspondent to the target point. """ + # TODO(dlyakhov): Support node insertion on a specific input port id. target_type = target_point.target_type target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name) if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]: @@ -263,16 +236,21 @@ def _set_module_to_the_graph_module( return module_name_in_model -def apply_quantization_transformations(model: torch.fx.Graph): +def apply_quantization_transformations(model: torch.fx.GraphModule) -> None: """ Applies quantization transformations to the model. :param model: Model to apply transformations to. """ + # BatchNorm operations have 3 output ports, + # to make it easier for alorithms to work + # with the target graph BatchNorm operations + # are being fused + _fuse_conv_bn_(model) separate_conv_and_bias(model) separate_linear_and_bias(model) -def revert_quantization_transformations(model: torch.fx.Graph): +def revert_quantization_transformations(model: torch.fx.GraphModule) -> None: """ Reverts quantization transformations from the model. :param model: Model to revert transformations from. diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index ed29351c577..6edd74d13a6 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -340,8 +340,8 @@ def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[O return types @staticmethod - def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> List[str]: - return [] + def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> Set[str]: + return set() @staticmethod def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]: diff --git a/tests/torch/fx/test_sanity.py b/tests/torch/fx/test_sanity.py index f4dab10b965..a3f0d828260 100644 --- a/tests/torch/fx/test_sanity.py +++ b/tests/torch/fx/test_sanity.py @@ -52,7 +52,7 @@ class SanitySampleCase: SanitySampleCase( "resnet18", "https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/302_resnet18_fp32_v1.pth", - 55.2, + 55.35, 51, 58, ), From 826fed9118277405d43759280f80582348dd08c6 Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Wed, 24 Jul 2024 17:12:06 +0200 Subject: [PATCH 5/7] Comment --- nncf/experimental/torch/fx/transformations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/nncf/experimental/torch/fx/transformations.py b/nncf/experimental/torch/fx/transformations.py index 32876d1fd55..87403857e4c 100644 --- a/nncf/experimental/torch/fx/transformations.py +++ b/nncf/experimental/torch/fx/transformations.py @@ -224,8 +224,9 @@ def _set_module_to_the_graph_module( :return: A graph module attribute name which keep given module. """ module_to_insert = module_to_insert + # TODO(dlyakhov) Make module name human readable. module_name_in_model = ( - ";".join( + "__".join( "_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) for tp in target_points ) + "_" From 3b4a992c307b674ed36ac21c2a61db84346555eb Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Thu, 25 Jul 2024 18:39:17 +0200 Subject: [PATCH 6/7] Comments --- nncf/experimental/torch/fx/commands.py | 37 +++++++++++++++++++ .../torch/fx/model_transformer.py | 33 ++--------------- .../torch/fx/nncf_graph_builder.py | 5 ++- nncf/experimental/torch/fx/node_utils.py | 28 ++++++++++++++ .../torch/fx/statistics/aggregator.py | 2 +- nncf/experimental/torch/fx/transformations.py | 23 +++++------- .../fast_bias_correction/torch_fx_backend.py | 11 +++--- .../algorithms/min_max/torch_fx_backend.py | 3 +- 8 files changed, 90 insertions(+), 52 deletions(-) create mode 100644 nncf/experimental/torch/fx/commands.py create mode 100644 nncf/experimental/torch/fx/node_utils.py diff --git a/nncf/experimental/torch/fx/commands.py b/nncf/experimental/torch/fx/commands.py new file mode 100644 index 00000000000..831f177cac7 --- /dev/null +++ b/nncf/experimental/torch/fx/commands.py @@ -0,0 +1,37 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Union + +import torch.fx + +from nncf.common.graph.transformations.commands import Command +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.commands import TransformationType + + +class FXApplyTransformationCommand(Command): + """ + Command to apply given transformation to a model. + """ + + def __init__( + self, + transformation_fn: Callable[[torch.fx.GraphModule], None], + priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, + ): + """ + :param transformation_fn: Target transformation function. + :param priority: Transformation priority. + """ + super().__init__(TransformationType.INSERT) + self.tranformation_fn = transformation_fn + self.priority = priority diff --git a/nncf/experimental/torch/fx/model_transformer.py b/nncf/experimental/torch/fx/model_transformer.py index 33cb703fde8..4be8f306051 100644 --- a/nncf/experimental/torch/fx/model_transformer.py +++ b/nncf/experimental/torch/fx/model_transformer.py @@ -10,31 +10,18 @@ # limitations under the License. from collections import defaultdict -from typing import Callable, List, Union +from typing import List import torch import torch.fx from torch.fx.passes.split_utils import split_by_tags from nncf.common.graph.model_transformer import ModelTransformer -from nncf.common.graph.transformations.commands import Command -from nncf.common.graph.transformations.commands import TransformationPriority -from nncf.common.graph.transformations.commands import TransformationType +from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand from nncf.torch.graph.transformations.commands import PTModelExtractionCommand from nncf.torch.graph.transformations.layout import PTTransformationLayout -class FXApplyTransformationCommand(Command): - def __init__( - self, - transformation_fn: Callable[[torch.fx.GraphModule], None], - priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, - ): - super().__init__(TransformationType.INSERT) - self.tranformation_fn = transformation_fn - self.priority = priority - - class FXModelTransformer(ModelTransformer): """ Applies transformations upon Torch FX model. @@ -107,23 +94,11 @@ def _apply_model_extraction( continue node.tag = tags[i] + # TODO(dlyakhov): reduce memory consumption by + # more optimal splitting implementation. splitted_gm = split_by_tags(model, tags) return splitted_gm.extracted - @staticmethod - def get_graph_node_by_name(graph: torch.fx.Graph, name: str) -> torch.fx.Node: - """ - Retrieves a node with the specified name from the grpah. - - :param graph: Given torch fx graph. - :param name: Target node name. - :return: A graph node with the given name. - """ - for node in graph.nodes: - if node.name == name: - return node - raise RuntimeError(f"Node with name {name} is not found") - @staticmethod def _apply_transformation( model: torch.fx.GraphModule, diff --git a/nncf/experimental/torch/fx/nncf_graph_builder.py b/nncf/experimental/torch/fx/nncf_graph_builder.py index e90d6bf7fa7..0863cab72ee 100644 --- a/nncf/experimental/torch/fx/nncf_graph_builder.py +++ b/nncf/experimental/torch/fx/nncf_graph_builder.py @@ -57,7 +57,7 @@ def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMe node_type = node.op node_metatype = UnknownMetatype if node_metatype is UnknownMetatype: - nncf_logger.info(f"Unknown metatype for node: {node}") + nncf_logger.debug(f"Unknown metatype for node: {node}") return node_type, node_metatype @staticmethod @@ -134,7 +134,8 @@ def get_edge_params( tensor = source_node.meta["val"] tensor_shape = tuple(tensor.shape) else: - nncf_logger.info(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.") + # TODO(dlyakhov): Refactor algorithms to always have knowns edges shapes. + nncf_logger.debug(f"Edge shape between {source_node.name} and {dist_node.name} is unknown.") tensor_shape = None input_port_id = dist_node.all_input_nodes.index(source_node) diff --git a/nncf/experimental/torch/fx/node_utils.py b/nncf/experimental/torch/fx/node_utils.py new file mode 100644 index 00000000000..96b994cf320 --- /dev/null +++ b/nncf/experimental/torch/fx/node_utils.py @@ -0,0 +1,28 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +def get_graph_node_by_name(graph: torch.fx.Graph, name: str) -> torch.fx.Node: + """ + Retrieves a node with the specified name from the grpah. + Raises a runtime error if graph does not contain node with + the given name. + + :param graph: Given torch fx graph. + :param name: Target node name. + :return: A graph node with the given name. + """ + for node in graph.nodes: + if node.name == name: + return node + raise RuntimeError(f"Node with name {name} is not found") diff --git a/nncf/experimental/torch/fx/statistics/aggregator.py b/nncf/experimental/torch/fx/statistics/aggregator.py index fd9e8e1e386..bf45c4cea0b 100644 --- a/nncf/experimental/torch/fx/statistics/aggregator.py +++ b/nncf/experimental/torch/fx/statistics/aggregator.py @@ -21,7 +21,7 @@ from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer from nncf.common.tensor_statistics.aggregator import StatisticsAggregator from nncf.experimental.common.tensor_statistics.collectors import TensorCollector -from nncf.experimental.torch.fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand from nncf.experimental.torch.fx.transformations import leaf_module_insertion_transformation_builder from nncf.tensor import Tensor from nncf.torch.nncf_network import NNCFNetwork diff --git a/nncf/experimental/torch/fx/transformations.py b/nncf/experimental/torch/fx/transformations.py index 87403857e4c..47ae266ba1b 100644 --- a/nncf/experimental/torch/fx/transformations.py +++ b/nncf/experimental/torch/fx/transformations.py @@ -18,9 +18,10 @@ from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node from torch.quantization.fake_quantize import FakeQuantize +import nncf from nncf.common.graph.graph import NNCFNode from nncf.common.graph.transformations.commands import TargetType -from nncf.experimental.torch.fx.model_transformer import FXModelTransformer +from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name from nncf.torch.graph.transformations.commands import PTTargetPoint TransformationFNType = Callable[[torch.fx.GraphModule], None] @@ -62,11 +63,16 @@ def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> T def bias_update_transformation(model: torch.fx.GraphModule): graph = model.graph target_node_name = node.node_name - graph_node = FXModelTransformer.get_graph_node_by_name(graph, target_node_name) + graph_node = get_graph_node_by_name(graph, target_node_name) + if len(graph_node.users) != 1: + raise nncf.InternalError(f"Node with bias have {len(graph_node.users)} users, 1 expected.") + bias_node = next(iter(graph_node.users)) with graph.inserting_before(bias_node): new_constant = create_getattr_from_value(model, graph, target_node_name + "_shifted_bias", value) + args = list(bias_node.args) + # A bias node suppose to have constant on the second input port. args[1] = new_constant bias_node.args = tuple(args) graph.eliminate_dead_code() @@ -203,7 +209,7 @@ def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint) -> torc """ # TODO(dlyakhov): Support node insertion on a specific input port id. target_type = target_point.target_type - target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name) + target_node = get_graph_node_by_name(graph, target_point.target_node_name) if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]: target_node = target_node.all_input_nodes[target_point.input_port_id] elif target_type != TargetType.OPERATOR_POST_HOOK: @@ -345,16 +351,7 @@ def separate_conv_and_bias(model: torch.fx.GraphModule): conv_node.args = tuple(args) with model.graph.inserting_after(conv_node): new_conv_bias_node = create_getattr_from_value( - model, - model.graph, - conv_bias_node.name + "_", - conv_bias_value.reshape( - ( - 1, - -1, - ) - + (1,) * (dims - 2) - ), + model, model.graph, conv_bias_node.name + "_", conv_bias_value.reshape((1, -1) + (1,) * (dims - 2)) ) with model.graph.inserting_after(new_conv_bias_node): add_node = model.graph.create_node("call_function", add_node_target, (conv_node, new_conv_bias_node), {}) diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py index 6af561ff8c4..d808448307e 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py @@ -22,7 +22,8 @@ from nncf.common.graph.definitions import NNCFGraphNodeType from nncf.common.graph.transformations.commands import TargetType from nncf.experimental.common.tensor_statistics.collectors import TensorCollector -from nncf.experimental.torch.fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand +from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name from nncf.experimental.torch.fx.transformations import bias_update_transformation_builder from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend from nncf.tensor import Tensor @@ -82,11 +83,9 @@ def create_input_data(shape: Tuple[int], data: List[Tensor], input_name: str, ch @staticmethod def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor: - # TODO(dlyakhov): make a node_name_vs_node map to speed up the process - from nncf.experimental.torch.fx.model_transformer import FXModelTransformer - bias_node = nncf_graph.get_next_nodes(node)[0] - graph_bias_node = FXModelTransformer.get_graph_node_by_name(model.graph, bias_node.node_name) + # TODO(dlyakhov): make a node_name_vs_node map to speed up the process + graph_bias_node = get_graph_node_by_name(model.graph, bias_node.node_name) return Tensor(_get_tensor_constant_from_node(graph_bias_node.all_input_nodes[1], model)) @staticmethod @@ -100,7 +99,7 @@ def process_model_output(raw_data: Dict, output_name: str) -> Tensor: @staticmethod def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: weight_node = nncf_graph.get_previous_nodes(node)[1] - return weight_node.node_type == "dequantize_per_channel" + return "dequantize" in weight_node.node_type @staticmethod def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index 6edd74d13a6..c095836e674 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -27,7 +27,7 @@ from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic -from nncf.experimental.torch.fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand from nncf.experimental.torch.fx.transformations import qdq_insertion_tranformation_builder from nncf.parameters import ModelType from nncf.parameters import TargetDevice @@ -254,6 +254,7 @@ def _create_quantizer( quantizer = quantizer_cls(quantizer_spec) # Fill it with minmax + # TODO(dlyakhov) Prevent creation of intermediate objects like nncf quantizer. FXMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape) # Convert to the torch fake quantizer torch_fq = convert_to_torch_fakequantizer(quantizer) From 12eb153dcf8b138f0569c5b1f497ccff04a72bad Mon Sep 17 00:00:00 2001 From: dlyakhov Date: Fri, 26 Jul 2024 09:44:28 +0200 Subject: [PATCH 7/7] TODO to use find_nodes instead of get_graph_node_by_name --- nncf/experimental/torch/fx/node_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/nncf/experimental/torch/fx/node_utils.py b/nncf/experimental/torch/fx/node_utils.py index 96b994cf320..5d03d5e355d 100644 --- a/nncf/experimental/torch/fx/node_utils.py +++ b/nncf/experimental/torch/fx/node_utils.py @@ -12,6 +12,8 @@ import torch +# TODO(dlyakhov): Use torch.fx.graph.find_nodes method instead after +# torch version update (>= 2.4) def get_graph_node_by_name(graph: torch.fx.Graph, name: str) -> torch.fx.Node: """ Retrieves a node with the specified name from the grpah.