Skip to content

Commit

Permalink
Add mechanism for overriding with corresponding unittests.
Browse files Browse the repository at this point in the history
  • Loading branch information
vcanicTT committed Nov 29, 2024
1 parent 40f1737 commit 47e7e40
Show file tree
Hide file tree
Showing 6 changed files with 326 additions and 444 deletions.
17 changes: 11 additions & 6 deletions include/ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,23 @@ struct MemoryLayoutAnalysisPolicyTypeParser
return false;
}

static void print(llvm::raw_ostream &os,
const MemoryLayoutAnalysisPolicyType &value) {
llvm::StringRef policy;
static std::string toString(const MemoryLayoutAnalysisPolicyType &value) {
std::string res;
switch (value) {
case MemoryLayoutAnalysisPolicyType::DFSharding:
policy = "DFSharding";
res += "DFSharding";
break;
case MemoryLayoutAnalysisPolicyType::L1Interleaved:
policy = "L1Interleaved";
res += "L1Interleaved";
break;
}
os << "memory-layout-analysis-policy=" << policy << "\n";
return res;
}

static void print(llvm::raw_ostream &os,
const MemoryLayoutAnalysisPolicyType &value) {
os << "memory-layout-analysis-policy="
<< MemoryLayoutAnalysisPolicyTypeParser::toString(value) << "\n";
}
};

Expand Down
4 changes: 3 additions & 1 deletion include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#ifndef TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H
#define TTMLIR_DIALECT_TTNN_PIPELINES_TTNNPIPELINES_H

#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"
#include "ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"

#include "mlir/Pass/PassOptions.h"

Expand Down
78 changes: 9 additions & 69 deletions include/ttmlir/Dialect/TTNN/Utils/OptimizerOverrides.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,84 +2,20 @@
//
// SPDX-License-Identifier: Apache-2.0

#pragma once

#ifndef TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H
#define TTMLIR_DIALECT_TTNN_UTILS_OPTIMIZEROVERRIDES_H

#include <llvm/Support/CommandLine.h>

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TT/Utils/MemoryLayoutAnalysisParams.h"

#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h"
#include "ttmlir/Dialect/TTNN/Utils/Utils.h"

namespace mlir::tt::ttnn {

struct OutputLayoutOverrideParams {

SmallVector<int64_t, 2> grid;
BufferType bufferType;
TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc...
Layout memoryLayout; // ROW_MAJOR / TILE
mlir::tt::DataType dataType;

bool operator==(const OutputLayoutOverrideParams rhs) const {
return grid[0] == rhs.grid[0] && grid[1] == rhs.grid[1] &&
bufferType == rhs.bufferType &&
tensorMemoryLayout == rhs.tensorMemoryLayout &&
memoryLayout == rhs.memoryLayout && dataType == rhs.dataType;
}

bool operator!=(const OutputLayoutOverrideParams &rhs) const {
return !(*this == rhs);
}
};

struct InputLayoutOverrideParams {

SmallVector<int64_t> operandIdxes;

bool operator==(const InputLayoutOverrideParams &rhs) const {
if (operandIdxes.size() != rhs.operandIdxes.size()) {
return false;
}
for (std::size_t i = 0; i < operandIdxes.size(); i++) {
if (operandIdxes[i] != rhs.operandIdxes[i]) {
return false;
}
}
return true;
}

bool operator!=(const InputLayoutOverrideParams &rhs) const {
return !(*this == rhs);
}
};

struct OutputLayoutOverrideParser
: public llvm::cl::parser<llvm::StringMap<OutputLayoutOverrideParams>> {
public:
OutputLayoutOverrideParser(llvm::cl::Option &opt)
: llvm::cl::parser<llvm::StringMap<OutputLayoutOverrideParams>>(opt) {}

bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
llvm::StringMap<OutputLayoutOverrideParams> &value);

static void print(llvm::raw_ostream &os,
const llvm::StringMap<OutputLayoutOverrideParams> &value);
};

struct InputLayoutOverrideParser
: public llvm::cl::parser<llvm::StringMap<InputLayoutOverrideParams>> {
public:
InputLayoutOverrideParser(llvm::cl::Option &opt)
: llvm::cl::parser<llvm::StringMap<InputLayoutOverrideParams>>(opt) {}

bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
llvm::StringMap<InputLayoutOverrideParams> &value);

static void print(llvm::raw_ostream &os,
const llvm::StringMap<InputLayoutOverrideParams> &value);
};

class OptimizerOverridesHandler {
public:
OptimizerOverridesHandler() {};
Expand Down Expand Up @@ -135,6 +71,10 @@ class OptimizerOverridesHandler {
tt::DataType);

private:
// Options for the TTIR to TTNN backend pipeline,
// we use them to extract the names and the deafulat values.
TTIRToTTNNBackendPipelineOptions pipelineOptions;

// Flags for enabling/disabling the optimizer passes
bool enableOptimizer = true;

Expand Down
79 changes: 79 additions & 0 deletions include/ttmlir/Dialect/TTNN/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,89 @@
#ifndef TTMLIR_DIALECT_TTNN_UTILS_UTILS_H
#define TTMLIR_DIALECT_TTNN_UTILS_UTILS_H

#include <llvm/Support/CommandLine.h>

#include "ttmlir/Dialect/TT/IR/TTOpsTypes.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsAttrs.h"
#include "ttmlir/Dialect/TTNN/IR/TTNNOpsTypes.h"

namespace mlir::tt::ttnn {

struct OutputLayoutOverrideParams {

SmallVector<int64_t, 2> grid;
BufferType bufferType;
TensorMemoryLayout tensorMemoryLayout; // INTERLEAVED / SHARDED etc...
Layout memoryLayout; // ROW_MAJOR / TILE
mlir::tt::DataType dataType;

bool operator==(const OutputLayoutOverrideParams rhs) const {
return grid[0] == rhs.grid[0] && grid[1] == rhs.grid[1] &&
bufferType == rhs.bufferType &&
tensorMemoryLayout == rhs.tensorMemoryLayout &&
memoryLayout == rhs.memoryLayout && dataType == rhs.dataType;
}

bool operator!=(const OutputLayoutOverrideParams &rhs) const {
return !(*this == rhs);
}
};

struct InputLayoutOverrideParams {

SmallVector<int64_t> operandIdxes;

bool operator==(const InputLayoutOverrideParams &rhs) const {
if (operandIdxes.size() != rhs.operandIdxes.size()) {
return false;
}
for (std::size_t i = 0; i < operandIdxes.size(); i++) {
if (operandIdxes[i] != rhs.operandIdxes[i]) {
return false;
}
}
return true;
}

bool operator!=(const InputLayoutOverrideParams &rhs) const {
return !(*this == rhs);
}
};

struct OutputLayoutOverrideParser
: public llvm::cl::parser<llvm::StringMap<OutputLayoutOverrideParams>> {
public:
OutputLayoutOverrideParser(llvm::cl::Option &opt)
: llvm::cl::parser<llvm::StringMap<OutputLayoutOverrideParams>>(opt) {}

bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
llvm::StringMap<OutputLayoutOverrideParams> &value);

static std::string
toString(const llvm::StringMap<OutputLayoutOverrideParams> &);

static void print(llvm::raw_ostream &os,
const llvm::StringMap<OutputLayoutOverrideParams> &value);
};

struct InputLayoutOverrideParser
: public llvm::cl::parser<llvm::StringMap<InputLayoutOverrideParams>> {
public:
InputLayoutOverrideParser(llvm::cl::Option &opt)
: llvm::cl::parser<llvm::StringMap<InputLayoutOverrideParams>>(opt) {}

bool parse(llvm::cl::Option &opt, StringRef argName, StringRef arg,
llvm::StringMap<InputLayoutOverrideParams> &value);

static std::string
toString(const llvm::StringMap<InputLayoutOverrideParams> &);

static void print(llvm::raw_ostream &os,
const llvm::StringMap<InputLayoutOverrideParams> &value);
};

} // namespace mlir::tt::ttnn

namespace mlir::tt::ttnn::utils {

// Map tt::MemorySpace to ttnn::BufferType
Expand Down
Loading

0 comments on commit 47e7e40

Please sign in to comment.