-
Notifications
You must be signed in to change notification settings - Fork 70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC-0016: Masked reductions and normalizations #27
base: master
Are you sure you want to change the base?
Conversation
This RFC Discusses semantics and implementation details of masked reduction and normalization operators.
Feels like it should be prototypable with |
|
||
**Input types:** A mask is a boolean tensor. It accompanies a dense Tensor of the same shape. If an entry is True the corresponding element at the same index in the paired dense Tensor is a "valid" value. If it is False it is not. "valid" here means that this value is meant to be included in the computation and otherwise is meant to be ignored. This matches the semantics of masked_scatter, masked_select and masked_fill. | ||
|
||
**Fully masked rows:** If a slice (e.g. row) is fully masked out there is no guarantee the corresponding return values are filled with any specific value such as the operation's identity value. However, given a sparse input with a row entirely zero and masked out the result is likely to be zero to maximize memory savings. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to decide if this lack of guarantee is a problem wrt the MHA NaN gradient issue for fully-masked out rows. Say we output non-zero values for this case, do we put the responsibility onto the user to avoid using those values in the loss calculation?
Edit: I realize this applies more to softmax (i.e. normalizations) but it looks like there's a similar statement below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If an element is fully masked out its value shouldn't matter. If the user wants to use the values of those elements in subsequent calculations they can use the mask to fill them with the values want. But of the purposes of our functions the user told us that value of those elements don't matter, so we can change them if we want.
If there are use cases that require those values to be untouched, we can add that as a feature later on. I imagine it's more difficult to write a kernel that does this with the same level of efficiency at least for some functions.
@ezyang - does this mean you'd prefer to see this released and packaged out of tree first before considering inclusion in the core? |
Not necessarily; I'm referring to this part of the spec:
wouldn't be a long step to have an executable specification that people can play around with. |
Since the nan* reductions, like nansum, are existing masked reductions we should be sure the semantics are equivalent. This proposal just allows the mask to be specified directly rather than by value. Supporting more general value-based masking might be interesting in the future, too. |
Nit: |
If it's just one person, probably sticking it in a colab is good enough. Multiple people wanting to work on the semantics ~> put it in GitHub somewhere. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have two suggestions
- re mask definition that mismatches with the one from numpy.ma
- re API of masked operations to match the requirements of ReductionOpInfo
|
||
|
||
``` | ||
def masked_sum(input, dim, keepdim, dtype, mask): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Notice that https://github.com/pytorch/pytorch/blob/72274e2a2fd55019ec860e1743dbdc5b0c5a5624/torch/testing/_internal/common_methods_invocations.py#L860-L865 defines:
An operator is a reduction operator if it reduces one or more dimensions of
the input tensor to a single value. Reduction operators must implement the
following signature:
- `op(input, *args, *, dim=None, keepdim=False, **kwargs) -> Tensor`
Following this definition, I suggest using
def masked_sum(input, mask=None, *, dim=None, keepdim=False, dtype=None): ...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, absolutely agreed. I skipped implementing all the overloads and default values for brevity in this RFC.
|
||
**Operator constraints and general signature - Reductions** | ||
|
||
**Input types:** A mask is a boolean tensor. It accompanies a dense Tensor of the same shape. If an entry is True the corresponding element at the same index in the paired dense Tensor is a "valid" value. If it is False it is not. "valid" here means that this value is meant to be included in the computation and otherwise is meant to be ignored. This matches the semantics of masked_scatter, masked_select and masked_fill. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This definition of mask mismatches the definition of numpy.ma mask:
When an element of the mask is False, the corresponding element of the associated array is valid
and is said to be unmasked. When an element of the mask is True, the corresponding element of
the associated array is said to be masked (invalid).
Possible solutions:
- Adjust the mask definition to match the one in numpy.ma
- Rename "mask" to "valid"
- Document the mismatch with numpy.ma
I would vote for 1 because the mismatch with the three non-arithmetic ops masked_scatter
, ... would not be as bad as the mismatch with arithmetic ops of more widely used numpy.ma module, IMHO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will redraw my vote for option 1 and suggest defining:
A mask is a boolean tensor of selection.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the current definition of mask adheres to A mask is a boolean tensor of selection.
if I understand you correctly?
Interestingly enough our MHA implementation uses True to indicate an element is meant to be ignored.
But like you mention, the other ops use the inverse of this definition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the current definition of mask adheres to
A mask is a boolean tensor of selection.
if I understand you correctly?
yes
Interestingly enough our MHA implementation uses True to indicate an element is meant to be ignored.
I guess it boils down to not assuming that "mask" is a uniquely defined concept and the exact meaning of mask must be documented in functions docstrings.
But like you mention the other ops use the inverse of this definition.
In addition, https://numpy.org/doc/stable/reference/generated/numpy.sum.html uses where
argument as a "mask of selection".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Based on our offline discussion with Ralf I think it's safe to go for mask semantics that adhere to those described in this RFC.
|
||
``` | ||
def masked_sum(input, dim, keepdim, dtype, mask): | ||
return torch.sum(input * mask, dim, keepdim, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NOTE: This actually requires masked_fill, because we have no guarantee that any masked values of input are valid.
>>> torch.tensor([float('inf')]) * torch.tensor([False])
tensor([nan])
>>> torch.tensor([float('inf')]).masked_fill(~torch.tensor([False]), 0)
tensor([0.])
This RFC Discusses semantics and implementation details of masked reduction and normalization operators.
Rendered