Skip to content

Commit

Permalink
Add mechanism for overriding with corresponding unittests.
Browse files Browse the repository at this point in the history
  • Loading branch information
vcanicTT committed Nov 27, 2024
1 parent c0ef845 commit 295dc19
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 26 deletions.
12 changes: 6 additions & 6 deletions include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ class OptimizerOverridesHandler {

// Setters for the overrides
// These are used to enable/disable the optimizer passes
void setOptimizerPass(bool);
void setEnableOptimizer(bool);
// These are used to enable/disable the memory configurations
void setMemoryConfig(bool);
void setMemoryReconfig(bool);
void setMemoryLayoutAnalysis(bool);
void setEnableMemoryLayoutAnalysisPolicy(bool);
void setMemoryLayoutAnalysisPolicy(MemoryLayoutAnalysisPolicyType);
Expand All @@ -105,9 +105,9 @@ class OptimizerOverridesHandler {

// Getters for the overrides
// These are used to get the current state of the optimizer passes
bool getOptimizerPass() const;
bool getEnableOptimizer() const;
// These are used to get the current state of the memory configurations
bool getMemoryConfig() const;
bool getMemoryReconfig() const;
bool getMemoryLayoutAnalysis() const;
bool getEnableMemoryLayoutAnalysisPolicy() const;
MemoryLayoutAnalysisPolicyType getMemoryLayoutAnalysisPolicy() const;
Expand Down Expand Up @@ -136,10 +136,10 @@ class OptimizerOverridesHandler {

private:
// Flags for enabling/disabling the optimizer passes
bool enableOptimizerPass = true;
bool enableOptimizer = true;

// Flags for enabling/disabling the memory configurations
bool enableMemoryConfig = true;
bool enableMemoryReconfig = true;
bool enableMemoryLayoutAnalysis = true;

// Input layout overrides
Expand Down
20 changes: 10 additions & 10 deletions lib/Dialect/TTNN/Utils/OptimizerOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -189,12 +189,12 @@ void InputLayoutOverrideParser::print(
os << "\n";
}

void OptimizerOverridesHandler::setOptimizerPass(bool value) {
enableOptimizerPass = value;
void OptimizerOverridesHandler::setEnableOptimizer(bool value) {
enableOptimizer = value;
}

void OptimizerOverridesHandler::setMemoryConfig(bool value) {
enableMemoryConfig = value;
void OptimizerOverridesHandler::setMemoryReconfig(bool value) {
enableMemoryReconfig = value;
}
void OptimizerOverridesHandler::setMemoryLayoutAnalysis(bool value) {
enableMemoryLayoutAnalysis = value;
Expand Down Expand Up @@ -227,12 +227,12 @@ void OptimizerOverridesHandler::setMeshShape(std::vector<int64_t> value) {
meshShape = value;
}

bool OptimizerOverridesHandler::getOptimizerPass() const {
return enableOptimizerPass;
bool OptimizerOverridesHandler::getEnableOptimizer() const {
return enableOptimizer;
}

bool OptimizerOverridesHandler::getMemoryConfig() const {
return enableMemoryConfig;
bool OptimizerOverridesHandler::getMemoryReconfig() const {
return enableMemoryReconfig;
}
bool OptimizerOverridesHandler::getMemoryLayoutAnalysis() const {
return enableMemoryLayoutAnalysis;
Expand Down Expand Up @@ -268,11 +268,11 @@ std::string OptimizerOverridesHandler::toString() const {

std::string options = "";

if (enableOptimizerPass) {
if (enableOptimizer) {
options += "enable-optimizer=true ";
}

if (enableMemoryConfig) {
if (enableMemoryReconfig) {
options += "memreconfig-enabled=true ";
}

Expand Down
20 changes: 10 additions & 10 deletions test/unittests/Optimizer/TestOptimizerOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,24 +195,24 @@ class TestOptimizerOverrides : public ::testing::Test {
void TearDown() override {}
};

// Test the setOptimizerPass method
// Test the setEnableOptimizer method
TEST_F(TestOptimizerOverrides, TestSetOptimizerPass) {

optimizerOverridesHandler.setOptimizerPass(true);
ASSERT_TRUE(optimizerOverridesHandler.getOptimizerPass());
optimizerOverridesHandler.setEnableOptimizer(true);
ASSERT_TRUE(optimizerOverridesHandler.getEnableOptimizer());

optimizerOverridesHandler.setOptimizerPass(false);
ASSERT_FALSE(optimizerOverridesHandler.getOptimizerPass());
optimizerOverridesHandler.setEnableOptimizer(false);
ASSERT_FALSE(optimizerOverridesHandler.getEnableOptimizer());
}

// Test the setMemoryConfig method
TEST_F(TestOptimizerOverrides, TestSetMemoryConfig) {

optimizerOverridesHandler.setMemoryConfig(true);
ASSERT_TRUE(optimizerOverridesHandler.getMemoryConfig());
optimizerOverridesHandler.setMemoryReconfig(true);
ASSERT_TRUE(optimizerOverridesHandler.getMemoryReconfig());

optimizerOverridesHandler.setMemoryConfig(false);
ASSERT_FALSE(optimizerOverridesHandler.getMemoryConfig());
optimizerOverridesHandler.setMemoryReconfig(false);
ASSERT_FALSE(optimizerOverridesHandler.getMemoryReconfig());
}

// Test the setMemoryLayoutAnalysis method
Expand Down Expand Up @@ -411,7 +411,7 @@ TEST_F(TestOptimizerOverrides, TestToString) {
"override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32";

optimizerOverridesHandler.setMemoryLayoutAnalysis(true);
optimizerOverridesHandler.setMemoryConfig(true);
optimizerOverridesHandler.setMemoryReconfig(true);
optimizerOverridesHandler.addInputLayoutOverride("add_0_1_2", {0});
optimizerOverridesHandler.addOutputLayoutOverride(
"add_1_2", {1, 1}, BufferType::DRAM, TensorMemoryLayout::Interleaved,
Expand Down

0 comments on commit 295dc19

Please sign in to comment.