Skip to content

Commit

Permalink
Merge branch 'fx_resnet_test' of https://github.com/anzr299/nncf into…
Browse files Browse the repository at this point in the history
… fx_resnet_test
  • Loading branch information
anzr299 committed Jul 14, 2024
2 parents 2741084 + 227f10b commit d94ebb7
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions tests/torch/fx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,35 @@ def torchvision_model_builder(model_id: str, input_shape: Tuple[int,]):
return ModelCase(model, model_id, input_shape)



def ultralytics_model_builder(model_id: str, input_shape: Tuple[int,]):
model_config = model_id + ".yaml" # Initialize the model with random weights instead of downloading them.
model = YOLO(model_config)
model = model.model
ex_input = torch.ones(input_shape)
model.eval()
model(ex_input) # inferring from model to avoid anchor mutation in YOLOv8
model.eval()
model(ex_input) # inferring from model to avoid anchor mutation in YOLOv8
return ModelCase(model, model_id, input_shape)



TEST_MODELS = (
torchvision_model_builder("resnet18", (1, 3, 224, 224)),
torchvision_model_builder("mobilenet_v3_small", (1, 3, 224, 224)),
torchvision_model_builder("vit_b_16", (1, 3, 224, 224)),
torchvision_model_builder("swin_v2_s", (1, 3, 224, 224)),
ultralytics_model_builder("yolov8n", (1, 3, 224, 224)),
torchvision_model_builder("resnet18", (1, 3, 224, 224)),
torchvision_model_builder("mobilenet_v3_small", (1, 3, 224, 224)),
torchvision_model_builder("vit_b_16", (1, 3, 224, 224)),
torchvision_model_builder("swin_v2_s", (1, 3, 224, 224)),
ultralytics_model_builder("yolov8n", (1, 3, 224, 224)),
)



def get_dot_filename(model_name):
return model_name + ".dot"

Expand All @@ -99,17 +109,21 @@ def get_ref_metatypes_from_json(
model_name: str, model_metatypes: Dict[NNCFNodeName, Type[OperatorMetatype]]
) -> Dict[NNCFNodeName, Type[OperatorMetatype]]:


model_json_name = get_json_filename(model_name)
complete_path = get_full_path_to_json(model_json_name)


json_parent_dir = Path(complete_path).parent


if os.getenv("NNCF_TEST_REGEN_JSON") is not None:
if not os.path.exists(json_parent_dir):
os.makedirs(json_parent_dir)
with safe_open(complete_path, "w") as file:
json.dump(model_metatypes, file)


with safe_open(complete_path, "r") as file:
return json.load(file)

Expand Down

0 comments on commit d94ebb7

Please sign in to comment.