Skip to content

Commit

Permalink
Refactor TTIRLayoutOperandsRewriter pass (#365)
Browse files Browse the repository at this point in the history
Change the pass to pattern match on DPS interface instead, it's a
generic rewrite for all DPS ops anyway.

Make the initial memory space programmable, if device mem
the runtime will automatically upload them there.
  • Loading branch information
nsmithtt authored Aug 13, 2024
1 parent 4430394 commit 28c278b
Show file tree
Hide file tree
Showing 27 changed files with 85 additions and 130 deletions.
9 changes: 3 additions & 6 deletions include/ttmlir/Dialect/TTIR/IR/TTIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def TTIR_MatmulOp : TTIR_DPSOp<"matmul"> {

def AnyRankedTensorOrMemRef: AnyTypeOf<[AnyRankedTensor, AnyNon0RankedMemRef]>;

def TTIR_KernelOp : TTIR_Op<"kernel", [DestinationStyleOpInterface, AttrSizedOperandSegments]> {
def TTIR_KernelOp : TTIR_DPSOp<"kernel", [AttrSizedOperandSegments]> {
let summary = "Kernel call.";
let description = [{
A generic kernel call operation. This operation is used to pattern match by some consuming backend.
Expand All @@ -307,12 +307,9 @@ def TTIR_KernelOp : TTIR_Op<"kernel", [DestinationStyleOpInterface, AttrSizedOpe
let arguments = (ins FlatSymbolRefAttr:$op,
FlatSymbolRefAttr:$kind,
Variadic<AnyRankedTensorOrMemRef>:$inputs,
Variadic<AnyRankedTensorOrMemRef>:$outputs);
Variadic<AnyRankedTensorOrMemRef>:$outputs,
TT_OperandConstraintArrayAttr:$operand_constraints);
let results = (outs Variadic<AnyRankedTensorOrMemRef>:$results);

let extraClassDeclaration = [{
MutableOperandRange getDpsInitsMutable() { return getOutputsMutable(); }
}];
}

def TTIR_YieldOp : TTIR_Op<"yield", [Pure, ReturnLike, Terminator]> {
Expand Down
7 changes: 7 additions & 0 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,13 @@ def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> {
let description = [{
Transition between different tensor layouts.
}];

let options = [
Option<"initMemorySpace", "init-memory-space",
"::mlir::tt::MemorySpace",
/*default=*/"::mlir::tt::MemorySpace::System",
"Set the initial memory space for tensors to start in">,
];
}

def TTIRAllocate: Pass<"ttir-allocate", "::mlir::ModuleOp"> {
Expand Down
76 changes: 43 additions & 33 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ class TTIRNamedToKernelRewriter : public OpRewritePattern<TTIROpTy> {

auto kernel = rewriter.create<ttir::KernelOp>(
op.getLoc(), op.getResultTypes(), kernelName, kernelKind,
op.getInputs(), op.getOutputs());
op.getInputs(), op.getOutputs(), op.getOperandConstraints());

rewriter.replaceOp(op, kernel);

Expand Down Expand Up @@ -343,14 +343,15 @@ inline MemorySpace uppermostMemorySpace(OperandConstraint operandConstraint) {

class TTIRLayoutTensorTypeConverter : public TypeConverter {
public:
TTIRLayoutTensorTypeConverter(MLIRContext *ctx) {
TTIRLayoutTensorTypeConverter(MLIRContext *ctx, MemorySpace initMemorySpace) {
addConversion([](Type type) { return type; });
addConversion([ctx](RankedTensorType type) -> Type {
addConversion([ctx, initMemorySpace](RankedTensorType type) -> Type {
auto layout = type.getEncoding();
if (layout) {
return type;
}
auto newLayout = LayoutAttr::get(ctx, type);
// Default to initMemorySpace, the optimizer might decide otherwise
auto newLayout = LayoutAttr::get(ctx, type, initMemorySpace);
return RankedTensorType::get(type.getShape(), type.getElementType(),
newLayout);
});
Expand Down Expand Up @@ -415,13 +416,12 @@ class TTIRLayoutTensorTypeRewriter : public RewritePattern {
const TypeConverter *converter;
};

static std::optional<Value>
createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input,
OperandConstraint operandConstraint) {
static std::optional<Value> createToLayoutOp(PatternRewriter &rewriter,
Location loc, Value input,
MemorySpace desiredMemorySpace) {
auto ty = mlir::cast<RankedTensorType>(input.getType());
auto currLayout = mlir::cast<LayoutAttr>(ty.getEncoding());
auto currMemorySpace = currLayout.getMemorySpace();
auto desiredMemorySpace = uppermostMemorySpace(operandConstraint);
if (currMemorySpace == desiredMemorySpace) {
return std::nullopt;
}
Expand All @@ -440,27 +440,38 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input,
->getResult(0);
}

template <typename TTIROpTy>
class TTIRLayoutOperandsRewriter : public OpRewritePattern<TTIROpTy> {
static std::optional<Value>
createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input,
OperandConstraint operandConstraint) {
auto desiredMemorySpace = uppermostMemorySpace(operandConstraint);
return createToLayoutOp(rewriter, loc, input, desiredMemorySpace);
}

class TTIRLayoutDPSOperandsRewriter
: public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
public:
using OpRewritePattern<TTIROpTy>::OpRewritePattern;
using OpInterfaceRewritePattern<
DestinationStyleOpInterface>::OpInterfaceRewritePattern;

LogicalResult matchAndRewrite(TTIROpTy op,
LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
PatternRewriter &rewriter) const final {
if (mlir::isa<ToLayoutOp>(op.getOperation())) {
// Skip the ToLayoutOp itself
return failure();
}

assert(op->template hasTrait<TTIROp::Trait>());
auto dpsInterface = cast<DestinationStyleOpInterface>(op.getOperation());
bool modified = false;
for (auto &operand : op->getOpOperands()) {
bool isResult = dpsInterface.isDpsInit(&operand);
bool isResult = op.isDpsInit(&operand);
auto encoding =
mlir::cast<RankedTensorType>(operand.get().getType()).getEncoding();
if (not encoding) {
return failure(); // Hasn't been type converted yet
}
assert(encoding);

auto operandConstraint =
mlir::cast<OperandConstraintAttr>(
op.getOperandConstraints()[operand.getOperandNumber()])
mlir::cast<TTIROp>(op.getOperation())
.getOperandConstraints()[operand.getOperandNumber()])
.getValue();
auto desiredLayout = createToLayoutOp(rewriter, op.getLoc(),
operand.get(), operandConstraint);
Expand Down Expand Up @@ -495,14 +506,18 @@ class TTIRLayoutOperandsRewriter : public OpRewritePattern<TTIROpTy> {
class TTIRLayoutFuncReturnRewriter
: public OpRewritePattern<mlir::func::ReturnOp> {
public:
using OpRewritePattern<mlir::func::ReturnOp>::OpRewritePattern;
TTIRLayoutFuncReturnRewriter(MLIRContext *ctx, MemorySpace initMemorySpace)
: OpRewritePattern<mlir::func::ReturnOp>(ctx),
initMemorySpace(initMemorySpace) {}

LogicalResult matchAndRewrite(mlir::func::ReturnOp op,
PatternRewriter &rewriter) const final {
bool modified = false;
for (auto &operand : op->getOpOperands()) {
// Leave the return values in initMemorySpace, optimizer might decide
// otherwise
if (auto layout = createToLayoutOp(rewriter, op.getLoc(), operand.get(),
OperandConstraint::System);
initMemorySpace);
layout) {
rewriter.modifyOpInPlace(
op, [&]() { op.setOperand(operand.getOperandNumber(), *layout); });
Expand All @@ -511,6 +526,9 @@ class TTIRLayoutFuncReturnRewriter
}
return modified ? success() : failure();
}

private:
MemorySpace initMemorySpace;
};

class TTIRLayout : public impl::TTIRLayoutBase<TTIRLayout> {
Expand All @@ -519,7 +537,8 @@ class TTIRLayout : public impl::TTIRLayoutBase<TTIRLayout> {

void runOnOperation() final {
{
TTIRLayoutTensorTypeConverter typeConverter(&getContext());
TTIRLayoutTensorTypeConverter typeConverter(&getContext(),
initMemorySpace);
RewritePatternSet patterns(&getContext());
patterns.add<TTIRLayoutTensorTypeRewriter>(typeConverter, &getContext());
FrozenRewritePatternSet patternSet(std::move(patterns));
Expand All @@ -530,18 +549,9 @@ class TTIRLayout : public impl::TTIRLayoutBase<TTIRLayout> {
}
{
RewritePatternSet patterns(&getContext());
patterns.add<
TTIRLayoutOperandsRewriter<GenericOp>,
TTIRLayoutOperandsRewriter<AddOp>,
TTIRLayoutOperandsRewriter<MultiplyOp>,
TTIRLayoutOperandsRewriter<SubtractOp>,
TTIRLayoutOperandsRewriter<GreaterEqualOp>,
TTIRLayoutOperandsRewriter<ReluOp>, TTIRLayoutOperandsRewriter<SumOp>,
TTIRLayoutOperandsRewriter<MeanOp>,
TTIRLayoutOperandsRewriter<SoftmaxOp>,
TTIRLayoutOperandsRewriter<TransposeOp>,
TTIRLayoutOperandsRewriter<MatmulOp>, TTIRLayoutFuncReturnRewriter>(
&getContext());
patterns.add<TTIRLayoutDPSOperandsRewriter>(&getContext());
patterns.add<TTIRLayoutFuncReturnRewriter>(&getContext(),
initMemorySpace);
FrozenRewritePatternSet patternSet(std::move(patterns));
if (failed(applyPatternsAndFoldGreedily(getOperation(), patternSet))) {
signalPassFailure();
Expand Down
2 changes: 2 additions & 0 deletions lib/Dialect/TTMetal/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,8 @@ void createTTIRToTTMetalBackendPipeline(OpPassManager &pm) {
pm.addPass(mlir::tt::ttir::createTTIRLoadSystemDesc());
pm.addPass(mlir::tt::ttir::createTTIRImplicitDevice());
pm.addPass(mlir::tt::ttir::createTTIRGeneric());
mlir::tt::ttir::TTIRLayoutOptions layoutOptions;
layoutOptions.initMemorySpace = mlir::tt::MemorySpace::DeviceL1;
pm.addPass(mlir::tt::ttir::createTTIRLayout());
pm.addPass(mlir::tt::ttir::createTTIRGenericRegionOperandsToMemref());
pm.addPass(mlir::tt::ttir::createTTIRAllocate());
Expand Down
4 changes: 3 additions & 1 deletion lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ void createTTIRToTTNNBackendPipeline(
OpPassManager &pm, const TTIRToTTNNBackendPipelineOptions &options) {
pm.addPass(mlir::tt::ttir::createTTIRLoadSystemDesc());
pm.addPass(mlir::tt::ttir::createTTIRImplicitDevice());
pm.addPass(mlir::tt::ttir::createTTIRLayout());
mlir::tt::ttir::TTIRLayoutOptions layoutOptions;
layoutOptions.initMemorySpace = mlir::tt::MemorySpace::System;
pm.addPass(mlir::tt::ttir::createTTIRLayout(layoutOptions));

if (options.gridSetPassEnabled) {
ttir::TTIRGridSetOptions gridSetOptions;
Expand Down
15 changes: 8 additions & 7 deletions test/ttmlir/Dialect/TTIR/test_allocate.mlir
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-layout --ttir-allocate %s | FileCheck %s
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-allocate %s | FileCheck %s
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
#l1_ = #tt.memory_space<l1>
#layout = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>>
module attributes {tt.device = #tt.device<#tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, [0]>, tt.system_desc = #tt.system_desc<[{arch = <wormhole_b0>, grid = 8x8, l1_size = 1499136, num_dram_channels = 12, dram_channel_size = 1073741824, noc_l1_address_align_bytes = 16, pcie_address_align_bytes = 32, noc_dram_address_align_bytes = 32}], [0], [<pcie|host_mmio>], [<0, 0, 0, 0>]>} {
func.func @forward(%arg0: tensor<64x128xf32, #layout>, %arg1: tensor<64x128xf32, #layout>) -> tensor<64x128xf32, #layout> {
// CHECK: %[[C:.*]] = "ttir.alloc"[[C:.*]]
// CHECK-NOT: %[[C:.*]] = tensor.empty() : tensor<64x128xf32>
%0 = tensor.empty() : tensor<64x128xf32>
%1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
// CHECK: "ttir.dealloc"[[C:.*]]
return %1 : tensor<64x128xf32>
%0 = tensor.empty() : tensor<64x128xf32, #layout>
%1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout>, tensor<64x128xf32, #layout>) -> tensor<64x128xf32, #layout>
return %1 : tensor<64x128xf32, #layout>
}
}
6 changes: 2 additions & 4 deletions test/ttmlir/Dialect/TTIR/test_grid_set.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: #layout2 = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #l1_>>
// CHECK: %[[C:.*]] = "ttir.to_layout"[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.to_layout"[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]] -> tensor<64x128xf32, #layout2>
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #l1_>>
// CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]] -> tensor<64x128xf32, #[[LAYOUT_1]]>
%1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
}
Expand Down
4 changes: 1 addition & 3 deletions test/ttmlir/Dialect/TTIR/test_layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<8x64x128xf32>, %arg1: tensor<8x64x128xf32>) -> tensor<8x64x128xf32> {
// CHECK: %[[C:.*]] = tensor.empty() : tensor<8x64x128xf32, #layout>
%0 = tensor.empty() : tensor<8x64x128xf32>
// CHECK: %[[C:.*]] = "ttir.to_layout"[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.to_layout"[[C:.*]]
// CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]]
%1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<8x64x128xf32>, tensor<8x64x128xf32>, tensor<8x64x128xf32>) -> tensor<8x64x128xf32>
return %1 : tensor<8x64x128xf32>
}
Expand Down
6 changes: 0 additions & 6 deletions test/ttmlir/Dialect/TTMetal/simple_eltwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,16 @@

func.func @multiply(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]]
// CHECK: %[[C:.*]] = "ttmetal.host_write"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]]
%1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
// CHECK: "ttmetal.dealloc"[[C:.*]]
// CHECK: %[[C:.*]] = "ttmetal.host_read"[[C:.*]]
return %1 : tensor<64x128xf32>
}

func.func @add(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: %[[C:.*]] = "ttmetal.alloc"[[C:.*]]
// CHECK: %[[C:.*]] = "ttmetal.host_write"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: %[[C:.*]] = "ttmetal.dispatch"[[C:.*]]
%1 = "ttir.add"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
// CHECK: "ttmetal.dealloc"[[C:.*]]
// CHECK: %[[C:.*]] = "ttmetal.host_read"[[C:.*]]
return %1 : tensor<64x128xf32>
}
11 changes: 5 additions & 6 deletions test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,17 @@
#loc = loc("test_ops.py:17_0_0":0:0)
module @pybuda_graph attributes {} {
func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) {
// CHECK: #layout1 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #system>>
// CHECK: #layout2 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>>
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>>
%0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2>
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
%1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5)
%2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2>
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
%3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6)
%4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2>
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
%5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7)
// CHECK: return %20, %22 : tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>
// CHECK: return %[[R0:.*]], %[[R1:.*]] : tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>
return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4)
} loc(#loc)
} loc(#loc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,19 @@
#loc = loc("test_ops.py:17_0_0":0:0)
module @pybuda_graph attributes {} {
func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) {
// CHECK: #layout1 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #system>>
// CHECK: #layout2 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <4x4>, memref<8x8xf32, #l1_>>
// CHECK: #layout3 = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>>
// CHECK: #[[LAYOUT_0:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #system>>
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <4x4>, memref<8x8xf32, #l1_>>
// CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>>
%0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2>
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
%1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5)
%2 = tensor.empty() : tensor<1x32x32xf32> loc(#loc6)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout2>
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
%3 = "ttir.add"(%1, %arg0, %2) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc6)
%4 = tensor.empty() : tensor<1x32x32xf32> loc(#loc7)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #layout3>
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_2]]>
%5 = "ttir.add"(%arg2, %arg1, %4) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc7)
// CHECK: return %20, %22 : tensor<1x32x32xf32, #layout1>, tensor<1x32x32xf32, #layout1>
// CHECK: return %[[R0:.*]], %[[R1:.*]] : tensor<1x32x32xf32, #[[LAYOUT_0]]>, tensor<1x32x32xf32, #[[LAYOUT_0]]>
return %3, %5 : tensor<1x32x32xf32>, tensor<1x32x32xf32> loc(#loc4)
} loc(#loc)
} loc(#loc)
Expand Down
Loading

0 comments on commit 28c278b

Please sign in to comment.