diff --git a/tests/torch/fx/test_models.py b/tests/torch/fx/test_models.py index 382c1a8e579..fd52c8abbe3 100644 --- a/tests/torch/fx/test_models.py +++ b/tests/torch/fx/test_models.py @@ -14,7 +14,7 @@ import os from dataclasses import dataclass from pathlib import Path -from typing import Dict, Tuple, Type +from typing import Any, Dict, Tuple, Type import openvino.torch # noqa import pytest @@ -25,6 +25,7 @@ import torch.utils.data import torch.utils.data.distributed import torchvision.models as models +from diffusers import StableDiffusionPipeline from torch._export import capture_pre_autograd_graph from ultralytics.models.yolo import YOLO @@ -48,24 +49,63 @@ def fx_dir_fixture(request): class ModelCase: model: torch.nn.Module model_id: str - input_shape: Tuple[int] + ex_input: Any -def torchvision_model_builder(model_id: str, input_shape: Tuple[int,]): +def stable_diffusion_model_builder(model_id: str, prompt: str, input_shape: Tuple[int, ...]): + pipe = StableDiffusionPipeline.from_pretrained(model_id).to("cpu") + + text_encoder = pipe.text_encoder + text_encoder.eval() + + unet = pipe.unet + unet.eval() + + vae = pipe.vae + vae.eval() + + tokenizer = pipe.tokenizer + + ex_input = torch.ones(input_shape) + device = "cpu" + prompt = "A truly masterpeice of a work" + text_inputs = tokenizer(prompt, return_tensors="pt") + text_input_ids = text_inputs.input_ids.to("cpu") + text_embeddings = text_encoder(text_input_ids)[0] + + batch_size = text_input_ids.shape[0] + latents = torch.ones( + (batch_size, unet.config.in_channels, pipe.unet.sample_size, pipe.unet.sample_size), device=device + ) + dummy_t = torch.tensor([0.0], device=device) + unet_input = ( + latents, + dummy_t, + text_embeddings, + ) + + return ( + model_id, + ModelCase(unet, "SD_UNET", unet_input), + ModelCase(vae, "SD_VAE", ex_input), + ModelCase(text_encoder, "SD_Text_Encoder", text_input_ids), + ) + + +def torchvision_model_builder(model_id: str, input_shape: Tuple[int, ...]): model = getattr(models, model_id)(weights=None) - return ModelCase(model, model_id, input_shape) + ex_input = torch.ones(input_shape) + return ModelCase(model, model_id, ex_input) -def ultralytics_model_builder(model_id: str, input_shape: Tuple[int,]): +def ultralytics_model_builder(model_id: str, input_shape: Tuple[int, ...]): model_config = model_id + ".yaml" # Initialize the model with random weights instead of downloading them. model = YOLO(model_config) model = model.model ex_input = torch.ones(input_shape) model.eval() - model(ex_input) # inferring from model to avoid anchor mutation in YOLOv8 - model.eval() - model(ex_input) # inferring from model to avoid anchor mutation in YOLOv8 - return ModelCase(model, model_id, input_shape) + _ = model(ex_input) # inferring from model to avoid anchor mutation in YOLOv8 + return ModelCase(model, model_id, ex_input) TEST_MODELS = ( @@ -74,6 +114,7 @@ def ultralytics_model_builder(model_id: str, input_shape: Tuple[int,]): torchvision_model_builder("vit_b_16", (1, 3, 224, 224)), torchvision_model_builder("swin_v2_s", (1, 3, 224, 224)), ultralytics_model_builder("yolov8n", (1, 3, 224, 224)), + stable_diffusion_model_builder("stabilityai/stable-diffusion-2-1-base", "Hello world", (1, 3, 224, 224)), ) @@ -92,7 +133,7 @@ def get_full_path_to_json(model_json_name: str) -> str: def get_ref_metatypes_from_json( - model_name: str, model_metatypes: Dict[NNCFNodeName, Type[OperatorMetatype]] + model_name: str, model_metatypes: Dict[NNCFNodeName, Type[OperatorMetatype]], fx_dir: str ) -> Dict[NNCFNodeName, Type[OperatorMetatype]]: model_json_name = get_json_filename(model_name) @@ -100,6 +141,11 @@ def get_ref_metatypes_from_json( json_parent_dir = Path(complete_path).parent + pipeline_model_dir = fx_dir.split("/") + if len(pipeline_model_dir) > 1: + json_parent_dir = json_parent_dir / "/".join(pipeline_model_dir[1:]) + complete_path = json_parent_dir / model_json_name + if os.getenv("NNCF_TEST_REGEN_JSON") is not None: if not os.path.exists(json_parent_dir): os.makedirs(json_parent_dir) @@ -115,21 +161,31 @@ def compare_nncf_graph_model(model: PTNNCFGraph, model_name: str, path_to_dot: s check_graph(model, dot_filename, path_to_dot) -@pytest.mark.parametrize("test_case", TEST_MODELS) -def test_models(test_case: ModelCase, fx_dir): +def run_test(test_case: ModelCase, fx_dir): + device = torch.device("cpu") with disable_patching(): - device = torch.device("cpu") model_name = test_case.model_id model = test_case.model model.to(device) + model.eval() with torch.no_grad(): - ex_input = torch.ones(test_case.input_shape) + ex_input = test_case.ex_input path_to_dot = fx_dir - model.eval() - exported_model = capture_pre_autograd_graph(model, args=(ex_input,)) + if not isinstance(ex_input, tuple): + ex_input = (ex_input,) + exported_model = capture_pre_autograd_graph(model, args=ex_input) nncf_graph = GraphConverter.create_nncf_graph(exported_model) compare_nncf_graph_model(nncf_graph, model_name, path_to_dot) model_metatypes = {n.node_name: n.metatype.name for n in nncf_graph.get_all_nodes()} - ref_metatypes = get_ref_metatypes_from_json(model_name, model_metatypes) + ref_metatypes = get_ref_metatypes_from_json(model_name, model_metatypes, fx_dir) assert model_metatypes == ref_metatypes + + +@pytest.mark.parametrize("test_case", TEST_MODELS) +def test_models(test_case: ModelCase, fx_dir): + if isinstance(test_case, tuple): + for model in test_case[1:]: + run_test(model, fx_dir + "/" + test_case[0]) + else: + run_test(test_case, fx_dir)