diff --git a/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h b/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h index 0a443658f1..04f43f6428 100644 --- a/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h +++ b/include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h @@ -87,9 +87,9 @@ class OptimizerOverridesHandler { // Setters for the overrides // These are used to enable/disable the optimizer passes - void setOptimizerPass(bool); + void setEnableOptimizer(bool); // These are used to enable/disable the memory configurations - void setMemoryConfig(bool); + void setMemoryReconfig(bool); void setMemoryLayoutAnalysis(bool); void setEnableMemoryLayoutAnalysisPolicy(bool); void setMemoryLayoutAnalysisPolicy(MemoryLayoutAnalysisPolicyType); @@ -105,9 +105,9 @@ class OptimizerOverridesHandler { // Getters for the overrides // These are used to get the current state of the optimizer passes - bool getOptimizerPass() const; + bool getEnableOptimizer() const; // These are used to get the current state of the memory configurations - bool getMemoryConfig() const; + bool getMemoryReconfig() const; bool getMemoryLayoutAnalysis() const; bool getEnableMemoryLayoutAnalysisPolicy() const; MemoryLayoutAnalysisPolicyType getMemoryLayoutAnalysisPolicy() const; @@ -136,10 +136,10 @@ class OptimizerOverridesHandler { private: // Flags for enabling/disabling the optimizer passes - bool enableOptimizerPass = true; + bool enableOptimizer = true; // Flags for enabling/disabling the memory configurations - bool enableMemoryConfig = true; + bool enableMemoryReconfig = true; bool enableMemoryLayoutAnalysis = true; // Input layout overrides diff --git a/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp b/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp index 616516b601..a79e9364ec 100644 --- a/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp +++ b/lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp @@ -189,12 +189,12 @@ void InputLayoutOverrideParser::print( os << "\n"; } -void OptimizerOverridesHandler::setOptimizerPass(bool value) { - enableOptimizerPass = value; +void OptimizerOverridesHandler::setEnableOptimizer(bool value) { + enableOptimizer = value; } -void OptimizerOverridesHandler::setMemoryConfig(bool value) { - enableMemoryConfig = value; +void OptimizerOverridesHandler::setMemoryReconfig(bool value) { + enableMemoryReconfig = value; } void OptimizerOverridesHandler::setMemoryLayoutAnalysis(bool value) { enableMemoryLayoutAnalysis = value; @@ -227,12 +227,12 @@ void OptimizerOverridesHandler::setMeshShape(std::vector value) { meshShape = value; } -bool OptimizerOverridesHandler::getOptimizerPass() const { - return enableOptimizerPass; +bool OptimizerOverridesHandler::getEnableOptimizer() const { + return enableOptimizer; } -bool OptimizerOverridesHandler::getMemoryConfig() const { - return enableMemoryConfig; +bool OptimizerOverridesHandler::getMemoryReconfig() const { + return enableMemoryReconfig; } bool OptimizerOverridesHandler::getMemoryLayoutAnalysis() const { return enableMemoryLayoutAnalysis; @@ -268,11 +268,11 @@ std::string OptimizerOverridesHandler::toString() const { std::string options = ""; - if (enableOptimizerPass) { + if (enableOptimizer) { options += "enable-optimizer=true "; } - if (enableMemoryConfig) { + if (enableMemoryReconfig) { options += "memreconfig-enabled=true "; } diff --git a/test/unittests/Optimizer/TestOptimizerOverrides.cpp b/test/unittests/Optimizer/TestOptimizerOverrides.cpp index 9465e708cc..04d4559f2b 100644 --- a/test/unittests/Optimizer/TestOptimizerOverrides.cpp +++ b/test/unittests/Optimizer/TestOptimizerOverrides.cpp @@ -195,24 +195,24 @@ class TestOptimizerOverrides : public ::testing::Test { void TearDown() override {} }; -// Test the setOptimizerPass method +// Test the setEnableOptimizer method TEST_F(TestOptimizerOverrides, TestSetOptimizerPass) { - optimizerOverridesHandler.setOptimizerPass(true); - ASSERT_TRUE(optimizerOverridesHandler.getOptimizerPass()); + optimizerOverridesHandler.setEnableOptimizer(true); + ASSERT_TRUE(optimizerOverridesHandler.getEnableOptimizer()); - optimizerOverridesHandler.setOptimizerPass(false); - ASSERT_FALSE(optimizerOverridesHandler.getOptimizerPass()); + optimizerOverridesHandler.setEnableOptimizer(false); + ASSERT_FALSE(optimizerOverridesHandler.getEnableOptimizer()); } // Test the setMemoryConfig method TEST_F(TestOptimizerOverrides, TestSetMemoryConfig) { - optimizerOverridesHandler.setMemoryConfig(true); - ASSERT_TRUE(optimizerOverridesHandler.getMemoryConfig()); + optimizerOverridesHandler.setMemoryReconfig(true); + ASSERT_TRUE(optimizerOverridesHandler.getMemoryReconfig()); - optimizerOverridesHandler.setMemoryConfig(false); - ASSERT_FALSE(optimizerOverridesHandler.getMemoryConfig()); + optimizerOverridesHandler.setMemoryReconfig(false); + ASSERT_FALSE(optimizerOverridesHandler.getMemoryReconfig()); } // Test the setMemoryLayoutAnalysis method @@ -411,7 +411,7 @@ TEST_F(TestOptimizerOverrides, TestToString) { "override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32"; optimizerOverridesHandler.setMemoryLayoutAnalysis(true); - optimizerOverridesHandler.setMemoryConfig(true); + optimizerOverridesHandler.setMemoryReconfig(true); optimizerOverridesHandler.addInputLayoutOverride("add_0_1_2", {0}); optimizerOverridesHandler.addOutputLayoutOverride( "add_1_2", {1, 1}, BufferType::DRAM, TensorMemoryLayout::Interleaved,