Skip to content

Commit

Permalink
Merge pull request #1 from daniil-lyakhov/dl/fx_resnet_test
Browse files Browse the repository at this point in the history
[TorchFX][Test] Ultralytics dependency is removed
  • Loading branch information
anzr299 authored Jul 29, 2024
2 parents 20380eb + f1838a1 commit 909fb13
Show file tree
Hide file tree
Showing 9 changed files with 2,171 additions and 226 deletions.

Large diffs are not rendered by default.

366 changes: 183 additions & 183 deletions tests/torch/data/reference_graphs/fx/yolov8n.dot

Large diffs are not rendered by default.

75 changes: 34 additions & 41 deletions tests/torch/fx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
import json
import os
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Dict, Tuple, Type
from typing import Callable, Dict, Tuple, Type

import openvino.torch # noqa
import pytest
Expand All @@ -26,54 +27,47 @@
import torch.utils.data.distributed
import torchvision.models as models
from torch._export import capture_pre_autograd_graph
from ultralytics.models.yolo import YOLO

from nncf.common.graph.graph import NNCFNodeName
from nncf.common.graph.operator_metatypes import OperatorMetatype
from nncf.common.utils.os import safe_open
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
from nncf.torch.graph.graph import PTNNCFGraph
from tests.shared.paths import TEST_ROOT
from tests.torch.test_compressed_graph import check_graph
from tests.torch.test_models.yolov8.model import YoloV8Model


@pytest.fixture(name="fx_dir")
def fx_dir_fixture(request):
fx_dir_name = "fx"
return fx_dir_name
FX_DIR_NAME = "fx"


@dataclass
class ModelCase:
model: torch.nn.Module
model_builder: Callable[[], torch.nn.Module]
model_id: str
input_shape: Tuple[int]


def torchvision_model_builder(model_id: str, input_shape: Tuple[int,]):
model = getattr(models, model_id)(weights=None)
return ModelCase(model, model_id, input_shape)
def torchvision_model_case(model_id: str, input_shape: Tuple[int,]):
model = getattr(models, model_id)
return ModelCase(partial(model, weights=None), model_id, input_shape)


def yolo_v8_case(model_id, input_shape):
def get_model() -> torch.nn.Module:
model = YoloV8Model().eval()
# Warmup model
model(torch.empty(input_shape))
return model

def ultralytics_model_builder(model_id: str, input_shape: Tuple[int,]):
model_config = model_id + ".yaml" # Initialize the model with random weights instead of downloading them.
model = YOLO(model_config)
model = model.model
ex_input = torch.ones(input_shape)
model.eval()
model(ex_input) # inferring from model to avoid anchor mutation in YOLOv8
model.eval()
model(ex_input) # inferring from model to avoid anchor mutation in YOLOv8
return ModelCase(model, model_id, input_shape)
return ModelCase(get_model, model_id, input_shape)


TEST_MODELS = (
torchvision_model_builder("resnet18", (1, 3, 224, 224)),
torchvision_model_builder("mobilenet_v3_small", (1, 3, 224, 224)),
torchvision_model_builder("vit_b_16", (1, 3, 224, 224)),
torchvision_model_builder("swin_v2_s", (1, 3, 224, 224)),
ultralytics_model_builder("yolov8n", (1, 3, 224, 224)),
torchvision_model_case("resnet18", (1, 3, 224, 224)),
torchvision_model_case("mobilenet_v3_small", (1, 3, 224, 224)),
torchvision_model_case("vit_b_16", (1, 3, 224, 224)),
torchvision_model_case("swin_v2_s", (1, 3, 224, 224)),
yolo_v8_case("yolov8n", (1, 3, 224, 224)),
)


Expand Down Expand Up @@ -110,26 +104,25 @@ def get_ref_metatypes_from_json(
return json.load(file)


def compare_nncf_graph_model(model: PTNNCFGraph, model_name: str, path_to_dot: str):
dot_filename = get_dot_filename(model_name)
check_graph(model, dot_filename, path_to_dot)


@pytest.mark.parametrize("test_case", TEST_MODELS)
def test_models(test_case: ModelCase, fx_dir):
@pytest.mark.parametrize("test_case", TEST_MODELS, ids=[m.model_id for m in TEST_MODELS])
def test_models(test_case: ModelCase):
with disable_patching():
device = torch.device("cpu")
model_name = test_case.model_id
model = test_case.model
model = test_case.model_builder()
model.to(device)

with torch.no_grad():
ex_input = torch.ones(test_case.input_shape)
path_to_dot = fx_dir
model.eval()
exported_model = capture_pre_autograd_graph(model, args=(ex_input,))
nncf_graph = GraphConverter.create_nncf_graph(exported_model)
compare_nncf_graph_model(nncf_graph, model_name, path_to_dot)
model_metatypes = {n.node_name: n.metatype.name for n in nncf_graph.get_all_nodes()}
ref_metatypes = get_ref_metatypes_from_json(model_name, model_metatypes)
assert model_metatypes == ref_metatypes
nncf_graph = GraphConverter.create_nncf_graph(exported_model)

# Check NNCFGrpah
dot_filename = get_dot_filename(model_name)
check_graph(nncf_graph, dot_filename, FX_DIR_NAME)

# Check metatypes
model_metatypes = {n.node_name: n.metatype.name for n in nncf_graph.get_all_nodes()}
ref_metatypes = get_ref_metatypes_from_json(model_name, model_metatypes)
assert model_metatypes == ref_metatypes
1 change: 0 additions & 1 deletion tests/torch/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,3 @@ timm==0.9.2
# Required for torch/fx tests
torchvision
fastdownload==0.0.7
ultralytics==8.2.56 # TODO(dlyakhov) move ultralytics requirements to the nightly test
Loading

0 comments on commit 909fb13

Please sign in to comment.