Skip to content

Commit

Permalink
Version 2.0 of L1 Interleaved policy
Browse files Browse the repository at this point in the history
  • Loading branch information
fbajraktariTT committed Dec 2, 2024
1 parent cfbc6a1 commit 062a978
Show file tree
Hide file tree
Showing 32 changed files with 1,111 additions and 152 deletions.
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class L1ChainConfig {
std::unordered_set<Edge> &memReconfigEdges);

bool isEmpty() { return opL1MemSpecs.empty(); }
void addOpL1MemSpec(OpL1MemSpec &&spec) {
void addOpL1MemSpec(OpL1MemSpec spec) {
assert(state == L1ChainState::InBuild);
l1ChainedOps.insert(spec.op);
opL1MemSpecs.push_back(std::move(spec));
Expand Down
97 changes: 97 additions & 0 deletions include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,43 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"

namespace mlir::tt::ttnn {

class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {
public:
struct OpMemSpec {
TTNNLayoutAttr layout;
// Minimum L1 memory usage required for scheduling the op
// given the layouts of all the ops that are already scheduled.
//
uint64_t requiredL1Usage;
};

// This struct is holding information about the greedily choosen
// configuration of the @baseOp: 1) layouts and 2) precedence.
//
// The @layouts represents the mapping between the op and its choosen
// layout. All the ops that are included in the @layouts map must be
// either @baseOp or its operand with legal L1 Interleaved output layout
// at the moment of analyzing the @baseOp.
//
// The @precedence represents the order of the op's operands in which they
// should be scheduled. Only op's operands that are included in the @layouts
// map are included in the @precedence.
//
struct OpConfig {
Operation *baseOp;
llvm::DenseMap<Operation *, TTNNLayoutAttr> layouts;
llvm::SmallVector<Operation *> precedence;
};

struct L1Usage {
size_t outputL1Usage;
size_t requiredL1Usage;
};

public:
L1InterleavedPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
Expand All @@ -22,7 +55,71 @@ class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {
: MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts,
schedule, usableL1CacheSize) {}

/**
* Retrieve the greedy OpConfig for the given base operation
* and its opsL1Usage map.
*
* @param baseOp The base operation for which the greedy configuration is
* being determined.
* @param opsL1Usage A map between the operation and its output L1 usage. All
* operations included in the opsL1Usage map must be either the baseOp or its
* operand with a legal L1 Interleaved output layout at the time of analyzing
* the baseOp.
* @return The greedy OpConfig for the baseOp.
*/
OpConfig getGreedyConfig(Operation *baseOp,
llvm::DenseMap<Operation *, L1Usage> &opsL1Usage);

void run() final;

private:
// Check if the op is analyzable. Op is analyzable if it has at least one
// legal layout.
bool isAnalyzable(Operation *op);

// Fetch op's DRAM layout from legalLayouts.
bool hasDRAMBufferType(Operation *op);
TTNNLayoutAttr getDRAMLayout(Operation *op);

// Fetch op's L1 Interleaved layout from legalLayouts.
bool hasL1BufferType(Operation *op);
TTNNLayoutAttr getL1InterleavedLayout(Operation *op);

size_t getAvailableL1CacheSize() const {
// Figure out this const based on exec data, but will be replaced
// with API.
//
constexpr float tensorL1UsageCap = 0.75;
return tensorL1UsageCap * usableL1CacheSize;
}

// Precedence schedule map for each operation. It contains the order
// in which operands need to be executed for each op.
llvm::DenseMap<Operation *, llvm::SmallVector<Operation *>> precedenceMap;

llvm::DenseSet<Operation *> visitedOps;
void buildSchedule(mlir::Operation *op, func::FuncOp &func) {

// Schedule all the precedents of the current operation
//
visitedOps.insert(op);
for (Operation *precedent : precedenceMap[op]) {
if (!visitedOps.count(precedent)) {
buildSchedule(precedent, func);
}
}

(*schedule)[func].push_back(op);
}

void constructSchedule(func::FuncOp &func) {
func->walk([&](Operation *op) {
if (op->hasTrait<mlir::OpTrait::ReturnLike>()) {
Operation *outputOp = op->getOperand(0).getDefiningOp();
buildSchedule(outputOp, func);
}
});
}
};

} // namespace mlir::tt::ttnn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
#define TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSIS_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Analysis/Edge.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h"
#include "ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h"

namespace mlir::tt::ttnn {

Expand Down
3 changes: 3 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
bool hasShardedTensorMemoryLayout() const;
bool hasShardedL1TensorMemoryLayout() const;
bool hasInterleavedL1TensorMemoryLayout() const;
bool hasInterleavedDRAMTensorMemoryLayout() const;
bool hasL1BufferType() const;
bool hasDRAMBufferType() const;
bool isTiled() const;
Layout getLayout() const;
Type getElementType() const;
Expand Down
3 changes: 1 addition & 2 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
#ifndef TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H
#define TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H

#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"

#include "mlir/Pass/PassOptions.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSISPARAMS_H
#define TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSISPARAMS_H
#ifndef TTMLIR_DIALECT_TTNN_UTILS_MEMORYLAYOUTANALYSISPARAMS_H
#define TTMLIR_DIALECT_TTNN_UTILS_MEMORYLAYOUTANALYSISPARAMS_H

#include <llvm/ADT/StringSwitch.h>
#include <llvm/Support/CommandLine.h>
Expand Down Expand Up @@ -49,4 +49,4 @@ struct MemoryLayoutAnalysisPolicyTypeParser

} // namespace mlir::tt

#endif // TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSISPARAMS_H
#endif // TTMLIR_DIALECT_TTNN_UTILS_MEMORYLAYOUTANALYSISPARAMS_H
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
#ifndef TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H
#define TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H

#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h"
#include "ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h"

namespace mlir::tt::ttnn {
Expand Down
4 changes: 4 additions & 0 deletions include/ttmlir/Scheduler/Scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class Scheduler {
// Method to get the next set of schedulable operations
llvm::SmallVector<mlir::Operation *> getScheduleableOps();

// Method to check if an operation is either a TTIR op or a
// TTNN scheduleable op.
bool isTTShedulableOp(mlir::Operation *op);

// Method to check if an operation can be scheduled
bool canSchedule(mlir::Operation *op);

Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ add_mlir_dialect_library(MLIRTTNNAnalysis
MLIRTTNNPassesIncGen
MLIRTTOpsIncGen

LINK_LIBS
LINK_LIBS PUBLIC
MLIRScheduler
)
Loading

0 comments on commit 062a978

Please sign in to comment.