Skip to content

Commit

Permalink
[TorchFX]: Test Fast Bias Correction algorithm (#2812)
Browse files Browse the repository at this point in the history
### Changes

Add Fast BC test in tests/torch/fx. 

### Related tickets

#2809 

### Tests

The implemented test class inherits from the template for Fast bias
correction.
  • Loading branch information
anzr299 authored Jul 30, 2024
1 parent 25eff06 commit 2a0a93e
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 4 deletions.
51 changes: 49 additions & 2 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
import torch
import torch.fx
from torch.ao.quantization.fx.utils import create_getattr_from_value
from torch.ao.quantization.pt2e.utils import _fuse_conv_bn_
from torch.ao.quantization.pt2e.utils import _get_tensor_constant_from_node
from torch.ao.quantization.pt2e.utils import fold_bn_weights_into_conv_node
from torch.quantization.fake_quantize import FakeQuantize

import nncf
Expand Down Expand Up @@ -243,6 +243,53 @@ def _set_module_to_the_graph_module(
return module_name_in_model


def _is_supported_batch_norm_for_training(node: torch.fx.Node):
"""
Return True if the given node refers to an aten batch norm op QAT supports.
"""
supported_ops = [
torch.ops.aten._native_batch_norm_legit.default,
torch.ops.aten.cudnn_batch_norm.default,
torch.ops.aten.miopen_batch_norm.default,
]
return node.target in supported_ops


def _is_bn_node(node: torch.fx.Node):
return (
_is_supported_batch_norm_for_training(node)
or node.target == torch.ops.aten._native_batch_norm_legit_no_training.default
)


def fuse_conv_bn(model: torch.fx.GraphModule) -> None:
"""
BatchNorm operations have 3 output ports, to make it easier for alorithms to work with
the target graph BatchNorm operations are being fused
:param model: Model to apply transformations to.
"""
has_bn = any(_is_bn_node(node) for node in model.graph.nodes)
if not has_bn:
return

for node in model.graph.nodes:
if node.op != "call_function" or not _is_bn_node(node):
continue
bn_node = node

node = bn_node.args[0]
if not _is_conv(node):
continue
conv_node = node
conv_weight_node = conv_node.args[1]
conv_bias_node = conv_node.args[2] if len(conv_node.args) > 2 else None
fold_bn_weights_into_conv_node(conv_node, conv_weight_node, conv_bias_node, bn_node, model)

model.graph.eliminate_dead_code()
model.recompile()


def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
"""
Applies quantization transformations to the model.
Expand All @@ -252,7 +299,7 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
# to make it easier for alorithms to work
# with the target graph BatchNorm operations
# are being fused
_fuse_conv_bn_(model)
fuse_conv_bn(model)
separate_conv_and_bias(model)
separate_linear_and_bias(model)

Expand Down
4 changes: 2 additions & 2 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -707,7 +707,7 @@ class PTModuleBatchNormMetatype(PTModuleOperatorSubtype):
name = "BatchNormOp"
module_to_function_names = {
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"],
NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training"],
NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training", "cudnn_batch_norm"],
}


Expand All @@ -716,7 +716,7 @@ class PTBatchNormMetatype(PTOperatorMetatype):
name = "BatchNormOp"
module_to_function_names = {
NamespaceTarget.TORCH_NN_FUNCTIONAL: ["batch_norm"],
NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training"],
NamespaceTarget.ATEN: ["_native_batch_norm_legit_no_training", "cudnn_batch_norm"],
}
subtypes = [PTModuleBatchNormMetatype]
weight_port_ids = [3]
Expand Down
98 changes: 98 additions & 0 deletions tests/torch/fx/test_fast_bias_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

import pytest
import torch
import torch.fx
from torch._export import capture_pre_autograd_graph

from nncf.common.factory import NNCFGraphFactory
from nncf.experimental.torch.fx.transformations import apply_quantization_transformations
from nncf.quantization.algorithms.fast_bias_correction.torch_fx_backend import FXFastBiasCorrectionAlgoBackend
from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
from nncf.torch.model_graph_manager import OPERATORS_WITH_BIAS_METATYPES
from tests.post_training.test_templates.test_fast_bias_correction import TemplateTestFBCAlgorithm


class TestTorchFXFBCAlgorithm(TemplateTestFBCAlgorithm):
@staticmethod
def list_to_backend_type(data: List) -> torch.Tensor:
return torch.Tensor(data)

@staticmethod
def get_backend() -> FXFastBiasCorrectionAlgoBackend:
return FXFastBiasCorrectionAlgoBackend

def _get_fx_model(model: torch.nn.Module) -> torch.fx.GraphModule:
device = next(model.named_parameters())[1].device
input_shape = model.INPUT_SIZE
if input_shape is None:
input_shape = [1, 3, 32, 32]
ex_input = torch.ones(input_shape).to(device)
model.eval()
with disable_patching():
fx_model = capture_pre_autograd_graph(model, args=(ex_input,))
apply_quantization_transformations(fx_model)
return fx_model

@staticmethod
def backend_specific_model(model: torch.nn.Module, tmp_dir: str) -> torch.fx.GraphModule:
fx_model = TestTorchFXFBCAlgorithm._get_fx_model(model)
return fx_model

@staticmethod
def fn_to_type(tensor):
return torch.Tensor(tensor)

@staticmethod
def get_transform_fn():
def transform_fn(data_item):
tensor, _ = data_item
return tensor

return transform_fn

@staticmethod
def check_bias(model: torch.fx.GraphModule, ref_bias: list):
ref_bias = torch.Tensor(ref_bias)
nncf_graph = NNCFGraphFactory.create(model)
for node in nncf_graph.get_all_nodes():
if node.metatype not in OPERATORS_WITH_BIAS_METATYPES:
continue
bias_value = FXFastBiasCorrectionAlgoBackend.get_bias_value(node, nncf_graph, model)
bias_value = torch.flatten(bias_value.data).cpu()
# TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189
assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}"
return
raise ValueError("Not found node with bias")


@pytest.mark.cuda
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skipping for CPU-only setups")
class TestTorchFXCudaFBCAlgorithm(TestTorchFXFBCAlgorithm):
@staticmethod
def list_to_backend_type(data: List) -> torch.Tensor:
return torch.Tensor(data).cuda()

@staticmethod
def backend_specific_model(model: bool, tmp_dir: str):
fx_cuda_model = TestTorchFXFBCAlgorithm._get_fx_model(model.cuda())
return fx_cuda_model

@staticmethod
def fn_to_type(tensor):
return torch.Tensor(tensor).cuda()

@staticmethod
def check_bias(model: torch.fx.GraphModule, ref_bias: list):
TestTorchFXFBCAlgorithm.check_bias(model, ref_bias)

0 comments on commit 2a0a93e

Please sign in to comment.