-
Notifications
You must be signed in to change notification settings - Fork 239
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
init
- Loading branch information
Showing
1 changed file
with
113 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# Copyright (c) 2024 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import List | ||
|
||
import pytest | ||
import torch | ||
|
||
from nncf.common.factory import NNCFGraphFactory | ||
from nncf.quantization.algorithms.fast_bias_correction.torch_fx_backend import FXFastBiasCorrectionAlgoBackend | ||
from nncf.torch.model_graph_manager import get_fused_bias_value | ||
from nncf.torch.model_graph_manager import is_node_with_fused_bias | ||
from nncf.torch.nncf_network import NNCFNetwork | ||
from tests.post_training.test_templates.test_fast_bias_correction import TemplateTestFBCAlgorithm | ||
from tests.torch.ptq.helpers import get_nncf_network | ||
from torch._export import capture_pre_autograd_graph | ||
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter | ||
|
||
class TestTorchFXFBCAlgorithm(TemplateTestFBCAlgorithm): | ||
@staticmethod | ||
def list_to_backend_type(data: List) -> torch.Tensor: | ||
return torch.Tensor(data) | ||
|
||
@staticmethod | ||
def get_backend() -> FXFastBiasCorrectionAlgoBackend: | ||
return FXFastBiasCorrectionAlgoBackend | ||
|
||
@staticmethod | ||
def backend_specific_model(model: torch.nn.Module, tmp_dir: str): | ||
|
||
input_shape = model.INPUT_SIZE | ||
if input_shape is None: | ||
input_shape = [1, 3, 32, 32] | ||
ex_input = torch.ones(input_shape) | ||
with torch.no_grad(): | ||
model = model.eval() | ||
device = next(model.named_parameters())[1].device | ||
model.to(device) | ||
print(model) | ||
exported_model = capture_pre_autograd_graph(model, args=(ex_input,)) | ||
return exported_model | ||
|
||
@staticmethod | ||
def fn_to_type(tensor): | ||
return torch.Tensor(tensor) | ||
|
||
@staticmethod | ||
def get_transform_fn(): | ||
def transform_fn(data_item): | ||
tensor, _ = data_item | ||
return tensor | ||
|
||
return transform_fn | ||
|
||
@staticmethod | ||
def check_bias(model: NNCFNetwork, ref_bias: list): | ||
ref_bias = torch.Tensor(ref_bias) | ||
nncf_graph = NNCFGraphFactory.create(model) | ||
for node in nncf_graph.get_all_nodes(): | ||
if not is_node_with_fused_bias(node, nncf_graph): | ||
continue | ||
bias_value = get_fused_bias_value(node, model) | ||
# TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189 | ||
assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}" | ||
return | ||
raise ValueError("Not found node with bias") | ||
|
||
|
||
@pytest.mark.cuda | ||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="Skipping for CPU-only setups") | ||
class TestTorchFXCudaFBCAlgorithm(TestTorchFXFBCAlgorithm): | ||
@staticmethod | ||
def list_to_backend_type(data: List) -> torch.Tensor: | ||
return torch.Tensor(data).cuda() | ||
|
||
@staticmethod | ||
def backend_specific_model(model: bool, tmp_dir: str): | ||
input_shape = model.INPUT_SIZE | ||
if input_shape is None: | ||
input_shape = [1, 3, 32, 32] | ||
ex_input = torch.ones(input_shape) | ||
with torch.no_grad(): | ||
model = model.eval() | ||
device = next(model.named_parameters())[1].device | ||
model.to(device) | ||
model(ex_input) | ||
exported_model = capture_pre_autograd_graph(model, args=(ex_input,)) | ||
return exported_model | ||
|
||
@staticmethod | ||
def fn_to_type(tensor): | ||
return torch.Tensor(tensor).cuda() | ||
|
||
@staticmethod | ||
def check_bias(model: NNCFNetwork, ref_bias: list): | ||
ref_bias = torch.Tensor(ref_bias) | ||
nncf_graph = NNCFGraphFactory.create(model) | ||
for node in nncf_graph.get_all_nodes(): | ||
if not is_node_with_fused_bias(node, nncf_graph): | ||
continue | ||
bias_value = get_fused_bias_value(node, model).cpu() | ||
# TODO(AlexanderDokuchaev): return atol=0.0001 after fix 109189 | ||
assert torch.all(torch.isclose(bias_value, ref_bias, atol=0.02)), f"{bias_value} != {ref_bias}" | ||
return | ||
raise ValueError("Not found node with bias") |