From 0b83c39218dbfd8bfa5242b24a3bfda352188ddb Mon Sep 17 00:00:00 2001 From: Nick Smith <127986401+nsmithtt@users.noreply.github.com> Date: Sat, 31 Aug 2024 07:36:58 -0700 Subject: [PATCH] Add a new TTIR Layout pass option defaultMemorySpace (#564) --- .../ttmlir/Dialect/TTIR/Transforms/Passes.td | 4 ++ lib/Dialect/TTIR/Transforms/Passes.cpp | 47 +++++++++++++++---- lib/Dialect/TTMetal/Transforms/Passes.cpp | 1 + lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp | 1 + test/ttmlir/Dialect/TTIR/test_grid_set.mlir | 2 +- .../Dialect/TTNN/multiple_add_with_loc.mlir | 2 +- .../multiple_add_with_loc_grid_override.mlir | 4 +- test/ttmlir/Dialect/TTNN/simple_matmul.mlir | 2 +- .../Dialect/TTNN/ttir_to_ttnn_pipeline.mlir | 2 +- .../ttir_to_ttnn_pipeline_custom_opt.mlir | 2 +- 10 files changed, 50 insertions(+), 17 deletions(-) diff --git a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td index de7ac591a..c5a67e76c 100644 --- a/include/ttmlir/Dialect/TTIR/Transforms/Passes.td +++ b/include/ttmlir/Dialect/TTIR/Transforms/Passes.td @@ -47,6 +47,10 @@ def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> { "::mlir::tt::MemorySpace", /*default=*/"::mlir::tt::MemorySpace::System", "Set the initial memory space for tensors to start in">, + Option<"defaultMemorySpace", "default-memory-space", + "::mlir::tt::MemorySpace", + /*default=*/"::mlir::tt::MemorySpace::DeviceDRAM", + "Set the default memory space for layout pass to prefer for operation operands, if not constrained">, ]; } diff --git a/lib/Dialect/TTIR/Transforms/Passes.cpp b/lib/Dialect/TTIR/Transforms/Passes.cpp index bf303f65e..9e77b4c66 100644 --- a/lib/Dialect/TTIR/Transforms/Passes.cpp +++ b/lib/Dialect/TTIR/Transforms/Passes.cpp @@ -420,13 +420,31 @@ inline MemorySpace getMemorySpace(RankedTensorType ty) { return getMemorySpace(layout); } -inline MemorySpace uppermostMemorySpace(OperandConstraint operandConstraint) { - if (bitEnumContainsAny(operandConstraint, OperandConstraint::L1)) { - return MemorySpace::DeviceL1; +inline OperandConstraint +memorySpaceAsOperandConstraint(MemorySpace memorySpace) { + switch (memorySpace) { + case MemorySpace::System: + case MemorySpace::SystemMMIO: + return OperandConstraint::System; + case MemorySpace::DeviceDRAM: + return OperandConstraint::DRAM; + case MemorySpace::DeviceL1: + return OperandConstraint::L1; + } +} + +inline MemorySpace getLegalMemorySpace(OperandConstraint operandConstraint, + MemorySpace defaultMemorySpace) { + if (bitEnumContainsAny(operandConstraint, + memorySpaceAsOperandConstraint(defaultMemorySpace))) { + return defaultMemorySpace; } if (bitEnumContainsAny(operandConstraint, OperandConstraint::DRAM)) { return MemorySpace::DeviceDRAM; } + if (bitEnumContainsAny(operandConstraint, OperandConstraint::L1)) { + return MemorySpace::DeviceL1; + } return MemorySpace::System; } @@ -547,8 +565,10 @@ static std::optional createToLayoutOp(PatternRewriter &rewriter, static std::optional createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, - OperandConstraint operandConstraint) { - auto desiredMemorySpace = uppermostMemorySpace(operandConstraint); + OperandConstraint operandConstraint, + MemorySpace defaultMemorySpace) { + auto desiredMemorySpace = + getLegalMemorySpace(operandConstraint, defaultMemorySpace); bool tiled = !bitEnumContainsAny(operandConstraint, OperandConstraint::Scalar); return createToLayoutOp(rewriter, loc, input, desiredMemorySpace, tiled); @@ -557,8 +577,10 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input, class TTIRLayoutDPSOperandsRewriter : public OpInterfaceRewritePattern { public: - using OpInterfaceRewritePattern< - DestinationStyleOpInterface>::OpInterfaceRewritePattern; + TTIRLayoutDPSOperandsRewriter(MLIRContext *ctx, + MemorySpace defaultMemorySpace) + : OpInterfaceRewritePattern(ctx), + defaultMemorySpace(defaultMemorySpace) {} LogicalResult matchAndRewrite(DestinationStyleOpInterface op, PatternRewriter &rewriter) const final { @@ -582,8 +604,9 @@ class TTIRLayoutDPSOperandsRewriter mlir::cast(op.getOperation()) .getOperandConstraints()[operand.getOperandNumber()]) .getValue(); - auto desiredLayout = createToLayoutOp(rewriter, op.getLoc(), - operand.get(), operandConstraint); + auto desiredLayout = + createToLayoutOp(rewriter, op.getLoc(), operand.get(), + operandConstraint, defaultMemorySpace); if (desiredLayout) { rewriter.modifyOpInPlace(op, [&]() { @@ -599,6 +622,9 @@ class TTIRLayoutDPSOperandsRewriter return modified ? success() : failure(); } + +private: + MemorySpace defaultMemorySpace; }; class TTIRLayoutFuncReturnRewriter @@ -650,7 +676,8 @@ class TTIRLayout : public impl::TTIRLayoutBase { } { RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); + patterns.add(&getContext(), + defaultMemorySpace); patterns.add(&getContext(), initMemorySpace); FrozenRewritePatternSet patternSet(std::move(patterns)); diff --git a/lib/Dialect/TTMetal/Transforms/Passes.cpp b/lib/Dialect/TTMetal/Transforms/Passes.cpp index 4146db7c2..76f7763a8 100644 --- a/lib/Dialect/TTMetal/Transforms/Passes.cpp +++ b/lib/Dialect/TTMetal/Transforms/Passes.cpp @@ -840,6 +840,7 @@ void createTTIRToTTMetalBackendPipeline(OpPassManager &pm) { pm.addPass(mlir::tt::ttir::createTTIRGenericRegion()); mlir::tt::ttir::TTIRLayoutOptions layoutOptions; layoutOptions.initMemorySpace = mlir::tt::MemorySpace::DeviceL1; + layoutOptions.defaultMemorySpace = mlir::tt::MemorySpace::DeviceL1; pm.addPass(mlir::tt::ttir::createTTIRLayout(layoutOptions)); pm.addPass(mlir::tt::ttir::createTTIRGenericRegionOperandsToMemref()); pm.addPass(mlir::tt::ttir::createTTIRAllocate()); diff --git a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp index cb78dfd68..56c05e8fe 100644 --- a/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp +++ b/lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp @@ -23,6 +23,7 @@ void createTTIRToTTNNBackendPipeline( pm.addPass(mlir::tt::ttir::createTTIRImplicitDevice()); mlir::tt::ttir::TTIRLayoutOptions layoutOptions; layoutOptions.initMemorySpace = mlir::tt::MemorySpace::System; + layoutOptions.defaultMemorySpace = mlir::tt::MemorySpace::DeviceDRAM; pm.addPass(mlir::tt::ttir::createTTIRLayout(layoutOptions)); if (options.gridSetPassEnabled) { diff --git a/test/ttmlir/Dialect/TTIR/test_grid_set.mlir b/test/ttmlir/Dialect/TTIR/test_grid_set.mlir index bf6eae61e..0860ff4da 100644 --- a/test/ttmlir/Dialect/TTIR/test_grid_set.mlir +++ b/test/ttmlir/Dialect/TTIR/test_grid_set.mlir @@ -3,7 +3,7 @@ module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { %0 = tensor.empty() : tensor<64x128xf32> - // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #l1_>> + // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #dram>> // CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]] -> tensor<64x128xf32, #[[LAYOUT_1]]> %1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32> return %1 : tensor<64x128xf32> diff --git a/test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir index 5ba74e6f6..a8616f152 100644 --- a/test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir +++ b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir @@ -3,7 +3,7 @@ #loc = loc("test_ops.py:17_0_0":0:0) module @pybuda_graph attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) { - // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>> + // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #dram>> %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]> %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) diff --git a/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_grid_override.mlir b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_grid_override.mlir index ae356c481..adf62660b 100644 --- a/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_grid_override.mlir +++ b/test/ttmlir/Dialect/TTNN/multiple_add_with_loc_grid_override.mlir @@ -4,8 +4,8 @@ module @pybuda_graph attributes {} { func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) { // CHECK: #[[LAYOUT_0:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #system>> - // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <4x4>, memref<8x8xf32, #l1_>> - // CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>> + // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <4x4>, memref<8x8xf32, #dram>> + // CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #dram>> %0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5) // CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]> %1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5) diff --git a/test/ttmlir/Dialect/TTNN/simple_matmul.mlir b/test/ttmlir/Dialect/TTNN/simple_matmul.mlir index 992b0c21d..f8ee937e7 100644 --- a/test/ttmlir/Dialect/TTNN/simple_matmul.mlir +++ b/test/ttmlir/Dialect/TTNN/simple_matmul.mlir @@ -1,6 +1,6 @@ // RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s #any_device_tile = #tt.operand_constraint -// CHECK: #[[TILED_LAYOUT:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>> +// CHECK: #[[TILED_LAYOUT:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>> module attributes {} { func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> { %0 = tensor.empty() : tensor<64x96xbf16> diff --git a/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline.mlir b/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline.mlir index 00c67542a..cfdfde2d1 100644 --- a/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline.mlir +++ b/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline.mlir @@ -2,7 +2,7 @@ #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #l1_>> + // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #dram>> // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32> diff --git a/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir b/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir index 7b1d1ee47..e1acc7c80 100644 --- a/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir +++ b/test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline_custom_opt.mlir @@ -2,7 +2,7 @@ #any_device = #tt.operand_constraint module attributes {} { func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> { - // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>> + // CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #dram>> // CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]] // CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]] %0 = tensor.empty() : tensor<64x128xf32>