diff --git a/aa_torch_fx.py b/aa_torch_fx.py new file mode 100644 index 00000000000..339d33d1598 --- /dev/null +++ b/aa_torch_fx.py @@ -0,0 +1,456 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import copy +import re +import subprocess +import time +import warnings +from itertools import islice +from pathlib import Path + +import numpy as np +import openvino as ov +import openvino.torch # noqa +import pandas as pd +import torch +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq +import torchvision.models as models +from sklearn.metrics import accuracy_score +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import convert_pt2e +from torch.ao.quantization.quantize_pt2e import prepare_pt2e +from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer +from torch.fx.passes.graph_drawer import FxGraphDrawer +from torch.jit import TracerWarning +from torchao.utils import benchmark_model as ao_benchmark_model +from torchvision import datasets +from transformers import AutoImageProcessor +from transformers import AutoModelForImageClassification + +import nncf +from nncf.common.logging.track_progress import track +from nncf.common.quantization.structs import QuantizationPreset # noqa +from nncf.parameters import ModelType +from nncf.torch.dynamic_graph.patch_pytorch import disable_patching + +warnings.filterwarnings("ignore", category=TracerWarning) +warnings.filterwarnings("ignore", category=UserWarning) + +DATASET_IMAGENET = "/home/dlyakhov/datasets/imagenet/val" + +hf_models = () + + +def hf_model_builder(model_id: str): + def build(weights): + processor = AutoImageProcessor.from_pretrained(model_id) + model = AutoModelForImageClassification.from_pretrained(model_id) + + class ModelWithProcessing(torch.nn.Module): + def __init__(self, processor, model): + super().__init__() + self.processor = processor + self.model = model + + def forward(self, x): + processed_input = processor(x, return_tensors="pt") + return model(processed_input) + + # return ModelWithProcessing(processor, model) + return model + + class DummyWeights: + def transforms(self): + return models.ResNet18_Weights.DEFAULT.transforms() + + @property + def meta(self): + return {} + + return build, DummyWeights() + + +MODELS_DICT = { + "vit_h_14": (models.vit_h_14, models.ViT_H_14_Weights.DEFAULT), + "vit_b_16": (models.vit_b_16, models.ViT_B_16_Weights.DEFAULT), + "swin_v2_t": (models.swin_v2_t, models.Swin_V2_T_Weights.DEFAULT), + "swin_v2_s": (models.swin_v2_s, models.Swin_V2_S_Weights.DEFAULT), + "resnet18": (models.resnet18, models.ResNet18_Weights.DEFAULT), + "resnet50": (models.resnet50, models.ResNet50_Weights.DEFAULT), + "mobilenet_v2": (models.mobilenet_v2, models.MobileNet_V2_Weights.DEFAULT), + "mobilenet_v3_small": (models.mobilenet_v3_small, models.MobileNet_V3_Small_Weights.DEFAULT), + "mobilenet_v3_large": (models.mobilenet_v3_large, models.MobileNet_V3_Large_Weights.DEFAULT), + # "densenet161": (models.densenet161, models.DenseNet161_Weights.DEFAULT), + "vgg16": (models.vgg16, models.VGG16_Weights.DEFAULT), + "efficientnet_b7": (models.efficientnet_b7, models.EfficientNet_B7_Weights.DEFAULT), + "inception_v3": (models.inception_v3, models.Inception_V3_Weights.DEFAULT), + "regnet_x_32gf": (models.regnet_x_32gf, models.RegNet_X_32GF_Weights.DEFAULT), + # "google/vit-base-patch16-224": hf_model_builder("google/vit-base-patch16-224"), + # "convnext_large": (models.convnext_large, models.ConvNeXt_Large_Weights.DEFAULT), + # "convnext_small": (models.convnext_small, models.ConvNeXt_Small_Weights.DEFAULT), +} + + +def measure_time(model, example_inputs, num_iters=1000): + with torch.no_grad(): + model(*example_inputs) + total_time = 0 + for i in range(0, num_iters): + start_time = time.time() + model(*example_inputs) + total_time += time.time() - start_time + average_time = (total_time / num_iters) * 1000 + return average_time + + +def measure_time_ov(model, example_inputs, num_iters=1000): + ie = ov.Core() + compiled_model = ie.compile_model(model, "CPU") + infer_request = compiled_model.create_infer_request() + infer_request.infer(example_inputs) + total_time = 0 + for i in range(0, num_iters): + start_time = time.time() + infer_request.infer(example_inputs) + total_time += time.time() - start_time + average_time = (total_time / num_iters) * 1000 + return average_time + + +def quantize(model, example_inputs, calibration_dataset, subset_size=300): + with torch.no_grad(): + exported_model = capture_pre_autograd_graph(model, example_inputs) + + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + + prepared_model = prepare_pt2e(exported_model, quantizer) + from tqdm import tqdm + + for inp, _ in islice(tqdm(calibration_dataset), subset_size): + prepared_model(inp) + converted_model = convert_pt2e(prepared_model) + return converted_model + + +def validate(model, val_loader, subset_size=None): + dataset_size = len(val_loader) + + predictions = np.zeros((dataset_size)) + references = -1 * np.ones((dataset_size)) + + with track(total=dataset_size, description="Validation") as pbar: + + for i, (images, target) in enumerate(val_loader): + if subset_size is not None and i >= subset_size: + break + + output_data = model(images).detach().numpy() + predicted_label = np.argmax(output_data, axis=1) + predictions[i] = predicted_label.item() + references[i] = target + pbar.progress.update(pbar.task, advance=1) + acc_top1 = accuracy_score(predictions, references) * 100 + print(acc_top1) + return acc_top1 + + +def validate_ov(model, val_loader): + dataset_size = len(val_loader) + + # Initialize result tensors for async inference support. + predictions = np.zeros((dataset_size)) + references = -1 * np.ones((dataset_size)) + + core = ov.Core() + compiled_model = core.compile_model(model) + + infer_queue = ov.AsyncInferQueue(compiled_model, 4) + with track(total=dataset_size, description="Validation") as pbar: + + def process_result(request, userdata): + output_data = request.get_output_tensor().data + predicted_label = np.argmax(output_data, axis=1) + predictions[userdata] = predicted_label.item() + pbar.progress.update(pbar.task, advance=1) + + infer_queue.set_callback(process_result) + + for i, (images, target) in enumerate(val_loader): + # W/A for memory leaks when using torch DataLoader and OpenVINO + image_copies = copy.deepcopy(images.numpy()) + infer_queue.start_async(image_copies, userdata=i) + references[i] = target + + infer_queue.wait_all() + + acc_top1 = accuracy_score(predictions, references) * 100 + print(acc_top1) + return acc_top1 + + +def run_benchmark(model_path: Path, shape) -> float: + command = f"benchmark_app -m {model_path} -d CPU -api async -t 15" + command += f' -shape="[{",".join(str(x) for x in shape)}]"' + cmd_output = subprocess.check_output(command, shell=True) # nosec + match = re.search(r"Throughput\: (.+?) FPS", str(cmd_output)) + return float(match.group(1)) + + +def torch_ao_sq_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input): + import torch + from torchao.quantization.smoothquant import smooth_fq_linear_to_inference + from torchao.quantization.smoothquant import swap_linear_with_smooth_fq_linear + + # Fuse the int8*int8 -> int32 matmul and subsequent mul op avoiding materialization of the int32 intermediary tensor + torch._inductor.config.force_fuse_int_mm_with_mul = True + + # plug in your model + # model = torch.compile(pt_model) + model = pt_model + + # convert linear modules to smoothquant + # linear module in calibration mode + swap_linear_with_smooth_fq_linear(model) + + # Create a data loader for calibration + calibration_loader = val_loader + + # Calibrate the model + model.train() + from tqdm import tqdm + + for batch in tqdm(islice(calibration_loader, 300)): + inputs = batch[0] + model(inputs) + + # set it to inference mode + smooth_fq_linear_to_inference(model) + + # compile the model to improve performance + model = torch.compile(model, mode="max-autotune") + acc1_quant_model = validate(model, val_loader) + print(f"torch ao metric acc@1: {acc1_quant_model}") + result["torch_ao_quant_model_acc"] = acc1_quant_model + + latency = ao_benchmark_model(model, 20, example_input) + print(f"torch ao latency: {latency}") + result["torch_ao_quant_model_latency"] = latency + + +def nncf_fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input): + with disable_patching(): + with torch.no_grad(): + exported_model = capture_pre_autograd_graph(pt_model, (example_input,)) + + def transform(x): + return x[0] + + quant_fx_model = nncf.quantize( + exported_model, nncf.Dataset(val_loader, transform_func=transform), model_type=ModelType.TRANSFORMER + ) + quant_compile_model = torch.compile(quant_fx_model, backend="openvino") + + # acc1_quant_model = validate(quant_compile_model, val_loader) + acc1_quant_model = -1.0 + latency_fx = measure_time(quant_compile_model, (example_input,)) + print(f"latency: {latency_fx}") + result["acc1_nncf_fx_quant_model"] = acc1_quant_model + result["torch_compile_ov_latency_nncf_fx_quant_model"] = latency_fx + + g = FxGraphDrawer(quant_compile_model, f"b_nncf_{pt_model.__class__.__name__}_int8") + g.get_dot_graph().write_svg(f"b_nncf_{pt_model.__class__.__name__}_int8.svg") + + # EXPORT TO OV + exported_model = torch.export.export(quant_compile_model, (example_input,)) + ov_quant_model = ov.convert_model(exported_model, example_input=example_input) + quant_file_path = output_dir / "quant.xml" + ov.save_model(ov_quant_model, quant_file_path) + + fps = run_benchmark(quant_file_path, shape_input) + print(f"fps: {fps}") + result["ov_fps_nncf_fx_quant_model"] = fps + + +def fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input): + with disable_patching(): + fp32_pt_model = copy.deepcopy(pt_model) + fp32_compile_model = torch.compile(fp32_pt_model, backend="openvino") + + quant_pt_model = quantize(fp32_compile_model, (example_input,), val_loader) + quant_compile_model = torch.compile(quant_pt_model, backend="openvino") + + g = FxGraphDrawer(quant_pt_model, f"b_pt_{pt_model.__class__.__name__}_int8") + g.get_dot_graph().write_svg(f"b_pt_{pt_model.__class__.__name__}_int8.svg") + + acc1_quant_model = validate(quant_compile_model, val_loader) + result["acc1_quant_model"] = acc1_quant_model + + latency_fx = measure_time(quant_compile_model, (example_input,)) + print(f"latency: {latency_fx}") + result["torch_compile_latency_fps_quant_model"] = latency_fx + + +def nncf_pt_2_ov_quantization(pt_model, val_loader, example_input, output_dir, result, shape_input): + def transform(x): + return x[0] + + nncf_model = nncf.quantize(copy.deepcopy(pt_model), nncf.Dataset(val_loader, transform_func=transform)) + + ov_nncf_model = ov.convert_model(nncf_model, example_input=example_input) + nncf_pt_file_path = output_dir / "nncf_pt.xml" + ov.save_model(ov_nncf_model, nncf_pt_file_path) + acc1_nncf_pt = validate_ov(ov_nncf_model, val_loader) + result["acc1_nncf_pt"] = acc1_nncf_pt + fps = run_benchmark(nncf_pt_file_path, shape_input) + print(f"fps: {fps}") + result["ov_fps_nncf_pt"] = fps + + +def nncf_ov_2_ov_quantization(ov_fp32_model, val_loader, output_dir, result, shape_input): + def transform(x): + return np.array(x[0]) + + from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters + from nncf.quantization.advanced_parameters import AdvancedSmoothQuantParameters + + advanced_params = AdvancedQuantizationParameters() + # for sq_param in [-1, 0.15, 0.5, 0.75]: + for sq_param in [0.95]: + advanced_params.smooth_quant_alphas = AdvancedSmoothQuantParameters(matmul=sq_param) + + from copy import deepcopy + + fast_bias_correction = True + nncf_ov_int8_model = nncf.quantize( + deepcopy(ov_fp32_model), + nncf.Dataset(val_loader, transform_func=transform), + fast_bias_correction=fast_bias_correction, + model_type=ModelType.TRANSFORMER, + preset=QuantizationPreset.MIXED, + advanced_parameters=advanced_params, + ) + acc1_nncf_ov = validate_ov(nncf_ov_int8_model, val_loader) + result[f"acc1_nncf_ov_{sq_param}"] = acc1_nncf_ov + for precision, model in (("int8", nncf_ov_int8_model), ("fp32", ov_fp32_model)): + nncf_ov_file_path = output_dir / f"nncf_ov_{precision}.xml" + ov.save_model(model, nncf_ov_file_path) + fps = run_benchmark(nncf_ov_file_path, shape_input) + print(f"fps_{precision}: {fps} {sq_param}") + result[f"ov_fps_nncf_ov_{precision}_{sq_param}"] = fps + + latency = measure_time_ov(model, next(iter(val_loader))[0], num_iters=10_000) + print(f"latency_{precision}: {latency}") + result[f"ov_latency_nncf_ov_{precision}_{sq_param}"] = latency + + +def process_model(model_name: str): + + result = {"name": model_name} + model_cls, model_weights = MODELS_DICT[model_name] + output_dir = Path("models") / model_name + output_dir.mkdir(exist_ok=True) + ############################################################## + # Prepare dataset + ############################################################## + + val_dataset = datasets.ImageFolder(root=DATASET_IMAGENET, transform=model_weights.transforms()) + val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1, num_workers=2, shuffle=False) + + ############################################################## + # Prepare original model + ############################################################## + + pt_model = model_cls(weights=model_weights) + pt_model = pt_model.eval() + example_input = next(iter(val_loader))[0] + shape_input = list(example_input.shape) + ############################################################## + # Process FP32 Model + ############################################################## + + fp32_pt_model = copy.deepcopy(pt_model) + + orig_infer_acc1 = model_weights.meta.get("_metrics", {}).get("ImageNet-1K", {}).get("acc@1") + print(f"fp32 model metric: {orig_infer_acc1}") + # orig_infer_acc1 = validate(fp32_pt_model, val_loader) + result["acc1_fp32_openvino"] = orig_infer_acc1 + + fp32_pt_model = torch.export.export(fp32_pt_model, (example_input,)) + ov_fp32_model = ov.convert_model(fp32_pt_model, example_input=example_input) + ov_fp32_file_path = None + ov_fp32_file_path = output_dir / "fp32.xml" + ov.save_model(ov_fp32_model, ov_fp32_file_path) + # result["fps_fp32_openvino"] = run_benchmark(ov_fp32_file_path, shape_input) + # print(f"fps_fp32_openvino {result['fps_fp32_openvino']}") + + del fp32_pt_model + ############################################################## + # Process Torch AO Quantize with SQ + ############################################################## + # torch_ao_sq_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) + + ############################################################## + # with torch.no_grad(): + # exported_model = capture_pre_autograd_graph(pt_model, (example_input,)) + # latency_fx = measure_time(torch.compile(exported_model), (example_input,)) + # print(f"latency: {latency_fx}") + ############################################################# + + ############################################################## + # Process PT Quantize + ############################################################## + fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) + + ############################################################## + # Process NNCF FX Quantize + ############################################################## + # nncf_fx_2_ov_quantization(pt_model, example_input, output_dir, result, val_loader, shape_input) + + ############################################################## + # Process NNCF Quantize by PT + ############################################################## + # nncf_pt_2_ov_quantization(pt_model, val_loader, example_input, output_dir, result, shape_input) + + ############################################################## + # Process NNCF Quantize by OV + ############################################################## + # nncf_ov_2_ov_quantization(ov_fp32_model, val_loader, output_dir, result, shape_input) + + print(result) + return result + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", help="torchvision model name", type=str, default="all") + parser.add_argument("--file_name", help="output csv file_name", type=str, default="result.csv") + + args = parser.parse_args() + + results_list = [] + if args.model == "all": + for model_name in MODELS_DICT: + print("---------------------------------------------------") + print(f"name: {model_name}") + results_list.append(process_model(model_name)) + else: + results_list.append(process_model(args.model)) + + df = pd.DataFrame(results_list) + print(df) + df.to_csv(args.file_name) + + +if __name__ == "__main__": + main() diff --git a/examples/llm_compression/openvino/tiny_llama/main.py b/examples/llm_compression/openvino/tiny_llama/main.py index f2be54ce1aa..e5f3893f1ab 100644 --- a/examples/llm_compression/openvino/tiny_llama/main.py +++ b/examples/llm_compression/openvino/tiny_llama/main.py @@ -11,12 +11,12 @@ import time from functools import partial -import datasets import numpy as np import openvino as ov from optimum.intel.openvino import OVModelForCausalLM from transformers import AutoTokenizer +import datasets import nncf diff --git a/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py b/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py index b3fbce5722b..e34b09bc2f9 100644 --- a/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py +++ b/examples/llm_compression/openvino/tiny_llama_find_hyperparams/main.py @@ -17,12 +17,12 @@ import numpy as np import openvino as ov -from datasets import load_dataset from optimum.intel import OVModelForCausalLM from transformers import AutoTokenizer from whowhatbench import Evaluator import nncf +from datasets import load_dataset from nncf.common.logging import nncf_logger DataItem = TypeVar("DataItem") diff --git a/nncf/common/factory.py b/nncf/common/factory.py index 6616f9dbe3a..d5d13605a07 100644 --- a/nncf/common/factory.py +++ b/nncf/common/factory.py @@ -41,6 +41,10 @@ def create(model: TModel) -> NNCFGraph: if model_backend == BackendType.OPENVINO: from nncf.openvino.graph.nncf_graph_builder import GraphConverter + return GraphConverter.create_nncf_graph(model) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.nncf_graph_builder import GraphConverter + return GraphConverter.create_nncf_graph(model) if model_backend == BackendType.TORCH: return model.nncf.get_graph() @@ -72,6 +76,10 @@ def create(model: TModel, inplace: bool = False) -> ModelTransformer: from nncf.torch.model_transformer import PTModelTransformer return PTModelTransformer(model) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.model_transformer import FXModelTransformer + + return FXModelTransformer(model) raise nncf.UnsupportedBackendError( "Cannot create backend-specific model transformer because {} is not supported!".format(model_backend.value) ) @@ -99,6 +107,10 @@ def create(model: TModel) -> Engine: from nncf.torch.engine import PTEngine return PTEngine(model) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.engine import FXEngine + + return FXEngine(model) raise nncf.UnsupportedBackendError( "Cannot create backend-specific engine because {} is not supported!".format(model_backend.value) ) @@ -151,6 +163,10 @@ def create(model: TModel, dataset: Dataset) -> aggregator.StatisticsAggregator: from nncf.torch.statistics.aggregator import PTStatisticsAggregator return PTStatisticsAggregator(dataset) + if model_backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.statistics.aggregator import FXStatisticsAggregator + + return FXStatisticsAggregator(dataset) raise nncf.UnsupportedBackendError( "Cannot create backend-specific statistics aggregator because {} is not supported!".format( model_backend.value diff --git a/nncf/common/graph/patterns/manager.py b/nncf/common/graph/patterns/manager.py index 824a3b4a4c5..eae1546aad5 100644 --- a/nncf/common/graph/patterns/manager.py +++ b/nncf/common/graph/patterns/manager.py @@ -48,6 +48,11 @@ def _get_backend_hw_patterns_map(backend: BackendType) -> Dict[HWFusedPatternNam if backend == BackendType.TORCH: from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS + registry = PT_HW_FUSED_PATTERNS.registry_dict + return registry + if backend == BackendType.TORCH_FX: + from nncf.torch.hardware.fused_patterns import PT_HW_FUSED_PATTERNS + registry = PT_HW_FUSED_PATTERNS.registry_dict return registry raise ValueError(f"Hardware-fused patterns not implemented for {backend} backend.") @@ -76,6 +81,11 @@ def _get_backend_ignored_patterns_map( if backend == BackendType.TORCH: from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS + registry = PT_IGNORED_PATTERNS.registry_dict + return registry + if backend == BackendType.TORCH_FX: + from nncf.torch.quantization.ignored_patterns import PT_IGNORED_PATTERNS + registry = PT_IGNORED_PATTERNS.registry_dict return registry raise ValueError(f"Ignored patterns not implemented for {backend} backend.") diff --git a/nncf/common/utils/backend.py b/nncf/common/utils/backend.py index 9dcd6a57d71..e5de38e5ca4 100644 --- a/nncf/common/utils/backend.py +++ b/nncf/common/utils/backend.py @@ -20,6 +20,7 @@ class BackendType(Enum): TORCH = "Torch" + TORCH_FX = "TorchFX" TENSORFLOW = "Tensorflow" ONNX = "ONNX" OPENVINO = "OpenVINO" @@ -33,6 +34,7 @@ def get_available_backends() -> List[BackendType]: """ frameworks = [ ("torch", BackendType.TORCH), + ("torch.fx", BackendType.TORCH_FX), ("tensorflow", BackendType.TENSORFLOW), ("onnx", BackendType.ONNX), ("openvino.runtime", BackendType.OPENVINO), @@ -51,14 +53,27 @@ def get_available_backends() -> List[BackendType]: def is_torch_model(model: TModel) -> bool: """ - Returns True if the model is an instance of torch.nn.Module, otherwise False. + Returns True if the model is an instance of torch.nn.Module and not a torch.fx.GraphModule, otherwise False. :param model: A target model. - :return: True if the model is an instance of torch.nn.Module, otherwise False. + :return: True if the model is an instance of torch.nn.Module and not torch.fx.GraphModule, otherwise False. """ import torch + import torch.fx - return isinstance(model, torch.nn.Module) + return not isinstance(model, torch.fx.GraphModule) and isinstance(model, torch.nn.Module) + + +def is_torch_fx_model(model: TModel) -> bool: + """ + Returns True if the model is an instance of torch.fx.GraphModule, otherwise False. + + :param model: A target model. + :return: True if the model is an instance of torch.fx.GraphModule, otherwise False. + """ + import torch.fx + + return isinstance(model, torch.fx.GraphModule) def is_tensorflow_model(model: TModel) -> bool: @@ -118,6 +133,9 @@ def get_backend(model: TModel) -> BackendType: """ available_backends = get_available_backends() + if BackendType.TORCH_FX in available_backends and is_torch_fx_model(model): + return BackendType.TORCH_FX + if BackendType.TORCH in available_backends and is_torch_model(model): return BackendType.TORCH diff --git a/nncf/experimental/torch_fx/__init__.py b/nncf/experimental/torch_fx/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/nncf/experimental/torch_fx/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nncf/experimental/torch_fx/engine.py b/nncf/experimental/torch_fx/engine.py new file mode 100644 index 00000000000..5f9dc2ac221 --- /dev/null +++ b/nncf/experimental/torch_fx/engine.py @@ -0,0 +1,48 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Tuple, Union + +import torch +from torch import nn + +from nncf.common.engine import Engine + + +class FXEngine(Engine): + """ + Engine for the Pytorch FX backend. + """ + + def __init__(self, model: nn.Module): + """ + Constructor. + + :param model: Pytorch module to infer. + """ + + self._model = model + + def infer( + self, input_data: Union[torch.Tensor, Tuple[torch.Tensor], Dict[str, torch.Tensor]] + ) -> Union[torch.Tensor, Dict[str, Any]]: + """ + Runs Torch model on the provided input. + + :param input_data: Inputs for the model. + :return: Model outputs. + """ + + if isinstance(input_data, dict): + return self._model(**input_data) + if isinstance(input_data, tuple): + return self._model(*input_data) + return self._model(input_data) diff --git a/nncf/experimental/torch_fx/model_transformer.py b/nncf/experimental/torch_fx/model_transformer.py new file mode 100644 index 00000000000..48b3cf0c1f1 --- /dev/null +++ b/nncf/experimental/torch_fx/model_transformer.py @@ -0,0 +1,183 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +# from functools import partial +from typing import Callable, List, Union + +import torch +import torch.fx +from torch.fx.passes.split_utils import split_by_tags + +from nncf.common.graph.model_transformer import ModelTransformer +from nncf.common.graph.transformations.commands import Command +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.commands import TransformationType +from nncf.torch.graph.transformations.commands import PTModelExtractionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.graph.transformations.layout import PTTransformationLayout + + +class FXModuleInsertionCommand(Command): + def __init__( + self, + target_points: List[PTTargetPoint], + module_to_insert: torch.nn.Module, + priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, + ): + super().__init__(TransformationType.INSERT) + self.target_points = target_points + self.module_to_insert = module_to_insert + self.priority = priority + + +class FXApplyTransformationCommand(Command): + def __init__( + self, + transformation_fn: Callable[[torch.fx.GraphModule], None], + priority: Union[TransformationPriority, int] = TransformationPriority.DEFAULT_PRIORITY, + ): + super().__init__(TransformationType.INSERT) + self.tranformation_fn = transformation_fn + self.priority = priority + + +class FXModelTransformer(ModelTransformer): + """ + Applies transformations upon Torch FX model. + """ + + # TODO: manage priorities of transformations + + def __init__(self, model: torch.fx.GraphModule): + super().__init__(model) + + self._command_transformation_ordered_pairs = [ + # TODO: Move the module insertion command to a transformation + (FXApplyTransformationCommand, self._apply_transformation), + (FXModuleInsertionCommand, self._apply_module_insertion), + (PTModelExtractionCommand, self._apply_model_extraction), + ] + + def transform(self, transformation_layout: PTTransformationLayout) -> torch.fx.GraphModule: + transformations = transformation_layout.transformations + aggregated_transformations = defaultdict(list) + for transformation in transformations: + aggregated_transformations[transformation.__class__].append(transformation) + + model = self._model + for transformation_cls, transformation_fn in self._command_transformation_ordered_pairs: + transformations = aggregated_transformations[transformation_cls] + if transformations: + model = transformation_fn(model, transformations) + + # Do not eliminate dead code as + # the dead code is coputing statistics :) + # model.graph.eliminate_dead_code() + model.recompile() + return model + + @staticmethod + def _apply_model_extraction( + model: torch.fx.GraphModule, + transformations: List[PTModelExtractionCommand], + ) -> torch.fx.GraphModule: + transformation = transformations[-1] + assert len(transformation.input_node_names) == 1 + assert transformation.input_node_names == transformation.output_node_names + node_name = transformation.input_node_names[0] + + tags = ["before", "extracted", "after"] + i = 0 + for node in model.graph.nodes: + if node.name == node_name: + node.tag = tags[1] + weights = [node.all_input_nodes[1]] + while weights: + w_node = weights.pop() + assert w_node.tag in tags[0:2] + w_node.tag = tags[1] + weights.extend(w_node.all_input_nodes) + i = 2 + continue + node.tag = tags[i] + + splitted_gm = split_by_tags(model, tags) + return splitted_gm.extracted + + @staticmethod + def _apply_module_insertion( + model: torch.fx.GraphModule, + transformations: List[FXModuleInsertionCommand], + ) -> torch.fx.GraphModule: + """ + Applies insertion of PTSharedFnInsertionCommand commands. For each command method inserts + a torch module to the torch.fx.GraphModule and inserts call hooks for each command target points. + + :param model: Model to apply transformations. + :param transformations: List of the bias correction transformations. + :param device: Target device for the insertion functions. Applies only to + functions which are subclassed from torch.nn.Module. Do nothing in case device is None. + :return: A modified torch.fx.GraphModule. + """ + for transformation in transformations: + # Set fn to the model as an attribute + module_to_insert = transformation.module_to_insert + module_name_in_model = ( + ";".join( + "_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) + for tp in transformation.target_points + ) + + "_" + + str(id(module_to_insert)) + ) + assert not hasattr(model, module_name_in_model) + setattr(model, module_name_in_model, module_to_insert) + # Insert call_module nodes to the model + for target_point in transformation.target_points: + FXModelTransformer._create_call_module_node(model.graph, target_point, module_name_in_model) + return model + + @staticmethod + def get_graph_node_by_name(graph, name): + for node in graph.nodes: + if node.name == name: + return node + raise RuntimeError(f"Node with name {name} is not found") + + @staticmethod + def _get_target_node(graph: torch.fx.Graph, target_point: PTTargetPoint): + target_type = target_point.target_type + target_node = FXModelTransformer.get_graph_node_by_name(graph, target_point.target_node_name) + if target_type in [TargetType.OPERATOR_PRE_HOOK, TargetType.OPERATION_WITH_WEIGHTS]: + target_node = target_node.all_input_nodes[target_point.input_port_id] + elif target_type == TargetType.OPERATOR_POST_HOOK: + pass + else: + raise RuntimeError(f"Unsupported target type: {target_type} for target_point: {target_point}") + return target_node + + @staticmethod + def _create_call_module_node(graph: torch.fx.Graph, target_point: PTTargetPoint, module_name: str): + target_node = FXModelTransformer._get_target_node(graph, target_point) + with graph.inserting_after(target_node): + graph.create_node("call_module", module_name, (target_node,), {}, name=module_name + "_graph_node") + + @staticmethod + def _apply_transformation( + model: torch.fx.GraphModule, + transformations: List[FXApplyTransformationCommand], + ) -> torch.fx.GraphModule: + for transformation in transformations: + transformation.tranformation_fn(model) + return model diff --git a/nncf/experimental/torch_fx/nncf_graph_builder.py b/nncf/experimental/torch_fx/nncf_graph_builder.py new file mode 100644 index 00000000000..9990ee3bf2f --- /dev/null +++ b/nncf/experimental/torch_fx/nncf_graph_builder.py @@ -0,0 +1,167 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from itertools import chain +from typing import Tuple + +import torch.fx +from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_ + +import nncf.torch.graph.operator_metatypes as om +from nncf.common.graph import NNCFGraph +from nncf.common.graph import NNCFNode +from nncf.common.graph.layer_attributes import Dtype +from nncf.common.graph.operator_metatypes import UnknownMetatype +from nncf.common.logging import nncf_logger +from nncf.experimental.torch_fx.transformations import separate_conv_and_bias +from nncf.experimental.torch_fx.transformations import separate_linear_and_bias +from nncf.experimental.torch_fx.transformations import view_to_reshape +from nncf.torch.graph.graph import PTNNCFGraph +from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES + + +class GraphConverter: + """ + Builds the NNCFGraph from an OpenVINO model. + """ + + @staticmethod + def _get_leaf_node(module: torch.nn.Module, node: torch.fx.Node) -> torch.nn.Module: + py_obj = module + assert isinstance(node.target, str) + atoms = node.target.split(".") + for atom in atoms: + if not hasattr(py_obj, atom): + raise RuntimeError(str(py_obj) + " does not have attribute " + atom + "!") + py_obj = getattr(py_obj, atom) + return py_obj + + @staticmethod + def _get_node_type_and_metatype(node: torch.fx.Node) -> Tuple[str, om.OperatorMetatype]: + if node.op == "placeholder": + node_type = "input" + node_metatype = om.PTInputNoopMetatype + elif node.op == "output": + node_type = "output" + node_metatype = om.PTOutputNoopMetatype + elif node.op == "get_attr": + node_type = "get_attr" + node_metatype = om.PTConstNoopMetatype + elif node.op in ("call_function",): + if hasattr(node.target, "overloadpacket"): + node_type = str(node.target.overloadpacket).split(".")[1] + elif node.target.__name__ == "getitem": + node_type = "__getitem__" + else: + # TODO: get correct nodes types from this nodes as well + node_type = str(node.target) + node_metatype = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type) + else: + node_type = node.op + node_metatype = UnknownMetatype + if node_metatype is UnknownMetatype: + nncf_logger.info(f"Unknown metatype for node: {node}") + return node_type, node_metatype + + @staticmethod + def create_nncf_graph(model: torch.fx.GraphModule) -> NNCFGraph: + """ + Creates NNCFGraph from GraphModule. + All nodes from model which have valid metatype are added to NNCFGraph. + Then, corresponding edges are added to the NNCFGraph with shape, type, output and input port ids. + + :param model: torch fx GraphModule. + :return: NNCFGraph. + """ + + _fuse_conv_bn_(model) + # BN fuses to conv bias, conv+bias joined op + # needs to be splited for nncf + separate_linear_and_bias(model) + separate_conv_and_bias(model) + view_to_reshape(model) + + nncf_graph = PTNNCFGraph() + + for source_node in model.graph.nodes: + + node_type, node_metatype = GraphConverter._get_node_type_and_metatype(source_node) + + nncf_node = nncf_graph.add_nncf_node( + node_name=source_node.name, + node_type=node_type, + node_metatype=node_metatype, # layer_attributes, + ) + + def get_module_params_or_buffers(): + for pname, ptensor in chain(leaf_module.named_parameters(), leaf_module.named_buffers()): + pname1 = source_node.name + "." + pname + nncf_param_node = nncf_graph.add_nncf_node( + pname1, + "parameter" if isinstance(ptensor, torch.nn.Parameter) else "buffer", + om.PTConstNoopMetatype, + ) + # TODO: Use valid tensor_shape, input_port_id, output_port_id + nncf_graph.add_edge_between_nncf_nodes( + nncf_param_node, nncf_node, tensor_shape=[1, 1, 1, 1], input_port_id=0, output_port_id=0 + ) + + if source_node.op == "call_module": + leaf_module = GraphConverter._get_leaf_node(model, source_node) + + if not isinstance(leaf_module, torch.fx.GraphModule): + get_module_params_or_buffers() + + for source_node in model.graph.nodes: + + source_nncf_node = nncf_graph.get_node_by_name(source_node.name) + for idx, dist_node in enumerate(source_node.users): + dist_node_id = nncf_graph.get_node_by_name(dist_node.name).node_id + input_port_id, output_port_id, tensor_shape = GraphConverter.get_edge_params( + model, source_node, source_nncf_node, dist_node, idx + ) + + nncf_graph.add_edge_between_nncf_nodes( + source_nncf_node.node_id, + dist_node_id, + tensor_shape=tensor_shape, + input_port_id=input_port_id, + output_port_id=output_port_id, + dtype=Dtype.FLOAT, + ) + + return nncf_graph + + @staticmethod + def get_edge_params( + model, source_node: torch.fx.Node, source_nncf_node: NNCFNode, dist_node: torch.fx.Node, output_idx: int + ): + output_port_id = 0 + if source_node.op in ("get_attr",): + tensor_shape = tuple(getattr(model, source_node.target).shape) + elif "val" in source_node.meta: + if source_nncf_node.metatype is om.PTBatchNormMetatype: + tensor = source_node.meta["val"][0] + elif source_nncf_node.metatype is om.PTSplitMetatype: + tensor = source_node.meta["val"][output_idx] + # Assume every split outputs corresponds to an unique output_port_id + output_port_id = output_idx + else: + tensor = source_node.meta["val"] + tensor_shape = tuple(tensor.shape) + else: + nncf_logger.info( + f"Edge shape between {source_node.name} and {dist_node.name} is unknown. Using [1,1,1,1] instead." + ) + tensor_shape = [1, 1, 1, 1] + + input_port_id = dist_node.all_input_nodes.index(source_node) + return input_port_id, output_port_id, tensor_shape diff --git a/nncf/experimental/torch_fx/quantization/__init__.py b/nncf/experimental/torch_fx/quantization/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/nncf/experimental/torch_fx/quantization/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nncf/experimental/torch_fx/quantization/quantize_model.py b/nncf/experimental/torch_fx/quantization/quantize_model.py new file mode 100644 index 00000000000..0f40800fb49 --- /dev/null +++ b/nncf/experimental/torch_fx/quantization/quantize_model.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from copy import deepcopy +from typing import Optional + +import torch +import torch.fx +from torch.ao.quantization.pt2e.duplicate_dq_pass import DuplicateDQPass +from torch.ao.quantization.pt2e.port_metadata_pass import PortNodeMetaForQDQ +from torch.ao.quantization.pt2e.qat_utils import _fold_conv_bn_qat +from torch.ao.quantization.pt2e.utils import _disallow_eval_train +from torch.fx import GraphModule +from torch.fx.passes.infra.pass_manager import PassManager + +import nncf +from nncf.common.factory import NNCFGraphFactory +from nncf.common.quantization.structs import QuantizationPreset +from nncf.common.quantization.structs import QuantizationScheme +from nncf.data import Dataset +from nncf.experimental.torch_fx.transformations import merge_conv_and_bias +from nncf.parameters import ModelType +from nncf.parameters import QuantizationMode +from nncf.parameters import TargetDevice +from nncf.quantization.advanced_parameters import AdvancedQuantizationParameters +from nncf.quantization.advanced_parameters import QuantizationParameters +from nncf.quantization.algorithms.post_training.algorithm import PostTrainingQuantization +from nncf.scopes import IgnoredScope + +DEFAULT_RANGE_TYPE = "mean_min_max" + + +def quantize_impl( + model: torch.fx.GraphModule, + calibration_dataset: Dataset, + mode: Optional[QuantizationMode] = None, + preset: Optional[QuantizationPreset] = None, + target_device: TargetDevice = TargetDevice.ANY, + subset_size: int = 300, + fast_bias_correction: bool = True, + model_type: Optional[ModelType] = None, + ignored_scope: Optional[IgnoredScope] = None, + advanced_parameters: Optional[AdvancedQuantizationParameters] = None, +) -> torch.nn.Module: + """ + Implementation of the `quantize()` method for the Torch FX backend. + """ + if fast_bias_correction is False: + raise ValueError(f"fast_bias_correction={fast_bias_correction} is not supported") + if target_device == TargetDevice.CPU_SPR: + raise nncf.InternalError("target_device == CPU_SPR is not supported") + if mode is not None: + raise ValueError(f"mode={mode} is not supported") + + original_graph_meta = model.meta + + copied_model = deepcopy(model) + + if advanced_parameters is None: + advanced_parameters = AdvancedQuantizationParameters() + # torch.fx supports only assymetric activations quantization + # force to use only this type of quantization + activations_quantization_params = advanced_parameters.activations_quantization_params + if activations_quantization_params is None: + activations_quantization_params = QuantizationParameters() + + activations_quantization_params.mode = QuantizationScheme.ASYMMETRIC + advanced_parameters.activations_quantization_params = activations_quantization_params + + quantization_algorithm = PostTrainingQuantization( + preset=preset, + target_device=target_device, + subset_size=subset_size, + fast_bias_correction=fast_bias_correction, + model_type=model_type, + ignored_scope=ignored_scope, + advanced_parameters=advanced_parameters, + ) + nncf_graph = NNCFGraphFactory.create(copied_model) + quantized_model = quantization_algorithm.apply(copied_model, nncf_graph, dataset=calibration_dataset) + merge_conv_and_bias(quantized_model) + + # Magic. Without this call compiled model + # is not preformant + quantized_model = GraphModule(quantized_model, quantized_model.graph) + + quantized_model = _fold_conv_bn_qat(quantized_model) + pm = PassManager([DuplicateDQPass()]) + + quantized_model = pm(quantized_model).graph_module + pm = PassManager([PortNodeMetaForQDQ()]) + quantized_model = pm(quantized_model).graph_module + + quantized_model.meta.update(original_graph_meta) + quantized_model = _disallow_eval_train(quantized_model) + + return quantized_model diff --git a/nncf/experimental/torch_fx/statistics/__init__.py b/nncf/experimental/torch_fx/statistics/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/nncf/experimental/torch_fx/statistics/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nncf/experimental/torch_fx/statistics/aggregator.py b/nncf/experimental/torch_fx/statistics/aggregator.py new file mode 100644 index 00000000000..f1ce6ff05ec --- /dev/null +++ b/nncf/experimental/torch_fx/statistics/aggregator.py @@ -0,0 +1,99 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import numpy as np +import torch + +from nncf.common.factory import TModel +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.transformations.commands import TransformationPriority +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.common.tensor_statistics.aggregator import StatisticPointsContainer +from nncf.common.tensor_statistics.aggregator import StatisticsAggregator +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.torch_fx.model_transformer import FXModuleInsertionCommand +from nncf.tensor import Tensor +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.return_types import maybe_get_values_from_torch_return_type + + +class TensorCollectorModule(torch.nn.Module): + """ + torch.nn.Module which calls given collector in forward + """ + + def __init__(self, collector: TensorCollector): + super().__init__() + self._collector = collector + + def forward(self, x: torch.Tensor): + """ + Register inputs hook function. + + :parameter x: tensor to register in hook. + :return: tensor to register in hook. + """ + x_unwrapped = maybe_get_values_from_torch_return_type(x) + self._collector.register_input_for_all_reducers(Tensor(x_unwrapped)) + return x + + +class FXStatisticsAggregator(StatisticsAggregator): + HOOKS_GROUP_NAME = "statistics_hooks" + + def collect_statistics(self, model: NNCFNetwork, graph: NNCFGraph) -> None: + with torch.no_grad(): + super().collect_statistics(model, graph) + # All statistics are collected as a dead code, + # so eliminate dead core removed statistcs collector + # from the target model. No additional code required + # for that, horay! + model.graph.eliminate_dead_code() + model.recompile() + + def _register_statistics(self, outputs: Dict[str, Tensor], statistic_points: StatisticPointsContainer) -> None: + return + + def _get_transformation_layout_extra_outputs( + self, statistic_points: StatisticPointsContainer + ) -> TransformationLayout: + transformation_layout = TransformationLayout() + transformation_commands = [] + + for _statistic_points in statistic_points.values(): + for _statistic_point in _statistic_points: + for collectors in _statistic_point.algorithm_to_tensor_collectors.values(): + for collector in collectors: + transformation_commands.append( + FXModuleInsertionCommand( + [_statistic_point.target_point], + TensorCollectorModule(collector), + TransformationPriority.FP32_TENSOR_STATISTICS_OBSERVATION, + ) + ) + + for transformation_command in transformation_commands: + transformation_layout.register(transformation_command) + + return transformation_layout + + @staticmethod + def _get_merged_statistic_points( + statistic_points: StatisticPointsContainer, model: TModel, graph: NNCFGraph + ) -> StatisticPointsContainer: + # TODO: mirgate to experimental statistic collector and use common merging algorithm + return statistic_points + + @staticmethod + def _process_outputs(outputs: Dict[str, np.ndarray]) -> Dict[str, Tensor]: + return outputs diff --git a/nncf/experimental/torch_fx/transformations.py b/nncf/experimental/torch_fx/transformations.py new file mode 100644 index 00000000000..d572c06b120 --- /dev/null +++ b/nncf/experimental/torch_fx/transformations.py @@ -0,0 +1,350 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import Callable, List, Optional + +import torch +import torch.fx +from torch.ao.quantization.fx.utils import create_getattr_from_value +from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node +from torch.ao.quantization.pt2e.utils import _is_conv +from torch.quantization.fake_quantize import FakeQuantize + +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.transformations.commands import TargetType +from nncf.experimental.torch_fx.model_transformer import FXModelTransformer +from nncf.torch.graph.transformations.commands import PTTargetPoint + + +def fake_quantize_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]): + def fake_quantize_insertion_transformation(model: torch.fx.GraphModule): + module_attr_name = _set_module_to_the_graph_module(model, quantizer, target_points) + graph = model.graph + for target_point in target_points: + target_node = FXModelTransformer._get_target_node(model.graph, target_point) + with graph.inserting_after(target_node): + fq_node = graph.create_node( + "call_module", module_attr_name, (target_node,), {}, name=module_attr_name + "_quantizer" + ) + for user in list(target_node.users): + if user is fq_node: + continue + user.replace_input_with(target_node, fq_node) + + return fake_quantize_insertion_transformation + + +def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor): + def bias_update_transformation(model: torch.fx.GraphModule): + graph = model.graph + target_node_name = node.node_name + graph_node = FXModelTransformer.get_graph_node_by_name(graph, target_node_name) + bias_node = next(iter(graph_node.users)) + with graph.inserting_before(bias_node): + new_constant = create_getattr_from_value(model, graph, target_node_name + "_shifted_bias", value) + args = list(bias_node.args) + args[1] = new_constant + bias_node.args = tuple(args) + graph.eliminate_dead_code() + + return bias_update_transformation + + +def qdq_insertion_tranformation_builder(quantizer: FakeQuantize, target_points: List[PTTargetPoint]): + def qdq_insertion_tranformation(model: torch.fx.GraphModule): + if any(tp.target_type != TargetType.OPERATION_WITH_WEIGHTS for tp in target_points) and len(target_points) > 1: + raise RuntimeError + for target_point in target_points: + target_node = FXModelTransformer._get_target_node(model.graph, target_point) + insert_one_qdq(model, target_node, quantizer, target_point) + + return qdq_insertion_tranformation + + +def insert_one_qdq( + model: torch.fx.GraphModule, target_node: torch.fx.Node, quantizer: FakeQuantize, target_point: PTTargetPoint +): + # Copied from torch.ao.quantization.quantize_pt2e.convert_pt2e + # 1. extract information for inserting q/dq node from activation_post_process + node_type = "call_function" + quantize_op: Optional[Callable] = None + # scale, zero_point = activation_post_process.calculate_qparams() # type: ignore[attr-defined, operator] + dtype = torch.int8 if quantizer.quant_min < 0 else torch.uint8 + if quantizer.is_per_channel: + qparams = { + "_scale_": quantizer.scale, + "_zero_point_": quantizer.zero_point, + "_axis_": quantizer.ch_axis, + "_quant_min_": quantizer.quant_min, + "_quant_max_": quantizer.quant_max, + "_dtype_": dtype, + } + quantize_op = torch.ops.quantized_decomposed.quantize_per_channel.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_channel.default + else: + qparams = { + "_scale_": float(quantizer.scale), + "_zero_point_": int(quantizer.zero_point), + "_quant_min_": quantizer.quant_min, + "_quant_max_": quantizer.quant_max, + "_dtype_": dtype, + } + quantize_op = torch.ops.quantized_decomposed.quantize_per_tensor.default + dequantize_op = torch.ops.quantized_decomposed.dequantize_per_tensor.default + + # 2. replace activation_post_process node with quantize and dequantize + graph = model.graph + # TODO: use metatype to get correct input_port_id + # Do not quantize already quantized nodes + # inserting_before handle only order in the graph generated code. + # so, inserting quantize-dequantize and all constant nodes before the usage of the nodes + with graph.inserting_before(target_node): + quantize_op_inputs = [target_node] + for key, value_or_node in qparams.items(): + # TODO: we can add the information of whether a value needs to + # be registered as an attribute in qparams dict itself + if key in ["_scale_", "_zero_point_"] and (not isinstance(value_or_node, (float, int))): + # For scale and zero_point values we register them as buffers in the root module. + # However, note that when the values are not tensors, as in the case of + # per_tensor quantization, they will be treated as literals. + # However, registering them as a node seems to cause issue with dynamo + # tracing where it may consider tensor overload as opposed to default. + # With extra check of scale and zero_point being scalar, it makes + # sure that the default overload can be used. + # TODO: maybe need more complex attr name here + qparam_node = create_getattr_from_value(model, graph, target_node.name + key, value_or_node) + quantize_op_inputs.append(qparam_node) + else: + # for qparams that are not scale/zero_point (like axis, dtype) we store + # them as literals in the graph. + quantize_op_inputs.append(value_or_node) + with graph.inserting_after(target_node): + quantized_node = graph.create_node(node_type, quantize_op, tuple(quantize_op_inputs), {}) + # use the same qparams from quantize op + dq_inputs = [quantized_node] + quantize_op_inputs[1:] + user_dq_nodes = [] + with graph.inserting_after(quantized_node): + for user in target_node.users: + if user is quantized_node: + continue + user_dq_nodes.append((user, graph.call_function(dequantize_op, tuple(dq_inputs), {}))) + + for user, dq_node in user_dq_nodes: + user.replace_input_with(target_node, dq_node) + + +def _set_module_to_the_graph_module( + model: torch.fx.GraphModule, module_to_insert: torch.nn.Module, target_points: List[PTTargetPoint] +) -> str: + """ + Sets given module to the given torch.fx.GraphModule with unique name. + """ + module_to_insert = module_to_insert + module_name_in_model = ( + ";".join( + "_".join((tp.target_node_name, str(tp.input_port_id), str(tp.target_type.value))) for tp in target_points + ) + + "_" + + str(id(module_to_insert)) + ) + assert not hasattr(model, module_name_in_model) + setattr(model, module_name_in_model, module_to_insert) + return module_name_in_model + + +def _is_linear(n: torch.fx.Node): + return n.op == "call_function" and n.target in [torch.ops.aten.linear.default] + + +def separate_linear_and_bias(model: torch.fx.GraphModule): + """ + Separates one joined linear+bias node to two nodes: conv and bias. + Needed as nncf does not expect joined conv + """ + add_node_target = torch.ops.aten.add_.Tensor + for n in model.graph.nodes: + if not _is_linear(n): + continue + if len(n.args) < 3 or n.args[2] is None: + continue + linear_node = n + linear_bias_node = linear_node.args[2] + conv_bias_value = _get_tensor_constant_from_node(linear_bias_node, model) + args = list(n.args) + args[2] = None + linear_node.args = tuple(args) + with model.graph.inserting_after(linear_node): + new_linear_bias_node = create_getattr_from_value( + model, + model.graph, + linear_bias_node.name + "_", + conv_bias_value, + ) + with model.graph.inserting_after(new_linear_bias_node): + add_node = model.graph.create_node( + "call_function", add_node_target, (linear_node, new_linear_bias_node), {} + ) + for user in list(linear_node.users): + if user is add_node: + continue + user.replace_input_with(linear_node, add_node) + if "val" in linear_node.meta: + add_node.meta["val"] = linear_node.meta["val"] + model.graph.eliminate_dead_code() + model.recompile() + + +def view_to_reshape(model: torch.fx.GraphModule): + for n in model.graph.nodes: + if not (n.op == "call_function" and n.target in [torch.ops.aten.view.default]): + continue + with model.graph.inserting_after(n): + reshape = model.graph.create_node("call_function", torch.ops.aten.reshape.default, tuple(n.args), {}) + reshape.meta = n.meta + + for user in list(n.users): + user.replace_input_with(n, reshape) + + model.graph.eliminate_dead_code() + model.recompile() + + +def separate_conv_and_bias(model: torch.fx.GraphModule): + """ + Separates one joined conv+bias node to two nodes: conv and bias. + Needed as nncf does not expect joined conv + """ + add_node_target = torch.ops.aten.add_.Tensor + for n in model.graph.nodes: + if not _is_conv(n): + continue + if len(n.args) < 3 or n.args[2] is None: + continue + conv_node = n + dims = len(_get_tensor_constant_from_node(conv_node.args[1], model).shape) + conv_bias_node = conv_node.args[2] + conv_bias_value = _get_tensor_constant_from_node(conv_bias_node, model) + args = list(n.args) + args[2] = None + conv_node.args = tuple(args) + with model.graph.inserting_after(conv_node): + new_conv_bias_node = create_getattr_from_value( + model, + model.graph, + conv_bias_node.name + "_", + conv_bias_value.reshape( + ( + 1, + -1, + ) + + (1,) * (dims - 2) + ), + ) + with model.graph.inserting_after(new_conv_bias_node): + add_node = model.graph.create_node("call_function", add_node_target, (conv_node, new_conv_bias_node), {}) + for user in list(conv_node.users): + if user is add_node: + continue + user.replace_input_with(conv_node, add_node) + + if "val" in conv_node.meta: + add_node.meta["val"] = conv_node.meta["val"] + model.graph.eliminate_dead_code() + model.recompile() + + +def merge_conv_and_bias(model: torch.fx.GraphModule): + """ + Separates one joined conv+bias node to two nodes: conv and bias. + Needed as nncf does not expect joined conv + """ + add_node_targets = (torch.ops.aten.add_.Tensor,) + for n in model.graph.nodes: + if not _is_conv(n): + continue + if len(n.args) > 2 and n.args[2] is not None: + continue + bias_node = next(iter(n.users)) + if len(n.users) > 1 or bias_node.target not in add_node_targets: + continue + conv_node = n + const_node = None + for node in bias_node.all_input_nodes: + if node is not conv_node: + const_node = node + break + assert const_node is not None + bias_value = _get_tensor_constant_from_node(const_node, model).squeeze() + with model.graph.inserting_before(conv_node): + new_bias_node = create_getattr_from_value(model, model.graph, const_node.name + "_", bias_value) + args = list(conv_node.args) + args[2] = new_bias_node + conv_node.args = tuple(args) + for user in list(bias_node.users): + user.replace_input_with(bias_node, conv_node) + + model.graph.eliminate_dead_code() + model.recompile() + + +def _is_scaled_dot_product_attention(n: torch.fx.Node): + return n.op == "call_function" and n.target in [torch.ops.aten.scaled_dot_product_attention.default] + + +def _unfold_sdp(model: torch.fx.GraphModule, node: torch.fx.Node): + transpose_target = torch.ops.aten.transpose.int + matmul_target = torch.ops.aten.matmul.default + mul_target = torch.ops.aten.multiply.Scalar + softmax_target = torch.ops.aten.softmax.int + + query, key, value = node.args + q, k, v = (n.meta["val"] for n in node.args) + n = query.meta["val"].shape[-1] + scale_factor = 1 / math.sqrt(n) + + with model.graph.inserting_before(node): + k_transposed = model.graph.create_node("call_function", transpose_target, (key, -2, -1), {}) + k = k.transpose(-2, -1) + k_transposed.meta["val"] = torch.clone(k) + + sa = model.graph.create_node("call_function", matmul_target, (query, k_transposed), {}) + attn_value = q @ k + sa.meta["val"] = torch.clone(attn_value) + + sa_scaled = model.graph.create_node("call_function", mul_target, (sa, float(scale_factor)), {}) + sa_scaled.meta["val"] = torch.clone(attn_value) + + softmax = model.graph.create_node("call_function", softmax_target, (sa_scaled, -1), {}) + softmax.meta["val"] = torch.clone(attn_value) + + result = model.graph.create_node("call_function", matmul_target, (softmax, value), {}) + r = attn_value @ v + result.meta["val"] = torch.clone(r) + + for user in list(node.users): + user.replace_input_with(node, result) + model.graph.eliminate_dead_code() + + +@staticmethod +def unfold_scaled_dot_product_attention(model: torch.fx.GraphModule): + for n in model.graph.nodes: + if not _is_scaled_dot_product_attention(n): + continue + args = n.args + if len(args) > 3: + raise NotImplementedError( + f"Unfolding of scaled dot product attention node {n}" " with more than 3 inputs is not implemented yet" + ) + _unfold_sdp(model, n) + model.graph.eliminate_dead_code() + model.recompile() diff --git a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py index 0b4e322d9d8..c0fcdae9e6f 100644 --- a/nncf/quantization/algorithms/fast_bias_correction/algorithm.py +++ b/nncf/quantization/algorithms/fast_bias_correction/algorithm.py @@ -93,7 +93,7 @@ def __init__( @property def available_backends(self) -> List[BackendType]: - return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH] + return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX] def _set_backend_entity(self, model: TModel) -> None: """ @@ -116,6 +116,12 @@ def _set_backend_entity(self, model: TModel) -> None: from nncf.quantization.algorithms.fast_bias_correction.torch_backend import PTFastBiasCorrectionAlgoBackend self._backend_entity = PTFastBiasCorrectionAlgoBackend() + elif model_backend == BackendType.TORCH_FX: + from nncf.quantization.algorithms.fast_bias_correction.torch_fx_backend import ( + FXFastBiasCorrectionAlgoBackend, + ) + + self._backend_entity = FXFastBiasCorrectionAlgoBackend() else: raise nncf.UnsupportedBackendError( "Cannot return backend-specific entity because {} is not supported!".format(model_backend.value) diff --git a/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py new file mode 100644 index 00000000000..089afd4ab11 --- /dev/null +++ b/nncf/quantization/algorithms/fast_bias_correction/torch_fx_backend.py @@ -0,0 +1,116 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch +import torch.fx +from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node + +import nncf.torch.graph.operator_metatypes as om +from nncf.common.graph import NNCFGraph +from nncf.common.graph import NNCFNode +from nncf.common.graph.definitions import NNCFGraphNodeType +from nncf.common.graph.transformations.commands import TargetType +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch_fx.transformations import bias_update_transformation_builder +from nncf.quantization.algorithms.fast_bias_correction.backend import FastBiasCorrectionAlgoBackend +from nncf.tensor import Tensor +from nncf.torch.graph.transformations.commands import PTModelExtractionCommand +from nncf.torch.graph.transformations.commands import PTTargetPoint +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.tensor_statistics.collectors import get_mean_statistic_collector + + +class FXFastBiasCorrectionAlgoBackend(FastBiasCorrectionAlgoBackend): + TARGET_TYPE_TO_PT_INS_TYPE_MAP = { + TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK, + TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, + } + + @staticmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: + if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: + port_id = None + if target_type in FXFastBiasCorrectionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP: + target_type = FXFastBiasCorrectionAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type] + return PTTargetPoint(target_type, target_node_name, input_port_id=port_id) + + @staticmethod + def create_bias_correction_command( + node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph + ) -> FXApplyTransformationCommand: + return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data)) + + @staticmethod + def model_extraction_command( + input_ids: List[Tuple[str, int]], output_ids: List[Tuple[str, int]] + ) -> PTModelExtractionCommand: + return PTModelExtractionCommand([input_ids[0][0]], [output_ids[0][0]]) + + @staticmethod + def mean_statistic_collector( + channel_axis: int, + inplace: bool, + num_samples: Optional[int] = None, + window_size: Optional[int] = None, + ) -> TensorCollector: + return get_mean_statistic_collector(num_samples, channel_axis, window_size) + + @staticmethod + def get_sub_input_output_names(subgraph: NNCFNetwork) -> Tuple[str, str]: + # Pytorch does not have name for extracted node + return None, None + + @staticmethod + def create_input_data(shape: Tuple[int], data: List[Tensor], input_name: str, channel_axis: int) -> torch.Tensor: + blob = torch.zeros(shape, dtype=data[0].data.dtype, device=data[0].data.device) + for j, idx in enumerate(np.ndindex(blob.shape[channel_axis])): + index = tuple(slice(None) if i != channel_axis else idx for i in range(blob.ndim)) + blob[index] = data[j].data + return blob + + @staticmethod + def get_bias_value(node: NNCFNode, nncf_graph: NNCFGraph, model: torch.fx.GraphModule) -> Tensor: + # TODO: make a node_name_vs_node map to speed up the process + from nncf.experimental.torch_fx.model_transformer import FXModelTransformer + + bias_node = nncf_graph.get_next_nodes(node)[0] + graph_bias_node = FXModelTransformer.get_graph_node_by_name(model.graph, bias_node.node_name) + return Tensor(_get_tensor_constant_from_node(graph_bias_node.all_input_nodes[1], model)) + + @staticmethod + def get_activation_port_ids_for_bias_node(node: NNCFNode) -> Tuple[int, int]: + return 0, 0 + + @staticmethod + def process_model_output(raw_data: Dict, output_name: str) -> Tensor: + return Tensor(raw_data) + + @staticmethod + def is_quantized_weights(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + weight_node = nncf_graph.get_previous_nodes(node)[1] + return weight_node.node_type == "dequantize_per_channel" + + @staticmethod + def is_node_with_bias(node: NNCFNode, nncf_graph: NNCFGraph) -> bool: + # Assumes that all biases were unfused + if node.metatype in (om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype, om.PTLinearMetatype): + next_nodes = nncf_graph.get_next_nodes(node) + if len(next_nodes) != 1: + return False + return next_nodes[0].metatype in (om.PTAddMetatype,) + + @staticmethod + def get_node_names_for_input_output_statistics(node: NNCFNode, nncf_graph: NNCFGraph) -> Tuple[str, str]: + return node.node_name, node.node_name diff --git a/nncf/quantization/algorithms/min_max/algorithm.py b/nncf/quantization/algorithms/min_max/algorithm.py index 2ffd27b81dd..941ca25d5f1 100644 --- a/nncf/quantization/algorithms/min_max/algorithm.py +++ b/nncf/quantization/algorithms/min_max/algorithm.py @@ -328,7 +328,7 @@ def _init_cache(self) -> None: @property def available_backends(self) -> List[BackendType]: - return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH] + return [BackendType.ONNX, BackendType.OPENVINO, BackendType.TORCH, BackendType.TORCH_FX] def _get_quantizer_constraints( self, @@ -381,6 +381,10 @@ def _set_backend_entity(self, model: TModel) -> None: from nncf.quantization.algorithms.min_max.openvino_backend import OVMinMaxAlgoBackend self._backend_entity = OVMinMaxAlgoBackend() + elif model_backend == BackendType.TORCH_FX: + from nncf.quantization.algorithms.min_max.torch_fx_backend import FXMinMaxAlgoBackend + + self._backend_entity = FXMinMaxAlgoBackend() elif model_backend == BackendType.TORCH: from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py new file mode 100644 index 00000000000..c5403386441 --- /dev/null +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -0,0 +1,351 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, List, Optional, Set, Tuple + +import torch + +import nncf +import nncf.torch.graph.operator_metatypes as om +from nncf.common.graph.definitions import NNCFGraphNodeType +from nncf.common.graph.graph import NNCFGraph +from nncf.common.graph.graph import NNCFNode +from nncf.common.graph.operator_metatypes import OperatorMetatype +from nncf.common.graph.transformations.commands import TargetType +from nncf.common.graph.transformations.commands import TransformationCommand +from nncf.common.hardware.config import HWConfig +from nncf.common.quantization.structs import QuantizationScheme as QuantizationMode +from nncf.common.quantization.structs import QuantizerConfig +from nncf.experimental.common.tensor_statistics.collectors import AGGREGATORS_MAP +from nncf.experimental.common.tensor_statistics.collectors import TensorCollector +from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic +from nncf.experimental.torch_fx.model_transformer import FXApplyTransformationCommand +from nncf.experimental.torch_fx.transformations import qdq_insertion_tranformation_builder +from nncf.parameters import ModelType +from nncf.parameters import TargetDevice +from nncf.quantization.advanced_parameters import StatisticsType +from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend +from nncf.quantization.fake_quantize import FakeConvertParameters +from nncf.quantization.fake_quantize import FakeQuantizeParameters +from nncf.quantization.range_estimator import AggregatorType +from nncf.quantization.range_estimator import RangeEstimatorParameters +from nncf.torch.graph.graph import PTNNCFGraph +from nncf.torch.graph.graph import PTTargetPoint +from nncf.torch.graph.transformations.commands import PTSharedFnInsertionCommand +from nncf.torch.hardware.config import PTHWConfig +from nncf.torch.nncf_network import NNCFNetwork +from nncf.torch.quantization.default_quantization import DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT +from nncf.torch.quantization.layers import QUANTIZATION_MODULES +from nncf.torch.quantization.layers import AsymmetricQuantizer +from nncf.torch.quantization.layers import BaseQuantizer +from nncf.torch.quantization.layers import PTQuantizerSpec +from nncf.torch.quantization.layers import get_scale_shape +from nncf.torch.quantization.strip import convert_to_torch_fakequantizer +from nncf.torch.tensor_statistics.collectors import PT_REDUCERS_MAP + + +class FXMinMaxAlgoBackend(MinMaxAlgoBackend): + TARGET_TYPE_TO_PT_INS_TYPE_MAP = { + TargetType.PRE_LAYER_OPERATION: TargetType.OPERATOR_PRE_HOOK, + TargetType.POST_LAYER_OPERATION: TargetType.OPERATOR_POST_HOOK, + } + + @property + def mat_mul_metatypes(self) -> List[OperatorMetatype]: + return [om.PTLinearMetatype, om.PTMatMulMetatype] + + @property + def post_processing_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def shapeof_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def dropout_metatypes(self) -> List[OperatorMetatype]: + return [om.PTDropoutMetatype] + + @property + def read_variable_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def conv_metatypes(self) -> List[OperatorMetatype]: + return [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype] + + @property + def overflow_fix_metatypes(self) -> List[OperatorMetatype]: + return [ + om.PTConv1dMetatype, + om.PTConv2dMetatype, + om.PTConv3dMetatype, + om.PTLinearMetatype, + om.PTConvTranspose1dMetatype, + om.PTConvTranspose2dMetatype, + om.PTConvTranspose3dMetatype, + ] + + @property + def add_metatypes(self) -> List[OperatorMetatype]: + return [om.PTAddMetatype] + + @property + def group_conv_metatypes(self) -> List[OperatorMetatype]: + return self.conv_metatypes + + @property + def scaled_dot_product_attention_metatypes(self) -> List[OperatorMetatype]: + return [] + + @property + def scales_unification_map(self) -> Dict[OperatorMetatype, OperatorMetatype]: + return {om.PTCatMetatype: self.overflow_fix_metatypes} + + @property + def hw_config(self) -> HWConfig: + return PTHWConfig + + @property + def quant_trait_op_dict(self) -> Dict[int, OperatorMetatype]: + return DEFAULT_PT_QUANT_TRAIT_TO_OP_DICT + + @staticmethod + def get_start_nodes_for_activation_path_tracing(nncf_graph: PTNNCFGraph) -> List[NNCFNode]: + return nncf_graph.get_input_nodes() + + @staticmethod + def target_point(target_type: TargetType, target_node_name: str, port_id: int) -> PTTargetPoint: + if NNCFGraphNodeType.INPUT_NODE in target_node_name or target_type == TargetType.POST_LAYER_OPERATION: + port_id = None + if target_type in FXMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP: + target_type = FXMinMaxAlgoBackend.TARGET_TYPE_TO_PT_INS_TYPE_MAP[target_type] + return PTTargetPoint(target_type, target_node_name, input_port_id=port_id) + + @staticmethod + def create_convert_insertion_command( + target_point: PTTargetPoint, + parameters: FakeConvertParameters, + ) -> TransformationCommand: + raise nncf.InternalError("FakeConvert insertion not implemented in PyTorch backend!") + + @staticmethod + def get_target_point_shape(nncf_graph: PTNNCFGraph, node: NNCFNode, target_point: PTTargetPoint) -> Tuple[int, ...]: + return nncf_graph.get_input_shape_for_insertion_point(target_point) + + @staticmethod + def get_weight_quantization_axes(node: NNCFNode, target_point: PTTargetPoint) -> Tuple[int]: + # TODO: support transpose conv and other cases + return (0,) + + @staticmethod + def get_statistic_collector( + range_estimator_params: RangeEstimatorParameters, + use_abs_max: bool, + reduction_axes: Optional[Tuple[int, ...]], + aggregation_axes: Optional[Tuple[int, ...]], + inplace: bool, + num_samples: Optional[int] = None, + ) -> TensorCollector: + collector = TensorCollector(MinMaxTensorStatistic) + for params, container_key in zip( + [range_estimator_params.min, range_estimator_params.max], + [MinMaxTensorStatistic.MIN_STAT, MinMaxTensorStatistic.MAX_STAT], + ): + if params.statistics_type not in PT_REDUCERS_MAP: + raise nncf.InternalError( + f"Statistic type: {params.statistics_type} is not supported for Torch PTQ backend yet." + ) + + if params.aggregator_type not in AGGREGATORS_MAP: + raise nncf.InternalError( + f"Aggregator type: {params.aggregator_type} is not supported for Torch PTQ backend yet." + ) + + statistic_type = params.statistics_type + if statistic_type in [StatisticsType.QUANTILE, StatisticsType.ABS_QUANTILE]: + # TODO(dlyakhov): merge two quantile aggregators in one + if container_key == MinMaxTensorStatistic.MIN_STAT: + quantile = params.quantile_outlier_prob + else: + quantile = 1 - params.quantile_outlier_prob + reducer = PT_REDUCERS_MAP[statistic_type](reduction_axes=reduction_axes, quantile=[quantile]) + else: + if use_abs_max and statistic_type == StatisticsType.MAX: + statistic_type = StatisticsType.ABS_MAX + reducer = PT_REDUCERS_MAP[statistic_type](reduction_axes=reduction_axes) + + kwargs = { + "num_samples": num_samples, + "aggregation_axes": aggregation_axes, + } + if params.aggregator_type in [AggregatorType.MEAN_NO_OUTLIERS, AggregatorType.MEDIAN_NO_OUTLIERS]: + kwargs.update({"quantile": params.quantile_outlier_prob}) + aggregator = AGGREGATORS_MAP[params.aggregator_type](**kwargs) + + collector.register_statistic_branch(container_key, reducer, aggregator) + return collector + + @staticmethod + def get_weight_tensor_port_ids(node: NNCFNode) -> List[Optional[int]]: + return node.metatype.weight_port_ids + + @staticmethod + def get_weight_name(nncf_graph: NNCFGraph, target_point: PTTargetPoint) -> str: + weighted_node = nncf_graph.get_node_by_name(target_point.target_node_name) + weight = nncf_graph.get_previous_nodes(weighted_node)[target_point.input_port_id] + return weight.node_name + + @staticmethod + def should_quantize_weight(weight_name: str, quantized_weight_names: Set[str]) -> bool: + # If the nodes share one weight tensor, we should have only one quantizer on that + return weight_name not in quantized_weight_names + + @staticmethod + def get_weight_config(config: QuantizerConfig, model: NNCFNetwork) -> QuantizerConfig: + return config + + @staticmethod + def _get_input_scale_shape( + nncf_graph: NNCFGraph, target_point: PTTargetPoint, per_channel: bool + ) -> Tuple[Tuple[int, ...], Tuple[int, ...], int]: + is_weights = target_point.is_weight_target_point() + if is_weights: + # TODO: support transpose conv/ make channel_idx common + channel_idx = 0 + else: + channel_idx = 1 # channel dim for activations + + input_shape = nncf_graph.get_input_shape_for_insertion_point(target_point) + scale_shape = tuple( + get_scale_shape(input_shape, is_weights=is_weights, per_channel=per_channel, channel_idx=channel_idx) + ) + + return input_shape, scale_shape, channel_idx + + @staticmethod + def _create_quantizer( + quantizer_config: QuantizerConfig, + scale_shape: Tuple, + parameters: FakeQuantizeParameters, + target_type: TargetType, + ) -> BaseQuantizer: + mode = quantizer_config.mode + quantizer_cls = QUANTIZATION_MODULES.get(mode) + narrow_range = target_type == TargetType.OPERATION_WITH_WEIGHTS and mode == QuantizationMode.SYMMETRIC + quantizer_spec = PTQuantizerSpec.from_config( + quantizer_config, + narrow_range=narrow_range, + scale_shape=scale_shape, + half_range=False, + logarithm_scale=False, + is_quantized_on_export=False, + compression_lr_multiplier=None, + ) + quantizer = quantizer_cls(quantizer_spec) + + # Fill it with minmax + FXMinMaxAlgoBackend._fill_quantizer_parameters(quantizer, parameters, quantizer_spec.scale_shape) + # Convert to the torch fake quantizer + torch_fq = convert_to_torch_fakequantizer(quantizer) + return torch_fq + + @staticmethod + def _fill_quantizer_parameters(quantizer: BaseQuantizer, parameters: FakeQuantizeParameters, scale_shape) -> None: + if isinstance(quantizer, AsymmetricQuantizer): + quantizer.input_low = torch.nn.Parameter(parameters.input_low.data.reshape(scale_shape)) + input_range = parameters.input_high - parameters.input_low + # Subtract eps from the input_range to make quantizer parameters equal to + # original parameters on the forward call. + quantizer.input_range = torch.nn.Parameter((input_range.data - quantizer.eps).reshape(scale_shape)) + else: + quantizer.signed = bool(torch.any(parameters.input_low.data < 0)) + # Subtract eps from the scale to make quantizer parameters equal to + # original parameters on the forward call. + quantizer.scale = torch.nn.Parameter((parameters.input_high.data - quantizer.eps).reshape(scale_shape)) + + @staticmethod + def create_quantizer_insertion_command( + nncf_graph: NNCFGraph, + target_point: PTTargetPoint, + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> FXApplyTransformationCommand: + _, scale_shape, _ = FXMinMaxAlgoBackend._get_input_scale_shape( + nncf_graph, target_point, quantizer_config.per_channel + ) + + quantizer = FXMinMaxAlgoBackend._create_quantizer( + quantizer_config, scale_shape, parameters, target_point.target_type + ) + transformation = qdq_insertion_tranformation_builder(quantizer, [target_point]) + return FXApplyTransformationCommand(transformation) + + @staticmethod + def create_unified_scales_quantizers_insertion_commands( + nncf_graph: NNCFGraph, + target_points: List[PTTargetPoint], + quantizer_config: QuantizerConfig, + parameters: FakeQuantizeParameters, + ) -> List[PTSharedFnInsertionCommand]: + _, scale_shape, _ = FXMinMaxAlgoBackend._get_input_scale_shape( + nncf_graph, target_points[0], quantizer_config.per_channel + ) + + quantizer = FXMinMaxAlgoBackend._create_quantizer( + quantizer_config, scale_shape, parameters, target_points[0].target_type + ) + + transformations = [] + for tp in target_points: + transformation = qdq_insertion_tranformation_builder(quantizer, [tp]) + transformations.append(FXApplyTransformationCommand(transformation)) + return transformations + + @staticmethod + def get_ignored_metatypes(model_type: ModelType, device: TargetDevice) -> List[OperatorMetatype]: + types = [] + if model_type == ModelType.TRANSFORMER: + types = [ + om.PTAddMetatype, + om.PTPowerMetatype, + om.PTSubMetatype, + om.PTAvgPool2dMetatype, + om.PTAvgPool3dMetatype, + om.PTMeanMetatype, + om.PTSumMetatype, + om.PTReduceL2, + om.PTDivMetatype, + om.PTMaxMetatype, + om.PTSqueezeMetatype, + om.PTLayerNormMetatype, + om.PTModuleLayerNormMetatype, + om.PTGroupNormMetatype, + om.PTModuleGroupNormMetatype, + # Batchnorm + om.PTBatchNormMetatype, + om.PTModuleBatchNormMetatype, + ] + if device != TargetDevice.CPU_SPR: + types.append(om.PTMulMetatype) + return types + + @staticmethod + def get_ignored_names_by_layer_attributes(nncf_graph: NNCFGraph) -> List[str]: + return [] + + @staticmethod + def get_weight_nodes(nncf_graph: NNCFGraph) -> List[NNCFNode]: + retval = set() + for node in nncf_graph.get_all_nodes(): + if node.metatype in [om.PTConv1dMetatype, om.PTConv2dMetatype, om.PTConv3dMetatype, om.PTLinearMetatype]: + retval.add(node) + return list(retval) diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 5632bedf5b6..bdc096a4982 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -228,7 +228,21 @@ def quantize( ignored_scope=ignored_scope, advanced_parameters=advanced_parameters, ) + if backend == BackendType.TORCH_FX: + from nncf.experimental.torch_fx.quantization.quantize_model import quantize_impl + return quantize_impl( + model=model, + calibration_dataset=calibration_dataset, + mode=mode, + preset=preset, + target_device=target_device, + subset_size=subset_size, + fast_bias_correction=fast_bias_correction, + model_type=model_type, + ignored_scope=ignored_scope, + advanced_parameters=advanced_parameters, + ) raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}") diff --git a/nncf/torch/dynamic_graph/patch_pytorch.py b/nncf/torch/dynamic_graph/patch_pytorch.py index 5148496fea6..5d20a0d7ba6 100644 --- a/nncf/torch/dynamic_graph/patch_pytorch.py +++ b/nncf/torch/dynamic_graph/patch_pytorch.py @@ -371,7 +371,7 @@ def patch_torch_operators(): functions_to_patch = {} for namespace in NamespaceTarget: - if namespace == NamespaceTarget.EXTERNAL: + if namespace in [NamespaceTarget.ATEN, NamespaceTarget.EXTERNAL]: continue functions_to_patch[namespace] = get_all_functions_from_namespace(namespace) diff --git a/nncf/torch/dynamic_graph/structs.py b/nncf/torch/dynamic_graph/structs.py index c767790a92c..d8cf563107f 100644 --- a/nncf/torch/dynamic_graph/structs.py +++ b/nncf/torch/dynamic_graph/structs.py @@ -22,6 +22,7 @@ class NamespaceTarget(Enum): TORCH_TENSOR = "torch.tensor" TORCH_NN_PARAMETER = "torch.nn.parameter" TORCH = "torch" + ATEN = "aten" EXTERNAL = "external_function" diff --git a/nncf/torch/graph/operator_metatypes.py b/nncf/torch/graph/operator_metatypes.py index 0e0bca6e770..802e5740bc7 100644 --- a/nncf/torch/graph/operator_metatypes.py +++ b/nncf/torch/graph/operator_metatypes.py @@ -55,6 +55,7 @@ class PTOperatorMetatype(OperatorMetatype): NamespaceTarget.TORCH_NN_FUNCTIONAL: [], NamespaceTarget.TORCH_TENSOR: [], NamespaceTarget.TORCH: [], + NamespaceTarget.ATEN: [], } subtypes: List[Type["PTOperatorMetatype"]] = [] @@ -527,7 +528,7 @@ class PTGELUMetatype(PTOperatorMetatype): @PT_OPERATOR_METATYPES.register() class PTSILUMetatype(PTOperatorMetatype): name = "SiluOp" - module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["silu"]} + module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["silu"], NamespaceTarget.ATEN: ["silu_"]} @PT_OPERATOR_METATYPES.register() @@ -546,6 +547,7 @@ class PTAddMetatype(PTOperatorMetatype): module_to_function_names = { NamespaceTarget.TORCH_TENSOR: ["add", "__add__", "__iadd__", "__radd__"], NamespaceTarget.TORCH: ["add"], + NamespaceTarget.ATEN: ["add_"], } hw_config_names = [HWConfigOpName.ADD] num_expected_input_edges = 2 @@ -557,6 +559,7 @@ class PTSubMetatype(PTOperatorMetatype): module_to_function_names = { NamespaceTarget.TORCH_TENSOR: ["sub", "__sub__", "__isub__", "__rsub__"], NamespaceTarget.TORCH: ["sub"], + NamespaceTarget.ATEN: ["sub_"], } hw_config_names = [HWConfigOpName.SUBTRACT] num_expected_input_edges = 2 @@ -690,13 +693,19 @@ class PTThresholdMetatype(PTOperatorMetatype): @PT_OPERATOR_METATYPES.register(is_subtype=True) class PTModuleBatchNormMetatype(PTModuleOperatorSubtype): name = "BatchNormOp" - module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]} + module_to_function_names = { + NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"], + NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training"], + } @PT_OPERATOR_METATYPES.register() class PTBatchNormMetatype(PTOperatorMetatype): name = "BatchNormOp" - module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"]} + module_to_function_names = { + NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"], + NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training"], + } subtypes = [PTModuleBatchNormMetatype] weight_port_ids = [3] bias_port_id = 4 @@ -825,7 +834,8 @@ class PTGatherMetatype(PTOperatorMetatype): name = "GatherOp" module_to_function_names = { NamespaceTarget.TORCH_TENSOR: ["index_select", "__getitem__"], - NamespaceTarget.TORCH: ["gather", "index_select", "where"], + NamespaceTarget.TORCH: ["gather", "index_select", "select", "where"], + NamespaceTarget.ATEN: ["slice"], } @@ -840,7 +850,7 @@ class PTReshapeMetatype(PTOperatorMetatype): name = "ReshapeOp" module_to_function_names = { NamespaceTarget.TORCH_TENSOR: ["reshape", "view", "flatten", "unsqueeze"], - NamespaceTarget.TORCH: ["flatten", "unsqueeze"], + NamespaceTarget.TORCH: ["flatten", "unflatten", "unsqueeze"], } hw_config_names = [HWConfigOpName.RESHAPE, HWConfigOpName.UNSQUEEZE, HWConfigOpName.FLATTEN] @@ -862,6 +872,7 @@ class PTSplitMetatype(PTOperatorMetatype): NamespaceTarget.TORCH_NN_FUNCTIONAL: [], NamespaceTarget.TORCH_TENSOR: ["split", "chunk", "unbind"], NamespaceTarget.TORCH: ["split", "chunk", "unbind"], + NamespaceTarget.ATEN: ["split_with_sizes"], } hw_config_names = [HWConfigOpName.SPLIT, HWConfigOpName.CHUNK] @@ -1027,7 +1038,10 @@ class PTSqrtMetatype(PTOperatorMetatype): @PT_OPERATOR_METATYPES.register() class PTInterpolateMetatype(PTOperatorMetatype): name = "InterpolateOp" - module_to_function_names = {NamespaceTarget.TORCH_NN_FUNCTIONAL: ["interpolate"]} + module_to_function_names = { + NamespaceTarget.TORCH_NN_FUNCTIONAL: ["interpolate"], + NamespaceTarget.ATEN: ["upsample_nearest2d", "upsample_nearest_exact2d"], + } hw_config_names = [HWConfigOpName.INTERPOLATE] num_expected_input_edges = 1 diff --git a/nncf/torch/graph/pattern_operations.py b/nncf/torch/graph/pattern_operations.py index d9957871d87..dc0e5b43af2 100644 --- a/nncf/torch/graph/pattern_operations.py +++ b/nncf/torch/graph/pattern_operations.py @@ -67,7 +67,7 @@ ) ARITHMETIC_OPERATIONS = { - GraphPattern.METATYPE_ATTR: ["__iadd__", "__add__", "__mul__", "__rmul__", "__truediv__"], + GraphPattern.METATYPE_ATTR: ["__iadd__", "__add__", "__mul__", "__rmul__", "__truediv__", "add_"], GraphPattern.LABEL_ATTR: "ARITHMETIC", } diff --git a/tests/post_training/pipelines/lm_weight_compression.py b/tests/post_training/pipelines/lm_weight_compression.py index fcab0a20f88..de8eeebee1f 100644 --- a/tests/post_training/pipelines/lm_weight_compression.py +++ b/tests/post_training/pipelines/lm_weight_compression.py @@ -19,7 +19,6 @@ import numpy as np import openvino as ov import torch -from datasets import load_dataset from memory_profiler import memory_usage from optimum.exporters.openvino.convert import export_from_model from optimum.intel.openvino import OVModelForCausalLM @@ -28,6 +27,7 @@ from whowhatbench import Evaluator import nncf +from datasets import load_dataset from tests.post_training.pipelines.base import BackendType from tests.post_training.pipelines.base import BaseTestPipeline from tests.post_training.pipelines.base import StatsFromOutput diff --git a/tests/torch/sparsity/movement/helpers/run_recipe.py b/tests/torch/sparsity/movement/helpers/run_recipe.py index 77b3140a967..383552932d5 100644 --- a/tests/torch/sparsity/movement/helpers/run_recipe.py +++ b/tests/torch/sparsity/movement/helpers/run_recipe.py @@ -20,7 +20,6 @@ import torch.nn import torch.nn.functional as F import torch.utils.data -from datasets import Dataset from transformers import AutoModelForAudioClassification from transformers import AutoModelForImageClassification from transformers import AutoModelForSequenceClassification @@ -34,6 +33,7 @@ from transformers import SwinConfig from transformers import Wav2Vec2Config +from datasets import Dataset from nncf import NNCFConfig from nncf.experimental.torch.sparsity.movement.scheduler import MovementSchedulerParams from nncf.torch.dynamic_graph.io_handling import FillerInputElement diff --git a/tests/torch/sparsity/movement/helpers/trainer.py b/tests/torch/sparsity/movement/helpers/trainer.py index 89ffeb6c865..2af37c5b2f4 100644 --- a/tests/torch/sparsity/movement/helpers/trainer.py +++ b/tests/torch/sparsity/movement/helpers/trainer.py @@ -14,7 +14,6 @@ import numpy as np import torch -from datasets import Dataset # pylint: disable=no-name-in-module from transformers import TrainingArguments from transformers.trainer import Trainer from transformers.trainer_callback import TrainerCallback @@ -22,6 +21,7 @@ from transformers.trainer_callback import TrainerState from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from datasets import Dataset # pylint: disable=no-name-in-module from nncf.api.compression import CompressionAlgorithmController from nncf.common.compression import BaseCompressionAlgorithmController from nncf.common.utils.tensorboard import prepare_for_tensorboard diff --git a/tests/torch/sparsity/movement/test_model_saving.py b/tests/torch/sparsity/movement/test_model_saving.py index 27a9655591a..c7949afeb82 100644 --- a/tests/torch/sparsity/movement/test_model_saving.py +++ b/tests/torch/sparsity/movement/test_model_saving.py @@ -18,7 +18,6 @@ import pytest import torch from addict import Dict -from datasets import Dataset from onnx import numpy_helper from openvino._offline_transformations import apply_fused_names_cleanup from openvino._offline_transformations import apply_moc_transformations @@ -29,6 +28,7 @@ from scipy.special import softmax from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from datasets import Dataset from nncf.torch import create_compressed_model from nncf.torch.checkpoint_loading import load_state from tests.torch.helpers import PTTensorListComparator diff --git a/tests/torch/sparsity/movement/training_scripts/run_glue.py b/tests/torch/sparsity/movement/training_scripts/run_glue.py index 360832a5bb7..d0f5b14269e 100644 --- a/tests/torch/sparsity/movement/training_scripts/run_glue.py +++ b/tests/torch/sparsity/movement/training_scripts/run_glue.py @@ -12,12 +12,13 @@ from pathlib import Path from typing import Dict, List, Optional, Tuple -import datasets import evaluate import jstyleson import numpy as np from transformers.training_args import ParallelMode +import datasets + # isort: off from nncf import NNCFConfig from nncf.api.compression import CompressionAlgorithmController diff --git a/tests/torch_fx/__init__.py b/tests/torch_fx/__init__.py new file mode 100644 index 00000000000..2e49d63977d --- /dev/null +++ b/tests/torch_fx/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/torch_fx/helpers.py b/tests/torch_fx/helpers.py new file mode 100644 index 00000000000..8bbc721e0fa --- /dev/null +++ b/tests/torch_fx/helpers.py @@ -0,0 +1,105 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pathlib import Path + +import torch +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.datasets as datasets +import torchvision.transforms as transforms +from fastdownload import FastDownload + + +class TinyImagenetDatasetManager: + DATASET_URL = "http://cs231n.stanford.edu/tiny-imagenet-200.zip" + DATASET_PATH = "~/.cache/nncf/tests/datasets" + + def __init__(self, image_size: int, batch_size: int) -> None: + self.image_size = image_size + self.batch_size = batch_size + + @staticmethod + def download_dataset() -> Path: + downloader = FastDownload(base=TinyImagenetDatasetManager.DATASET_PATH, archive="downloaded", data="extracted") + return downloader.get(TinyImagenetDatasetManager.DATASET_URL) + + @staticmethod + def prepare_tiny_imagenet_200(dataset_dir: Path): + # Format validation set the same way as train set is formatted. + val_data_dir = dataset_dir / "val" + val_images_dir = val_data_dir / "images" + if not val_images_dir.exists(): + return + + val_annotations_file = val_data_dir / "val_annotations.txt" + with open(val_annotations_file, "r") as f: + val_annotation_data = map(lambda line: line.split("\t")[:2], f.readlines()) + for image_filename, image_label in val_annotation_data: + from_image_filepath = val_images_dir / image_filename + to_image_dir = val_data_dir / image_label + if not to_image_dir.exists(): + to_image_dir.mkdir() + to_image_filepath = to_image_dir / image_filename + from_image_filepath.rename(to_image_filepath) + val_annotations_file.unlink() + val_images_dir.rmdir() + + def create_data_loaders(self): + dataset_path = TinyImagenetDatasetManager.download_dataset() + + TinyImagenetDatasetManager.prepare_tiny_imagenet_200(dataset_path) + print(f"Successfully downloaded and prepared dataset at: {dataset_path}") + + train_dir = dataset_path / "train" + val_dir = dataset_path / "val" + + normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + train_dataset = datasets.ImageFolder( + train_dir, + transforms.Compose( + [ + transforms.Resize(self.image_size), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + normalize, + ] + ), + ) + val_dataset = datasets.ImageFolder( + val_dir, + transforms.Compose( + [ + transforms.Resize(self.image_size), + transforms.ToTensor(), + normalize, + ] + ), + ) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True, sampler=None + ) + + val_loader = torch.utils.data.DataLoader( + val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, pin_memory=True + ) + + # Creating separate dataloader with batch size = 1 + # as dataloaders with batches > 1 are not supported yet. + calibration_dataset = torch.utils.data.DataLoader( + val_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True + ) + + return train_loader, val_loader, calibration_dataset diff --git a/tests/torch_fx/requirements.txt b/tests/torch_fx/requirements.txt new file mode 100644 index 00000000000..99ee43ce754 --- /dev/null +++ b/tests/torch_fx/requirements.txt @@ -0,0 +1 @@ +fastdownload==0.0.7 \ No newline at end of file diff --git a/tests/torch_fx/test_sanity.py b/tests/torch_fx/test_sanity.py new file mode 100644 index 00000000000..197c2f95472 --- /dev/null +++ b/tests/torch_fx/test_sanity.py @@ -0,0 +1,141 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from dataclasses import dataclass +from typing import Tuple + +import numpy as np +import openvino.torch # noqa +import pytest +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.models as models +from torch._export import capture_pre_autograd_graph + +import nncf +from nncf.common.logging.track_progress import track +from nncf.torch.dynamic_graph.patch_pytorch import disable_patching +from tests.torch_fx.helpers import TinyImagenetDatasetManager + +IMAGE_SIZE = 64 +BATCH_SIZE = 128 + + +@dataclass +class SanitySampleCase: + model_id: str + checkpoint_url: str + top1_int8_ref: float + ref_num_q: int + ref_num_dq: int + + +MODELS = ( + SanitySampleCase( + "resnet18", + "https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/302_resnet18_fp32_v1.pth", + 55.23, + 51, + 58, + ), +) + + +def get_model(model_id: str, checkpoint_url: str, device: torch.device) -> torch.nn.Module: + num_classes = 200 # 200 is for Tiny ImageNet, default is 1000 for ImageNet + model = getattr(models, model_id)(weights=None) + # Update the last FC layer for Tiny ImageNet number of classes. + model.fc = nn.Linear(in_features=512, out_features=num_classes, bias=True) + model.to(device) + checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location=torch.device("cpu"), progress=False) + model.load_state_dict(checkpoint["state_dict"]) + return model + + +def validate(val_loader: torch.utils.data.DataLoader, model: torch.nn.Module, device: torch.device) -> float: + top1_sum = 0.0 + with torch.no_grad(): + for images, target in track(val_loader, total=len(val_loader), description="Validation:"): + images = images.to(device) + target = target.to(device) + + # Compute output. + output = model(images) + + # Measure accuracy and record loss. + [acc1] = accuracy(output, target, topk=(1,)) + top1_sum += acc1.item() + + num_samples = len(val_loader) + top1_avg = top1_sum / num_samples + return top1_avg + + +def accuracy(output: torch.Tensor, target: torch.tensor, topk: Tuple[int, ...] = (1,)): + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def count_q_dq(model: torch.fx.GraphModule): + q, dq = 0, 0 + for node in model.graph.nodes: + if node.op == "call_function" and hasattr(node.target, "overloadpacket"): + node_type = str(node.target.overloadpacket).split(".")[1] + if node_type in ["quantize_per_tensor", "quantize_per_channel"]: + q += 1 + elif node_type in ["dequantize_per_tensor", "dequantize_per_channel"]: + dq += 1 + return q, dq + + +@pytest.mark.parametrize("test_case", MODELS) +def test_sanity(test_case: SanitySampleCase): + with disable_patching(): + device = torch.device("cpu") + model = get_model(test_case.model_id, test_case.checkpoint_url, device) + _, val_dataloader, calibration_dataset = TinyImagenetDatasetManager( + IMAGE_SIZE, BATCH_SIZE + ).create_data_loaders() + + def transform_fn(data_item): + return data_item[0].to(device) + + calibration_dataset = nncf.Dataset(calibration_dataset, transform_fn) + + with torch.no_grad(): + ex_input = next(iter(calibration_dataset.get_inference_data())) + model.eval() + exported_model = capture_pre_autograd_graph(model, args=(ex_input,)) + quantized_model = nncf.quantize(exported_model, calibration_dataset) + quantized_model = torch.compile(quantized_model, backend="openvino") + + top1_int8 = validate(val_dataloader, quantized_model, device) + assert np.isclose(top1_int8, test_case.top1_int8_ref, atol=1e-2) + + num_q, num_dq = count_q_dq(quantized_model) + assert num_q == test_case.ref_num_q + assert num_dq == test_case.ref_num_dq diff --git a/torch_compile_ex_release.py b/torch_compile_ex_release.py new file mode 100644 index 00000000000..7bd0addf02e --- /dev/null +++ b/torch_compile_ex_release.py @@ -0,0 +1,217 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Enable torch inductor freezing feature first +import os + +os.environ["TORCHINDUCTOR_FREEZING"] = "1" + + +import argparse +import copy +import time +from collections import defaultdict + +import openvino.torch # noqa +import torch + +# Optional: using the C++ wrapper instead of default Python wrapper +import torch._inductor.config as config +import torch.ao.quantization.quantizer.x86_inductor_quantizer as xiq +import torchvision.models as models +from torch._export import capture_pre_autograd_graph +from torch.ao.quantization.quantize_pt2e import convert_pt2e +from torch.ao.quantization.quantize_pt2e import prepare_pt2e +from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer +from torch.fx.passes.graph_drawer import FxGraphDrawer + +from nncf.experimental.torch_fx.model_transformer import QPARAMPerChannel +from nncf.experimental.torch_fx.model_transformer import QPARAMSPerTensor +from nncf.experimental.torch_fx.model_transformer import insert_qdq_to_model +from nncf.experimental.torch_fx.nncf_graph_builder import GraphConverter # noqa + + +def get_exported_model_from_nn_module(module, example_inputs): + with torch.no_grad(): + return capture_pre_autograd_graph(module, example_inputs) + + +NNCF_IMPL = True + + +def get_qsetup(exported_model, example_inputs): + quantizer = X86InductorQuantizer() + quantizer.set_global(xiq.get_default_x86_inductor_quantization_config()) + + prepared_model = prepare_pt2e(exported_model, quantizer) + prepared_model(*example_inputs) + converted_model = convert_pt2e(prepared_model) + g = FxGraphDrawer(converted_model, "resnet18_int8") + g.get_dot_graph().write_svg("resnet18_int8_compiled.svg") + qsetup = defaultdict(lambda: dict()) + + for node in converted_model.graph.nodes: + if "dequantize" in node.name: + quantize = node.all_input_nodes[0] + # place = "activations" + # if len(quantize.all_input_nodes) > 1: + # place = "weights" + if "per_tensor" in node.name: + params = QPARAMSPerTensor(*node.args[1:]) + else: + params = [] + for i in range(1, 3): + name = node.args[i].target + params.append(getattr(converted_model, name)) + params = QPARAMPerChannel(*(params + list(node.args[3:]))) + + target_node_name = quantize.all_input_nodes[0].name + qsetup[target_node_name] = params + return qsetup + + +def quantize(model, example_inputs): + if NNCF_IMPL: + # Use NNCF here on exported model + # to create a quantized model which is compatible with + # convert_pt2e function + pass + # 1. Convert torch.graph to NNCFGraph. + # # 2. Analize nncf grpah for SQ/CA + # # 3. Collect statistics + # # 4. Update params + # 5. Analize nncf graph for quantization + # 6. Insert observers + # 7. prepared_model(*example_inputs) + # 8. convert_pt2e(prepared_model) + import nncf + + calibration_dataset = nncf.Dataset(example_inputs) + exported_model = get_exported_model_from_nn_module(model, example_inputs) + quantized_model = nncf.quantize(exported_model, calibration_dataset) + g = FxGraphDrawer(quantized_model, "resnet18_quantized_native_nncf") + g.get_dot_graph().write_svg("resnet18_quantized_native_nncf.svg") + return quantized_model + + else: + # g = FxGraphDrawer(exported_model, "resnet18") + # g.get_dot_graph().write_svg("resnet18_compiled.svg") + + # MOCK NNCF QUANTIZATION + exported_model = get_exported_model_from_nn_module(model, example_inputs) + qsetup = get_qsetup(exported_model, example_inputs) + exported_model = get_exported_model_from_nn_module(model, example_inputs) + exported_model = insert_qdq_to_model(exported_model, qsetup) + g = FxGraphDrawer(exported_model, "resnet18_int8") + g.get_dot_graph().write_svg("resnet18_int8_compiled_manually.svg") + return exported_model + + return None # converted_model + + +config.cpp_wrapper = True + + +def measure_time(model, example_inputs, num_iters): + with torch.no_grad(): + model(*example_inputs) + total_time = 0 + for i in range(0, num_iters): + start_time = time.time() + model(*example_inputs) + total_time += time.time() - start_time + average_time = (total_time / num_iters) * 1000 + return average_time + + +def get_dummy_dataset(): + traced_bs = 1 + x = torch.randn(traced_bs, 3, 224, 224).contiguous(memory_format=torch.channels_last) + example_inputs = (x,) + return example_inputs + + +def main_nncf(model_name, num_iters): + model = models.__dict__[model_name](pretrained=True) + model = model.eval() + + example_inputs = get_dummy_dataset() + import nncf + + calibration_dataset = nncf.Dataset(example_inputs) + quantized_model = nncf.quantize(model, calibration_dataset) + + import openvino as ov + + ov_model = ov.convert_model(quantized_model.cpu(), example_input=example_inputs[0]) + ov.serialize(ov_model, "./model_cache_nncf/model.xml") + + +def main(model_name, num_iters): + model = models.__dict__[model_name](pretrained=True) + model = model.eval() + + example_inputs = get_dummy_dataset() + + converted_model = quantize(copy.deepcopy(model), example_inputs) + + print("original model execution time: ", measure_time(model, example_inputs, num_iters)) + + native_optimized_model_fp32 = torch.compile(model) + print( + "Torch Inductor FP32 model execution time: ", + measure_time(native_optimized_model_fp32, example_inputs, num_iters), + ) + + native_optimized_model_int8 = torch.compile(converted_model) + print( + "Torch Inductor INT8 model execution time: ", + measure_time(native_optimized_model_int8, example_inputs, num_iters), + ) + + ov_optimized_model_fp32 = torch.compile(model, backend="openvino") + print( + "Torch.compile OpenVINO FP32 model execution time: ", + measure_time(ov_optimized_model_fp32, example_inputs, num_iters), + ) + + ov_optimized_model_int8 = torch.compile( + converted_model, backend="openvino", options={"model_caching": True, "cache_dir": "./model_cache"} + ) + print( + "Torch.compile OpenVINO INT8 model execution time: ", + measure_time(ov_optimized_model_int8, example_inputs, num_iters), + ) + + import intel_extension_for_pytorch # noqa + + ipex_optimized_model_fp32 = torch.compile(model, backend="ipex") + print( + "Torch.compile IPEX FP32 model execution time: ", + measure_time(ipex_optimized_model_fp32, example_inputs, num_iters), + ) + + ipex_optimized_model_int8 = torch.compile(converted_model, backend="ipex") + print( + "Torch.compile IPEX INT8 model execution time: ", + measure_time(ipex_optimized_model_int8, example_inputs, num_iters), + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num_iters", help="number of inference iterations", type=int, default=100) + parser.add_argument("--model", help="torchvision model name", type=str, default="resnet18") + args = parser.parse_args() + model_name = args.model + num_iters = args.num_iters + main(model_name, num_iters) + # main_nncf(model_name, num_iters) diff --git a/yolo_fx_bad_metrics_repro.py b/yolo_fx_bad_metrics_repro.py new file mode 100644 index 00000000000..b5c05d6bbcb --- /dev/null +++ b/yolo_fx_bad_metrics_repro.py @@ -0,0 +1,86 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Tuple + +import numpy as np +import torch +from tqdm import tqdm +from ultralytics.data.utils import check_det_dataset +from ultralytics.engine.validator import BaseValidator as Validator +from ultralytics.models.yolo import YOLO +from ultralytics.utils.torch_utils import de_parallel + + +def print_statistics(stats: np.ndarray, total_images: int, total_objects: int) -> None: + mp, mr, map50, mean_ap = ( + stats["metrics/precision(B)"], + stats["metrics/recall(B)"], + stats["metrics/mAP50(B)"], + stats["metrics/mAP50-95(B)"], + ) + s = ("%20s" + "%12s" * 6) % ("Class", "Images", "Labels", "Precision", "Recall", "mAP@.5", "mAP@.5:.95") + print(s) + pf = "%20s" + "%12i" * 2 + "%12.3g" * 4 # print format + print(pf % ("all", total_images, total_objects, mp, mr, map50, mean_ap)) + + +def prepare_validation(model: YOLO, data: str) -> Tuple[Validator, torch.utils.data.DataLoader]: + # custom = {"rect": True, "batch": 1} # method defaults + # rect: false forces to resize all input pictures to one size + custom = {"rect": False, "batch": 1} # method defaults + args = {**model.overrides, **custom, "mode": "val"} # highest priority args on the right + + validator = model._smart_load("validator")(args=args, _callbacks=model.callbacks) + stride = 32 # default stride + validator.stride = stride # used in get_dataloader() for padding + validator.data = check_det_dataset(data) + validator.init_metrics(de_parallel(model)) + + data_loader = validator.get_dataloader(validator.data.get(validator.args.split), validator.args.batch) + return validator, data_loader + + +def validate(model, data_loader: torch.utils.data.DataLoader, validator: Validator) -> Tuple[Dict, int, int]: + with torch.no_grad(): + for batch in data_loader: + batch = validator.preprocess(batch) + preds = model(batch["img"]) + preds = validator.postprocess(preds) + validator.update_metrics(preds, batch) + stats = validator.get_stats() + return stats, validator.seen, validator.nt_per_class.sum() + + +def main(torch_fx): + # ultralytics @ git+https://github.com/THU-MIG/yolov10.git@2c36ab0f108efdd17c7e290564bb845ccb6844d8 + # pip install git+https://github.com/THU-MIG/yolov10.git + # pip install huggingface-hub + # yolo_model = YOLO("yolov10n.pt") + + yolo_model = YOLO("yolov8n") + + model_type = "torch" + model = yolo_model.model + if torch_fx: + model = torch.compile(model) + model_type = "FX" + print(f"FP32 {model_type} model validation results:") + validator, data_loader = prepare_validation(yolo_model, "coco128.yaml") + stats, total_images, total_objects = validate(model, tqdm(data_loader), validator) + print_statistics(stats, total_images, total_objects) + + +if __name__ == "__main__": + print("Torch model:") + main(torch_fx=False) + print("Torch FX model:") + main(torch_fx=True)