Skip to content

Commit

Permalink
[Typing] Add type annotations for op_generator/op_gen.py/ `op_gener…
Browse files Browse the repository at this point in the history
…ator/op_interface_gen.py` (PaddlePaddle#66384)
  • Loading branch information
AndPuQing authored and Dale1314 committed Jul 28, 2024
1 parent bb6e04a commit 70af598
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 54 deletions.
117 changes: 69 additions & 48 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

import argparse
import logging
import math
import os
import pathlib
import sys
from typing import Any, NamedTuple

import yaml
from decomp_interface_gen_op_list import (
Expand Down Expand Up @@ -350,52 +352,52 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
}


def to_phi_and_fluid_op_name(op_item):
class OpNamePair(NamedTuple):
phi_name: str
fluid_name: str


def to_phi_and_fluid_op_name(op_item: str) -> OpNamePair:
# Template: - op : phi_name (fluid_name)
names = op_item.split('(')
if len(names) == 1:
phi_fluid_name = names[0].strip()
return phi_fluid_name, phi_fluid_name
return OpNamePair(phi_fluid_name, phi_fluid_name)
else:
phi_name = names[0].strip()
fluid_name = names[1].split(')')[0].strip()
return phi_name, fluid_name
return OpNamePair(phi_name, fluid_name)


def to_phi_and_fluid_grad_op_name(op_item):
def to_phi_and_fluid_grad_op_name(op_item: str) -> list[OpNamePair]:
# Template: sum_grad (reduce_sum_grad), sum_double_grad
rtn = []
all_names = op_item.split(', ')
for name in all_names:
backward_phi_name, backward_fluid_name = to_phi_and_fluid_op_name(name)
rtn.append([backward_phi_name, backward_fluid_name])
return rtn
return list(map(to_phi_and_fluid_op_name, op_item.split(', ')))


# =====================================
# Parse Op Compat From Yaml
# =====================================
class OpCompatParser:
def __init__(self, ops_compat_yaml_file):
def __init__(self, ops_compat_yaml_file: str):
self.ops_compat_yaml_file = ops_compat_yaml_file
with open(self.ops_compat_yaml_file, "r") as f:
self.ops_compat = yaml.safe_load(f)

def get_compat(self, op_name):
def get_compat(self, op_name: str):
for compat in self.ops_compat:
forward_phi_name, forward_fluid_name = to_phi_and_fluid_op_name(
compat['op']
)
if op_name == forward_phi_name:
name_pair = to_phi_and_fluid_op_name(compat['op'])
if op_name == name_pair.phi_name:
return compat
elif 'backward' in compat.keys():
bkw_names = to_phi_and_fluid_grad_op_name(compat['backward'])
for name in bkw_names:
if op_name == name[0]:
if op_name == name.phi_name:
return compat
return None

def parse_support_tensor(self, op):
def parse_support_tensor(
self, op
) -> tuple[dict[str, dict[str, bool]], dict[str, dict[str, bool]]]:
scalar_item = {}
int_array_item = {}
for support_tensor_attr in op['support_tensor']:
Expand Down Expand Up @@ -423,6 +425,9 @@ def __init__(self, op_yaml_item, op_compat_item, yaml_file):
self.yaml_file = yaml_file
self.is_sparse_op = self.parse_op_type()
self.op_phi_name = self.parse_op_phi_name()
self.class_name: str | None = None
self.kernel_input_type_list: list[str] | None = None
self.kernel_output_type_list: list[str] | None = None

self.kernel_map = self.parse_kernel_map()

Expand Down Expand Up @@ -754,7 +759,7 @@ def parse_non_mutable_attribute(self):
op_non_mutable_attribute_default_value_list,
)

def parse_op_type(self):
def parse_op_type(self) -> bool:
if self.yaml_file.endswith(
"sparse_ops.parsed.yaml"
) or self.yaml_file.endswith("sparse_backward.parsed.yaml"):
Expand Down Expand Up @@ -1050,7 +1055,7 @@ def parse_backward_name(self):
else:
return None

def get_phi_dtype_name(self, name):
def get_phi_dtype_name(self, name: str):
name = name.replace('Scalar', 'phi::Scalar')
name = name.replace('IntArray', 'phi::IntArray')
name = name.replace('DataLayout', 'phi::DataLayout')
Expand All @@ -1070,7 +1075,9 @@ def get_phi_dtype_name(self, name):
return name


def get_input_grad_semantic(op_info, op_info_items):
def get_input_grad_semantic(
op_info: OpInfoParser, op_info_items: dict[str, OpInfoParser]
):
input_grad_semantics = []
num_inputs = len(op_info.input_name_list)

Expand Down Expand Up @@ -1105,7 +1112,9 @@ def get_input_grad_semantic(op_info, op_info_items):
return input_grad_semantics


def get_mutable_attribute_grad_semantic(op_info, op_info_items):
def get_mutable_attribute_grad_semantic(
op_info: OpInfoParser, op_info_items: dict[str, OpInfoParser]
):
mutable_attribute_grad_semantics = []
fwd_mutable_attribute_list = op_info.mutable_attribute_name_list

Expand Down Expand Up @@ -1135,7 +1144,7 @@ def get_mutable_attribute_grad_semantic(op_info, op_info_items):
return mutable_attribute_grad_semantics


def split_ops(op_info_items: dict, cc_file, split_nums):
def split_ops(op_info_items: dict[str, Any], cc_file: str, split_nums: int):
op_list = list(op_info_items.keys())
ops_max_size = math.ceil(len(op_list) / split_nums)
split_op_info_items = []
Expand Down Expand Up @@ -1241,7 +1250,11 @@ def GenOneDnnExtraAttrsDefaultValue(onednn_extra_args):


def AutoCodeGen(
args, op_info_items, all_op_info_items, namespaces, dialect_name
args: argparse.Namespace,
op_info_items: dict[str, OpInfoParser],
all_op_info_items: dict[str, OpInfoParser],
namespaces: list[str],
dialect_name: str,
):
# (3) CodeGen: Traverse op_info_items and generate
ops_name_list = [] # all op class name store in this list
Expand Down Expand Up @@ -1444,7 +1457,7 @@ def AutoCodeGen(
op_dialect_name = (
dialect_name
+ "."
+ kernel_func_name
+ kernel_func_name # type: ignore
+ "_"
+ op_dialect_name_inplace_suffix
)
Expand All @@ -1457,7 +1470,7 @@ def AutoCodeGen(
op_dialect_name = (
dialect_name
+ "."
+ kernel_func_name
+ kernel_func_name # type: ignore
+ op_dialect_name_suffix
)
if kernel_func_name is None:
Expand Down Expand Up @@ -1891,12 +1904,12 @@ def AutoCodeGen(
extra_args=extra_args,
skip_transform_inputs=skip_transform_inputs,
data_format_tensors=data_format_tensors,
is_onednn_only="true"
if op_info.is_onednn_only
else "false",
dynamic_fallback="true"
if op_info.dynamic_fallback
else "false",
is_onednn_only=(
"true" if op_info.is_onednn_only else "false"
),
dynamic_fallback=(
"true" if op_info.dynamic_fallback else "false"
),
)
# generate op verify function str
op_verify_str = ''
Expand Down Expand Up @@ -2100,19 +2113,19 @@ def AutoCodeGen(


def OpGenerator(
args,
op_yaml_files,
op_compat_yaml_file,
namespaces,
dialect_name,
op_def_h_file,
op_info_file,
op_def_cc_file,
op_vjp_cc_file,
op_cc_split_num,
bwd_op_cc_split_num,
onednn_yaml_file,
ops_onednn_extra_yaml_file,
args: argparse.Namespace,
op_yaml_files: list[str],
op_compat_yaml_file: str,
namespaces: list[str],
dialect_name: str,
op_def_h_file: str,
op_info_file: str,
op_def_cc_file: list[str],
op_vjp_cc_file: str,
op_cc_split_num: int,
bwd_op_cc_split_num: int,
onednn_yaml_file: str | None,
ops_onednn_extra_yaml_file: str | None,
):
# (1) Prepare: Delete existing old files: pd_op.h.tmp, pd_op.cc.tmp
if os.path.exists(op_def_h_file):
Expand All @@ -2127,6 +2140,10 @@ def OpGenerator(
op_compat_parser = OpCompatParser(op_compat_yaml_file)

if dialect_name == "onednn_op":
if onednn_yaml_file is None or ops_onednn_extra_yaml_file is None:
raise ValueError(
"onednn_op should provide onednn_yaml_file and ops_onednn_extra_yaml_file"
)
with open(ops_onednn_extra_yaml_file, "r") as f:
ops_onednn_extra = yaml.safe_load(f)
ops_onednn_extra_map = {}
Expand Down Expand Up @@ -2155,8 +2172,8 @@ def OpGenerator(
ops_onednn_extra_map[op_name] = item
op_yaml_files.insert(0, onednn_yaml_file)

op_infos = []
all_op_info_items = {}
op_infos: list[dict[str, OpInfoParser]] = []
all_op_info_items: dict[str, OpInfoParser] = {}
new_op_def_cc_file = []
first_file = True
onednn_only_op_list = []
Expand Down Expand Up @@ -2188,7 +2205,11 @@ def OpGenerator(
):
op_compat_item = op_compat_item.pop('scalar')

if 'support_tensor' in op.keys() and op['support_tensor']:
if (
op_compat_item is not None
and 'support_tensor' in op.keys()
and op['support_tensor']
):
(
scalar_item,
int_array_item,
Expand Down
19 changes: 13 additions & 6 deletions paddle/fluid/pir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,15 @@
# limitations under the License.

# generator interfaces
from __future__ import annotations

from typing import TYPE_CHECKING

from vjp_interface_black_list import vjp_interface_black_list

if TYPE_CHECKING:
from op_gen import OpInfoParser

CHECK_INPUT_TEMPLATE = """
PADDLE_ENFORCE_EQ(
inputs_.size(),
Expand Down Expand Up @@ -106,11 +113,11 @@


def gen_op_vjp_str(
op_class_name,
op_grad_name,
op_phi_name,
op_info,
op_grad_info,
op_class_name: str,
op_grad_name: str,
op_phi_name: str,
op_info: OpInfoParser,
op_grad_info: OpInfoParser,
):
bw_input_list = op_grad_info.input_name_list
fwd_input_and_mutable_attr_name_list = (
Expand Down Expand Up @@ -272,7 +279,7 @@ def gen_op_vjp_str(
return str


def gen_exclusive_interface_str(op_info, op_info_items):
def gen_exclusive_interface_str(op_info: OpInfoParser, op_info_items):
exclusive_interface_str = ""
if op_info.op_phi_name[0] not in vjp_interface_black_list:
exclusive_interface_str += "\n static std::vector<std::vector<pir::Value>> Vjp(pir::Operation* op, const std::vector<std::vector<pir::Value>>& inputs_, const std::vector<std::vector<pir::Value>>& outputs, const std::vector<std::vector<pir::Value>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients);"
Expand Down

0 comments on commit 70af598

Please sign in to comment.