Skip to content

Commit

Permalink
FLAGS_logging_pir_py_code (PaddlePaddle#63981)
Browse files Browse the repository at this point in the history
* FLAGS_logging_pir_py_code

* FLAGS_logging_pir_py_code_dir

---------

Co-authored-by: jiahy0825 <[email protected]>
  • Loading branch information
2 people authored and co63oc committed May 10, 2024
1 parent c316322 commit a3caffd
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 17 deletions.
8 changes: 2 additions & 6 deletions paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,10 @@ void ApplyCinnPass(::pir::Program* program,
ApplyPdToCinnPass(program, CreatePassManager);
ApplyCinnPreprocessPass(program, CreatePassManager);
ApplyBuildGroupOpPass(program, CreatePassManager);
LOG(INFO) << "====[pir-to-py-code group-ops begin]===" << std::endl
<< PirToPyCodeConverter().Convert(*program);
LOG(INFO) << "====[pir-to-py-code group-ops end]===";
PirToPyCodeConverter().SaveIfFlagEnabled("group_op_programs", *program);
ApplyGroupOpPass(program, CreatePassManager);
ApplyDivideGroupOpToFusionOpPass(program, CreatePassManager);
LOG(INFO) << "====[pir-to-py-code fusion-ops begin]===" << std::endl
<< PirToPyCodeConverter().Convert(*program);
LOG(INFO) << "====[pir-to-py-code fusion-ops end]===";
PirToPyCodeConverter().SaveIfFlagEnabled("fusion_op_programs", *program);
LOG(INFO) << "FusionOp count before lowering : *****[ "
<< GetOpCount<cinn::dialect::FusionOp>(program->module_op())
<< " ]*****";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ class FusionOpPattern : public pir::OpRewritePattern<cinn::dialect::FusionOp> {
pir::PatternRewriter& rewriter) const { // NOLINT
auto it = op_handler_map().find(op->name());
if (it == op_handler_map().end()) {
LOG(WARNING) << "No fallback handler for op: " << op->name();
VLOG(4) << "No fallback handler for op: " << op->name();
return std::nullopt;
}
return (this->*(it->second))(op, rewriter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/pir_to_py_code_converter.h"
#include <atomic>
#include <iomanip>
#include <mutex>
#include <sstream>
#include <unordered_set>
#include <variant>
Expand All @@ -23,6 +24,7 @@
#include "paddle/cinn/hlir/dialect/operator/transforms/type_adt_type_id.h"
#include "paddle/common/adt_type_id.h"
#include "paddle/common/ddim.h"
#include "paddle/common/flags.h"
#include "paddle/common/overloaded.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
Expand All @@ -31,6 +33,7 @@
#include "paddle/pir/include/core/program.h"
#include "paddle/pir/include/dialect/control_flow/ir/cf_op.h"
#include "paddle/pir/include/dialect/shape/ir/shape_attribute.h"
COMMON_DECLARE_string(logging_pir_py_code_dir);

namespace cinn::dialect::ir {

Expand Down Expand Up @@ -79,8 +82,27 @@ struct OpPyCode {

constexpr int kDefaultIndentSize = 2;

namespace {

int64_t GetAutoIncrementalId() {
static std::atomic<int64_t> seq_no(0);
return seq_no++;
}

} // namespace

struct PirToPyCodeConverterHelper {
PirToPyCodeConverterHelper() : indent_size_(kDefaultIndentSize) {}
explicit PirToPyCodeConverterHelper(const pir::Program* program)
: program_(program),
indent_size_(kDefaultIndentSize),
seq_no_(GetAutoIncrementalId()) {}

std::string Convert() { return Convert(*program_); }

private:
const pir::Program* program_;
const int indent_size_;
int64_t seq_no_;

std::string Convert(const pir::Program& program) {
auto istrings = ConvertMethodsToPyClass(program.module_op(), [&]() {
Expand All @@ -92,7 +114,6 @@ struct PirToPyCodeConverterHelper {
return ConvertIStringsToString(istrings);
}

private:
IStrings DefineInit(const pir::ModuleOp& module) {
IStrings def_init;
def_init.push_back(IString("def __init__(self):"));
Expand Down Expand Up @@ -783,15 +804,15 @@ struct PirToPyCodeConverterHelper {
IStrings ret;
{
std::stringstream ss;
ss << "class " << GetPyClassName(module) << ":";
ss << "class " << GetPyClassName() << ":";
ret.push_back(IString(ss.str()));
}
PushBackIndented(&ret, GetBody());
return ret;
}

std::string GetPyClassName(const pir::ModuleOp& module) {
return std::string("Program");
std::string GetPyClassName() {
return std::string("PirProgram_") + std::to_string(seq_no_);
}

std::string ConvertIStringsToString(const IStrings& istrings) {
Expand Down Expand Up @@ -819,14 +840,29 @@ struct PirToPyCodeConverterHelper {
ret->push_back(Indent(istring));
}
}

const int indent_size_;
};

} // namespace

std::string PirToPyCodeConverter::Convert(const pir::Program& program) const {
return PirToPyCodeConverterHelper().Convert(program);
void PirToPyCodeConverter::SaveIfFlagEnabled(
const std::string& tag, const pir::Program& program) const {
if (FLAGS_logging_pir_py_code_dir == "") return;
const std::string file_path =
FLAGS_logging_pir_py_code_dir + "/" + tag + ".py";
const std::string content = PirToPyCodeConverterHelper(&program).Convert();
static std::mutex mutex;
std::unique_lock<std::mutex> lock(mutex);
static std::unordered_map<std::string, std::once_flag> once_flags;
std::call_once(once_flags[file_path], [&] {
std::ofstream ofs;
ofs.open(file_path.c_str(), std::ios::out | std::ios::trunc);
ofs.close();
});
std::ofstream ofs;
ofs.open(file_path.c_str(), std::ios::out | std::ios::app);
if (!ofs.is_open()) return;
ofs << content << std::endl;
ofs.close();
}

} // namespace cinn::dialect::ir
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ class PirToPyCodeConverter {
PirToPyCodeConverter(const PirToPyCodeConverter&) = delete;
PirToPyCodeConverter(PirToPyCodeConverter&&) = delete;

std::string Convert(const pir::Program& program) const;
void SaveIfFlagEnabled(const std::string& tag,
const pir::Program& program) const;
};

} // namespace cinn::dialect::ir
4 changes: 4 additions & 0 deletions paddle/common/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,10 @@ PHI_DEFINE_EXPORTED_bool(enable_pir_with_pt_in_dy2st,
true,
"Enable new IR in executor");

PHI_DEFINE_EXPORTED_string(logging_pir_py_code_dir,
"",
"the logging directory to save pir py code");

/**
* Using PIR API in Python
* Name: enable_pir_api
Expand Down

0 comments on commit a3caffd

Please sign in to comment.