diff --git a/test/test_vmap.py b/test/test_vmap.py index 3870783f2a..1b12f9b22c 100644 --- a/test/test_vmap.py +++ b/test/test_vmap.py @@ -3565,6 +3565,31 @@ def _assert_all_slices_unique(self, tensor): slices_equal.diagonal().zero_() self.assertEqual(slices_equal, torch.zeros_like(slices_equal)) + def _reset_random(self, generator, orig_state, use_generator, seed): + return generator.set_state(orig_state) if use_generator else torch.manual_seed(seed) + + def _assert_throws_in_error_mode(self, fn, args, in_dims): + with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"): + vmap(fn, in_dims=in_dims, randomness="error")(*args) + + def _assert_throws_in_same_mode_batched(self, fn, args, in_dims): + with self.assertRaisesRegex(RuntimeError, + r"Vmap does not currently support same randomness with a batched tensor input"): + vmap(fn, in_dims=in_dims, randomness="same")(*args) + + def _in_dims(self, *batched_strings): + + def get_in_dim(batched_string): + if batched_string == "first": + return 0 + if batched_string == "last": + return -1 + assert batched_string == "none" + return None + + batched_strings = batched_strings + ("first",) # for the always batched as first dim dummy argument + return tuple(get_in_dim(batched_string) for batched_string in batched_strings) + @parametrize('randomness', ['error', 'same', 'different']) @parametrize('batched_input', [True, False]) def test_dropout(self, device, randomness, batched_input):