Skip to content

Commit

Permalink
[PTen] Polish kernel register marco design (PaddlePaddle#38078)
Browse files Browse the repository at this point in the history
* polish register marco

* resolve compile failed

* revert needless change

* revert eager related change

* revert eager related change

* change register marco name

* polish deetails
  • Loading branch information
chenwhql authored and Caozhou1995 committed Dec 28, 2021
1 parent b009d38 commit a66810e
Show file tree
Hide file tree
Showing 19 changed files with 552 additions and 513 deletions.
18 changes: 9 additions & 9 deletions paddle/pten/api/lib/kernel_declare.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ limitations under the License. */
// the kernel declare statement is automatically generated according to the
// file name of the kernel, and this header file will be removed

PT_DECLARE_KERNEL(full_like, CPU);
PT_DECLARE_KERNEL(dot, CPU);
PT_DECLARE_KERNEL(flatten, CPU);
PT_DECLARE_KERNEL(sign, CPU);
PT_DECLARE_KERNEL(full_like, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(dot, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(flatten, CPU, ALL_LAYOUT);
PT_DECLARE_KERNEL(sign, CPU, ALL_LAYOUT);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(full_like, CUDA);
PT_DECLARE_KERNEL(dot, CUDA);
PT_DECLARE_KERNEL(flatten, CUDA);
PT_DECLARE_KERNEL(sign, CUDA);
PT_DECLARE_KERNEL(full_like, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(dot, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(flatten, CUDA, ALL_LAYOUT);
PT_DECLARE_KERNEL(sign, CUDA, ALL_LAYOUT);
#endif

#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(flatten, XPU);
PT_DECLARE_KERNEL(flatten, XPU, ALL_LAYOUT);
#endif
6 changes: 3 additions & 3 deletions paddle/pten/api/lib/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ limitations under the License. */
#include "paddle/pten/include/core.h"
#include "paddle/pten/include/infermeta.h"

PT_DECLARE_KERNEL(copy, CPU);
PT_DECLARE_KERNEL(copy, CPU, ALL_LAYOUT);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_DECLARE_KERNEL(copy, CUDA);
PT_DECLARE_KERNEL(copy, CUDA, ALL_LAYOUT);
#endif

#ifdef PADDLE_WITH_XPU
PT_DECLARE_KERNEL(copy, XPU);
PT_DECLARE_KERNEL(copy, XPU, ALL_LAYOUT);
#endif

namespace paddle {
Expand Down
37 changes: 36 additions & 1 deletion paddle/pten/common/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ namespace experimental {
* in the future
*/
enum class Backend : uint8_t {
// kernel backend cannot be undefined
UNDEFINED = 0,

// basic kernel backend
Expand All @@ -54,6 +53,42 @@ enum class Backend : uint8_t {

// end of backend types
NUM_BACKENDS,

/**
* [ Why we need ALL in baisc kernel key member? ]
*
* For Tensor, ALL represents an illegal Backend, but for Kernel, some
* kernels may be device-independent by nature, such as reshape; and when
* and some kernels are also device-independent when implemented based on
* primitive API.
*
* In this case, we need to provide a more concise registration method,
* instead of registering the kernels for each device with almost
* repetitive code, we need one registration covers all situations,
* so if we provide the ALL field with Register the kernel in this statement.
*
* Of course, we have also considered solving this problem through different
* named macros, for example, if we define
*
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND
*
* Based on this design pattern, the dtype and layout also have the same
* requirements, this cause we need to define a series of macros
*
* PT_REGISTER_KERNEL_FOR_ALL_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_LAYOUT
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_LAYOUT_AND_DTYPE
* PT_REGISTER_KERNEL_FOR_ALL_BACKEND_AND_LAYOUT_AND_DTYPE
*
* It makes the system of registering macros more complicated, we think
* this is not a simple design, so we still adopt the design of providing
* the ALL field.
*
* Note: ALL_BACKEND only used for Kernel registration and selection
*/
ALL_BACKEND = UNDEFINED,
};

inline std::ostream& operator<<(std::ostream& os, Backend backend) {
Expand Down
4 changes: 3 additions & 1 deletion paddle/pten/common/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,9 @@ enum class DataType {
FLOAT64,
COMPLEX64,
COMPLEX128,
NUM_DATA_TYPES
NUM_DATA_TYPES,
// See Note [ Why we need ALL in baisc kernel key member? ]
ALL_DTYPE = UNDEFINED,
};

inline size_t SizeOf(DataType data_type) {
Expand Down
8 changes: 4 additions & 4 deletions paddle/pten/common/layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,21 +20,21 @@ namespace experimental {

enum class DataLayout {
UNDEFINED = 0,
ANY,
// TODO(chenweihang): keep ANY for compatibility, remove it later
ANY = UNDEFINED,
NHWC,
NCHW,
MKLDNN,
NUM_DATA_LAYOUTS,
// See Note [ Why we need ALL in baisc kernel key member? ]
ALL_LAYOUT = UNDEFINED,
};

inline std::ostream& operator<<(std::ostream& os, DataLayout layout) {
switch (layout) {
case DataLayout::UNDEFINED:
os << "Undefined";
break;
case DataLayout::ANY:
os << "Any";
break;
case DataLayout::NHWC:
os << "NHWC";
break;
Expand Down
891 changes: 444 additions & 447 deletions paddle/pten/core/kernel_registry.h

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions paddle/pten/kernels/cpu/creation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void FillConstant(const CPUContext& dev_ctx,

PT_REGISTER_KERNEL(full_like,
CPU,
ANY,
ALL_LAYOUT,
pten::FillAnyLike,
float,
double,
Expand All @@ -74,7 +74,7 @@ PT_REGISTER_KERNEL(full_like,

PT_REGISTER_KERNEL(full,
CPU,
ANY,
ALL_LAYOUT,
pten::FillConstant,
float,
double,
Expand Down
12 changes: 9 additions & 3 deletions paddle/pten/kernels/cpu/linalg.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ using complex128 = ::paddle::platform::complex<double>;

PT_REGISTER_KERNEL(dot,
CPU,
ANY,
ALL_LAYOUT,
pten::Dot,
float,
double,
Expand All @@ -84,5 +84,11 @@ PT_REGISTER_KERNEL(dot,
complex64,
complex128) {}

PT_REGISTER_KERNEL(
matmul, CPU, ANY, pten::Matmul, float, double, complex64, complex128) {}
PT_REGISTER_KERNEL(matmul,
CPU,
ALL_LAYOUT,
pten::Matmul,
float,
double,
complex64,
complex128) {}
15 changes: 7 additions & 8 deletions paddle/pten/kernels/cpu/manipulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ void Cast(const CPUContext& dev_ctx,

PT_REGISTER_KERNEL(flatten,
CPU,
ANY,
ALL_LAYOUT,
pten::Flatten,
float,
double,
Expand All @@ -95,7 +95,7 @@ PT_REGISTER_KERNEL(flatten,
int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape,
CPU,
ANY,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
double,
Expand All @@ -106,7 +106,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,

PT_REGISTER_KERNEL(cast,
CPU,
ANY,
ALL_LAYOUT,
pten::Cast,
float,
double,
Expand All @@ -122,8 +122,7 @@ PT_REGISTER_KERNEL(cast,
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}

PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CPU, ANY, pten::Reshape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape,
CPU,
ANY,
pten::ReshapeWithXShape) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape, CPU, ALL_LAYOUT, pten::Reshape, ALL_DTYPE) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape_with_xshape, CPU, ALL_LAYOUT, pten::ReshapeWithXShape, ALL_DTYPE) {}
16 changes: 8 additions & 8 deletions paddle/pten/kernels/cpu/math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ using complex128 = ::paddle::platform::complex<double>;

// NOTE(chenweihang): using bfloat16 will cause redefine with xpu bfloat16
// using bfloat16 = ::paddle::platform::bfloat16;
PT_REGISTER_KERNEL(sign, CPU, ANY, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(mean, CPU, ANY, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(sign, CPU, ALL_LAYOUT, pten::Sign, float, double) {}
PT_REGISTER_KERNEL(mean, CPU, ALL_LAYOUT, pten::Mean, float, double, bool) {}
PT_REGISTER_KERNEL(scale,
CPU,
ANY,
ALL_LAYOUT,
pten::Scale,
float,
double,
Expand All @@ -127,7 +127,7 @@ PT_REGISTER_KERNEL(scale,
int64_t) {}
PT_REGISTER_KERNEL(add,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseAdd,
float,
double,
Expand All @@ -137,7 +137,7 @@ PT_REGISTER_KERNEL(add,
complex128) {}
PT_REGISTER_KERNEL(subtract,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseSub,
float,
double,
Expand All @@ -147,7 +147,7 @@ PT_REGISTER_KERNEL(subtract,
complex128) {}
PT_REGISTER_KERNEL(divide,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseDiv,
float,
double,
Expand All @@ -157,7 +157,7 @@ PT_REGISTER_KERNEL(divide,
complex128) {}
PT_REGISTER_KERNEL(multiply,
CPU,
ANY,
ALL_LAYOUT,
pten::ElementwiseMul,
float,
double,
Expand All @@ -168,7 +168,7 @@ PT_REGISTER_KERNEL(multiply,
complex128) {}
PT_REGISTER_KERNEL(sum,
CPU,
ANY,
ALL_LAYOUT,
pten::Sum,
bool,
float,
Expand Down
2 changes: 1 addition & 1 deletion paddle/pten/kernels/cpu/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,4 @@ void Copy(const CPUContext& dev_ctx,

} // namespace pten

PT_REGISTER_KERNEL_ALL_DTYPE(copy, CPU, ANY, pten::Copy) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(copy, CPU, ALL_LAYOUT, pten::Copy, ALL_DTYPE) {}
4 changes: 2 additions & 2 deletions paddle/pten/kernels/cuda/creation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void FillConstant(const CUDAContext& dev_ctx,

PT_REGISTER_KERNEL(full_like,
CUDA,
ANY,
ALL_LAYOUT,
pten::FillAnyLike,
float,
double,
Expand All @@ -75,7 +75,7 @@ PT_REGISTER_KERNEL(full_like,

PT_REGISTER_KERNEL(full,
CUDA,
ANY,
ALL_LAYOUT,
pten::FillConstant,
float,
double,
Expand Down
4 changes: 2 additions & 2 deletions paddle/pten/kernels/cuda/linalg.cu
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ using complex128 = ::paddle::platform::complex<double>;

PT_REGISTER_KERNEL(dot,
CUDA,
ANY,
ALL_LAYOUT,
pten::Dot,
float,
double,
Expand All @@ -71,7 +71,7 @@ PT_REGISTER_KERNEL(dot,

PT_REGISTER_KERNEL(matmul,
CUDA,
ANY,
ALL_LAYOUT,
pten::Matmul,
float,
double,
Expand Down
14 changes: 6 additions & 8 deletions paddle/pten/kernels/cuda/manipulation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ using float16 = paddle::platform::float16;

PT_REGISTER_KERNEL(flatten,
CUDA,
ANY,
ALL_LAYOUT,
pten::Flatten,
float,
float16,
Expand All @@ -97,7 +97,7 @@ PT_REGISTER_KERNEL(flatten,
int64_t) {}
PT_REGISTER_KERNEL(flatten_with_xshape,
CUDA,
ANY,
ALL_LAYOUT,
pten::FlattenWithXShape,
float,
double,
Expand All @@ -109,7 +109,7 @@ PT_REGISTER_KERNEL(flatten_with_xshape,
#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_KERNEL(cast, \
CUDA, \
ANY, \
ALL_LAYOUT, \
pten::Cast, \
float, \
double, \
Expand All @@ -132,8 +132,6 @@ PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif

PT_REGISTER_KERNEL_ALL_DTYPE(reshape, CUDA, ANY, pten::Reshape) {}
PT_REGISTER_KERNEL_ALL_DTYPE(reshape_with_xshape,
CUDA,
ANY,
pten::ReshapeWithXShape) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(reshape, CUDA, ANY, pten::Reshape, ALL_DTYPE) {}
PT_REGISTER_NO_TEMPLATE_KERNEL(
reshape_with_xshape, CUDA, ANY, pten::ReshapeWithXShape, ALL_DTYPE) {}
Loading

0 comments on commit a66810e

Please sign in to comment.