Skip to content

Commit

Permalink
Enable conversion of all_reduce and GSPMD custom_op into TTIR dialect (
Browse files Browse the repository at this point in the history
…#1351)

1. TT_Reduce_Type is created to share compution type with TTNN dialect
  2. AllReduceOp in TTIR is introdcued to accomodate stableHLO
     all_reduce op
  3. MeshShardOp in TTIR is introduced to capture GSPMD custom sharding
  4. Realistic test cases are added from JAX/PJRT output

Current verion of importing is targetting GSPMD input, but our future
plans mainly focus on supporting Shardy-based JAX/PJRT output.
  • Loading branch information
wooseokTT authored Nov 26, 2024
1 parent ebde568 commit 3d029b6
Show file tree
Hide file tree
Showing 6 changed files with 722 additions and 21 deletions.
50 changes: 50 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def TT_OperandConstraintSingleBank : I32BitEnumAttrCaseBit<"SingleBank", 7, "sin
def TT_OperandConstraintHeightSharded : I32BitEnumAttrCaseBit<"HeightSharded", 8, "height_sharded">;
def TT_OperandConstraintWidthSharded : I32BitEnumAttrCaseBit<"WidthSharded", 9, "width_sharded">;
def TT_OperandConstraintBlockSharded : I32BitEnumAttrCaseBit<"BlockSharded", 10, "block_sharded">;
def TT_OperandConstraintSystemScalar : I32BitEnumAttrCaseGroup<"SystemScalar", [TT_OperandConstraintSystem, TT_OperandConstraintScalar], "system_scalar">;
def TT_OperandConstraintAnyLayout : I32BitEnumAttrCaseGroup<"AnyLayout", [TT_OperandConstraintNone, TT_OperandConstraintInterleaved, TT_OperandConstraintSingleBank, TT_OperandConstraintHeightSharded, TT_OperandConstraintWidthSharded, TT_OperandConstraintBlockSharded], "any_layout">;
def TT_OperandConstraintAny : I32BitEnumAttrCaseGroup<"Any", [TT_OperandConstraintSystem, TT_OperandConstraintDRAM, TT_OperandConstraintL1, TT_OperandConstraintScalar, TT_OperandConstraintTile, TT_OperandConstraintAnyLayout], "any">;
def TT_OperandConstraintAnyDevice : I32BitEnumAttrCaseGroup<"AnyDevice", [TT_OperandConstraintDRAM, TT_OperandConstraintL1, TT_OperandConstraintScalar, TT_OperandConstraintTile, TT_OperandConstraintAnyLayout], "any_device">;
Expand All @@ -155,6 +156,7 @@ def TT_OperandConstraint : I32BitEnumAttr<"OperandConstraint", "TT Operand Const
TT_OperandConstraintHeightSharded,
TT_OperandConstraintWidthSharded,
TT_OperandConstraintBlockSharded,
TT_OperandConstraintSystemScalar,
TT_OperandConstraintAnyLayout,
TT_OperandConstraintAny,
TT_OperandConstraintAnyDevice,
Expand Down Expand Up @@ -189,6 +191,54 @@ def TT_BufferAccess : I32BitEnumAttr<"BufferAccess", "TT Buffer Access",
let cppNamespace = "::mlir::tt";
}

def TT_ReduceType_Sum : I32EnumAttrCase<"Sum", 0, "sum">;
def TT_ReduceType_Mean : I32EnumAttrCase<"Mean", 1, "mean">;
def TT_ReduceType_Max : I32EnumAttrCase<"Max", 2, "max">;
def TT_ReduceType_Min : I32EnumAttrCase<"Min", 3, "min">;
def TT_ReduceType_Std : I32EnumAttrCase<"Std", 4, "std">;
def TT_ReduceType_Var : I32EnumAttrCase<"Var", 5, "var">;

def TT_ReduceType: I32EnumAttr<"ReduceType", "TT Reduce Type",
[
TT_ReduceType_Sum,
TT_ReduceType_Mean,
TT_ReduceType_Max,
TT_ReduceType_Min,
TT_ReduceType_Std,
TT_ReduceType_Var,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt";
}

def TT_MeshShardDirection_FullToShard : I32EnumAttrCase<"FullToShard", 0, "full_to_shard">;
def TT_MeshShardDirection_ShardToFull : I32EnumAttrCase<"ShardToFull", 1, "shard_to_full">;

def TT_MeshShardDirection: I32EnumAttr<"MeshShardDirection", "TT MeshShardDirection",
[
TT_MeshShardDirection_FullToShard,
TT_MeshShardDirection_ShardToFull,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt";
}

def TT_MeshShardType_Manual : I32EnumAttrCase<"Manual", 0, "manual">;
def TT_MeshShardType_Replicate : I32EnumAttrCase<"Replicate", 1, "replicate">;
def TT_MeshShardType_Maximal : I32EnumAttrCase<"Maximal", 2, "maximal">;
def TT_MeshShardType_Devices : I32EnumAttrCase<"Devices", 3, "devices">;

def TT_MeshShardType: I32EnumAttr<"MeshShardType", "TT MeshShardType",
[
TT_MeshShardType_Manual,
TT_MeshShardType_Replicate,
TT_MeshShardType_Maximal,
TT_MeshShardType_Devices,
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::tt";
}

def TT_CPURoleHost : I32EnumAttrCase<"Host", 0, "host">;
def TT_CPURoleDevice : I32EnumAttrCase<"Device", 1, "device">;

Expand Down
14 changes: 14 additions & 0 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,20 @@ def TT_ArgumentAllocationAttr : TT_Attr<"ArgumentAllocation", "arg_alloc", []> {
let assemblyFormat = "`<` $address `,` $size `,` $memorySpace `>`";
}

def TT_ReduceTypeAttr : EnumAttr<TT_Dialect, TT_ReduceType, "reduce_type"> {
let assemblyFormat = "`<` $value `>`";
}

def TT_ReduceTypeArrayAttr : TypedArrayAttrBase<TT_ReduceTypeAttr, "">;

def TT_MeshShardDirectionAttr : EnumAttr<TT_Dialect, TT_MeshShardDirection, "shard_direction"> {
let assemblyFormat = "`<` $value `>`";
}

def TT_MeshShardTypeAttr : EnumAttr<TT_Dialect, TT_MeshShardType, "shard_type"> {
let assemblyFormat = "`<` $value `>`";
}

//===----------------------------------------------------------------------===//
// TT type definitions
//===----------------------------------------------------------------------===//
Expand Down
119 changes: 98 additions & 21 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -719,27 +719,6 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> {
}];
}

// CCL ops
def TTIR_AllGatherOp : TTIR_DPSOp<"all_gather"> {
let summary = "All gather operation.";
let description = [{
All gather op.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dim,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
let summary = "Conv2d operation.";
let description = [{
Expand Down Expand Up @@ -1317,4 +1296,102 @@ def TTIR_YieldOp : TTIR_Op<"yield", [Pure, ReturnLike, Terminator]> {
let arguments = (ins Variadic<AnyRankedTensorOrMemRef>:$values);
}

//===----------------------------------------------------------------------===//
// TTIR ccl ops
//===----------------------------------------------------------------------===//

def TTIR_AllGatherOp : TTIR_DPSOp<"all_gather"> {
let summary = "All gather operation.";
let description = [{
All gather op.
}];

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dim,
TT_OperandConstraintArrayAttr:$operand_constraints);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_AllReduceOp : TTIR_DPSOp<"all_reduce"> {
let summary = "AllReduce operation.";
let description = [{
AllReduce op.
}];

let arguments = (ins
Variadic<AnyRankedTensor>:$inputs,
AnyRankedTensor:$output,
I64ElementsAttr:$replica_groups,
SI32Attr:$dim,
OptionalAttr<SI32Attr>:$channel_handle,
UnitAttr:$use_global_device_ids,
TT_ReduceTypeAttr:$reduce_type,
TT_OperandConstraintArrayAttr:$operand_constraints
);

let results = (outs Variadic<AnyRankedTensor>:$results);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

def TTIR_MeshShardOp : TTIR_DPSOp<"mesh_shard"> {
let summary = "Mesh shard operation.";
let description = [{
MeshShard op shards the inputs (FullToShard) or concatnates the outputs (ShardToFull) for ccl ops.

shard_direction attribute determines whether to shard or concat.

shard_type attribute determines how to shard or concat.
manual: no sharding
replicate: all devices have identical data
maximal: only one device contains full data
devices: shard_shape determines sharded dimensions

For example, on 2x4 mesh hardware, following op shards arg0 to 8 slices, row divided by 2
and col divided by 4.

%1 = "ttir.mesh_shard"(%arg0, %0) <
{... shard_direction = #tt.shard_direction<full_to_shard>,
shard_shape = #tt.grid<2x4>,
shard_type = #tt.shard_type<devices>}> : (tensor<8192x784xf32>, ...) -> tensor<4096x196xf32>

On the other hand, this op concatnates %4 to single tensor by concatnating
one of the top row tensor with one of the bottom row tensor.

%6 = "ttir.mesh_shard"(%4, %5) <
{..., shard_direction = #tt.shard_direction<shard_to_full>,
shard_shape = #tt.grid<2x1>,
shard_type = #tt.shard_type<devices>}> : (tensor<4096x16384xf32>, ...) -> tensor<8192x16384xf32>
}];

let arguments = (ins
AnyRankedTensor:$input,
AnyRankedTensor:$output,
TT_MeshShardTypeAttr:$shard_type,
TT_MeshShardDirectionAttr:$shard_direction,
TT_GridAttr:$shard_shape,
TT_OperandConstraintArrayAttr:$operand_constraints
);

let results = (outs AnyRankedTensor:$result);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];

let hasVerifier = 1;
}

#endif
Loading

0 comments on commit 3d029b6

Please sign in to comment.