From 70af598275b983d5d63fdd13b9f0dd3e953dfc70 Mon Sep 17 00:00:00 2001 From: PuQing Date: Fri, 26 Jul 2024 10:04:25 +0800 Subject: [PATCH] [Typing] Add type annotations for `op_generator/op_gen.py`/ `op_generator/op_interface_gen.py` (#66384) --- .../fluid/pir/dialect/op_generator/op_gen.py | 117 +++++++++++------- .../dialect/op_generator/op_interface_gen.py | 19 ++- 2 files changed, 82 insertions(+), 54 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index d72e11ab1db850..81e228b83ae07e 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations import argparse import logging @@ -18,6 +19,7 @@ import os import pathlib import sys +from typing import Any, NamedTuple import yaml from decomp_interface_gen_op_list import ( @@ -350,52 +352,52 @@ class {TEST_API} {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ } -def to_phi_and_fluid_op_name(op_item): +class OpNamePair(NamedTuple): + phi_name: str + fluid_name: str + + +def to_phi_and_fluid_op_name(op_item: str) -> OpNamePair: # Template: - op : phi_name (fluid_name) names = op_item.split('(') if len(names) == 1: phi_fluid_name = names[0].strip() - return phi_fluid_name, phi_fluid_name + return OpNamePair(phi_fluid_name, phi_fluid_name) else: phi_name = names[0].strip() fluid_name = names[1].split(')')[0].strip() - return phi_name, fluid_name + return OpNamePair(phi_name, fluid_name) -def to_phi_and_fluid_grad_op_name(op_item): +def to_phi_and_fluid_grad_op_name(op_item: str) -> list[OpNamePair]: # Template: sum_grad (reduce_sum_grad), sum_double_grad - rtn = [] - all_names = op_item.split(', ') - for name in all_names: - backward_phi_name, backward_fluid_name = to_phi_and_fluid_op_name(name) - rtn.append([backward_phi_name, backward_fluid_name]) - return rtn + return list(map(to_phi_and_fluid_op_name, op_item.split(', '))) # ===================================== # Parse Op Compat From Yaml # ===================================== class OpCompatParser: - def __init__(self, ops_compat_yaml_file): + def __init__(self, ops_compat_yaml_file: str): self.ops_compat_yaml_file = ops_compat_yaml_file with open(self.ops_compat_yaml_file, "r") as f: self.ops_compat = yaml.safe_load(f) - def get_compat(self, op_name): + def get_compat(self, op_name: str): for compat in self.ops_compat: - forward_phi_name, forward_fluid_name = to_phi_and_fluid_op_name( - compat['op'] - ) - if op_name == forward_phi_name: + name_pair = to_phi_and_fluid_op_name(compat['op']) + if op_name == name_pair.phi_name: return compat elif 'backward' in compat.keys(): bkw_names = to_phi_and_fluid_grad_op_name(compat['backward']) for name in bkw_names: - if op_name == name[0]: + if op_name == name.phi_name: return compat return None - def parse_support_tensor(self, op): + def parse_support_tensor( + self, op + ) -> tuple[dict[str, dict[str, bool]], dict[str, dict[str, bool]]]: scalar_item = {} int_array_item = {} for support_tensor_attr in op['support_tensor']: @@ -423,6 +425,9 @@ def __init__(self, op_yaml_item, op_compat_item, yaml_file): self.yaml_file = yaml_file self.is_sparse_op = self.parse_op_type() self.op_phi_name = self.parse_op_phi_name() + self.class_name: str | None = None + self.kernel_input_type_list: list[str] | None = None + self.kernel_output_type_list: list[str] | None = None self.kernel_map = self.parse_kernel_map() @@ -754,7 +759,7 @@ def parse_non_mutable_attribute(self): op_non_mutable_attribute_default_value_list, ) - def parse_op_type(self): + def parse_op_type(self) -> bool: if self.yaml_file.endswith( "sparse_ops.parsed.yaml" ) or self.yaml_file.endswith("sparse_backward.parsed.yaml"): @@ -1050,7 +1055,7 @@ def parse_backward_name(self): else: return None - def get_phi_dtype_name(self, name): + def get_phi_dtype_name(self, name: str): name = name.replace('Scalar', 'phi::Scalar') name = name.replace('IntArray', 'phi::IntArray') name = name.replace('DataLayout', 'phi::DataLayout') @@ -1070,7 +1075,9 @@ def get_phi_dtype_name(self, name): return name -def get_input_grad_semantic(op_info, op_info_items): +def get_input_grad_semantic( + op_info: OpInfoParser, op_info_items: dict[str, OpInfoParser] +): input_grad_semantics = [] num_inputs = len(op_info.input_name_list) @@ -1105,7 +1112,9 @@ def get_input_grad_semantic(op_info, op_info_items): return input_grad_semantics -def get_mutable_attribute_grad_semantic(op_info, op_info_items): +def get_mutable_attribute_grad_semantic( + op_info: OpInfoParser, op_info_items: dict[str, OpInfoParser] +): mutable_attribute_grad_semantics = [] fwd_mutable_attribute_list = op_info.mutable_attribute_name_list @@ -1135,7 +1144,7 @@ def get_mutable_attribute_grad_semantic(op_info, op_info_items): return mutable_attribute_grad_semantics -def split_ops(op_info_items: dict, cc_file, split_nums): +def split_ops(op_info_items: dict[str, Any], cc_file: str, split_nums: int): op_list = list(op_info_items.keys()) ops_max_size = math.ceil(len(op_list) / split_nums) split_op_info_items = [] @@ -1241,7 +1250,11 @@ def GenOneDnnExtraAttrsDefaultValue(onednn_extra_args): def AutoCodeGen( - args, op_info_items, all_op_info_items, namespaces, dialect_name + args: argparse.Namespace, + op_info_items: dict[str, OpInfoParser], + all_op_info_items: dict[str, OpInfoParser], + namespaces: list[str], + dialect_name: str, ): # (3) CodeGen: Traverse op_info_items and generate ops_name_list = [] # all op class name store in this list @@ -1444,7 +1457,7 @@ def AutoCodeGen( op_dialect_name = ( dialect_name + "." - + kernel_func_name + + kernel_func_name # type: ignore + "_" + op_dialect_name_inplace_suffix ) @@ -1457,7 +1470,7 @@ def AutoCodeGen( op_dialect_name = ( dialect_name + "." - + kernel_func_name + + kernel_func_name # type: ignore + op_dialect_name_suffix ) if kernel_func_name is None: @@ -1891,12 +1904,12 @@ def AutoCodeGen( extra_args=extra_args, skip_transform_inputs=skip_transform_inputs, data_format_tensors=data_format_tensors, - is_onednn_only="true" - if op_info.is_onednn_only - else "false", - dynamic_fallback="true" - if op_info.dynamic_fallback - else "false", + is_onednn_only=( + "true" if op_info.is_onednn_only else "false" + ), + dynamic_fallback=( + "true" if op_info.dynamic_fallback else "false" + ), ) # generate op verify function str op_verify_str = '' @@ -2100,19 +2113,19 @@ def AutoCodeGen( def OpGenerator( - args, - op_yaml_files, - op_compat_yaml_file, - namespaces, - dialect_name, - op_def_h_file, - op_info_file, - op_def_cc_file, - op_vjp_cc_file, - op_cc_split_num, - bwd_op_cc_split_num, - onednn_yaml_file, - ops_onednn_extra_yaml_file, + args: argparse.Namespace, + op_yaml_files: list[str], + op_compat_yaml_file: str, + namespaces: list[str], + dialect_name: str, + op_def_h_file: str, + op_info_file: str, + op_def_cc_file: list[str], + op_vjp_cc_file: str, + op_cc_split_num: int, + bwd_op_cc_split_num: int, + onednn_yaml_file: str | None, + ops_onednn_extra_yaml_file: str | None, ): # (1) Prepare: Delete existing old files: pd_op.h.tmp, pd_op.cc.tmp if os.path.exists(op_def_h_file): @@ -2127,6 +2140,10 @@ def OpGenerator( op_compat_parser = OpCompatParser(op_compat_yaml_file) if dialect_name == "onednn_op": + if onednn_yaml_file is None or ops_onednn_extra_yaml_file is None: + raise ValueError( + "onednn_op should provide onednn_yaml_file and ops_onednn_extra_yaml_file" + ) with open(ops_onednn_extra_yaml_file, "r") as f: ops_onednn_extra = yaml.safe_load(f) ops_onednn_extra_map = {} @@ -2155,8 +2172,8 @@ def OpGenerator( ops_onednn_extra_map[op_name] = item op_yaml_files.insert(0, onednn_yaml_file) - op_infos = [] - all_op_info_items = {} + op_infos: list[dict[str, OpInfoParser]] = [] + all_op_info_items: dict[str, OpInfoParser] = {} new_op_def_cc_file = [] first_file = True onednn_only_op_list = [] @@ -2188,7 +2205,11 @@ def OpGenerator( ): op_compat_item = op_compat_item.pop('scalar') - if 'support_tensor' in op.keys() and op['support_tensor']: + if ( + op_compat_item is not None + and 'support_tensor' in op.keys() + and op['support_tensor'] + ): ( scalar_item, int_array_item, diff --git a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py index 9a6da01236065d..b883073938222e 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py @@ -13,8 +13,15 @@ # limitations under the License. # generator interfaces +from __future__ import annotations + +from typing import TYPE_CHECKING + from vjp_interface_black_list import vjp_interface_black_list +if TYPE_CHECKING: + from op_gen import OpInfoParser + CHECK_INPUT_TEMPLATE = """ PADDLE_ENFORCE_EQ( inputs_.size(), @@ -106,11 +113,11 @@ def gen_op_vjp_str( - op_class_name, - op_grad_name, - op_phi_name, - op_info, - op_grad_info, + op_class_name: str, + op_grad_name: str, + op_phi_name: str, + op_info: OpInfoParser, + op_grad_info: OpInfoParser, ): bw_input_list = op_grad_info.input_name_list fwd_input_and_mutable_attr_name_list = ( @@ -272,7 +279,7 @@ def gen_op_vjp_str( return str -def gen_exclusive_interface_str(op_info, op_info_items): +def gen_exclusive_interface_str(op_info: OpInfoParser, op_info_items): exclusive_interface_str = "" if op_info.op_phi_name[0] not in vjp_interface_black_list: exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& inputs_, const std::vector>& outputs, const std::vector>& out_grads, const std::vector>& stop_gradients);"