Skip to content

Commit

Permalink
Add mechanism for overriding with corresponding unittests.
Browse files Browse the repository at this point in the history
  • Loading branch information
vcanicTT committed Nov 26, 2024
1 parent 124f0e9 commit 6c756cc
Showing 1 changed file with 54 additions and 57 deletions.
111 changes: 54 additions & 57 deletions test/unittests/TestOptimizerOverrides/TestOptimizerOverrides.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,9 @@ using namespace mlir::tt::ttnn;
class TestOptimizerOverrides : public ::testing::Test {

public:
std::shared_ptr<OptimizerOverridesHandler> optimizerOverridesHandler;
OptimizerOverridesHandler optimizerOverridesHandler;

void SetUp() override {
optimizerOverridesHandler = std::make_shared<OptimizerOverridesHandler>();
}
void SetUp() override {}

llvm::StringMap<InputLayoutOverrideParams> createInputLayoutOverrides() {

Expand Down Expand Up @@ -200,55 +198,54 @@ class TestOptimizerOverrides : public ::testing::Test {
// Test the setOptimizerPass method
TEST_F(TestOptimizerOverrides, TestSetOptimizerPass) {

optimizerOverridesHandler->setOptimizerPass(true);
ASSERT_TRUE(optimizerOverridesHandler->getOptimizerPass());
optimizerOverridesHandler.setOptimizerPass(true);
ASSERT_TRUE(optimizerOverridesHandler.getOptimizerPass());

optimizerOverridesHandler->setOptimizerPass(false);
ASSERT_FALSE(optimizerOverridesHandler->getOptimizerPass());
optimizerOverridesHandler.setOptimizerPass(false);
ASSERT_FALSE(optimizerOverridesHandler.getOptimizerPass());
}

// Test the setMemoryConfig method
TEST_F(TestOptimizerOverrides, TestSetMemoryConfig) {

optimizerOverridesHandler->setMemoryConfig(true);
ASSERT_TRUE(optimizerOverridesHandler->getMemoryConfig());
optimizerOverridesHandler.setMemoryConfig(true);
ASSERT_TRUE(optimizerOverridesHandler.getMemoryConfig());

optimizerOverridesHandler->setMemoryConfig(false);
ASSERT_FALSE(optimizerOverridesHandler->getMemoryConfig());
optimizerOverridesHandler.setMemoryConfig(false);
ASSERT_FALSE(optimizerOverridesHandler.getMemoryConfig());
}

// Test the setMemoryLayoutAnalysis method
TEST_F(TestOptimizerOverrides, TestSetMemoryLayoutAnalysis) {

optimizerOverridesHandler->setMemoryLayoutAnalysis(true);
ASSERT_TRUE(optimizerOverridesHandler->getMemoryLayoutAnalysis());
optimizerOverridesHandler.setMemoryLayoutAnalysis(true);
ASSERT_TRUE(optimizerOverridesHandler.getMemoryLayoutAnalysis());

optimizerOverridesHandler->setMemoryLayoutAnalysis(false);
ASSERT_FALSE(optimizerOverridesHandler->getMemoryLayoutAnalysis());
optimizerOverridesHandler.setMemoryLayoutAnalysis(false);
ASSERT_FALSE(optimizerOverridesHandler.getMemoryLayoutAnalysis());
}

// Test the setEnableMemoryLayoutAnalysisPolicy method
TEST_F(TestOptimizerOverrides, TestSetEnableMemoryLayoutAnalysisPolicy) {

optimizerOverridesHandler->setEnableMemoryLayoutAnalysisPolicy(true);
ASSERT_TRUE(optimizerOverridesHandler->getEnableMemoryLayoutAnalysisPolicy());
optimizerOverridesHandler.setEnableMemoryLayoutAnalysisPolicy(true);
ASSERT_TRUE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysisPolicy());

optimizerOverridesHandler->setEnableMemoryLayoutAnalysisPolicy(false);
ASSERT_FALSE(
optimizerOverridesHandler->getEnableMemoryLayoutAnalysisPolicy());
optimizerOverridesHandler.setEnableMemoryLayoutAnalysisPolicy(false);
ASSERT_FALSE(optimizerOverridesHandler.getEnableMemoryLayoutAnalysisPolicy());
}

// Test the setMemoryLayoutAnalysisPolicy method
TEST_F(TestOptimizerOverrides, TestSetMemoryLayoutAnalysisPolicy) {

optimizerOverridesHandler->setMemoryLayoutAnalysisPolicy(
optimizerOverridesHandler.setMemoryLayoutAnalysisPolicy(
mlir::tt::MemoryLayoutAnalysisPolicyType::DFSharding);
ASSERT_EQ(optimizerOverridesHandler->getMemoryLayoutAnalysisPolicy(),
ASSERT_EQ(optimizerOverridesHandler.getMemoryLayoutAnalysisPolicy(),
mlir::tt::MemoryLayoutAnalysisPolicyType::DFSharding);

optimizerOverridesHandler->setMemoryLayoutAnalysisPolicy(
optimizerOverridesHandler.setMemoryLayoutAnalysisPolicy(
mlir::tt::MemoryLayoutAnalysisPolicyType::L1Interleaved);
ASSERT_EQ(optimizerOverridesHandler->getMemoryLayoutAnalysisPolicy(),
ASSERT_EQ(optimizerOverridesHandler.getMemoryLayoutAnalysisPolicy(),
mlir::tt::MemoryLayoutAnalysisPolicyType::L1Interleaved);
}

Expand All @@ -258,9 +255,9 @@ TEST_F(TestOptimizerOverrides, TestSetInputLayoutOverrides) {
llvm::StringMap<InputLayoutOverrideParams> inputLayoutOverrides =
createInputLayoutOverrides();

optimizerOverridesHandler->setInputLayoutOverrides(inputLayoutOverrides);
optimizerOverridesHandler.setInputLayoutOverrides(inputLayoutOverrides);
ASSERT_TRUE(compareInputLayoutOverrides(
optimizerOverridesHandler->getInputLayoutOverrides(),
optimizerOverridesHandler.getInputLayoutOverrides(),
inputLayoutOverrides));
}

Expand All @@ -270,9 +267,9 @@ TEST_F(TestOptimizerOverrides, TestSetOutputLayoutOverrides) {
llvm::StringMap<OutputLayoutOverrideParams> outputLayoutOverrides =
createOutputLayoutOverrides();

optimizerOverridesHandler->setOutputLayoutOverrides(outputLayoutOverrides);
optimizerOverridesHandler.setOutputLayoutOverrides(outputLayoutOverrides);
ASSERT_TRUE(compareOutputLayoutOverrides(
optimizerOverridesHandler->getOutputLayoutOverrides(),
optimizerOverridesHandler.getOutputLayoutOverrides(),
outputLayoutOverrides));
}

Expand All @@ -289,15 +286,15 @@ TEST_F(TestOptimizerOverrides, TestAddInputLayoutOverrideObject) {
llvm::StringMap<InputLayoutOverrideParams> inputLayoutOverrides =
createInputLayoutOverrides();

optimizerOverridesHandler->addInputLayoutOverride(
optimizerOverridesHandler.addInputLayoutOverride(
"input0", createInputLayoutOverrideParams());
optimizerOverridesHandler->addInputLayoutOverride(
optimizerOverridesHandler.addInputLayoutOverride(
"input1", createInputLayoutOverrideParams());
optimizerOverridesHandler->addInputLayoutOverride(
optimizerOverridesHandler.addInputLayoutOverride(
"input2", createInputLayoutOverrideParams());

ASSERT_TRUE(compareInputLayoutOverrides(
optimizerOverridesHandler->getInputLayoutOverrides(),
optimizerOverridesHandler.getInputLayoutOverrides(),
inputLayoutOverrides));
}

Expand All @@ -313,12 +310,12 @@ TEST_F(TestOptimizerOverrides, TestAddInputLayoutOverrideParams) {
llvm::StringMap<InputLayoutOverrideParams> inputLayoutOverrides =
createInputLayoutOverrides();

optimizerOverridesHandler->addInputLayoutOverride("input0", {0, 1});
optimizerOverridesHandler->addInputLayoutOverride("input1", {0, 1});
optimizerOverridesHandler->addInputLayoutOverride("input2", {0, 1});
optimizerOverridesHandler.addInputLayoutOverride("input0", {0, 1});
optimizerOverridesHandler.addInputLayoutOverride("input1", {0, 1});
optimizerOverridesHandler.addInputLayoutOverride("input2", {0, 1});

ASSERT_TRUE(compareInputLayoutOverrides(
optimizerOverridesHandler->getInputLayoutOverrides(),
optimizerOverridesHandler.getInputLayoutOverrides(),
inputLayoutOverrides));
}

Expand All @@ -335,15 +332,15 @@ TEST_F(TestOptimizerOverrides, TestAddOutputLayoutOverrideObject) {
llvm::StringMap<OutputLayoutOverrideParams> outputLayoutOverrides =
createOutputLayoutOverrides();

optimizerOverridesHandler->addOutputLayoutOverride(
optimizerOverridesHandler.addOutputLayoutOverride(
"output0", createOutputLayoutOverrideParams_0());
optimizerOverridesHandler->addOutputLayoutOverride(
optimizerOverridesHandler.addOutputLayoutOverride(
"output1", createOutputLayoutOverrideParams_1());
optimizerOverridesHandler->addOutputLayoutOverride(
optimizerOverridesHandler.addOutputLayoutOverride(
"output2", createOutputLayoutOverrideParams_2());

ASSERT_TRUE(compareOutputLayoutOverrides(
optimizerOverridesHandler->getOutputLayoutOverrides(),
optimizerOverridesHandler.getOutputLayoutOverrides(),
outputLayoutOverrides));
}

Expand All @@ -359,34 +356,34 @@ TEST_F(TestOptimizerOverrides, TestAddOutputLayoutOverrideParams) {
llvm::StringMap<OutputLayoutOverrideParams> outputLayoutOverrides =
createOutputLayoutOverrides();

optimizerOverridesHandler->addOutputLayoutOverride(
optimizerOverridesHandler.addOutputLayoutOverride(
"output0", {2, 2}, BufferType::DRAM, TensorMemoryLayout::Interleaved,
Layout::Tile, mlir::tt::DataType::Float16);
optimizerOverridesHandler->addOutputLayoutOverride(
optimizerOverridesHandler.addOutputLayoutOverride(
"output1", {8, 4}, BufferType::L1, TensorMemoryLayout::BlockSharded,
Layout::RowMajor, mlir::tt::DataType::Float16);
optimizerOverridesHandler->addOutputLayoutOverride(
optimizerOverridesHandler.addOutputLayoutOverride(
"output2", {3, 6}, BufferType::SystemMemory,
TensorMemoryLayout::HeightSharded, Layout::Tile,
mlir::tt::DataType::Float16);

ASSERT_TRUE(compareOutputLayoutOverrides(
optimizerOverridesHandler->getOutputLayoutOverrides(),
optimizerOverridesHandler.getOutputLayoutOverrides(),
outputLayoutOverrides));
}

// Test the setSystemDescPath method
TEST_F(TestOptimizerOverrides, TestSetSystemDescPath) {

optimizerOverridesHandler->setSystemDescPath("system_desc_path");
ASSERT_EQ(optimizerOverridesHandler->getSystemDescPath(), "system_desc_path");
optimizerOverridesHandler.setSystemDescPath("system_desc_path");
ASSERT_EQ(optimizerOverridesHandler.getSystemDescPath(), "system_desc_path");
}

// Test the setMaxLegalLayouts method
TEST_F(TestOptimizerOverrides, TestSetMaxLegalLayouts) {

optimizerOverridesHandler->setMaxLegalLayouts(10);
ASSERT_EQ(optimizerOverridesHandler->getMaxLegalLayouts(), 10);
optimizerOverridesHandler.setMaxLegalLayouts(10);
ASSERT_EQ(optimizerOverridesHandler.getMaxLegalLayouts(), 10);
}

// Test the setMeshShape method
Expand All @@ -396,9 +393,9 @@ TEST_F(TestOptimizerOverrides, TestSetMeshShape) {
meshShape.push_back(1);
meshShape.push_back(2);

optimizerOverridesHandler->setMeshShape(meshShape);
ASSERT_EQ(optimizerOverridesHandler->getMeshShape()[0], meshShape[0]);
ASSERT_EQ(optimizerOverridesHandler->getMeshShape()[1], meshShape[1]);
optimizerOverridesHandler.setMeshShape(meshShape);
ASSERT_EQ(optimizerOverridesHandler.getMeshShape()[0], meshShape[0]);
ASSERT_EQ(optimizerOverridesHandler.getMeshShape()[1], meshShape[1]);
}

// Test the toString method
Expand All @@ -413,12 +410,12 @@ TEST_F(TestOptimizerOverrides, TestToString) {
options +=
"override-output-layout=add_1_2=1x1:dram:interleaved:row_major:f32";

optimizerOverridesHandler->setMemoryLayoutAnalysis(true);
optimizerOverridesHandler->setMemoryConfig(true);
optimizerOverridesHandler->addInputLayoutOverride("add_0_1_2", {0});
optimizerOverridesHandler->addOutputLayoutOverride(
optimizerOverridesHandler.setMemoryLayoutAnalysis(true);
optimizerOverridesHandler.setMemoryConfig(true);
optimizerOverridesHandler.addInputLayoutOverride("add_0_1_2", {0});
optimizerOverridesHandler.addOutputLayoutOverride(
"add_1_2", {1, 1}, BufferType::DRAM, TensorMemoryLayout::Interleaved,
Layout::RowMajor, mlir::tt::DataType::Float32);

ASSERT_EQ(optimizerOverridesHandler->toString(), options);
ASSERT_EQ(optimizerOverridesHandler.toString(), options);
}

0 comments on commit 6c756cc

Please sign in to comment.