Skip to content

Commit

Permalink
add testing utilities used
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Andow committed Apr 7, 2022
1 parent 29ff9c8 commit f06f70d
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3550,8 +3550,11 @@ def test_randperm(self, device, randomness, use_generator):
assert torch.allclose(vmap_result[i], expected)

def _get_image(self, batched_input, batch_size, device):
if batched_input:
if batched_input == "first" or batched_input is True:
return torch.ones([batch_size, 3, 3, 14, 14], device=device)
if batched_input == "last":
return torch.ones([3, 3, 14, 14, batch_size], device=device)
assert batched_input == "none" or batched_input is False
return torch.ones([3, 3, 14, 14], device=device)

def _assert_all_slices_equal(self, tensor):
Expand All @@ -3565,6 +3568,31 @@ def _assert_all_slices_unique(self, tensor):
slices_equal.diagonal().zero_()
self.assertEqual(slices_equal, torch.zeros_like(slices_equal))

def _reset_random(self, generator, orig_state, use_generator, seed):
return generator.set_state(orig_state) if use_generator else torch.manual_seed(seed)

def _assert_throws_in_error_mode(self, fn, args, in_dims):
with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
vmap(fn, in_dims=in_dims, randomness="error")(*args)

def _assert_throws_in_same_mode_batched(self, fn, args, in_dims):
with self.assertRaisesRegex(RuntimeError,
r"Vmap does not currently support same randomness with a batched tensor input"):
vmap(fn, in_dims=in_dims, randomness="same")(*args)

def _in_dims(self, *batched_strings):

def get_in_dim(batched_string):
if batched_string == "first":
return 0
if batched_string == "last":
return -1
assert batched_string == "none"
return None

batched_strings = batched_strings + ("first",) # for the always batched as first dim dummy argument
return tuple(get_in_dim(batched_string) for batched_string in batched_strings)

@parametrize('randomness', ['error', 'same', 'different'])
@parametrize('batched_input', [True, False])
def test_dropout(self, device, randomness, batched_input):
Expand Down

0 comments on commit f06f70d

Please sign in to comment.