-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable conversion of all_reduce and GSPMD custom_op into TTIR dialect #1351
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (1/6)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (2/6)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (3/6)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (4/6)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (5/6)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Clang-Tidy
found issue(s) with the introduced code (6/6)
eb22ca1
to
f944c3b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Minor syntactic nits
f944c3b
to
6b8d573
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Two additional minor nits, otherwise looks good!
def TTIR_MeshShardOp : TTIR_DPSOp<"mesh_shard"> { | ||
let summary = "Mesh shard operation"; | ||
let description = [{ | ||
MeshShard op |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
More descriptive please with some example... Same for all other ops
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. I will give more detailed descriptions about the ops...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added detailed description. Let me know if you need further changes.
template <typename srcOpTy> | ||
LogicalResult getReduceType(srcOpTy &srcOp, ReduceType &reduceType) { | ||
if constexpr (!std::is_same<srcOpTy, mlir::stablehlo::AllReduceOp>::value) { | ||
return failure(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense just to specialize this function for ReduceOp so you will get compile error instead of getting pass error during runtime?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think there are multiple stablehlo multi-device funcs that could reuse this function in the future. @wooseokTT, feel free to correct me if I'm wrong.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. We are planning to land further ccl ops including reduce_scatter that uses this function in near future. The computation ops will be embedded into all_reduce/reduce_scatter ops as a type of computation attribute, and they will be gone in the mlir. So, I would assume that pass error makes sense to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you provide a more detailed PR description or perhaps a short document outlining what this PR introduces? Since we’re introducing new concepts into TTIR, it would be really helpful to have a reference that explains the ideas around multi-device functionality and some implementation details. This would make it easier for those who are less familiar with the changes to understand the context and goals.
6b8d573
to
28bcbe8
Compare
1. TT_Reduce_Type is created to share compution type with TTNN dialect 2. AllReduceOp in TTIR is introdcued to accomodate stableHLO all_reduce op 3. MeshShardOp in TTIR is introduced to capture GSPMD custom sharding 4. Realistic test cases are added from JAX/PJRT output Current verion of importing is targetting GSPMD input, but our future plans mainly focus on supporting Shardy-based JAX/PJRT output.
28bcbe8
to
3b9531e
Compare
@mtopalovicTT I updated the PR description with the details. Let me know if you need any further descriptions. TT-MLIR is actively evolving now, so further PRs will specify more concrete concepts and details, I believe. |
@wooseokTT Thanks PR description is awesome. This makes stuff a lot more clear. |
As a first step of multi-device support plan, this PR allows to convert the MLIR outputs that target all_reduce op from JAX/OpenXLA(GSPMD)/PJRT. There will be following several PRs, which allow the computation flows from TTIR down to runtime. Detailed steps are as follows.
(1) Convert MLIRs from JAX/OpenXLA/PJRT to TTIR (this PR)
(2) Pass converted TTIR to TTNN MLIR and Flatbuffer format
(3) Parse TTNN flatbuffer and execute in TT Runtime
Although current version of code is targeting GSPMD partitioned MLIRs, our future plan mainly aims at supporting Shardy-based JAX/PJRT MLIRs.
In general, GSPMD partitioned MLIR has following computation pattern,
A. Shard inputs for computation in multi-device
%0 = stablehlo.custom_call @sharding(%arg0) {mhlo.sharding = "{devices=[2,4]<=[8]}"} : (tensor<...>) -> tensor<...>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{manual}"} : (tensor<...>) -> tensor<...>
B. Simultaneous compute on multiple devices
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], ... : (tensor<...>, ) -> tensor<...>
C. Merge partial computation results using CCL ops
%1 = "stablehlo.all_reduce"(%0) < ... > ( ... ):
%2 = stablehlo.add %arg2, %arg3 : tensor
stablehlo.return %2 : tensor
}) : (tensor<4096x16384xf32>) -> tensor<4096x16384xf32>
D. Concat outputs if needed
%5 = stablehlo.custom_call @sharding(%4) {mhlo.sharding = "{manual}"} : (tensor<...>) -> tensor<...>
%6 = stablehlo.custom_call @SPMDShardToFullShape(%5) {mhlo.sharding = "{devices=[2,1,4]<=[8] last_tile_dim_replicate}"} : (tensor<...>) -> tensor<...>
Currently, we can convert B, so this PR convert A, C, and D parts.
For C, we need to introduce TTIR all_reduce op while for A and D, we introduce new TTIR mesh_shard op.