Skip to content

Commit

Permalink
add testing utilities used
Browse files Browse the repository at this point in the history
  • Loading branch information
Sam Andow committed Apr 7, 2022
1 parent 29ff9c8 commit 97e884c
Showing 1 changed file with 25 additions and 0 deletions.
25 changes: 25 additions & 0 deletions test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3565,6 +3565,31 @@ def _assert_all_slices_unique(self, tensor):
slices_equal.diagonal().zero_()
self.assertEqual(slices_equal, torch.zeros_like(slices_equal))

def _reset_random(self, generator, orig_state, use_generator, seed):
return generator.set_state(orig_state) if use_generator else torch.manual_seed(seed)

def _assert_throws_in_error_mode(self, fn, args, in_dims):
with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
vmap(fn, in_dims=in_dims, randomness="error")(*args)

def _assert_throws_in_same_mode_batched(self, fn, args, in_dims):
with self.assertRaisesRegex(RuntimeError,
r"Vmap does not currently support same randomness with a batched tensor input"):
vmap(fn, in_dims=in_dims, randomness="same")(*args)

def _in_dims(self, *batched_strings):

def get_in_dim(batched_string):
if batched_string == "first":
return 0
if batched_string == "last":
return -1
assert batched_string == "none"
return None

batched_strings = batched_strings + ("first",) # for the always batched as first dim dummy argument
return tuple(get_in_dim(batched_string) for batched_string in batched_strings)

@parametrize('randomness', ['error', 'same', 'different'])
@parametrize('batched_input', [True, False])
def test_dropout(self, device, randomness, batched_input):
Expand Down

0 comments on commit 97e884c

Please sign in to comment.