Skip to content

Commit

Permalink
Merge branch 'fx_model_test_stable_diffusion' of https://github.com/a…
Browse files Browse the repository at this point in the history
…nzr299/nncf into fx_model_test_stable_diffusion
  • Loading branch information
anzr299 committed Jul 29, 2024
2 parents f33011d + 8b2f8be commit 00b8620
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions tests/torch/fx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import torch.utils.data.distributed
import torchvision.models as models
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionPipeline
from torch._export import capture_pre_autograd_graph

from nncf.common.graph.graph import NNCFNodeName
Expand Down Expand Up @@ -88,13 +89,19 @@ def get_full_path_to_json(model_json_name: str) -> str:

def get_ref_metatypes_from_json(
model_name: str, model_metatypes: Dict[NNCFNodeName, Type[OperatorMetatype]], fx_dir: str
model_name: str, model_metatypes: Dict[NNCFNodeName, Type[OperatorMetatype]], fx_dir: str
) -> Dict[NNCFNodeName, Type[OperatorMetatype]]:

model_json_name = get_json_filename(model_name)
complete_path = get_full_path_to_json(model_json_name)

json_parent_dir = Path(complete_path).parent

pipeline_model_dir = fx_dir.split("/")
if len(pipeline_model_dir) > 1:
json_parent_dir = json_parent_dir / "/".join(pipeline_model_dir[1:])
complete_path = json_parent_dir / model_json_name

pipeline_model_dir = fx_dir.split("/")
if len(pipeline_model_dir) > 1:
json_parent_dir = json_parent_dir / "/".join(pipeline_model_dir[1:])
Expand All @@ -117,6 +124,7 @@ def test_models(test_case: ModelCase):
model = test_case.model_builder()
model.to(device)
model.eval()
model.eval()

with torch.no_grad():
ex_input = torch.ones(test_case.input_shape)
Expand Down

0 comments on commit 00b8620

Please sign in to comment.