Skip to content

Commit

Permalink
Init test fast bias correction
Browse files Browse the repository at this point in the history
init
  • Loading branch information
anzr299 committed Jul 15, 2024
1 parent 6ba8a34 commit 60cc138
Showing 1 changed file with 113 additions and 0 deletions.
113 changes: 113 additions & 0 deletions tests/torch/fx/test_fast_bias_correction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List

import pytest
import torch

from nncf.common.factory import NNCFGraphFactory
from nncf.quantization.algorithms.fast_bias_correction.torch_fx_backend import FXFastBiasCorrectionAlgoBackend
from nncf.torch.model_graph_manager import get_fused_bias_value
from nncf.torch.model_graph_manager import is_node_with_fused_bias
from nncf.torch.nncf_network import NNCFNetwork
from tests.post_training.test_templates.test_fast_bias_correction import TemplateTestFBCAlgorithm
from tests.torch.ptq.helpers import get_nncf_network
from torch._export import capture_pre_autograd_graph
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter

class TestTorchFXFBCAlgorithm(TemplateTestFBCAlgorithm):
@staticmethod
def list_to_backend_type(data: List) -> torch.Tensor:
return torch.Tensor(data)

@staticmethod
def get_backend() -> FXFastBiasCorrectionAlgoBackend:
return FXFastBiasCorrectionAlgoBackend

@staticmethod
def backend_specific_model(model: torch.nn.Module, tmp_dir: str):

input_shape = model.INPUT_SIZE
if input_shape is None:
input_shape = [1, 3, 32, 32]
ex_input = torch.ones(input_shape)
with torch.no_grad():
model = model.eval()
device = next(model.named_parameters())[1].device
model.to(device)
print(model)
exported_model = capture_pre_autograd_graph(model, args=(ex_input,))
return exported_model

@staticmethod
def fn_to_type(tensor):
return torch.Tensor(tensor)

@staticmethod
def get_transform_fn():
def transform_fn(data_item):
tensor, _ = data_item
return tensor

return transform_fn

@staticmethod
def check_bias(model: NNCFNetwork, ref_bias: list):
ref_bias = torch.Tensor(ref_bias)
nncf_graph = NNCFGraphFactory.create(model)
for node in nncf_graph.get_all_nodes():
if not is_node_with_fused_bias(node, nncf_graph):
continue
bias_value = get_fused_bias_value(node, model)
# TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189
assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}"
return
raise ValueError("Not found node with bias")


@pytest.mark.cuda
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skipping for CPU-only setups")
class TestTorchFXCudaFBCAlgorithm(TestTorchFXFBCAlgorithm):
@staticmethod
def list_to_backend_type(data: List) -> torch.Tensor:
return torch.Tensor(data).cuda()

@staticmethod
def backend_specific_model(model: bool, tmp_dir: str):
input_shape = model.INPUT_SIZE
if input_shape is None:
input_shape = [1, 3, 32, 32]
ex_input = torch.ones(input_shape)
with torch.no_grad():
model = model.eval()
device = next(model.named_parameters())[1].device
model.to(device)
model(ex_input)
exported_model = capture_pre_autograd_graph(model, args=(ex_input,))
return exported_model

@staticmethod
def fn_to_type(tensor):
return torch.Tensor(tensor).cuda()

@staticmethod
def check_bias(model: NNCFNetwork, ref_bias: list):
ref_bias = torch.Tensor(ref_bias)
nncf_graph = NNCFGraphFactory.create(model)
for node in nncf_graph.get_all_nodes():
if not is_node_with_fused_bias(node, nncf_graph):
continue
bias_value = get_fused_bias_value(node, model).cpu()
# TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189
assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}"
return
raise ValueError("Not found node with bias")

0 comments on commit 60cc138

Please sign in to comment.