Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable conversion of all_reduce and GSPMD custom_op into TTIR dialect #1351

Merged
merged 1 commit into from
Nov 26, 2024

Conversation

wooseokTT
Copy link
Contributor

@wooseokTT wooseokTT commented Nov 20, 2024

  • Overall Plan

As a first step of multi-device support plan, this PR allows to convert the MLIR outputs that target all_reduce op from JAX/OpenXLA(GSPMD)/PJRT. There will be following several PRs, which allow the computation flows from TTIR down to runtime. Detailed steps are as follows.

(1) Convert MLIRs from JAX/OpenXLA/PJRT to TTIR (this PR)
(2) Pass converted TTIR to TTNN MLIR and Flatbuffer format
(3) Parse TTNN flatbuffer and execute in TT Runtime

Although current version of code is targeting GSPMD partitioned MLIRs, our future plan mainly aims at supporting Shardy-based JAX/PJRT MLIRs.

  • Implementation Details

In general, GSPMD partitioned MLIR has following computation pattern,

A. Shard inputs for computation in multi-device
%0 = stablehlo.custom_call @sharding(%arg0) {mhlo.sharding = "{devices=[2,4]<=[8]}"} : (tensor<...>) -> tensor<...>
%1 = stablehlo.custom_call @SPMDFullToShardShape(%0) {mhlo.sharding = "{manual}"} : (tensor<...>) -> tensor<...>

B. Simultaneous compute on multiple devices
%0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], ... : (tensor<...>, ) -> tensor<...>

C. Merge partial computation results using CCL ops
%1 = "stablehlo.all_reduce"(%0) < ... > ( ... ):
%2 = stablehlo.add %arg2, %arg3 : tensor
stablehlo.return %2 : tensor
}) : (tensor<4096x16384xf32>) -> tensor<4096x16384xf32>

D. Concat outputs if needed
%5 = stablehlo.custom_call @sharding(%4) {mhlo.sharding = "{manual}"} : (tensor<...>) -> tensor<...>
%6 = stablehlo.custom_call @SPMDShardToFullShape(%5) {mhlo.sharding = "{devices=[2,1,4]<=[8] last_tile_dim_replicate}"} : (tensor<...>) -> tensor<...>

Currently, we can convert B, so this PR convert A, C, and D parts.
For C, we need to introduce TTIR all_reduce op while for A and D, we introduce new TTIR mesh_shard op.

  • Changes in this PR
  1. TT_Reduce_Type is created to share computation type with TTNN dialect
  2. AllReduceOp in TTIR is introduced to accommodate stableHLO all_reduce op
  3. MeshShardOp in TTIR is introduced to capture GSPMD custom sharding functions
  4. Realistic test cases are added from JAX/PJRT output

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (1/6)

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (2/6)

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (3/6)

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (4/6)

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (5/6)

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Clang-Tidy found issue(s) with the introduced code (6/6)

@wooseokTT wooseokTT force-pushed the wooseok/add_stablehlo_ccl_op_support branch 2 times, most recently from eb22ca1 to f944c3b Compare November 20, 2024 20:44
Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Minor syntactic nits

@wooseokTT wooseokTT force-pushed the wooseok/add_stablehlo_ccl_op_support branch from f944c3b to 6b8d573 Compare November 21, 2024 18:51
This was linked to issues Nov 21, 2024
Copy link
Contributor

@nsmithtt nsmithtt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two additional minor nits, otherwise looks good!

lib/Dialect/TTIR/IR/TTIROps.cpp Outdated Show resolved Hide resolved
include/ttmlir/Dialect/TTIR/IR/TTIROps.td Outdated Show resolved Hide resolved
def TTIR_MeshShardOp : TTIR_DPSOp<"mesh_shard"> {
let summary = "Mesh shard operation";
let description = [{
MeshShard op
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More descriptive please with some example... Same for all other ops

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. I will give more detailed descriptions about the ops...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added detailed description. Let me know if you need further changes.

Comment on lines 1061 to 1065
template <typename srcOpTy>
LogicalResult getReduceType(srcOpTy &srcOp, ReduceType &reduceType) {
if constexpr (!std::is_same<srcOpTy, mlir::stablehlo::AllReduceOp>::value) {
return failure();
}
Copy link
Contributor

@mtopalovicTT mtopalovicTT Nov 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense just to specialize this function for ReduceOp so you will get compile error instead of getting pass error during runtime?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think there are multiple stablehlo multi-device funcs that could reuse this function in the future. @wooseokTT, feel free to correct me if I'm wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. We are planning to land further ccl ops including reduce_scatter that uses this function in near future. The computation ops will be embedded into all_reduce/reduce_scatter ops as a type of computation attribute, and they will be gone in the mlir. So, I would assume that pass error makes sense to me.

lib/Conversion/StableHLOToTTIR/StableHLOToTTIRPatterns.cpp Outdated Show resolved Hide resolved
Copy link
Contributor

@mtopalovicTT mtopalovicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nsmithtt @wooseokTT

Could you provide a more detailed PR description or perhaps a short document outlining what this PR introduces? Since we’re introducing new concepts into TTIR, it would be really helpful to have a reference that explains the ideas around multi-device functionality and some implementation details. This would make it easier for those who are less familiar with the changes to understand the context and goals.

@wooseokTT wooseokTT force-pushed the wooseok/add_stablehlo_ccl_op_support branch from 6b8d573 to 28bcbe8 Compare November 23, 2024 00:18
  1. TT_Reduce_Type is created to share compution type with TTNN dialect
  2. AllReduceOp in TTIR is introdcued to accomodate stableHLO
     all_reduce op
  3. MeshShardOp in TTIR is introduced to capture GSPMD custom sharding
  4. Realistic test cases are added from JAX/PJRT output

Current verion of importing is targetting GSPMD input, but our future
plans mainly focus on supporting Shardy-based JAX/PJRT output.
@wooseokTT wooseokTT force-pushed the wooseok/add_stablehlo_ccl_op_support branch from 28bcbe8 to 3b9531e Compare November 23, 2024 00:22
@wooseokTT
Copy link
Contributor Author

wooseokTT commented Nov 25, 2024

@nsmithtt @wooseokTT

Could you provide a more detailed PR description or perhaps a short document outlining what this PR introduces? Since we’re introducing new concepts into TTIR, it would be really helpful to have a reference that explains the ideas around multi-device functionality and some implementation details. This would make it easier for those who are less familiar with the changes to understand the context and goals.

@mtopalovicTT I updated the PR description with the details. Let me know if you need any further descriptions. TT-MLIR is actively evolving now, so further PRs will specify more concrete concepts and details, I believe.

@mtopalovicTT
Copy link
Contributor

@wooseokTT Thanks PR description is awesome. This makes stuff a lot more clear.

@wooseokTT wooseokTT merged commit 3d029b6 into main Nov 26, 2024
18 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Push Jax test through Add CCL ops to TTIR Dialect
3 participants