Skip to content

Commit

Permalink
Fix flakiness on StochasticDepth test (#4758)
Browse files Browse the repository at this point in the history
* Fix flakiness on the TestStochasticDepth test.

* Fix minor bug when p=1.0

* Remove device and dtype setting.
  • Loading branch information
datumbox authored Oct 27, 2021
1 parent 5ea2348 commit bbfda42
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
23 changes: 20 additions & 3 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,13 +1149,15 @@ def _create_masks(image, masks):


class TestStochasticDepth:
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("p", [0.2, 0.5, 0.8])
@pytest.mark.parametrize("mode", ["batch", "row"])
def test_stochastic_depth(self, mode, p):
def test_stochastic_depth_random(self, seed, mode, p):
torch.manual_seed(seed)
stats = pytest.importorskip("scipy.stats")
batch_size = 5
x = torch.ones(size=(batch_size, 3, 4, 4))
layer = ops.StochasticDepth(p=p, mode=mode).to(device=x.device, dtype=x.dtype)
layer = ops.StochasticDepth(p=p, mode=mode)
layer.__repr__()

trials = 250
Expand All @@ -1173,7 +1175,22 @@ def test_stochastic_depth(self, mode, p):
num_samples += batch_size

p_value = stats.binom_test(counts, num_samples, p=p)
assert p_value > 0.0001
assert p_value > 0.01

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("p", (0, 1))
@pytest.mark.parametrize("mode", ["batch", "row"])
def test_stochastic_depth(self, seed, mode, p):
torch.manual_seed(seed)
batch_size = 5
x = torch.ones(size=(batch_size, 3, 4, 4))
layer = ops.StochasticDepth(p=p, mode=mode)

out = layer(x)
if p == 0:
assert out.equal(x)
elif p == 1:
assert out.equal(torch.zeros_like(x))


class TestUtils:
Expand Down
4 changes: 3 additions & 1 deletion torchvision/ops/stochastic_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True)
else:
size = [1] * input.ndim
noise = torch.empty(size, dtype=input.dtype, device=input.device)
noise = noise.bernoulli_(survival_rate).div_(survival_rate)
noise = noise.bernoulli_(survival_rate)
if survival_rate > 0.0:
noise.div_(survival_rate)
return input * noise


Expand Down

0 comments on commit bbfda42

Please sign in to comment.