Skip to content

Commit

Permalink
Updated regnet.py, test_regnet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mobley-trent committed Sep 26, 2023
1 parent 53ef379 commit ca81270
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
4 changes: 4 additions & 0 deletions ivy_models/regnet/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Optional, Callable

from collections import OrderedDict
import builtins


class RegNetSpec(BaseSpec):
Expand Down Expand Up @@ -131,11 +132,14 @@ def _forward(self, x: ivy.Array) -> ivy.Array:


def _regnet_torch_weights_mapping(old_key, new_key):
W_KEY = ["conv1/weight", "conv2/weight", "conv3/weight", "downsample/0/weight"]
new_mapping = new_key
# if "weight" in old_key:
# new_mapping = {"key_chain": new_key, "pattern": "b c h w -> h w c b"}
# elif "bias" in old_key:
# new_mapping = {"key_chain": new_key, "pattern": "h -> 1 h 1 1"}
if builtins.any([kc in old_key for kc in W_KEY]):
new_mapping = {"key_chain": new_key, "pattern": "b c h w -> h w c b"}

return new_mapping

Expand Down
1 change: 1 addition & 0 deletions ivy_models_tests/regnet/test_regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_regnet_img_classification(device, fw):
to_ivy=True,
)
)
img = ivy.squeeze(img, axis=0)

# Create model
model.v = ivy.asarray(v)
Expand Down

0 comments on commit ca81270

Please sign in to comment.