Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FixEighOP; Unified MatrixEighFunctor function #35812

Merged
merged 4 commits into from
Sep 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 5 additions & 8 deletions paddle/fluid/operators/eigh_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,9 @@ class EighOp : public framework::OperatorWithKernel {
input_dim[rank - 2], input_dim[rank - 1]));

std::vector<int64_t> values_dim;
if (rank > 2) {
for (auto i = 0; i < rank - 1; i++) {
values_dim.emplace_back(input_dim[i]);
}
} else {
values_dim = {input_dim[1]};

for (auto i = 0; i < rank - 1; i++) {
values_dim.emplace_back(input_dim[i]);
}

ctx->SetOutputDim("Eigenvalues", framework::make_ddim(values_dim));
Expand Down Expand Up @@ -99,9 +96,9 @@ class EighGradOp : public framework::OperatorWithKernel {
"EighGrad");
OP_INOUT_CHECK(ctx->HasInput("Eigenvectors"), "Input", "Eigenvectors",
"EighGrad");
OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvalues")),
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvalues")),
"Input", "Eigenvalues@GRAD", "EighGrad");
OP_INOUT_CHECK(ctx->HasInputs(framework::GradVarName("Eigenvectors")),
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Eigenvectors")),
"Input", "Eigenvectors@GRAD", "EighGrad");
auto dims = ctx->GetInputDim("Eigenvectors");
auto x_grad_name = framework::GradVarName("X");
Expand Down
32 changes: 6 additions & 26 deletions paddle/fluid/operators/eigh_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,34 +14,14 @@ limitations under the License. */

#include "paddle/fluid/operators/eigh_op.h"

namespace paddle {
namespace operators {

using Tensor = framework::Tensor;

template <typename ValueType, typename T>
class EighGPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto input_var = ctx.Input<Tensor>("X");
auto output_w_var = ctx.Output<Tensor>("Eigenvalues");
auto output_v_var = ctx.Output<Tensor>("Eigenvectors");
std::string lower = ctx.Attr<std::string>("UPLO");
bool is_lower = (lower == "L");
math::MatrixEighFunctor<ValueType, T> functor;
functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true);
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;

REGISTER_OP_CUDA_KERNEL(
eigh, ops::EighGPUKernel<float, float>, ops::EighGPUKernel<double, double>,
ops::EighGPUKernel<float, paddle::platform::complex<float>>,
ops::EighGPUKernel<double, paddle::platform::complex<double>>);
eigh, ops::EighKernel<paddle::platform::CUDADeviceContext, float, float>,
ops::EighKernel<paddle::platform::CUDADeviceContext, double, double>,
ops::EighKernel<paddle::platform::CUDADeviceContext, float,
paddle::platform::complex<float>>,
ops::EighKernel<paddle::platform::CUDADeviceContext, double,
paddle::platform::complex<double>>);

REGISTER_OP_CUDA_KERNEL(
eigh_grad,
Expand Down
33 changes: 13 additions & 20 deletions paddle/fluid/operators/eigh_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,17 @@ namespace operators {

using Tensor = framework::Tensor;

template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;

template <typename DeviceContext, typename ValueType, typename T>
class EighKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto input_var = ctx.Input<Tensor>("X");
auto output_w_var = ctx.Output<Tensor>("Eigenvalues");
auto output_v_var = ctx.Output<Tensor>("Eigenvectors");
auto input = ctx.Input<Tensor>("X");
auto output_w = ctx.Output<Tensor>("Eigenvalues");
auto output_v = ctx.Output<Tensor>("Eigenvectors");
std::string lower = ctx.Attr<std::string>("UPLO");
bool is_lower = (lower == "L");
math::MatrixEighFunctorCPU<DeviceContext, ValueType, T> functor;
functor(ctx, *input_var, output_w_var, output_v_var, is_lower, true);
math::MatrixEighFunctor<DeviceContext, ValueType, T> functor;
functor(ctx, *input, output_w, output_v, is_lower, true);
}
};

Expand All @@ -49,30 +42,30 @@ class EighGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto& x_grad = *ctx.Output<framework::Tensor>(framework::GradVarName("X"));
x_grad.mutable_data<T>(ctx.GetPlace());
auto& output_w_var = *ctx.Input<Tensor>("Eigenvalues");
auto& output_v_var = *ctx.Input<Tensor>("Eigenvectors");
auto& output_w = *ctx.Input<Tensor>("Eigenvalues");
auto& output_v = *ctx.Input<Tensor>("Eigenvectors");
auto& output_w_grad =
*ctx.Input<Tensor>(framework::GradVarName("Eigenvalues"));
auto& output_v_grad =
*ctx.Input<Tensor>(framework::GradVarName("Eigenvectors"));

auto& dims = output_v_var.dims();
auto& dims = output_v.dims();
const int m = dims[dims.size() - 1];
auto dito =
math::DeviceIndependenceTensorOperations<DeviceContext, T, ValueType>(
ctx);
auto tV = dito.Transpose(dito.Conj(output_v_var));
auto W = dito.Sub_(dito.Unsqueeze(output_w_var, -2),
dito.Unsqueeze(output_w_var, -1));
auto tV = dito.Transpose(dito.Conj(output_v));
auto W = dito.template Sub<ValueType>(dito.Unsqueeze(output_w, -2),
dito.Unsqueeze(output_w, -1));
Tensor result = dito.Matmul(tV, output_v_grad);
result.mutable_data<T>(dims, ctx.GetPlace());
std::vector<int> out_shape = framework::vectorize<int>(dims);
auto constant = dito.Fill(out_shape, 0.5);
result = dito.Sub(result, dito.Conj(dito.Transpose(result)));
result = dito.Mul(result, constant);
result = dito.Div_(result, W);
result = dito.Div(result, W);
result = dito.DiagFill(m, m, m, 0, output_w_grad, result);
x_grad = dito.Matmul(output_v_var, dito.Matmul(result, tV));
x_grad = dito.Matmul(output_v, dito.Matmul(result, tV));
}
};

Expand Down
34 changes: 18 additions & 16 deletions paddle/fluid/operators/math/eigen_values_vectors.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include "Eigen/Core"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/svd_helper.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/dynload/cusolver.h"
Expand All @@ -26,10 +25,6 @@ namespace paddle {
namespace operators {
namespace math {

template <typename T, size_t D, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;

template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using InputMatrixMap = Eigen::Map<
Expand Down Expand Up @@ -67,7 +62,7 @@ inline void ComputeFloatEigenvaluesAndVectors(ValueType *x_data,

eigenvalues = eigen_solver.eigenvalues().transpose();
if (has_vectors) {
eigenvectors = eigen_solver.eigenvectors().transpose();
eigenvectors = eigen_solver.eigenvectors();
}
}
}
Expand Down Expand Up @@ -103,7 +98,7 @@ inline void ComputeComplexEigenvaluesAndVectors(T *x_data,

eigenvalues = eigen_solver.eigenvalues().transpose();
if (has_vectors) {
eigenvectors = eigen_solver.eigenvectors().transpose();
eigenvectors = eigen_solver.eigenvectors();
}
}
}
Expand All @@ -117,11 +112,18 @@ inline int64_t GetBatchSize(framework::DDim dims) {
return batch_size;
}

template <typename DeviceContext, typename ValueType, typename T>
struct MatrixEighFunctor {
void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
bool has_vectors);
};

// Calculates the eigenvalues ​​and eigenvectors of Hermitian or real
// symmetric matrices, and uses the variable has_vectors to
// control whether to return the eigenvectors.
template <typename DeviceContext, typename ValueType, typename T>
struct MatrixEighFunctorCPU {
template <typename ValueType, typename T>
struct MatrixEighFunctor<platform::CPUDeviceContext, ValueType, T> {
public:
void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
Expand All @@ -134,7 +136,8 @@ struct MatrixEighFunctorCPU {
for (int64_t i = 0; i < dim_size - 2; i++) {
batch_size *= dims[i];
}
auto dito = DeviceIndependenceTensorOperations<DeviceContext, T>(ctx);
auto dito =
DeviceIndependenceTensorOperations<platform::CPUDeviceContext, T>(ctx);
Tensor input_tensor;
TensorCopy(input, ctx.GetPlace(), &input_tensor);
if (!is_lower) {
Expand All @@ -157,9 +160,6 @@ struct MatrixEighFunctorCPU {
ComputeFloatEigenvaluesAndVectors<ValueType>(
x_data, value_data, vector_data, batch_size, rows, rows, has_vectors);
}
if (has_vectors) {
*eigen_vectors = dito.Transpose(*eigen_vectors);
}
}
};

Expand All @@ -169,7 +169,7 @@ struct MatrixEighFunctorCPU {
// symmetric matrices on GPU, and uses the variable has_vectors
// to control whether to return the eigenvectors.
template <typename ValueType, typename T>
struct MatrixEighFunctor {
struct MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T> {
public:
void operator()(const framework::ExecutionContext &ctx, const Tensor &input,
Tensor *eigen_values, Tensor *eigen_vectors, bool is_lower,
Expand Down Expand Up @@ -278,7 +278,8 @@ struct MatrixEighFunctor {

#define EVDBUFFER_INSTANCE(ValueType, T, C, CastType) \
template <> \
inline void MatrixEighFunctor<ValueType, T>::EvdBuffer( \
inline void \
MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T>::EvdBuffer( \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, \
cublasFillMode_t uplo, int n, const T *A, int lda, const ValueType *W, \
int *lwork) const { \
Expand All @@ -292,7 +293,8 @@ FUNC_WITH_TYPES(EVDBUFFER_INSTANCE);

#define EVD_INSTANCE(ValueType, T, C, CastType) \
template <> \
inline void MatrixEighFunctor<ValueType, T>::Evd( \
inline void \
MatrixEighFunctor<platform::CUDADeviceContext, ValueType, T>::Evd( \
cusolverDnHandle_t handle, cusolverEigMode_t jobz, \
cublasFillMode_t uplo, int n, T *A, int lda, ValueType *W, T *work, \
int lwork, int *devInfo) const { \
Expand Down
68 changes: 24 additions & 44 deletions paddle/fluid/operators/svd_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,10 +289,20 @@ struct DeviceIndependenceTensorOperations {
framework::Tensor Div(const framework::Tensor& x,
const framework::Tensor& y) {
framework::Tensor ret;
std::vector<int> out_shape = GetBroadcastShape({&x, &y});
ret.Resize(framework::make_ddim(out_shape));
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(
context, &x, &y, -1, DivFunctor<T>(), &ret);
if (x.type() != y.type()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xytype还是有限制的吧?

ret.mutable_data<T>(x.dims(), context.GetPlace());
auto x_vector = EigenVector<T>::Flatten(x);
auto y_vector = EigenVector<ValueType>::Flatten(y);
auto out_vector = EigenVector<T>::Flatten(ret);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
out_vector.device(place) = x_vector / y_vector;
} else {
std::vector<int> out_shape = GetBroadcastShape({&x, &y});
ret.Resize(framework::make_ddim(out_shape));
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(
context, &x, &y, -1, DivFunctor<T>(), &ret);
}
return ret;
}
framework::Tensor Add(const framework::Tensor& x,
Expand Down Expand Up @@ -330,7 +340,8 @@ struct DeviceIndependenceTensorOperations {
NameInTensorMap inputs({{"X", {&x}}});
return CreateOpRunAndReturnTensor("reduce_max", inputs, attrs, out_dim);
}

// Support float and complex type subtraction,the default is T type
template <typename InT = T>
framework::Tensor Sub(const framework::Tensor& x,
const framework::Tensor& y) {
framework::Tensor ret;
Expand All @@ -340,18 +351,18 @@ struct DeviceIndependenceTensorOperations {
#if defined(__NVCC__) || defined(__HIPCC__)
// For GPU, there is no need to define XxxInverseFunctor and call
// ElementwiseComputeEx in two branches.
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
context, &x, &y, -1, SubFunctor<T>(), &ret);
ElementwiseComputeEx<SubFunctor<InT>, DeviceContext, InT>(
context, &x, &y, -1, SubFunctor<InT>(), &ret);
#endif
} else {
if (x.dims().size() >= y.dims().size()) {
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(
context, &x, &y, -1, SubFunctor<T>(), &ret);
ElementwiseComputeEx<SubFunctor<InT>, DeviceContext, InT>(
context, &x, &y, -1, SubFunctor<InT>(), &ret);
} else {
ElementwiseComputeEx<InverseSubFunctor<T>, DeviceContext, T>(
// This is copyed from elementwise_sub, which means we
// need reverse will xrank < yrank
context, &x, &y, -1, InverseSubFunctor<T>(), &ret);
// This is copyed from elementwise_sub, which means we
// need reverse will xrank < yrank
ElementwiseComputeEx<InverseSubFunctor<InT>, DeviceContext, InT>(
context, &x, &y, -1, InverseSubFunctor<InT>(), &ret);
}
}
return ret;
Expand Down Expand Up @@ -461,37 +472,6 @@ struct DeviceIndependenceTensorOperations {
return out;
}

// Support x and y are different data types
Tensor Div_(const Tensor& x, const Tensor& y) {
Tensor out;
out.mutable_data<T>(x.dims(), context.GetPlace());
auto x_vector = EigenVector<T>::Flatten(x);
auto y_vector = EigenVector<ValueType>::Flatten(y);
auto out_vector = EigenVector<T>::Flatten(out);
auto& place =
*context.template device_context<DeviceContext>().eigen_device();
out_vector.device(place) = x_vector / y_vector;
return out;
}

framework::Tensor Sub_(const framework::Tensor& x,
const framework::Tensor& y) {
framework::Tensor ret;
std::vector<int> out_shape = GetBroadcastShape({&x, &y});
ret.Resize(framework::make_ddim(out_shape));
if (x.dims().size() >= y.dims().size()) {
ElementwiseComputeEx<SubFunctor<ValueType>, DeviceContext, ValueType>(
context, &x, &y, -1, SubFunctor<ValueType>(), &ret);
} else {
ElementwiseComputeEx<InverseSubFunctor<ValueType>, DeviceContext,
ValueType>(
// This is copyed from elementwise_sub, which means we
// need reverse will xrank < yrank
context, &x, &y, -1, InverseSubFunctor<ValueType>(), &ret);
}
return ret;
}

private:
const framework::ExecutionContext& context;
BlasT<DeviceContext, T> GetBlas() {
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/fluid/tests/unittests/test_eigh_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_in_static_mode(self):
self.check_static_complex_result()

def test_in_dynamic_mode(self):
paddle.disable_static(self.place)
paddle.disable_static()
input_real_data = paddle.to_tensor(self.real_data)
expected_w, expected_v = np.linalg.eigh(self.real_data)
actual_w, actual_v = paddle.linalg.eigh(input_real_data)
Expand All @@ -152,7 +152,7 @@ def test_in_dynamic_mode(self):
self.compare_result(actual_w, actual_v.numpy(), expected_w, expected_v)

def test_eigh_grad(self):
paddle.disable_static(self.place)
paddle.disable_static()
x = paddle.to_tensor(self.complex_data, stop_gradient=False)
w, v = paddle.linalg.eigh(x)
(w.sum() + paddle.abs(v).sum()).backward()
Expand Down