Skip to content

Commit

Permalink
Add a new TTIR Layout pass option defaultMemorySpace (#564)
Browse files Browse the repository at this point in the history
  • Loading branch information
nsmithtt authored Aug 31, 2024
1 parent 19c5407 commit 0b83c39
Show file tree
Hide file tree
Showing 10 changed files with 50 additions and 17 deletions.
4 changes: 4 additions & 0 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def TTIRLayout: Pass<"ttir-layout", "::mlir::ModuleOp"> {
"::mlir::tt::MemorySpace",
/*default=*/"::mlir::tt::MemorySpace::System",
"Set the initial memory space for tensors to start in">,
Option<"defaultMemorySpace", "default-memory-space",
"::mlir::tt::MemorySpace",
/*default=*/"::mlir::tt::MemorySpace::DeviceDRAM",
"Set the default memory space for layout pass to prefer for operation operands, if not constrained">,
];
}

Expand Down
47 changes: 37 additions & 10 deletions lib/Dialect/TTIR/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,13 +420,31 @@ inline MemorySpace getMemorySpace(RankedTensorType ty) {
return getMemorySpace(layout);
}

inline MemorySpace uppermostMemorySpace(OperandConstraint operandConstraint) {
if (bitEnumContainsAny(operandConstraint, OperandConstraint::L1)) {
return MemorySpace::DeviceL1;
inline OperandConstraint
memorySpaceAsOperandConstraint(MemorySpace memorySpace) {
switch (memorySpace) {
case MemorySpace::System:
case MemorySpace::SystemMMIO:
return OperandConstraint::System;
case MemorySpace::DeviceDRAM:
return OperandConstraint::DRAM;
case MemorySpace::DeviceL1:
return OperandConstraint::L1;
}
}

inline MemorySpace getLegalMemorySpace(OperandConstraint operandConstraint,
MemorySpace defaultMemorySpace) {
if (bitEnumContainsAny(operandConstraint,
memorySpaceAsOperandConstraint(defaultMemorySpace))) {
return defaultMemorySpace;
}
if (bitEnumContainsAny(operandConstraint, OperandConstraint::DRAM)) {
return MemorySpace::DeviceDRAM;
}
if (bitEnumContainsAny(operandConstraint, OperandConstraint::L1)) {
return MemorySpace::DeviceL1;
}
return MemorySpace::System;
}

Expand Down Expand Up @@ -547,8 +565,10 @@ static std::optional<Value> createToLayoutOp(PatternRewriter &rewriter,

static std::optional<Value>
createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input,
OperandConstraint operandConstraint) {
auto desiredMemorySpace = uppermostMemorySpace(operandConstraint);
OperandConstraint operandConstraint,
MemorySpace defaultMemorySpace) {
auto desiredMemorySpace =
getLegalMemorySpace(operandConstraint, defaultMemorySpace);
bool tiled =
!bitEnumContainsAny(operandConstraint, OperandConstraint::Scalar);
return createToLayoutOp(rewriter, loc, input, desiredMemorySpace, tiled);
Expand All @@ -557,8 +577,10 @@ createToLayoutOp(PatternRewriter &rewriter, Location loc, Value input,
class TTIRLayoutDPSOperandsRewriter
: public OpInterfaceRewritePattern<DestinationStyleOpInterface> {
public:
using OpInterfaceRewritePattern<
DestinationStyleOpInterface>::OpInterfaceRewritePattern;
TTIRLayoutDPSOperandsRewriter(MLIRContext *ctx,
MemorySpace defaultMemorySpace)
: OpInterfaceRewritePattern<DestinationStyleOpInterface>(ctx),
defaultMemorySpace(defaultMemorySpace) {}

LogicalResult matchAndRewrite(DestinationStyleOpInterface op,
PatternRewriter &rewriter) const final {
Expand All @@ -582,8 +604,9 @@ class TTIRLayoutDPSOperandsRewriter
mlir::cast<TTIROp>(op.getOperation())
.getOperandConstraints()[operand.getOperandNumber()])
.getValue();
auto desiredLayout = createToLayoutOp(rewriter, op.getLoc(),
operand.get(), operandConstraint);
auto desiredLayout =
createToLayoutOp(rewriter, op.getLoc(), operand.get(),
operandConstraint, defaultMemorySpace);

if (desiredLayout) {
rewriter.modifyOpInPlace(op, [&]() {
Expand All @@ -599,6 +622,9 @@ class TTIRLayoutDPSOperandsRewriter

return modified ? success() : failure();
}

private:
MemorySpace defaultMemorySpace;
};

class TTIRLayoutFuncReturnRewriter
Expand Down Expand Up @@ -650,7 +676,8 @@ class TTIRLayout : public impl::TTIRLayoutBase<TTIRLayout> {
}
{
RewritePatternSet patterns(&getContext());
patterns.add<TTIRLayoutDPSOperandsRewriter>(&getContext());
patterns.add<TTIRLayoutDPSOperandsRewriter>(&getContext(),
defaultMemorySpace);
patterns.add<TTIRLayoutFuncReturnRewriter>(&getContext(),
initMemorySpace);
FrozenRewritePatternSet patternSet(std::move(patterns));
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTMetal/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,7 @@ void createTTIRToTTMetalBackendPipeline(OpPassManager &pm) {
pm.addPass(mlir::tt::ttir::createTTIRGenericRegion());
mlir::tt::ttir::TTIRLayoutOptions layoutOptions;
layoutOptions.initMemorySpace = mlir::tt::MemorySpace::DeviceL1;
layoutOptions.defaultMemorySpace = mlir::tt::MemorySpace::DeviceL1;
pm.addPass(mlir::tt::ttir::createTTIRLayout(layoutOptions));
pm.addPass(mlir::tt::ttir::createTTIRGenericRegionOperandsToMemref());
pm.addPass(mlir::tt::ttir::createTTIRAllocate());
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ void createTTIRToTTNNBackendPipeline(
pm.addPass(mlir::tt::ttir::createTTIRImplicitDevice());
mlir::tt::ttir::TTIRLayoutOptions layoutOptions;
layoutOptions.initMemorySpace = mlir::tt::MemorySpace::System;
layoutOptions.defaultMemorySpace = mlir::tt::MemorySpace::DeviceDRAM;
pm.addPass(mlir::tt::ttir::createTTIRLayout(layoutOptions));

if (options.gridSetPassEnabled) {
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTIR/test_grid_set.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
%0 = tensor.empty() : tensor<64x128xf32>
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #l1_>>
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #dram>>
// CHECK: %[[C:.*]] = "ttir.multiply"[[C:.*]] -> tensor<64x128xf32, #[[LAYOUT_1]]>
%1 = "ttir.multiply"(%arg0, %arg1, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<64x128xf32>, tensor<64x128xf32>, tensor<64x128xf32>) -> tensor<64x128xf32>
return %1 : tensor<64x128xf32>
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/multiple_add_with_loc.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
#loc = loc("test_ops.py:17_0_0":0:0)
module @pybuda_graph attributes {} {
func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) {
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>>
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #dram>>
%0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
%1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
module @pybuda_graph attributes {} {
func.func @main(%arg0: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg1: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0), %arg2: tensor<1x32x32xf32> loc("test_ops.py:17_0_0":0:0)) -> (tensor<1x32x32xf32>, tensor<1x32x32xf32>) {
// CHECK: #[[LAYOUT_0:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #system>>
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <4x4>, memref<8x8xf32, #l1_>>
// CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #l1_>>
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <4x4>, memref<8x8xf32, #dram>>
// CHECK: #[[LAYOUT_2:.*]] = #tt.layout<(d0, d1, d2) -> (d0 * 32 + d1, d2), undef, <8x8>, memref<4x4xf32, #dram>>
%0 = tensor.empty() : tensor<1x32x32xf32> loc(#loc5)
// CHECK: %[[C:.*]] = "ttnn.add"[[C:.*]] -> tensor<1x32x32xf32, #[[LAYOUT_1]]>
%1 = "ttir.add"(%arg1, %arg2, %0) <{operandSegmentSizes = array<i32: 2, 1>, operand_constraints = [#any_device, #any_device, #any_device]}> : (tensor<1x32x32xf32>, tensor<1x32x32xf32>, tensor<1x32x32xf32>) -> tensor<1x32x32xf32> loc(#loc5)
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/simple_matmul.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: ttmlir-opt --ttir-load-system-desc --ttir-implicit-device --ttir-layout --ttnn-open-device --convert-ttir-to-ttnn %s | FileCheck %s
#any_device_tile = #tt.operand_constraint<dram|l1|tile|any_device_tile>
// CHECK: #[[TILED_LAYOUT:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #l1_>>
// CHECK: #[[TILED_LAYOUT:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<2x4x!tt.tile<32x32, bf16>, #dram>>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>) -> tensor<64x96xbf16> {
%0 = tensor.empty() : tensor<64x96xbf16>
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Dialect/TTNN/ttir_to_ttnn_pipeline.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #l1_>>
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <8x8>, memref<8x16xf32, #dram>>
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#any_device = #tt.operand_constraint<dram|l1|scalar|tile|any_device|any_device_tile>
module attributes {} {
func.func @forward(%arg0: tensor<64x128xf32>, %arg1: tensor<64x128xf32>) -> tensor<64x128xf32> {
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #l1_>>
// CHECK: #[[LAYOUT_1:.*]] = #tt.layout<(d0, d1) -> (d0, d1), undef, <1x1>, memref<64x128xf32, #dram>>
// CHECK: %[[C:.*]] = "ttnn.open_device"[[C:.*]]
// CHECK: %[[C:.*]] = "ttnn.empty"[[C:.*]]
%0 = tensor.empty() : tensor<64x128xf32>
Expand Down

0 comments on commit 0b83c39

Please sign in to comment.