diff --git a/.github/workflows/build-and-test.yml b/.github/workflows/build-and-test.yml index ade377c06a..32e4eee3d6 100644 --- a/.github/workflows/build-and-test.yml +++ b/.github/workflows/build-and-test.yml @@ -47,8 +47,9 @@ jobs: fail-fast: false matrix: build: [ - {runs-on: ubuntu-latest, enable_perf: OFF, name: "run", ttrt_flags: ""}, - {runs-on: ubuntu-latest, enable_perf: ON, name: "perf", ttrt_flags: ""}, + {runs-on: ubuntu-latest, enable_perf: OFF, enable_op_model: OFF, name: "run", ttrt_flags: ""}, + {runs-on: ubuntu-latest, enable_perf: ON, enable_op_model: OFF, name: "perf", ttrt_flags: ""}, + {runs-on: ubuntu-latest, enable_perf: OFF, enable_op_model: ON, name: "op_model" , ttrt_flags: ""} ] name: Build tt-mlir @@ -78,7 +79,7 @@ jobs: uses: hendrikmuhs/ccache-action@v1.2 with: create-symlink: true - key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-${{ env.SDK_VERSION }} + key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }}-${{ env.SDK_VERSION }} # Build project @@ -97,6 +98,7 @@ jobs: -DTTMLIR_ENABLE_RUNTIME_TESTS=ON \ -DTT_RUNTIME_ENABLE_PERF_TRACE=${{ matrix.build.enable_perf }} \ -DTTMLIR_ENABLE_STABLEHLO=ON \ + -DTTMLIR_ENABLE_OP_MODEL=${{ matrix.build.enable_op_model }} \ -S ${{ steps.strings.outputs.work-dir }} - name: Build @@ -147,7 +149,7 @@ jobs: - name: Upload Test Report uses: actions/upload-artifact@v4 with: - name: test-reports-${{ matrix.build.runs-on }}-perf-${{ matrix.build.enable_perf }} + name: test-reports-${{ matrix.build.runs-on }}-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }} path: build/test/report.xml - name: Show Test Report @@ -480,7 +482,7 @@ jobs: uses: hendrikmuhs/ccache-action@v1.2 with: create-symlink: true - key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-${{ env.SDK_VERSION }} + key: ${{ matrix.build.runs-on }}-run-ON-perf-${{ matrix.build.enable_perf }}-op_model-${{ matrix.build.enable_op_model }}-${{ env.SDK_VERSION }} - name: Configure CMake shell: bash @@ -496,6 +498,7 @@ jobs: -DTTMLIR_ENABLE_RUNTIME_TESTS=OFF \ -DTT_RUNTIME_ENABLE_PERF_TRACE=${{ matrix.build.enable_perf }} \ -DTTMLIR_ENABLE_STABLEHLO=OFF \ + -DTTMLIR_ENABLE_OP_MODEL=${{ matrix.build.enable_op_model }} \ -S ${{ steps.strings.outputs.work-dir }} - name: Build tt-explorer diff --git a/.github/workflows/nightly-uplift.yml b/.github/workflows/nightly-uplift.yml index b8dbf3d05c..54dd758aed 100644 --- a/.github/workflows/nightly-uplift.yml +++ b/.github/workflows/nightly-uplift.yml @@ -62,8 +62,11 @@ jobs: echo "Pull Request URL - ${{ steps.create-pr.outputs.pull-request-url }}" gh pr review ${{ steps.create-pr.outputs.pull-request-number }} --approve - - name: Enable Pull Request Automerge - if: ${{ steps.create-pr.outputs.pull-request-number }} - run: gh pr merge --squash --auto "${{ steps.create-pr.outputs.pull-request-number }}" - env: - GH_TOKEN: ${{ secrets.GH_TOKEN }} + # Note: Dissable auto-merge for now until we are more confident + # that uplift won't break the downstream projects + # + # - name: Enable Pull Request Automerge + # if: ${{ steps.create-pr.outputs.pull-request-number }} + # run: gh pr merge --squash --auto "${{ steps.create-pr.outputs.pull-request-number }}" + # env: + # GH_TOKEN: ${{ secrets.GH_TOKEN }} diff --git a/CMakeLists.txt b/CMakeLists.txt index 54fcc89d47..2927fb5602 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -8,6 +8,7 @@ endif() option(TT_RUNTIME_ENABLE_PERF_TRACE "Enable performance mode" OFF) option(TTMLIR_ENABLE_RUNTIME "Enable runtime" OFF) option(TTMLIR_ENABLE_STABLEHLO "Enable StableHLO support" OFF) +option(TTMLIR_ENABLE_OP_MODEL "Enable OpModel support" OFF) if (TTMLIR_ENABLE_STABLEHLO) add_compile_definitions(TTMLIR_ENABLE_STABLEHLO) @@ -20,6 +21,11 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(TTMLIR_ENABLE_BINDINGS_PYTHON ON CACHE BOOL "Enable Python bindings") +if (APPLE) + set(TTMLIR_ENABLE_OP_MODEL OFF) + message(WARNING "TTNNOpModelLib is disabled on Apple platforms. Optimizer will not get true performance.") +endif() + list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/cmake/modules) if (TT_RUNTIME_ENABLE_PERF_TRACE) diff --git a/include/ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h b/include/ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h index acd5373c90..5f1feb08b2 100644 --- a/include/ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h +++ b/include/ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h @@ -7,11 +7,15 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" namespace mlir::tt { +void populateTosaToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter); + std::unique_ptr> createConvertTosaToTTIRPass(); } // namespace mlir::tt -#endif +#endif // TTMLIR_CONVERSION_TOSATOTTIR_TOSATOTTIR_H diff --git a/include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h b/include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h index 16fafe551a..4a44e883da 100644 --- a/include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h +++ b/include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h @@ -27,18 +27,23 @@ struct MemoryLayoutAnalysisPolicyTypeParser return false; } - static void print(llvm::raw_ostream &os, - const MemoryLayoutAnalysisPolicyType &value) { - llvm::StringRef policy; + static std::string toString(const MemoryLayoutAnalysisPolicyType &value) { + std::string res; switch (value) { case MemoryLayoutAnalysisPolicyType::DFSharding: - policy = "DFSharding"; + res += "DFSharding"; break; case MemoryLayoutAnalysisPolicyType::L1Interleaved: - policy = "L1Interleaved"; + res += "L1Interleaved"; break; } - os << "memory-layout-analysis-policy=" << policy << "\n"; + return res; + } + + static void print(llvm::raw_ostream &os, + const MemoryLayoutAnalysisPolicyType &value) { + os << "memory-layout-analysis-policy=" + << MemoryLayoutAnalysisPolicyTypeParser::toString(value) << "\n"; } }; diff --git a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h index 48c723e1cd..636d5f6238 100644 --- a/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h +++ b/include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h @@ -6,7 +6,8 @@ #define TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H #include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h" -#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h" +#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" +#include "ttmlir/Dialect/TTNN/Utils/Utils.h" #include "mlir/Pass/PassOptions.h" diff --git a/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h b/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h index db24eeb287..c474106e3a 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h +++ b/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h @@ -5,50 +5,98 @@ #ifndef TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H #define TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H -#include - -#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h" +#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h" +#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" namespace mlir::tt::ttnn { -struct OutputLayoutOverrideParams { - SmallVector grid; - BufferType bufferType; - TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... - Layout memoryLayout; // ROW_MAJOR / TILE - tt::DataType dataType; -}; +class OptimizerOverridesHandler { +public: + OptimizerOverridesHandler() {}; + ~OptimizerOverridesHandler() {}; -struct InputLayoutOverrideParams { - SmallVector operandIdxes; -}; + // Setters for the overrides + // These are used to enable/disable the optimizer passes + void setEnableOptimizer(bool); + // These are used to enable/disable the memory configurations + void setMemoryReconfig(bool); + void setEnableMemoryLayoutAnalysis(bool); + void setEnableMemoryLayoutAnalysisPolicy(bool); + void setMemoryLayoutAnalysisPolicy(MemoryLayoutAnalysisPolicyType); + // These are used to set the input/output layout overrides + void setInputLayoutOverrides(llvm::StringMap &); + void setOutputLayoutOverrides(llvm::StringMap &); + // These are used to add system descriptor path + void setSystemDescPath(std::string); + // These are used to set the maximum number of legal layouts for grid analysis + void setMaxLegalLayouts(int64_t); + // These are used to set the mesh shape + void setMeshShape(std::vector); -struct OutputLayoutOverrideParser - : public llvm::cl::parser> { -public: - OutputLayoutOverrideParser(llvm::cl::Option &opt) - : llvm::cl::parser>(opt) {} + // Getters for the overrides + // These are used to get the current state of the optimizer passes + bool getEnableOptimizer() const; + // These are used to get the current state of the memory configurations + bool getMemoryReconfig() const; + bool getEnableMemoryLayoutAnalysis() const; + bool getEnableMemoryLayoutAnalysisPolicy() const; + MemoryLayoutAnalysisPolicyType getMemoryLayoutAnalysisPolicy() const; + // These are used to get the current input/output layout overrides + llvm::StringMap getInputLayoutOverrides() const; + llvm::StringMap getOutputLayoutOverrides() const; + // These are used to get the current system descriptor path + std::string getSystemDescPath() const; + // These are used to get the current maximum number of legal layouts for grid + // analysis + int64_t getMaxLegalLayouts() const; + // These are used to get the current mesh shape + std::vector getMeshShape() const; - bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, - llvm::StringMap &value); + // Method that converts the overrides to a string + std::string toString() const; - static void print(llvm::raw_ostream &os, - const llvm::StringMap &value); -}; + // Fill input/output layout overrides maps. + // This is used from tt-forge frontend where we define and compile the models. + void addInputLayoutOverride(StringRef, InputLayoutOverrideParams); + void addInputLayoutOverride(StringRef, SmallVector &); + void addOutputLayoutOverride(StringRef, OutputLayoutOverrideParams); + void addOutputLayoutOverride(StringRef, SmallVector &, BufferType, + TensorMemoryLayout, tt::ttnn::Layout, + tt::DataType); -struct InputLayoutOverrideParser - : public llvm::cl::parser> { -public: - InputLayoutOverrideParser(llvm::cl::Option &opt) - : llvm::cl::parser>(opt) {} +private: + // Options for the TTIR to TTNN backend pipeline, + // we use them to extract the names and the deafulat values. + TTIRToTTNNBackendPipelineOptions pipelineOptions; + + // Flags for enabling/disabling the optimizer passes + bool enableOptimizer = false; + + // Flags for enabling/disabling the memory configurations + bool enableMemoryReconfig = true; + bool enableMemoryLayoutAnalysis = false; + + // Input layout overrides + llvm::StringMap inputLayoutOverrides; + + // Output layout overrides + llvm::StringMap outputLayoutOverrides; + + // Memory layout analysis policy + bool enableMemoryLayoutAnalysisPolicy = false; + MemoryLayoutAnalysisPolicyType memoryLayoutAnalysisPolicy; + + // System descriptor path + std::string systemDescPath; + + // Maximum number of legal layouts for grid analysis + int64_t maxLegalLayouts = 0; - bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, - llvm::StringMap &value); + // Mesh shape + std::vector meshShape; - static void print(llvm::raw_ostream &os, - const llvm::StringMap &value); -}; +}; // class OptimizerOverridesHandler } // namespace mlir::tt::ttnn diff --git a/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h b/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h new file mode 100644 index 0000000000..09e587c9c3 --- /dev/null +++ b/include/ttmlir/Dialect/TTNN/Utils/PassOverrides.h @@ -0,0 +1,91 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H +#define TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H + +#include + +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" + +namespace mlir::tt::ttnn { + +struct OutputLayoutOverrideParams { + + SmallVector grid; + BufferType bufferType; + TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + Layout memoryLayout; // ROW_MAJOR / TILE + mlir::tt::DataType dataType; + + bool operator==(const OutputLayoutOverrideParams rhs) const { + return grid[0] == rhs.grid[0] && grid[1] == rhs.grid[1] && + bufferType == rhs.bufferType && + tensorMemoryLayout == rhs.tensorMemoryLayout && + memoryLayout == rhs.memoryLayout && dataType == rhs.dataType; + } + + bool operator!=(const OutputLayoutOverrideParams &rhs) const { + return !(*this == rhs); + } +}; + +struct InputLayoutOverrideParams { + + SmallVector operandIdxes; + + bool operator==(const InputLayoutOverrideParams &rhs) const { + if (operandIdxes.size() != rhs.operandIdxes.size()) { + return false; + } + for (std::size_t i = 0; i < operandIdxes.size(); i++) { + if (operandIdxes[i] != rhs.operandIdxes[i]) { + return false; + } + } + return true; + } + + bool operator!=(const InputLayoutOverrideParams &rhs) const { + return !(*this == rhs); + } +}; + +struct OutputLayoutOverrideParser + : public llvm::cl::parser> { +public: + OutputLayoutOverrideParser(llvm::cl::Option &opt) + : llvm::cl::parser>(opt) {} + + bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value); + + static std::string + toString(const llvm::StringMap &); + + static void print(llvm::raw_ostream &os, + const llvm::StringMap &value); +}; + +struct InputLayoutOverrideParser + : public llvm::cl::parser> { +public: + InputLayoutOverrideParser(llvm::cl::Option &opt) + : llvm::cl::parser>(opt) {} + + bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value); + + static std::string + toString(const llvm::StringMap &); + + static void print(llvm::raw_ostream &os, + const llvm::StringMap &value); +}; + +} // namespace mlir::tt::ttnn + +#endif // TTMLIR_DIALECT_TTNN_UTILS_PASSOVERRIDES_H diff --git a/include/ttmlir/Dialect/TTNN/Utils/Utils.h b/include/ttmlir/Dialect/TTNN/Utils/Utils.h index 533235a610..d7d8fbdd30 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/Utils.h +++ b/include/ttmlir/Dialect/TTNN/Utils/Utils.h @@ -5,6 +5,8 @@ #ifndef TTMLIR_DIALECT_TTNN_UTILS_UTILS_H #define TTMLIR_DIALECT_TTNN_UTILS_UTILS_H +#include + #include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h" diff --git a/include/ttmlir/OpModel/TTNN/TTNNOpModel.h b/include/ttmlir/OpModel/TTNN/TTNNOpModel.h new file mode 100644 index 0000000000..31ac149849 --- /dev/null +++ b/include/ttmlir/OpModel/TTNN/TTNNOpModel.h @@ -0,0 +1,24 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_OPMODEL_TTNN_TTNNOPMODEL_H +#define TTMLIR_OPMODEL_TTNN_TTNNOPMODEL_H + +#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h" + +#include + +namespace mlir::tt::op_model::ttnn { + +struct ReluOpInterface { + static bool isLegal(const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout); + + static std::tuple + getOpL1Usage(const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout); +}; + +} // namespace mlir::tt::op_model::ttnn +#endif // TTMLIR_OPMODEL_TTNN_TTNNOPMODEL_H diff --git a/include/ttmlir/Target/Common/types.fbs b/include/ttmlir/Target/Common/types.fbs index 2d67ee1d1c..3e7ed425f7 100644 --- a/include/ttmlir/Target/Common/types.fbs +++ b/include/ttmlir/Target/Common/types.fbs @@ -11,67 +11,67 @@ struct Dim2dRange { } enum Arch: uint { - Grayskull = 0, - Wormhole_b0 = 1, - Blackhole = 2, + Grayskull, + Wormhole_b0, + Blackhole } enum DataType: uint16 { - Float32 = 0, - Float16 = 1, - BFloat16 = 2, - BFP_Float8 = 3, - BFP_BFloat8 = 4, - BFP_Float4 = 5, - BFP_BFloat4 = 6, - BFP_Float2 = 7, - BFP_BFloat2 = 8, - UInt32 = 9, - UInt16 = 10, - UInt8 = 11, + Float32, + Float16, + BFloat16, + BFP_Float8, + BFP_BFloat8, + BFP_Float4, + BFP_BFloat4, + BFP_Float2, + BFP_BFloat2, + UInt32, + UInt16, + UInt8, } enum OOBVal: ushort { - Undef = 0, - Zero = 1, - One = 2, - Inf = 3, - NegInf = 4, + Undef, + Zero, + One, + Inf, + NegInf, } enum MemorySpace: ushort { - System = 0, - SystemMMIO = 1, - DeviceDRAM = 2, - DeviceL1 = 3, + System, + SystemMMIO, + DeviceDRAM, + DeviceL1, } enum ChipCapability: uint32 (bit_flags) { - PCIE = 0, - HostMMIO = 1, + PCIE, + HostMMIO, } enum TensorMemoryLayout: ushort { - None = 0, - Interleaved = 1, - SingleBank = 2, - HeightSharded = 3, - WidthSharded = 4, - BlockSharded = 5, + None, + Interleaved, + SingleBank, + HeightSharded, + WidthSharded, + BlockSharded, } enum TensorLayout: ushort { - RowMajor = 0, - Tile = 1, - Invalid = 2, + RowMajor, + Tile, + Invalid, } enum BufferType: ushort { - DRAM = 0, - L1 = 1, - SystemMemory = 2, - L1Small = 3, - Trace = 4, + DRAM, + L1, + SystemMemory, + L1Small, + Trace, } // TODO (#620): Add other fields like core_ranges, shard orientation etc. @@ -197,8 +197,8 @@ table ChipPhysicalCores { enum CPURole: uint8 { - Host = 0, - Device = 1, + Host, + Device, } table CPUDesc { @@ -223,9 +223,11 @@ table EventRef { global_id: uint32; } +// Explicit non-sequential enumeration copied over from tt-metal definition of +// `enum class MathFidelity`. enum MathFidelity : uint8 { - LoFi = 0, - HiFi2 = 2, - HiFi3 = 3, - HiFi4 = 4, + LoFi = 0, + HiFi2 = 2, + HiFi3 = 3, + HiFi4 = 4, } diff --git a/include/ttmlir/Target/TTMetal/program.fbs b/include/ttmlir/Target/TTMetal/program.fbs index 4fcf966020..52451234b1 100644 --- a/include/ttmlir/Target/TTMetal/program.fbs +++ b/include/ttmlir/Target/TTMetal/program.fbs @@ -3,18 +3,18 @@ include "Common/types.fbs"; namespace tt.target.metal; enum NocIndex : ushort { - Noc0 = 0, - Noc1 = 1, + Noc0, + Noc1, } enum EthType : ushort { - Sender = 0, - Receiver = 1, + Sender, + Receiver, } enum UnpackToDestMode : uint8 { - UnpackToDestFp32 = 0, - Default = 1, + UnpackToDestFp32, + Default, } table NocConfig { @@ -45,17 +45,17 @@ table KernelSource { } enum BinaryType : ushort { - BRISC = 0, - NCRISC = 1, - TRISC0 = 2, - TRISC1 = 3, - TRISC2 = 4, - ERISC = 5, + BRISC, + NCRISC, + TRISC0, + TRISC1, + TRISC2, + ERISC, } enum CoreType : ushort { - WORKER = 0, - ETH = 1, + WORKER, + ETH, } table KernelBinary { diff --git a/include/ttmlir/Target/TTNN/program.fbs b/include/ttmlir/Target/TTNN/program.fbs index 5f486bac93..39535e2f0b 100644 --- a/include/ttmlir/Target/TTNN/program.fbs +++ b/include/ttmlir/Target/TTNN/program.fbs @@ -72,46 +72,46 @@ table ArangeOp { } enum EltwiseOpType: uint32 { - Add = 0, - Multiply = 1, - Subtract = 2, - Relu = 3, - GreaterEqual = 4, - Sqrt = 5, - Div = 6, - Sigmoid = 7, - Reciprocal = 8, - Exp = 9, - Maximum = 10, - Abs = 11, - Neg = 12, - Rsqrt = 13, - Typecast = 14, - Equal = 15, - NotEqual = 16, - LessEqual = 17, - LessThan = 18, - GreaterThan = 19, - LogicalAnd = 20, - LogicalOr = 21, - LogicalNot = 22, - Cbrt = 23, - Minimum = 24, - Ceil = 25, - Sin = 26, - Cos = 27, - Log = 28, - Log1p = 29, - Expm1 = 30, - Sign = 31, - Remainder = 32, - IsFinite = 33, - Floor = 34, - Where = 35, - Gelu = 36, - LogicalXor = 37, - Clamp = 38, - LeakyRelu = 39, + Add, + Multiply, + Subtract, + Relu, + GreaterEqual, + Sqrt, + Div, + Sigmoid, + Reciprocal, + Exp, + Maximum, + Abs, + Neg, + Rsqrt, + Typecast, + Equal, + NotEqual, + LessEqual, + LessThan, + GreaterThan, + LogicalAnd, + LogicalOr, + LogicalNot, + Cbrt, + Minimum, + Ceil, + Sin, + Cos, + Log, + Log1p, + Expm1, + Sign, + Remainder, + IsFinite, + Floor, + Where, + Gelu, + LogicalXor, + Clamp, + LeakyRelu, } table ClampOpParams { @@ -136,9 +136,9 @@ table EltwiseOp { } enum ReductionOpType: uint32 { - Sum = 0, - Mean = 1, - Max = 2, + Sum, + Mean, + Max, } table ReductionOp { diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index c3dc3a4b71..881d6545dc 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,6 +1,7 @@ include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo) include_directories(${TTMLIR_TOOLCHAIN_DIR}/src/stablehlo-build) +add_subdirectory(OpModel) add_subdirectory(CAPI) add_subdirectory(Conversion) add_subdirectory(Dialect) diff --git a/lib/Conversion/TosaToTTIR/CMakeLists.txt b/lib/Conversion/TosaToTTIR/CMakeLists.txt index 41baf75c67..56000eb652 100644 --- a/lib/Conversion/TosaToTTIR/CMakeLists.txt +++ b/lib/Conversion/TosaToTTIR/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_library(TTMLIRTosaToTTIR - TosaToTTIR.cpp + TosaToTTIRPass.cpp + TosaToTTIRPatterns.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/ttmlir/Conversion/TosaToTTIR diff --git a/lib/Conversion/TosaToTTIR/TosaToTTIR.cpp b/lib/Conversion/TosaToTTIR/TosaToTTIR.cpp deleted file mode 100644 index 6c6a7faf56..0000000000 --- a/lib/Conversion/TosaToTTIR/TosaToTTIR.cpp +++ /dev/null @@ -1,122 +0,0 @@ -// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -// -// SPDX-License-Identifier: Apache-2.0 - -#include "ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h" -#include "ttmlir/Dialect/TT/IR/TT.h" -#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" -#include "ttmlir/Dialect/TTIR/IR/TTIR.h" -#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Func/Transforms/FuncConversions.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/DialectConversion.h" - -using namespace mlir; -using namespace tt; - -namespace mlir::tt::ttir { - -#define GEN_PASS_DEF_CONVERTTOSATOTTIR -#include "ttmlir/Conversion/Passes.h.inc" - -} // namespace mlir::tt::ttir - -namespace { - -template -class TosaToTTIROpConversionPattern : public OpConversionPattern { - using OpConversionPattern::OpConversionPattern; - -public: - LogicalResult - matchAndRewrite(SrcOp srcOp, Adaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if constexpr (std::is_same::value) { - assert(srcOp.getShift() == 0); - } - - auto outputType = mlir::cast(srcOp.getResult().getType()); - auto outputTensor = rewriter.create( - srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); - rewriter.replaceOpWithNewOp( - srcOp, TypeRange(outputTensor.getType()), adaptor.getOperands(), - ValueRange(outputTensor), - rewriter.getArrayAttr( - SmallVector(adaptor.getOperands().size() + 1, - rewriter.getAttr( - OperandConstraint::AnyDeviceTile)))); - return success(); - } -}; - -struct ConvertTosaToTTIRPass - : public ttir::impl::ConvertTosaToTTIRBase { - void runOnOperation() override { - mlir::ConversionTarget target(getContext()); - - target.addIllegalDialect(); - - target.addLegalDialect(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - target.addLegalOp(); - - // For now keep the same type assuming tosa ops operate on builtin tensor. - TypeConverter typeConverter; - typeConverter.addConversion([](Type type) { - assert(isa(type) && - "only ranked tensor type supported"); - return type; - }); - RewritePatternSet patterns(&getContext()); - - // Add conversion patterns. - patterns - .add>( - typeConverter, &getContext()); - patterns - .add>( - typeConverter, &getContext()); - patterns.add< - TosaToTTIROpConversionPattern>( - typeConverter, &getContext()); - patterns.add< - TosaToTTIROpConversionPattern>( - typeConverter, &getContext()); - patterns.add< - TosaToTTIROpConversionPattern>( - typeConverter, &getContext()); - patterns.add>( - typeConverter, &getContext()); - - // Apply conversion. - if (failed( - applyFullConversion(getOperation(), target, std::move(patterns)))) { - signalPassFailure(); - return; - } - } -}; - -} // namespace - -namespace mlir::tt { - -std::unique_ptr> createConvertTosaToTTIRPass() { - return std::make_unique(); -} - -} // namespace mlir::tt diff --git a/lib/Conversion/TosaToTTIR/TosaToTTIRPass.cpp b/lib/Conversion/TosaToTTIR/TosaToTTIRPass.cpp new file mode 100644 index 0000000000..183d58ccaa --- /dev/null +++ b/lib/Conversion/TosaToTTIR/TosaToTTIRPass.cpp @@ -0,0 +1,74 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h" +#include "ttmlir/Dialect/TTIR/IR/TTIR.h" + +using namespace mlir; +using namespace mlir::tt; + +namespace mlir::tt::ttir { + +#define GEN_PASS_DEF_CONVERTTOSATOTTIR +#include "ttmlir/Conversion/Passes.h.inc" + +} // namespace mlir::tt::ttir + +namespace { + +struct ConvertTosaToTTIRPass + : public ttir::impl::ConvertTosaToTTIRBase { + void runOnOperation() override { + mlir::ConversionTarget target(getContext()); + + target.addIllegalDialect(); + + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + + // For now keep the same type assuming tosa ops operate on builtin tensor. + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { + assert(isa(type) && + "only ranked tensor type supported"); + return type; + }); + RewritePatternSet patterns(&getContext()); + + // Add conversion patterns. + populateTosaToTTIRPatterns(&getContext(), patterns, typeConverter); + + // Apply conversion. + if (failed( + applyFullConversion(getOperation(), target, std::move(patterns)))) { + signalPassFailure(); + return; + } + } +}; + +} // namespace + +namespace mlir::tt { + +std::unique_ptr> createConvertTosaToTTIRPass() { + return std::make_unique(); +} + +} // namespace mlir::tt diff --git a/lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp b/lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp new file mode 100644 index 0000000000..46eadb7899 --- /dev/null +++ b/lib/Conversion/TosaToTTIR/TosaToTTIRPatterns.cpp @@ -0,0 +1,126 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" + +#include "ttmlir/Conversion/TosaToTTIR/TosaToTTIR.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" +#include "ttmlir/Dialect/TTIR/IR/TTIROps.h" + +using namespace mlir; +using namespace mlir::tt; + +namespace { + +// TODO(sdjukic): extract this pattern into separate file and use it for both +// TOSA and StableHLO + +template +class TosaToTTIRDefaultDPSOpConversionPattern + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +public: + LogicalResult + matchAndRewrite(SrcOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + LogicalResult legalityResult = + checkConversionLegality(srcOp, adaptor, rewriter); + if (!legalityResult.succeeded()) { + return legalityResult; + } + + RankedTensorType outputType = + mlir::cast(srcOp.getResult().getType()); + tensor::EmptyOp outputTensor = rewriter.create( + srcOp.getLoc(), outputType.getShape(), outputType.getElementType()); + rewriter.replaceOpWithNewOp( + srcOp, TypeRange(outputTensor.getType()), adaptor.getOperands(), + ValueRange(outputTensor), + rewriter.getArrayAttr( + SmallVector(adaptor.getOperands().size() + 1, + rewriter.getAttr( + OperandConstraint::AnyDeviceTile)))); + return success(); + } + +private: + virtual LogicalResult + checkConversionLegality(SrcOp srcOp, Adaptor adaptor, + ConversionPatternRewriter &rewriter) const { + return success(); + } +}; + +class TosaToTTIRMultiplyOpConversionPattern + : public TosaToTTIRDefaultDPSOpConversionPattern< + tosa::MulOp, mlir::tt::ttir::MultiplyOp> { + using TosaToTTIRDefaultDPSOpConversionPattern< + tosa::MulOp, + mlir::tt::ttir::MultiplyOp>::TosaToTTIRDefaultDPSOpConversionPattern; + +private: + LogicalResult + checkConversionLegality(tosa::MulOp srcOp, tosa::MulOp::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (srcOp.getShift() != 0) { + return rewriter.notifyMatchFailure( + srcOp, "TTIR MultiplyOp doesn't support shifted multiply."); + } + return success(); + } +}; + +void addElementwiseUnaryOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + + patterns.add>( + typeConverter, ctx); + patterns.add>( + typeConverter, ctx); +} + +void addElementwiseBinaryOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add>( + typeConverter, ctx); + patterns.add(typeConverter, ctx); + patterns.add>(typeConverter, ctx); +} + +void addCompareOpsConversionPatterns(MLIRContext *ctx, + RewritePatternSet &patterns, + TypeConverter &typeConverter) { + patterns.add>(typeConverter, + ctx); +} + +} // namespace + +namespace mlir::tt { + +void populateTosaToTTIRPatterns(MLIRContext *ctx, RewritePatternSet &patterns, + TypeConverter &typeConverter) { + addElementwiseUnaryOpsConversionPatterns(ctx, patterns, typeConverter); + addElementwiseBinaryOpsConversionPatterns(ctx, patterns, typeConverter); + addCompareOpsConversionPatterns(ctx, patterns, typeConverter); +} + +} // namespace mlir::tt diff --git a/lib/Dialect/TTNN/IR/CMakeLists.txt b/lib/Dialect/TTNN/IR/CMakeLists.txt index 1620e96b5c..4b7804a5fd 100644 --- a/lib/Dialect/TTNN/IR/CMakeLists.txt +++ b/lib/Dialect/TTNN/IR/CMakeLists.txt @@ -11,10 +11,12 @@ add_mlir_dialect_library(MLIRTTNNDialect DEPENDS MLIRTTNNOpsIncGen MLIRTTOpsIncGen + TTNNOpModelLib LINK_LIBS PUBLIC TTMLIRTTNNUtils MLIRSCFToEmitC MLIRLinalgDialect MLIRMLProgramDialect + TTNNOpModelLib ) diff --git a/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp b/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp index 9079a60194..344a4a4831 100644 --- a/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp +++ b/lib/Dialect/TTNN/IR/TTNNOpModelInterface.cpp @@ -5,6 +5,9 @@ #include "ttmlir/Dialect/TTNN/IR/TTNNOps.h" #include "ttmlir/Dialect/TTNN/IR/TTNNOpModelInterface.cpp.inc" +#include "ttmlir/OpModel/TTNN/TTNNOpModel.h" + +#include #include namespace mlir::tt::ttnn { @@ -22,14 +25,16 @@ size_t ReluOp::getOpPerfCycles(const std::vector &input_layouts, std::tuple ReluOp::getOpL1Usage(const std::vector &input_layouts, const TTNNLayoutAttr &output_layout) { - // TODO(mbezulj) wire to tt-metal once we have API - return std::make_tuple(1024, 2048, 1024); + assert(input_layouts.size() == 1); + return op_model::ttnn::ReluOpInterface::getOpL1Usage(input_layouts[0], + output_layout); } bool ReluOp::isOpLegal(const std::vector &input_layouts, const TTNNLayoutAttr &output_layout) { - // TODO(mbezulj) wire to tt-metal once we have API - return true; + assert(input_layouts.size() == 1); + return op_model::ttnn::ReluOpInterface::isLegal(input_layouts[0], + output_layout); } } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Utils/CMakeLists.txt b/lib/Dialect/TTNN/Utils/CMakeLists.txt index f49f829e6f..f78f418642 100644 --- a/lib/Dialect/TTNN/Utils/CMakeLists.txt +++ b/lib/Dialect/TTNN/Utils/CMakeLists.txt @@ -1,6 +1,7 @@ add_mlir_dialect_library(TTMLIRTTNNUtils Utils.cpp OptimizerOverrides.cpp + PassOverrides.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/TTNN diff --git a/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp b/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp index 5ef306cdb0..bbc456948e 100644 --- a/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp +++ b/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp @@ -6,187 +6,173 @@ namespace mlir::tt::ttnn { -bool OutputLayoutOverrideParser::parse( - llvm::cl::Option &opt, StringRef argName, StringRef arg, +void OptimizerOverridesHandler::setEnableOptimizer(bool value) { + enableOptimizer = value; +} + +void OptimizerOverridesHandler::setMemoryReconfig(bool value) { + enableMemoryReconfig = value; +} +void OptimizerOverridesHandler::setEnableMemoryLayoutAnalysis(bool value) { + enableMemoryLayoutAnalysis = value; +} +void OptimizerOverridesHandler::setEnableMemoryLayoutAnalysisPolicy( + bool value) { + enableMemoryLayoutAnalysisPolicy = value; +} +void OptimizerOverridesHandler::setMemoryLayoutAnalysisPolicy( + MemoryLayoutAnalysisPolicyType value) { + memoryLayoutAnalysisPolicy = value; +} + +void OptimizerOverridesHandler::setInputLayoutOverrides( + llvm::StringMap &value) { + inputLayoutOverrides = value; +} +void OptimizerOverridesHandler::setOutputLayoutOverrides( llvm::StringMap &value) { - SmallVector opOverrideList; - constexpr size_t kMaxGridSize = 2; - constexpr size_t kvPairSize = 2; - constexpr size_t kMaxLayoutOverrideParams = 5; - constexpr size_t iOpName = 0; - constexpr size_t iLayoutOverrideParams = 1; - constexpr size_t iGrid = 0; - constexpr size_t iMemorySpace = 1; - constexpr size_t iTensorMemoryLayout = 2; - constexpr size_t iMemoryLayout = 3; - constexpr size_t iDataType = 4; - constexpr char opSeparator = ','; - constexpr char opNameSeparator = '='; - constexpr char paramSepataor = ':'; - constexpr char gridSeparator = 'x'; - - arg.split(opOverrideList, opSeparator); - for (const StringRef override : opOverrideList) { - SmallVector opOverrideParts; - override.split(opOverrideParts, opNameSeparator); - if (opOverrideParts.size() != kvPairSize) { - opt.error("Invalid format for override grid sizes: " + override); - return true; - } + outputLayoutOverrides = value; +} - SmallVector layoutParamParts; - // Split into layout parameters. - opOverrideParts[iLayoutOverrideParams].split(layoutParamParts, - paramSepataor); - if (layoutParamParts.size() != kMaxLayoutOverrideParams) { - opt.error("Invalid number of layout parameters: " + - std::to_string(layoutParamParts.size())); - return true; - } +void OptimizerOverridesHandler::setSystemDescPath(std::string value) { + systemDescPath = value; +} +void OptimizerOverridesHandler::setMaxLegalLayouts(int64_t value) { + maxLegalLayouts = value; +} +void OptimizerOverridesHandler::setMeshShape(std::vector value) { + meshShape = value; +} - // Parse grid. - SmallVector grid; - SmallVector gridParts; - layoutParamParts[iGrid].split(gridParts, gridSeparator); - for (const StringRef gridPart : gridParts) { - int64_t gridValue; - if (gridPart.getAsInteger(10 /*Radix*/, gridValue)) { - opt.error("Invalid grid size: " + gridPart); - return true; - } - grid.push_back(gridValue); - } +bool OptimizerOverridesHandler::getEnableOptimizer() const { + return enableOptimizer; +} - // Parse memory space. - std::optional bufferType = - symbolizeBufferType(layoutParamParts[iMemorySpace]); - if (!bufferType.has_value()) { - opt.error("Invalid memory space: " + layoutParamParts[iMemorySpace]); - return true; - } +bool OptimizerOverridesHandler::getMemoryReconfig() const { + return enableMemoryReconfig; +} +bool OptimizerOverridesHandler::getEnableMemoryLayoutAnalysis() const { + return enableMemoryLayoutAnalysis; +} +bool OptimizerOverridesHandler::getEnableMemoryLayoutAnalysisPolicy() const { + return enableMemoryLayoutAnalysisPolicy; +} +MemoryLayoutAnalysisPolicyType +OptimizerOverridesHandler::getMemoryLayoutAnalysisPolicy() const { + return memoryLayoutAnalysisPolicy; +} - // Parse tensor memory layout. - std::optional tensorMemoryLayout = - symbolizeTensorMemoryLayout(layoutParamParts[iTensorMemoryLayout]); - if (!tensorMemoryLayout.has_value()) { - opt.error("Invalid tensor memory layout: " + - layoutParamParts[iTensorMemoryLayout]); - return true; - } +std::string OptimizerOverridesHandler::getSystemDescPath() const { + return systemDescPath; +} +int64_t OptimizerOverridesHandler::getMaxLegalLayouts() const { + return maxLegalLayouts; +} +std::vector OptimizerOverridesHandler::getMeshShape() const { + return meshShape; +} - // Parse memory layout. - std::optional memoryLayout = - mlir::tt::ttnn::symbolizeLayout(layoutParamParts[iMemoryLayout]); - if (!memoryLayout.has_value()) { - opt.error("Invalid memory layout: " + layoutParamParts[iMemoryLayout]); - return true; - } +llvm::StringMap +OptimizerOverridesHandler::getInputLayoutOverrides() const { + return inputLayoutOverrides; +} +llvm::StringMap +OptimizerOverridesHandler::getOutputLayoutOverrides() const { + return outputLayoutOverrides; +} - // Parse data type. - std::optional dataType = - mlir::tt::DataTypeStringToEnum(layoutParamParts[iDataType]); - if (!dataType.has_value()) { - opt.error("Invalid data type: " + layoutParamParts[iDataType]); - return true; - } +std::string OptimizerOverridesHandler::toString() const { - // Set parsed op overrides. - value[opOverrideParts[iOpName]] = OutputLayoutOverrideParams{ - std::move(grid), bufferType.value(), tensorMemoryLayout.value(), - memoryLayout.value(), dataType.value()}; + std::string options = ""; + + if (enableOptimizer) { + options += std::string(pipelineOptions.optimizerPassEnabled.getArgStr()) + + "=true "; } - return false; -} - -void OutputLayoutOverrideParser::print( - llvm::raw_ostream &os, - const llvm::StringMap &value) { - os << "override-output-layout="; - size_t count = 0; - for (const auto &entry : value) { - os << entry.getKey() << "="; - const OutputLayoutOverrideParams ¶ms = entry.getValue(); - // Print grid values - for (size_t i = 0; i < params.grid.size(); ++i) { - os << params.grid[i]; - if (i < params.grid.size() - 1) { - os << "x"; - } - } - // Print memory space and memory layout - os << ":" << mlir::tt::ttnn::stringifyBufferType(params.bufferType); - os << ":" - << mlir::tt::ttnn::stringifyTensorMemoryLayout( - params.tensorMemoryLayout); - os << ":" << mlir::tt::ttnn::stringifyLayout(params.memoryLayout); - os << ":" << mlir::tt::DataTypeEnumToString(params.dataType); - if (++count < value.size()) { - os << ","; - } + + if (enableMemoryReconfig) { + options += + std::string(pipelineOptions.memReconfigEnabled.getArgStr()) + "=true "; } - os << "\n"; -} -bool InputLayoutOverrideParser::parse( - llvm::cl::Option &opt, StringRef argName, StringRef arg, - llvm::StringMap &value) { - SmallVector opOverrideList; - constexpr size_t kvPairSize = 2; - constexpr size_t iOpName = 0; - constexpr size_t iOperands = 1; - constexpr char opSeparator = ','; - constexpr char opNameSeparator = '='; - constexpr char opParamSeparator = ':'; - - arg.split(opOverrideList, opSeparator); - for (const StringRef override : opOverrideList) { - SmallVector opOverrideParts; - override.split(opOverrideParts, opNameSeparator); - if (opOverrideParts.size() != kvPairSize) { - opt.error("Invalid format for input layouts override: " + override); - return true; - } + if (enableMemoryLayoutAnalysis) { + options += + std::string(pipelineOptions.memoryLayoutAnalysisEnabled.getArgStr()) + + "=true "; + } - SmallVector operandIndexes; - SmallVector operandIndexParts; - - // Parse operand indexes. - opOverrideParts[iOperands].split(operandIndexParts, opParamSeparator); - for (const StringRef operandIndexPart : operandIndexParts) { - int64_t operandIndexValue; - if (operandIndexPart.getAsInteger(10 /*Radix*/, operandIndexValue)) { - opt.error("Invalid operand index: " + operandIndexPart); - return true; - } - operandIndexes.push_back(operandIndexValue); - } + if (enableMemoryLayoutAnalysisPolicy) { + options += + std::string(pipelineOptions.memoryLayoutAnalysisPolicy.getArgStr()) + + MemoryLayoutAnalysisPolicyTypeParser::toString( + memoryLayoutAnalysisPolicy) + + " "; + } - // Set parsed op overrides. - value[opOverrideParts[iOpName]] = - InputLayoutOverrideParams{std::move(operandIndexes)}; + // Create input layout overrides. + // Example: insert-memreconfig=input0=0:1,input1=0,input2=0:1:2 + if (inputLayoutOverrides.size() > 0) { + options += std::string(pipelineOptions.overrideInputLayout.getArgStr()) + + "=" + InputLayoutOverrideParser::toString(inputLayoutOverrides) + + " "; } - return false; -} - -void InputLayoutOverrideParser::print( - llvm::raw_ostream &os, - const llvm::StringMap &value) { - os << "insert-memreconfig="; - size_t count = 0; - for (const auto &entry : value) { - os << entry.getKey() << "="; - const InputLayoutOverrideParams ¶ms = entry.getValue(); - for (int64_t operandIdx : params.operandIdxes) { - os << operandIdx - << (operandIdx < static_cast(params.operandIdxes.size()) - 1 - ? ':' - : char()); - } - if (++count < value.size()) { - os << ","; + + // Create output layout overrides. + // Example: + // override-output-layout=op1=2x2:dram:interleaved:tile:fp32,op2=4x4:l1:block_sharded:row_major:fp16 + // Example: + // override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32" + if (outputLayoutOverrides.size() > 0) { + options += + std::string(pipelineOptions.overrideOutputLayout.getArgStr()) + "=" + + OutputLayoutOverrideParser::toString(outputLayoutOverrides) + " "; + } + + if (systemDescPath.size() > 0) { + options += std::string(pipelineOptions.systemDescPath.getArgStr()) + + systemDescPath + " "; + } + + if (maxLegalLayouts > 0) { + options += std::string(pipelineOptions.maxLegalLayouts.getArgStr()) + + std::to_string(maxLegalLayouts) + " "; + } + + if (meshShape.size() > 0) { + options += std::string(pipelineOptions.meshShape.getArgStr()) + "="; + for (int64_t meshShapeValue : meshShape) { + options += std::to_string(meshShapeValue) + ","; } + // Remove the last comma. + options.pop_back(); + } + + if (options[options.size() - 1] == ' ') { + options.pop_back(); } - os << "\n"; + + return options; +} + +void OptimizerOverridesHandler::addInputLayoutOverride( + StringRef opName, InputLayoutOverrideParams params) { + inputLayoutOverrides[opName] = params; +} +void OptimizerOverridesHandler::addInputLayoutOverride( + StringRef opName, SmallVector &operandIdxes) { + inputLayoutOverrides[opName] = + InputLayoutOverrideParams{std::move(operandIdxes)}; +} +void OptimizerOverridesHandler::addOutputLayoutOverride( + StringRef opName, OutputLayoutOverrideParams params) { + outputLayoutOverrides[opName] = params; +} +void OptimizerOverridesHandler::addOutputLayoutOverride( + StringRef opName, SmallVector &grid, BufferType bufferType, + TensorMemoryLayout tensorMemoryLayout, tt::ttnn::Layout memoryLayout, + tt::DataType dataType) { + outputLayoutOverrides[opName] = OutputLayoutOverrideParams{ + std::move(grid), bufferType, tensorMemoryLayout, memoryLayout, dataType}; } } // namespace mlir::tt::ttnn diff --git a/lib/Dialect/TTNN/Utils/PassOverrides.cpp b/lib/Dialect/TTNN/Utils/PassOverrides.cpp new file mode 100644 index 0000000000..9c8ef2be1f --- /dev/null +++ b/lib/Dialect/TTNN/Utils/PassOverrides.cpp @@ -0,0 +1,206 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "ttmlir/Dialect/TTNN/Utils/PassOverrides.h" + +namespace mlir::tt::ttnn { + +bool OutputLayoutOverrideParser::parse( + llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value) { + SmallVector opOverrideList; + constexpr size_t kMaxGridSize = 2; + constexpr size_t kvPairSize = 2; + constexpr size_t kMaxLayoutOverrideParams = 5; + constexpr size_t iOpName = 0; + constexpr size_t iLayoutOverrideParams = 1; + constexpr size_t iGrid = 0; + constexpr size_t iMemorySpace = 1; + constexpr size_t iTensorMemoryLayout = 2; + constexpr size_t iMemoryLayout = 3; + constexpr size_t iDataType = 4; + constexpr char opSeparator = ','; + constexpr char opNameSeparator = '='; + constexpr char paramSepataor = ':'; + constexpr char gridSeparator = 'x'; + + arg.split(opOverrideList, opSeparator); + for (const StringRef override : opOverrideList) { + SmallVector opOverrideParts; + override.split(opOverrideParts, opNameSeparator); + if (opOverrideParts.size() != kvPairSize) { + opt.error("Invalid format for override grid sizes: " + override); + return true; + } + + SmallVector layoutParamParts; + // Split into layout parameters. + opOverrideParts[iLayoutOverrideParams].split(layoutParamParts, + paramSepataor); + if (layoutParamParts.size() != kMaxLayoutOverrideParams) { + opt.error("Invalid number of layout parameters: " + + std::to_string(layoutParamParts.size())); + return true; + } + + // Parse grid. + SmallVector grid; + SmallVector gridParts; + layoutParamParts[iGrid].split(gridParts, gridSeparator); + for (const StringRef gridPart : gridParts) { + int64_t gridValue; + if (gridPart.getAsInteger(10 /*Radix*/, gridValue)) { + opt.error("Invalid grid size: " + gridPart); + return true; + } + grid.push_back(gridValue); + } + + // Parse memory space. + std::optional bufferType = + symbolizeBufferType(layoutParamParts[iMemorySpace]); + if (!bufferType.has_value()) { + opt.error("Invalid memory space: " + layoutParamParts[iMemorySpace]); + return true; + } + + // Parse tensor memory layout. + std::optional tensorMemoryLayout = + symbolizeTensorMemoryLayout(layoutParamParts[iTensorMemoryLayout]); + if (!tensorMemoryLayout.has_value()) { + opt.error("Invalid tensor memory layout: " + + layoutParamParts[iTensorMemoryLayout]); + return true; + } + + // Parse memory layout. + std::optional memoryLayout = + mlir::tt::ttnn::symbolizeLayout(layoutParamParts[iMemoryLayout]); + if (!memoryLayout.has_value()) { + opt.error("Invalid memory layout: " + layoutParamParts[iMemoryLayout]); + return true; + } + + // Parse data type. + std::optional dataType = + mlir::tt::DataTypeStringToEnum(layoutParamParts[iDataType]); + if (!dataType.has_value()) { + opt.error("Invalid data type: " + layoutParamParts[iDataType]); + return true; + } + + // Set parsed op overrides. + value[opOverrideParts[iOpName]] = OutputLayoutOverrideParams{ + std::move(grid), bufferType.value(), tensorMemoryLayout.value(), + memoryLayout.value(), dataType.value()}; + } + return false; +} + +std::string OutputLayoutOverrideParser::toString( + const llvm::StringMap &value) { + std::string res; + size_t count = 0; + for (const auto &entry : value) { + res += std::string(entry.getKey()) + "="; + const OutputLayoutOverrideParams ¶ms = entry.getValue(); + // Print grid values + for (size_t i = 0; i < params.grid.size(); ++i) { + res += std::to_string(params.grid[i]); + if (i < params.grid.size() - 1) { + res += "x"; + } + } + // Print memory space and memory layout + res += ":" + + std::string(mlir::tt::ttnn::stringifyBufferType(params.bufferType)); + res += ":" + std::string(mlir::tt::ttnn::stringifyTensorMemoryLayout( + params.tensorMemoryLayout)); + res += + ":" + std::string(mlir::tt::ttnn::stringifyLayout(params.memoryLayout)); + res += ":" + std::string(mlir::tt::DataTypeEnumToString(params.dataType)); + if (++count < value.size()) { + res += ","; + } + } + return res; +} + +void OutputLayoutOverrideParser::print( + llvm::raw_ostream &os, + const llvm::StringMap &value) { + os << "override-output-layout="; + os << OutputLayoutOverrideParser::toString(value); + os << "\n"; +} + +bool InputLayoutOverrideParser::parse( + llvm::cl::Option &opt, StringRef argName, StringRef arg, + llvm::StringMap &value) { + SmallVector opOverrideList; + constexpr size_t kvPairSize = 2; + constexpr size_t iOpName = 0; + constexpr size_t iOperands = 1; + constexpr char opSeparator = ','; + constexpr char opNameSeparator = '='; + constexpr char opParamSeparator = ':'; + + arg.split(opOverrideList, opSeparator); + for (const StringRef override : opOverrideList) { + SmallVector opOverrideParts; + override.split(opOverrideParts, opNameSeparator); + if (opOverrideParts.size() != kvPairSize) { + opt.error("Invalid format for input layouts override: " + override); + return true; + } + + SmallVector operandIndexes; + SmallVector operandIndexParts; + + // Parse operand indexes. + opOverrideParts[iOperands].split(operandIndexParts, opParamSeparator); + for (const StringRef operandIndexPart : operandIndexParts) { + int64_t operandIndexValue; + if (operandIndexPart.getAsInteger(10 /*Radix*/, operandIndexValue)) { + opt.error("Invalid operand index: " + operandIndexPart); + return true; + } + operandIndexes.push_back(operandIndexValue); + } + + // Set parsed op overrides. + value[opOverrideParts[iOpName]] = + InputLayoutOverrideParams{std::move(operandIndexes)}; + } + return false; +} + +std::string InputLayoutOverrideParser::toString( + const llvm::StringMap &value) { + std::string res; + size_t count = 0; + for (const auto &entry : value) { + res += std::string(entry.getKey()) + "="; + const InputLayoutOverrideParams ¶ms = entry.getValue(); + for (int64_t operandIdx : params.operandIdxes) { + res += std::to_string(operandIdx) + ":"; + } + // Remove the last colon. + res.pop_back(); + if (++count < value.size()) { + res += ","; + } + } + return res; +} + +void InputLayoutOverrideParser::print( + llvm::raw_ostream &os, + const llvm::StringMap &value) { + os << "insert-memreconfig="; + os << InputLayoutOverrideParser::toString(value); + os << "\n"; +} + +} // namespace mlir::tt::ttnn diff --git a/lib/OpModel/CMakeLists.txt b/lib/OpModel/CMakeLists.txt new file mode 100644 index 0000000000..9c34667d09 --- /dev/null +++ b/lib/OpModel/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(TTNN) diff --git a/lib/OpModel/TTNN/CMakeLists.txt b/lib/OpModel/TTNN/CMakeLists.txt new file mode 100644 index 0000000000..094b9f1ddd --- /dev/null +++ b/lib/OpModel/TTNN/CMakeLists.txt @@ -0,0 +1,40 @@ +set(LIB_NAME TTNNOpModelLib) + +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +set(SOURCES + TTNNOpModelLib.cpp +) +add_library(${LIB_NAME} STATIC ${SOURCES}) + +message(STATUS "TTMLIR_ENABLE_OP_MODEL[${TTMLIR_ENABLE_OP_MODEL}]") +if (TTMLIR_ENABLE_OPMODEL) + # Link to tt-metal libs and include directories + target_include_directories(${LIB_NAME} PUBLIC "$") + target_link_libraries(${LIB_NAME} PUBLIC TTNN_LIBRARY TTMETAL_LIBRARY) + target_compile_definitions(${LIB_NAME} PUBLIC TTMLIR_ENABLE_OPMODEL) +else() + # link stubs implementation when op model library is disabled + message(WARNING "TTNNOpModelLib is disabled. The optimizer will not achieve optimal performance.") +endif() + +# Specify the include directories for the library +target_include_directories(${LIB_NAME} + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/ + ${PROJECT_SOURCE_DIR}/include/ttmlir/OpModel/TTNN/) + + +# Add TTNNOpModelLib to the export set +install(TARGETS ${LIB_NAME} + EXPORT TTNNOpModelLibTargets + LIBRARY DESTINATION lib + ARCHIVE DESTINATION lib + RUNTIME DESTINATION bin + INCLUDES DESTINATION include) + +# Export the targets +export(EXPORT TTNNOpModelLibTargets + FILE "${CMAKE_CURRENT_BINARY_DIR}/TTNNOpModelLibTargets.cmake" + NAMESPACE TTNN::) diff --git a/lib/OpModel/TTNN/TTNNOpModelLib.cpp b/lib/OpModel/TTNN/TTNNOpModelLib.cpp new file mode 100644 index 0000000000..87bfc04150 --- /dev/null +++ b/lib/OpModel/TTNN/TTNNOpModelLib.cpp @@ -0,0 +1,183 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include "TTNNOpModel.h" + +#ifdef TTMLIR_ENABLE_OPMODEL +#include "TTNNOpModelLib_Impl.h" +#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h" + +#include +#include + +#include +#include +#endif // TTMLIR_ENABLE_OPMODEL + +namespace mlir::tt::op_model::ttnn { + +#ifdef TTMLIR_ENABLE_OPMODEL +// alias to a common tt_metal types +using DataType = ::tt::tt_metal::DataType; +using Layout = ::tt::tt_metal::Layout; +using CoreRange = ::tt::tt_metal::CoreRange; +using CoreRangeSet = ::tt::tt_metal::CoreRangeSet; +using CoreCoord = ::tt::tt_metal::CoreCoord; +using ShardSpec = ::tt::tt_metal::ShardSpec; +using ShardOrientation = ::tt::tt_metal::ShardOrientation; +using TensorMemoryLayout = ::tt::tt_metal::TensorMemoryLayout; +using MemoryConfig = ::tt::tt_metal::MemoryConfig; + +namespace detail { + +DataType getDataType(const mlir::MemRefType &memref) { + + auto dataType = elementTypeToDataType(memref.getElementType()); + + switch (dataType) { + case tt::DataType::Float32: + return DataType::FLOAT32; + case tt::DataType::BFloat16: + return DataType::BFLOAT16; + case tt::DataType::BFP_BFloat8: + return DataType::BFLOAT8_B; + case tt::DataType::BFP_BFloat4: + return DataType::BFLOAT4_B; + case tt::DataType::UInt32: + return DataType::UINT32; + case tt::DataType::UInt16: + return DataType::UINT16; + case tt::DataType::UInt8: + return DataType::UINT8; + default: + throw std::runtime_error("Invalid element type"); + } +} + +::ttnn::SimpleShape getTensorShape(const mlir::MemRefType &memref) { + ::tt::tt_metal::SmallVector small_vector_shape( + memref.getShape().begin(), memref.getShape().end()); + return ::ttnn::SimpleShape(small_vector_shape); +} + +const std::array +getShardShape(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + const auto layoutShardTile = layout.getShardShape(); + + if (layoutShardTile.size() != 2) { + llvm::errs() << "ERROR: layout_shard_tile.size() != 2\n"; + return {0, 0}; + } + + std::array shardShape; + shardShape[0] = layoutShardTile[0]; + shardShape[1] = layoutShardTile[1]; + return shardShape; +} + +Layout getTensorLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + return layout.isTiled() ? Layout::TILE : Layout::ROW_MAJOR; +} + +CoreRangeSet getCoreRangeSet(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + // TODO(mbezulj): handle more complex grid shapes + // assuming grid shape is one rect starting at (0,0) + + const auto layoutGrid = layout.getGrid(); + + const auto layoutGridShape = layoutGrid.getShape(); + if (layoutGridShape.size() != 2) { + llvm::errs() << "ERROR: layout_grid.getShape().size() == 2\n"; + return {}; + } + + return CoreRangeSet(CoreRange(CoreCoord(0, layoutGridShape[0]), + CoreCoord(0, layoutGridShape[1]))); +} + +std::optional +layout_get_shard_spec(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + // tt_ShardOrientation is not part of ttnn::TTNNLayoutAttr; + // defaulting to ROW_MAJOR. TODO: figure out if we need to expose this + return isShardedMemoryLayout(layout.getMemLayout()) + ? std::make_optional(ShardSpec(getCoreRangeSet(layout), + getShardShape(layout), + ShardOrientation::ROW_MAJOR, false)) + : std::nullopt; +} + +::tt::tt_metal::BufferType getBufferType(const mlir::MemRefType &memref) { + auto memorySpace = + mlir::cast(memref.getMemorySpace()).getValue(); + + switch (memorySpace) { + case tt::MemorySpace::DeviceDRAM: + return ::tt::tt_metal::BufferType::DRAM; + case tt::MemorySpace::DeviceL1: + return ::tt::tt_metal::BufferType::L1; + default: // TODO(mbezulj): handle other memory spaces + throw std::runtime_error("Unsupported memory space"); + } +} + +::tt::tt_metal::TensorMemoryLayout +getTensorMemoryLayout(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + auto tensorMemoryLayout = layout.getMemLayout(); + + switch (tensorMemoryLayout) { + case mlir::tt::ttnn::TensorMemoryLayout::Interleaved: + return ::tt::tt_metal::TensorMemoryLayout::INTERLEAVED; + case mlir::tt::ttnn::TensorMemoryLayout::SingleBank: + return ::tt::tt_metal::TensorMemoryLayout::SINGLE_BANK; + case mlir::tt::ttnn::TensorMemoryLayout::HeightSharded: + return ::tt::tt_metal::TensorMemoryLayout::HEIGHT_SHARDED; + case mlir::tt::ttnn::TensorMemoryLayout::WidthSharded: + return ::tt::tt_metal::TensorMemoryLayout::WIDTH_SHARDED; + case mlir::tt::ttnn::TensorMemoryLayout::BlockSharded: + return ::tt::tt_metal::TensorMemoryLayout::BLOCK_SHARDED; + default: + throw std::runtime_error("Unsupported tensor memory layout"); + } +} + +::tt::tt_metal::MemoryConfig +getMemoryConfig(const mlir::tt::ttnn::TTNNLayoutAttr &layout) { + + auto tensorMemoryLayout = getTensorMemoryLayout(layout); + auto bufferType = getBufferType(layout.getMemref()); + + auto shardSpec = layout_get_shard_spec(layout); + return ::tt::tt_metal::MemoryConfig(tensorMemoryLayout, bufferType, + shardSpec); +} + +} // namespace detail +#endif // TTMLIR_ENABLE_OPMODEL + +//===----------------------------------------------------------------------===// +// ReluOp +//===----------------------------------------------------------------------===// + +bool ReluOpInterface::isLegal( + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout) { + +#ifdef TTMLIR_ENABLE_OPMODEL + return true; // to wire into tt-metal with the next uplift +#else + return true; +#endif // TTMLIR_ENABLE_OPMODEL +} + +std::tuple ReluOpInterface::getOpL1Usage( + const mlir::tt::ttnn::TTNNLayoutAttr &inputLayout, + const mlir::tt::ttnn::TTNNLayoutAttr &outputLayout) { +#ifdef TTMLIR_ENABLE_OPMODEL + return std::make_tuple(0, 0, 0); // to wire into tt-metal with the next uplift +#else + return std::make_tuple(0, 0, 0); +#endif // TTMLIR_ENABLE_OPMODEL +} + +} // namespace mlir::tt::op_model::ttnn diff --git a/lib/OpModel/TTNN/TTNNOpModelLib_Impl.h b/lib/OpModel/TTNN/TTNNOpModelLib_Impl.h new file mode 100644 index 0000000000..ed39d881a9 --- /dev/null +++ b/lib/OpModel/TTNN/TTNNOpModelLib_Impl.h @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#ifndef TTMLIR_OPMODEL_TTNN_TTNNOPMODELLIB_IMPL_H +#define TTMLIR_OPMODEL_TTNN_TTNNOPMODELLIB_IMPL_H + +// This header resolves tt-metal warnings that would otherwise be treated as +// errors in the MLIR build. Ensure that this is the only place where tt-metal +// headers are included. + +#pragma clang diagnostic push +#pragma clang diagnostic ignored "-Wcast-qual" +#pragma clang diagnostic ignored "-Wctad-maybe-unsupported" +#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" +#pragma clang diagnostic ignored "-Wignored-qualifiers" +#pragma clang diagnostic ignored "-Wvla-extension" +#pragma clang diagnostic ignored "-Wcovered-switch-default" +#pragma clang diagnostic ignored "-Wsign-compare" +#pragma clang diagnostic ignored "-Wc++20-extensions" +#pragma clang diagnostic ignored "-Wc++20-designator" +#pragma clang diagnostic ignored "-Wnon-virtual-dtor" +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunknown-warning-option" +#pragma clang diagnostic ignored "-Wsuggest-override" +#pragma clang diagnostic ignored "-Wgnu-anonymous-struct" +#pragma clang diagnostic ignored "-Wnested-anon-types" +#pragma clang diagnostic ignored "-Wreorder-ctor" +#pragma clang diagnostic ignored "-Wmismatched-tags" +#pragma clang diagnostic ignored "-Wunused-lambda-capture" +#pragma clang diagnostic ignored "-Wmissing-field-initializers" +#pragma clang diagnostic ignored "-Wunused-private-field" +#pragma clang diagnostic ignored "-Wimplicit-fallthrough" +#pragma clang diagnostic ignored "-Wstring-conversion" +#pragma clang diagnostic ignored "-Wunneeded-internal-declaration" +#pragma clang diagnostic ignored "-Wunused-local-typedef" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wpessimizing-move" +#pragma clang diagnostic ignored "-Wparentheses" +#pragma clang diagnostic ignored "-Wdeprecated-volatile" +#pragma clang diagnostic ignored "-Wdeprecated-this-capture" +#pragma clang diagnostic ignored "-Wc++23-extensions" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" +#pragma clang diagnostic ignored "-Wlogical-op-parentheses" +#pragma clang diagnostic ignored "-Wundefined-inline" +#pragma clang diagnostic ignored "-Wc99-extensions" +#pragma clang diagnostic ignored "-Wc++11-narrowing" +#pragma clang diagnostic ignored "-Wzero-length-array" +#pragma clang diagnostic ignored "-Wdeprecated-declarations" + +#define FMT_HEADER_ONLY + +#include "tt_metal/common/core_coord.hpp" +#include "tt_metal/impl/buffers/buffer.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/tensor/types.hpp" + +#pragma clang diagnostic pop + +#endif // TTMLIR_OPMODEL_TTNN_TTNNOPMODELLIB_IMPL_H diff --git a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp index dfd8b9375e..4fc6fca87f 100644 --- a/runtime/lib/ttnn/operations/pool/maxpool2d.cpp +++ b/runtime/lib/ttnn/operations/pool/maxpool2d.cpp @@ -31,11 +31,14 @@ preshardForMaxPool2d(const ::tt::target::ttnn::MaxPool2dOp *op, op->dilation_width() * (op->kernel_width() - 1) - 1) / op->stride_width(); + constexpr bool en_ch_padding = false; + auto parallel_config = ::ttnn::operations::conv::conv2d::determine_parallel_config( ::ttnn::TensorMemoryLayout::HEIGHT_SHARDED, op->batch_size(), op->channels(), output_height, output_width, op->channels(), - device.compute_with_storage_grid_size(), ShardOrientation::ROW_MAJOR); + device.compute_with_storage_grid_size(), ShardOrientation::ROW_MAJOR, + en_ch_padding); auto sharded_memory_config = ::ttnn::operations::conv::conv2d:: create_sharded_memory_config_from_parallel_config(inputShape, parallel_config, 1); diff --git a/test/unittests/Optimizer/CMakeLists.txt b/test/unittests/Optimizer/CMakeLists.txt index 681d78ff0e..4e6ee799a7 100644 --- a/test/unittests/Optimizer/CMakeLists.txt +++ b/test/unittests/Optimizer/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(OptimizerTests TestShardSolver.cpp + TestOptimizerOverrides.cpp ) target_link_libraries(OptimizerTests diff --git a/test/unittests/Optimizer/TestOptimizerOverrides.cpp b/test/unittests/Optimizer/TestOptimizerOverrides.cpp new file mode 100644 index 0000000000..c75fde21f9 --- /dev/null +++ b/test/unittests/Optimizer/TestOptimizerOverrides.cpp @@ -0,0 +1,433 @@ +// SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +// +// SPDX-License-Identifier: Apache-2.0 + +#include + +#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h" + +using namespace mlir::tt::ttnn; + +class TestOptimizerOverrides : public ::testing::Test { + +public: + OptimizerOverridesHandler optimizerOverridesHandler; + + void SetUp() override {} + + llvm::StringMap createInputLayoutOverrides() { + + // struct InputLayoutOverrideParams { + // SmallVector operandIdxes; + // }; + + llvm::StringMap inputLayoutOverrides; + + // Create input layout overrides for 3 input overrides. + inputLayoutOverrides["input0"] = createInputLayoutOverrideParams(); + inputLayoutOverrides["input1"] = createInputLayoutOverrideParams(); + inputLayoutOverrides["input2"] = createInputLayoutOverrideParams(); + + return inputLayoutOverrides; + } + + InputLayoutOverrideParams createInputLayoutOverrideParams() { + + InputLayoutOverrideParams inputLayoutOverrideParams; + + // Create input layout override params for 2 operands. + // Their operand indexes are 0 and 1, respectively. + inputLayoutOverrideParams.operandIdxes.push_back(0); + inputLayoutOverrideParams.operandIdxes.push_back(1); + + return inputLayoutOverrideParams; + } + + llvm::StringMap createOutputLayoutOverrides() { + + llvm::StringMap outputLayoutOverrides; + + // Create output layout overrides for 3 output overrides. + outputLayoutOverrides["output0"] = createOutputLayoutOverrideParams_0(); + outputLayoutOverrides["output1"] = createOutputLayoutOverrideParams_1(); + outputLayoutOverrides["output2"] = createOutputLayoutOverrideParams_2(); + + return outputLayoutOverrides; + } + + OutputLayoutOverrideParams createOutputLayoutOverrideParams_0() { + + // struct OutputLayoutOverrideParams { + // SmallVector grid; + // BufferType; + // TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + // Layout memoryLayout; // ROW_MAJOR / TILE + // mlir::tt::DataType dataType; + // }; + + OutputLayoutOverrideParams outputLayoutOverrideParams; + + // Output 0 has + // - grid size 2x2, + // - buffer type dram + // - tensor memory layout interleaved + // - memory layout tile + // - data type fp16. + outputLayoutOverrideParams.grid.push_back(2); + outputLayoutOverrideParams.grid.push_back(2); + outputLayoutOverrideParams.bufferType = BufferType::DRAM; + outputLayoutOverrideParams.tensorMemoryLayout = + TensorMemoryLayout::Interleaved; + outputLayoutOverrideParams.memoryLayout = Layout::Tile; + outputLayoutOverrideParams.dataType = mlir::tt::DataType::Float16; + + return outputLayoutOverrideParams; + } + + OutputLayoutOverrideParams createOutputLayoutOverrideParams_1() { + + // struct OutputLayoutOverrideParams { + // SmallVector grid; + // BufferType; + // TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + // Layout memoryLayout; // ROW_MAJOR / TILE + // mlir::tt::DataType dataType; + // }; + + OutputLayoutOverrideParams outputLayoutOverrideParams; + + // Output 1 has + // - grid size 8x4, + // - buffer type l1 + // - tensor memory layout block_sharded + // - memory layout row_major + // - data type fp16. + outputLayoutOverrideParams.grid.push_back(8); + outputLayoutOverrideParams.grid.push_back(4); + outputLayoutOverrideParams.bufferType = BufferType::L1; + outputLayoutOverrideParams.tensorMemoryLayout = + TensorMemoryLayout::BlockSharded; + outputLayoutOverrideParams.memoryLayout = Layout::RowMajor; + outputLayoutOverrideParams.dataType = mlir::tt::DataType::Float16; + + return outputLayoutOverrideParams; + } + + OutputLayoutOverrideParams createOutputLayoutOverrideParams_2() { + + // struct OutputLayoutOverrideParams { + // SmallVector grid; + // BufferType; + // TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc... + // Layout memoryLayout; // ROW_MAJOR / TILE + // mlir::tt::DataType dataType; + // }; + + OutputLayoutOverrideParams outputLayoutOverrideParams; + + // Output 2 has + // - grid size 3x6, + // - buffer type system + // - tensor memory layout height_sharded + // - memory layout tile + // - data type fp16. + outputLayoutOverrideParams.grid.push_back(3); + outputLayoutOverrideParams.grid.push_back(6); + outputLayoutOverrideParams.bufferType = BufferType::SystemMemory; + outputLayoutOverrideParams.tensorMemoryLayout = + TensorMemoryLayout::HeightSharded; + outputLayoutOverrideParams.memoryLayout = Layout::Tile; + outputLayoutOverrideParams.dataType = mlir::tt::DataType::Float16; + + return outputLayoutOverrideParams; + } + + bool + compareInputLayoutOverrides(llvm::StringMap in1, + llvm::StringMap in2) { + // Check if the sizes of the two input layout overrides are the same. + if (in1.size() != in2.size()) { + return false; + } + llvm::StringMap::iterator it1; + for (it1 = in1.begin(); it1 != in1.end(); it1++) { + // Check if the two input layout overrides have the same keys. + llvm::StringMap::iterator it2 = + in2.find(it1->getKey()); + if (it2 == in2.end()) { + return false; + } + // Check if the two input layout overrides have the same values. + // The structure InputLayoutOverrideParams has overloaded operators for == + // and !=, so we can compare the objects in this way. + if (it1->getValue() != it2->getValue()) { + return false; + } + } + return true; + } + + bool compareOutputLayoutOverrides( + llvm::StringMap out1, + llvm::StringMap out2) { + // Check if the sizes of the two output layout overrides are the same. + if (out1.size() != out2.size()) { + return false; + } + llvm::StringMap::iterator it1; + for (it1 = out1.begin(); it1 != out1.end(); it1++) { + // Check if the two output layout overrides have the same keys. + llvm::StringMap::iterator it2 = + out2.find(it1->getKey()); + if (it2 == out2.end()) { + return false; + } + // Check if the two output layout overrides have the same values. + // The structure OutputLayoutOverrideParams has overloaded operators for + // == and !=, so we can compare the objects in this way. + if (it1->getValue() != it2->getValue()) { + return false; + } + } + return true; + } + + void TearDown() override {} +}; + +// Test the setEnableOptimizer method +TEST_F(TestOptimizerOverrides, TestSetOptimizerPass) { + + optimizerOverridesHandler.setEnableOptimizer(true); + ASSERT_TRUE(optimizerOverridesHandler.getEnableOptimizer()); + + optimizerOverridesHandler.setEnableOptimizer(false); + ASSERT_FALSE(optimizerOverridesHandler.getEnableOptimizer()); +} + +// Test the setMemoryConfig method +TEST_F(TestOptimizerOverrides, TestSetMemoryConfig) { + + optimizerOverridesHandler.setMemoryReconfig(true); + ASSERT_TRUE(optimizerOverridesHandler.getMemoryReconfig()); + + optimizerOverridesHandler.setMemoryReconfig(false); + ASSERT_FALSE(optimizerOverridesHandler.getMemoryReconfig()); +} + +// Test the setMemoryLayoutAnalysis method +TEST_F(TestOptimizerOverrides, TestSetMemoryLayoutAnalysis) { + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysis(true); + ASSERT_TRUE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysis()); + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysis(false); + ASSERT_FALSE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysis()); +} + +// Test the setEnableMemoryLayoutAnalysisPolicy method +TEST_F(TestOptimizerOverrides, TestSetEnableMemoryLayoutAnalysisPolicy) { + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysisPolicy(true); + ASSERT_TRUE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysisPolicy()); + + optimizerOverridesHandler.setEnableMemoryLayoutAnalysisPolicy(false); + ASSERT_FALSE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysisPolicy()); +} + +// Test the setMemoryLayoutAnalysisPolicy method +TEST_F(TestOptimizerOverrides, TestSetMemoryLayoutAnalysisPolicy) { + + optimizerOverridesHandler.setMemoryLayoutAnalysisPolicy( + mlir::tt::MemoryLayoutAnalysisPolicyType::DFSharding); + ASSERT_EQ(optimizerOverridesHandler.getMemoryLayoutAnalysisPolicy(), + mlir::tt::MemoryLayoutAnalysisPolicyType::DFSharding); + + optimizerOverridesHandler.setMemoryLayoutAnalysisPolicy( + mlir::tt::MemoryLayoutAnalysisPolicyType::L1Interleaved); + ASSERT_EQ(optimizerOverridesHandler.getMemoryLayoutAnalysisPolicy(), + mlir::tt::MemoryLayoutAnalysisPolicyType::L1Interleaved); +} + +// Test the setInputLayoutOverrides method +TEST_F(TestOptimizerOverrides, TestSetInputLayoutOverrides) { + + llvm::StringMap inputLayoutOverrides = + createInputLayoutOverrides(); + + optimizerOverridesHandler.setInputLayoutOverrides(inputLayoutOverrides); + ASSERT_TRUE(compareInputLayoutOverrides( + optimizerOverridesHandler.getInputLayoutOverrides(), + inputLayoutOverrides)); +} + +// Test the setOutputLayoutOverrides method +TEST_F(TestOptimizerOverrides, TestSetOutputLayoutOverrides) { + + llvm::StringMap outputLayoutOverrides = + createOutputLayoutOverrides(); + + optimizerOverridesHandler.setOutputLayoutOverrides(outputLayoutOverrides); + ASSERT_TRUE(compareOutputLayoutOverrides( + optimizerOverridesHandler.getOutputLayoutOverrides(), + outputLayoutOverrides)); +} + +// Test the addInputLayoutOverride method passing the whole object +TEST_F(TestOptimizerOverrides, TestAddInputLayoutOverrideObject) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the first function, which takes the whole object as a + // parameter. + + llvm::StringMap inputLayoutOverrides = + createInputLayoutOverrides(); + + optimizerOverridesHandler.addInputLayoutOverride( + "input0", createInputLayoutOverrideParams()); + optimizerOverridesHandler.addInputLayoutOverride( + "input1", createInputLayoutOverrideParams()); + optimizerOverridesHandler.addInputLayoutOverride( + "input2", createInputLayoutOverrideParams()); + + ASSERT_TRUE(compareInputLayoutOverrides( + optimizerOverridesHandler.getInputLayoutOverrides(), + inputLayoutOverrides)); +} + +// Test the addInputLayoutOverride method passing the individual parameters +TEST_F(TestOptimizerOverrides, TestAddInputLayoutOverrideParams) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the second function, which takes the individual parameters. + + llvm::StringMap inputLayoutOverrides = + createInputLayoutOverrides(); + + llvm::SmallVector operandIdxes1 = {0, 1}; + llvm::SmallVector operandIdxes2 = {0, 1}; + llvm::SmallVector operandIdxes3 = {0, 1}; + + optimizerOverridesHandler.addInputLayoutOverride("input0", operandIdxes1); + optimizerOverridesHandler.addInputLayoutOverride("input1", operandIdxes2); + optimizerOverridesHandler.addInputLayoutOverride("input2", operandIdxes3); + + ASSERT_TRUE(compareInputLayoutOverrides( + optimizerOverridesHandler.getInputLayoutOverrides(), + inputLayoutOverrides)); +} + +// Test the addOutputLayoutOverride method passing the whole object +TEST_F(TestOptimizerOverrides, TestAddOutputLayoutOverrideObject) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the first function, which takes the whole object as a + // parameter. + + llvm::StringMap outputLayoutOverrides = + createOutputLayoutOverrides(); + + optimizerOverridesHandler.addOutputLayoutOverride( + "output0", createOutputLayoutOverrideParams_0()); + optimizerOverridesHandler.addOutputLayoutOverride( + "output1", createOutputLayoutOverrideParams_1()); + optimizerOverridesHandler.addOutputLayoutOverride( + "output2", createOutputLayoutOverrideParams_2()); + + ASSERT_TRUE(compareOutputLayoutOverrides( + optimizerOverridesHandler.getOutputLayoutOverrides(), + outputLayoutOverrides)); +} + +// Test the addOutputLayoutOverride method passing the individual parameters +TEST_F(TestOptimizerOverrides, TestAddOutputLayoutOverrideParams) { + + // This method is implemented across two functions in the + // OptimizerOverridesHandler class. The first function takes the whole object + // as a parameter, while the second function takes the individual parameters. + + // Here, we test the second function, which takes the individual parameters. + + llvm::StringMap outputLayoutOverrides = + createOutputLayoutOverrides(); + + llvm::SmallVector grid1 = {2, 2}; + llvm::SmallVector grid2 = {8, 4}; + llvm::SmallVector grid3 = {3, 6}; + + optimizerOverridesHandler.addOutputLayoutOverride( + "output0", grid1, BufferType::DRAM, TensorMemoryLayout::Interleaved, + Layout::Tile, mlir::tt::DataType::Float16); + optimizerOverridesHandler.addOutputLayoutOverride( + "output1", grid2, BufferType::L1, TensorMemoryLayout::BlockSharded, + Layout::RowMajor, mlir::tt::DataType::Float16); + optimizerOverridesHandler.addOutputLayoutOverride( + "output2", grid3, BufferType::SystemMemory, + TensorMemoryLayout::HeightSharded, Layout::Tile, + mlir::tt::DataType::Float16); + + ASSERT_TRUE(compareOutputLayoutOverrides( + optimizerOverridesHandler.getOutputLayoutOverrides(), + outputLayoutOverrides)); +} + +// Test the setSystemDescPath method +TEST_F(TestOptimizerOverrides, TestSetSystemDescPath) { + + optimizerOverridesHandler.setSystemDescPath("system_desc_path"); + ASSERT_EQ(optimizerOverridesHandler.getSystemDescPath(), "system_desc_path"); +} + +// Test the setMaxLegalLayouts method +TEST_F(TestOptimizerOverrides, TestSetMaxLegalLayouts) { + + optimizerOverridesHandler.setMaxLegalLayouts(10); + ASSERT_EQ(optimizerOverridesHandler.getMaxLegalLayouts(), 10); +} + +// Test the setMeshShape method +TEST_F(TestOptimizerOverrides, TestSetMeshShape) { + + std::vector meshShape; + meshShape.push_back(1); + meshShape.push_back(2); + + optimizerOverridesHandler.setMeshShape(meshShape); + ASSERT_EQ(optimizerOverridesHandler.getMeshShape()[0], meshShape[0]); + ASSERT_EQ(optimizerOverridesHandler.getMeshShape()[1], meshShape[1]); +} + +// Test the toString method +TEST_F(TestOptimizerOverrides, TestToString) { + + std::string options; + options += + "enable-optimizer=true "; // The optimizer pass is enabled by default. + options += "memreconfig-enabled=true "; + options += "memory-layout-analysis-enabled=true "; + options += "insert-memreconfig=add_0_1_2=0 "; + options += + "override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32"; + + llvm::SmallVector operandIdxes = {0}; + llvm::SmallVector grid = {1, 1}; + + optimizerOverridesHandler.setEnableOptimizer(true); + optimizerOverridesHandler.setEnableMemoryLayoutAnalysis(true); + optimizerOverridesHandler.setMemoryReconfig(true); + optimizerOverridesHandler.addInputLayoutOverride("add_0_1_2", operandIdxes); + optimizerOverridesHandler.addOutputLayoutOverride( + "add_1_2", grid, BufferType::DRAM, TensorMemoryLayout::Interleaved, + Layout::RowMajor, mlir::tt::DataType::Float32); + + ASSERT_EQ(optimizerOverridesHandler.toString(), options); +} diff --git a/third_party/CMakeLists.txt b/third_party/CMakeLists.txt index c9ff431bf1..e033913e24 100644 --- a/third_party/CMakeLists.txt +++ b/third_party/CMakeLists.txt @@ -1,6 +1,6 @@ include(ExternalProject) -set(TT_METAL_VERSION "69870bdeaf1c9270e325810249def6a3e9f38fb4") +set(TT_METAL_VERSION "82ba2cbad64d1e36cad446d1f2f9bd266883ae74") if ("$ENV{ARCH_NAME}" STREQUAL "grayskull") set(ARCH_NAME "grayskull")