Skip to content

Commit

Permalink
[Release] fix multinomial (#672)
Browse files Browse the repository at this point in the history
* fix multinomial (#664)

* add testing utilities used
  • Loading branch information
Samantha Andow authored Apr 8, 2022
1 parent 18f600a commit e4dd624
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 2 deletions.
35 changes: 34 additions & 1 deletion functorch/csrc/BatchRulesRandomness.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,39 @@ std::tuple<Tensor,Tensor> native_dropout_batching_rule(const Tensor& tensor, dou
return std::make_tuple(output, mask);
}

Tensor multinomial_batching_rule(const Tensor& self, const int64_t num_samples, const bool replacement, const c10::optional<Generator> generator) {
c10::impl::ExcludeDispatchKeyGuard guard(kVmapModeKey);
auto maybe_layer = maybeCurrentDynamicLayer();
const auto cur_level = maybe_layer->layerId();

Tensor self_value;
optional<int64_t> self_bdim;
std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
self_value = moveBatchDimToFront(self_value, self_bdim);

RandomnessType randomness = maybe_layer->randomness();
check_randomness(randomness, self_bdim.has_value());

if (randomness == RandomnessType::Different && !self_bdim) {
auto shape = self_value.sizes();
VmapDimVector shapeVec(1, maybe_layer->batchSize());
shapeVec.reserve(shape.size() + 1);
shapeVec.insert(shapeVec.end(), shape.begin(), shape.end());
self_value = self_value.expand(shapeVec);
}
if (self_value.dim() == 3 && (self_bdim || randomness == RandomnessType::Different)) {
self_value = reshape_dim_into(1, 0, self_value);
}
auto out = multinomial(self_value, num_samples, replacement, generator);
if (randomness == RandomnessType::Same && !self_bdim) {
return out;
}
if(self_value.dim() == 3 && self_bdim) {
out = out.reshape(self.sizes());
}
return makeBatched(out, 0, cur_level);
}

template <typename A, A a, typename C>
struct RandomBatchRuleHelper;

Expand Down Expand Up @@ -419,7 +452,7 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchVmapMode, m) {

UNARY_POINTWISE_RANDOM(_standard_gamma);
UNARY_POINTWISE_RANDOM(_sample_dirichlet);
UNARY_POINTWISE_RANDOM(multinomial);
m.impl("multinomial", multinomial_batching_rule);
UNARY_POINTWISE_RANDOM(poisson);
UNARY_POINTWISE_RANDOM(bernoulli);

Expand Down
96 changes: 95 additions & 1 deletion test/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -3561,8 +3561,11 @@ def test_randperm(self, device, randomness, use_generator):
assert torch.allclose(vmap_result[i], expected)

def _get_image(self, batched_input, batch_size, device):
if batched_input:
if batched_input == "first" or batched_input is True:
return torch.ones([batch_size, 3, 3, 14, 14], device=device)
if batched_input == "last":
return torch.ones([3, 3, 14, 14, batch_size], device=device)
assert batched_input == "none" or batched_input is False
return torch.ones([3, 3, 14, 14], device=device)

def _assert_all_slices_equal(self, tensor):
Expand All @@ -3576,6 +3579,31 @@ def _assert_all_slices_unique(self, tensor):
slices_equal.diagonal().zero_()
self.assertEqual(slices_equal, torch.zeros_like(slices_equal))

def _reset_random(self, generator, orig_state, use_generator, seed):
return generator.set_state(orig_state) if use_generator else torch.manual_seed(seed)

def _assert_throws_in_error_mode(self, fn, args, in_dims):
with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
vmap(fn, in_dims=in_dims, randomness="error")(*args)

def _assert_throws_in_same_mode_batched(self, fn, args, in_dims):
with self.assertRaisesRegex(RuntimeError,
r"Vmap does not currently support same randomness with a batched tensor input"):
vmap(fn, in_dims=in_dims, randomness="same")(*args)

def _in_dims(self, *batched_strings):

def get_in_dim(batched_string):
if batched_string == "first":
return 0
if batched_string == "last":
return -1
assert batched_string == "none"
return None

batched_strings = batched_strings + ("first",) # for the always batched as first dim dummy argument
return tuple(get_in_dim(batched_string) for batched_string in batched_strings)

@parametrize('randomness', ['error', 'same', 'different'])
@parametrize('batched_input', [True, False])
def test_dropout(self, device, randomness, batched_input):
Expand Down Expand Up @@ -3801,6 +3829,72 @@ def test_random_inplace_not_batched(self, device, use_generator, randomness):
op(passed, vmaped_value)
assert torch.allclose(unvmaped_value, passed)

@parametrize('use_generator', [True, False])
@parametrize('randomness', ['error', 'same', 'different'])
@parametrize('batched_call', [True, False])
@parametrize('batched_input', ["first", "last", "none"])
def test_multinomial(self, device, use_generator, randomness, batched_call, batched_input):
def flatten_input(input, batch_call, batch_location):
if batch_call and batch_location != "none":
final_size = 3 # [B0, B, N]
elif not batch_call and batch_location == "none":
final_size = 1 # [N]
else:
final_size = 2 # [B0, N] or [B, N]

start_idx = final_size - 1
end_idx = -1
if batch_location == "last":
start_idx -= 1
end_idx -= 1 # gets to correct final size because using negative indices

ret = input.flatten(start_idx, end_idx)
assert ret.dim() == final_size
return ret

def op(input, _):
return torch.multinomial(input, 10, **kwargs)

generator = torch.Generator(device=device)
orig_state = generator.get_state()
kwargs = {'generator': generator} if use_generator else {}

B0 = 4
seed = 1234567
in_dims = self._in_dims(batched_input)

always_batched = torch.randn(B0, device=device)
passed = self._get_image(batched_input, B0, device)
passed = flatten_input(passed, batched_call, batched_input)
if randomness == 'error':
self._assert_throws_in_error_mode(op, (passed, always_batched), in_dims=in_dims)
return
if randomness == 'same' and batched_input != "none":
self._assert_throws_in_same_mode_batched(op, (passed, always_batched), in_dims=in_dims)
return

generator = self._reset_random(generator, orig_state, use_generator, seed)
vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(passed, always_batched)

generator = self._reset_random(generator, orig_state, use_generator, seed)

if randomness == "different":
if batched_input == "none":
passed = passed.expand(B0, *passed.shape)
if batched_input == "last":
passed = passed.movedim(-1, 0)
orig_passed_size = passed.shape[:2] if batched_call else passed.shape[:1]
passed = passed.flatten(0, 1) if batched_call else passed
expected = op(passed, always_batched)
expected.reshape(*orig_passed_size, 10)
self._assert_all_slices_unique(vmap_result)
self.assertEqual(vmap_result, expected)
else:
expected = op(passed, always_batched)
self._assert_all_slices_equal(vmap_result)
for i in range(B0):
self.assertEqual(vmap_result[i], expected)

def test_unsupported_random(self, device):
x = torch.randn(3, device=device)
y = x.abs()
Expand Down

0 comments on commit e4dd624

Please sign in to comment.