Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-1453] Support the intput whose dimension is greater than 6 for Transpose and Rollaxis #18707

Merged
merged 4 commits into from
Jul 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 36 additions & 15 deletions src/operator/numpy/np_matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,10 +134,10 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const NumpyTransposeParam& param = nnvm::get<NumpyTransposeParam>(attrs.parsed);
if (req[0] == kNullOp) return;
CHECK(req[0] == kWriteTo || req[0] == kAddTo)
<< "Transpose only supports kWriteTo, kNullOp and kAddTo";
<< "Transpose does not support inplace";
const NumpyTransposeParam& param = nnvm::get<NumpyTransposeParam>(attrs.parsed);
mxnet::TShape axes;
if (ndim_is_known(param.axes)) {
axes = common::CanonicalizeAxes(param.axes);
Expand All @@ -147,10 +147,14 @@ void NumpyTranspose(const nnvm::NodeAttrs& attrs,
axes[i] = axes.ndim() - 1 - i;
}
}
mshadow::Tensor<xpu, 1, dim_t> workspace =
GetTransposeExWorkspace<xpu>(ctx, axes);
if (req[0] == kAddTo) {
TransposeImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], axes);
TransposeExImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0],
axes, workspace);
} else {
TransposeImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], axes);
TransposeExImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0],
axes, workspace);
}
}

Expand Down Expand Up @@ -779,13 +783,21 @@ void NumpyRollaxisCompute(const nnvm::NodeAttrs& attrs,
using namespace mshadow::expr;
CHECK_EQ(inputs.size(), 1U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(req[0], kWriteTo) << "Rollaxis does not support inplace";
mxnet::TShape axes;
if (req[0] == kNullOp) return;
CHECK(req[0] == kWriteTo || req[0] == kAddTo)
<< "Rollaxis does not support inplace";
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed);
axes = NumpyRollaxisShapeImpl(param.axis, param.start, inputs[0].ndim());
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, {
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
})
mxnet::TShape axes = NumpyRollaxisShapeImpl(param.axis, param.start, inputs[0].ndim());

mshadow::Tensor<xpu, 1, dim_t> workspace =
GetTransposeExWorkspace<xpu>(ctx, axes);
if (req[0] == kAddTo) {
TransposeExImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0],
axes, workspace);
} else {
TransposeExImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0],
axes, workspace);
}
}

template<typename xpu>
Expand All @@ -796,6 +808,9 @@ void NumpyRollaxisBackward(const nnvm::NodeAttrs &attrs,
const std::vector<TBlob> &outputs) {
using namespace mshadow;
using namespace mshadow::expr;
if (req[0] == kNullOp) return;
CHECK(req[0] == kWriteTo || req[0] == kAddTo)
<< "Rollaxis Backward does not support inplace";
const NumpyRollaxisParam& param = nnvm::get<NumpyRollaxisParam>(attrs.parsed);
int axis_origin = param.axis;
int start_origin = param.start;
Expand All @@ -819,11 +834,17 @@ void NumpyRollaxisBackward(const nnvm::NodeAttrs &attrs,
axis = start_origin;
start = axis_origin + 1;
}
mxnet::TShape axes;
axes = NumpyRollaxisShapeImpl(axis, start, inputs[0].ndim());
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, Dtype, {
TransposeImpl<xpu>(ctx.run_ctx, inputs[0], outputs[0], axes);
})
mxnet::TShape axes = NumpyRollaxisShapeImpl(axis, start, inputs[0].ndim());

mshadow::Tensor<xpu, 1, dim_t> workspace =
GetTransposeExWorkspace<xpu>(ctx, axes);
if (req[0] == kAddTo) {
TransposeExImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0],
axes, workspace);
} else {
TransposeExImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0],
axes, workspace);
}
}

struct NumpyRot90Param : public dmlc::Parameter<NumpyRot90Param> {
Expand Down
17 changes: 13 additions & 4 deletions src/operator/numpy/np_matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ bool NumpyTransposeShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& shp = (*in_attrs)[0];
mxnet::TShape& out_shp = (*out_attrs)[0];
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";

int ndim = -1;
if (ndim_is_known(shp)) {
Expand Down Expand Up @@ -133,6 +132,10 @@ NNVM_REGISTER_OP(_npi_transpose)
}
})
.set_attr<FCompute>("FCompute<cpu>", NumpyTranspose<cpu>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a"};
Expand Down Expand Up @@ -1261,7 +1264,6 @@ bool NumpyRollaxisShape(const nnvm::NodeAttrs& attrs,

// check transpose dimentions no more than 6
mxnet::TShape& shp = (*in_attrs)[0];
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";

// check axis and start range
CHECK_GE(param.axis, -shp.ndim())
Expand Down Expand Up @@ -1304,6 +1306,10 @@ until it lies in a given position.)code" ADD_FILELINE)
.set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)
.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{"_npi_rollaxis_backward"})
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.add_argument("data", "NDArray-or-Symbol", "Input ndarray")
.add_arguments(NumpyRollaxisParam::__FIELDS__());

Expand All @@ -1312,7 +1318,11 @@ NNVM_REGISTER_OP(_npi_rollaxis_backward)
.set_num_outputs(1)
.set_attr_parser(ParamParser<NumpyRollaxisParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisBackward<cpu>);
.set_attr<FCompute>("FCompute<cpu>", NumpyRollaxisBackward<cpu>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
});

template<>
void NumpyFlipForwardImpl<cpu>(const OpContext& ctx,
Expand Down Expand Up @@ -1368,7 +1378,6 @@ bool NumpyMoveaxisShape(const nnvm::NodeAttrs& attrs,
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 1U);
mxnet::TShape& shp = (*in_attrs)[0];
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
CHECK_EQ(param.source.ndim(), param.destination.ndim())
<< "source and destination not equal.";
mxnet::TShape ret(shp.ndim(), -1);
Expand Down
138 changes: 130 additions & 8 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,15 +321,16 @@ inline bool IsIdentityTranspose(const TShape& axes) {
}

template<typename xpu, bool is_addto = false>
void TransposeImpl(RunContext ctx,
bool TransposeCommonImpl(RunContext ctx,
const TBlob& src,
const TBlob& ret,
const mxnet::TShape& axes) {
// return true when running successfully, otherwise false
using namespace mshadow;
using namespace mshadow::expr;
CHECK_EQ(src.type_flag_, ret.type_flag_);
// zero-size tensor, no need to compute
if (src.shape_.Size() == 0U) return;
if (src.shape_.Size() == 0U) return true;
Stream<xpu> *s = ctx.get_stream<xpu>();
#ifdef __CUDACC__
// This transpose can be used only if there exist n and m such that:
Expand All @@ -339,7 +340,7 @@ void TransposeImpl(RunContext ctx,
MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
transpose_pseudo2D<DType, is_addto>(ret, src, axes, s);
});
return;
return true;
}
#endif
// Special handle the identity case
Expand All @@ -355,7 +356,7 @@ void TransposeImpl(RunContext ctx,
s, ret.Size(), out.dptr_, in.dptr_);
}
});
return;
return true;
}
// Handle the general transpose case
MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
Expand Down Expand Up @@ -413,10 +414,127 @@ void TransposeImpl(RunContext ctx,
break;
}
default:
LOG(FATAL) << "Transpose support at most 6 dimensions";
// return false when dimensions > 6
return false;
break;
}
});
return true;
}

template<typename xpu, bool is_addto = false>
void TransposeImpl(RunContext ctx,
const TBlob& src,
const TBlob& ret,
const mxnet::TShape& axes) {
CHECK_LE(axes.ndim(), 6) << "TransposeImpl supports at most 6 dimensions";
CHECK((TransposeCommonImpl<xpu, is_addto>(ctx, src, ret, axes))) <<
"Failed to execute TransposeImpl Operator";
}

template <bool is_addto>
struct TransposeExKernel {
/*!
* \brief
* \param tid global thread id
* \param out_data output data
* \param in_data input data
* \param strides input strides and output strides
* \param ndim the number of dimension
*/
template <typename DType>
MSHADOW_XINLINE static void Map(int tid,
DType *out_data,
const DType *in_data,
const dim_t *strides,
const int ndim
) {
// tid is the index of input data
const dim_t* const out_strides = strides + ndim;
int k = tid;
int out_id = 0;
for (int i = 0; i < ndim; ++i) {
out_id += (k / strides[i]) * out_strides[i];
k %= strides[i];
}
if (is_addto)
out_data[out_id] += in_data[tid];
else
out_data[out_id] = in_data[tid];
}
};

template<typename xpu, bool is_addto = false>
void TransposeExImpl(RunContext ctx,
const TBlob& src,
const TBlob& ret,
const mxnet::TShape& axes,
mshadow::Tensor<xpu, 1, dim_t>& strides_xpu
) {
/*
* If ndim <= 6, it is not necessary to allocate any space for `strides_xpu`
* If ndim > 6, `strides_xpu` should be allocated `ndim * 2` elements
*/
using namespace mshadow;
using namespace mshadow::expr;
if (TransposeCommonImpl<xpu, is_addto>(ctx, src, ret, axes)) return;
CHECK_GT(axes.ndim(), 6) <<
"Failed to execute TransposeExImpl when axes.ndim() <= 6";
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(ret.type_flag_, DType, {
CHECK_EQ(strides_xpu.MSize(), axes.ndim() * 2) << \
"If ndim > 6, `strides_xpu` should be allocated `ndim * 2` elements";

const mxnet::TShape &in_shape = src.shape_;
// strides: in_strides and out_strides
const int ndim = axes.ndim();
std::vector<dim_t> strides(ndim * 2);
// compute in_strides
strides[ndim - 1] = 1;
for (int i = ndim - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * in_shape[i + 1];
}
// compute out_strides
std::vector<dim_t> tmp_strides(ndim);
tmp_strides[ndim - 1] = 1;
for (int i = ndim - 2; i >= 0; --i) {
tmp_strides[i] = tmp_strides[i + 1] * in_shape[axes[i + 1]];
}
// reorder tmp_strides to out_strides
dim_t * const out_strides = &strides[ndim];
for (int i = 0; i < ndim; ++i) {
out_strides[axes[i]] = tmp_strides[i];
}
Shape<1> strides_shape;
strides_shape[0] = ndim * 2;
Tensor<cpu, 1, dim_t> strides_cpu(strides.data(), strides_shape);
// copy arguments into xpu context
Copy(strides_xpu, strides_cpu, s);
const DType *in = src.dptr<DType>();
DType *out = ret.dptr<DType>();
if (is_addto) {
mxnet_op::Kernel<TransposeExKernel<true>, xpu>::Launch(s,
in_shape.Size(), out, in, strides_xpu.dptr_, ndim);
} else {
mxnet_op::Kernel<TransposeExKernel<false>, xpu>::Launch(s,
in_shape.Size(), out, in, strides_xpu.dptr_, ndim);
}
});
}

template<typename xpu>
mshadow::Tensor<xpu, 1, dim_t> GetTransposeExWorkspace(
const OpContext& ctx,
const mxnet::TShape& axes
) {
if (axes.ndim() > 6) {
// allocate workspace when axes.ndim() > 6
mshadow::Shape<1> strides_shape;
strides_shape[0] = axes.ndim() * 2;
return ctx.requested[0].get_space_typed<xpu, 1, dim_t>(
strides_shape, ctx.get_stream<xpu>());
}
return {};
}

// matrix transpose
Expand All @@ -441,10 +559,15 @@ void Transpose(const nnvm::NodeAttrs& attrs,
} else {
axes = common::CanonicalizeAxes(param.axes);
}

mshadow::Tensor<xpu, 1, dim_t> workspace =
GetTransposeExWorkspace<xpu>(ctx, axes);
if (req[0] == kAddTo) {
TransposeImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0], axes);
TransposeExImpl<xpu, true>(ctx.run_ctx, inputs[0], outputs[0],
axes, workspace);
} else {
TransposeImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0], axes);
TransposeExImpl<xpu, false>(ctx.run_ctx, inputs[0], outputs[0],
axes, workspace);
}
}

Expand All @@ -458,7 +581,6 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape& out_shp = (*out_attrs)[0];
if (!mxnet::ndim_is_known(shp) && !mxnet::ndim_is_known(out_shp))
return false; // none of the shapes is known
CHECK_LE(shp.ndim(), 6) << "Transpose support at most 6 dimensions";
if (out_shp.ndim() >= 0 && shp.ndim() >= 0)
CHECK_EQ(out_shp.ndim(), shp.ndim());
mxnet::TShape get(std::max(shp.ndim(), out_shp.ndim()), -1);
Expand Down
4 changes: 4 additions & 0 deletions src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,10 @@ Examples::
}
})
.set_attr<FCompute>("FCompute<cpu>", Transpose<cpu>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
#if MXNET_USE_MKLDNN == 1
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", TransposeComputeExCPU)
Expand Down
8 changes: 6 additions & 2 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -2439,7 +2439,11 @@ def hybrid_forward(self, F, x):
[(8, 2, 16), [(0, 2, 1), (2, 0, 1), (0, 1, 2), (2, 1, 0), (-1, -2, -3)]],
[(8, 3, 4, 8), [(0, 2, 3, 1), (1, 2, 3, 0), (0, 3, 2, 1)]],
[(8, 3, 2, 3, 8), [(0, 1, 3, 2, 4), (0, 1, 2, 3, 4), (4, 0, 1, 2, 3)]],
[(3, 4, 3, 4, 3, 2), [(0, 1, 3, 2, 4, 5), (2, 3, 4, 1, 0, 5), None]]
[(3, 4, 3, 4, 3, 2), [(0, 1, 3, 2, 4, 5), (2, 3, 4, 1, 0, 5), None]],
[(3, 4, 3, 4, 3, 2, 2), [(0, 1, 3, 2, 4, 5, 6),
(2, 3, 4, 1, 0, 5, 6), None]],
[(3, 4, 3, 4, 3, 2, 3, 2), [(0, 1, 3, 2, 4, 5, 7, 6),
(2, 3, 4, 1, 0, 5, 7, 6), None]],
])
@pytest.mark.parametrize('grad_req', ['write', 'add'])
def test_np_transpose(data_shape, axes_workload, hybridize, dtype, grad_req):
Expand Down Expand Up @@ -10116,7 +10120,7 @@ def hybrid_forward(self, F, a, *args, **kwargs):
dtypes = ['int32', 'int64', 'float16', 'float32', 'float64']
for hybridize in [False, True]:
for dtype in dtypes:
for ndim in [0, 1, 2, 3, 4, 5, 6]:
for ndim in [0, 1, 2, 3, 4, 5, 6, 7, 8]:
shape = rand_shape_nd(ndim, dim=5, allow_zero_size=True)
np_data = _np.random.uniform(low=-100, high=100, size=shape).astype(dtype)
mx_data = np.array(np_data, dtype=dtype)
Expand Down
4 changes: 2 additions & 2 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2573,9 +2573,9 @@ def test_broadcasting_ele(sym_bcast):

@with_seed()
def test_transpose():
for ndim in range(1, 7):
for ndim in range(1, 10):
for t in range(5):
dims = list(np.random.randint(1, 10, size=ndim))
dims = list(np.random.randint(1, 5, size=ndim))
axes = list(range(ndim))
random.shuffle(axes)
axes = tuple(axes)
Expand Down