Skip to content

Commit

Permalink
roll_op: support Tensor as input for shifts (#36727)
Browse files Browse the repository at this point in the history
  • Loading branch information
Feiyu Chan authored Oct 26, 2021
1 parent 236ed94 commit 7b1e30f
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 22 deletions.
39 changes: 24 additions & 15 deletions paddle/fluid/operators/roll_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,23 @@ class RollOp : public framework::OperatorWithKernel {
auto dims = ctx->Attrs().Get<std::vector<int64_t>>("axis");
auto shifts = ctx->Attrs().Get<std::vector<int64_t>>("shifts");

if (dims.size() != 0) {
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
platform::errors::InvalidArgument(
"When dims.size() != 0, dims.size() "
"should be equal to "
"shifts.size(). But received "
"dims.size() = %d, shifts.size() = %d",
dims.size(), shifts.size()));
} else {
PADDLE_ENFORCE_EQ(shifts.size(), 1,
platform::errors::InvalidArgument(
"When dims.size() == 0, shifts.size() "
"should be equal to 1, But received "
"shifts.size() = %d",
shifts.size()));
if (!ctx->HasInput("ShiftsTensor")) {
if (dims.size() != 0) {
PADDLE_ENFORCE_EQ(dims.size(), shifts.size(),
platform::errors::InvalidArgument(
"When dims.size() != 0, dims.size() "
"should be equal to "
"shifts.size(). But received "
"dims.size() = %d, shifts.size() = %d",
dims.size(), shifts.size()));
} else {
PADDLE_ENFORCE_EQ(shifts.size(), 1,
platform::errors::InvalidArgument(
"When dims.size() == 0, shifts.size() "
"should be equal to 1, But received "
"shifts.size() = %d",
shifts.size()));
}
}

ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
Expand Down Expand Up @@ -105,6 +107,10 @@ class RollOpMaker : public framework::OpProtoAndCheckerMaker {
"The number of places by which the elements "
"of the tensor are shifted.")
.SetDefault({});
AddInput("ShiftsTensor",
"The number of places by which the elements of the tensor "
"are shifted.")
.AsDispensable();
AddAttr<std::vector<int64_t>>(
"axis",
"Axis along which to roll. It must have the same size "
Expand All @@ -129,6 +135,9 @@ class RollGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> op) const override {
op->SetType("roll_grad");
op->SetInput("X", this->Input("X"));
if (this->HasInput("ShiftsTensor")) {
op->SetInput("ShiftsTensor", this->Input("ShiftsTensor"));
}
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
Expand Down
20 changes: 20 additions & 0 deletions paddle/fluid/operators/roll_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,16 @@ class RollKernel<platform::CUDADeviceContext, T>
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");

auto* in_data = in->data<T>();
Expand Down Expand Up @@ -134,6 +144,16 @@ class RollGradKernel<platform::CUDADeviceContext, T>
auto* in = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto* out = context.Output<LoDTensor>(framework::GradVarName("X"));
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");

auto* in_data = in->data<T>();
Expand Down
17 changes: 17 additions & 0 deletions paddle/fluid/operators/roll_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include <memory>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace operators {
Expand Down Expand Up @@ -85,6 +87,16 @@ class RollKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
PADDLE_ENFORCE_EQ(
shifts_tensor->dims().size(), 1,
platform::errors::InvalidArgument(
"The rank of ShiftsTensor is expected to be 1, got %s",
shifts_tensor->dims().size()));
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");

std::vector<T> out_vec;
Expand Down Expand Up @@ -123,6 +135,11 @@ class RollGradKernel : public framework::OpKernel<T> {
auto& input = input_var->Get<LoDTensor>();
auto* output = output_var->GetMutable<LoDTensor>();
std::vector<int64_t> shifts = context.Attr<std::vector<int64_t>>("shifts");
if (context.HasInput("ShiftsTensor")) {
const auto* shifts_tensor =
context.Input<framework::Tensor>("ShiftsTensor");
shifts = GetDataFromTensor<int64_t>(shifts_tensor);
}
std::vector<int64_t> dims = context.Attr<std::vector<int64_t>>("axis");

std::vector<T> out_vec;
Expand Down
28 changes: 28 additions & 0 deletions python/paddle/fluid/tests/unittests/test_roll_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,34 @@ def test_axis_out_range():

self.assertRaises(ValueError, test_axis_out_range)

def test_shifts_as_tensor_dygraph(self):
with fluid.dygraph.guard():
x = paddle.arange(9).reshape([3, 3])
shape = paddle.shape(x)
shifts = shape // 2
axes = [0, 1]
out = paddle.roll(x, shifts=shifts, axis=axes).numpy()
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])
self.assertTrue(np.allclose(out, expected_out))

def test_shifts_as_tensor_static(self):
with program_guard(Program(), Program()):
x = paddle.arange(9).reshape([3, 3]).astype('float32')
shape = paddle.shape(x)
shifts = shape // 2
axes = [0, 1]
out = paddle.roll(x, shifts=shifts, axis=axes)
expected_out = np.array([[8, 6, 7], [2, 0, 1], [5, 3, 4]])

exe = fluid.Executor(fluid.CPUPlace())
[out_np] = exe.run(fetch_list=[out])
self.assertTrue(np.allclose(out_np, expected_out))

if paddle.is_compiled_with_cuda():
exe = fluid.Executor(fluid.CPUPlace())
[out_np] = exe.run(fetch_list=[out])
self.assertTrue(np.allclose(out_np, expected_out))


if __name__ == "__main__":
unittest.main()
23 changes: 16 additions & 7 deletions python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,15 +696,24 @@ def roll(x, shifts, axis=None, name=None):

helper = LayerHelper("roll", **locals())
check_type(axis, 'axis', (list, tuple), 'roll')
check_type(shifts, 'shifts', (list, tuple), 'roll')

out = helper.create_variable_for_type_inference(x.dtype)

helper.append_op(
type='roll',
inputs={'X': x},
outputs={'Out': out},
attrs={'axis': axis,
'shifts': shifts})
if isinstance(shifts, Variable):
helper.append_op(
type='roll',
inputs={'X': x,
"ShiftsTensor": shifts},
outputs={'Out': out},
attrs={'axis': axis})
else:
check_type(shifts, 'shifts', (list, tuple), 'roll')
helper.append_op(
type='roll',
inputs={'X': x},
outputs={'Out': out},
attrs={'axis': axis,
'shifts': shifts})
return out


Expand Down

0 comments on commit 7b1e30f

Please sign in to comment.