-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Override mechanism #1417
Merged
Merged
Override mechanism #1417
Changes from all commits
Commits
Show all changes
14 commits
Select commit
Hold shift + click to select a range
ace99eb
Add mechanism for overriding with corresponding unittests.
vcanicTT e11fa26
Add mechanism for overriding with corresponding unittests.
vcanicTT 8c387fe
Add mechanism for overriding with corresponding unittests.
vcanicTT c3f5dd3
Add mechanism for overriding with corresponding unittests.
vcanicTT 1081138
Add mechanism for overriding with corresponding unittests.
vcanicTT d53f658
Add mechanism for overriding with corresponding unittests.
vcanicTT c41526a
Add mechanism for overriding with corresponding unittests.
vcanicTT c5b0b77
Add mechanism for overriding with corresponding unittests.
vcanicTT fbafabc
Add mechanism for overriding with corresponding unittests.
vcanicTT d2032d5
Add mechanism for overriding with corresponding unittests.
vcanicTT b23b18c
Add mechanism for overriding with corresponding unittests.
vcanicTT 3936b5a
Add mechanism for overriding with corresponding unittests.
vcanicTT bdcde2f
Add mechanism for overriding with corresponding unittests.
vcanicTT 49b814e
Add mechanism for overriding with corresponding unittests.
vcanicTT File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,50 +5,98 @@ | |
#ifndef TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H | ||
#define TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
#include <llvm/Support/CommandLine.h> | ||
|
||
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" | ||
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" | ||
#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h" | ||
#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h" | ||
#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" | ||
|
||
namespace mlir::tt::ttnn { | ||
|
||
struct OutputLayoutOverrideParams { | ||
SmallVector<int64_t, 2> grid; | ||
BufferType bufferType; | ||
TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... | ||
Layout memoryLayout; // ROW_MAJOR / TILE | ||
tt::DataType dataType; | ||
}; | ||
class OptimizerOverridesHandler { | ||
public: | ||
OptimizerOverridesHandler() {}; | ||
~OptimizerOverridesHandler() {}; | ||
|
||
struct InputLayoutOverrideParams { | ||
SmallVector<int64_t> operandIdxes; | ||
}; | ||
// Setters for the overrides | ||
// These are used to enable/disable the optimizer passes | ||
void setEnableOptimizer(bool); | ||
// These are used to enable/disable the memory configurations | ||
void setMemoryReconfig(bool); | ||
void setEnableMemoryLayoutAnalysis(bool); | ||
void setEnableMemoryLayoutAnalysisPolicy(bool); | ||
void setMemoryLayoutAnalysisPolicy(MemoryLayoutAnalysisPolicyType); | ||
// These are used to set the input/output layout overrides | ||
void setInputLayoutOverrides(llvm::StringMap<InputLayoutOverrideParams> &); | ||
void setOutputLayoutOverrides(llvm::StringMap<OutputLayoutOverrideParams> &); | ||
// These are used to add system descriptor path | ||
void setSystemDescPath(std::string); | ||
// These are used to set the maximum number of legal layouts for grid analysis | ||
void setMaxLegalLayouts(int64_t); | ||
// These are used to set the mesh shape | ||
void setMeshShape(std::vector<int64_t>); | ||
|
||
struct OutputLayoutOverrideParser | ||
: public llvm::cl::parser<llvm::StringMap<OutputLayoutOverrideParams>> { | ||
public: | ||
OutputLayoutOverrideParser(llvm::cl::Option &opt) | ||
: llvm::cl::parser<llvm::StringMap<OutputLayoutOverrideParams>>(opt) {} | ||
// Getters for the overrides | ||
// These are used to get the current state of the optimizer passes | ||
bool getEnableOptimizer() const; | ||
// These are used to get the current state of the memory configurations | ||
bool getMemoryReconfig() const; | ||
bool getEnableMemoryLayoutAnalysis() const; | ||
bool getEnableMemoryLayoutAnalysisPolicy() const; | ||
MemoryLayoutAnalysisPolicyType getMemoryLayoutAnalysisPolicy() const; | ||
// These are used to get the current input/output layout overrides | ||
llvm::StringMap<InputLayoutOverrideParams> getInputLayoutOverrides() const; | ||
llvm::StringMap<OutputLayoutOverrideParams> getOutputLayoutOverrides() const; | ||
// These are used to get the current system descriptor path | ||
std::string getSystemDescPath() const; | ||
// These are used to get the current maximum number of legal layouts for grid | ||
// analysis | ||
int64_t getMaxLegalLayouts() const; | ||
// These are used to get the current mesh shape | ||
std::vector<int64_t> getMeshShape() const; | ||
|
||
bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, | ||
llvm::StringMap<OutputLayoutOverrideParams> &value); | ||
// Method that converts the overrides to a string | ||
std::string toString() const; | ||
|
||
static void print(llvm::raw_ostream &os, | ||
const llvm::StringMap<OutputLayoutOverrideParams> &value); | ||
}; | ||
// Fill input/output layout overrides maps. | ||
// This is used from tt-forge frontend where we define and compile the models. | ||
void addInputLayoutOverride(StringRef, InputLayoutOverrideParams); | ||
void addInputLayoutOverride(StringRef, SmallVector<int64_t> &); | ||
void addOutputLayoutOverride(StringRef, OutputLayoutOverrideParams); | ||
void addOutputLayoutOverride(StringRef, SmallVector<int64_t> &, BufferType, | ||
TensorMemoryLayout, tt::ttnn::Layout, | ||
tt::DataType); | ||
|
||
struct InputLayoutOverrideParser | ||
: public llvm::cl::parser<llvm::StringMap<InputLayoutOverrideParams>> { | ||
public: | ||
InputLayoutOverrideParser(llvm::cl::Option &opt) | ||
: llvm::cl::parser<llvm::StringMap<InputLayoutOverrideParams>>(opt) {} | ||
private: | ||
// Options for the TTIR to TTNN backend pipeline, | ||
// we use them to extract the names and the deafulat values. | ||
TTIRToTTNNBackendPipelineOptions pipelineOptions; | ||
|
||
// Flags for enabling/disabling the optimizer passes | ||
bool enableOptimizer = false; | ||
|
||
// Flags for enabling/disabling the memory configurations | ||
bool enableMemoryReconfig = true; | ||
bool enableMemoryLayoutAnalysis = false; | ||
|
||
// Input layout overrides | ||
llvm::StringMap<InputLayoutOverrideParams> inputLayoutOverrides; | ||
|
||
// Output layout overrides | ||
llvm::StringMap<OutputLayoutOverrideParams> outputLayoutOverrides; | ||
|
||
// Memory layout analysis policy | ||
bool enableMemoryLayoutAnalysisPolicy = false; | ||
MemoryLayoutAnalysisPolicyType memoryLayoutAnalysisPolicy; | ||
|
||
// System descriptor path | ||
std::string systemDescPath; | ||
|
||
// Maximum number of legal layouts for grid analysis | ||
int64_t maxLegalLayouts = 0; | ||
|
||
bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, | ||
llvm::StringMap<InputLayoutOverrideParams> &value); | ||
// Mesh shape | ||
std::vector<int64_t> meshShape; | ||
|
||
static void print(llvm::raw_ostream &os, | ||
const llvm::StringMap<InputLayoutOverrideParams> &value); | ||
}; | ||
}; // class OptimizerOverridesHandler | ||
|
||
} // namespace mlir::tt::ttnn | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC | ||
// | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
#ifndef TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H | ||
#define TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H | ||
|
||
#include <llvm/Support/CommandLine.h> | ||
|
||
#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" | ||
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" | ||
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" | ||
|
||
namespace mlir::tt::ttnn { | ||
|
||
struct OutputLayoutOverrideParams { | ||
|
||
SmallVector<int64_t, 2> grid; | ||
BufferType bufferType; | ||
TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... | ||
Layout memoryLayout; // ROW_MAJOR / TILE | ||
mlir::tt::DataType dataType; | ||
|
||
bool operator==(const OutputLayoutOverrideParams rhs) const { | ||
return grid[0] == rhs.grid[0] && grid[1] == rhs.grid[1] && | ||
bufferType == rhs.bufferType && | ||
tensorMemoryLayout == rhs.tensorMemoryLayout && | ||
memoryLayout == rhs.memoryLayout && dataType == rhs.dataType; | ||
} | ||
|
||
bool operator!=(const OutputLayoutOverrideParams &rhs) const { | ||
return !(*this == rhs); | ||
} | ||
}; | ||
|
||
struct InputLayoutOverrideParams { | ||
|
||
SmallVector<int64_t> operandIdxes; | ||
|
||
bool operator==(const InputLayoutOverrideParams &rhs) const { | ||
if (operandIdxes.size() != rhs.operandIdxes.size()) { | ||
return false; | ||
} | ||
for (std::size_t i = 0; i < operandIdxes.size(); i++) { | ||
if (operandIdxes[i] != rhs.operandIdxes[i]) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
bool operator!=(const InputLayoutOverrideParams &rhs) const { | ||
return !(*this == rhs); | ||
} | ||
}; | ||
|
||
struct OutputLayoutOverrideParser | ||
: public llvm::cl::parser<llvm::StringMap<OutputLayoutOverrideParams>> { | ||
public: | ||
OutputLayoutOverrideParser(llvm::cl::Option &opt) | ||
: llvm::cl::parser<llvm::StringMap<OutputLayoutOverrideParams>>(opt) {} | ||
|
||
bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, | ||
llvm::StringMap<OutputLayoutOverrideParams> &value); | ||
|
||
static std::string | ||
toString(const llvm::StringMap<OutputLayoutOverrideParams> &); | ||
|
||
static void print(llvm::raw_ostream &os, | ||
const llvm::StringMap<OutputLayoutOverrideParams> &value); | ||
}; | ||
|
||
struct InputLayoutOverrideParser | ||
: public llvm::cl::parser<llvm::StringMap<InputLayoutOverrideParams>> { | ||
public: | ||
InputLayoutOverrideParser(llvm::cl::Option &opt) | ||
: llvm::cl::parser<llvm::StringMap<InputLayoutOverrideParams>>(opt) {} | ||
|
||
bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, | ||
llvm::StringMap<InputLayoutOverrideParams> &value); | ||
|
||
static std::string | ||
toString(const llvm::StringMap<InputLayoutOverrideParams> &); | ||
|
||
static void print(llvm::raw_ostream &os, | ||
const llvm::StringMap<InputLayoutOverrideParams> &value); | ||
}; | ||
|
||
} // namespace mlir::tt::ttnn | ||
|
||
#endif // TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
code/includes outside of area guarded by header guard; consider moving it