Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Numpy flip operator
Browse files Browse the repository at this point in the history
* Implement flip

* fix some bug and add gpu test

* register param and edit test

* add testcase for backward

* remove print

* optimize 0-dim and 0-shape

* adjust format and add doc in _symbol.py

* fix bug in symbol

* add flip in __all__

* fix format error

* import ndarray

* move flip implementation to np_matrix_op and remove test in gpu

* delate redundant blank line
  • Loading branch information
Ying committed Sep 10, 2019
1 parent 9675a2d commit 554517b
Show file tree
Hide file tree
Showing 7 changed files with 386 additions and 3 deletions.
70 changes: 69 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'flip']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -2363,3 +2363,71 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint:
0.2025
"""
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)


@set_module('mxnet.ndarray.numpy')
def flip(x, axis=None, out=None, **kwargs):
r"""
flip(x, axis=None, out=None)
Reverse the order of elements in an array along the given axis.
The shape of the array is preserved, but the elements are reordered.
Parameters
----------
m : ndarray or scalar
Input array.
axis : None or int or tuple of ints, optional
Axis or axes along which to flip over. The default,
axis=None, will flip over all of the axes of the input array.
If axis is negative it counts from the last to the first axis.
If axis is a tuple of ints, flipping is performed on all of the axes
specified in the tuple.
out : ndarray or scalar, optional
Alternative output array in which to place the result. It must have
the same shape and type as the expected output.
Returns
-------
out : ndarray or scalar
A view of `m` with the entries of axis reversed. Since a view is
returned, this operation is done in constant time.
Examples
--------
>>> A = np.arange(8).reshape((2,2,2))
>>> A
array([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
>>> np.flip(A, 0)
array([[[4, 5],
[6, 7]],
[[0, 1],
[2, 3]]])
>>> np.flip(A, 1)
array([[[2, 3],
[0, 1]],
[[6, 7],
[4, 5]]])
>>> np.flip(A)
array([[[7, 6],
[5, 4]],
[[3, 2],
[1, 0]]])
>>> np.flip(A, (0, 2))
array([[[5, 4],
[7, 6]],
[[1, 0],
[3, 2]]])
"""
from ...numpy import ndarray
if isinstance(x, numeric_types):
return _np.flip(x, axis, **kwargs)
elif isinstance(x, ndarray):
return _npi.flip(x, axis, out=out, **kwargs)
else:
raise TypeError('type {} not supported'.format(str(type(x))))
64 changes: 63 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
'degrees', 'log2', 'log1p', 'rint', 'radians', 'reciprocal', 'square', 'negative',
'fix', 'ceil', 'floor', 'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh',
'tensordot', 'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate',
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
'stack', 'mean', 'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'flip']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -3808,3 +3808,65 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=None):
0.2025
"""
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)


@set_module('mxnet.numpy')
def flip(x, axis=None, out=None, **kwargs):
r"""
flip(x, axis=None, out=None)
Reverse the order of elements in an array along the given axis.
The shape of the array is preserved, but the elements are reordered.
Parameters
----------
m : ndarray or scalar
Input array.
axis : None or int or tuple of ints, optional
Axis or axes along which to flip over. The default,
axis=None, will flip over all of the axes of the input array.
If axis is negative it counts from the last to the first axis.
If axis is a tuple of ints, flipping is performed on all of the axes
specified in the tuple.
out : ndarray or scalar, optional
Alternative output array in which to place the result. It must have
the same shape and type as the expected output.
Returns
-------
out : ndarray or scalar
A view of `m` with the entries of axis reversed. Since a view is
returned, this operation is done in constant time.
Examples
--------
>>> A = np.arange(8).reshape((2,2,2))
>>> A
array([[[0, 1],
[2, 3]],
[[4, 5],
[6, 7]]])
>>> np.flip(A, 0)
array([[[4, 5],
[6, 7]],
[[0, 1],
[2, 3]]])
>>> np.flip(A, 1)
array([[[2, 3],
[0, 1]],
[[6, 7],
[4, 5]]])
>>> np.flip(A)
array([[[7, 6],
[5, 4]],
[[3, 2],
[1, 0]]])
>>> np.flip(A, (0, 2))
array([[[5, 4],
[7, 6]],
[[1, 0],
[3, 2]]])
"""
return _mx_nd_np.flip(x, axis, out=out, **kwargs)
40 changes: 39 additions & 1 deletion python/mxnet/symbol/numpy/_symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
'rint', 'radians', 'reciprocal', 'square', 'negative', 'fix', 'ceil', 'floor',
'trunc', 'logical_not', 'arcsinh', 'arccosh', 'arctanh', 'tensordot',
'linspace', 'expand_dims', 'tile', 'arange', 'split', 'concatenate', 'stack', 'mean',
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var']
'maximum', 'minimum', 'swapaxes', 'clip', 'argmax', 'std', 'var', 'flip']


def _num_outputs(sym):
Expand Down Expand Up @@ -2678,4 +2678,42 @@ def var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False): # pylint:
return _npi.var(a, axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims, out=out)


@set_module('mxnet.symbol.numpy')
def flip(x, axis=None, out=None, **kwargs):
r"""
flip(x, axis=None, out=None)
Reverse the order of elements in an array along the given axis.
The shape of the array is preserved, but the elements are reordered.
Parameters
----------
m : _Symbol or scalar
Input array.
axis : None or int or tuple of ints, optional
Axis or axes along which to flip over. The default,
axis=None, will flip over all of the axes of the input array.
If axis is negative it counts from the last to the first axis.
If axis is a tuple of ints, flipping is performed on all of the axes
specified in the tuple.
out : _Symbol or scalar, optional
Alternative output array in which to place the result. It must have
the same shape and type as the expected output.
Returns
-------
out : _Symbol or scalar
A view of `m` with the entries of axis reversed. Since a view is
returned, this operation is done in constant time.
"""
if isinstance(x, numeric_types):
return _np.flip(x, axis, **kwargs)
elif isinstance(x, _Symbol):
return _npi.flip(x, axis, out=out, **kwargs)
else:
raise TypeError('type {} not supported'.format(str(type(x))))


_set_np_symbol_class(_Symbol)
75 changes: 75 additions & 0 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,81 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
}
}

struct FlipParam : public dmlc::Parameter<FlipParam> {
mxnet::Tuple<int> axis;
DMLC_DECLARE_PARAMETER(FlipParam) {
DMLC_DECLARE_FIELD(axis)
.describe("The axis which to flip elements.");
}
};

struct flip0dim_shared_kernel {
template<typename DType>
MSHADOW_XINLINE static void Map(int i,
DType* out_data,
const DType* in_data) {
out_data[i] = in_data[i];
}
};

#define FLIP_MAX_DIM 10
#define FLIP_MIN_DIM -1

template<typename xpu>
void NumpyFlipForwardImpl(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const std::vector<index_t>& stride_,
const std::vector<index_t>& trailing_,
const index_t& flip_index);

template<typename xpu>
void NumpyFlipForward(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const FlipParam& param = nnvm::get<FlipParam>(attrs.parsed);
mxnet::Tuple<int> axistemp;
CHECK_EQ(inputs[0].type_flag_, outputs[0].type_flag_);
CHECK_LT(param.axis.ndim(), FLIP_MAX_DIM);
CHECK_GE(param.axis.ndim(), FLIP_MIN_DIM);
if (param.axis.ndim() == FLIP_MIN_DIM) {
if (inputs[0].shape_.ndim() == 0) {
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
mxnet_op::Kernel<flip0dim_shared_kernel, xpu>::Launch(s, inputs[0].Size(),
outputs[0].dptr<DType>(), inputs[0].dptr<DType>());
});
return;
}
std::vector<int> temp;
for (int i = 0; i < inputs[0].shape_.ndim(); i++) {
temp.push_back(i);
}
axistemp.assign(temp.begin(), temp.end());
} else {
axistemp = param.axis;
}

const mxnet::TShape& ishape = inputs[0].shape_;
if (ishape.ProdShape(0, ishape.ndim()) == 0) {
return; // zero shape
}
std::vector<index_t> stride_(axistemp.ndim());
std::vector<index_t> trailing_(axistemp.ndim());
index_t flip_index = 0;
for (int axis : axistemp) {
CHECK_LT(axis, ishape.ndim());
stride_[flip_index] = ishape[axis];
trailing_[flip_index] = 1;
for (int i2 = axis + 1; i2 < ishape.ndim(); ++i2) {
trailing_[flip_index] *= ishape[i2];
}
flip_index++;
}
NumpyFlipForwardImpl<xpu>(ctx, inputs, outputs, stride_, trailing_, flip_index);
}
} // namespace op
} // namespace mxnet

Expand Down
46 changes: 46 additions & 0 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,5 +345,51 @@ Examples::
.add_argument("data", "NDArray-or-Symbol[]", "List of arrays to stack")
.add_arguments(StackParam::__FIELDS__());

template<>
void NumpyFlipForwardImpl<cpu>(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const std::vector<index_t>& stride_,
const std::vector<index_t>& trailing_,
const index_t& flip_index) {
mshadow::Stream<cpu> *s = ctx.get_stream<cpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
mxnet_op::Kernel<reverse, cpu>::Launch(s, inputs[0].Size(), flip_index,
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
stride_.data(), trailing_.data());
});
}

DMLC_REGISTER_PARAMETER(FlipParam);

NNVM_REGISTER_OP(_npi_flip)
.set_num_outputs(1)
.set_num_inputs(1)
.set_attr_parser(ParamParser<FlipParam>)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string> {"data"};
})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest> {ResourceRequest::kTempSpace};
})
.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", NumpyFlipForward<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_backward_npi_flip"})
.add_argument("data", "NDArray-or-Symbol", "Input data array")
.add_arguments(FlipParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_npi_flip)
.set_num_inputs(1)
.set_num_outputs(1)
.set_attr_parser(ParamParser<FlipParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest> {ResourceRequest::kTempSpace};
})
.set_attr<FCompute>("FCompute<cpu>", NumpyFlipForward<cpu>);
} // namespace op
} // namespace mxnet
34 changes: 34 additions & 0 deletions src/operator/numpy/np_matrix_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,39 @@ NNVM_REGISTER_OP(_backward_np_concat)
NNVM_REGISTER_OP(_npi_stack)
.set_attr<FCompute>("FCompute<gpu>", StackOpForward<gpu>);

template<>
void NumpyFlipForwardImpl<gpu>(const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<TBlob>& outputs,
const std::vector<index_t>& stride_,
const std::vector<index_t>& trailing_,
const index_t& flip_index) {
mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
mshadow::Tensor<gpu, 1, uint8_t> workspace =
ctx.requested[0].get_space_typed<gpu, 1, uint8_t>(
mshadow::Shape1(flip_index * sizeof(index_t) * 2), s);

auto stride_workspace = workspace.dptr_;
auto trailing_workspace = workspace.dptr_ + flip_index * sizeof(index_t);

cudaMemcpyAsync(stride_workspace, thrust::raw_pointer_cast(stride_.data()),
stride_.size() * sizeof(index_t),
cudaMemcpyHostToDevice, mshadow::Stream<gpu>::GetStream(s));
cudaMemcpyAsync(trailing_workspace, thrust::raw_pointer_cast(trailing_.data()),
trailing_.size() * sizeof(index_t),
cudaMemcpyHostToDevice, mshadow::Stream<gpu>::GetStream(s));

MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
mxnet_op::Kernel<reverse, gpu>::Launch(s, inputs[0].Size(), flip_index,
inputs[0].dptr<DType>(), outputs[0].dptr<DType>(),
reinterpret_cast<index_t*>(stride_workspace), reinterpret_cast<index_t*>(trailing_workspace));
});
}

NNVM_REGISTER_OP(_npi_flip)
.set_attr<FCompute>("FCompute<gpu>", NumpyFlipForward<gpu>);

NNVM_REGISTER_OP(_backward_npi_flip)
.set_attr<FCompute>("FCompute<gpu>", NumpyFlipForward<gpu>);
} // namespace op
} // namespace mxnet
Loading

0 comments on commit 554517b

Please sign in to comment.