Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Pten】Auto-geneate kernel signature in C++ API #39281

Merged
merged 1 commit into from
Jan 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 0 additions & 128 deletions paddle/pten/api/include/kernel_signature.h

This file was deleted.

10 changes: 5 additions & 5 deletions python/paddle/utils/code_gen/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def __init__(self, api_item_yaml):
# args:
# inputs:
# names : [], list of input names
# input_info : {input_name : type}
# attrs:
# names : [], list of attribute names
# attr_info : { attr_name : (type, default_values)}
Expand Down Expand Up @@ -91,8 +92,8 @@ def gene_output(self, output_type_list):

def gene_api_code(self):
if self.is_base_api:
input_tensors, kernel_args = gen_utils.get_kernel_args(
self.args['inputs']['names'], self.args['attrs'],
input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args(
self.args['inputs'], self.args['attrs'], self.out_type_list,
self.kernel['param'])
outputs_args, output_create = self.gene_output(self.out_type_list)
return f"""
Expand All @@ -103,8 +104,8 @@ def gene_api_code(self):
{input_tensors}
{gen_utils.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], self.infer_meta)}
{output_create}

auto* kernel_fn = kernel.GetVariadicKernelFn<pten::{self.api}_kernel>();
using kernel_signature = {kernel_signature};
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)({kernel_args}, {outputs_args});

return out;
Expand Down Expand Up @@ -136,7 +137,6 @@ def source_include(header_file_path):

#include "glog/logging.h"

#include "paddle/pten/api/include/kernel_signature.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
Expand Down
19 changes: 12 additions & 7 deletions python/paddle/utils/code_gen/backward_api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,19 +108,24 @@ def gene_output(self, output_type_list):
output_create = ""

if len(output_type_list) == 1:
return_type = output_type_list[0]
kernel_output = 'dense_out'
output_create = f"""
{self.return_type} out;
auto dense_out = SetKernelOutput(out_meta, kernel_backend, &out);"""

elif len(output_type_list) > 1:
output_create = f"""
{self.return_type} out;"""
{self.return_type} out({len(output_type_list)});"""

for i, out_type_item in enumerate(output_type_list):
kernel_output = kernel_output + f'dense_out_{i}, '
get_out_code = f'&out[{i}][0]' if out_type_item == 'Tensor' else f'&out[{i}]'
if out_type_item == 'Tensor':
get_out_code = f'&out[{i}][0]'
output_create = output_create + f"""
out[{i}].emplace_back();"""

else:
get_out_code = f'&out[{i}]'
output_create = output_create + f"""
auto dense_out_{i} = SetKernelOutput(std::get<{i}>(out_meta), kernel_backend, {get_out_code});"""

Expand All @@ -134,8 +139,8 @@ def gene_output(self, output_type_list):

def gene_api_code(self):
if self.is_base_api:
input_tensors, kernel_args = gen_utils.get_kernel_args(
self.args['inputs']['names'], self.args['attrs'],
input_tensors, kernel_args, kernel_signature = gen_utils.get_kernel_args(
self.args['inputs'], self.args['attrs'], self.output_type_list,
self.kernel['param'])
outputs_args, output_create = self.gene_output(
self.output_type_list)
Expand All @@ -149,7 +154,8 @@ def gene_api_code(self):
{gen_utils.gene_infer_meta(self.args['inputs']['names'], self.args['attrs']['names'], self.infer_meta)}
{output_create}

auto* kernel_fn = kernel.GetVariadicKernelFn<pten::{self.backward_api}_kernel>();
using kernel_signature = {kernel_signature};
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)({kernel_args}, {outputs_args});

return out;
Expand Down Expand Up @@ -197,7 +203,6 @@ def source_include(header_file_path):

#include "glog/logging.h"

#include "paddle/pten/api/include/kernel_signature.h"
#include "paddle/pten/api/lib/api_registry.h"
#include "paddle/pten/api/lib/api_utils.h"
#include "paddle/pten/api/lib/kernel_dispatch.h"
Expand Down
29 changes: 27 additions & 2 deletions python/paddle/utils/code_gen/gen_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,21 @@ def gene_infer_meta(input_names, attr_names, infer_meta) -> str:
"""


def get_kernel_args(input_names, attrs, kernel_param):
def get_kernel_args(inputs, attrs, out_type_list, kernel_param):
input_trans_map = {
'const Tensor&': 'const pten::DenseTensor&',
'const Tensor &': 'const pten::DenseTensor&',
'const std::vector<Tensor>&': 'const std::vector<pten::DenseTensor>&',
'const std::vector<Tensor> &': 'const std::vector<pten::DenseTensor>&'
}
out_trans_map = {
'Tensor': 'pten::DenseTensor*',
'std::vector<Tensor>': 'std::vector<pten::DenseTensor*>&'
}
input_names = inputs['names']
input_infos = inputs['input_info']
kernel_args_type_list = ['const platform::DeviceContext&']

input_tensor_code = ""
for input_name in input_names:
# set input code
Expand All @@ -302,15 +316,26 @@ def get_kernel_args(input_names, attrs, kernel_param):
for param in kernel_param:
if param in input_names:
kernel_args = kernel_args + "*" + PREFIX_TENSOR_NAME + param + ", "
kernel_args_type_list.append(input_trans_map[input_infos[param]])
elif param in attr_names:
# set attr for kernel_context
if 'ScalarArray' in attrs['attr_info'][param][0]:
kernel_args_type_list.append('const pten::ScalarArray&')
param = 'pten::ScalarArray(' + param + ')'
elif 'Scalar' in attrs['attr_info'][param][0]:
kernel_args_type_list.append('const pten::Scalar&')
param = 'pten::Scalar(' + param + ')'
else:
kernel_args_type_list.append(attrs['attr_info'][param][0])
kernel_args = kernel_args + param + ", "
elif isinstance(param, bool):
kernel_args = kernel_args + str(param).lower() + ", "
else:
kernel_args = kernel_args + str(param) + ", "
return input_tensor_code, kernel_args[:-2]

for out_type in out_type_list:
kernel_args_type_list.append(out_trans_map[out_type])

kernel_signature = "void(*)(" + ", ".join(kernel_args_type_list) + ")"

return input_tensor_code, kernel_args[:-2], kernel_signature