Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TTIR] Remove TTIR operand_constraints #1388

Merged
merged 1 commit into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions docs/src/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,10 @@ simpler.
So what does MLIR look like, how does it work and get parsed? The
hierarchy of an MLIR Module is as shown:
```
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {tt.system_desc = #tt.system_desc<[<#tt.arch<wormhole_b0>, #tt.grid<8x8>>], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
%0 = tensor.empty() : tensor<64x128xf32>
%1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
%1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}
}
Expand Down
109 changes: 39 additions & 70 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -185,8 +185,7 @@ class TTIR_ElementwiseOp<string mnemonic, list<Trait> traits = []> :
}];

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
TT_OperandConstraintArrayAttr:$operand_constraints);
Variadic<AnyRankedTensor>:$outputs);
let results = (outs Variadic<AnyRankedTensor>:$results);
}

Expand All @@ -199,9 +198,9 @@ class TTIR_ElementwiseTernaryOp<string mnemonic, list<Trait> traits = []> :

let builders =
[
OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out, "ArrayAttr": $operand_constraints),
OpBuilder<(ins "Value": $first, "Value": $second, "Value": $third, "Value": $out),
[{
build($_builder, $_state, {out.getType()}, {first, second, third}, out, operand_constraints);
build($_builder, $_state, {out.getType()}, {first, second, third}, out);
}]>
];
}
Expand All @@ -222,9 +221,9 @@ class TTIR_ElementwiseUnaryOp<string mnemonic, list<Trait> traits = []> :

let builders =
[
OpBuilder<(ins "Value": $in, "Value": $out, "ArrayAttr": $operand_constraints),
OpBuilder<(ins "Value": $in, "Value": $out),
[{
build($_builder, $_state, {out.getType()}, in, out, operand_constraints);
build($_builder, $_state, {out.getType()}, in, out);
}]>
];
}
Expand Down Expand Up @@ -408,14 +407,13 @@ class TTIR_ElementwiseUnaryWithFloatParameterOp<string mnemonic, list<Trait> tra

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
Variadic<AnyRankedTensor>:$outputs,
F32Attr:$parameter,
TT_OperandConstraintArrayAttr:$operand_constraints);
F32Attr:$parameter);

let builders =
[
OpBuilder<(ins "Value": $in, "Value": $out, "FloatAttr":$parameter, "ArrayAttr": $operand_constraints),
OpBuilder<(ins "Value": $in, "Value": $out, "FloatAttr":$parameter),
[{
build($_builder, $_state, {out.getType()}, {in}, {out}, parameter, operand_constraints);
build($_builder, $_state, {out.getType()}, {in}, {out}, parameter);
}]>
];
}
Expand Down Expand Up @@ -452,9 +450,9 @@ class TTIR_ElementwiseBinaryOp<string mnemonic, list<Trait> traits = []> :

let builders =
[
OpBuilder<(ins "Value": $lhs, "Value": $rhs, "Value": $out, "ArrayAttr": $operand_constraints),
OpBuilder<(ins "Value": $lhs, "Value": $rhs, "Value": $out),
[{
build($_builder, $_state, {out.getType()}, {lhs, rhs}, out, operand_constraints);
build($_builder, $_state, {out.getType()}, {lhs, rhs}, out);
}]>
];
}
Expand Down Expand Up @@ -568,8 +566,7 @@ class TTIR_ReductionOp<string mnemonic, list<Trait> traits = []> :
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
BoolAttr:$keep_dim,
OptionalAttr<I32ArrayAttr>:$dim_arg,
TT_OperandConstraintArrayAttr:$operand_constraints);
OptionalAttr<I32ArrayAttr>:$dim_arg);

let results = (outs AnyRankedTensor:$result);

Expand Down Expand Up @@ -636,8 +633,7 @@ def TTIR_EmbeddingOp : TTIR_DPSOp<"embedding"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);
AnyRankedTensor:$output);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -657,8 +653,7 @@ def TTIR_EmbeddingBackwardOp : TTIR_DPSOp<"embedding_backward"> {
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$weight,
AnyRankedTensor:$in_gradient,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);
AnyRankedTensor:$output);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -677,8 +672,7 @@ def TTIR_SoftmaxOp : TTIR_DPSOp<"softmax"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dimension,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$dimension);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -698,8 +692,7 @@ def TTIR_TransposeOp : TTIR_DPSOp<"transpose"> {
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dim0,
SI32Attr:$dim1,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$dim1);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -718,8 +711,7 @@ def TTIR_ConcatOp : TTIR_DPSOp<"concat"> {

let arguments = (ins Variadic<AnyRankedTensor>:$inputs,
AnyRankedTensor:$output,
SI32Attr:$dim,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$dim);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -739,8 +731,7 @@ def TTIR_UpdateCacheOp : TTIR_DPSOp<"update_cache"> {
let arguments = (ins AnyRankedTensor:$cache,
AnyRankedTensor:$input,
AnyRankedTensor:$update_index,
I32Attr:$batch_offset,
TT_OperandConstraintArrayAttr:$operand_constraints);
I32Attr:$batch_offset);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -759,8 +750,7 @@ def TTIR_FillCacheOp : TTIR_DPSOp<"fill_cache"> {

let arguments = (ins AnyRankedTensor:$cache,
AnyRankedTensor:$input,
I32Attr:$batch_offset,
TT_OperandConstraintArrayAttr:$operand_constraints);
I32Attr:$batch_offset);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -779,8 +769,7 @@ def TTIR_BroadcastOp : TTIR_DPSOp<"broadcast"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I64ArrayAttr:$dimension,
TT_OperandConstraintArrayAttr:$operand_constraints);
I64ArrayAttr:$dimension);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -807,8 +796,7 @@ def TTIR_Conv2dOp : TTIR_DPSOp<"conv2d"> {
SI32Attr:$padding_left,
SI32Attr:$padding_right,
SI32Attr:$padding_top,
SI32Attr:$padding_bottom,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$padding_bottom);

let results = (outs AnyRankedTensor:$result);

Expand Down Expand Up @@ -841,8 +829,7 @@ def TTIR_ConvolutionOp : TTIR_DPSOp<"convolution"> {
DenseBoolArrayAttr:$window_reversal,
TTIR_ConvolutionLayoutAttr:$convolution_layout,
ConfinedAttr<I64Attr, [IntPositive]>:$feature_group_count,
ConfinedAttr<I64Attr, [IntPositive]>:$batch_group_count,
TT_OperandConstraintArrayAttr:$operand_constraints
ConfinedAttr<I64Attr, [IntPositive]>:$batch_group_count
);

let results = (outs AnyRankedTensor);
Expand All @@ -869,8 +856,7 @@ def TTIR_GatherOp: TTIR_DPSOp<"gather"> {
DenseI64ArrayAttr:$start_index_map,
SI64Attr:$index_vector_dim,
DenseI64ArrayAttr:$slice_sizes,
BoolAttr:$indices_are_sorted,
TT_OperandConstraintArrayAttr:$operand_constraints);
BoolAttr:$indices_are_sorted);
let results = (outs AnyRankedTensor:$result);
let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
Expand All @@ -891,8 +877,7 @@ def TTIR_PoolingOp : TTIR_DPSOp<"pooling", [AttrSizedOperandSegments]> {
DenseI64ArrayAttr:$window_strides,
DenseI64ArrayAttr:$base_dilations,
DenseI64ArrayAttr:$window_dilations,
DenseI64ArrayAttr:$padding,
TT_OperandConstraintArrayAttr:$operand_constraints
DenseI64ArrayAttr:$padding
);

let results = (outs Variadic<AnyRankedTensor>);
Expand All @@ -918,8 +903,7 @@ def TTIR_MaxPool2dOp : TTIR_DPSOp<"max_pool2d"> {
SI32Attr:$padding_left,
SI32Attr:$padding_right,
SI32Attr:$padding_top,
SI32Attr:$padding_bottom,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$padding_bottom);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -938,8 +922,7 @@ def TTIR_ReshapeOp: TTIR_DPSOp<"reshape"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
I32ArrayAttr:$shape,
TT_OperandConstraintArrayAttr:$operand_constraints);
I32ArrayAttr:$shape);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -964,8 +947,7 @@ def TTIR_SliceOp: TTIR_DPSOp<"slice"> {
AnyRankedTensor:$output,
I32ArrayAttr:$begins,
I32ArrayAttr:$ends,
I32ArrayAttr:$step,
TT_OperandConstraintArrayAttr:$operand_constraints);
I32ArrayAttr:$step);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -991,8 +973,7 @@ def TTIR_SelectOp: TTIR_DPSOp<"select"> {
SI32Attr:$dim,
SI32Attr:$begin,
SI32Attr:$length,
DefaultValuedOptionalAttr<SI32Attr, "0">:$stride,
TT_OperandConstraintArrayAttr:$operand_constraints);
DefaultValuedOptionalAttr<SI32Attr, "0">:$stride);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -1017,8 +998,7 @@ def TTIR_IndexOp: TTIR_DPSOp<"index"> {
I32Attr:$dim,
I32Attr:$begin,
I32Attr:$end,
I32Attr:$step,
TT_OperandConstraintArrayAttr:$operand_constraints);
I32Attr:$step);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -1038,8 +1018,7 @@ def TTIR_SqueezeOp : TTIR_DPSOp<"squeeze"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dim,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$dim);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -1058,8 +1037,7 @@ def TTIR_UnsqueezeOp : TTIR_DPSOp<"unsqueeze"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dim,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$dim);

let results = (outs AnyRankedTensor:$result);

Expand Down Expand Up @@ -1087,8 +1065,7 @@ def TTIR_ClampOp : TTIR_DPSOp<"clamp"> {
let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
F32Attr:$min,
F32Attr:$max,
TT_OperandConstraintArrayAttr:$operand_constraints);
F32Attr:$max);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
Expand Down Expand Up @@ -1191,8 +1168,7 @@ def TTIR_FillOp : TTIR_DPSOp<"fill", [AllShapesMatch<["value", "result"]>]> {
}];

let arguments = (ins AnyRankedTensor:$output,
ElementsAttr:$value,
TT_OperandConstraintArrayAttr:$operand_constraints);
ElementsAttr:$value);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -1217,8 +1193,7 @@ def TTIR_LinearOp : TTIR_DPSOp<"linear"> {
let arguments = (ins AnyRankedTensor:$a,
AnyRankedTensor:$b,
Optional<AnyRankedTensor>:$bias,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);
AnyRankedTensor:$output);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -1238,8 +1213,7 @@ def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {

let arguments = (ins AnyRankedTensor:$a,
AnyRankedTensor:$b,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);
AnyRankedTensor:$output);

let results = (outs AnyRankedTensor:$result);

Expand Down Expand Up @@ -1362,8 +1336,7 @@ def TTIR_ScatterOp: TTIR_DPSOp<"scatter"> {
I32Attr:$index_vector_dim,
BoolAttr:$indices_are_sorted,
BoolAttr:$unique_indices,
AnyRankedTensor:$output,
TT_OperandConstraintArrayAttr:$operand_constraints);
AnyRankedTensor:$output);

let regions = (region SizedRegion<1>:$update_computation);

Expand Down Expand Up @@ -1391,8 +1364,7 @@ def TTIR_KernelOp : TTIR_DPSOp<"kernel", [AttrSizedOperandSegments]> {
let arguments = (ins FlatSymbolRefAttr:$op,
FlatSymbolRefAttr:$kind,
Variadic<AnyRankedTensorOrMemRef>:$inputs,
Variadic<AnyRankedTensorOrMemRef>:$outputs,
TT_OperandConstraintArrayAttr:$operand_constraints);
Variadic<AnyRankedTensorOrMemRef>:$outputs);
let results = (outs Variadic<AnyRankedTensorOrMemRef>:$results);
}

Expand All @@ -1417,8 +1389,7 @@ def TTIR_AllGatherOp : TTIR_DPSOp<"all_gather"> {

let arguments = (ins AnyRankedTensor:$input,
AnyRankedTensor:$output,
SI32Attr:$dim,
TT_OperandConstraintArrayAttr:$operand_constraints);
SI32Attr:$dim);

let results = (outs AnyRankedTensor:$result);

Expand All @@ -1442,8 +1413,7 @@ def TTIR_AllReduceOp : TTIR_DPSOp<"all_reduce"> {
SI32Attr:$dim,
OptionalAttr<SI32Attr>:$channel_handle,
UnitAttr:$use_global_device_ids,
TT_ReduceTypeAttr:$reduce_type,
TT_OperandConstraintArrayAttr:$operand_constraints
TT_ReduceTypeAttr:$reduce_type
);

let results = (outs Variadic<AnyRankedTensor>:$results);
Expand Down Expand Up @@ -1490,8 +1460,7 @@ def TTIR_MeshShardOp : TTIR_DPSOp<"mesh_shard"> {
AnyRankedTensor:$output,
TT_MeshShardTypeAttr:$shard_type,
TT_MeshShardDirectionAttr:$shard_direction,
TT_GridAttr:$shard_shape,
TT_OperandConstraintArrayAttr:$operand_constraints
TT_GridAttr:$shard_shape
);

let results = (outs AnyRankedTensor:$result);
Expand Down
10 changes: 0 additions & 10 deletions include/ttmlir/Dialect/TTIR/IR/TTIROpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,6 @@ include "ttmlir/Dialect/TT/IR/TTOpsTypes.td"
def TTIROpInterface : OpInterface<"TTIROp"> {
let cppNamespace = "::mlir::tt::ttir";
let methods = [
InterfaceMethod<
/*desc=*/[{
Return the constraints on the operands of this operation.
}],
/*retTy=*/"::mlir::ArrayAttr",
/*methodName=*/"getOperandConstraints",
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/""
>,
InterfaceMethod<
/*desc=*/[{
Get the device of the current scope.
Expand Down
Loading
Loading