Skip to content

Commit

Permalink
Updated layers.py, test_regnet.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mobley-trent committed Aug 24, 2023
1 parent 4c4b630 commit adf548e
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
15 changes: 8 additions & 7 deletions ivy_models/regnet/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,11 +417,11 @@ def from_init_params(
# Compute the block widths. Each stage has one unique block width
widths_cont = ivy.arange(depth) * w_a + w_0
block_capacity = ivy.round(ivy.log(widths_cont / w_0) / math.log(w_m))
block_widths = (
(ivy.round(ivy.divide(w_0 * ivy.pow(w_m, block_capacity), QUANT)) * QUANT)
.int()
.tolist()
)
_block_widths = (
ivy.round(ivy.divide(w_0 * ivy.pow(w_m, block_capacity), QUANT)) * QUANT
) # noqa: E501
_block_widths = _block_widths.astype(ivy.int64)
block_widths = (_block_widths).tolist()
num_stages = len(set(block_widths))

# Convert to per stage parameters
Expand All @@ -434,9 +434,10 @@ def from_init_params(
splits = [w != wp or r != rp for w, wp, r, rp in split_helper]

stage_widths = [w for w, t in zip(block_widths, splits[:-1]) if t]
stage_depths = (
ivy.diff(ivy.array([d for d, t in enumerate(splits) if t])).int().tolist()
stage_depths = ivy.diff(
ivy.array([d for d, t in enumerate(splits) if t]),
)
stage_depths = stage_depths.astype(ivy.int64).tolist()

strides = [STRIDE] * num_stages
bottleneck_multipliers = [bottleneck_multiplier] * num_stages
Expand Down
1 change: 0 additions & 1 deletion ivy_models_tests/regnet/test_regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import random
import os

ivy.set_backend("torch")

VARIANTS = {
"regnet_y_1_6gf": regnet_y_1_6gf,
Expand Down

0 comments on commit adf548e

Please sign in to comment.