Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC-0016: Masked reductions and normalizations #27

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

cpuhrsch
Copy link

@cpuhrsch cpuhrsch commented Aug 26, 2021

This RFC Discusses semantics and implementation details of masked reduction and normalization operators.

Rendered

This RFC Discusses semantics and implementation details of masked reduction and normalization operators.
@cpuhrsch
Copy link
Author

cpuhrsch commented Aug 26, 2021

@ezyang
Copy link
Contributor

ezyang commented Aug 27, 2021

Feels like it should be prototypable with __torch_function__ (or maybe __torch_dispatch__?)


**Input types:** A mask is a boolean tensor. It accompanies a dense Tensor of the same shape. If an entry is True the corresponding element at the same index in the paired dense Tensor is a "valid" value. If it is False it is not. "valid" here means that this value is meant to be included in the computation and otherwise is meant to be ignored. This matches the semantics of masked_scatter, masked_select and masked_fill.

**Fully masked rows:** If a slice (e.g. row) is fully masked out there is no guarantee the corresponding return values are filled with any specific value such as the operation's identity value. However, given a sparse input with a row entirely zero and masked out the result is likely to be zero to maximize memory savings.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to decide if this lack of guarantee is a problem wrt the MHA NaN gradient issue for fully-masked out rows. Say we output non-zero values for this case, do we put the responsibility onto the user to avoid using those values in the loss calculation?

Edit: I realize this applies more to softmax (i.e. normalizations) but it looks like there's a similar statement below.

Copy link
Author

@cpuhrsch cpuhrsch Aug 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If an element is fully masked out its value shouldn't matter. If the user wants to use the values of those elements in subsequent calculations they can use the mask to fill them with the values want. But of the purposes of our functions the user told us that value of those elements don't matter, so we can change them if we want.

If there are use cases that require those values to be untouched, we can add that as a feature later on. I imagine it's more difficult to write a kernel that does this with the same level of efficiency at least for some functions.

@cpuhrsch
Copy link
Author

@ezyang - does this mean you'd prefer to see this released and packaged out of tree first before considering inclusion in the core?

@ezyang
Copy link
Contributor

ezyang commented Aug 30, 2021

does this mean you'd prefer to see this released and packaged out of tree first before considering inclusion in the core?

Not necessarily; I'm referring to this part of the spec:

Indeed the best way to describe the behavior is to implement it. Please note that this is only meant to describe semantics and is not an actual implementation.

wouldn't be a long step to have an executable specification that people can play around with.

@mruberry
Copy link

Since the nan* reductions, like nansum, are existing masked reductions we should be sure the semantics are equivalent. This proposal just allows the mask to be specified directly rather than by value. Supporting more general value-based masking might be interesting in the future, too.

cc @heitorschueroff

@cpuhrsch
Copy link
Author

cpuhrsch commented Aug 30, 2021

@ezyang - agreed, I'm wondering whether or when we should create an out-of-tree Python-only prototype for a MaskedTensor.

@mruberry - you can always get a value-based (let's say 4) mask by e.g. masked_sum(input, input != 4) or masked_sum(input, input == input) for nan.

@pearu
Copy link

pearu commented Aug 30, 2021

masked_sum(input, ~(input != input)) for nan.

Nit: masked_sum(input, input == input) would work for the nan case as well.

@ezyang
Copy link
Contributor

ezyang commented Aug 30, 2021

agreed, I'm wondering whether or when we should create an out-of-tree Python-only prototype for a MaskedTensor.

If it's just one person, probably sticking it in a colab is good enough. Multiple people wanting to work on the semantics ~> put it in GitHub somewhere.

Copy link

@pearu pearu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have two suggestions

  • re mask definition that mismatches with the one from numpy.ma
  • re API of masked operations to match the requirements of ReductionOpInfo



```
def masked_sum(input, dim, keepdim, dtype, mask):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Notice that https://github.com/pytorch/pytorch/blob/72274e2a2fd55019ec860e1743dbdc5b0c5a5624/torch/testing/_internal/common_methods_invocations.py#L860-L865 defines:

    An operator is a reduction operator if it reduces one or more dimensions of
    the input tensor to a single value. Reduction operators must implement the
    following signature:
    - `op(input, *args, *, dim=None, keepdim=False, **kwargs) -> Tensor`

Following this definition, I suggest using

def masked_sum(input, mask=None, *, dim=None, keepdim=False, dtype=None): ...

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, absolutely agreed. I skipped implementing all the overloads and default values for brevity in this RFC.


**Operator constraints and general signature - Reductions**

**Input types:** A mask is a boolean tensor. It accompanies a dense Tensor of the same shape. If an entry is True the corresponding element at the same index in the paired dense Tensor is a "valid" value. If it is False it is not. "valid" here means that this value is meant to be included in the computation and otherwise is meant to be ignored. This matches the semantics of masked_scatter, masked_select and masked_fill.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This definition of mask mismatches the definition of numpy.ma mask:

When an element of the mask is False, the corresponding element of the associated array is valid
and is said to be unmasked. When an element of the mask is True, the corresponding element of
the associated array is said to be masked (invalid).

Possible solutions:

  1. Adjust the mask definition to match the one in numpy.ma
  2. Rename "mask" to "valid"
  3. Document the mismatch with numpy.ma

I would vote for 1 because the mismatch with the three non-arithmetic ops masked_scatter, ... would not be as bad as the mismatch with arithmetic ops of more widely used numpy.ma module, IMHO.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will redraw my vote for option 1 and suggest defining:

A mask is a boolean tensor of selection.

Copy link
Author

@cpuhrsch cpuhrsch Sep 8, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current definition of mask adheres to A mask is a boolean tensor of selection. if I understand you correctly?

Interestingly enough our MHA implementation uses True to indicate an element is meant to be ignored.

But like you mention, the other ops use the inverse of this definition.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the current definition of mask adheres to A mask is a boolean tensor of selection. if I understand you correctly?

yes

Interestingly enough our MHA implementation uses True to indicate an element is meant to be ignored.

I guess it boils down to not assuming that "mask" is a uniquely defined concept and the exact meaning of mask must be documented in functions docstrings.

But like you mention the other ops use the inverse of this definition.

In addition, https://numpy.org/doc/stable/reference/generated/numpy.sum.html uses where argument as a "mask of selection".

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Based on our offline discussion with Ralf I think it's safe to go for mask semantics that adhere to those described in this RFC.


```
def masked_sum(input, dim, keepdim, dtype, mask):
return torch.sum(input * mask, dim, keepdim, dtype=dtype)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NOTE: This actually requires masked_fill, because we have no guarantee that any masked values of input are valid.

>>> torch.tensor([float('inf')]) * torch.tensor([False])
tensor([nan])
>>> torch.tensor([float('inf')]).masked_fill(~torch.tensor([False]), 0)
tensor([0.])

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants