-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Multi-Gpu Python Frontend API #3094
Comments
Yes, that's the right summary. |
Is that a question for me? We don't expose communication expressions in the python frontend. For example, Fuser/tests/python/test_multidevice.py Lines 43 to 49 in 4c9090e
|
No, it's not the current implementation. One of the preseg passes (insert_resharding) inserts |
This is ballpark what I had in mind. Speaking of the implementation, I'm unsure about creating a separate To get started, I'm inclined to have something like GSPMD's mesh_split API, which merely annotates the sharding and embeds the annotation inside the definition. This should be enough given we expect nvFuser's sharding propagation to do most of the heavy lifting. |
This was a question for myself. Communication expressions blur the line between the math definition and fusion scheduling. I think we should support them in the |
I'm not sure I follow. We are very incentivized to hide them from the user and thus the definition, because people tend to make mistakes or communicate in a suboptimal way. We might want to expose communication for debugging, but I'd prefer exposing that via host IR. |
The only users of the python-frontend directly are ourselves. We should prioritize our productivity. Do you intend to expose HostIR in python? If Thunder traces a python program with What if I want to trace a SOTA implementation like Megatron-Core? IIUC, our implementation comes from porting their approach to NvFuser. Won't there always be a lag time for supporting their latest research? |
This is exactly what I've wanted from nvfuser. Currently a Thunder trace of ddp/fsdp has all the distributed communication ops in it, namely all-reduce, all-gather, and reduce-scatter, but nvfuser does not have python API for communication ops, the ops are kind of graph break points. |
Both of you are asking good questions on how nvFuser and Thunder coplay. It's a large design space that I haven't explored fully, and I'm sure @kevinstephano has better ideas. I plan to let nvFuser solve the following two problems:
So to your original questions, I don't plan to let nvFuser take a tensor-parallel Thunder trace instrumented with torch.distributed operations. It's certainly doable but isn't the best investment at this moment. Instead, I think the most immediate goal is to allow nvFuser to take a data-parallel Thunder trace with tensor-parallel annotations. There, DP is implemented using torch.distributed ops (or DTensor?), and the TP intention is represented using some annotations that nvFuser can process further. Will torch.distributed become graph breaks? Yes, but they won't be everywhere dictating all communications so hopefully nvFuser will still have quite some good regions to optimize TP. |
|
The closest I can find is https://gist.github.com/wujingyue/b111aa8b8d92067fc6004f5d0488dd27, the forward and backward trace for a transformer layer. You can imagine the same traces with inputs being annotated row-wise sharded, column-wise sharded, or replicated.
The same traces above but with batch size > 1. |
The inputs need to be annotated with the involved data parallel schemes, and also the trace needs to have |
Yes, I believe the trace needs to contain the needed DP constructs because we are talking about combining DP and TP. I'm just unsure about the exact format. Would you mind sending me a DDP'ed trace and/or teaching me how I can generate one? This'll help me think more concretely. |
I'd use https://github.com/Lightning-AI/lightning-thunder/blob/main/thunder/benchmarks/benchmark_litgpt.py and run it with For ddp, set
Just so you know, |
🚀 The feature, motivation and pitch
RFC: Multi-Gpu Python Frontend API
This RFC compares and contrasts some ideas for exposing multi-gpu support in the python frontend.
multigpu_schedule
approach.Current Multi-GPU support in NvFuser.
multidevice_schedule
function to create device meshand to apply ParallelType layout.
Then, create HostIRContainer for communication expressions.
DTensor in NvFuser.
References:
API Example:
Description:
In PyTorch, a
DTensor
is a tensor with a device mesh and a layout.In NvFuser, a
DTensor
is a TensorView with a device mesh.The layout is specified by setting
ParallelType::DIDx
,ParallelType::DIDy
, andParallelType::DIDz
on some IterDomains.Apply propagation rules through operations for
DTensors
in FusionDefinition.Manual Multi-Gpu Definition.
Why expose communication expressions in python-frontend?
What is a multi-gpu matmul?
Goal: Compute
C[M, N] = A[M, K] @ B[K, N]
using a mesh of devices.Shard A and B input matrices according to C output matrix.
sA
is row-wise sharded.sB
is col-wise sharded.Apply matmul given A and B shards.
sC[sM, sN] = sA[sM, K] @ sB[K, sN]
Gather C output shards to get full C output matrix.
sC
is gathered from all devices to create C matrix.Multidevice Schedule:
Manual Scheduling with Multidevice Schedule:
Manual Scheduling in definition:
DTensor:
cc @wujingyue @kevinstephano
The text was updated successfully, but these errors were encountered: