Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon No.28] implement logcumsumexp #42267

Merged
merged 27 commits into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
5c3b6bb
implement logcumsumexp
tiancaishaonvjituizi Apr 26, 2022
4054f7c
polish
tiancaishaonvjituizi Apr 26, 2022
b8ade29
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi Apr 30, 2022
1f98cc7
fix ci
tiancaishaonvjituizi May 3, 2022
518c75a
reformat
tiancaishaonvjituizi May 3, 2022
e94f42c
update
tiancaishaonvjituizi May 3, 2022
8c680e6
address reviews
tiancaishaonvjituizi May 9, 2022
442bc00
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi May 9, 2022
a3e50da
add OpTest
tiancaishaonvjituizi May 9, 2022
0b4b8ca
use user defined grad
tiancaishaonvjituizi May 13, 2022
3bf4cfe
add formula in docs, address reviews
tiancaishaonvjituizi May 13, 2022
34f57f1
remove 'reference' comment
tiancaishaonvjituizi May 14, 2022
661bff3
Update logcumsumexp_grad_kernel.h
tiancaishaonvjituizi May 14, 2022
30241bb
Update logcumsumexp_sig.cc
tiancaishaonvjituizi May 14, 2022
2454012
Update logcumsumexp_grad_impl.h
tiancaishaonvjituizi May 14, 2022
1734440
decrease input size, update python
tiancaishaonvjituizi May 16, 2022
d6a773e
Merge branch 'logcumsumexp' of github.com:tiancaishaonvjituizi/Paddle…
tiancaishaonvjituizi May 16, 2022
3e4953a
shrink test data size
tiancaishaonvjituizi May 17, 2022
790e616
fix sample code
tiancaishaonvjituizi May 17, 2022
9797d10
refine docs
tiancaishaonvjituizi May 18, 2022
3b4b8fe
update docs
tiancaishaonvjituizi May 25, 2022
57bb711
fix docs;test=document_fix
tiancaishaonvjituizi May 27, 2022
4601d17
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi May 27, 2022
250998c
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi Jun 7, 2022
6a3647c
set test timeout to 30s
tiancaishaonvjituizi Jun 7, 2022
d6c7aa7
Merge remote-tracking branch 'origin/develop' into logcumsumexp
tiancaishaonvjituizi Jun 8, 2022
13edc4f
reformat
tiancaishaonvjituizi Jun 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,85 @@ class CumsumGradMaker : public framework::SingleGradOpMaker<T> {
}
};

class LogcumsumexpOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input of logcumsumexp operator");
AddOutput("Out", "Output of logcumsumexp operator");
AddAttr<int>("axis",
"The dimension to accumulate along. -1 means the last "
"dimension [default -1].")
.SetDefault(-1);
AddAttr<bool>("flatten",
"Whether to compute the logcumsumexp over the flattened array. "
"[default false].")
.SetDefault(false);
AddAttr<bool>("exclusive",
"Whether to perform exclusive logcumsumexp. [default false].")
.SetDefault(false);
AddAttr<bool>("reverse",
"If true, the logcumsumexp is performed in the reversed direction. "
"[default false].")
.SetDefault(false);
AddComment(R"DOC(
Returns the logarithm of the cumulative summation of the exponentiation of elements of input along the given axis.
By default, the first element of the result is the same of the first element of
the input. If exlusive is true, the first element of the result is the minimum value of dtype.
tiancaishaonvjituizi marked this conversation as resolved.
Show resolved Hide resolved
)DOC");
}
};

class LogcumsumexpGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "logcumsumexp");
OP_INOUT_CHECK(ctx->HasInput("Out"), "Input", "Out", "logcumsumexp");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "logcumsumexp");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};

template <typename T>
class LogcumsumexpGradMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;

protected:
void Apply(GradOpPtr<T> grad_op) const override {
grad_op->SetType("logcumsumexp_grad");
grad_op->SetInput("X", this->OutputGrad("Out"));
grad_op->SetOutput("Out", this->InputGrad("X"));
grad_op->SetAttr("axis", BOOST_GET_CONST(int, this->GetAttr("axis")));
grad_op->SetAttr("flatten",
BOOST_GET_CONST(bool, this->GetAttr("flatten")));
grad_op->SetAttr("reverse",
BOOST_GET_CONST(bool, this->GetAttr("reverse")));
grad_op->SetAttr("exclusive",
BOOST_GET_CONST(bool, this->GetAttr("exclusive")));
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
DECLARE_INFER_SHAPE_FUNCTOR(cumsum, CumsumInferShapeFunctor,
PD_INFER_META(phi::CumsumInferMeta));
PD_INFER_META(phi::CumInferMeta));
DECLARE_INFER_SHAPE_FUNCTOR(logcumsumexp, LogcumsumexpInferShapeFunctor,
PD_INFER_META(phi::CumInferMeta));
REGISTER_OPERATOR(cumsum, ops::CumOp, ops::CumsumOpMaker,
ops::CumsumGradMaker<paddle::framework::OpDesc>,
ops::CumsumGradMaker<paddle::imperative::OpBase>,
CumsumInferShapeFunctor);
REGISTER_OPERATOR(logcumsumexp, ops::CumOp, ops::LogcumsumexpOpMaker,
ops::LogcumsumexpGradMaker<paddle::framework::OpDesc>,
ops::LogcumsumexpGradMaker<paddle::imperative::OpBase>,
LogcumsumexpInferShapeFunctor);
REGISTER_OPERATOR(logcumsumexp_grad, ops::LogcumsumexpGradOp);

REGISTER_OP_VERSION(cumsum)
.AddCheckpoint(
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,12 +235,12 @@ void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out) {
out->set_layout(x.layout());
}

void CumsumInferMeta(const MetaTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
MetaTensor* out) {
void CumInferMeta(const MetaTensor& x,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cumsum 和 Logcumsumexp 复用同一个 infer meta 函数

int axis,
bool flatten,
bool exclusive,
bool reverse,
MetaTensor* out) {
auto x_dims = x.dims();
if (flatten) {
out->set_dims(phi::make_ddim({phi::product(x_dims)}));
Expand Down
12 changes: 6 additions & 6 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,12 @@ void CholeskyInferMeta(const MetaTensor& x, bool upper, MetaTensor* out);

void CreateLikeInferMeta(const MetaTensor& x, DataType dtype, MetaTensor* out);

void CumsumInferMeta(const MetaTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
MetaTensor* out);
void CumInferMeta(const MetaTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
MetaTensor* out);

void DiagInferMeta(const MetaTensor& x,
int offset,
Expand Down
273 changes: 273 additions & 0 deletions paddle/phi/kernels/cpu/cum_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/cum_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/eigen/eigen_function.h"

namespace phi {

template <typename Device,
typename Dim,
typename X,
typename Out,
typename Reducer>
void ComputeImp(Device d,
const Dim& dims,
X x,
Out out,
int axis,
bool reverse,
bool exclusive,
Reducer reducer) {
if (!reverse) {
out.reshape(dims).device(d) =
x.reshape(dims).scan(axis, reducer, exclusive);
} else {
std::array<bool, Dim::count> rev;
rev.fill(false);
rev[axis] = reverse;
out.reshape(dims).device(d) = x.reshape(dims)
.reverse(rev)
.scan(axis, reducer, exclusive)
.reverse(rev);
}
}

template <typename T, typename Context, typename Reducer>
void ScanKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
Reducer reducer,
DenseTensor* out) {
auto out_dims = out->dims();

PADDLE_ENFORCE_EQ(
axis < out_dims.size() && axis >= (0 - out_dims.size()),
true,
phi::errors::OutOfRange(
"Attr(axis) is out of range, It's expected "
"to be in range of [-%d, %d]. But received Attr(axis) = %d.",
out_dims.size(),
out_dims.size() - 1,
axis));
if (axis < 0) {
axis += out_dims.size();
}

dev_ctx.template Alloc<T>(out);

int pre = 1;
int post = 1;
int mid = out_dims[axis];
for (int i = 0; i < axis; ++i) {
pre *= out_dims[i];
}
for (int i = axis + 1; i < out_dims.size(); ++i) {
post *= out_dims[i];
}

auto x0 = EigenVector<T>::Flatten(x);
auto out0 = EigenVector<T>::Flatten(*out);
auto& place = *dev_ctx.eigen_device();

using IndexT = Eigen::DenseIndex;
if (pre == 1) {
if (post == 1) {
ComputeImp(place,
Eigen::DSizes<IndexT, 1>(mid),
x0,
out0,
/* axis= */ 0,
reverse,
exclusive,
reducer);
} else {
ComputeImp(place,
Eigen::DSizes<IndexT, 2>(mid, post),
x0,
out0,
/* axis= */ 0,
reverse,
exclusive,
reducer);
}
} else {
if (post == 1) {
ComputeImp(place,
Eigen::DSizes<IndexT, 2>(pre, mid),
x0,
out0,
/* axis= */ 1,
reverse,
exclusive,
reducer);
} else {
ComputeImp(place,
Eigen::DSizes<IndexT, 3>(pre, mid, post),
x0,
out0,
/* axis= */ 1,
reverse,
exclusive,
reducer);
}
}
}

template <typename T, typename Context>
void CumsumKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* out) {
using Reducer = Eigen::internal::SumReducer<T>;
auto reducer = Reducer();
ScanKernel<T, Context, Reducer>(
dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out);
}
Copy link
Contributor Author

@tiancaishaonvjituizi tiancaishaonvjituizi Apr 26, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

本文件从这一行往上的内容是从 cumsum_kernel.cc 移动过来的,增加了 Reducer 参数


// Copied from
// https://github.com/tensorflow/tensorflow/blob/master/tensorflow/core/kernels/scan_ops.h
tiancaishaonvjituizi marked this conversation as resolved.
Show resolved Hide resolved
template <typename T>
struct LogSumExp {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
const T& b) const {
auto mi = Eigen::internal::scalar_min_op<T>()(a, b);
auto ma = Eigen::internal::scalar_max_op<T>()(a, b);

auto sub = Eigen::internal::scalar_difference_op<T>();
auto add = Eigen::internal::scalar_sum_op<T>();
auto exp = Eigen::internal::scalar_exp_op<T>();
auto log1p = Eigen::internal::scalar_log1p_op<T>();
auto cmp_lt =
Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();

auto logsumexp = add(log1p(exp(sub(mi, ma))), ma);
return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? ma : logsumexp;
}
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T packetOp(const T& a,
const T& b) const {
auto mi = Eigen::internal::pmin(a, b);
auto ma = Eigen::internal::pmax(a, b);
using Eigen::internal::padd;
using Eigen::internal::pcmp_lt;
using Eigen::internal::pexp;
using Eigen::internal::plog1p;
using Eigen::internal::pset1;
using Eigen::internal::psub;

auto logsumexp = padd(plog1p(pexp(psub(mi, ma))), ma);
return pselect(
pcmp_lt(ma, pset1(Eigen::NumTraits<T>::lowest())), ma, logsumexp);
}
};

template <typename T>
struct LogSumExpReducer {
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reduce(const T t, T* accum) const {
LogSumExp<T> logsumexp;
*accum = logsumexp(*accum, t);
}

template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE void reducePacket(const Packet& p,
Packet* accum) const {
LogSumExp<T> logsumexp;
*accum = logsumexp.packetOp(*accum, p);
}

EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T initialize() const {
return -Eigen::NumTraits<T>::infinity();
}

template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet initializePacket() const {
return Eigen::internal::pset1(initialize());
}

EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T finalize(const T accum) const {
return accum;
}

template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet
finalizePacket(const Packet& vaccum) const {
return vaccum;
}

template <typename Packet>
EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T
finalizeBoth(const T saccum, const Packet& vaccum) const {
auto max_reducer = Eigen::internal::MaxReducer<T, Eigen::PropagateNaN>();
auto sum_reducer = Eigen::internal::SumReducer<T>();
auto exp = Eigen::internal::scalar_exp_op<T>();
auto cmp_lt =
Eigen::internal::scalar_cmp_op<T, T, Eigen::internal::cmp_LT>();
auto log = Eigen::internal::scalar_log_op<T>();
auto add = Eigen::internal::scalar_sum_op<T>();

using Eigen::internal::pexp;
using Eigen::internal::psub;

// `ma = max(x1, ..., xn)`
// If the max of all of the `xi` is `-infinity` then the result is
// -infinity. If the max is larger than `-infinity` then it's safe to use
// for normalization even if the other elements are `-infinity`.
//
// `logsumexp(x1, ..., xn) = ma + log (exp(x1 - ma) + ... + exp(xn - ma))`
auto ma = max_reducer.finalizeBoth(saccum, vaccum);
auto logsumexp = add(log(sum_reducer.finalizeBoth(
exp(saccum - ma), pexp(psub(vaccum, pset1(ma))))),
ma);
return cmp_lt(ma, Eigen::NumTraits<T>::lowest()) ? initialize() : logsumexp;
}
};
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


template <typename T, typename Context>
void LogcumsumexpKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
bool flatten,
bool exclusive,
bool reverse,
DenseTensor* out) {
using Reducer = LogSumExpReducer<T>;
auto reducer = Reducer();
ScanKernel<T, Context, Reducer>(
dev_ctx, x, axis, flatten, exclusive, reverse, reducer, out);
}

} // namespace phi

PD_REGISTER_KERNEL(cumsum,
CPU,
ALL_LAYOUT,
phi::CumsumKernel,
float,
double,
int16_t,
int,
int64_t) {}

PD_REGISTER_KERNEL(
logcumsumexp, CPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {}
Loading