Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kullback-Leibler divergence #131

Merged
merged 4 commits into from
Apr 13, 2022
Merged

Kullback-Leibler divergence #131

merged 4 commits into from
Apr 13, 2022

Conversation

Atticus1806
Copy link
Contributor

This is a draft for the Kullback-Leibler divergence (loss).

The idea is to use this with .mark_as_loss(). This is only a draft since I am not 100% sure about all the modalities, which might need to be taken care of (or are already dealt with with the marking)

In my mind this is:

  • conversion to actual (batch-wise) loss
  • potential inconsistencies / stability concerns with the input e.g. PyTorch (https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html) assumes estimated to be in the log-space already and makes it optional for target to be too. Do we want to enforce/handle this in the same way?

@albertz
Copy link
Member

albertz commented Apr 13, 2022

It should be consistent to cross_entropy regarding the logic of log-space etc.

Also see the discussion in #38.

I don't think it should also perform the reduction over time or other dims. It should just calculate the KL. The final reduction and accumulation over time and batch would be handled by RETURNN.

@Atticus1806 Atticus1806 marked this pull request as ready for review April 13, 2022 13:17
nn/loss.py Outdated Show resolved Hide resolved
@Atticus1806
Copy link
Contributor Author

Okay I added the checks for the type of estimated the same way it is done for ce. So I think this should be ready then.

nn/loss.py Outdated Show resolved Hide resolved
nn/loss.py Outdated Show resolved Hide resolved
nn/loss.py Outdated Show resolved Hide resolved
nn/loss.py Outdated Show resolved Hide resolved
@albertz albertz merged commit 7c62226 into main Apr 13, 2022
@albertz albertz deleted the kullback_leibler branch April 13, 2022 14:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants