Skip to content

Commit

Permalink
Dump original pir code to original_programs.py (PaddlePaddle#64373)
Browse files Browse the repository at this point in the history
* fix typo

* refactor hash and operator== for ShapeOrDataDimExprs

* dump null Types and null symbols to pir py code

---------

Co-authored-by: jiahy0825 <[email protected]>
  • Loading branch information
2 people authored and co63oc committed May 19, 2024
1 parent b7724b5 commit 505a40b
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 50 deletions.
12 changes: 10 additions & 2 deletions paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,21 @@ int64_t GetOpCount(const ::pir::Operation* op) {
void ApplyCinnPass(::pir::Program* program,
const std::function<std::shared_ptr<pir::PassManager>()>&
CreatePassManager) {
PirToPyCodeConverter(program)
.file_name("original_programs.py")
.dump_symbolic_shape(false)
.SaveIfFlagEnabled();
ApplyPdToCinnPass(program, CreatePassManager);
ApplyCinnPreprocessPass(program, CreatePassManager);
ApplyBuildGroupOpPass(program, CreatePassManager);
PirToPyCodeConverter().SaveIfFlagEnabled("group_op_programs", *program);
PirToPyCodeConverter(program)
.file_name("group_op_programs.py")
.SaveIfFlagEnabled();
ApplyGroupOpPass(program, CreatePassManager);
ApplyDivideGroupOpToFusionOpPass(program, CreatePassManager);
PirToPyCodeConverter().SaveIfFlagEnabled("fusion_op_programs", *program);
PirToPyCodeConverter(program)
.file_name("fusion_op_programs.py")
.SaveIfFlagEnabled();
LOG(INFO) << "FusionOp count before lowering : *****[ "
<< GetOpCount<cinn::dialect::FusionOp>(program->module_op())
<< " ]*****";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"

COMMON_DECLARE_string(logging_pir_py_code_dir);
COMMON_DECLARE_bool(logging_trunc_pir_py_code);

namespace cinn::dialect::ir {

Expand Down Expand Up @@ -92,20 +93,28 @@ int64_t GetAutoIncrementalId() {
return seq_no++;
}

using ShapeAnalysisGetterT =
std::function<std::optional<pir::ShapeConstraintIRAnalysis*>(
const pir::Program*)>;

} // namespace

struct PirToPyCodeConverterHelper {
explicit PirToPyCodeConverterHelper(const pir::Program* program)
explicit PirToPyCodeConverterHelper(
const pir::Program* program,
const ShapeAnalysisGetterT& ShapeAnalysisGetter)
: program_(program),
indent_size_(kDefaultIndentSize),
seq_no_(GetAutoIncrementalId()) {}
seq_no_(GetAutoIncrementalId()),
ShapeAnalysisGetter_(ShapeAnalysisGetter) {}

std::string Convert() { return Convert(*program_); }

private:
const pir::Program* program_;
const int indent_size_;
int64_t seq_no_;
ShapeAnalysisGetterT ShapeAnalysisGetter_;

std::string Convert(const pir::Program& program) {
auto istrings = ConvertMethodsToPyClass(program.module_op(), [&]() {
Expand Down Expand Up @@ -147,7 +156,8 @@ struct PirToPyCodeConverterHelper {
template <typename DoEachEQCstrT>
void VisitEachEQCstr(const DoEachEQCstrT& DoEachEQCstr) {
const auto& constraints_mgr = GetConstraintsMgr();
for (const auto& [lhs, rhs] : constraints_mgr.equals().GetMap()) {
if (!constraints_mgr.has_value()) return;
for (const auto& [lhs, rhs] : constraints_mgr.value()->equals().GetMap()) {
if (lhs == rhs) continue;
DoEachEQCstr(lhs, rhs);
}
Expand All @@ -165,7 +175,8 @@ struct PirToPyCodeConverterHelper {
template <typename DoEachGtOneCstrT>
void VisitEachGtOneCstr(const DoEachGtOneCstrT& DoEachGtOneCstr) {
const auto& constraints_mgr = GetConstraintsMgr();
for (const auto& dim_expr : constraints_mgr.gtones()) {
if (!constraints_mgr.has_value()) return;
for (const auto& dim_expr : constraints_mgr.value()->gtones()) {
DoEachGtOneCstr(dim_expr);
}
}
Expand All @@ -181,7 +192,8 @@ struct PirToPyCodeConverterHelper {
void VisitEachBroadcastableCstr(
const DoEachBroadcastableCstrT& DoEachBroadcastableCstr) {
const auto& constraints_mgr = GetConstraintsMgr();
const auto& broadcastables = constraints_mgr.broadcastables();
if (!constraints_mgr.has_value()) return;
const auto& broadcastables = constraints_mgr.value()->broadcastables();
for (const auto& broadcastable : broadcastables) {
const auto& [lhs, rhs] = *broadcastable;
if (lhs == rhs) continue;
Expand All @@ -198,9 +210,10 @@ struct PirToPyCodeConverterHelper {
return ss.str();
}

const symbol::ConstraintsManager& GetConstraintsMgr() {
auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(program_);
return shape_analysis.constraints_manager();
std::optional<const symbol::ConstraintsManager*> GetConstraintsMgr() {
const auto& shape_analysis = ShapeAnalysisGetter_(program_);
if (!shape_analysis.has_value()) return std::nullopt;
return &shape_analysis.value()->constraints_manager();
}

IStrings ConvertModuleOp(const pir::ModuleOp& module) {
Expand Down Expand Up @@ -402,6 +415,12 @@ struct PirToPyCodeConverterHelper {
}
ss << attr_name << "=" << ConvertAttr(attr);
});
VisitSymbolicAttrs(op, [&](const auto& attr_name, const auto& attrs) {
if (i++ > 0) {
ss << ", ";
}
ss << attr_name << "=" << ConvertSymbolicAttrs(attrs);
});
return ss.str();
}

Expand All @@ -410,6 +429,25 @@ struct PirToPyCodeConverterHelper {
return std::visit(AttrConverter{attr}, adt_type_id.variant());
}

static std::string ConvertSymbolicAttrs(
const std::vector<std::optional<pir::Attribute>>& attrs) {
std::ostringstream ss;
ss << "self.a_array(";
int i = 0;
for (const auto& attr : attrs) {
if (i++ > 0) {
ss << ", ";
}
if (!attr.has_value()) {
ss << "self.a_symbol(self.s_null())";
} else {
ss << ConvertAttr(attr.value());
}
}
ss << ")";
return ss.str();
}

static std::string ConvertShapeOrData(
const symbol::ShapeOrDataDimExprs& shape_or_data) {
return shape_or_data.Match(
Expand Down Expand Up @@ -781,55 +819,53 @@ struct PirToPyCodeConverterHelper {
if (attr_name == "sym_shape_str") continue;
DoEachAttr(attr_name, attr);
}
DoEachAttr("__operands_symbols_signature__",
GetOpOperandsSymbolsSignature(op));
DoEachAttr("__results_symbols_signature__",
GetOpResultsSymbolsSignature(op));
}

pir::Attribute GetOpOperandsSymbolsSignature(const pir::Operation* op) {
std::vector<pir::Attribute> attrs = GetOpOperandsSymbolDimsAttributes(op);
return pir::ArrayAttribute::get(pir::IrContext::Instance(), attrs);
}

pir::Attribute GetOpResultsSymbolsSignature(const pir::Operation* op) {
std::vector<pir::Attribute> attrs = GetOpResultsSymbolDimsAttributes(op);
return pir::ArrayAttribute::get(pir::IrContext::Instance(), attrs);
template <typename DoEachAttrT>
void VisitSymbolicAttrs(const pir::Operation* op,
const DoEachAttrT& DoEachAttr) {
DoEachAttr("__operands_symbols_signature__",
GetOpOperandsSymbolDimsAttributes(op));
DoEachAttr("__results_symbols_signature__",
GetOpResultsSymbolDimsAttributes(op));
}

std::vector<pir::Attribute> GetOpOperandsSymbolDimsAttributes(
std::vector<std::optional<pir::Attribute>> GetOpOperandsSymbolDimsAttributes(
const pir::Operation* op) {
std::vector<pir::Attribute> attrs;
std::vector<std::optional<pir::Attribute>> attrs;
attrs.reserve(op->num_operands());
for (int i = 0; i < op->num_operands(); ++i) {
attrs.push_back(GetValueSymbolDimsAttribute(op->operand_source(i)));
}
return attrs;
}

std::vector<pir::Attribute> GetOpResultsSymbolDimsAttributes(
std::vector<std::optional<pir::Attribute>> GetOpResultsSymbolDimsAttributes(
const pir::Operation* op) {
std::vector<pir::Attribute> attrs;
std::vector<std::optional<pir::Attribute>> attrs;
attrs.reserve(op->num_results());
for (int i = 0; i < op->num_results(); ++i) {
attrs.push_back(GetValueSymbolDimsAttribute(op->result(i)));
}
return attrs;
}

pir::Attribute GetValueSymbolDimsAttribute(pir::Value value) {
std::optional<pir::Attribute> GetValueSymbolDimsAttribute(pir::Value value) {
auto* ctx = pir::IrContext::Instance();
using SymbolAttr = pir::shape::SymbolAttribute;
if (!value) {
return SymbolAttr::get(ctx, symbol::TensorShapeOrDataDimExprs{});
return std::nullopt;
}
const auto* shape_or_data = GetShapeOrDataDimExprs(value);
return SymbolAttr::get(ctx, *shape_or_data);
const auto& shape_or_data = GetShapeOrDataDimExprs(value);
if (!shape_or_data.has_value()) return std::nullopt;
return SymbolAttr::get(ctx, *shape_or_data.value());
}

const symbol::ShapeOrDataDimExprs* GetShapeOrDataDimExprs(pir::Value value) {
auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(program_);
return &shape_analysis.GetShapeOrDataForValue(value);
std::optional<const symbol::ShapeOrDataDimExprs*> GetShapeOrDataDimExprs(
pir::Value value) {
const auto& shape_analysis = ShapeAnalysisGetter_(program_);
if (!shape_analysis.has_value()) return std::nullopt;
return &shape_analysis.value()->GetShapeOrDataForValue(value);
}

std::string ConvertInputTypes(const pir::Operation* op) {
Expand Down Expand Up @@ -869,6 +905,10 @@ struct PirToPyCodeConverterHelper {
template <typename T>
using AdtTypeId = ::common::AdtTypeId<T>;

std::string operator()(AdtTypeId<cinn::dialect::ir::NullType>) {
return "self.t_null";
}

std::string operator()(AdtTypeId<::pir::VectorType>) {
std::stringstream ss;
const auto& name = ::pir::VectorType::name();
Expand Down Expand Up @@ -1130,22 +1170,39 @@ struct PirToPyCodeConverterHelper {
}
};

std::optional<pir::ShapeConstraintIRAnalysis*> GetShapeAnalysisFromManager(
const pir::Program* program) {
return &pir::ShapeAnalysisManager::Instance().Get(program);
}

std::optional<pir::ShapeConstraintIRAnalysis*> GetNullShapeAnalysis(
const pir::Program* program) {
return std::nullopt;
}

} // namespace

void PirToPyCodeConverter::SaveIfFlagEnabled(
const std::string& tag, const pir::Program& program) const {
void PirToPyCodeConverter::SaveIfFlagEnabled() const {
if (program_ == nullptr) return;
if (file_name_.empty()) return;
if (FLAGS_logging_pir_py_code_dir == "") return;
const std::string file_path =
FLAGS_logging_pir_py_code_dir + "/" + tag + ".py";
const std::string content = PirToPyCodeConverterHelper(&program).Convert();
FLAGS_logging_pir_py_code_dir + "/" + file_name_;
ShapeAnalysisGetterT ShapeAnalysisGetter =
(dump_symbolic_shape_ ? GetShapeAnalysisFromManager
: GetNullShapeAnalysis);
PirToPyCodeConverterHelper converter_helper(program_, ShapeAnalysisGetter);
const std::string content = converter_helper.Convert();
static std::mutex mutex;
std::unique_lock<std::mutex> lock(mutex);
static std::unordered_map<std::string, std::once_flag> once_flags;
std::call_once(once_flags[file_path], [&] {
std::ofstream ofs;
ofs.open(file_path.c_str(), std::ios::out | std::ios::trunc);
ofs.close();
});
if (FLAGS_logging_trunc_pir_py_code) {
static std::unordered_map<std::string, std::once_flag> once_flags;
std::call_once(once_flags[file_path], [&] {
std::ofstream ofs;
ofs.open(file_path.c_str(), std::ios::out | std::ios::trunc);
ofs.close();
});
}
std::ofstream ofs;
ofs.open(file_path.c_str(), std::ios::out | std::ios::app);
if (!ofs.is_open()) return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,27 @@ namespace cinn::dialect::ir {

class PirToPyCodeConverter {
public:
PirToPyCodeConverter() = default;
explicit PirToPyCodeConverter(pir::Program* program)
: program_(program), file_name_(), dump_symbolic_shape_(true) {}
PirToPyCodeConverter(const PirToPyCodeConverter&) = delete;
PirToPyCodeConverter(PirToPyCodeConverter&&) = delete;

void SaveIfFlagEnabled(const std::string& tag,
const pir::Program& program) const;
PirToPyCodeConverter& file_name(const std::string& file_name) {
file_name_ = file_name;
return *this;
}

PirToPyCodeConverter& dump_symbolic_shape(bool val) {
dump_symbolic_shape_ = val;
return *this;
}

void SaveIfFlagEnabled() const;

private:
pir::Program* program_;
std::string file_name_;
bool dump_symbolic_shape_;
};

} // namespace cinn::dialect::ir
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
namespace cinn::dialect::ir {

TypeAdtTypeId GetTypeAdtTypeId(const pir::Type& type) {
if (!type) {
return ::common::AdtTypeId<NullType>{};
}
#define RETURN_TYPE_TYPE_ID_IF_MATCH(cls) \
if (type.isa<cls>()) return ::common::AdtTypeId<cls>{};
FOR_EACH_PIR_ALTERNATIVE_TYPLE(RETURN_TYPE_TYPE_ID_IF_MATCH)
Expand Down
10 changes: 6 additions & 4 deletions paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

namespace pir {
class Type;

class VectorType;
class DenseTensorType;
class BFloat16Type;
Expand Down Expand Up @@ -58,13 +57,16 @@ class Complex128Type;

namespace cinn::dialect::ir {

class NullType;
class UnclassifiedType;

using TypeAdtTypeIdBase = ::common::AdtBaseTypeId<
using TypeAdtTypeIdBase =
::common::AdtBaseTypeId<NullType,
#define MAKE_TYPE_ADT_TYPE_ID_ALTERNATIVE(cls) cls,
FOR_EACH_PIR_ALTERNATIVE_TYPLE(MAKE_TYPE_ADT_TYPE_ID_ALTERNATIVE)
FOR_EACH_PIR_ALTERNATIVE_TYPLE(
MAKE_TYPE_ADT_TYPE_ID_ALTERNATIVE)
#undef MAKE_TYPE_ADT_TYPE_ID_ALTERNATIVE
UnclassifiedType>;
UnclassifiedType>;

struct TypeAdtTypeId : public TypeAdtTypeIdBase {
using TypeAdtTypeIdBase::TypeAdtTypeIdBase;
Expand Down
5 changes: 5 additions & 0 deletions paddle/common/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,11 @@ PHI_DEFINE_EXPORTED_string(logging_pir_py_code_dir,
"",
"the logging directory to save pir py code");

PHI_DEFINE_EXPORTED_bool(logging_trunc_pir_py_code,
true,
"whether truncate the logging files under directory "
"FLAGS_logging_pir_py_code_dir");

/**
* Using PIR API in Python
* Name: enable_pir_api
Expand Down

0 comments on commit 505a40b

Please sign in to comment.