diff --git a/ivy_models/regnet/regnet.py b/ivy_models/regnet/regnet.py index 3ed058d..cc32d10 100644 --- a/ivy_models/regnet/regnet.py +++ b/ivy_models/regnet/regnet.py @@ -5,6 +5,7 @@ from typing import Optional, Callable from collections import OrderedDict +import builtins class RegNetSpec(BaseSpec): @@ -131,11 +132,14 @@ def _forward(self, x: ivy.Array) -> ivy.Array: def _regnet_torch_weights_mapping(old_key, new_key): + W_KEY = ["conv1/weight", "conv2/weight", "conv3/weight", "downsample/0/weight"] new_mapping = new_key # if "weight" in old_key: # new_mapping = {"key_chain": new_key, "pattern": "b c h w -> h w c b"} # elif "bias" in old_key: # new_mapping = {"key_chain": new_key, "pattern": "h -> 1 h 1 1"} + if builtins.any([kc in old_key for kc in W_KEY]): + new_mapping = {"key_chain": new_key, "pattern": "b c h w -> h w c b"} return new_mapping diff --git a/ivy_models_tests/regnet/test_regnet.py b/ivy_models_tests/regnet/test_regnet.py index 3304ec5..34e8a19 100644 --- a/ivy_models_tests/regnet/test_regnet.py +++ b/ivy_models_tests/regnet/test_regnet.py @@ -31,6 +31,7 @@ def test_regnet_img_classification(device, fw): to_ivy=True, ) ) + img = ivy.squeeze(img, axis=0) # Create model model.v = ivy.asarray(v)