Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dump original pir code to original_programs.py #64373

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,21 @@ int64_t GetOpCount(const ::pir::Operation* op) {
void ApplyCinnPass(::pir::Program* program,
const std::function<std::shared_ptr<pir::PassManager>()>&
CreatePassManager) {
PirToPyCodeConverter(program)
.file_name("original_programs.py")
.dump_symbolic_shape(false)
.SaveIfFlagEnabled();
ApplyPdToCinnPass(program, CreatePassManager);
ApplyCinnPreprocessPass(program, CreatePassManager);
ApplyBuildGroupOpPass(program, CreatePassManager);
PirToPyCodeConverter().SaveIfFlagEnabled("group_op_programs", *program);
PirToPyCodeConverter(program)
.file_name("group_op_programs.py")
.SaveIfFlagEnabled();
ApplyGroupOpPass(program, CreatePassManager);
ApplyDivideGroupOpToFusionOpPass(program, CreatePassManager);
PirToPyCodeConverter().SaveIfFlagEnabled("fusion_op_programs", *program);
PirToPyCodeConverter(program)
.file_name("fusion_op_programs.py")
.SaveIfFlagEnabled();
LOG(INFO) << "FusionOp count before lowering : *****[ "
<< GetOpCount<cinn::dialect::FusionOp>(program->module_op())
<< " ]*****";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "paddle/pir/include/dialect/shape/utils/shape_analysis.h"

COMMON_DECLARE_string(logging_pir_py_code_dir);
COMMON_DECLARE_bool(logging_trunc_pir_py_code);

namespace cinn::dialect::ir {

Expand Down Expand Up @@ -92,20 +93,28 @@ int64_t GetAutoIncrementalId() {
return seq_no++;
}

using ShapeAnalysisGetterT =
std::function<std::optional<pir::ShapeConstraintIRAnalysis*>(
const pir::Program*)>;

} // namespace

struct PirToPyCodeConverterHelper {
explicit PirToPyCodeConverterHelper(const pir::Program* program)
explicit PirToPyCodeConverterHelper(
const pir::Program* program,
const ShapeAnalysisGetterT& ShapeAnalysisGetter)
: program_(program),
indent_size_(kDefaultIndentSize),
seq_no_(GetAutoIncrementalId()) {}
seq_no_(GetAutoIncrementalId()),
ShapeAnalysisGetter_(ShapeAnalysisGetter) {}

std::string Convert() { return Convert(*program_); }

private:
const pir::Program* program_;
const int indent_size_;
int64_t seq_no_;
ShapeAnalysisGetterT ShapeAnalysisGetter_;

std::string Convert(const pir::Program& program) {
auto istrings = ConvertMethodsToPyClass(program.module_op(), [&]() {
Expand Down Expand Up @@ -147,7 +156,8 @@ struct PirToPyCodeConverterHelper {
template <typename DoEachEQCstrT>
void VisitEachEQCstr(const DoEachEQCstrT& DoEachEQCstr) {
const auto& constraints_mgr = GetConstraintsMgr();
for (const auto& [lhs, rhs] : constraints_mgr.equals().GetMap()) {
if (!constraints_mgr.has_value()) return;
for (const auto& [lhs, rhs] : constraints_mgr.value()->equals().GetMap()) {
if (lhs == rhs) continue;
DoEachEQCstr(lhs, rhs);
}
Expand All @@ -165,7 +175,8 @@ struct PirToPyCodeConverterHelper {
template <typename DoEachGtOneCstrT>
void VisitEachGtOneCstr(const DoEachGtOneCstrT& DoEachGtOneCstr) {
const auto& constraints_mgr = GetConstraintsMgr();
for (const auto& dim_expr : constraints_mgr.gtones()) {
if (!constraints_mgr.has_value()) return;
for (const auto& dim_expr : constraints_mgr.value()->gtones()) {
DoEachGtOneCstr(dim_expr);
}
}
Expand All @@ -181,7 +192,8 @@ struct PirToPyCodeConverterHelper {
void VisitEachBroadcastableCstr(
const DoEachBroadcastableCstrT& DoEachBroadcastableCstr) {
const auto& constraints_mgr = GetConstraintsMgr();
const auto& broadcastables = constraints_mgr.broadcastables();
if (!constraints_mgr.has_value()) return;
const auto& broadcastables = constraints_mgr.value()->broadcastables();
for (const auto& broadcastable : broadcastables) {
const auto& [lhs, rhs] = *broadcastable;
if (lhs == rhs) continue;
Expand All @@ -198,9 +210,10 @@ struct PirToPyCodeConverterHelper {
return ss.str();
}

const symbol::ConstraintsManager& GetConstraintsMgr() {
auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(program_);
return shape_analysis.constraints_manager();
std::optional<const symbol::ConstraintsManager*> GetConstraintsMgr() {
const auto& shape_analysis = ShapeAnalysisGetter_(program_);
if (!shape_analysis.has_value()) return std::nullopt;
return &shape_analysis.value()->constraints_manager();
}

IStrings ConvertModuleOp(const pir::ModuleOp& module) {
Expand Down Expand Up @@ -402,6 +415,12 @@ struct PirToPyCodeConverterHelper {
}
ss << attr_name << "=" << ConvertAttr(attr);
});
VisitSymbolicAttrs(op, [&](const auto& attr_name, const auto& attrs) {
if (i++ > 0) {
ss << ", ";
}
ss << attr_name << "=" << ConvertSymbolicAttrs(attrs);
});
return ss.str();
}

Expand All @@ -410,6 +429,25 @@ struct PirToPyCodeConverterHelper {
return std::visit(AttrConverter{attr}, adt_type_id.variant());
}

static std::string ConvertSymbolicAttrs(
const std::vector<std::optional<pir::Attribute>>& attrs) {
std::ostringstream ss;
ss << "self.a_array(";
int i = 0;
for (const auto& attr : attrs) {
if (i++ > 0) {
ss << ", ";
}
if (!attr.has_value()) {
ss << "self.a_symbol(self.s_null())";
} else {
ss << ConvertAttr(attr.value());
}
}
ss << ")";
return ss.str();
}

static std::string ConvertShapeOrData(
const symbol::ShapeOrDataDimExprs& shape_or_data) {
return shape_or_data.Match(
Expand Down Expand Up @@ -781,55 +819,53 @@ struct PirToPyCodeConverterHelper {
if (attr_name == "sym_shape_str") continue;
DoEachAttr(attr_name, attr);
}
DoEachAttr("__operands_symbols_signature__",
GetOpOperandsSymbolsSignature(op));
DoEachAttr("__results_symbols_signature__",
GetOpResultsSymbolsSignature(op));
}

pir::Attribute GetOpOperandsSymbolsSignature(const pir::Operation* op) {
std::vector<pir::Attribute> attrs = GetOpOperandsSymbolDimsAttributes(op);
return pir::ArrayAttribute::get(pir::IrContext::Instance(), attrs);
}

pir::Attribute GetOpResultsSymbolsSignature(const pir::Operation* op) {
std::vector<pir::Attribute> attrs = GetOpResultsSymbolDimsAttributes(op);
return pir::ArrayAttribute::get(pir::IrContext::Instance(), attrs);
template <typename DoEachAttrT>
void VisitSymbolicAttrs(const pir::Operation* op,
const DoEachAttrT& DoEachAttr) {
DoEachAttr("__operands_symbols_signature__",
GetOpOperandsSymbolDimsAttributes(op));
DoEachAttr("__results_symbols_signature__",
GetOpResultsSymbolDimsAttributes(op));
}

std::vector<pir::Attribute> GetOpOperandsSymbolDimsAttributes(
std::vector<std::optional<pir::Attribute>> GetOpOperandsSymbolDimsAttributes(
const pir::Operation* op) {
std::vector<pir::Attribute> attrs;
std::vector<std::optional<pir::Attribute>> attrs;
attrs.reserve(op->num_operands());
for (int i = 0; i < op->num_operands(); ++i) {
attrs.push_back(GetValueSymbolDimsAttribute(op->operand_source(i)));
}
return attrs;
}

std::vector<pir::Attribute> GetOpResultsSymbolDimsAttributes(
std::vector<std::optional<pir::Attribute>> GetOpResultsSymbolDimsAttributes(
const pir::Operation* op) {
std::vector<pir::Attribute> attrs;
std::vector<std::optional<pir::Attribute>> attrs;
attrs.reserve(op->num_results());
for (int i = 0; i < op->num_results(); ++i) {
attrs.push_back(GetValueSymbolDimsAttribute(op->result(i)));
}
return attrs;
}

pir::Attribute GetValueSymbolDimsAttribute(pir::Value value) {
std::optional<pir::Attribute> GetValueSymbolDimsAttribute(pir::Value value) {
auto* ctx = pir::IrContext::Instance();
using SymbolAttr = pir::shape::SymbolAttribute;
if (!value) {
return SymbolAttr::get(ctx, symbol::TensorShapeOrDataDimExprs{});
return std::nullopt;
}
const auto* shape_or_data = GetShapeOrDataDimExprs(value);
return SymbolAttr::get(ctx, *shape_or_data);
const auto& shape_or_data = GetShapeOrDataDimExprs(value);
if (!shape_or_data.has_value()) return std::nullopt;
return SymbolAttr::get(ctx, *shape_or_data.value());
}

const symbol::ShapeOrDataDimExprs* GetShapeOrDataDimExprs(pir::Value value) {
auto& shape_analysis = pir::ShapeAnalysisManager::Instance().Get(program_);
return &shape_analysis.GetShapeOrDataForValue(value);
std::optional<const symbol::ShapeOrDataDimExprs*> GetShapeOrDataDimExprs(
pir::Value value) {
const auto& shape_analysis = ShapeAnalysisGetter_(program_);
if (!shape_analysis.has_value()) return std::nullopt;
return &shape_analysis.value()->GetShapeOrDataForValue(value);
}

std::string ConvertInputTypes(const pir::Operation* op) {
Expand Down Expand Up @@ -869,6 +905,10 @@ struct PirToPyCodeConverterHelper {
template <typename T>
using AdtTypeId = ::common::AdtTypeId<T>;

std::string operator()(AdtTypeId<cinn::dialect::ir::NullType>) {
return "self.t_null";
}

std::string operator()(AdtTypeId<::pir::VectorType>) {
std::stringstream ss;
const auto& name = ::pir::VectorType::name();
Expand Down Expand Up @@ -1130,22 +1170,39 @@ struct PirToPyCodeConverterHelper {
}
};

std::optional<pir::ShapeConstraintIRAnalysis*> GetShapeAnalysisFromManager(
const pir::Program* program) {
return &pir::ShapeAnalysisManager::Instance().Get(program);
}

std::optional<pir::ShapeConstraintIRAnalysis*> GetNullShapeAnalysis(
const pir::Program* program) {
return std::nullopt;
}

} // namespace

void PirToPyCodeConverter::SaveIfFlagEnabled(
const std::string& tag, const pir::Program& program) const {
void PirToPyCodeConverter::SaveIfFlagEnabled() const {
if (program_ == nullptr) return;
if (file_name_.empty()) return;
if (FLAGS_logging_pir_py_code_dir == "") return;
const std::string file_path =
FLAGS_logging_pir_py_code_dir + "/" + tag + ".py";
const std::string content = PirToPyCodeConverterHelper(&program).Convert();
FLAGS_logging_pir_py_code_dir + "/" + file_name_;
ShapeAnalysisGetterT ShapeAnalysisGetter =
(dump_symbolic_shape_ ? GetShapeAnalysisFromManager
: GetNullShapeAnalysis);
PirToPyCodeConverterHelper converter_helper(program_, ShapeAnalysisGetter);
const std::string content = converter_helper.Convert();
static std::mutex mutex;
std::unique_lock<std::mutex> lock(mutex);
static std::unordered_map<std::string, std::once_flag> once_flags;
std::call_once(once_flags[file_path], [&] {
std::ofstream ofs;
ofs.open(file_path.c_str(), std::ios::out | std::ios::trunc);
ofs.close();
});
if (FLAGS_logging_trunc_pir_py_code) {
static std::unordered_map<std::string, std::once_flag> once_flags;
std::call_once(once_flags[file_path], [&] {
std::ofstream ofs;
ofs.open(file_path.c_str(), std::ios::out | std::ios::trunc);
ofs.close();
});
}
std::ofstream ofs;
ofs.open(file_path.c_str(), std::ios::out | std::ios::app);
if (!ofs.is_open()) return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,27 @@ namespace cinn::dialect::ir {

class PirToPyCodeConverter {
public:
PirToPyCodeConverter() = default;
explicit PirToPyCodeConverter(pir::Program* program)
: program_(program), file_name_(), dump_symbolic_shape_(true) {}
PirToPyCodeConverter(const PirToPyCodeConverter&) = delete;
PirToPyCodeConverter(PirToPyCodeConverter&&) = delete;

void SaveIfFlagEnabled(const std::string& tag,
const pir::Program& program) const;
PirToPyCodeConverter& file_name(const std::string& file_name) {
file_name_ = file_name;
return *this;
}

PirToPyCodeConverter& dump_symbolic_shape(bool val) {
dump_symbolic_shape_ = val;
return *this;
}

void SaveIfFlagEnabled() const;

private:
pir::Program* program_;
std::string file_name_;
bool dump_symbolic_shape_;
};

} // namespace cinn::dialect::ir
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
namespace cinn::dialect::ir {

TypeAdtTypeId GetTypeAdtTypeId(const pir::Type& type) {
if (!type) {
return ::common::AdtTypeId<NullType>{};
}
#define RETURN_TYPE_TYPE_ID_IF_MATCH(cls) \
if (type.isa<cls>()) return ::common::AdtTypeId<cls>{};
FOR_EACH_PIR_ALTERNATIVE_TYPLE(RETURN_TYPE_TYPE_ID_IF_MATCH)
Expand Down
10 changes: 6 additions & 4 deletions paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

namespace pir {
class Type;

class VectorType;
class DenseTensorType;
class BFloat16Type;
Expand Down Expand Up @@ -58,13 +57,16 @@ class Complex128Type;

namespace cinn::dialect::ir {

class NullType;
class UnclassifiedType;

using TypeAdtTypeIdBase = ::common::AdtBaseTypeId<
using TypeAdtTypeIdBase =
::common::AdtBaseTypeId<NullType,
#define MAKE_TYPE_ADT_TYPE_ID_ALTERNATIVE(cls) cls,
FOR_EACH_PIR_ALTERNATIVE_TYPLE(MAKE_TYPE_ADT_TYPE_ID_ALTERNATIVE)
FOR_EACH_PIR_ALTERNATIVE_TYPLE(
MAKE_TYPE_ADT_TYPE_ID_ALTERNATIVE)
#undef MAKE_TYPE_ADT_TYPE_ID_ALTERNATIVE
UnclassifiedType>;
UnclassifiedType>;

struct TypeAdtTypeId : public TypeAdtTypeIdBase {
using TypeAdtTypeIdBase::TypeAdtTypeIdBase;
Expand Down
5 changes: 5 additions & 0 deletions paddle/common/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,11 @@ PHI_DEFINE_EXPORTED_string(logging_pir_py_code_dir,
"",
"the logging directory to save pir py code");

PHI_DEFINE_EXPORTED_bool(logging_trunc_pir_py_code,
true,
"whether truncate the logging files under directory "
"FLAGS_logging_pir_py_code_dir");

/**
* Using PIR API in Python
* Name: enable_pir_api
Expand Down