Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Frontend API and PyTorch backend #1120

Open
albertz opened this issue Sep 12, 2022 · 200 comments
Open

Frontend API and PyTorch backend #1120

albertz opened this issue Sep 12, 2022 · 200 comments

Comments

@albertz
Copy link
Member

albertz commented Sep 12, 2022

Edit Originally, this issue was about a proof-of-concept for a new PyTorch backend in RETURNN.
This has somehow evolved into a whole new generic frontend API (original idea here: #1264), which very much follows the API from RETURNN-common nn, which covers multiple backends. This API is accessible for the user via returnn.frontend, and the convention for the user to use it would be like:

import returnn.frontend as rf

class MyModel(rf.Module):
  def __call__(self, a: rf.Tensor, b: rf.Tensor):
    return rf.matmul(a, b, ...)

We also made sure that our Tensor class (earlier called Data) supports any raw tensor type (type of Tensor.placeholder, or Tensor.raw_tensor now) and is backend independent. This is all in returnn.tensor now.

Currently, the following backends are relevant:

  • RETURNN layers net dict backend for TF. The raw tensor type is NameCtx.
  • PyTorch. The raw tensor type is torch.Tensor.
  • TF directly / low-level. Raw tensor is tf.Tensor.

The terminology on frontend/backend is sometimes used a bit inconsistent. We mean frontend or frontend API basically to describe what the user sees, what you have in the returnn.frontend namespace, i.e. functions like reduce and dot, or also modules like Linear. This API is basically very much the same as RETURNN-common nn.

We have an abstract base class Backend, which defines some core functions of the API and allows to reimplement it for different backends. The derived classes are e.g. TorchBackend, ReturnnLayersBackend, TFBackend. The earlier terminology was maybe a bit confusing: They implement some frontend functions for the specific backend. So sometimes we referred to this as "different frontend (implementations)".

The Backend class and its different backend implementations is supposed to be internal to RETURNN and not directly exposed to the user. The user has some helper functions to switch the backends.

There is also a lot of code which builds on top of the backend functions. E.g. the rf.Module class, modules like rf.Linear, would all be independent from the backend. Or also functions like cross_entropy.

The user in the end needs to define the following functions in the config:

def get_model(*, epoch: int, step: int, **_unused_kwargs) -> Union[rf.Module, torch.nn.Module]:
  ...

def train_step(*, model, extern_data: TensorDict):
  ...

def forward_step(*, model, extern_data: TensorDict):
  ...

train_step would only be used in training. Here the user should call mark_as_loss on some tensors.

forward_step would be used for recognition. Here the user should call mark_as_output on some tensors. See #1336 for more discussion on this, how it would be used then for beam search, forwarding, or whatever you want to do with the outputs.

To also support PyTorch modules more directly, get_model() can also return a torch.nn.Module. See #1120 (comment).

To add losses when using a raw torch.nn.Module, the API inside train_step would look like:

rf.get_run_ctx().mark_as_loss(..., loss_raw_torch_tensor, ...)

But this is intended only to be used for external code, and for our own code, we recommend the use of the RF. But in any case, it will be easy to mix pure PT and RF code together (with the PT backend).

To get access to the step inside train_step, there will be sth like rf.global_train_step().

Related:


Some summary of the current open discussions, items, or status updates:


Done:


This issue here is also to discuss and report on implementation details of the new PyTorch backend.

The initial step would be to just get a proof-of-concept, meaning the goals currently are:

  • Get some experience with PyTorch, some better ideas how to integrate PyTorch into RETURNN, etc.
  • Get training and inference with some basic model, using existing RETURNN datasets. It would be some hybrid NN-HMM for ASR. We can take existing PyTorch code, for example some Conformer encoder from ESPnet.

Getting it compatible to the TF backend is maybe the ultimate goal, but this is on a completely different level than the proof-of-concept discussed here.

For the dataset, we would implement a PyTorch IterableDataset to wrap any RETURNN dataset.

albertz added a commit that referenced this issue Sep 12, 2022
albertz added a commit that referenced this issue Sep 12, 2022
albertz added a commit that referenced this issue Sep 12, 2022
@albertz
Copy link
Member Author

albertz commented Sep 12, 2022

The current goal is that you can run python3 rnn.py demos/demo-torch.config and it starts the training.

albertz added a commit that referenced this issue Sep 12, 2022
@albertz
Copy link
Member Author

albertz commented Sep 12, 2022

Note that I created the config option backend, so you do backend = "torch" in the config to enable the PyTorch backend.

I also created an initial returnn.torch package/directory, with a dummy Engine in engine.py.

Feel free to create a data_pipeline.py file when you start working on implementing the dataset and related code.

@albertz
Copy link
Member Author

albertz commented Sep 12, 2022

Btw, as we work in master, please make sure the tests are passing. There is no PyTorch test at all yet, so the only test relevant for us currently is about code inspections, like PEP8 etc. Please double check that there are no warnings on the code before you commit.

For commit messages, maybe prefix them with "PyTorch" or "Torch" or so.

@albertz
Copy link
Member Author

albertz commented Sep 12, 2022

For the main train loop, see the PyTorch quickstart tutorial (slightly simplified, adapted):

device = "cuda" if torch.cuda.is_available() else "cpu"

class NeuralNetwork(nn.Module):
  ...

model = NeuralNetwork().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

def train(dataloader, model, loss_fn, optimizer):
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    ...

# epoch loop
epochs = 5
for t in range(epochs):
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)

@albertz
Copy link
Member Author

albertz commented Sep 12, 2022

So, in the RETURNN Engine, we would also have the same to loop over the batches:

    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

We also would have a such a model instance. This is somewhat like the root TFNetwork, or the root module in returnn-common. It is necessary to have such a root module, to have unique param names. See corresponding discussion in returnn-common.

As I understand the PyTorch code, I think we need to have one single loss in the end, such that we have a single loss.backward() call. I think multiple backward() calls would lead to backprop multiple times and this would be stupid. But this is not really a problem: We just need to do the summation of all losses you potentially have. We also do this for TF/Theano.

Maybe, like in returnn-common, we can have the same mark_as_loss API. So, some initial config API suggestions:

def get_model(*, epoch: int, **_unused_kwargs) -> torch.nn.Module:

Would create the model instance. I'm not sure if we really need to or should pass extern_data here, as you should have all the relevant information anyway in the config. But maybe it makes it simpler to write it? Originally I also thought about just providing class Model(torch.nn.Module), with the API that Model() should work.

def train_step(*, extern_data: ExternData, model: Model, ctx: TrainCtx):

Here you would do:

        pred = model(extern_data.data["inputs"].placeholder)
        loss = F.cross_entrop(pred, y)
        ctx.mark_as_loss(loss)

The TrainCtx provides this mark_as_loss method, which would just collect the losses. Maybe together with a name. Then in the engine, the train loop could look sth like this:

    model.train()
    ctx = TrainCtx()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        extern_data.set(X, y)  # ...

        pred = train_step(extern_data=extern_data, model=model, ctx=ctx)
        loss = ctx.total_loss()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

@patrick-wilken
Copy link
Contributor

I'm implementing a dataset wrapper. The code is still too ugly for a pull request, so first a comment: 😄

The general interface is clear:

from torch.utils.data import IterableDataset


class DatasetWrapper(IterableDataset):
  def __init__(self, returnn_dataset):
    ...
    
  def __iter__(self):
    ...

And then in engine.train() something like:

from torch.utils.data import DataLoader

train_data = DatasetWrapper(self.train_dataset)

data_loader = DataLoader(train_data)

for batch_index, batch_data in enumerate(data_loader):
  ...

Most intuitive would be to let DatasetWrapper.__iter__() return a generator over single sequences and then do batching via the arguments to DataLoader. However, for IterableDataset the batch_sampler and collate_fn arguments are not available. You can set DataLoader(train_data, batch_size=42), but this would really just put a constant amount of sequences into a batch, not constant amount of time frames and other more sophisticated logic we want to have.
So I think, DatasetWrapper.__iter__() already has to provide the batches. I'm currently trying to figure out how to best reuse the existing batching code, so Dataset.generate_batches(), BatchSetGenerator, FeedDictDataProvider.get_next_batch() etc. Will continue tomorrow... 😄

@patrick-wilken
Copy link
Contributor

patrick-wilken commented Sep 14, 2022

@albertz, it looks to me that having a separate loss function in the training loop is not strictly necessary. Those available loss functions output normal Tensors. So you could have the loss as part of the Module if you want "all calculations to be equal". But that would rather be suitable when defining a network dict, not so much when loading an existing Module, because people normally don't do it that way...

@albertz
Copy link
Member Author

albertz commented Sep 14, 2022

No, the IterableDataset is supposed to return individual sequences, not batches.

I think the DataLoader is supposed to do that. Yes, I have read somewhere that it is a bit limited in functionality in that it only supports fixed size batches. But I'm sure there are other solutions, other DataLoaderExt implementations or whatever. Maybe just check how other frameworks like Fairseq have done this.

@albertz
Copy link
Member Author

albertz commented Sep 14, 2022

it looks to me that having a separate loss function in the training loop is not strictly necessary.

I never said that? See my suggestion on how to do it, i.e. using the TrainCtx.

So you could have the loss as part of the Module

This is now totally up to the user. You could either totally decouple it, i.e. your model just is the model, without losses, and then you have separate code to calculate the losses. Or you could mix it together, define the losses already within the model (when you pass TrainCtx to it). In returnn-common, we have the same situation now.

@albertz
Copy link
Member Author

albertz commented Sep 14, 2022

Btw, I don't like having PRs in this early phase of development. PRs just slow it down. I think we can all directly work in the master. PRs are there to avoid breaking anything, and to discuss and review the code. We cannot really break anything, as it is not functional anyway. And for discussion, we can also discuss right here. And for review, you should feel responsible to do that anyway.

If sth is unclear, better discuss here before starting to implement anyway. If we have some rough understanding on how to implement it, we should just start, and then iterate on the code, in master directly.

@Icemole
Copy link
Collaborator

Icemole commented Sep 15, 2022

I checked the implementation of datasets for Fairseq and k2 (lhotse repo). Actually, both inherit from torch.utils.data.Dataset as opposed to torch.utils.data.IterableDataset, so both are map-style datasets. I've seen some tendency from the community to choose map-style datasets as opposed to iterable-style datasets even for big datasets, for instance here and here, the arguments being that the DataLoader can handle the batches and iteration over the dataset (or one can do it in a custom way, see k2 below), memory management can still be efficient with a Dataset and additional things have to be taken care of when using IterableDataset (see here, example 2).

k2

The __getitem__() method of the K2SpeechRecognitionDataset class directly returns a batch, which could allow for "custom" batches. Because of this, the DataLoader doesn't actually compute any batches, so when declaring it, batch_size must be equal to None. They also use a custom sampler defined here. An example usage as shown in the last link is:

dataset = K2SpeechRecognitionDataset(cuts)
sampler = SimpleCutSampler(cuts, shuffle=True)  # cuts == input data + text + potentially useful data
loader = DataLoader(dataset, sampler=sampler, batch_size=None)
for epoch in range(start_epoch, n_epochs):
    sampler.set_epoch(epoch)  # Only for shuffling purposes
    train(loader)

I wasn't able to find the definition of __len__() for that class though. Maybe if the data is never indexed, its definition can be avoided?

Fairseq

For Fairseq I focused on the SpeechToTextDataset subclass of FairseqDataset. The __getitem__() method returns an object of the class SpeechToTextDatasetItem, which also has the source audio and the target text. The collater() method is in charge of creating the batches. I assume it's the collate_fn passed to the DataLoader (maybe through *EpochBatchIterator classes), but I couldn't find any evidence of it.

The __len__() method simply returns the number of audios, calculated when the dataset is initialized.

In the LibriSpeech data preparation example, the data is processed and saved to a .tsv file, which is then loaded in the from_tsv() method when executing $ fairseq-train.

Fairseq also has a FairseqIterableDataset class which inherits from torch.utils.data.IterableDataset, but it doesn't seem to be used anywhere.

@albertz
Copy link
Member Author

albertz commented Sep 15, 2022

In general, I would trust the Fairseq developers more, esp w.r.t. how to use PyTorch in the best way. So, let's follow how they do it.

@Icemole
Copy link
Collaborator

Icemole commented Sep 15, 2022

So this means that we should implement DatasetWrapper as inheriting from torch.utils.data.Dataset instead of torch.utils.data.IterableDataset?

I have found two short tutorials on how to use a big enough dataset: this short tutorial from the Stanford University on efficiently obtaining data to the dataset from disk by using a torch.utils.data.Dataset, and this Medium post whose comments I think have good insight. I leave them here in case they're useful in any way.

@albertz
Copy link
Member Author

albertz commented Sep 15, 2022

When wrapping the RETURNN dataset, I don't think any of these tutorials can be applied. You can just use the RETURNN Dataset API, nothing else.

The Torch map-based dataset unfortunately does not fit too well to RETURNN as most RETURNN datasets do not really allow for efficient random access. However, I think HDFDataset actually should be fine. I think if it works with that, this is ok for now.

We can also do both torch.utils.data.Dataset and torch.utils.data.IterableDataset and then the user can decide. I'm still sure that even for torch.utils.data.IterableDataset, you can do everything what we need.

Remember that for the beginning, we just want to have some proof-of-concept. It's totally ok if we have static batch sizes for that.

@patrick-wilken
Copy link
Contributor

The interface of RETURNN datasets is iterable style. I would first wrap that and Dataset instead of IterableDataset wouldn't really fit here. Maybe later a map-style wrapper for HDFDataset or something like that would be nice though.

As Nahuel said, Fairseq does not use IterableDataset. I found one usage in huggingface/transformers, there they use one instance to provide sequences and then another instance which wraps the first one and does the batching: https://github.com/huggingface/transformers/blob/main/src/transformers/trainer.py#L851
That sounds like a good idea to me.

@albertz
Copy link
Member Author

albertz commented Sep 21, 2022

I moved the init_seq_order into the DatasetWrapper.__iter__. I think this is cleaner.

@albertz
Copy link
Member Author

albertz commented Sep 21, 2022

python3 rnn.py demos/demo-torch.config runs now. I.e. it constructs the net, calcs and optimizes the loss.

Many things are missing. The model is always random in each epoch. No storing/saving implemented yet.

@albertz
Copy link
Member Author

albertz commented Sep 21, 2022

In this current state, I don't use ExternData. Probably we want to introduce this. Not sure.

But for getting a first proof-of-concept, we maybe also don't need it yet.

@albertz
Copy link
Member Author

albertz commented Oct 17, 2022

Current state:

  • PR PyTorch: save/load models #1137, done.
  • We discussed that it is probably simpler for everyone to not open PRs but directly push to master, as long as the code changes is just about the PyTorch backend and does not touch any other code.
    • We should still discuss about how to implement things, better before we actually do it, at least on a high level. We can discuss right here in this issue.
    • We should still review the code changes made by others. In case there are problems, these can be discussed here as well.
    • We don't want to break the CI tests, i.e. the code style must be correct. For that, we either can just use PyCharm for editing, or run the code style checks locally. @patrick-wilken had some issues with that. In case these do not resolve, let's open a separate issue about that and discuss there about the problem.
    • Once you touch other code, e.g. refactor some TF code (e.g. move out Data/ExternData or whatever), this must be via PR.
  • We want a hybrid NN-HMM for now. Either a multi-layer BLSTM (via torch.nn.LSTM) or a Conformer, just using code from ESPnet.
  • For inference, we want to export to ONNX. It's a bit unclear how that looks like for the user. We might use extern_data to define the inputs, and have a new model_outputs just like extern_data to define the model outputs (as Data templates), and a model_forward function, which gets those inputs and returns such outputs.

@albertz
Copy link
Member Author

albertz commented Oct 21, 2022

For inference, we want to export to ONNX. It's a bit unclear how that looks like for the user. We might use extern_data to define the inputs, and have a new model_outputs just like extern_data to define the model outputs (as Data templates), and a model_forward function, which gets those inputs and returns such outputs.

This basically means that we should also abstract ExternData, Data and Dim and remove TF specific code there and move it to a common location. -> #1165 (Done now)

Such model_outputs can also be useful for the TF backends, e.g. for the compile-TF-graph script, to clearly define what outputs are expected. -> #1166

albertz added a commit that referenced this issue Apr 18, 2023
albertz added a commit that referenced this issue Apr 18, 2023
#1120

Copied from RETURNN-common.
rwth-i6/returnn_common#252

On namespace (rf.encoder.conformer.Conformer), see:
#1120 (comment)

This is currently untested.
@albertz
Copy link
Member Author

albertz commented Apr 18, 2023

The Conformer seems to work now. At least test_conformer passes.

@albertz
Copy link
Member Author

albertz commented Apr 18, 2023

@Icemole Thanks for the LSTM implementation. Unfortunately, there are many issues with it. I'm just writing to let you know. I will just fix this now. Please also see my changes in my commits. Please ask directly on Slack if anything is unclear. I'm just writing here in case other people wonder.

  • To properly handle hidden state, we need to use PackedSequence, because otherwise we get the wrong hidden state.

  • The hidden state seems to be a tuple (h,c), each of shape (num_layers * num_directions, batch_size, hidden_size). You have the first dimension wrong/missing.

  • We discussed that we want to have a specific LstmState, not just State. (Frontend API and PyTorch backend #1120 (comment))

  • (For the Backend interface, I did not use State nor LstmState because it is a functional API anyway and I did not want to clutter the Backend interface too much. But this is maybe debatable.)

  • You are confusing in_dim with spatial_dim. And you are not properly passing the spatial_dim. You forgot this argument in the signature. (Well, your spatial_dim in the signature is actually the in_dim.) (Also see: Frontend API and PyTorch backend #1120 (comment) and the RC code.)

  • Your logic for index_in_dim is too complicated. You can just use get_axis_from_description. As said, please see the other code, and if anything (any single line of code) is not clear to you, please clarify.

  • Do not use is_batch_dim() to check for the batch dim. For the RF, any dim which is not one of operating/specified dims (in the LSTM case: spatial dim, feature/in dim for input, spatial dim, feature/out dim for output) is a batch dim. There can be one, multiple, or zero batch dims. The code is simply batch_dims = [d for d in source.dims if d != spatial_dim and d != in_dim].

  • You are not transposing the inputs (source, state) to the right dim order. In fact, you just leave the dim order as it is. This is wrong in general. For many Torch ops, you cannot do it like this.

  • You have one (incorrect) check for batch_first, but this ignores the order of the other dims.

  • In case of LSTM, time-major should be more efficient in any case. So do not leave it as batch-major, if it is like that. Just transpose it into the order which is most efficient.

  • In the output, the order should be left as-is, so it is wrong to use copy_template() on the original (untransposed) input.

  • You put this into a file lstm.py. But I would follow RC here and call it rec.py. (In PyTorch, it's called rnn.py.)

  • In RC, the LSTM supports to operate both on a single step or on a sequence. This is determined by having spatial_dim being either single_step_dim or being a real dim in the source. See Unify rec ...Step variants with on-seq variants? returnn_common#81 for some discussion. I think we should just follow the same API. Also as we agreed that we would just copy the RC API.

  • Why do you use plural (ff|rec)_biases? It is only a single bias, or not? I would also keep weight in singular. Keep it consistent to Linear or other modules/functions.

  • Why do you need the rec bias? Having only the ff_bias should be enough, or not? And just call it bias, not ff_bias then. Extending on that though: Maybe they implemented the gradient incorrectly, and the rec bias only gets the gradient from the rec part, and the ff bias only gets the gradient from the ff part? If this is the case, you can use 0.5 * bias for both the rec bias and ff bias and then it should be correct.

  • You should document the shape (expected dims), esp of the parameters. And also the order of parts in the weight matrices.

  • Put * into the signature to not leave all as positional args. (doc)

  • Use mark_as_default_output with specifying an explicit shape to define the dim order.

  • Do not import things directly from the returnn.frontend namespace, e.g. like State. Instead, just refer to it via rf.State.

  • It is 4 * out_dim instead of out_dim * 4. This is a bit subtle, but when you would split the dims up in the 4 parts, you want it this way. See also the RC code.

  • You write:

    Feed-forward has priority over recurrent, and weights have priority over biases. See the torch docstring
    or torch LSTMCell: https://github.com/pytorch/pytorch/blob/4bead64/aten/src/ATen/native/RNN.cpp#L1458

    I don't see any docstring on your given link. Did you put the wrong link? Where is the docstring? I don't really see from there how you get to that lstm_params order. I also tried to check the torch.nn.LSTM code but I did not fully grasp the flatten param logic.

@albertz albertz reopened this Apr 18, 2023
albertz added a commit that referenced this issue Apr 18, 2023
albertz added a commit that referenced this issue Apr 18, 2023
@albertz
Copy link
Member Author

albertz commented Apr 19, 2023

Now the LSTM also works with the TF-net-dict backend.

@albertz
Copy link
Member Author

albertz commented Apr 19, 2023

Next TODOs:

@Icemole
Copy link
Collaborator

Icemole commented May 9, 2023

I just pushed the first version of the ONNX exporting tool, so that you can check it out.

I'm not sure how usable it is: I checked the demos, and it couldn't export the convolutional networks from both the torch-demo and rf-demo (this is a problem from PyTorch, see for instance this SO post). However, it was able to export the same nets when I substituted convolutional layers by linear layers in both demos.

Anyway, it still needs more work (check the TODOs, also testing with more modules) and to be tidied up (removing comments/unnecessary code). Mainly:

  1. What to do if the user doesn't define the kind attribute of dimensions properly? For instance, in the rf-demo there's no reference to time_dim as a spatial dimension, or in_dim as a feature dimension. In the current version I add additional checks to know the dimension I'm treating (batch/time/dim), but I think this should be best handled by properly setting the Dim.kind attribute.
  2. Somewhat related with the point above, some things are hardcoded, like dynamic axes. We'd need a way to automatically infer dynamic dimensions, but I don't know if we can do it when we can't use the functions stated in the point above because the proper Dim.kind attributes aren't set.
  3. I created a _RFTensorAsPTTensor to work in a similar way that _RFModuleAsPTModule does. This might be avoided, but as of right now I'm not completely sure how; I'll think further when the tool is completed.

There might be some more issues that might arise by looking at the code, but as I said this is a very preliminary version. Feel free to post some feedback here or on slack.

@albertz
Copy link
Member Author

albertz commented May 9, 2023

torch-demo and rf-demo (this is a problem from PyTorch, see for instance this SO post).

Did you link the right SO post? Can you go into more detail on what the problems are? I.e. post error etc. Maybe link a Gist if it is too much here.

What to do if the user doesn't define the kind attribute of dimensions properly? For instance, in the rf-demo there's no reference to time_dim as a spatial dimension, or in_dim as a feature dimension.

I don't really understand. Why does the kind matter? The kind actually should not matter. It is important that the kind does not have any influence at all. It is purely for cosmetic purpose, nothing else. So it is important that it is impossible that anything does not work when the user does not set this.

In the current version I add additional checks to know the dimension I'm treating (batch/time/dim), but I think this should be best handled by properly setting the Dim.kind attribute.

No. You never ever should access or depend on the kind.

I don't really know what additional checks you mean. The user knows the dimensions because the user specifies them. The ONNX export script does not need to know anything additionally over what the user specifies.

Somewhat related with the point above, some things are hardcoded, like dynamic axes. We'd need a way to automatically infer dynamic dimensions, but I don't know if we can do it when we can't use the functions stated in the point above because the proper Dim.kind attributes aren't set.

As said, ignore Dim.kind.

Why not just check what dimension is dynamic? You can simply check is_dynamic()? Nothing should be hardcoded. You would just check is_dynamic(). This is again up to the user, to specify it correctly.

I created a _RFTensorAsPTTensor to work in a similar way that _RFModuleAsPTModule does. This might be avoided, but as of right now I'm not completely sure how; I'll think further when the tool is completed.

Don't do this. I don't really understand the purpose. Also, it is clear that this cannot work. You must get raw PyTorch tensors as input, and return raw PyTorch tensors. Just remove it and use raw PyTorch tensors.

Edit I fixed most of the issues in the script now. I am also using model_outputs (#1166) to have a clear specification of the output. I think we need that because when we call export, we must already know the output names.

@Icemole
Copy link
Collaborator

Icemole commented May 10, 2023

Did you link the right SO post? Can you go into more detail on what the problems are? I.e. post error etc. Maybe link a Gist if it is too much here.

The error reads:

torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::_convolution_mode' to ONNX opset version 14 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues

So I don't think it's possible to export convolutional nets as of right now. I can provide a stack trace if you want, but from this information and the SO post I can infer that the padding="same" argument is making the export fail. Because of this, people have usually implemented the equivalent of padding="same" in their own code.

As a tangential point, please have a look at the list of supported operators for exporting. Most notably, aten::_ctc_loss, aten::fft_* and many operations from the _quantized namespace aren't supported (@curufinwe please check; does this interfere with your plans of using the ONNX exporting tool?).

Thanks for the feedback. I've reviewed the commit and left a couple of comments. Most importantly, regarding the usage of is_batch_dim().

The implementation with _RFTensorAsPTTensor seemed to work, but I agree in the fact that it was undesirable if we could just use straight torch.Tensors.

@albertz
Copy link
Member Author

albertz commented May 10, 2023

So I don't think it's possible to export convolutional nets as of right now.

We discussed this already. You sent me this other SO post some days ago about exactly this. The problem is that padding="same" does not seem to be supported currently in ONNX. But we can just fix this in the backend, as I already explained: We need some good way inside the backend to check whether this is run for ONNX export, and in that case, we might have some custom code in some cases. E.g. for this case now. It’s easy to do the padding by ourselves and then reset padding=0 and than it would work. but we would only do this for ONNX, otherwise use the standard pytorch code

As a tangential point, please have a look at the list of supported operators for exporting. Most notably, aten::_ctc_loss, aten::fft_* and many operations from the _quantized namespace aren't supported (@curufinwe please check; does this interfere with your plans of using the ONNX exporting tool?).

Some of them, we could reimplement. Maybe also with slower variants. We could also try to directly fix them on ONNX side? The FFT functions are used by us for the feature extraction. It looks like quite a long list. I wonder a bit about quantization. I thought this is esp well supported with ONNX? I would say we individually handle cases when we stumble upon them.

@albertz
Copy link
Member Author

albertz commented May 10, 2023

Thanks for the feedback. I've reviewed the commit and left a couple of comments. Most importantly, regarding the usage of is_batch_dim().

I already commented (commit b9631be), but maybe I copy that here as well, so that it is easier to find, as this is maybe an important comment:

In all code of the backend, and also in RF, is_batch_dim() is never correct, because the batch dims are always those dims which are not specified. What is really a batch dim? The meaning here is different: In the backend, it refers to those dims which are not relevant for some op, so where it would apply the same op on all entries. That is the only meaning in the backend, when you read sth like "other dims are treated as batch dims". It does not matter what is_batch_dim() returns.

However, is_batch_dim(), and the global batch_dim object, and the batch dim in ExternData, in BatchInfo, in collate_batch, etc, that has a different meaning: It means the multiple sequences from the dataset which are put together into a mini batch.

For both those two distinct meanings, it can happen to refer to the same dim, and in practice, it often will also refer to the same dim (e.g. for an LSTM with input shape [B,T,F], this B usually means both things), but not necessarily, and in the backend, or anywhere in the RF, it would be wrong to make this assumption.

So this global batch dim, from the mini batch, this is kind of a singleton. is_batch_dim() is supposed to return True only for this. (It gets a bit more complicated in the TF-net-dict with some more special cases, but those should not matter for the RETURNN frontend.)

You are right, it checks the kind internally, which is also bad, but we assume there is really only this one dimension where this is True. But this is a historic implementation detail which will probably also be cleaned up at some later point.

Further, this global batch dim dim tag is treated a bit different from other dim tags, again for historical reasons. We probably should clean this up. E.g. dyn_size_ext should normally have been defined when it is dynamic, as it normally is, but this is often not defined for the batch dim properly. For the RF, there is some code in place to actually clean this up.

@albertz
Copy link
Member Author

albertz commented May 16, 2023

Some new milestone achieved: After I already verified that the Conformer encoder works correctly in the RF with PyTorch backend (by comparing it to @mmz33 pure TF-net-dict implementation), I now also have verified that I get exactly the same decoder outputs including the final label logits (again by comparing it to @mmz33 pure TF-net-dict AED implementation), again with RF with PyTorch backend.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

8 participants