Skip to content

Commit

Permalink
[part 3]change type of function args (#38887)
Browse files Browse the repository at this point in the history
* code clean

* [part 3]change type of function args
  • Loading branch information
zhangting2020 authored Jan 12, 2022
1 parent f120148 commit 0efcae8
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 70 deletions.
30 changes: 15 additions & 15 deletions paddle/fluid/operators/controlflow/bitwise_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ limitations under the License. */
namespace paddle {
namespace operators {

#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \
template <typename T> \
struct Bitwise##func##Functor { \
using ELEM_TYPE = T; \
HOSTDEVICE T operator()(const T& a, const T& b) const { return a expr b; } \
}; \
\
template <> \
struct Bitwise##func##Functor<bool> { \
using ELEM_TYPE = bool; \
HOSTDEVICE bool operator()(const bool& a, const bool& b) const { \
return a bool_expr b; \
} \
#define BITWISE_BINARY_FUNCTOR(func, expr, bool_expr) \
template <typename T> \
struct Bitwise##func##Functor { \
using ELEM_TYPE = T; \
HOSTDEVICE T operator()(const T a, const T b) const { return a expr b; } \
}; \
\
template <> \
struct Bitwise##func##Functor<bool> { \
using ELEM_TYPE = bool; \
HOSTDEVICE bool operator()(const bool a, const bool b) const { \
return a bool_expr b; \
} \
};

BITWISE_BINARY_FUNCTOR(And, &, &&)
Expand All @@ -45,13 +45,13 @@ BITWISE_BINARY_FUNCTOR(Xor, ^, !=)
template <typename T>
struct BitwiseNotFunctor {
using ELEM_TYPE = T;
HOSTDEVICE T operator()(const T& a) const { return ~a; }
HOSTDEVICE T operator()(const T a) const { return ~a; }
};

template <>
struct BitwiseNotFunctor<bool> {
using ELEM_TYPE = bool;
HOSTDEVICE bool operator()(const bool& a) const { return !a; }
HOSTDEVICE bool operator()(const bool a) const { return !a; }
};

template <typename DeviceContext, typename Functor>
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/controlflow/compare_all_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace operators {
template <typename T>
struct EqualReduceFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
HOSTDEVICE bool operator()(const T a, const T b) const {
if (std::is_floating_point<T>::value) {
// This branch will be optimized while compiling if T is integer. It is
// safe to cast a and b to double.
Expand Down
12 changes: 6 additions & 6 deletions paddle/fluid/operators/controlflow/compare_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,31 @@ namespace operators {
template <typename T>
struct LessThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a < b; }
HOSTDEVICE bool operator()(const T a, const T b) const { return a < b; }
};

template <typename T>
struct LessEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a <= b; }
HOSTDEVICE bool operator()(const T a, const T b) const { return a <= b; }
};

template <typename T>
struct GreaterThanFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a > b; }
HOSTDEVICE bool operator()(const T a, const T b) const { return a > b; }
};

template <typename T>
struct GreaterEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a >= b; }
HOSTDEVICE bool operator()(const T a, const T b) const { return a >= b; }
};

template <typename T>
struct EqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
HOSTDEVICE bool operator()(const T a, const T b) const {
if (std::is_floating_point<T>::value) {
// This branch will be optimized while compiling if T is integer. It is
// safe to cast a and b to double.
Expand All @@ -63,7 +63,7 @@ struct EqualFunctor {
template <typename T>
struct NotEqualFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
HOSTDEVICE bool operator()(const T a, const T b) const {
return !EqualFunctor<T>()(a, b);
}
};
Expand Down
28 changes: 4 additions & 24 deletions paddle/fluid/operators/controlflow/logical_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,6 @@ namespace plat = paddle::platform;
namespace paddle {
namespace operators {

#define LOGICAL_BINARY_FUNCTOR(func_name, op) \
template <typename T> \
struct func_name { \
using ELEMENT_TYPE = T; \
HOSTDEVICE bool operator()(const T* args) const { \
return static_cast<bool>(args[0]) op static_cast<bool>(args[1]); \
} \
};

LOGICAL_BINARY_FUNCTOR(CudaOrFunctor, ||)
LOGICAL_BINARY_FUNCTOR(CudaAndFunctor, &&)
LOGICAL_BINARY_FUNCTOR(CudaXorFunctor, ^)
#undef LOGICAL_BINARY_FUNCTOR

template <typename T>
struct CudaNotFunctor {
using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T* args) const { return !args[0]; }
};

template <typename Functor>
class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor>
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
Expand Down Expand Up @@ -76,8 +56,8 @@ class BinaryLogicalOpKernel<platform::CUDADeviceContext, Functor>
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<float>>, \
ops::BinaryLogicalOpKernel<plat::CUDADeviceContext, ops::func<double>>);

REGISTER_LOGICAL_CUDA_KERNEL(logical_or, CudaOrFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_and, CudaAndFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_xor, CudaXorFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_not, CudaNotFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_or, LogicalOrFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_and, LogicalAndFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_xor, LogicalXorFunctor)
REGISTER_LOGICAL_CUDA_KERNEL(logical_not, LogicalNotFunctor)
#undef REGISTER_LOGICAL_CUDA_KERNEL
42 changes: 18 additions & 24 deletions paddle/fluid/operators/controlflow/logical_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,38 +19,32 @@ limitations under the License. */
namespace paddle {
namespace operators {

template <typename T>
struct LogicalAndFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a && b; }
};
#define LOGICAL_BINARY_FUNCTOR(func_name, op) \
template <typename T> \
struct func_name { \
using ELEMENT_TYPE = T; \
HOSTDEVICE bool operator()(const T a, const T b) const { \
return static_cast<bool>(a) op static_cast<bool>(b); \
} \
};

template <typename T>
struct LogicalOrFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const { return a || b; }
};
LOGICAL_BINARY_FUNCTOR(LogicalOrFunctor, ||)
LOGICAL_BINARY_FUNCTOR(LogicalAndFunctor, &&)
LOGICAL_BINARY_FUNCTOR(LogicalXorFunctor, ^)
#undef LOGICAL_BINARY_FUNCTOR

template <typename T>
struct LogicalNotFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a) const { return !a; }
};

template <typename T>
struct LogicalXorFunctor {
using ELEM_TYPE = T;
HOSTDEVICE bool operator()(const T& a, const T& b) const {
return (a || b) && !(a && b);
}
using ELEMENT_TYPE = T;
HOSTDEVICE bool operator()(const T a) const { return !a; }
};

template <typename DeviceContext, typename Functor>
class BinaryLogicalOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
using T = typename Functor::ELEMENT_TYPE;
auto* x = context.Input<framework::Tensor>("X");
auto* y = context.Input<framework::Tensor>("Y");
auto* out = context.Output<framework::Tensor>("Out");
Expand All @@ -62,10 +56,10 @@ class BinaryLogicalOpKernel

template <typename DeviceContext, typename Functor>
class UnaryLogicalOpKernel
: public framework::OpKernel<typename Functor::ELEM_TYPE> {
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using T = typename Functor::ELEM_TYPE;
using T = typename Functor::ELEMENT_TYPE;
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
Functor unary_func;
Expand Down

0 comments on commit 0efcae8

Please sign in to comment.