Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AOT inductor should generate source code instead of a library #115965

Open
mindbeast opened this issue Dec 16, 2023 · 14 comments
Open

AOT inductor should generate source code instead of a library #115965

mindbeast opened this issue Dec 16, 2023 · 14 comments
Labels
feature A request for a proper, new feature. module: aotinductor aot inductor module: inductor oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@mindbeast
Copy link

mindbeast commented Dec 16, 2023

🚀 The feature, motivation and pitch

AOT inductor looks like the upcoming means to do inference from native code that was trained in pytorch, and the replacement for torchcript export to native code. It's clear this interface is in prototype status, but based on what is present right now, it's problematic for many users.

torch._export.aot_compile, as currently defined, produces a .so and presumably invokes nvcc for gpu models, and likely a host compiler for cpu models. This is pretty problematic for integration into many native build tools, as the export process takes over building of the inference library. Cross compilation is impossible, as is passing flags to the build tools.

An interface that would potentially be much friendlier would yield source code that could be fed into an existing build system rather than directly providing a library. This way pytorch wouldn't have to manage build tools in any capacity. There is likely a lot of complexity here, because code generation likely wants to hardcode many platform specific details, e.g. gpu type, cpu instruction sets. Being able to specify the platform constraints and capabilities on the aoti_compile interface would likely be wise, rather than attempting to automatically infer it from the local machine.

Alternatives

No response

Additional context

No response

cc @ezyang @chauhang @penguinwu @voznesenskym @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @yf225 @chenyang78 @kadeng @muchulee8 @ColinPeppler @amjames @desertfire @aakhundov @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4 @anijain2305 @peterbell10 @msaroufim @wconstab @bdhirsh @zou3519

@vadimkantorov
Copy link
Contributor

A part of problem might also be that the Inductor infers these from the local machine anyway, for it might perform benchmark and choose hparams on the local machine? it's also problematic for heterogenous inference fleet or cross-compilation scenarios :(

@zou3519 zou3519 added triage review feature A request for a proper, new feature. module: inductor labels Dec 18, 2023
@mindbeast
Copy link
Author

I admit to not being familar with the plans for what is happening to torchscript, but it would be unfortunate to lose the ability to export a model once from python, and then use that exported model from native code on different devices.

Given all the new torch.export facilities, I suppose it would be possible to completely generate device agnostic c++ code from export, but then I'd get absolutely zero layer level optimzations. Anyway, love to know the plans here if there are any.

@zou3519 zou3519 added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Dec 19, 2023
@aakhundov
Copy link
Contributor

aakhundov commented Dec 26, 2023

it's also problematic for heterogenous inference fleet or cross-compilation scenarios :(

@vadimkantorov I guess, for using an AOT-compiled model in heterogenous inference fleet setting, one would need to compile a separate model on each HW type and select accordingly for serving? TBH, it's hard to imagine something like cross-compilation here, as performance (micro-)benchmarks are involved, and those require real hardware. Unless one could resort to heuristics, but as you've mentioned, TorchInductor actually runs the benchmarks. Or do you have some alternatives in mind?

@mindbeast
Copy link
Author

I

it's also problematic for heterogenous inference fleet or cross-compilation scenarios :(

@vadimkantorov I guess, for using an AOT-compiled model in heterogenous inference fleet setting, one would need to compile a separate model on each HW type and select accordingly for serving? TBH, it's hard to imagine something like cross-compilation here, as performance (micro-)benchmarks are involved, and those require real hardware. Unless one could resort to heuristics, but as you've mentioned, TorchInductor actually runs the benchmarks. Or do you have some alternatives in mind?

I'm not intimately familiar with the internals or future goals of inductor, unfortunately, so can't say if this approach makes sense. Like a conventional compiler, though, there should be optimization passes that are not GPU hardware specific (like kernel fusion) and ones are architecture or even card specific (cutlass kernels perhaps). In theory, should be possible to do run a subset of those architecture agnotic passes that work for all gpus rather than something that is totally hardcoded to the specific gpu. Additionally, some optimizations will always yield better performance and won't need to be benchmarked.
If everything really is benchmark based, it would probably be reasonable to do optimizations that "tune" for the current gpu, but will produce functional code on all gpus. This would be similar to gcc -mtune, roughly.

If that's a big architectural change though, I'd settle for generated code that worked only on the "current gpu hardware". At least that way, it would be possible to integrate with an existing native build system. If I wanted to deploy on multiple architectures, I'd just have to export generated code on N different machines. As a general point though, having code generation always be tied exactly to the machine being deployed is a lot of added complexity to build pipelines. Would probably be wise to have some means to specify architecture and generate code, but that would necessitate heuristics for making the decisions.

It is possible doing something like this for deployed native code doesnt make sense. Perhaps it makes more sense as an external tool that works on the IR from torch.export or executorch? It would be a tremendous amount of redundant optimization passes and code generation as an external tool, though.

@ezyang
Copy link
Contributor

ezyang commented Dec 29, 2023

The request for something portable like TorchScript is a reasonable thing to want but it is not really in the design goals of AOTInductor. In fact, the so output is explicitly intended because it minimizes the amount of BC surface for these artifacts to just whatever C ABI the model so relies on.

If the ask here is specifically better cross compilation support, instead of splitting the compilation process into two stages, it will probably be easiest to just add enough control on toolchain etc so that you can do it all in one go. You would still be expected to redo the export for every target you want, but this seems expected.

@gluefox
Copy link

gluefox commented Feb 27, 2024

@ezyang Is there a path towards a device-agnostic export workflow ahead or will there be a point in time when torchscript is finally deprecated and no replacement is available? We had the use case of a heterogenous server-fleet above but what about the case when we ship models to arbitrary users and they run them locally via libtorch? In this case it is impossible to compile all needed variants beforehand.

@ezyang
Copy link
Contributor

ezyang commented Feb 28, 2024

The thing that is being worked on right now and closest to what you are looking for is pre dispatch export, but current focus is on a runtime that is just... eager PyTorch and Python

@mindbeast
Copy link
Author

@ezyang Is pre dispatch export just torch.export? I'm not familiar with anything that ingests it other executorch, which is really very Qualcomm centric right now.

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Feb 28, 2024

probably be easiest to just add enough control on toolchain etc so that you can do it all in one go

I think, even for hacking custom toolchains - a possibility to export the sources code / current compilation commands to reproduce whatever Inductor is doing ynder the hood would still be very handy (and increase transparency)

@masnesral
Copy link
Contributor

I'm trying to help scrub old issues this week and struggling to figure out what to do with this one. Is anyone still following along here and is there something actionable (even if low priority) here to warrant leaving this open?

@masnesral
Copy link
Contributor

I'm trying to help scrub old issues this week and struggling to figure out what to do with this one. Is anyone still following along here and is there something actionable (even if low priority) here to warrant leaving this open?

Since AOT Inductor is in the title, lemme check with @desertfire. Bin, any opinions?

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Sep 9, 2024

@ezyang Could AOT inductor be used to generate a no-runtime code (or maybe dump the needed runtime code in a separate directory)? This is interesting for some embedded scenarios, including compiling the model code to WebAssembly/wasm-simd - for which a specialized cross-compilation / LLVM options would be needed. E.g. being able to generate llm.c-like code would be useful for educational/hackery and some embedded/in-browser scenarios

@ezyang
Copy link
Contributor

ezyang commented Sep 9, 2024

@Chillee has been advocating for a bit, it's very reasonable, we just need an implementation of the C ABI that doesn't directly use full libtorch

@vadimkantorov
Copy link
Contributor

vadimkantorov commented Sep 9, 2024

If the generated code is sufficiently portable C (maybe with sufficiently portable C++ runtime code), then it could also be used as a baseline for ExecuTorch / TVM in terms of execution time...

I also wonder if now Triton can generate C code as a target? If so, maybe more operators with available Triton impls can be lowered to C code...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: aotinductor aot inductor module: inductor oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

9 participants