diff --git a/tests/torch/fx/test_models.py b/tests/torch/fx/test_models.py index 260a1060808..2e6e52b118a 100644 --- a/tests/torch/fx/test_models.py +++ b/tests/torch/fx/test_models.py @@ -56,6 +56,7 @@ def torchvision_model_builder(model_id: str, input_shape: Tuple[int,]): return ModelCase(model, model_id, input_shape) + def ultralytics_model_builder(model_id: str, input_shape: Tuple[int,]): model_config = model_id + ".yaml" # Initialize the model with random weights instead of downloading them. model = YOLO(model_config) @@ -63,18 +64,27 @@ def ultralytics_model_builder(model_id: str, input_shape: Tuple[int,]): ex_input = torch.ones(input_shape) model.eval() model(ex_input) # inferring from model to avoid anchor mutation in YOLOv8 + model.eval() + model(ex_input) # inferring from model to avoid anchor mutation in YOLOv8 return ModelCase(model, model_id, input_shape) + TEST_MODELS = ( torchvision_model_builder("resnet18", (1, 3, 224, 224)), torchvision_model_builder("mobilenet_v3_small", (1, 3, 224, 224)), torchvision_model_builder("vit_b_16", (1, 3, 224, 224)), torchvision_model_builder("swin_v2_s", (1, 3, 224, 224)), ultralytics_model_builder("yolov8n", (1, 3, 224, 224)), + torchvision_model_builder("resnet18", (1, 3, 224, 224)), + torchvision_model_builder("mobilenet_v3_small", (1, 3, 224, 224)), + torchvision_model_builder("vit_b_16", (1, 3, 224, 224)), + torchvision_model_builder("swin_v2_s", (1, 3, 224, 224)), + ultralytics_model_builder("yolov8n", (1, 3, 224, 224)), ) + def get_dot_filename(model_name): return model_name + ".dot" @@ -99,17 +109,21 @@ def get_ref_metatypes_from_json( model_name: str, model_metatypes: Dict[NNCFNodeName, Type[OperatorMetatype]] ) -> Dict[NNCFNodeName, Type[OperatorMetatype]]: + model_json_name = get_json_filename(model_name) complete_path = get_full_path_to_json(model_json_name) + json_parent_dir = Path(complete_path).parent + if os.getenv("NNCF_TEST_REGEN_JSON") is not None: if not os.path.exists(json_parent_dir): os.makedirs(json_parent_dir) with safe_open(complete_path, "w") as file: json.dump(model_metatypes, file) + with safe_open(complete_path, "r") as file: return json.load(file)