Skip to content

Commit

Permalink
[PIR] fix onednn dialect name (#60665)
Browse files Browse the repository at this point in the history
* fix onednn dialect name
  • Loading branch information
wanghuancoder authored Jan 11, 2024
1 parent 92343a0 commit f8eff51
Show file tree
Hide file tree
Showing 11 changed files with 32 additions and 32 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ paddle/fluid/pir/dialect/operator/ir/pd_api.*
paddle/fluid/pir/dialect/operator/ir/op_decomp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op_vjp.cc
paddle/fluid/pir/dialect/operator/ir/pd_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.*
paddle/fluid/pir/dialect/operator/ir/onednn_op.*
paddle/fluid/pir/dialect/operator/ir/pd_onednn_op_info.*
paddle/fluid/pir/dialect/operator/ir/pd_op_bwd.*
paddle/fluid/pir/dialect/operator/ir/pd_op_fused.*
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/framework/new_executor/pir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -731,7 +731,7 @@ void PirInterpreter::BuildInstruction() {
CREATE_INSTR(PhiKernelInstruction);
}
#ifdef PADDLE_WITH_DNNL
} else if (op.dialect()->name() == "pd_onednn_kernel") {
} else if (op.dialect()->name() == "onednn_kernel") {
auto op_name = op.attributes()
.at("op_name")
.dyn_cast<::pir::StrAttribute>()
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/ir_adaptor/translator/op_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
#include "paddle/utils/blank.h"

#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
#endif
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/pir/dialect/CMakeLists.txt.
Expand Down Expand Up @@ -86,7 +86,7 @@ using AttributeHandlerFn = std::function<pir::Attribute(
using DenseTensorTypeStorage = paddle::dialect::DenseTensorTypeStorage;
constexpr char kTargetDialectPrefix[] = "pd_op."; // NOLINT
#ifdef PADDLE_WITH_DNNL
constexpr char kOneDNNTargetDialectPrefix[] = "pd_onednn_op."; // NOLINT
constexpr char kOneDNNTargetDialectPrefix[] = "onednn_op."; // NOLINT
#endif
constexpr char kCustomOpDialectPrefix[] = "custom_op.";
constexpr char kEmptyVarName[] = "@EMPTY@"; // NOLINT
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pir/dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -123,9 +123,9 @@ if(WITH_MKLDNN)
set(op_onednn_info_file_tmp ${op_onednn_info_file}.tmp)

set(onednn_op_namespace paddle,onednn,dialect)
set(onednn_dialect_name pd_onednn_op)
set(onednn_op_header_file ${PD_DIALECT_SOURCE_DIR}/pd_onednn_op.h)
set(onednn_op_source_file ${PD_DIALECT_SOURCE_DIR}/pd_onednn_op.cc)
set(onednn_dialect_name onednn_op)
set(onednn_op_header_file ${PD_DIALECT_SOURCE_DIR}/onednn_op.h)
set(onednn_op_source_file ${PD_DIALECT_SOURCE_DIR}/onednn_op.cc)
set(onednn_op_header_file_tmp ${onednn_op_header_file}.tmp)
set(onednn_op_source_file_tmp ${onednn_op_source_file}.tmp)

Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class OneDNNKernelDialect : public pir::Dialect {
public:
explicit OneDNNKernelDialect(pir::IrContext* context);

static const char* name() { return "pd_onednn_kernel"; }
static const char* name() { return "onednn_kernel"; }

void PrintType(pir::Type type, std::ostream& os) const override;

Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pir/dialect/kernel/ir/kernel_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class CustomKernelOp : public pir::Op<CustomKernelOp> {
class OneDNNPhiKernelOp : public pir::Op<OneDNNPhiKernelOp> {
public:
using Op::Op;
static const char *name() { return "pd_onednn_kernel.phi_kernel"; }
static const char *name() { return "onednn_kernel.phi_kernel"; }
static constexpr uint32_t attributes_num = 3;
static const char *attributes_name[attributes_num];
std::string op_name();
Expand All @@ -72,7 +72,7 @@ class OneDNNPhiKernelOp : public pir::Op<OneDNNPhiKernelOp> {
class OneDNNMixedPhiKernelOp : public pir::Op<OneDNNMixedPhiKernelOp> {
public:
using Op::Op;
static const char *name() { return "pd_onednn_kernel.phi_mixed_kernel"; }
static const char *name() { return "onednn_kernel.phi_mixed_kernel"; }
static constexpr uint32_t attributes_num = 3;
static const char *attributes_name[attributes_num];
std::string op_name();
Expand All @@ -84,7 +84,7 @@ class OneDNNMixedPhiKernelOp : public pir::Op<OneDNNMixedPhiKernelOp> {
class OneDNNLegacyKernelOp : public pir::Op<OneDNNLegacyKernelOp> {
public:
using Op::Op;
static const char *name() { return "pd_onednn_kernel.legacy_kernel"; }
static const char *name() { return "onednn_kernel.legacy_kernel"; }
static constexpr uint32_t attributes_num = 3;
static const char *attributes_name[attributes_num];
std::string op_name();
Expand Down
32 changes: 16 additions & 16 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1151,20 +1151,20 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
if (
op_info.backward_name
and op_info.op_phi_name[0] not in vjp_interface_black_list
and dialect_name != "pd_onednn_op"
and dialect_name != "onednn_op"
):
op_interfaces += ["paddle::dialect::VjpInterface"]
exclusive_interface_str = gen_exclusive_interface_str(
op_info, op_info_items
)

if dialect_name == "pd_op" or dialect_name == "pd_onednn_op":
if dialect_name == "pd_op" or dialect_name == "onednn_op":
op_interfaces += ["paddle::dialect::GetKernelTypeForVarInterface"]

# if op has custom vjp rule, then append a CustomVjpTrait to it
if (
op_info.op_phi_name[0] in custom_vjp_op_name_list
and dialect_name != "pd_onednn_op"
and dialect_name != "onednn_op"
):
op_traits += ["paddle::dialect::CustomVjpTrait"]

Expand All @@ -1186,7 +1186,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
if op_name[-1] == "_":
op_traits += ["paddle::dialect::InplaceTrait"]

if dialect_name == "pd_onednn_op":
if dialect_name == "onednn_op":
op_traits += ["paddle::dialect::OneDNNTrait"]

if op_info.is_onednn_only:
Expand All @@ -1210,7 +1210,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
if (
op_name in decomp_interface_declare_gen_op_list
and kernel_func_name in decomp_interface_declare_gen_op_list
and dialect_name != "pd_onednn_op"
and dialect_name != "onednn_op"
):
op_interfaces = op_interfaces + [
"paddle::dialect::DecompInterface"
Expand Down Expand Up @@ -1274,7 +1274,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
build_func_with_muta_attr_is_input = ""

get_kernel_type_for_var_declare_str = ""
if dialect_name == "pd_op" or dialect_name == "pd_onednn_op":
if dialect_name == "pd_op" or dialect_name == "onednn_op":
get_kernel_type_for_var_declare_str = (
get_kernel_type_for_var_declare_template
)
Expand Down Expand Up @@ -1609,7 +1609,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
origin_op_name=op_info.op_yaml_item['name'],
)

if dialect_name == "pd_onednn_op":
if dialect_name == "onednn_op":
if len(op_info.onednn_extra_args) > 0:
args_name = []
for arg in op_info.onednn_extra_args:
Expand Down Expand Up @@ -1693,7 +1693,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):

# generate op GetKernelKeyForVar function str
op_get_kernel_type_for_var_str = ''
if dialect_name == "pd_op" or dialect_name == "pd_onednn_op":
if dialect_name == "pd_op" or dialect_name == "onednn_op":
op_get_kernel_type_for_var_str = (
gen_kernel_type_for_var_str(
op_class_name,
Expand Down Expand Up @@ -1722,7 +1722,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
op_info.backward_name
and op_info.op_phi_name[0]
not in vjp_interface_black_list
and dialect_name != "pd_onednn_op"
and dialect_name != "onednn_op"
):
op_vjp_str = gen_op_vjp_str(
op_class_name,
Expand Down Expand Up @@ -1753,7 +1753,7 @@ def AutoCodeGen(op_info_items, all_op_info_items, namespaces, dialect_name):
ops_defined_list.append(infer_symbolic_shape_define_str)

# NOTE(chenxi67)skip if dialect_name==cinn
if dialect_name == "cinn" or dialect_name == "pd_onednn_op":
if dialect_name == "cinn" or dialect_name == "onednn_op":
pass
else:
ops_vjp_defined_list.append(op_vjp_str)
Expand Down Expand Up @@ -1850,7 +1850,7 @@ def OpGenerator(
# (2) parse yaml files
op_compat_parser = OpCompatParser(op_compat_yaml_file)

if dialect_name == "pd_onednn_op":
if dialect_name == "onednn_op":
with open(ops_onednn_extra_yaml_file, "r") as f:
ops_onednn_extra = yaml.safe_load(f)
ops_onednn_extra_map = {}
Expand Down Expand Up @@ -1885,7 +1885,7 @@ def OpGenerator(
op_info_items = {}
for op in op_yaml_items:
op_compat_item = None
if dialect_name == "pd_op" or dialect_name == "pd_onednn_op":
if dialect_name == "pd_op" or dialect_name == "onednn_op":
op_compat_item = op_compat_parser.get_compat(op['name'])

if (
Expand All @@ -1911,7 +1911,7 @@ def OpGenerator(
) = op_compat_parser.parse_support_tensor(op)
op_compat_item['scalar'] = scalar_item
op_compat_item['int_array'] = int_array_item
if dialect_name == "pd_onednn_op":
if dialect_name == "onednn_op":
if first_file:
first_file = False
op["is_onednn_only"] = True
Expand All @@ -1931,7 +1931,7 @@ def OpGenerator(
all_op_info_items[op['name']] = item

op_infos.append(op_info_items)
if dialect_name == "pd_onednn_op":
if dialect_name == "onednn_op":
op_infos = [all_op_info_items]

# (3) auto code gen
Expand Down Expand Up @@ -2044,7 +2044,7 @@ def OpGenerator(
namespace=name, input=source_file_str
) # Add namespaces

if dialect_name == "pd_onednn_op":
if dialect_name == "onednn_op":
op_def_h_file_tmp = (
"paddle/fluid/pir/dialect/operator/ir/pd_op.h\"\n#include \""
+ op_def_h_file
Expand All @@ -2067,7 +2067,7 @@ def OpGenerator(
vjp_source_file_str = VJP_CC_FILE_TEMPLATE.format(input=vjp_source_file_str)
if (
dialect_name != 'cinn'
and dialect_name != 'pd_onednn_op'
and dialect_name != 'onednn_op'
and op_vjp_cc_file
):
with open(op_vjp_cc_file, 'w') as f:
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"

#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
#endif

namespace paddle {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class OneDNNOperatorDialect : public pir::Dialect {
public:
explicit OneDNNOperatorDialect(pir::IrContext* context);

static const char* name() { return "pd_onednn_op"; }
static const char* name() { return "onednn_op"; }

pir::Type ParseType(pir::IrParser& parser) override; // NOLINT
pir::Attribute ParseAttribute(pir::IrParser& parser) override; // NOLINT
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
#include "paddle/utils/string/string_helper.h"

#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
#endif

namespace paddle {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@
#include "paddle/utils/flags.h"

#ifdef PADDLE_WITH_DNNL
#include "paddle/fluid/pir/dialect/operator/ir/onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_onednn_dialect.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_onednn_op.h"
#include "paddle/fluid/pir/dialect/operator/trait/onednn.h"
#endif

Expand Down Expand Up @@ -2218,7 +2218,7 @@ void ProcessBlock(
}
}
std::string target_op_name = op_item->name();
target_op_name.replace(0, 12, "pd_op");
target_op_name.replace(0, 9, "pd_op");
auto op_info = ctx->GetRegisteredOpInfo(target_op_name);
if (!op_info) {
IR_THROW("Ctx should have corresponding OpInfo %s", target_op_name);
Expand Down

0 comments on commit f8eff51

Please sign in to comment.