Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Design of PyTorch step-functions #1290

Closed
JackTemaki opened this issue Mar 28, 2023 · 12 comments
Closed

Design of PyTorch step-functions #1290

JackTemaki opened this issue Mar 28, 2023 · 12 comments
Assignees
Labels

Comments

@JackTemaki
Copy link
Collaborator

Related to #1120 I want to start the discussion on which step functions we want to have and how they look like (This should be independent of PyTorch vs. RF) . Currently implemented is train_step. One possible approach is:

  • mandatory train_step
  • optional eval_step and forward_step with an automatic fallback to train_step if no customization is needed.

Then the question is if all get the same parameters or not. Currently we have model, data and TrainCtx. I think this is enough for the start (any implementation should have **kwargs anyway to not break). Maybe TrainCtx can be renamed then if it is not only for training.

@albertz
Copy link
Member

albertz commented Mar 28, 2023

TrainCtx is also not part of the current API (maybe just what you see in master, which does not reflect what we decided for the current API).

The current API is described exactly in the initial post.

We also suggest forward_step there.

The difference between train_step to forward_step is that train_step defines the losses (so it's actually both for training + eval), while forward_step defines outputs (so it would also cover beam search).

I think this is all we need.

@JackTemaki
Copy link
Collaborator Author

which does not reflect what we decided for the current API

I opened this discussion to challenge those decisions.

The questions are: is it okay to no distinguish between train and eval? Could there be ever the need to have code in eval which is different to train. (As I said, if not provided we could always have train as the fallback).

Then for the TrainCtx. I do not like that for a pure PyTorch model you would have to call rf.get_run_ctx(). I do not think this is intuitive, and I would vote for (as @patrick-wilken proposed) a clear separation between Engine and Frontend.

@albertz
Copy link
Member

albertz commented Mar 28, 2023

Then for the TrainCtx. I do not like that for a pure PyTorch model you would have to call rf.get_run_ctx(). I do not think this is intuitive, and I would vote for (as @patrick-wilken proposed) a clear separation between Engine and Frontend.

What's the problem with that? We discussed this and we came to the conclusion it is simple enough, or actually simpler that way.

For the user, there is not really any difference anyway. Either you need to pass around train_ctx explicitly (which can be annoying), and then you call:

train_ctx.mark_as_loss(...)

Vs (which works always):

get_train_ctx().mark_as_loss(...)

In case you use returnn.tensor.Tensor (nothing really from the RF), you can also simply call tensor.mark_as_loss(...). For this to work, this basically requires such a function get_train_ctx().

I'm not really sure what separation you mean. Or you just say this function get_run_ctx() should not be in the rf namespace but somewhere else? But is this really so relevant?

@albertz
Copy link
Member

albertz commented Mar 28, 2023

The questions are: is it okay to no distinguish between train and eval? Could there be ever the need to have code in eval which is different to train. (As I said, if not provided we could always have train as the fallback).

You likely would anyway use a separate config for that, in case you want to evaluate sth different,. In that separate config, you could define any other custom train_step function.

Similarly for forward_step. You might want to export attention weights, or the encoder outputs, or so. You would just have separate configs for this, where you overwrite forward_step, to output whatever you want.

@patrick-wilken
Copy link
Contributor

Thinking about the details I'm noticing that the interface is really more complex than I thought.

That's what I had in mind so far:

graph mode eager mode
get_model() return Module whose __call__ method constructs the graph, without losses, without beam search return Module whose __call__ method runs the forward pass, without losses, without beam search
train_step(model) construct the graph and add loss computation run forward pass and calculate losses
forward_step(model) construct the graph and and add beam search run forward pass including beam search

(Beam search is a placeholder here for however you want to decode.)

But: just having a __call__ function is not sufficient in general, right? Because it would have some outputs, and in train_step() / forward_step() you would only have access to these outputs, not to the internals of the model. But you want to be able to mark tensors as loss / output anywhere in the model. (Otherwise we wouldn't need this train_ctx / Tensor mark_as_loss() thing, train_step() could just calculate and return the loss.)
@albertz , in your setups in i6_experiments I see that your Module has for example an encode() and decode() method to be able to add CTC loss after encode(). So in general, for each loss you want to add inside the model you need a new method to split up the calculation. I wonder whether that's a good design...

For the maybe earlier intended use case of returnn_common where you define a Module, construct the network dict from it and set the network config parameter, you would just have mark_as_loss() and mark_as_output() inside the __call__() method, but with the current design that is not intended anymore, right?
But I fear in the graph based case this is still the easier solution, so people will just do it that way and do nothing in the step functions except for running model.__call__()?
In eager mode that doesn't really work because the losses and additional outputs would always be calculated, even when unused. But all that is needed to prevent that is to have e.g. a train / eval flag available as input to __call__, or maybe via a "ctx" object and then use that in the code.

@albertz
Copy link
Member

albertz commented Mar 30, 2023

Whether you have just a __call__ function, or sth else in your model (like a separate get_encoder_output()), this is up to you.

Whether you call mark_as_loss inside the model __call__ (or other functions), this is also up to you.

Maybe look at my existing example (using RC, but the API is very similar):

  • Here is my get_model function (in that example from_scratch_model_def).
  • Here is my train_step function (in that example from_scratch_training).

In those examples, I do all the mark_as_loss calls directly in the train_step function. The model does actually not have a __call__ function but instead a encode(), which also returns intermediate layer outputs, because I wanted to have them for auxiliary losses. It also has a decode() function, where it returns the log probs, and then I define the loss based on that (RNN-T here).

But you are really free to do it however you want. I liked to have this separation, at least for these examples, but maybe there are other cases where you want to do the mark_as_loss calls somewhere more inside some model functions. The model could also come with a train(source, target) function. But I actually like to have the separation between model and loss, again at least for this case.

For the maybe earlier intended use case of returnn_common where you define a Module, construct the network dict from it and set the network config parameter, you would just have mark_as_loss() and mark_as_output() inside the call() method, but with the current design that is not intended anymore, right?

This is still possible with the current design. But I actually never really intended it to be done like this.

But none of this discussion here is really anything about eager-based vs graph-based, or not? All what I said here is just the same, no matter if graph-based or eager-based. Yes, in eager mode, everything is always calculated. You have to make sure you would not calculate it if you don't need it. You can simply check on get_run_ctx() if the loss should be calculated (train or eval). You could also make an explicit arg to your __call__. All of this is really up to you.

@albertz
Copy link
Member

albertz commented Mar 30, 2023

Btw, also check how they do it in Fairseq or ESPnet. I would assume they also separate the model stuff from the the loss calculation. Or at least have that in a separate function.

@JackTemaki
Copy link
Collaborator Author

Btw, also check how they do it in Fairseq or ESPnet. I would assume they also separate the model stuff from the the loss calculation. Or at least have that in a separate function.

Nope, for ESPNet it is hidden somewhere completely in the middle in a forward function in the model, e.g. https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/vits/vits.py#L424
Any Model should then returnn a loss and stats (which works quite similar but not exactly as our total_loss and the loss_dict, because in our case the total is fixed to the sum of all losses in the dict)
https://github.com/espnet/espnet/blob/master/espnet2/gan_tts/espnet_model.py#L97

This also means it is completely hidden to the outside what the losses are. No idea how they track the losses then, definitely not an approach I like. BUT:
For me it is completely reasonable to have losses where ever I want them to be defined. Currently I just register the losses within the train_step for TorchScript reasons, but I hope I can find a solution for this in the near future.

So my question is:

[to mark_as_loss() inside __call__] But I actually never really intended it to be done like this.

why? If I have a duration predictor module in my TTS I want the duration loss to be defined exactly in there. Otherwise I lose modularity.

You can simply check on get_run_ctx() if the loss should be calculated

Do not forget the PyTorch way is to use self.training. This is set for all children recursively when calling Module.train() or Module.eval() and affect e.g. Dropout and BatchNorm. So this is completely fine to use, as e.g. in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/dropout.py#L59.

So we would not even need train_step in the most cases. I still think it is good in order to really implement custom training logic independent of the model, so stuff that definitely will never be exported or saved to a model file.

Whatever we do, RETURNN should not impose any limits. If you want to define your losses in the forward/call somewhere deep in the model, it should be possible. If you want to collect all losses in the train_step, it should be possible. If you want to store extra stuff in a global context, it should be possible. If you want to pass everything manually in the forward function, well this is anyway always possible.

It should also be easy to do more custom stuff like turn based training (https://github.com/espnet/espnet/blob/master/espnet2/train/gan_trainer.py#L147). But I think this is already working now, you just return different losses based on the step. For really crazy stuff we could even allow that every single function of the engine can be overwritten by a function defined in the config. So if someone wants a custom train_epoch, why not...

@JackTemaki
Copy link
Collaborator Author

Do not forget the PyTorch way is to use self.training. This is set for all children recursively when calling Module.train() or Module.eval() and affect e.g. Dropout and BatchNorm. So this is completely fine to use, as e.g. in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/dropout.py#L59.

I have to correct myself here, as we want cross validation in eval mode but WITH loss computation, we can of course not use self.training for deciding on loss computation. But this can be used for specaugment and so on.

@albertz
Copy link
Member

albertz commented Mar 30, 2023

[to mark_as_loss() inside __call__] But I actually never really intended it to be done like this.

why? If I have a duration predictor module in my TTS I want the duration loss to be defined exactly in there. Otherwise I lose modularity.

I just meant that I never intended to use that in my setups if possible, and that's what I would also recommend. Although I also saw some potential exceptions for some unsupervised auxiliary losses or so. And I also was not 100% sure about it, and just wanted to gain some experience in actually using it.

But I actually always intended to have the possibility to define it wherever you want.

And you have that possibility in the current design via rf.get_run_ctx().mark_as_loss(...), which can be called at any place. So, I guess all is good then?

@JackTemaki
Copy link
Collaborator Author

So, I guess all is good then?

Yes, although I personally do not like that the context is only available via the rf package but this is a minor detail.

@albertz
Copy link
Member

albertz commented Mar 31, 2023

I implemented that partly now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

5 participants