Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Losses to implement and losses naming conventions #38

Closed
albertz opened this issue Oct 21, 2021 · 32 comments
Closed

Losses to implement and losses naming conventions #38

albertz opened this issue Oct 21, 2021 · 32 comments
Assignees
Milestone

Comments

@albertz
Copy link
Member

albertz commented Oct 21, 2021

Some modules we should implement:

  • CrossEntropy or CE. Should this cover both dense and sparse targets, or do we want a separate module for the sparse case, like SparseCrossEntropy or so? Should this potentially also allow for logits? log-probs?
  • KL or KullbackLeiblerDivergence
  • BinaryCrossEntropy or BCE
  • L2Dist (absolute or mean?) Or MSE or MeanSquarredError? (The mean reduction is over the feature axis. Not over time or batch.)
  • L1Dist (absolute or mean?) Or MeanL1Dist?
  • Ctc or CTC or CtcLogProb
  • CosineSimilarity

I don't like the naming of the PyTorch losses too much here.
They have the postfix Loss on all of them, although these modules are generic and not necessarily just for loss computation (although that's probably their most common usage).
Also CrossEntropyLoss is actually log-softmax + CE together. So very much like the TF tf.nn.sparse_softmax_cross_entropy_with_logits.
And there is a separate NLLLoss. Which is just like CrossEntropyLoss but it doesn't take logits but log-prob instead. I find this naming to be confusing.

Also the question is how we should handle things like label smoothing. On RETURNN side (and also in TF), it is just an option to the CE-loss. On PyTorch side, it is not implemented yet as part of the official PyTorch API. Some background here. It was only very recently added (pytorch/pytorch#7455, pytorch/pytorch#63122). This also adds it as an option label_smoothing to CrossEntropyLoss. An alternative would be that the user makes this more explicit, like:

target_prob_smooth = smooth_one_hot(targets, label_prob=0.9)
loss = cross_entropy(target_prob_smooth, out_prob)

Although label smoothing has become very common, so maybe it makes sense to have this also just as an option.

Note also that the loss accumulation over the dataset and handling of calculating the correct average (mean) is handled by RETURNN. All such losses would just yield a vector of shape [B] or [B,T].

@Atticus1806

This comment has been minimized.

@albertz
Copy link
Member Author

albertz commented Oct 28, 2021

I don't like the naming of the PyTorch losses too much here.

I think this changed from the last time you looked? Because looking into https://pytorch.org/docs/stable/nn.html#loss-functions they do have Loss as postfix.

But that is one thing which I meant. This is what I don't like. CE is not just a loss. It is just CE. You can use it as a loss. But so can you also use anything else as a loss. And you can also use CE for other purpose, not just as a loss.

I made this comment to say that we maybe should not exactly follow the PyTorch naming scheme in this case. Or at least we should think about it.

I assume we want to put them similar to PyTorch into a sperate file like loss.py.

Yea sure, but this would just be our internal structuring, so this doesn't really matter too much. The user would simply use rc.nn.CE or whatever.

What would you think of as tests for the losses? Because I think this issue is something we should be careful about, when implementing by hand. Maybe run a small RETURNN with the RETURNN and the rc (Modul) loss and compare the results? Does that catch all cases though?

I don't quite understand. What exactly do you want to test? That e.g. the CE loss here and the RETURNN CE loss yield the same result? You can just write a test for that. Or for whatever else you want to test.

@Atticus1806
Copy link
Contributor

Atticus1806 commented Oct 28, 2021

Ohh I missread that, my bad.
Looking into it in PyTorch it seems inconsistent to me too. In their Modules they just call the functions from torch.nn.functional and in these functions e.g. CE does not have the loss postfix, but e.g. CTC still has the postfix. Maybe we can wrap in a similar way but then consistently have all the functions not named like losses?

What exactly do you want to test?

I think we should think about good tests for the losses since they are crucial part of optimization. My suggestion assumed we see the RETURNN losses as correct and test based on that for similar results so we can infer the correctness of the Modules. I just wanted to say that we should think about it already when writing the Modules and not afterwards. But maybe this is something that is not as hard as it seems to me right now?

@albertz

This comment has been minimized.

@albertz

This comment has been minimized.

@Atticus1806
Copy link
Contributor

Atticus1806 commented Oct 28, 2021

What do you mean by wrapping?

class L1Loss(_Loss):
    __constants__ = ['reduction']

    def __init__(self, size_average=None, reduce=None, reduction: str = 'mean') -> None:
        super(L1Loss, self).__init__(size_average, reduce, reduction)
    def forward(self, input: Tensor, target: Tensor) -> Tensor:
        return F.l1_loss(input, target, reduction=self.reduction)

This is the Loss class which then refers to the loss in torch.nn.functional which then refers to a C implementation.

So you also prefer the variant without Loss postfix?

For the implementation of the function yes, but we should have a Module that just calls this function with the Loss Postfix because I think it is straight forward when just wanting to use it as a loss, to reflect that in the name (which ofc. is easy to do).

Why should it be hard to write a test case? I don't really see any problem.

I did not put that much thought into it yet, I just could not come up with a quick idea on how to test, but if you say its not hard then its fine.

@albertz

This comment has been minimized.

@Atticus1806
Copy link
Contributor

What should we wrap? Why do we need to wrap anything? I don't think we need to wrap anything.

My suggestion was to implement the loss function without loss postfix as a standard function (like you suggested) and then have a class with the loss name similar to PyTorch, because while I agree with you that the functions can be used without their Loss intention, I feel like for a lot of users it would be more straight forward to actually have Loss in the name. Maybe this can already do e.g. mark_as_loss internally or similar.

@albertz
Copy link
Member Author

albertz commented Oct 28, 2021

For the implementation of the function yes, but we should have a Module that just calls this function with the Loss Postfix because I think it is straight forward when just wanting to use it as a loss, to reflect that in the name.

But I don't understand that reasoning.

  • You want it both as a function and as a module? Why? Do we really want that for everything now? Or just for losses? (Related: Functional layer API, conventions #21. You are right, we currently do have it for Dropout & dropout and some others.)
  • You want that functions and modules are inconsistent to each other? The module name should have the Loss prefix, but not the function? Why?

@Atticus1806
Copy link
Contributor

Atticus1806 commented Oct 28, 2021

Oh now I think I see your Problem. No I don't necessarily want one Module and one Function (that was bad wording I guess), in general this could also be 2x Module or 2x Function, but I would argue for having that "Non-Loss" Version wrapped in a "Loss" Version which handles the potential Loss settings.

@albertz
Copy link
Member Author

albertz commented Oct 28, 2021

this could also be 2x Module or 2x Function, but I would argue for having that "Non-Loss" Version wrapped in a "Loss" Version which handles the potential Loss settings.

I still don't really follow that reasoning.
It would just be an alias then.
Like:

cross_entropy_loss = cross_entropy

Or:

CELoss = CE

What is the point in that?

And what do you mean by "potential loss settings"? There is never anything specific to a loss. Whatever you do is just some mathematical calculation. Just at the end, you declare it as a loss.

@Atticus1806
Copy link
Contributor

Atticus1806 commented Oct 28, 2021

And what do you mean by "potential loss settings"? There is never anything specific to a loss. Whatever you do is just some mathematical calculation. Just at the end, you declare it as a loss.

Yes you are right about that. My thinking was, how I would approach a situation where I would have to build a model and use a loss. And while I agree with you that once "thinking" about doing something like 'loss = loss_scale * CE(x)' that seems fine, I am not sure if I would come up with that right away, while when having 'loss = CELoss(x ,scale=loss_scale)`, which internally does this, would make it more clear for a user. That could also then mark the output as a loss etc.

I think most of all this is a design decision. And I am not sure if in the end users do this wrapping by themselves. Because if they then we could also already do it here in my opinion.

@albertz
Copy link
Member Author

albertz commented Oct 28, 2021

Note that the mark_as_loss function (on a layer) has already such loss_scale option. Maybe you mean that by loss settings. Any similar or further specific things can be added to that function, although I would argue to keep that minimal.

I don't think that scale should be added to CE or CELoss.

something like loss = loss_scale * CE(x) that seems fine, I am not sure if I would come up with that right away, ...

How is that not totally obvious?

Although your code example misses the main thing which makes it a loss, which is the mark_as_loss (no matter if it is CE or CELoss). So it would be:

loss = CE(...)
loss.mark_as_loss()

And surely, our documentation would also just have a small part on the losses.

@Zettelkasten
Copy link
Member

I also think calling the losses just CE (without loss) etc. is good, and then having to always mark them explicitly. I think having an additional CELoss or so, when this is really just CE(..).mark_as_loss(), is not good. Or, if you want it this way, you can also define your own function so you can write it as Loss(CE(..)). Then this is syntax wise also super close to CELoss(..).

Regarding cross-entropy in particular, RETURNN currently has some heuristic for this where you can pass it a probability distribution generated from a softmax, and then will compute the CE on the logits before the softmax. I heard this is important for numerical stability.
This heuristic in RETURNN is somewhat bad currently, because it is implicit and can easily break and you will not notice if it does. How do we want to approach this here? If we write CE as our own module here, this would break the heuristic.
Do we want have the user to directly pass the logits to CE? With the danger that they forget this, pass the probabilities, and then get these instabilities?
Or I think in PyTorch, where CrossEntropyLoss always takes logits?
Or still support this heuristic, and make it work some way?

@albertz
Copy link
Member Author

albertz commented Oct 28, 2021

Ah, this is another thing which I found somewhat inconsistent on the PyTorch loss module names: Some of them actually get logits (unnormalized), while others get log-prob or prob. E.g. CrossEntropyLoss gets logits while NLLLoss is just the same but it gets logprobs. The idea to accept logits is fine (or actually good) but just the naming is confusing.

Surely, we can make the heuristic work. Also not so complicated. We can check (e.g. on RETURNN side) whether the input is a tensor right after a softmax op, and then get the input tensor of the softmax op instead. We make use of such logic in a couple of places. See e.g. safe_log. But you are right, it might semi-silently not work in some cases (you could still get a warning, just as you do now in RETURNN for CrossEntropyLoss...).

Or, we can make a variant of CE which accepts logits directly.

Or (just like ChoiceLayer) there can be an option which specifies this explicitly, like input_type in {"prob", "log_prob", "logits"}. In case of prob, we can still try to apply the heuristic.

But yes, this is a good point which we should decide.

@albertz
Copy link
Member Author

albertz commented Oct 28, 2021

Another thing: The targets of CE could be dense (another prob distrib) or sparse (just class indices). Should this be the same function (module) (just as RETURNN CrossEntropyLoss) or separate?

(Similar as e.g. Linear vs Embedding in PyTorch. In RETURNN, both is LinearLayer.)

@JackTemaki
Copy link
Contributor

But yes, this is a good point which we should decide.

I would like to have as few heuristics as possible. RETURNN already has many hidden heuristics for the losses that are non-verbose to the user, and I think it would be good to have this as explicit as possible (in the name and documentation of the module). In the extreme case maybe something like CrossEntropyWithLogits and CrossEntropyWithProbs (not good names, but something like this).

Another topic, how would we write Modules that have optimized calls in Tensorflow? From my current knowledge, I would write L2 like this:

class L2(Module):

    def __init__(self, reduction=None):
        """
        :param str|None reduction: None, "mean" or "sum", will reduce over everything except batch
        """
        super().__init__()
        assert reduction in [None, "mean", "sum"]
        self.reduction = reduction

    def forward(self, inp1, inp2) -> LayerRef:
        out = eval(concat(inp1, inp2), eval="tf_compat.v1.squared_difference(source(0),source(1))")
        if self.reduction:
            out = reduce(out, mode=self.reduction, axes="except_batch")
        return out

So here is also the question, should L2 be written "explicit", should it only "wrap" the TF call, should it include the reduction or should this be separate?

@albertz
Copy link
Member Author

albertz commented Oct 28, 2021

Note that reduction over batch (for framewises losses also over time) is sth which RETURNN handles. This is also important to be this way because RETURNN must see the batch dim and time dim to be able to properly accumulate losses over the epoch. If we want to move that logic somehow over here, it's not clear how we would be able to handle that.

Reduction over any other axes (in case of L2 or so) should (could?) be reduced explicitly before (as mean, sum, or what you like). (I think the current behavior in RETURNN is, whenever there are other axes in addition to batch or time, this is allowed, and it will just accumulate them with sum (edit or mean).)

Btw, some comments to your code example:

  • I think whenever we need to wrap some TF function, we should directly make such a wrapper function, so this is hidden away. Here for squared_difference. Although we might want to call it differently. We should not use eval directly.
  • concat is wrong in your example.

We should make use of TF optimized functions when possible. Efficiency is one main principle of RETURNN.
Although this is now some example where this is probably only really minor. But anyway.

@JackTemaki
Copy link
Contributor

JackTemaki commented Oct 28, 2021

I think the current behavior in RETURNN is, whenever there are other axes in addition to batch or time, this is allowed, and it will just accumulate them with sum.

All axes are flattened and then combined with reduce_mean, so this is not the sum, right?

from returnn/tf/layers/basic.py:8904 :
out = self.reduce_func(tf.reduce_mean(out, axis=1))

axes=1 is here everything but batch...

concat is wrong in your example.

Yes of course... ooops

@albertz
Copy link
Member Author

albertz commented Oct 28, 2021

Btw, on terminology:

L2 is actually a norm. Or more specifically, it is the Lebesgue space using the 2-norm. Also, it is actually sqrt(sum(x**2)) and not just sum(x**2).

L2-distance is also common though, and means ``sqrt(sum((x - y)**2))`.

For me, I would expect that a function l2 or a module L2 would calculate the norm, not the distance.
I would expect sth like l2_distance for the distance.
Maybe, to be more explicit, l2_norm for the norm.

We should very clearly reflect in our function name when we actually mean the squared-difference (so just sum((x-y)**2)) or mean-squared-difference.

@albertz
Copy link
Member Author

albertz commented Oct 29, 2021

Some open questions, and my current opinion:

Module or function, or both? I would vote for function only, to keep consistent with #21, where we have the simple convention: Modules are used when it contains some parameter, functions otherwise.

Do we want the _loss prefix? Or have two functions, one cross_entropy and then just an alias cross_entropy_loss = cross_entropy? I would vote for not having such aliases, and in general also not for having the _loss prefix. Of course, in some documentation section, we would list all functions such as cross_entropy which are commonly used as losses.

Do we want ce or cross_entropy? I think cross_entropy is nicer and less ambiguous in the code. (When we talk here about it, I think it's fine to refer to it as "CE" now because the context is clear.)

Do we want cross_entropy to cover both a sparse and dense target prob distrib? Or have different functions for each case, like cross_entropy_dense and cross_entropy_sparse? I think one function is fine.

Do we want cross_entropy to work on different input types (logits, log prob, prob)? Or have separate functions for each case? Like cross_entropy_logits etc. Not sure...

How would we handle label smoothing?

  • Option to cross_entropy (and all the other variants of that function, if we decide to have different functions for every variant).
  • Some extra function cross_entropy_label_smoothing or so.
  • Just do it via explicit code, which is very simple:
    target_prob_smooth = one_hot(targets, true_=0.9, false_=0.1 / (targets.dim - 1))
    loss = cross_entropy(target_prob_smooth, out_prob)
    

I tend to prefer the last option with explicit code, although I also see that this is so common that maybe the second option is also fine? I don't know...
Note that despite label smoothing, there are other things, and in the future maybe further new fancy things, which might become standard, and then we again end up with the problem of adding more and more options to the base cross_entropy function, which is sth we should try to avoid. So this is a strong argument against the first option.

What functions do we want to introduce, and using what names? (PyTorch loss functions)

  • cross_entropy (PyTorch: torch.nn.functional.cross_entropy)
  • kl_div (PyTorch: torch.nn.functional.kl_div)
  • squared_difference or squared_error (for sum((x-y)**2))
  • mean_squared_difference or mean_squared_error (for mean_F (sum((x-y)**2)))
  • l2_distance
  • l2_norm
  • l1_distance (sum(abs(x-y)))
  • mean_abs_error (mean_F (sum(abs(x-y))))
  • l1_norm (sum(abs(x)))
  • ctc or neg_log_prob_ctc or so

Others? (We don't really need to have all potential ones right away. This can just be extended later. We only need the most important ones for the beginning.)

@Atticus1806
Copy link
Contributor

Atticus1806 commented Nov 3, 2021

Do we want cross_entropy to work on different input types (logits, log prob, prob)? Or have separate functions for each case? Like cross_entropy_logits etc. Not sure...

We want rc to be straight forward and have as little hidden stuff as possible. This is why I think having explicit names like cross_entropy_logits would be the most straight forward for the user to use. If we want to restrict ourselves to one function though, I would vote for forcing the user to explicitly set the "Flag" (instead of it having a default value, which again would hide things in my opinion).

How would we handle label smoothing?

In general this comes down to the argument I think how much of the code is easy for the user to think of and where the border for this is. I think we should answer these kind of questions always by asking ourselves: "Is it reasonable to assume that someone who will use this can write it him/herself or would it require quite some knowledge/searching to come up with?" Once we agree on one of the two this then directly leads to whether we should implement it I think.

... the problem of adding more and more options to the base cross_entropy function, which is sth we should try to avoid.

Agreed, so either 2nd option or 3rd in my opinion. For label smoothing I might tend to 2 right now, but maybe others can also comment.

What functions do we want to introduce, and using what names?

If we name functions error doesn't this go into the same direction as giving functions the _loss suffix? Since error in the end like loss is an interpretation of the difference?

@albertz
Copy link
Member Author

albertz commented Nov 3, 2021

cross_entropy_logits ...

Ok. So we would have 3 functions then:

  • cross_entropy_logits
  • cross_entropy_log_prob
  • cross_entropy_prob

how much of the code is easy for the user ... someone ...

This can vary a lot, depending on the user and his/her background. Some people also might not even know about cross entropy and they just want to train it in whatever way.

In practice, people will probably anyway still use some small code snippets from someone. This can be in the documentation, some small code snippet how you would do label smoothing (just the code snippet I posted above).

Automatic code completion and IDE support is another thing though, where your IDE can easily show you the variants. And as label smoothing is really so common now, maybe we can anyway also add cross_entropy_logits_label_smoothing (but not for the other CE function variants, such that we don't get the combinatoric explosion)

Or I wonder how good GitHub Copilot will be to come up with the snippet when I just put the prefix # cross entropy with label smoothing...

@albertz albertz changed the title Losses Losses to implement and losses naming conventions Nov 3, 2021
@albertz
Copy link
Member Author

albertz commented Nov 3, 2021

Just as a random reference, Fairseq has LabelSmoothedCrossEntropyCriterion.

@albertz
Copy link
Member Author

albertz commented Nov 3, 2021

Another question: Should it be cross_entropy_logits(target_prob, out_logits) or cross_entropy_logits(out_logits, target_prob) or cross_entropy_logits(*, out_logits, target_prob)?

When you look at the mathematical definition of cross entropy H(p,q), the p is the target distribution and q is the output.
This is the reverse of how it is defined in PyTorch CrossEntropyLoss which gets (input, target).
However it matches TensorFlow sparse_softmax_cross_entropy_with_logits which gets (labels, logits).

For #56, it makes sense to have the convention that the first argument is usually the model output.
Edit This is not needed anymore. There is no restriction on the order of the arguments.

@albertz
Copy link
Member Author

albertz commented Nov 3, 2021

Note: We also should add (non-differential) functions for frame-error, edit-distance, etc. (See #57.)

@albertz
Copy link
Member Author

albertz commented Nov 4, 2021

Related is also #17 on dim tags. Many of these losses will reduce over some axis, usually the feature dim axis. But maybe that should be explicit?

@albertz
Copy link
Member Author

albertz commented Nov 8, 2021

Ok. So we would have 3 functions then:

  • cross_entropy_logits
  • cross_entropy_log_prob
  • cross_entropy_prob

I want to ask again about whether these functions should deal with both dense and sparse targets, or whether this should be separated and explicit as well? If this is separate, then we would have now already 6 functions. Is this good? And maybe there are other things?
Note that TF has sparse_softmax_cross_entropy_with_logits which is exactly one of those instances. It would also be good if we can explicitly wrap this for efficiency.

Also note, if we don't separate sparse inputs, then we need to implement it such that it can handle both dense and sparse inputs. This implies that we need one of these:

@albertz
Copy link
Member Author

albertz commented Jan 4, 2022

Note that DotLayer supports sparse inputs now (returnn#741).

And we also have shape information and sparse_dim available (#47).

But I would just use DotLayer for simplicity.

However, we should somehow make sure that it uses sparse_softmax_cross_entropy_with_logits for the standard case for efficiency reasons. And also make sure it is only calculated on the non-padded frames (#68), again for efficiency.

albertz added a commit that referenced this issue Jan 13, 2022
@albertz
Copy link
Member Author

albertz commented Jan 13, 2022

We have cross_entropy now.

I decided to go with a (required) option ...type which is "probs", "log-probs" or "logits".

It also wraps sparse_softmax_cross_entropy_with_logits when possible.
Otherwise it uses dot (both for dense or sparse).

It also has a label_smoothing option but not sure if that was a good idea. As mentioned before, it should stay very minimal. But on the other side, label smoothing is very common. But maybe I remove that again.
Edit Removed that again.

One important aspect which is missing is the flattening (#68). This is also for efficiency because we don't want to do unnecessary softmax computations. But in any case, we should deal with this separately (#68). (Edit Done as well.)

albertz added a commit that referenced this issue Apr 13, 2022
albertz added a commit that referenced this issue Apr 13, 2022
albertz added a commit that referenced this issue Apr 13, 2022
@albertz
Copy link
Member Author

albertz commented Apr 13, 2022

I think we now have a good starting point with cross_entropy, binary_cross_entropy, mean_squared_difference, mean_absolute_difference.

One (for us) relevant remaining loss is CTC. Here is a bit the question on how generic we want to have it. In any case there should be a ctc_loss corresponding to the original CTC loss, wrapping returnn.tf.native_op.ctc_loss (same as the TF CTC loss but faster). But then we also might want to have a more generic full sum loss, for any target FSA (e.g. generated on-the-fly by RASR). And the ctc_loss can maybe directly be based on that (internally in RETURNN it's also like that anyway).

Once we have ctc_loss, I think this issue can be closed and any further loss can be implemented later once it is needed.

@albertz
Copy link
Member Author

albertz commented Apr 13, 2022

We also have ctc_loss now, directly wrapping the newly introduced CtcLossLayer.

So I will close this now. In case sth is buggy, please open a new issue. In case you want to have a specific missing loss function, please open a new issue.

@albertz albertz closed this as completed Apr 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants