Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torch.compile for RF #1491

Open
albertz opened this issue Jan 9, 2024 · 3 comments
Open

Support torch.compile for RF #1491

albertz opened this issue Jan 9, 2024 · 3 comments

Comments

@albertz
Copy link
Member

albertz commented Jan 9, 2024

I'm not really sure whether that is possible because we have our own Tensor class which wraps around the torch.Tensor, and similarly all the PyTorch functions are wrapped inside RF.

However, I found out that there is TensorDict, which also wraps around PyTorch tensors, and it is explicitly stated that this is compatible with torch.compile, so maybe it is possible. In that case, it behaves like a dict, but does not inherit from dict, it just inherits from collections.abc.MutableMapping, see here.

So, let us discuss here possibilities and options to support torch.compile directly with the RF PyTorch backend.

Related:

@JackTemaki
Copy link
Collaborator

JackTemaki commented Jul 2, 2024

I just want to mention here, especially after the discussion with @NeoLegends this morning, that also torch.compile still has severe issues with dynamic shapes. This means in order to support compilation, we might need to be able to manage multiple compiled versions and created batches with exact this padding. Moritz mentioned this is also something that needs to be done when working with Jax.

I "naively" tried to wrap some of my (pure PyTorch) ASR trainings and recognitions with torch.compile, none of the experiments were in any sense successful so far (either it caused re-compiling for every batch, or just not a faster execution, both CPU and GPU). This was with PyTorch 2.2 though, not sure how much has improved since then.

@albertz
Copy link
Member Author

albertz commented Jul 2, 2024

Can you reference some of the issues with dynamic shapes for torch.compile? With RF specifically, we might have some chance to work around those issues more easily/automatically.

With JAX, I know that people use this approach successfully. But then, I think dynamic shapes support is planned to be added to JAX (jax-ml/jax#14634). Another approach I have seen is that people implement actually only a chunk-based or frame-based model (all static shapes) and then everything dynamic happens in a Python loop around that.

@JackTemaki
Copy link
Collaborator

After reviewing some of the public issues, it seems that this was more of a problem with either more special cases or earlier versions of PyTorch.

Nevertheless there is the e.g. bullet point #4 here:
https://pytorch.org/docs/stable/torch.compiler_dynamic_shapes.html#abridged-public-api

Which matches my observation, you get a compile crash, constant re-compilation (e.g. for every batch size), or slowdown.
There should be proper workarounds for it, it is just that for torch.compile you can definitely not "just apply it".

https://lightning.ai/docs/pytorch/latest/advanced/compile.html
Here at the end the note the same things, that torch.compile often yields not the desired behavior with more investment.

Older, not so related issues anymore are e.g.:
pytorch/pytorch#98441
pytorch/pytorch#106466
https://discuss.pytorch.org/t/varied-batch-size-for-compiled-model/184218 (then with all the problems you get with dynamic=True)

So I would say, we need to test this again with PyTorch 2.4, and spend much more time than a few hours. Probably we will work on this also in the context of the work of @sleepyeldrazi on efficient export and runtimes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants