Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【complex op No.12】add complex support for pow #62959

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/operators/generator/generate_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def restruct_io(op):

def process_scalar(op_item, scalar_configs):
scalar_map = {
'Scalar': 'float',
'Scalar': 'Scalar',
'Scalar(float)': 'float',
'Scalar(double)': 'double',
'Scalar(int)': 'int',
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License. */
#include <set>

#include "paddle/common/flags.h"
#include "paddle/phi/api/lib/data_type_set.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
Expand Down Expand Up @@ -3482,6 +3483,18 @@ void PoolInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void PowInferMeta(const MetaTensor& x, const Scalar& y, MetaTensor* out) {
paddle::experimental::DataTypeSet dtype_set{x.dtype()};
dtype_set = dtype_set | paddle::experimental::DataTypeSet(y.dtype());
DataType promote_result = PromoteTypes(dtype_set);
if (promote_result == DataType::UNDEFINED) {
promote_result = x.dtype();
}
out->set_dims(x.dims());
out->set_dtype(promote_result);
out->set_layout(x.layout());
}

void PushDenseInferMeta(const std::vector<const MetaTensor*>& ids,
int table_id,
float scale_data_norm,
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -559,6 +559,8 @@ void WeightQuantizeInferMeta(const MetaTensor& x,
MetaTensor* out,
MetaTensor* scale);

void PowInferMeta(const MetaTensor& x, const Scalar& y, MetaTensor* out);

void RealAndImagInferMeta(const MetaTensor& x, MetaTensor* out);

void ReduceInferMeta(const MetaTensor& x,
Expand Down
12 changes: 9 additions & 3 deletions paddle/phi/kernels/cpu/activation_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -461,20 +461,26 @@ PD_REGISTER_KERNEL(pow_grad,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(pow_double_grad,
CPU,
ALL_LAYOUT,
phi::PowDoubleGradKernel,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(pow_triple_grad,
CPU,
ALL_LAYOUT,
phi::PowTripleGradKernel,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
13 changes: 11 additions & 2 deletions paddle/phi/kernels/cpu/activation_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/phi/kernels/activation_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/impl/activation_impl.h"
Expand Down Expand Up @@ -308,5 +309,13 @@ PD_REGISTER_KERNEL(negative,
int,
int64_t) {}
PD_REGISTER_ACTIVATION_KERNEL(celu, CeluKernel)
PD_REGISTER_KERNEL(
pow, CPU, ALL_LAYOUT, phi::PowKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(pow,
CPU,
ALL_LAYOUT,
phi::PowKernel,
float,
double,
int,
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/elementwise_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ PD_REGISTER_KERNEL(elementwise_pow_grad,
double,
int,
int64_t,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(copysign_grad,
CPU,
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/cpu/elementwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ PD_REGISTER_KERNEL(elementwise_pow,
double,
int,
int64_t,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(heaviside,
CPU,
ALL_LAYOUT,
Expand Down
85 changes: 85 additions & 0 deletions paddle/phi/kernels/funcs/activation_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,29 @@ struct Real {
}
};

template <typename T>
struct RealCp {
HOSTDEVICE const T operator()(const ComplexType<T>& val) const {
return val.real;
}
};

template <typename T>
struct ImagCp {
HOSTDEVICE const T operator()(const ComplexType<T>& val) const {
return val.imag;
}
};

template <typename T>
struct ComplexAssemble {
HOSTDEVICE ComplexAssemble() {}

HOSTDEVICE ComplexType<T> operator()(const T& r, const T& i) const {
return ComplexType<T>(r, i);
}
};

// sine'(x) = cos(x)
template <typename T>
struct SinGradFunctor : public BaseActivationFunctor<T> {
Expand Down Expand Up @@ -2909,18 +2932,43 @@ struct PowFunctor : public BaseActivationFunctor<T> {
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}

T factor_cp;
using ComplexAttrPair = std::vector<std::pair<const char*, T*>>;
ComplexAttrPair GetComplexAttrs() { return {{"factor", &factor_cp}}; }

template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.pow(static_cast<T>(factor)); // NOLINT
}
};

template <typename T>
struct PowFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
phi::dtype::complex<T> factor;

using ComplexAttrPair =
std::vector<std::pair<const char*, phi::dtype::complex<T>*>>;
ComplexAttrPair GetComplexAttrs() { return {{"factor", &factor}}; }

template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
out.device(d) = x.pow(factor); // NOLINT
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct PowGradFunctor : public BaseActivationFunctor<T> {
float factor;
typename BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"factor", &factor}};
}
T factor_cp;
using ComplexAttrPair = std::vector<std::pair<const char*, T*>>;
ComplexAttrPair GetComplexAttrs() { return {{"factor", &factor_cp}}; }
template <typename Device,
typename X,
typename Out,
Expand All @@ -2934,6 +2982,43 @@ struct PowGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

template <typename T>
struct PowGradFunctor<ComplexType<T>>
: public BaseActivationFunctor<ComplexType<T>> {
ComplexType<T> factor;
using ComplexAttrPair =
std::vector<std::pair<const char*, phi::dtype::complex<T>*>>;
ComplexAttrPair GetComplexAttrs() { return {{"factor", &factor}}; }
template <typename Device,
typename X,
typename Out,
typename dOut,
typename dX>
void operator()(Device d, X x, Out out UNUSED, dOut dout, dX dx) const {
auto a_ = x.unaryExpr(RealCp<T>());
auto b_ = x.unaryExpr(ImagCp<T>());
auto c_2 = factor.real / 2;
auto d_2 = factor.imag / 2;
auto c_ = factor.real;
auto d_ = factor.imag;
auto arctan_ = (b_ / a_).unaryExpr(Atan<T>());
auto square_ = a_ * a_ + b_ * b_;
auto e_ = (square_.log() * c_2 - d_ * arctan_).exp();
auto v_ = square_.log() * d_2 + c_ * arctan_;

auto ux = e_ / square_ *
((a_ * c_ + b_ * d_) * v_.unaryExpr(Cosine<T>()) +
(b_ * c_ - a_ * d_) * v_.unaryExpr(Sine<T>()));
auto uy = e_ / square_ *
((b_ * c_ - a_ * d_) * v_.unaryExpr(Cosine<T>()) -
(b_ * d_ + a_ * c_) * v_.unaryExpr(Sine<T>()));

dx.device(d) = ux.binaryExpr(uy, ComplexAssemble<T>());
}

static constexpr ActBwdOpFwdDeps FwdDeps() { return ActBwdOpFwdDeps::kDepX; }
};

// floor(x) = flooring(x)
template <typename T>
struct FloorFunctor : public BaseActivationFunctor<T> {
Expand Down
16 changes: 15 additions & 1 deletion paddle/phi/kernels/funcs/elementwise_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -869,14 +869,28 @@ compute_pow(const T a, const T b) {
}
#else
template <typename T, typename MPType>
inline HOSTDEVICE T compute_pow(const T a, const T b) {
inline HOSTDEVICE typename std::enable_if<
!std::is_same<T, phi::dtype::complex<float>>::value &&
!std::is_same<T, phi::dtype::complex<double>>::value,
T>::type
compute_pow(const T a, const T b) {
MPType a_val = static_cast<MPType>(a);
MPType b_val = static_cast<MPType>(b);
#ifdef PADDLE_WITH_XPU_KP
return static_cast<T>(pow(a_val, b_val));
#endif
return static_cast<T>(std::pow(a_val, b_val));
}
template <typename T, typename MPType>
inline HOSTDEVICE typename std::enable_if<
std::is_same<T, phi::dtype::complex<float>>::value ||
std::is_same<T, phi::dtype::complex<double>>::value,
T>::type
compute_pow(const T a, const T b) {
MPType a_val = static_cast<MPType>(a);
MPType b_val = static_cast<MPType>(b);
return static_cast<T>(pow(a_val, b_val));
}
#endif

template <typename T>
Expand Down
12 changes: 9 additions & 3 deletions paddle/phi/kernels/gpu/activation_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,9 @@ PD_REGISTER_KERNEL(pow_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(pow_double_grad,
GPU,
ALL_LAYOUT,
Expand All @@ -552,7 +554,9 @@ PD_REGISTER_KERNEL(pow_double_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(pow_triple_grad,
GPU,
ALL_LAYOUT,
Expand All @@ -562,4 +566,6 @@ PD_REGISTER_KERNEL(pow_triple_grad,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/activation_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,9 @@ PD_REGISTER_KERNEL(pow,
int,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
PD_REGISTER_KERNEL(selu,
GPU,
ALL_LAYOUT,
Expand Down
4 changes: 3 additions & 1 deletion paddle/phi/kernels/gpu/elementwise_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,9 @@ PD_REGISTER_KERNEL(elementwise_pow_grad,
int,
phi::dtype::float16,
phi::dtype::bfloat16,
int64_t) {}
int64_t,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}

PD_REGISTER_KERNEL(add_grad,
GPU,
Expand Down
13 changes: 11 additions & 2 deletions paddle/phi/kernels/impl/activation_grad_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,14 @@
#include "paddle/phi/backends/all_context.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/kernels/as_complex_kernel.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/complex_kernel.h"
#include "paddle/phi/kernels/elementwise_add_kernel.h"
#include "paddle/phi/kernels/elementwise_multiply_kernel.h"
#include "paddle/phi/kernels/full_kernel.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/scale_kernel.h"

namespace phi {
Expand Down Expand Up @@ -336,8 +340,13 @@ void PowGradKernel(const Context& dev_ctx,
EigenVector<T>::Flatten(GET_DATA_SAFELY(&x, "Input", "X", "PowGrad"));
auto* place = dev_ctx.eigen_device();
phi::funcs::PowGradFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = factor.to<float>();
if (IsComplexType(x.dtype()) || IsComplexType(factor.dtype())) {
auto attrs = functor.GetComplexAttrs();
*(attrs[0].second) = factor.to<T>();
} else {
auto attrs = functor.GetAttrs();
*(attrs[0].second) = factor.to<float>();
}
functor(*place, x_flatten, nullptr, dout_flatten, dx_flatten);
}

Expand Down
9 changes: 7 additions & 2 deletions paddle/phi/kernels/impl/activation_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,13 @@ void PowKernel(const Context& dev_ctx,
GET_DATA_SAFELY(out, "Output", "Out", "Activation"));
auto* place = dev_ctx.eigen_device();
phi::funcs::PowFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = factor.to<float>();
if (IsComplexType(x.dtype()) || IsComplexType(factor.dtype())) {
auto attrs = functor.GetComplexAttrs();
*(attrs[0].second) = factor.to<T>();
} else {
auto attrs = functor.GetAttrs();
*(attrs[0].second) = factor.to<float>();
}
functor(*place, x_flatten, out_flatten);
}

Expand Down
Loading