diff --git a/include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h b/include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h index 16fafe551..4a44e883d 100644 --- a/include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h +++ b/include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h @@ -27,18 +27,23 @@ struct MemoryLayoutAnalysisPolicyTypeParser return false; } - static void print(llvm::raw_ostream &os, - const MemoryLayoutAnalysisPolicyType &value) { - llvm::StringRef policy; + static std::string toString(const MemoryLayoutAnalysisPolicyType &value) { + std::string res; switch (value) { case MemoryLayoutAnalysisPolicyType::DFSharding: - policy = "DFSharding"; + res += "DFSharding"; break; case MemoryLayoutAnalysisPolicyType::L1Interleaved: - policy = "L1Interleaved"; + res += "L1Interleaved"; break; } - os << "memory-layout-analysis-policy=" << policy << "\n"; + return res; + } + + static void print(llvm::raw_ostream &os, + const MemoryLayoutAnalysisPolicyType &value) { + os << "memory-layout-analysis-policy=" + << MemoryLayoutAnalysisPolicyTypeParser::toString(value) << "\n"; } }; diff --git a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h index 48c723e1c..636d5f623 100644 --- a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h +++ b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h @@ -6,7 +6,8 @@ #define TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H #include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h" -#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h" +#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" +#include "ttmlir/Dialect/TTNN/Utils/Utils.h" #include "mlir/Pass/PassOptions.h" diff --git a/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h b/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h index db24eeb28..c474106e3 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h +++ b/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h @@ -5,50 +5,98 @@ #ifndef TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H #define TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H -#include - -#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h" +#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h" +#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" namespace mlir::tt::ttnn { -struct OutputLayoutOverrideParams { - SmallVector grid; - BufferType bufferType; - TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... - Layout memoryLayout; // ROW_MAJOR / TILE - tt::DataType dataType; -}; +class OptimizerOverridesHandler { +public: + OptimizerOverridesHandler() {}; + ~OptimizerOverridesHandler() {}; -struct InputLayoutOverrideParams { - SmallVector operandIdxes; -}; + // Setters for the overrides + // These are used to enable/disable the optimizer passes + void setEnableOptimizer(bool); + // These are used to enable/disable the memory configurations + void setMemoryReconfig(bool); + void setEnableMemoryLayoutAnalysis(bool); + void setEnableMemoryLayoutAnalysisPolicy(bool); + void setMemoryLayoutAnalysisPolicy(MemoryLayoutAnalysisPolicyType); + // These are used to set the input/output layout overrides + void setInputLayoutOverrides(llvm::StringMap &); + void setOutputLayoutOverrides(llvm::StringMap &); + // These are used to add system descriptor path + void setSystemDescPath(std::string); + // These are used to set the maximum number of legal layouts for grid analysis + void setMaxLegalLayouts(int64_t); + // These are used to set the mesh shape + void setMeshShape(std::vector); -struct OutputLayoutOverrideParser - : public llvm::cl::parser> { -public: - OutputLayoutOverrideParser(llvm::cl::Option &opt) - : llvm::cl::parser>(opt) {} + // Getters for the overrides + // These are used to get the current state of the optimizer passes + bool getEnableOptimizer() const; + // These are used to get the current state of the memory configurations + bool getMemoryReconfig() const; + bool getEnableMemoryLayoutAnalysis() const; + bool getEnableMemoryLayoutAnalysisPolicy() const; + MemoryLayoutAnalysisPolicyType getMemoryLayoutAnalysisPolicy() const; + // These are used to get the current input/output layout overrides + llvm::StringMap getInputLayoutOverrides() const; + llvm::StringMap getOutputLayoutOverrides() const; + // These are used to get the current system descriptor path + std::string getSystemDescPath() const; + // These are used to get the current maximum number of legal layouts for grid + // analysis + int64_t getMaxLegalLayouts() const; + // These are used to get the current mesh shape + std::vector getMeshShape() const; - bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, - llvm::StringMap &value); + // Method that converts the overrides to a string + std::string toString() const; - static void print(llvm::raw_ostream &os, - const llvm::StringMap &value); -}; + // Fill input/output layout overrides maps. + // This is used from tt-forge frontend where we define and compile the models. + void addInputLayoutOverride(StringRef, InputLayoutOverrideParams); + void addInputLayoutOverride(StringRef, SmallVector &); + void addOutputLayoutOverride(StringRef, OutputLayoutOverrideParams); + void addOutputLayoutOverride(StringRef, SmallVector &, BufferType, + TensorMemoryLayout, tt::ttnn::Layout, + tt::DataType); -struct InputLayoutOverrideParser - : public llvm::cl::parser> { -public: - InputLayoutOverrideParser(llvm::cl::Option &opt) - : llvm::cl::parser>(opt) {} +private: + // Options for the TTIR to TTNN backend pipeline, + // we use them to extract the names and the deafulat values. + TTIRToTTNNBackendPipelineOptions pipelineOptions; + + // Flags for enabling/disabling the optimizer passes + bool enableOptimizer = false; + + // Flags for enabling/disabling the memory configurations + bool enableMemoryReconfig = true; + bool enableMemoryLayoutAnalysis = false; + + // Input layout overrides + llvm::StringMap inputLayoutOverrides; + + // Output layout overrides + llvm::StringMap outputLayoutOverrides; + + // Memory layout analysis policy + bool enableMemoryLayoutAnalysisPolicy = false; + MemoryLayoutAnalysisPolicyType memoryLayoutAnalysisPolicy; + + // System descriptor path + std::string systemDescPath; + + // Maximum number of legal layouts for grid analysis + int64_t maxLegalLayouts = 0; - bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, - llvm::StringMap &value); + // Mesh shape + std::vector meshShape; - static void print(llvm::raw_ostream &os, - const llvm::StringMap &value); -}; +}; // class OptimizerOverridesHandler } // namespace mlir::tt::ttnn diff --git a/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h b/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h new file mode 100644 index 000000000..09e587c9c --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h @@ -0,0 +1,91 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H +#define TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H + +#include + +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" + +namespace mlir::tt::ttnn { + +struct OutputLayoutOverrideParams { + + SmallVector grid; + BufferType bufferType; + TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + Layout memoryLayout; // ROW_MAJOR / TILE + mlir::tt::DataType dataType; + + bool operator==(const OutputLayoutOverrideParams rhs) const { + return grid[0] == rhs.grid[0] && grid[1] == rhs.grid[1] && + bufferType == rhs.bufferType && + tensorMemoryLayout == rhs.tensorMemoryLayout && + memoryLayout == rhs.memoryLayout && dataType == rhs.dataType; + } + + bool operator!=(const OutputLayoutOverrideParams &rhs) const { + return !(*this == rhs); + } +}; + +struct InputLayoutOverrideParams { + + SmallVector operandIdxes; + + bool operator==(const InputLayoutOverrideParams &rhs) const { + if (operandIdxes.size() != rhs.operandIdxes.size()) { + return false; + } + for (std::size_t i = 0; i < operandIdxes.size(); i++) { + if (operandIdxes[i] != rhs.operandIdxes[i]) { + return false; + } + } + return true; + } + + bool operator!=(const InputLayoutOverrideParams &rhs) const { + return !(*this == rhs); + } +}; + +struct OutputLayoutOverrideParser + : public llvm::cl::parser> { +public: + OutputLayoutOverrideParser(llvm::cl::Option &opt) + : llvm::cl::parser>(opt) {} + + bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value); + + static std::string + toString(const llvm::StringMap &); + + static void print(llvm::raw_ostream &os, + const llvm::StringMap &value); +}; + +struct InputLayoutOverrideParser + : public llvm::cl::parser> { +public: + InputLayoutOverrideParser(llvm::cl::Option &opt) + : llvm::cl::parser>(opt) {} + + bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value); + + static std::string + toString(const llvm::StringMap &); + + static void print(llvm::raw_ostream &os, + const llvm::StringMap &value); +}; + +} // namespace mlir::tt::ttnn + +#endif // TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H diff --git a/include/ttmlir/Dialect/TTNN/Utils/Utils.h b/include/ttmlir/Dialect/TTNN/Utils/Utils.h index 533235a61..d7d8fbdd3 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/Utils.h +++ b/include/ttmlir/Dialect/TTNN/Utils/Utils.h @@ -5,6 +5,8 @@ #ifndef TTMLIR_DIALECT_TTNN_UTILS_UTILS_H #define TTMLIR_DIALECT_TTNN_UTILS_UTILS_H +#include + #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" diff --git a/lib/Dialect/TTNN/Utils/CMakeLists.txt b/lib/Dialect/TTNN/Utils/CMakeLists.txt index f49f829e6..f78f41864 100644 --- a/lib/Dialect/TTNN/Utils/CMakeLists.txt +++ b/lib/Dialect/TTNN/Utils/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(TTMLIRTTNNUtils Utils.cpp OptimizerOverrides.cpp + PassOverrides.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/TTNN diff --git a/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp b/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp index 5ef306cdb..bbc456948 100644 --- a/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp +++ b/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp @@ -6,187 +6,173 @@ namespace mlir::tt::ttnn { -bool OutputLayoutOverrideParser::parse( - llvm::cl::Option &opt, StringRef argName, StringRef arg, +void OptimizerOverridesHandler::setEnableOptimizer(bool value) { + enableOptimizer = value; +} + +void OptimizerOverridesHandler::setMemoryReconfig(bool value) { + enableMemoryReconfig = value; +} +void OptimizerOverridesHandler::setEnableMemoryLayoutAnalysis(bool value) { + enableMemoryLayoutAnalysis = value; +} +void OptimizerOverridesHandler::setEnableMemoryLayoutAnalysisPolicy( + bool value) { + enableMemoryLayoutAnalysisPolicy = value; +} +void OptimizerOverridesHandler::setMemoryLayoutAnalysisPolicy( + MemoryLayoutAnalysisPolicyType value) { + memoryLayoutAnalysisPolicy = value; +} + +void OptimizerOverridesHandler::setInputLayoutOverrides( + llvm::StringMap &value) { + inputLayoutOverrides = value; +} +void OptimizerOverridesHandler::setOutputLayoutOverrides( llvm::StringMap &value) { - SmallVector opOverrideList; - constexpr size_t kMaxGridSize = 2; - constexpr size_t kvPairSize = 2; - constexpr size_t kMaxLayoutOverrideParams = 5; - constexpr size_t iOpName = 0; - constexpr size_t iLayoutOverrideParams = 1; - constexpr size_t iGrid = 0; - constexpr size_t iMemorySpace = 1; - constexpr size_t iTensorMemoryLayout = 2; - constexpr size_t iMemoryLayout = 3; - constexpr size_t iDataType = 4; - constexpr char opSeparator = ','; - constexpr char opNameSeparator = '='; - constexpr char paramSepataor = ':'; - constexpr char gridSeparator = 'x'; - - arg.split(opOverrideList, opSeparator); - for (const StringRef override : opOverrideList) { - SmallVector opOverrideParts; - override.split(opOverrideParts, opNameSeparator); - if (opOverrideParts.size() != kvPairSize) { - opt.error("Invalid format for override grid sizes: " + override); - return true; - } + outputLayoutOverrides = value; +} - SmallVector layoutParamParts; - // Split into layout parameters. - opOverrideParts[iLayoutOverrideParams].split(layoutParamParts, - paramSepataor); - if (layoutParamParts.size() != kMaxLayoutOverrideParams) { - opt.error("Invalid number of layout parameters: " + - std::to_string(layoutParamParts.size())); - return true; - } +void OptimizerOverridesHandler::setSystemDescPath(std::string value) { + systemDescPath = value; +} +void OptimizerOverridesHandler::setMaxLegalLayouts(int64_t value) { + maxLegalLayouts = value; +} +void OptimizerOverridesHandler::setMeshShape(std::vector value) { + meshShape = value; +} - // Parse grid. - SmallVector grid; - SmallVector gridParts; - layoutParamParts[iGrid].split(gridParts, gridSeparator); - for (const StringRef gridPart : gridParts) { - int64_t gridValue; - if (gridPart.getAsInteger(10 /*Radix*/, gridValue)) { - opt.error("Invalid grid size: " + gridPart); - return true; - } - grid.push_back(gridValue); - } +bool OptimizerOverridesHandler::getEnableOptimizer() const { + return enableOptimizer; +} - // Parse memory space. - std::optional bufferType = - symbolizeBufferType(layoutParamParts[iMemorySpace]); - if (!bufferType.has_value()) { - opt.error("Invalid memory space: " + layoutParamParts[iMemorySpace]); - return true; - } +bool OptimizerOverridesHandler::getMemoryReconfig() const { + return enableMemoryReconfig; +} +bool OptimizerOverridesHandler::getEnableMemoryLayoutAnalysis() const { + return enableMemoryLayoutAnalysis; +} +bool OptimizerOverridesHandler::getEnableMemoryLayoutAnalysisPolicy() const { + return enableMemoryLayoutAnalysisPolicy; +} +MemoryLayoutAnalysisPolicyType +OptimizerOverridesHandler::getMemoryLayoutAnalysisPolicy() const { + return memoryLayoutAnalysisPolicy; +} - // Parse tensor memory layout. - std::optional tensorMemoryLayout = - symbolizeTensorMemoryLayout(layoutParamParts[iTensorMemoryLayout]); - if (!tensorMemoryLayout.has_value()) { - opt.error("Invalid tensor memory layout: " + - layoutParamParts[iTensorMemoryLayout]); - return true; - } +std::string OptimizerOverridesHandler::getSystemDescPath() const { + return systemDescPath; +} +int64_t OptimizerOverridesHandler::getMaxLegalLayouts() const { + return maxLegalLayouts; +} +std::vector OptimizerOverridesHandler::getMeshShape() const { + return meshShape; +} - // Parse memory layout. - std::optional memoryLayout = - mlir::tt::ttnn::symbolizeLayout(layoutParamParts[iMemoryLayout]); - if (!memoryLayout.has_value()) { - opt.error("Invalid memory layout: " + layoutParamParts[iMemoryLayout]); - return true; - } +llvm::StringMap +OptimizerOverridesHandler::getInputLayoutOverrides() const { + return inputLayoutOverrides; +} +llvm::StringMap +OptimizerOverridesHandler::getOutputLayoutOverrides() const { + return outputLayoutOverrides; +} - // Parse data type. - std::optional dataType = - mlir::tt::DataTypeStringToEnum(layoutParamParts[iDataType]); - if (!dataType.has_value()) { - opt.error("Invalid data type: " + layoutParamParts[iDataType]); - return true; - } +std::string OptimizerOverridesHandler::toString() const { - // Set parsed op overrides. - value[opOverrideParts[iOpName]] = OutputLayoutOverrideParams{ - std::move(grid), bufferType.value(), tensorMemoryLayout.value(), - memoryLayout.value(), dataType.value()}; + std::string options = ""; + + if (enableOptimizer) { + options += std::string(pipelineOptions.optimizerPassEnabled.getArgStr()) + + "=true "; } - return false; -} - -void OutputLayoutOverrideParser::print( - llvm::raw_ostream &os, - const llvm::StringMap &value) { - os << "override-output-layout="; - size_t count = 0; - for (const auto &entry : value) { - os << entry.getKey() << "="; - const OutputLayoutOverrideParams ¶ms = entry.getValue(); - // Print grid values - for (size_t i = 0; i < params.grid.size(); ++i) { - os << params.grid[i]; - if (i < params.grid.size() - 1) { - os << "x"; - } - } - // Print memory space and memory layout - os << ":" << mlir::tt::ttnn::stringifyBufferType(params.bufferType); - os << ":" - << mlir::tt::ttnn::stringifyTensorMemoryLayout( - params.tensorMemoryLayout); - os << ":" << mlir::tt::ttnn::stringifyLayout(params.memoryLayout); - os << ":" << mlir::tt::DataTypeEnumToString(params.dataType); - if (++count < value.size()) { - os << ","; - } + + if (enableMemoryReconfig) { + options += + std::string(pipelineOptions.memReconfigEnabled.getArgStr()) + "=true "; } - os << "\n"; -} -bool InputLayoutOverrideParser::parse( - llvm::cl::Option &opt, StringRef argName, StringRef arg, - llvm::StringMap &value) { - SmallVector opOverrideList; - constexpr size_t kvPairSize = 2; - constexpr size_t iOpName = 0; - constexpr size_t iOperands = 1; - constexpr char opSeparator = ','; - constexpr char opNameSeparator = '='; - constexpr char opParamSeparator = ':'; - - arg.split(opOverrideList, opSeparator); - for (const StringRef override : opOverrideList) { - SmallVector opOverrideParts; - override.split(opOverrideParts, opNameSeparator); - if (opOverrideParts.size() != kvPairSize) { - opt.error("Invalid format for input layouts override: " + override); - return true; - } + if (enableMemoryLayoutAnalysis) { + options += + std::string(pipelineOptions.memoryLayoutAnalysisEnabled.getArgStr()) + + "=true "; + } - SmallVector operandIndexes; - SmallVector operandIndexParts; - - // Parse operand indexes. - opOverrideParts[iOperands].split(operandIndexParts, opParamSeparator); - for (const StringRef operandIndexPart : operandIndexParts) { - int64_t operandIndexValue; - if (operandIndexPart.getAsInteger(10 /*Radix*/, operandIndexValue)) { - opt.error("Invalid operand index: " + operandIndexPart); - return true; - } - operandIndexes.push_back(operandIndexValue); - } + if (enableMemoryLayoutAnalysisPolicy) { + options += + std::string(pipelineOptions.memoryLayoutAnalysisPolicy.getArgStr()) + + MemoryLayoutAnalysisPolicyTypeParser::toString( + memoryLayoutAnalysisPolicy) + + " "; + } - // Set parsed op overrides. - value[opOverrideParts[iOpName]] = - InputLayoutOverrideParams{std::move(operandIndexes)}; + // Create input layout overrides. + // Example: insert-memreconfig=input0=0:1,input1=0,input2=0:1:2 + if (inputLayoutOverrides.size() > 0) { + options += std::string(pipelineOptions.overrideInputLayout.getArgStr()) + + "=" + InputLayoutOverrideParser::toString(inputLayoutOverrides) + + " "; } - return false; -} - -void InputLayoutOverrideParser::print( - llvm::raw_ostream &os, - const llvm::StringMap &value) { - os << "insert-memreconfig="; - size_t count = 0; - for (const auto &entry : value) { - os << entry.getKey() << "="; - const InputLayoutOverrideParams ¶ms = entry.getValue(); - for (int64_t operandIdx : params.operandIdxes) { - os << operandIdx - << (operandIdx < static_cast(params.operandIdxes.size()) - 1 - ? ':' - : char()); - } - if (++count < value.size()) { - os << ","; + + // Create output layout overrides. + // Example: + // override-output-layout=op1=2x2:dram:interleaved:tile:fp32,op2=4x4:l1:block_sharded:row_major:fp16 + // Example: + // override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32" + if (outputLayoutOverrides.size() > 0) { + options += + std::string(pipelineOptions.overrideOutputLayout.getArgStr()) + "=" + + OutputLayoutOverrideParser::toString(outputLayoutOverrides) + " "; + } + + if (systemDescPath.size() > 0) { + options += std::string(pipelineOptions.systemDescPath.getArgStr()) + + systemDescPath + " "; + } + + if (maxLegalLayouts > 0) { + options += std::string(pipelineOptions.maxLegalLayouts.getArgStr()) + + std::to_string(maxLegalLayouts) + " "; + } + + if (meshShape.size() > 0) { + options += std::string(pipelineOptions.meshShape.getArgStr()) + "="; + for (int64_t meshShapeValue : meshShape) { + options += std::to_string(meshShapeValue) + ","; } + // Remove the last comma. + options.pop_back(); + } + + if (options[options.size() - 1] == ' ') { + options.pop_back(); } - os << "\n"; + + return options; +} + +void OptimizerOverridesHandler::addInputLayoutOverride( + StringRef opName, InputLayoutOverrideParams params) { + inputLayoutOverrides[opName] = params; +} +void OptimizerOverridesHandler::addInputLayoutOverride( + StringRef opName, SmallVector &operandIdxes) { + inputLayoutOverrides[opName] = + InputLayoutOverrideParams{std::move(operandIdxes)}; +} +void OptimizerOverridesHandler::addOutputLayoutOverride( + StringRef opName, OutputLayoutOverrideParams params) { + outputLayoutOverrides[opName] = params; +} +void OptimizerOverridesHandler::addOutputLayoutOverride( + StringRef opName, SmallVector &grid, BufferType bufferType, + TensorMemoryLayout tensorMemoryLayout, tt::ttnn::Layout memoryLayout, + tt::DataType dataType) { + outputLayoutOverrides[opName] = OutputLayoutOverrideParams{ + std::move(grid), bufferType, tensorMemoryLayout, memoryLayout, dataType}; } } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Utils/PassOverrides.cpp b/lib/Dialect/TTNN/Utils/PassOverrides.cpp new file mode 100644 index 000000000..9c8ef2be1 --- /dev/null +++ b/lib/Dialect/TTNN/Utils/PassOverrides.cpp @@ -0,0 +1,206 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" + +namespace mlir::tt::ttnn { + +bool OutputLayoutOverrideParser::parse( + llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value) { + SmallVector opOverrideList; + constexpr size_t kMaxGridSize = 2; + constexpr size_t kvPairSize = 2; + constexpr size_t kMaxLayoutOverrideParams = 5; + constexpr size_t iOpName = 0; + constexpr size_t iLayoutOverrideParams = 1; + constexpr size_t iGrid = 0; + constexpr size_t iMemorySpace = 1; + constexpr size_t iTensorMemoryLayout = 2; + constexpr size_t iMemoryLayout = 3; + constexpr size_t iDataType = 4; + constexpr char opSeparator = ','; + constexpr char opNameSeparator = '='; + constexpr char paramSepataor = ':'; + constexpr char gridSeparator = 'x'; + + arg.split(opOverrideList, opSeparator); + for (const StringRef override : opOverrideList) { + SmallVector opOverrideParts; + override.split(opOverrideParts, opNameSeparator); + if (opOverrideParts.size() != kvPairSize) { + opt.error("Invalid format for override grid sizes: " + override); + return true; + } + + SmallVector layoutParamParts; + // Split into layout parameters. + opOverrideParts[iLayoutOverrideParams].split(layoutParamParts, + paramSepataor); + if (layoutParamParts.size() != kMaxLayoutOverrideParams) { + opt.error("Invalid number of layout parameters: " + + std::to_string(layoutParamParts.size())); + return true; + } + + // Parse grid. + SmallVector grid; + SmallVector gridParts; + layoutParamParts[iGrid].split(gridParts, gridSeparator); + for (const StringRef gridPart : gridParts) { + int64_t gridValue; + if (gridPart.getAsInteger(10 /*Radix*/, gridValue)) { + opt.error("Invalid grid size: " + gridPart); + return true; + } + grid.push_back(gridValue); + } + + // Parse memory space. + std::optional bufferType = + symbolizeBufferType(layoutParamParts[iMemorySpace]); + if (!bufferType.has_value()) { + opt.error("Invalid memory space: " + layoutParamParts[iMemorySpace]); + return true; + } + + // Parse tensor memory layout. + std::optional tensorMemoryLayout = + symbolizeTensorMemoryLayout(layoutParamParts[iTensorMemoryLayout]); + if (!tensorMemoryLayout.has_value()) { + opt.error("Invalid tensor memory layout: " + + layoutParamParts[iTensorMemoryLayout]); + return true; + } + + // Parse memory layout. + std::optional memoryLayout = + mlir::tt::ttnn::symbolizeLayout(layoutParamParts[iMemoryLayout]); + if (!memoryLayout.has_value()) { + opt.error("Invalid memory layout: " + layoutParamParts[iMemoryLayout]); + return true; + } + + // Parse data type. + std::optional dataType = + mlir::tt::DataTypeStringToEnum(layoutParamParts[iDataType]); + if (!dataType.has_value()) { + opt.error("Invalid data type: " + layoutParamParts[iDataType]); + return true; + } + + // Set parsed op overrides. + value[opOverrideParts[iOpName]] = OutputLayoutOverrideParams{ + std::move(grid), bufferType.value(), tensorMemoryLayout.value(), + memoryLayout.value(), dataType.value()}; + } + return false; +} + +std::string OutputLayoutOverrideParser::toString( + const llvm::StringMap &value) { + std::string res; + size_t count = 0; + for (const auto &entry : value) { + res += std::string(entry.getKey()) + "="; + const OutputLayoutOverrideParams ¶ms = entry.getValue(); + // Print grid values + for (size_t i = 0; i < params.grid.size(); ++i) { + res += std::to_string(params.grid[i]); + if (i < params.grid.size() - 1) { + res += "x"; + } + } + // Print memory space and memory layout + res += ":" + + std::string(mlir::tt::ttnn::stringifyBufferType(params.bufferType)); + res += ":" + std::string(mlir::tt::ttnn::stringifyTensorMemoryLayout( + params.tensorMemoryLayout)); + res += + ":" + std::string(mlir::tt::ttnn::stringifyLayout(params.memoryLayout)); + res += ":" + std::string(mlir::tt::DataTypeEnumToString(params.dataType)); + if (++count < value.size()) { + res += ","; + } + } + return res; +} + +void OutputLayoutOverrideParser::print( + llvm::raw_ostream &os, + const llvm::StringMap &value) { + os << "override-output-layout="; + os << OutputLayoutOverrideParser::toString(value); + os << "\n"; +} + +bool InputLayoutOverrideParser::parse( + llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value) { + SmallVector opOverrideList; + constexpr size_t kvPairSize = 2; + constexpr size_t iOpName = 0; + constexpr size_t iOperands = 1; + constexpr char opSeparator = ','; + constexpr char opNameSeparator = '='; + constexpr char opParamSeparator = ':'; + + arg.split(opOverrideList, opSeparator); + for (const StringRef override : opOverrideList) { + SmallVector opOverrideParts; + override.split(opOverrideParts, opNameSeparator); + if (opOverrideParts.size() != kvPairSize) { + opt.error("Invalid format for input layouts override: " + override); + return true; + } + + SmallVector operandIndexes; + SmallVector operandIndexParts; + + // Parse operand indexes. + opOverrideParts[iOperands].split(operandIndexParts, opParamSeparator); + for (const StringRef operandIndexPart : operandIndexParts) { + int64_t operandIndexValue; + if (operandIndexPart.getAsInteger(10 /*Radix*/, operandIndexValue)) { + opt.error("Invalid operand index: " + operandIndexPart); + return true; + } + operandIndexes.push_back(operandIndexValue); + } + + // Set parsed op overrides. + value[opOverrideParts[iOpName]] = + InputLayoutOverrideParams{std::move(operandIndexes)}; + } + return false; +} + +std::string InputLayoutOverrideParser::toString( + const llvm::StringMap &value) { + std::string res; + size_t count = 0; + for (const auto &entry : value) { + res += std::string(entry.getKey()) + "="; + const InputLayoutOverrideParams ¶ms = entry.getValue(); + for (int64_t operandIdx : params.operandIdxes) { + res += std::to_string(operandIdx) + ":"; + } + // Remove the last colon. + res.pop_back(); + if (++count < value.size()) { + res += ","; + } + } + return res; +} + +void InputLayoutOverrideParser::print( + llvm::raw_ostream &os, + const llvm::StringMap &value) { + os << "insert-memreconfig="; + os << InputLayoutOverrideParser::toString(value); + os << "\n"; +} + +} // namespace mlir::tt::ttnn diff --git a/test/unittests/Optimizer/CMakeLists.txt b/test/unittests/Optimizer/CMakeLists.txt index 681d78ff0..4e6ee799a 100644 --- a/test/unittests/Optimizer/CMakeLists.txt +++ b/test/unittests/Optimizer/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(OptimizerTests TestShardSolver.cpp + TestOptimizerOverrides.cpp ) target_link_libraries(OptimizerTests diff --git a/test/unittests/Optimizer/TestOptimizerOverrides.cpp b/test/unittests/Optimizer/TestOptimizerOverrides.cpp new file mode 100644 index 000000000..c75fde21f --- /dev/null +++ b/test/unittests/Optimizer/TestOptimizerOverrides.cpp @@ -0,0 +1,433 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h" + +using namespace mlir::tt::ttnn; + +class TestOptimizerOverrides : public ::testing::Test { + +public: + OptimizerOverridesHandler optimizerOverridesHandler; + + void SetUp() override {} + + llvm::StringMap createInputLayoutOverrides() { + + // struct InputLayoutOverrideParams { + // SmallVector operandIdxes; + // }; + + llvm::StringMap inputLayoutOverrides; + + // Create input layout overrides for 3 input overrides. + inputLayoutOverrides["input0"] = createInputLayoutOverrideParams(); + inputLayoutOverrides["input1"] = createInputLayoutOverrideParams(); + inputLayoutOverrides["input2"] = createInputLayoutOverrideParams(); + + return inputLayoutOverrides; + } + + InputLayoutOverrideParams createInputLayoutOverrideParams() { + + InputLayoutOverrideParams inputLayoutOverrideParams; + + // Create input layout override params for 2 operands. + // Their operand indexes are 0 and 1, respectively. + inputLayoutOverrideParams.operandIdxes.push_back(0); + inputLayoutOverrideParams.operandIdxes.push_back(1); + + return inputLayoutOverrideParams; + } + + llvm::StringMap createOutputLayoutOverrides() { + + llvm::StringMap outputLayoutOverrides; + + // Create output layout overrides for 3 output overrides. + outputLayoutOverrides["output0"] = createOutputLayoutOverrideParams_0(); + outputLayoutOverrides["output1"] = createOutputLayoutOverrideParams_1(); + outputLayoutOverrides["output2"] = createOutputLayoutOverrideParams_2(); + + return outputLayoutOverrides; + } + + OutputLayoutOverrideParams createOutputLayoutOverrideParams_0() { + + // struct OutputLayoutOverrideParams { + // SmallVector grid; + // BufferType; + // TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + // Layout memoryLayout; // ROW_MAJOR / TILE + // mlir::tt::DataType dataType; + // }; + + OutputLayoutOverrideParams outputLayoutOverrideParams; + + // Output 0 has + // - grid size 2x2, + // - buffer type dram + // - tensor memory layout interleaved + // - memory layout tile + // - data type fp16. + outputLayoutOverrideParams.grid.push_back(2); + outputLayoutOverrideParams.grid.push_back(2); + outputLayoutOverrideParams.bufferType = BufferType::DRAM; + outputLayoutOverrideParams.tensorMemoryLayout = + TensorMemoryLayout::Interleaved; + outputLayoutOverrideParams.memoryLayout = Layout::Tile; + outputLayoutOverrideParams.dataType = mlir::tt::DataType::Float16; + + return outputLayoutOverrideParams; + } + + OutputLayoutOverrideParams createOutputLayoutOverrideParams_1() { + + // struct OutputLayoutOverrideParams { + // SmallVector grid; + // BufferType; + // TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + // Layout memoryLayout; // ROW_MAJOR / TILE + // mlir::tt::DataType dataType; + // }; + + OutputLayoutOverrideParams outputLayoutOverrideParams; + + // Output 1 has + // - grid size 8x4, + // - buffer type l1 + // - tensor memory layout block_sharded + // - memory layout row_major + // - data type fp16. + outputLayoutOverrideParams.grid.push_back(8); + outputLayoutOverrideParams.grid.push_back(4); + outputLayoutOverrideParams.bufferType = BufferType::L1; + outputLayoutOverrideParams.tensorMemoryLayout = + TensorMemoryLayout::BlockSharded; + outputLayoutOverrideParams.memoryLayout = Layout::RowMajor; + outputLayoutOverrideParams.dataType = mlir::tt::DataType::Float16; + + return outputLayoutOverrideParams; + } + + OutputLayoutOverrideParams createOutputLayoutOverrideParams_2() { + + // struct OutputLayoutOverrideParams { + // SmallVector grid; + // BufferType; + // TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + // Layout memoryLayout; // ROW_MAJOR / TILE + // mlir::tt::DataType dataType; + // }; + + OutputLayoutOverrideParams outputLayoutOverrideParams; + + // Output 2 has + // - grid size 3x6, + // - buffer type system + // - tensor memory layout height_sharded + // - memory layout tile + // - data type fp16. + outputLayoutOverrideParams.grid.push_back(3); + outputLayoutOverrideParams.grid.push_back(6); + outputLayoutOverrideParams.bufferType = BufferType::SystemMemory; + outputLayoutOverrideParams.tensorMemoryLayout = + TensorMemoryLayout::HeightSharded; + outputLayoutOverrideParams.memoryLayout = Layout::Tile; + outputLayoutOverrideParams.dataType = mlir::tt::DataType::Float16; + + return outputLayoutOverrideParams; + } + + bool + compareInputLayoutOverrides(llvm::StringMap in1, + llvm::StringMap in2) { + // Check if the sizes of the two input layout overrides are the same. + if (in1.size() != in2.size()) { + return false; + } + llvm::StringMap::iterator it1; + for (it1 = in1.begin(); it1 != in1.end(); it1++) { + // Check if the two input layout overrides have the same keys. + llvm::StringMap::iterator it2 = + in2.find(it1->getKey()); + if (it2 == in2.end()) { + return false; + } + // Check if the two input layout overrides have the same values. + // The structure InputLayoutOverrideParams has overloaded operators for == + // and !=, so we can compare the objects in this way. + if (it1->getValue() != it2->getValue()) { + return false; + } + } + return true; + } + + bool compareOutputLayoutOverrides( + llvm::StringMap out1, + llvm::StringMap out2) { + // Check if the sizes of the two output layout overrides are the same. + if (out1.size() != out2.size()) { + return false; + } + llvm::StringMap::iterator it1; + for (it1 = out1.begin(); it1 != out1.end(); it1++) { + // Check if the two output layout overrides have the same keys. + llvm::StringMap::iterator it2 = + out2.find(it1->getKey()); + if (it2 == out2.end()) { + return false; + } + // Check if the two output layout overrides have the same values. + // The structure OutputLayoutOverrideParams has overloaded operators for + // == and !=, so we can compare the objects in this way. + if (it1->getValue() != it2->getValue()) { + return false; + } + } + return true; + } + + void TearDown() override {} +}; + +// Test the setEnableOptimizer method +TEST_F(TestOptimizerOverrides, TestSetOptimizerPass) { + + optimizerOverridesHandler.setEnableOptimizer(true); + ASSERT_TRUE(optimizerOverridesHandler.getEnableOptimizer()); + + optimizerOverridesHandler.setEnableOptimizer(false); + ASSERT_FALSE(optimizerOverridesHandler.getEnableOptimizer()); +} + +// Test the setMemoryConfig method +TEST_F(TestOptimizerOverrides, TestSetMemoryConfig) { + + optimizerOverridesHandler.setMemoryReconfig(true); + ASSERT_TRUE(optimizerOverridesHandler.getMemoryReconfig()); + + optimizerOverridesHandler.setMemoryReconfig(false); + ASSERT_FALSE(optimizerOverridesHandler.getMemoryReconfig()); +} + +// Test the setMemoryLayoutAnalysis method +TEST_F(TestOptimizerOverrides, TestSetMemoryLayoutAnalysis) { + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysis(true); + ASSERT_TRUE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysis()); + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysis(false); + ASSERT_FALSE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysis()); +} + +// Test the setEnableMemoryLayoutAnalysisPolicy method +TEST_F(TestOptimizerOverrides, TestSetEnableMemoryLayoutAnalysisPolicy) { + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysisPolicy(true); + ASSERT_TRUE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysisPolicy()); + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysisPolicy(false); + ASSERT_FALSE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysisPolicy()); +} + +// Test the setMemoryLayoutAnalysisPolicy method +TEST_F(TestOptimizerOverrides, TestSetMemoryLayoutAnalysisPolicy) { + + optimizerOverridesHandler.setMemoryLayoutAnalysisPolicy( + mlir::tt::MemoryLayoutAnalysisPolicyType::DFSharding); + ASSERT_EQ(optimizerOverridesHandler.getMemoryLayoutAnalysisPolicy(), + mlir::tt::MemoryLayoutAnalysisPolicyType::DFSharding); + + optimizerOverridesHandler.setMemoryLayoutAnalysisPolicy( + mlir::tt::MemoryLayoutAnalysisPolicyType::L1Interleaved); + ASSERT_EQ(optimizerOverridesHandler.getMemoryLayoutAnalysisPolicy(), + mlir::tt::MemoryLayoutAnalysisPolicyType::L1Interleaved); +} + +// Test the setInputLayoutOverrides method +TEST_F(TestOptimizerOverrides, TestSetInputLayoutOverrides) { + + llvm::StringMap inputLayoutOverrides = + createInputLayoutOverrides(); + + optimizerOverridesHandler.setInputLayoutOverrides(inputLayoutOverrides); + ASSERT_TRUE(compareInputLayoutOverrides( + optimizerOverridesHandler.getInputLayoutOverrides(), + inputLayoutOverrides)); +} + +// Test the setOutputLayoutOverrides method +TEST_F(TestOptimizerOverrides, TestSetOutputLayoutOverrides) { + + llvm::StringMap outputLayoutOverrides = + createOutputLayoutOverrides(); + + optimizerOverridesHandler.setOutputLayoutOverrides(outputLayoutOverrides); + ASSERT_TRUE(compareOutputLayoutOverrides( + optimizerOverridesHandler.getOutputLayoutOverrides(), + outputLayoutOverrides)); +} + +// Test the addInputLayoutOverride method passing the whole object +TEST_F(TestOptimizerOverrides, TestAddInputLayoutOverrideObject) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the first function, which takes the whole object as a + // parameter. + + llvm::StringMap inputLayoutOverrides = + createInputLayoutOverrides(); + + optimizerOverridesHandler.addInputLayoutOverride( + "input0", createInputLayoutOverrideParams()); + optimizerOverridesHandler.addInputLayoutOverride( + "input1", createInputLayoutOverrideParams()); + optimizerOverridesHandler.addInputLayoutOverride( + "input2", createInputLayoutOverrideParams()); + + ASSERT_TRUE(compareInputLayoutOverrides( + optimizerOverridesHandler.getInputLayoutOverrides(), + inputLayoutOverrides)); +} + +// Test the addInputLayoutOverride method passing the individual parameters +TEST_F(TestOptimizerOverrides, TestAddInputLayoutOverrideParams) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the second function, which takes the individual parameters. + + llvm::StringMap inputLayoutOverrides = + createInputLayoutOverrides(); + + llvm::SmallVector operandIdxes1 = {0, 1}; + llvm::SmallVector operandIdxes2 = {0, 1}; + llvm::SmallVector operandIdxes3 = {0, 1}; + + optimizerOverridesHandler.addInputLayoutOverride("input0", operandIdxes1); + optimizerOverridesHandler.addInputLayoutOverride("input1", operandIdxes2); + optimizerOverridesHandler.addInputLayoutOverride("input2", operandIdxes3); + + ASSERT_TRUE(compareInputLayoutOverrides( + optimizerOverridesHandler.getInputLayoutOverrides(), + inputLayoutOverrides)); +} + +// Test the addOutputLayoutOverride method passing the whole object +TEST_F(TestOptimizerOverrides, TestAddOutputLayoutOverrideObject) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the first function, which takes the whole object as a + // parameter. + + llvm::StringMap outputLayoutOverrides = + createOutputLayoutOverrides(); + + optimizerOverridesHandler.addOutputLayoutOverride( + "output0", createOutputLayoutOverrideParams_0()); + optimizerOverridesHandler.addOutputLayoutOverride( + "output1", createOutputLayoutOverrideParams_1()); + optimizerOverridesHandler.addOutputLayoutOverride( + "output2", createOutputLayoutOverrideParams_2()); + + ASSERT_TRUE(compareOutputLayoutOverrides( + optimizerOverridesHandler.getOutputLayoutOverrides(), + outputLayoutOverrides)); +} + +// Test the addOutputLayoutOverride method passing the individual parameters +TEST_F(TestOptimizerOverrides, TestAddOutputLayoutOverrideParams) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the second function, which takes the individual parameters. + + llvm::StringMap outputLayoutOverrides = + createOutputLayoutOverrides(); + + llvm::SmallVector grid1 = {2, 2}; + llvm::SmallVector grid2 = {8, 4}; + llvm::SmallVector grid3 = {3, 6}; + + optimizerOverridesHandler.addOutputLayoutOverride( + "output0", grid1, BufferType::DRAM, TensorMemoryLayout::Interleaved, + Layout::Tile, mlir::tt::DataType::Float16); + optimizerOverridesHandler.addOutputLayoutOverride( + "output1", grid2, BufferType::L1, TensorMemoryLayout::BlockSharded, + Layout::RowMajor, mlir::tt::DataType::Float16); + optimizerOverridesHandler.addOutputLayoutOverride( + "output2", grid3, BufferType::SystemMemory, + TensorMemoryLayout::HeightSharded, Layout::Tile, + mlir::tt::DataType::Float16); + + ASSERT_TRUE(compareOutputLayoutOverrides( + optimizerOverridesHandler.getOutputLayoutOverrides(), + outputLayoutOverrides)); +} + +// Test the setSystemDescPath method +TEST_F(TestOptimizerOverrides, TestSetSystemDescPath) { + + optimizerOverridesHandler.setSystemDescPath("system_desc_path"); + ASSERT_EQ(optimizerOverridesHandler.getSystemDescPath(), "system_desc_path"); +} + +// Test the setMaxLegalLayouts method +TEST_F(TestOptimizerOverrides, TestSetMaxLegalLayouts) { + + optimizerOverridesHandler.setMaxLegalLayouts(10); + ASSERT_EQ(optimizerOverridesHandler.getMaxLegalLayouts(), 10); +} + +// Test the setMeshShape method +TEST_F(TestOptimizerOverrides, TestSetMeshShape) { + + std::vector meshShape; + meshShape.push_back(1); + meshShape.push_back(2); + + optimizerOverridesHandler.setMeshShape(meshShape); + ASSERT_EQ(optimizerOverridesHandler.getMeshShape()[0], meshShape[0]); + ASSERT_EQ(optimizerOverridesHandler.getMeshShape()[1], meshShape[1]); +} + +// Test the toString method +TEST_F(TestOptimizerOverrides, TestToString) { + + std::string options; + options += + "enable-optimizer=true "; // The optimizer pass is enabled by default. + options += "memreconfig-enabled=true "; + options += "memory-layout-analysis-enabled=true "; + options += "insert-memreconfig=add_0_1_2=0 "; + options += + "override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32"; + + llvm::SmallVector operandIdxes = {0}; + llvm::SmallVector grid = {1, 1}; + + optimizerOverridesHandler.setEnableOptimizer(true); + optimizerOverridesHandler.setEnableMemoryLayoutAnalysis(true); + optimizerOverridesHandler.setMemoryReconfig(true); + optimizerOverridesHandler.addInputLayoutOverride("add_0_1_2", operandIdxes); + optimizerOverridesHandler.addOutputLayoutOverride( + "add_1_2", grid, BufferType::DRAM, TensorMemoryLayout::Interleaved, + Layout::RowMajor, mlir::tt::DataType::Float32); + + ASSERT_EQ(optimizerOverridesHandler.toString(), options); +}