Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Optimizer] Greedy solution for join nodes in L1 Interleaved policy #1162

Merged
merged 4 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class L1ChainConfig {
std::unordered_set<Edge> &memReconfigEdges);

bool isEmpty() { return opL1MemSpecs.empty(); }
void addOpL1MemSpec(OpL1MemSpec &&spec) {
void addOpL1MemSpec(OpL1MemSpec spec) {
assert(state == L1ChainState::InBuild);
l1ChainedOps.insert(spec.op);
opL1MemSpecs.push_back(std::move(spec));
Expand Down
97 changes: 97 additions & 0 deletions include/ttmlir/Dialect/TTNN/Analysis/L1InterleavedPolicy.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,43 @@
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/MemoryLayoutAnalysisPolicy.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"

namespace mlir::tt::ttnn {

class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {
public:
struct OpMemSpec {
TTNNLayoutAttr layout;
// Minimum L1 memory usage required for scheduling the op
// given the layouts of all the ops that are already scheduled.
//
uint64_t requiredL1Usage;
};

// This struct is holding information about the greedily choosen
// configuration of the @baseOp: 1) layouts and 2) precedence.
//
// The @layouts represents the mapping between the op and its choosen
// layout. All the ops that are included in the @layouts map must be
// either @baseOp or its operand with legal L1 Interleaved output layout
// at the moment of analyzing the @baseOp.
//
// The @precedence represents the order of the op's operands in which they
// should be scheduled. Only op's operands that are included in the @layouts
// map are included in the @precedence.
//
struct OpConfig {
Operation *baseOp;
llvm::DenseMap<Operation *, TTNNLayoutAttr> layouts;
llvm::SmallVector<Operation *> precedence;
};

struct L1Usage {
size_t outputL1Usage;
size_t requiredL1Usage;
};

public:
L1InterleavedPolicy(
Operation *rootOp, std::vector<L1ChainConfig> &l1ChainConfigs,
Expand All @@ -22,7 +55,71 @@ class L1InterleavedPolicy : public MemoryLayoutAnalysisPolicy {
: MemoryLayoutAnalysisPolicy(rootOp, l1ChainConfigs, legalLayouts,
schedule, usableL1CacheSize) {}

/**
* Retrieve the greedy OpConfig for the given base operation
* and its opsL1Usage map.
*
* @param baseOp The base operation for which the greedy configuration is
* being determined.
* @param opsL1Usage A map between the operation and its output L1 usage. All
* operations included in the opsL1Usage map must be either the baseOp or its
* operand with a legal L1 Interleaved output layout at the time of analyzing
* the baseOp.
* @return The greedy OpConfig for the baseOp.
*/
OpConfig getGreedyConfig(Operation *baseOp,
llvm::DenseMap<Operation *, L1Usage> &opsL1Usage);

void run() final;

private:
// Check if the op is analyzable. Op is analyzable if it has at least one
// legal layout.
bool isAnalyzable(Operation *op);

// Fetch op's DRAM layout from legalLayouts.
bool hasDRAMBufferType(Operation *op);
TTNNLayoutAttr getDRAMLayout(Operation *op);

// Fetch op's L1 Interleaved layout from legalLayouts.
bool hasL1BufferType(Operation *op);
TTNNLayoutAttr getL1InterleavedLayout(Operation *op);

size_t getAvailableL1CacheSize() const {
// Figure out this const based on exec data, but will be replaced
// with API.
//
constexpr float tensorL1UsageCap = 0.75;
return tensorL1UsageCap * usableL1CacheSize;
}

// Precedence schedule map for each operation. It contains the order
// in which operands need to be executed for each op.
llvm::DenseMap<Operation *, llvm::SmallVector<Operation *>> precedenceMap;
fbajraktariTT marked this conversation as resolved.
Show resolved Hide resolved

llvm::DenseSet<Operation *> visitedOps;
void buildSchedule(mlir::Operation *op, func::FuncOp &func) {

// Schedule all the precedents of the current operation
//
visitedOps.insert(op);
for (Operation *precedent : precedenceMap[op]) {
if (!visitedOps.count(precedent)) {
buildSchedule(precedent, func);
}
}

(*schedule)[func].push_back(op);
}

void constructSchedule(func::FuncOp &func) {
func->walk([&](Operation *op) {
if (op->hasTrait<mlir::OpTrait::ReturnLike>()) {
Operation *outputOp = op->getOperand(0).getDefiningOp();
buildSchedule(outputOp, func);
}
});
}
};

} // namespace mlir::tt::ttnn
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
#define TTMLIR_DIALECT_TTNN_ANALYSIS_MEMORYLAYOUTANALYSIS_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Analysis/Edge.h"
#include "ttmlir/Dialect/TTNN/Analysis/L1ChainConfig.h"
#include "ttmlir/Dialect/TTNN/Analysis/TTNNAnalysis.h"
#include "ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h"

namespace mlir::tt::ttnn {

Expand Down
3 changes: 3 additions & 0 deletions include/ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ def TTNN_TTNNLayoutAttr: TTNN_Attr<"TTNNLayout", "ttnn_layout"> {
bool hasShardedTensorMemoryLayout() const;
bool hasShardedL1TensorMemoryLayout() const;
bool hasInterleavedL1TensorMemoryLayout() const;
bool hasInterleavedDRAMTensorMemoryLayout() const;
bool hasL1BufferType() const;
bool hasDRAMBufferType() const;
bool isTiled() const;
Layout getLayout() const;
Type getElementType() const;
Expand Down
3 changes: 1 addition & 2 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
#ifndef TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H
#define TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H

#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"

#include "mlir/Pass/PassOptions.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSISPARAMS_H
#define TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSISPARAMS_H
#ifndef TTMLIR_DIALECT_TTNN_UTILS_MEMORYLAYOUTANALYSISPARAMS_H
#define TTMLIR_DIALECT_TTNN_UTILS_MEMORYLAYOUTANALYSISPARAMS_H

#include <llvm/ADT/StringSwitch.h>
#include <llvm/Support/CommandLine.h>
Expand Down Expand Up @@ -49,4 +49,4 @@ struct MemoryLayoutAnalysisPolicyTypeParser

} // namespace mlir::tt

#endif // TTMLIR_DIALECT_TT_UTILS_MEMORYLAYOUTANALYSISPARAMS_H
#endif // TTMLIR_DIALECT_TTNN_UTILS_MEMORYLAYOUTANALYSISPARAMS_H
2 changes: 1 addition & 1 deletion include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
#ifndef TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H
#define TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H

#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h"
#include "ttmlir/Dialect/TTNN/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h"

namespace mlir::tt::ttnn {
Expand Down
4 changes: 4 additions & 0 deletions include/ttmlir/Scheduler/Scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class Scheduler {
// Method to get the next set of schedulable operations
llvm::SmallVector<mlir::Operation *> getScheduleableOps();

// Method to check if an operation is either a TTIR op or a
// TTNN scheduleable op.
bool isTTShedulableOp(mlir::Operation *op);
fbajraktariTT marked this conversation as resolved.
Show resolved Hide resolved

// Method to check if an operation can be scheduled
bool canSchedule(mlir::Operation *op);

Expand Down
2 changes: 1 addition & 1 deletion lib/Dialect/TTNN/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ add_mlir_dialect_library(MLIRTTNNAnalysis
MLIRTTNNPassesIncGen
MLIRTTOpsIncGen

LINK_LIBS
LINK_LIBS PUBLIC
fbajraktariTT marked this conversation as resolved.
Show resolved Hide resolved
MLIRScheduler
)
Loading
Loading