-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
elementwise max min #7538
elementwise max min #7538
Conversation
}; | ||
|
||
template <typename DeviceContext, typename T> | ||
class ElementwiseMaxKernel : public framework::OpKernel<T> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems that Max and Min are almost the same logic. We can reuse the code. Please refer to https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/reduce_op.h#L97 and https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/operators/reduce_op.h#L253
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds great.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
… dev_elementwise_max_min
… dev_elementwise_max_min
} | ||
} | ||
}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@JiayiFeng We can reduce more codes. But we can do it later, if we have more element-wise operators to implement. Anyway, let's find a general pattern and clean code later.
template <typename T>
struct MaxGradFunctor1 {
template <typename Device, typename Dxe, typename Xe, typename Ye, typename Dze>
void operator()(Device d, Dxe dx_e, Xe, x_e, Ye, y_e, Dze, dz_e) {
dx_e.device(d) = (x_e > y_e).template cast<T>() * dz_e;
}
};
template <typename T>
struct MinGradFunctor1 {
template <typename Device, typename Dxe, typename Xe, typename Ye, typename Dze>
void operator()(Device d, Dxe dx_e, Xe, x_e, Ye, y_e, Dze, dz_e) {
dx_e.device(d) = (x_e < y_e).template cast<T>() * dz_e;
}
};
template <typename T, typename Functor1, typename Functor2>
struct ElementwiseMinMaxGradFunctor {
template <typename Device, typename X, typename Y, typename Z, typename dX,
typename dY, typename dZ>
void operator()(Device d, X x, Y y, Z z, dX dx, dY dy, dZ dz) {
auto x_e = framework::EigenVector<T>::Flatten(*x);
auto y_e = framework::EigenVector<T>::Flatten(*y);
auto dz_e = framework::EigenVector<T>::Flatten(*dz);
if (dx) {
auto dx_e = framework::EigenVector<T>::Flatten(*dx);
Functor1 functor;
functor(d, dx_e, x_e, y_e, dz_e);
}
if (dy) {
auto dy_e = framework::EigenVector<T>::Flatten(*dy);
Functor2 functor;
functor(d, dy_e, x_e, y_e, dz_e);
}
}
};
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@QiJune
I agree. Finding a general pattern in the backward pass is not as easy as that in forwarding pass, for backward op
takes Z
, dZ
, X
, Y
while forwarding op
only takes X
and Y
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
… dev_elementwise_max_min
solve #7567