-
Notifications
You must be signed in to change notification settings - Fork 130
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Frontend API and PyTorch backend #1120
Comments
The current goal is that you can run |
Note that I created the config option I also created an initial Feel free to create a |
Btw, as we work in master, please make sure the tests are passing. There is no PyTorch test at all yet, so the only test relevant for us currently is about code inspections, like PEP8 etc. Please double check that there are no warnings on the code before you commit. For commit messages, maybe prefix them with "PyTorch" or "Torch" or so. |
For the main train loop, see the PyTorch quickstart tutorial (slightly simplified, adapted): device = "cuda" if torch.cuda.is_available() else "cpu"
class NeuralNetwork(nn.Module):
...
model = NeuralNetwork().to(device)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
def train(dataloader, model, loss_fn, optimizer):
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# Compute prediction error
pred = model(X)
loss = loss_fn(pred, y)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
...
# epoch loop
epochs = 5
for t in range(epochs):
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn) |
So, in the RETURNN for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device) We also would have a such a As I understand the PyTorch code, I think we need to have one single Maybe, like in returnn-common, we can have the same def get_model(*, epoch: int, **_unused_kwargs) -> torch.nn.Module: Would create the def train_step(*, extern_data: ExternData, model: Model, ctx: TrainCtx): Here you would do: pred = model(extern_data.data["inputs"].placeholder)
loss = F.cross_entrop(pred, y)
ctx.mark_as_loss(loss) The model.train()
ctx = TrainCtx()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
extern_data.set(X, y) # ...
pred = train_step(extern_data=extern_data, model=model, ctx=ctx)
loss = ctx.total_loss()
optimizer.zero_grad()
loss.backward()
optimizer.step() |
I'm implementing a dataset wrapper. The code is still too ugly for a pull request, so first a comment: 😄 The general interface is clear:
And then in
Most intuitive would be to let |
@albertz, it looks to me that having a separate loss function in the training loop is not strictly necessary. Those available loss functions output normal Tensors. So you could have the loss as part of the Module if you want "all calculations to be equal". But that would rather be suitable when defining a network dict, not so much when loading an existing Module, because people normally don't do it that way... |
No, the I think the |
I never said that? See my suggestion on how to do it, i.e. using the
This is now totally up to the user. You could either totally decouple it, i.e. your model just is the model, without losses, and then you have separate code to calculate the losses. Or you could mix it together, define the losses already within the model (when you pass |
Btw, I don't like having PRs in this early phase of development. PRs just slow it down. I think we can all directly work in the master. PRs are there to avoid breaking anything, and to discuss and review the code. We cannot really break anything, as it is not functional anyway. And for discussion, we can also discuss right here. And for review, you should feel responsible to do that anyway. If sth is unclear, better discuss here before starting to implement anyway. If we have some rough understanding on how to implement it, we should just start, and then iterate on the code, in master directly. |
I checked the implementation of datasets for Fairseq and k2 (lhotse repo). Actually, both inherit from k2 The
I wasn't able to find the definition of Fairseq For Fairseq I focused on the The In the LibriSpeech data preparation example, the data is processed and saved to a Fairseq also has a |
In general, I would trust the Fairseq developers more, esp w.r.t. how to use PyTorch in the best way. So, let's follow how they do it. |
So this means that we should implement I have found two short tutorials on how to use a big enough dataset: this short tutorial from the Stanford University on efficiently obtaining data to the dataset from disk by using a |
When wrapping the RETURNN dataset, I don't think any of these tutorials can be applied. You can just use the RETURNN Dataset API, nothing else. The Torch map-based dataset unfortunately does not fit too well to RETURNN as most RETURNN datasets do not really allow for efficient random access. However, I think We can also do both Remember that for the beginning, we just want to have some proof-of-concept. It's totally ok if we have static batch sizes for that. |
The interface of RETURNN datasets is iterable style. I would first wrap that and As Nahuel said, Fairseq does not use |
I moved the |
Many things are missing. The model is always random in each epoch. No storing/saving implemented yet. |
In this current state, I don't use But for getting a first proof-of-concept, we maybe also don't need it yet. |
Current state:
|
This basically means that we should also abstract Such |
#1120 Copied from RETURNN-common. rwth-i6/returnn_common#252 On namespace (rf.encoder.conformer.Conformer), see: #1120 (comment) This is currently untested.
The Conformer seems to work now. At least |
@Icemole Thanks for the LSTM implementation. Unfortunately, there are many issues with it. I'm just writing to let you know. I will just fix this now. Please also see my changes in my commits. Please ask directly on Slack if anything is unclear. I'm just writing here in case other people wonder.
|
Now the LSTM also works with the TF-net-dict backend. |
Next TODOs:
|
I just pushed the first version of the ONNX exporting tool, so that you can check it out. I'm not sure how usable it is: I checked the demos, and it couldn't export the convolutional networks from both the torch-demo and rf-demo (this is a problem from PyTorch, see for instance this SO post). However, it was able to export the same nets when I substituted convolutional layers by linear layers in both demos. Anyway, it still needs more work (check the TODOs, also testing with more modules) and to be tidied up (removing comments/unnecessary code). Mainly:
There might be some more issues that might arise by looking at the code, but as I said this is a very preliminary version. Feel free to post some feedback here or on slack. |
Did you link the right SO post? Can you go into more detail on what the problems are? I.e. post error etc. Maybe link a Gist if it is too much here.
I don't really understand. Why does the
No. You never ever should access or depend on the I don't really know what additional checks you mean. The user knows the dimensions because the user specifies them. The ONNX export script does not need to know anything additionally over what the user specifies.
As said, ignore Why not just check what dimension is dynamic? You can simply check
Don't do this. I don't really understand the purpose. Also, it is clear that this cannot work. You must get raw PyTorch tensors as input, and return raw PyTorch tensors. Just remove it and use raw PyTorch tensors. Edit I fixed most of the issues in the script now. I am also using |
The error reads:
So I don't think it's possible to export convolutional nets as of right now. I can provide a stack trace if you want, but from this information and the SO post I can infer that the As a tangential point, please have a look at the list of supported operators for exporting. Most notably, Thanks for the feedback. I've reviewed the commit and left a couple of comments. Most importantly, regarding the usage of The implementation with |
We discussed this already. You sent me this other SO post some days ago about exactly this. The problem is that
Some of them, we could reimplement. Maybe also with slower variants. We could also try to directly fix them on ONNX side? The FFT functions are used by us for the feature extraction. It looks like quite a long list. I wonder a bit about quantization. I thought this is esp well supported with ONNX? I would say we individually handle cases when we stumble upon them. |
I already commented (commit b9631be), but maybe I copy that here as well, so that it is easier to find, as this is maybe an important comment: In all code of the backend, and also in RF, However, For both those two distinct meanings, it can happen to refer to the same dim, and in practice, it often will also refer to the same dim (e.g. for an LSTM with input shape [B,T,F], this B usually means both things), but not necessarily, and in the backend, or anywhere in the RF, it would be wrong to make this assumption. So this global batch dim, from the mini batch, this is kind of a singleton. You are right, it checks the Further, this global batch dim dim tag is treated a bit different from other dim tags, again for historical reasons. We probably should clean this up. E.g. |
Some new milestone achieved: After I already verified that the Conformer encoder works correctly in the RF with PyTorch backend (by comparing it to @mmz33 pure TF-net-dict implementation), I now also have verified that I get exactly the same decoder outputs including the final label logits (again by comparing it to @mmz33 pure TF-net-dict AED implementation), again with RF with PyTorch backend. |
Edit Originally, this issue was about a proof-of-concept for a new PyTorch backend in RETURNN.
This has somehow evolved into a whole new generic frontend API (original idea here: #1264), which very much follows the API from RETURNN-common
nn
, which covers multiple backends. This API is accessible for the user viareturnn.frontend
, and the convention for the user to use it would be like:We also made sure that our
Tensor
class (earlier calledData
) supports any raw tensor type (type ofTensor.placeholder
, orTensor.raw_tensor
now) and is backend independent. This is all inreturnn.tensor
now.Currently, the following backends are relevant:
NameCtx
.torch.Tensor
.tf.Tensor
.The terminology on frontend/backend is sometimes used a bit inconsistent. We mean frontend or frontend API basically to describe what the user sees, what you have in the
returnn.frontend
namespace, i.e. functions likereduce
anddot
, or also modules likeLinear
. This API is basically very much the same as RETURNN-commonnn
.We have an abstract base class
Backend
, which defines some core functions of the API and allows to reimplement it for different backends. The derived classes are e.g.TorchBackend
,ReturnnLayersBackend
,TFBackend
. The earlier terminology was maybe a bit confusing: They implement some frontend functions for the specific backend. So sometimes we referred to this as "different frontend (implementations)".The
Backend
class and its different backend implementations is supposed to be internal to RETURNN and not directly exposed to the user. The user has some helper functions to switch the backends.There is also a lot of code which builds on top of the backend functions. E.g. the
rf.Module
class, modules likerf.Linear
, would all be independent from the backend. Or also functions likecross_entropy
.The user in the end needs to define the following functions in the config:
train_step
would only be used in training. Here the user should callmark_as_loss
on some tensors.forward_step
would be used for recognition. Here the user should callmark_as_output
on some tensors. See #1336 for more discussion on this, how it would be used then for beam search, forwarding, or whatever you want to do with the outputs.To also support PyTorch modules more directly,
get_model()
can also return atorch.nn.Module
. See #1120 (comment).To add losses when using a raw
torch.nn.Module
, the API insidetrain_step
would look like:But this is intended only to be used for external code, and for our own code, we recommend the use of the RF. But in any case, it will be easy to mix pure PT and RF code together (with the PT backend).
To get access to the step inside
train_step
, there will be sth likerf.global_train_step()
.Related:
nn
#1264nn
functions to RETURNN returnn_common#252Some summary of the current open discussions, items, or status updates:
rf.control_flow_ctx
or not? (rf.control_flow_ctx
or not? #1288)__init__
logic to work equally for graph-based and eager-based backends, specifically re-parameterization like weight norm returnn_common#250)train_step_callback
, staged training #1447)check_matched_dataset
for PT? (Frontend API and PyTorch backend #1120 (comment))Done:
seq_tag
inextern_data
(also see Add seq_idx, seq_tag to torch datapipe #1330)rf.scan
for beam search (RF: Accumulate Dim inrf.scan
for beam search #1327)rf.top_k
rf.cond
,rf.while_loop
,rf.scan
(rf.Cond and rf.Loop for eager frameworks #1282)Conformer
(Frontend API and PyTorch backend #1120 (comment))rf.Conv1d
,rf.max_pool1d
, etcrf.BatchNorm
,rf.LayerNorm
and other normalizationsrf.SelfAttention
,rf.dot_attention
rf.State
as base (Frontend API and PyTorch backend #1120 (comment))rf.dropout
rf.cond
(minimal initial support) (see rf.Cond and rf.Loop for eager frameworks #1282)rf.get_run_ctx().train_flag
rf.cross_entropy
.rf.Linear
, some activation functions.get_model
etc) for PT and TF (@albertz), also Design of PyTorch step-functions #1290Tensor
/Dim
? (Frontend API and PyTorch backend #1120 (comment))-> for now, we continue and focus on RF development, and the RF itself can be used as PyTorch extensions
rf.random
in tests (rf.random for tests #1283)torch.jit.script
/torch.jit.trace
) for ONNX? Tracing (torch.jit.trace
) should anyway always work, except it would not cover dynamic control flow.torch.compile
for ONNX will be the future but does not work yet? Maybe also OpenXLA instead of ONNX. Any other reasonable way to get PT models running within RASR? We can also simply use the RASR-Python bridge. (Exporting PyTorch code for use in RASR #1289)preload_from_files
(preload_from_files for PT engine #1292)ReturnnDatasetIterDataPipe
), chunking, batching, usingDataLoader2
.This issue here is also to discuss and report on implementation details of the new PyTorch backend.
The initial step would be to just get a proof-of-concept, meaning the goals currently are:
Getting it compatible to the TF backend is maybe the ultimate goal, but this is on a completely different level than the proof-of-concept discussed here.
For the dataset, we would implement a PyTorch
IterableDataset
to wrap any RETURNN dataset.The text was updated successfully, but these errors were encountered: