Skip to content

Commit

Permalink
Fix models and pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
anzr299 committed Jul 14, 2024
1 parent d94ebb7 commit c424725
Showing 1 changed file with 0 additions and 12 deletions.
12 changes: 0 additions & 12 deletions tests/torch/fx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def torchvision_model_builder(model_id: str, input_shape: Tuple[int,]):
return ModelCase(model, model_id, input_shape)



def ultralytics_model_builder(model_id: str, input_shape: Tuple[int,]):
model_config = model_id + ".yaml" # Initialize the model with random weights instead of downloading them.
model = YOLO(model_config)
Expand All @@ -69,22 +68,15 @@ def ultralytics_model_builder(model_id: str, input_shape: Tuple[int,]):
return ModelCase(model, model_id, input_shape)



TEST_MODELS = (
torchvision_model_builder("resnet18", (1, 3, 224, 224)),
torchvision_model_builder("mobilenet_v3_small", (1, 3, 224, 224)),
torchvision_model_builder("vit_b_16", (1, 3, 224, 224)),
torchvision_model_builder("swin_v2_s", (1, 3, 224, 224)),
ultralytics_model_builder("yolov8n", (1, 3, 224, 224)),
torchvision_model_builder("resnet18", (1, 3, 224, 224)),
torchvision_model_builder("mobilenet_v3_small", (1, 3, 224, 224)),
torchvision_model_builder("vit_b_16", (1, 3, 224, 224)),
torchvision_model_builder("swin_v2_s", (1, 3, 224, 224)),
ultralytics_model_builder("yolov8n", (1, 3, 224, 224)),
)



def get_dot_filename(model_name):
return model_name + ".dot"

Expand All @@ -109,21 +101,17 @@ def get_ref_metatypes_from_json(
model_name: str, model_metatypes: Dict[NNCFNodeName, Type[OperatorMetatype]]
) -> Dict[NNCFNodeName, Type[OperatorMetatype]]:


model_json_name = get_json_filename(model_name)
complete_path = get_full_path_to_json(model_json_name)


json_parent_dir = Path(complete_path).parent


if os.getenv("NNCF_TEST_REGEN_JSON") is not None:
if not os.path.exists(json_parent_dir):
os.makedirs(json_parent_dir)
with safe_open(complete_path, "w") as file:
json.dump(model_metatypes, file)


with safe_open(complete_path, "r") as file:
return json.load(file)

Expand Down

0 comments on commit c424725

Please sign in to comment.