Skip to content

Commit

Permalink
chore: fix mypy overload error by introducing intermediate variable
Browse files Browse the repository at this point in the history
  • Loading branch information
benwandrew committed Sep 21, 2023
1 parent bc8fcab commit 10de126
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions src/autora/theorist/darts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,14 +267,10 @@ def count_parameters_in_MB(model: Network) -> int:
Arguments:
model: model to count the parameters for
"""
return (
np.sum( # type: ignore
np.prod(v.size())
for name, v in model.named_parameters()
if "auxiliary" not in name
)
/ 1e6
)
for name, v in model.named_parameters():
if "auxiliary" not in name:
count_var = np.prod(v.size())
return np.sum(count_var) / 1e6


def save(model: torch.nn.Module, model_path: str, exp_folder: Optional[str] = None):
Expand Down

0 comments on commit 10de126

Please sign in to comment.