Skip to content

Commit

Permalink
Override mechanism (#1417)
Browse files Browse the repository at this point in the history
* Add mechanism for overriding with corresponding unittests.

* Add mechanism for overriding with corresponding unittests.

* Add mechanism for overriding with corresponding unittests.

* Add mechanism for overriding with corresponding unittests.

* Add mechanism for overriding with corresponding unittests.

* Add mechanism for overriding with corresponding unittests.

* Add mechanism for overriding with corresponding unittests.

* Add mechanism for overriding with corresponding unittests.

* Add mechanism for overriding with corresponding unittests.

* Add mechanism for overriding with corresponding unittests.

* Add mechanism for overriding with corresponding unittests.

* Add mechanism for overriding with corresponding unittests.

* Add mechanism for overriding with corresponding unittests.

* Add mechanism for overriding with corresponding unittests.
  • Loading branch information
vcanicTT authored Nov 29, 2024
1 parent 3eca3d8 commit 99331c7
Show file tree
Hide file tree
Showing 10 changed files with 980 additions and 206 deletions.
17 changes: 11 additions & 6 deletions include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,23 @@ struct MemoryLayoutAnalysisPolicyTypeParser
return false;
}

static void print(llvm::raw_ostream &os,
const MemoryLayoutAnalysisPolicyType &value) {
llvm::StringRef policy;
static std::string toString(const MemoryLayoutAnalysisPolicyType &value) {
std::string res;
switch (value) {
case MemoryLayoutAnalysisPolicyType::DFSharding:
policy = "DFSharding";
res += "DFSharding";
break;
case MemoryLayoutAnalysisPolicyType::L1Interleaved:
policy = "L1Interleaved";
res += "L1Interleaved";
break;
}
os << "memory-layout-analysis-policy=" << policy << "\n";
return res;
}

static void print(llvm::raw_ostream &os,
const MemoryLayoutAnalysisPolicyType &value) {
os << "memory-layout-analysis-policy="
<< MemoryLayoutAnalysisPolicyTypeParser::toString(value) << "\n";
}
};

Expand Down
3 changes: 2 additions & 1 deletion include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
#define TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H

#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h"
#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"

#include "mlir/Pass/PassOptions.h"

Expand Down
116 changes: 82 additions & 34 deletions include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,50 +5,98 @@
#ifndef TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H
#define TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H

#include <llvm/Support/CommandLine.h>

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h"
#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h"

namespace mlir::tt::ttnn {

struct OutputLayoutOverrideParams {
SmallVector<int64_t, 2> grid;
BufferType bufferType;
TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc...
Layout memoryLayout; // ROW_MAJOR / TILE
tt::DataType dataType;
};
class OptimizerOverridesHandler {
public:
OptimizerOverridesHandler() {};
~OptimizerOverridesHandler() {};

struct InputLayoutOverrideParams {
SmallVector<int64_t> operandIdxes;
};
// Setters for the overrides
// These are used to enable/disable the optimizer passes
void setEnableOptimizer(bool);
// These are used to enable/disable the memory configurations
void setMemoryReconfig(bool);
void setEnableMemoryLayoutAnalysis(bool);
void setEnableMemoryLayoutAnalysisPolicy(bool);
void setMemoryLayoutAnalysisPolicy(MemoryLayoutAnalysisPolicyType);
// These are used to set the input/output layout overrides
void setInputLayoutOverrides(llvm::StringMap<InputLayoutOverrideParams> &);
void setOutputLayoutOverrides(llvm::StringMap<OutputLayoutOverrideParams> &);
// These are used to add system descriptor path
void setSystemDescPath(std::string);
// These are used to set the maximum number of legal layouts for grid analysis
void setMaxLegalLayouts(int64_t);
// These are used to set the mesh shape
void setMeshShape(std::vector<int64_t>);

struct OutputLayoutOverrideParser
: public llvm::cl::parser<llvm::StringMap<OutputLayoutOverrideParams>> {
public:
OutputLayoutOverrideParser(llvm::cl::Option &opt)
: llvm::cl::parser<llvm::StringMap<OutputLayoutOverrideParams>>(opt) {}
// Getters for the overrides
// These are used to get the current state of the optimizer passes
bool getEnableOptimizer() const;
// These are used to get the current state of the memory configurations
bool getMemoryReconfig() const;
bool getEnableMemoryLayoutAnalysis() const;
bool getEnableMemoryLayoutAnalysisPolicy() const;
MemoryLayoutAnalysisPolicyType getMemoryLayoutAnalysisPolicy() const;
// These are used to get the current input/output layout overrides
llvm::StringMap<InputLayoutOverrideParams> getInputLayoutOverrides() const;
llvm::StringMap<OutputLayoutOverrideParams> getOutputLayoutOverrides() const;
// These are used to get the current system descriptor path
std::string getSystemDescPath() const;
// These are used to get the current maximum number of legal layouts for grid
// analysis
int64_t getMaxLegalLayouts() const;
// These are used to get the current mesh shape
std::vector<int64_t> getMeshShape() const;

bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
llvm::StringMap<OutputLayoutOverrideParams> &value);
// Method that converts the overrides to a string
std::string toString() const;

static void print(llvm::raw_ostream &os,
const llvm::StringMap<OutputLayoutOverrideParams> &value);
};
// Fill input/output layout overrides maps.
// This is used from tt-forge frontend where we define and compile the models.
void addInputLayoutOverride(StringRef, InputLayoutOverrideParams);
void addInputLayoutOverride(StringRef, SmallVector<int64_t> &);
void addOutputLayoutOverride(StringRef, OutputLayoutOverrideParams);
void addOutputLayoutOverride(StringRef, SmallVector<int64_t> &, BufferType,
TensorMemoryLayout, tt::ttnn::Layout,
tt::DataType);

struct InputLayoutOverrideParser
: public llvm::cl::parser<llvm::StringMap<InputLayoutOverrideParams>> {
public:
InputLayoutOverrideParser(llvm::cl::Option &opt)
: llvm::cl::parser<llvm::StringMap<InputLayoutOverrideParams>>(opt) {}
private:
// Options for the TTIR to TTNN backend pipeline,
// we use them to extract the names and the deafulat values.
TTIRToTTNNBackendPipelineOptions pipelineOptions;

// Flags for enabling/disabling the optimizer passes
bool enableOptimizer = false;

// Flags for enabling/disabling the memory configurations
bool enableMemoryReconfig = true;
bool enableMemoryLayoutAnalysis = false;

// Input layout overrides
llvm::StringMap<InputLayoutOverrideParams> inputLayoutOverrides;

// Output layout overrides
llvm::StringMap<OutputLayoutOverrideParams> outputLayoutOverrides;

// Memory layout analysis policy
bool enableMemoryLayoutAnalysisPolicy = false;
MemoryLayoutAnalysisPolicyType memoryLayoutAnalysisPolicy;

// System descriptor path
std::string systemDescPath;

// Maximum number of legal layouts for grid analysis
int64_t maxLegalLayouts = 0;

bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
llvm::StringMap<InputLayoutOverrideParams> &value);
// Mesh shape
std::vector<int64_t> meshShape;

static void print(llvm::raw_ostream &os,
const llvm::StringMap<InputLayoutOverrideParams> &value);
};
}; // class OptimizerOverridesHandler

} // namespace mlir::tt::ttnn

Expand Down
91 changes: 91 additions & 0 deletions include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC
//
// SPDX-License-Identifier: Apache-2.0

#ifndef TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H
#define TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H

#include <llvm/Support/CommandLine.h>

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"

namespace mlir::tt::ttnn {

struct OutputLayoutOverrideParams {

SmallVector<int64_t, 2> grid;
BufferType bufferType;
TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc...
Layout memoryLayout; // ROW_MAJOR / TILE
mlir::tt::DataType dataType;

bool operator==(const OutputLayoutOverrideParams rhs) const {
return grid[0] == rhs.grid[0] && grid[1] == rhs.grid[1] &&
bufferType == rhs.bufferType &&
tensorMemoryLayout == rhs.tensorMemoryLayout &&
memoryLayout == rhs.memoryLayout && dataType == rhs.dataType;
}

bool operator!=(const OutputLayoutOverrideParams &rhs) const {
return !(*this == rhs);
}
};

struct InputLayoutOverrideParams {

SmallVector<int64_t> operandIdxes;

bool operator==(const InputLayoutOverrideParams &rhs) const {
if (operandIdxes.size() != rhs.operandIdxes.size()) {
return false;
}
for (std::size_t i = 0; i < operandIdxes.size(); i++) {
if (operandIdxes[i] != rhs.operandIdxes[i]) {
return false;
}
}
return true;
}

bool operator!=(const InputLayoutOverrideParams &rhs) const {
return !(*this == rhs);
}
};

struct OutputLayoutOverrideParser
: public llvm::cl::parser<llvm::StringMap<OutputLayoutOverrideParams>> {
public:
OutputLayoutOverrideParser(llvm::cl::Option &opt)
: llvm::cl::parser<llvm::StringMap<OutputLayoutOverrideParams>>(opt) {}

bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
llvm::StringMap<OutputLayoutOverrideParams> &value);

static std::string
toString(const llvm::StringMap<OutputLayoutOverrideParams> &);

static void print(llvm::raw_ostream &os,
const llvm::StringMap<OutputLayoutOverrideParams> &value);
};

struct InputLayoutOverrideParser
: public llvm::cl::parser<llvm::StringMap<InputLayoutOverrideParams>> {
public:
InputLayoutOverrideParser(llvm::cl::Option &opt)
: llvm::cl::parser<llvm::StringMap<InputLayoutOverrideParams>>(opt) {}

bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
llvm::StringMap<InputLayoutOverrideParams> &value);

static std::string
toString(const llvm::StringMap<InputLayoutOverrideParams> &);

static void print(llvm::raw_ostream &os,
const llvm::StringMap<InputLayoutOverrideParams> &value);
};

} // namespace mlir::tt::ttnn

#endif // TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H
2 changes: 2 additions & 0 deletions include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#ifndef TTMLIR_DIALECT_TTNN_UTILS_UTILS_H
#define TTMLIR_DIALECT_TTNN_UTILS_UTILS_H

#include <llvm/Support/CommandLine.h>

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TTNN/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
add_mlir_dialect_library(TTMLIRTTNNUtils
Utils.cpp
OptimizerOverrides.cpp
PassOverrides.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/TTNN
Expand Down
Loading

0 comments on commit 99331c7

Please sign in to comment.