Skip to content

Commit

Permalink
Added unittest for getGreedyConfig method in L1InterleavedPolicy & fi…
Browse files Browse the repository at this point in the history
…xed linker error
  • Loading branch information
fbajraktariTT committed Dec 2, 2024
1 parent 542f3c5 commit 74ce175
Show file tree
Hide file tree
Showing 8 changed files with 117 additions and 69 deletions.
30 changes: 15 additions & 15 deletions include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,21 @@ class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {
: MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts,
schedule, usableL1CacheSize) {}

/**
* Retrieve the greedy OpConfig for the given base operation
* and its opsL1Usage map.
*
* @param baseOp The base operation for which the greedy configuration is
* being determined.
* @param opsL1Usage A map between the operation and its output L1 usage. All
* operations included in the opsL1Usage map must be either the baseOp or its
* operand with a legal L1 Interleaved output layout at the time of analyzing
* the baseOp.
* @return The greedy OpConfig for the baseOp.
*/
OpConfig getGreedyConfig(Operation *baseOp,
llvm::DenseMap<Operation *, L1Usage> &opsL1Usage);

void run() final;

private:
Expand All @@ -78,21 +93,6 @@ class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {
return tensorL1UsageCap * usableL1CacheSize;
}

/**
* Retrieve the greedy OpConfig for the given base operation
* and its opsL1Usage map.
*
* @param baseOp The base operation for which the greedy configuration is
* being determined.
* @param opsL1Usage A map between the operation and its output L1 usage. All
* operations included in the opsL1Usage map must be either the baseOp or its
* operand with a legal L1 Interleaved output layout at the time of analyzing
* the baseOp.
* @return The greedy OpConfig for the baseOp.
*/
OpConfig getGreedyConfig(Operation *baseOp,
llvm::DenseMap<Operation *, L1Usage> &opsL1Usage);

// Precedence schedule map for each operation. It contains the order
// in which operands need to be executed for each op.
llvm::DenseMap<Operation *, llvm::SmallVector<Operation *>> precedenceMap;
Expand Down
4 changes: 2 additions & 2 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
bool hasShardedL1TensorMemoryLayout() const;
bool hasInterleavedL1TensorMemoryLayout() const;
bool hasInterleavedDRAMTensorMemoryLayout() const;
bool hasL1BufferTypeLayout() const;
bool hasDRAMBufferTypeLayout() const;
bool hasL1BufferType() const;
bool hasDRAMBufferType() const;
bool isTiled() const;
Layout getLayout() const;
Type getElementType() const;
Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ add_mlir_dialect_library(MLIRTTNNAnalysis
MLIRTTNNPassesIncGen
MLIRTTOpsIncGen

LINK_LIBS
LINK_LIBS PUBLIC
MLIRScheduler
)
6 changes: 3 additions & 3 deletions lib/Dialect/TTNN/Analysis/L1InterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ uint64_t getOpOutputL1Usage(Operation *op, TTNNLayoutAttr opLayout,
DeviceAttr &deviceAttr) {
// In case the opLayout is not in L1 memory space, L1 memory usage is 0.
//
if (opLayout.hasDRAMBufferTypeLayout()) {
if (opLayout.hasDRAMBufferType()) {
return 0;
}

Expand Down Expand Up @@ -322,15 +322,15 @@ bool L1InterleavedPolicy::isAnalyzable(Operation *op) {
bool L1InterleavedPolicy::hasDRAMBufferType(Operation *op) {
return std::find_if(legalLayouts[op].begin(), legalLayouts[op].end(),
[](TTNNLayoutAttr layout) {
return layout.hasDRAMBufferTypeLayout();
return layout.hasDRAMBufferType();
}) != legalLayouts[op].end();
}

TTNNLayoutAttr L1InterleavedPolicy::getDRAMLayout(Operation *op) {
assert(hasDRAMBufferType(op));
auto dramLayoutIter = std::find_if(
legalLayouts[op].begin(), legalLayouts[op].end(),
[](TTNNLayoutAttr layout) { return layout.hasDRAMBufferTypeLayout(); });
[](TTNNLayoutAttr layout) { return layout.hasDRAMBufferType(); });
return *dramLayoutIter;
}

Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Analysis/MemoryLayoutAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ filterDRAMAndL1Interleaved(
for (const auto &opLayouts : legalLayouts) {
std::vector<TTNNLayoutAttr> opL1InterleavedLayouts;
for (const auto &layout : opLayouts.second) {
if (layout.hasDRAMBufferTypeLayout() ||
if (layout.hasDRAMBufferType() ||
layout.hasInterleavedL1TensorMemoryLayout()) {
opL1InterleavedLayouts.push_back(layout);
}
Expand Down
10 changes: 5 additions & 5 deletions lib/Dialect/TTNN/IR/TTNNOpsAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ Layout TTNNLayoutAttr::getLayout() const {
}

// Check if the tensor memory buffer type is L1
bool TTNNLayoutAttr::hasL1BufferTypeLayout() const {
bool TTNNLayoutAttr::hasL1BufferType() const {
return isL1BufferType(getBufferType());
}

// Check if the tensor memory buffer type is DRAM
bool TTNNLayoutAttr::hasDRAMBufferTypeLayout() const {
bool TTNNLayoutAttr::hasDRAMBufferType() const {
return isDRAMBufferType(getBufferType());
}

Expand All @@ -63,21 +63,21 @@ bool TTNNLayoutAttr::hasShardedTensorMemoryLayout() const {

// Check if the tensor memory layout is sharded in L1 memory
bool TTNNLayoutAttr::hasShardedL1TensorMemoryLayout() const {
return hasL1BufferTypeLayout() &&
return hasL1BufferType() &&
(getMemLayout() == TensorMemoryLayout::HeightSharded ||
getMemLayout() == TensorMemoryLayout::WidthSharded ||
getMemLayout() == TensorMemoryLayout::BlockSharded);
}

// Check if the tensor memory layout is interleaved and in L1 memory
bool TTNNLayoutAttr::hasInterleavedL1TensorMemoryLayout() const {
return hasL1BufferTypeLayout() &&
return hasL1BufferType() &&
(getMemLayout() == TensorMemoryLayout::Interleaved);
}

// Check if the tensor memory layout is interleaved and in DRAM memory
bool TTNNLayoutAttr::hasInterleavedDRAMTensorMemoryLayout() const {
return hasDRAMBufferTypeLayout() &&
return hasDRAMBufferType() &&
(getMemLayout() == TensorMemoryLayout::Interleaved);
}

Expand Down
1 change: 1 addition & 0 deletions test/unittests/Optimizer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ target_link_libraries(OptimizerTests
PRIVATE
MLIR
MLIRTTDialect
MLIRTTNNAnalysis
MLIRTTNNPipelines
)
131 changes: 89 additions & 42 deletions test/unittests/Optimizer/TestL1InterleavedPolicy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@

using namespace mlir::tt::ttnn;

constexpr int TensorDimX = 128;
constexpr int TensorDimY = 128;

class L1InterleavedPolicyBase : public ::testing::Test {
public:
mlir::MLIRContext context;
Expand All @@ -37,82 +40,78 @@ class L1InterleavedPolicyBase : public ::testing::Test {
module = mlir::ModuleOp::create(builder.getUnknownLoc());
builder.setInsertionPointToStart(&module->getBodyRegion().front());
createFuncOp();
// deviceAttr = mlir::tt::getCurrentScopeDevice(func);
deviceAttr = mlir::tt::getCurrentScopeDevice(func);
}

llvm::SmallVector<int64_t, 2> getDefaultTensorShape() { return {32, 32}; }
llvm::SmallVector<int64_t, 2> getTensorShape() {
return {TensorDimX, TensorDimY};
}

mlir::RankedTensorType
getTensorRankedType(llvm::SmallVector<int64_t, 2> tensorShape) {
return mlir::RankedTensorType::get(tensorShape, builder.getF32Type());
mlir::RankedTensorType getTensorRankedType() {
return mlir::RankedTensorType::get(getTensorShape(), builder.getF32Type());
}

mlir::Value createEmptyTensor(llvm::SmallVector<int64_t, 2> tensorShape) {
ShapeAttr shapeAttr = ShapeAttr::get(&context, tensorShape);
mlir::Value createEmptyTensor() {
ShapeAttr shapeAttr = ShapeAttr::get(&context, getTensorShape());
return builder.create<EmptyOp>(builder.getUnknownLoc(),
getTensorRankedType(tensorShape), nullptr,
shapeAttr, nullptr, nullptr, nullptr);
getTensorRankedType(), nullptr, shapeAttr,
nullptr, nullptr, nullptr);
}

mlir::func::FuncOp createFuncOp() {
mlir::SmallVector<mlir::Type> input;
input.push_back(getTensorRankedType(getDefaultTensorShape()));
input.push_back(getTensorRankedType());

mlir::SmallVector<mlir::Type> output;
output.push_back(getTensorRankedType(getDefaultTensorShape()));
output.push_back(getTensorRankedType());

auto funcType = builder.getType<mlir::FunctionType>(
mlir::TypeRange(input), mlir::TypeRange(output));
func = builder.create<mlir::func::FuncOp>(builder.getUnknownLoc(), "test",
funcType);

mlir::Block *block = func.addEntryBlock();
block->addArgument(getTensorRankedType(getDefaultTensorShape()),
builder.getUnknownLoc());
block->addArgument(getTensorRankedType(getDefaultTensorShape()),
builder.getUnknownLoc());
block->addArgument(getTensorRankedType(), builder.getUnknownLoc());
block->addArgument(getTensorRankedType(), builder.getUnknownLoc());

builder.setInsertionPointToStart(block);

return func;
}

void
addLayoutForOp(mlir::Operation *op, llvm::SmallVector<int64_t, 2> tensorShape,
llvm::DenseMap<mlir::Operation *, std::vector<TTNNLayoutAttr>>
&legalLayouts,
BufferType memorySpace,
TensorMemoryLayout tensorMemoryLayout) {
void addLayoutForOp(mlir::Operation *op,
llvm::DenseMap<mlir::Operation *,
std::vector<TTNNLayoutAttr>> &legalLayouts,
BufferType memorySpace,
TensorMemoryLayout tensorMemoryLayout) {
if (legalLayouts.find(op) == legalLayouts.end()) {
legalLayouts[op] = std::vector<TTNNLayoutAttr>{TTNNLayoutAttr::get(
&context, getTensorRankedType(tensorShape).getShape(),
builder.getF32Type(), memorySpace,
mlir::tt::GridAttr::get(&context, {8, 8}), tensorMemoryLayout)};
&context, getTensorRankedType().getShape(), builder.getF32Type(),
memorySpace, mlir::tt::GridAttr::get(&context, {8, 8}),
tensorMemoryLayout)};
} else {
legalLayouts[op].push_back(TTNNLayoutAttr::get(
&context, getTensorRankedType(tensorShape).getShape(),
builder.getF32Type(), memorySpace,
mlir::tt::GridAttr::get(&context, {8, 8}), tensorMemoryLayout));
&context, getTensorRankedType().getShape(), builder.getF32Type(),
memorySpace, mlir::tt::GridAttr::get(&context, {8, 8}),
tensorMemoryLayout));
}
}

void prepareOpForGreedyConfigPicker(
mlir::Operation *op, llvm::SmallVector<int64_t, 2> tensorShape,
uint64_t requiredL1Usage,
mlir::Operation *op, uint64_t outputL1Usage, uint64_t requiredL1Usage,
llvm::DenseMap<mlir::Operation *, std::vector<TTNNLayoutAttr>>
&legalLayouts,
llvm::DenseMap<mlir::Operation *, L1Usage> &opsL1Usage) {

// Add two legal layouts for the op with different buffer
// types: DRAM and L1.
addLayoutForOp(op, tensorShape, legalLayouts, BufferType::DRAM,
addLayoutForOp(op, legalLayouts, BufferType::DRAM,
TensorMemoryLayout::Interleaved);
addLayoutForOp(op, tensorShape, legalLayouts, BufferType::L1,
addLayoutForOp(op, legalLayouts, BufferType::L1,
TensorMemoryLayout::Interleaved);

L1Usage l1Usage;
// l1Usage.outputL1Usage =
// legalLayouts[op][1].getTensorSizeInBytes(tensorShape, deviceAttr);
l1Usage.outputL1Usage = outputL1Usage;
l1Usage.requiredL1Usage = requiredL1Usage;
opsL1Usage[op] = l1Usage;
}
Expand All @@ -126,21 +125,69 @@ TEST_F(L1InterleavedPolicyBase, VerifyGreedyPolicy) {
llvm::DenseMap<mlir::func::FuncOp, llvm::SmallVector<mlir::Operation *>>
schedule;
llvm::DenseMap<mlir::Operation *, L1Usage> opsL1Usage;
constexpr uint64_t usableL1CacheSize = 1024 * 1024;
constexpr uint64_t usableL1CacheSize = 15;

mlir::Value dest = createEmptyTensor(getDefaultTensorShape());
// Create operand A
mlir::Value dest = createEmptyTensor();
mlir::Value lhs = func.getBody().getBlocks().front().getArgument(0);
mlir::Value rhs = func.getBody().getBlocks().front().getArgument(1);
mlir::Operation *op =
mlir::Operation *opA =
builder.create<AddOp>(builder.getUnknownLoc(), lhs, rhs, dest);
uint64_t outputL1Usage = 2;
uint64_t requiredL1Usage = 8;
prepareOpForGreedyConfigPicker(opA, outputL1Usage, requiredL1Usage,
legalLayouts, opsL1Usage);

// Create operand B
dest = createEmptyTensor();
lhs = func.getBody().getBlocks().front().getArgument(0);
rhs = func.getBody().getBlocks().front().getArgument(1);
mlir::Operation *opB =
builder.create<AddOp>(builder.getUnknownLoc(), lhs, rhs, dest);
outputL1Usage = 3;
requiredL1Usage = 7;
prepareOpForGreedyConfigPicker(opB, outputL1Usage, requiredL1Usage,
legalLayouts, opsL1Usage);

// Create operand C
dest = createEmptyTensor();
lhs = func.getBody().getBlocks().front().getArgument(0);
rhs = func.getBody().getBlocks().front().getArgument(1);
mlir::Operation *opC =
builder.create<AddOp>(builder.getUnknownLoc(), lhs, rhs, dest);
llvm::SmallVector<int64_t, 2> opTensorShape = getDefaultTensorShape();
prepareOpForGreedyConfigPicker(op, opTensorShape, 0, legalLayouts,
opsL1Usage);
outputL1Usage = 1;
requiredL1Usage = 9;
prepareOpForGreedyConfigPicker(opC, outputL1Usage, requiredL1Usage,
legalLayouts, opsL1Usage);

// Create base op D
dest = createEmptyTensor();
lhs = func.getBody().getBlocks().front().getArgument(0);
rhs = func.getBody().getBlocks().front().getArgument(1);
mlir::Operation *opD =
builder.create<AddOp>(builder.getUnknownLoc(), lhs, rhs, dest);
outputL1Usage = 4;
requiredL1Usage = 0;
prepareOpForGreedyConfigPicker(opD, outputL1Usage, requiredL1Usage,
legalLayouts, opsL1Usage);

// Run greedy config picker policy
L1InterleavedPolicy l1InterleavedPolicy(nullptr, l1ChainConfigs, legalLayouts,
schedule, usableL1CacheSize);
// OpConfig greedyConfig = l1InterleavedPolicy.getGreedyConfig(op,
// opsL1Usage);
OpConfig greedyConfig = l1InterleavedPolicy.getGreedyConfig(opD, opsL1Usage);

// Sanity checks
ASSERT_TRUE(greedyConfig.baseOp == opD);
ASSERT_TRUE(greedyConfig.layouts.size() == 4);
ASSERT_TRUE(greedyConfig.precedence.size() == 3);

// All layouts should be using L1 buffer type
for (const auto &[op, layout] : greedyConfig.layouts) {
ASSERT_TRUE(layout.hasL1BufferType());
}

ASSERT_TRUE(true);
// Precedence order for op D should be: C, A, B
ASSERT_EQ(greedyConfig.precedence[0], opC);
ASSERT_EQ(greedyConfig.precedence[1], opA);
ASSERT_EQ(greedyConfig.precedence[2], opB);
}

0 comments on commit 74ce175

Please sign in to comment.