Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Varying weights to losses depending on the length of the sequence #291

Closed
shuishida opened this issue Nov 11, 2024 · 12 comments
Closed

Varying weights to losses depending on the length of the sequence #291

shuishida opened this issue Nov 11, 2024 · 12 comments

Comments

@shuishida
Copy link

Because the mean reduce is applied after filtering with a mask, it seems that if the input sequences vary in length (and therefore the size of the masks are different) then for batches were short sequences dominate within the batch will be weighted higher than batches with longer sequences. Although it shouldn't be a large effect, I wonder if it is better to set the masked loss values to zero and then apply the mean reduce for consistency of loss weighting?

https://github.com/lucidrains/x-transformers/blob/144d9ba84955139347e798ab025457b2d7adc314/x_transformers/continuous.py#L225C1-L225C30

@lucidrains
Copy link
Owner

@shuishida i think the current way is the correct behavior as typically we think token centric. but perhaps i could offer a loss_weight: Float['batch']? that you can manually pass in to weigh each sequence however you wish? what is your use case?

@shuishida
Copy link
Author

I guess my suggestion is to change

        if exists(mask):
            assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
            loss = loss[mask]

        return loss.mean()

to

        numel = loss.numel()

        if exists(mask):
            assert loss.ndim > 1, 'loss should not be reduced if mask is passed in'
            loss = loss[mask]

        return loss.sum() / numel

since the number of elements in the loss can vary depending on the number of unmasked tokens in the outputs. But if this change affects many evaluations or if there's a benefit of having some varied loss magnitudes I also understand if you want to keep it as it is.

@lucidrains
Copy link
Owner

lucidrains commented Nov 11, 2024

@shuishida just to be sure we are on the same page, are you aware that the loss by default is not reduced? https://github.com/lucidrains/x-transformers/blob/main/x_transformers/continuous.py#L172 (reduction = 'none')

loss = loss[mask] should remove all unmasked tokens

@shuishida
Copy link
Author

shuishida commented Nov 11, 2024

Yes, I am aware.

So I guess an example would be, let's say if I have losses for 2 sequences, each with sequence length 3.

loss = tensor([
    [1.8, 0.6, 1.2],    # first seq
    [1.2, 1.8, 0.6]     # second seq 
])

If I don't have any masking then after loss.mean() each token will contribute the following amounts to the loss.

tensor([
    [0.3, 0.1, 0.2],
    [0.2, 0.3, 0.1] 
])

However, if we mask this with the following:

mask = tensor([
    [True, True, True],    # first seq
    [True, False, False]   # second seq 
])

then loss[mask] will yield the following:

torch.tensor([1.8, 0.6, 1.2, 1.2])

Note that it will be a flattened vector rather than the 2D tensor that we had before. Also the number of elements will decrease to the number of positive masks. If we apply loss.mean() then now each token will contribute the following amount of losses:

tensor([
    [0.45, 0.15, 0.3],    # first seq
    [0.3, 0.0, 0.0]       # second seq
])

However, I think it would be more natural if the first sequence loss isn't affected by the masking happening in the second sequence, as this:

tensor([
    [0.3, 0.1, 0.2],    # first seq
    [0.2, 0.0, 0.0]     # second seq
])

We can achieve this if we do loss.sum() / old_numel instead of loss.mean()

@shuishida
Copy link
Author

This is the point I wanted to make, but it's a minor problem.

@lucidrains
Copy link
Owner

@shuishida yes i see, but i can also see the argument against it

when in doubt, i'll just make it a hyperparameter, give me 5 minutes

lucidrains added a commit that referenced this issue Nov 11, 2024
@lucidrains
Copy link
Owner

@shuishida hey Shu, setting this to True should address the issue

@shuishida
Copy link
Author

Fantastic, thank you!

@lucidrains
Copy link
Owner

@shuishida happy training

@shuishida
Copy link
Author

shuishida commented Nov 11, 2024

Sorry, I think there's still some misunderstanding :P

The title of the issue "Varying weights to losses depending on the length of the sequence" is describing the symptom of the issue, as opposed to a feature request.

I see that the flag you've added

        equal_loss_weight_batch = False  # setting this to True, if the mask is passed in and sequences are variable in length, each sequence will be weighted the same (as opposed to each token)

has added a new feature of weighting different sequences, which is kind of cool, but this wasn't what I was suggesting. I was pointing out that depending on how you mask the sequences, there is a side-effect that the batch elements inadvertently get reweighted. (example illustrated in #291 (comment))

Anyways I don't think it matters too much but just wanted to keep it on record in case someone in the future comes across this issue and get confused :P

@shuishida
Copy link
Author

Anyways I don't want to take up any more of your time. Thank you so much for your amazing work!

@lucidrains
Copy link
Owner

@shuishida ohh I see, I don't think I agree then, as that would give batches with high variance of sequence lengths less weight than the ones with low variance (less masking)

regardless, thanks for bringing it up

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants