forked from apache/mxnet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add npx op 'index_add' (apache#18089)
* part cpu * index_add forward & test * fix wrong doc * fix index_add_sanity_error * index_update_test * remove index_update & implement index_add backward * fix sanity error * reduce code length * depart into two file * test CI compiler * test CI * test CI * reduce mshadow & allow more dtype * fix sanity error * fix conflict * reduce fwd macro code * reduce bwd macro code * fix compile error * tensor ind * remove cudaMalloc/cudaFree * fix windows compile error * fix compile error * use value instead of references * remove pragma * fix naive engine error * try to pass CI * fix sanity error * depart gradient into three node * resolve comment & initialize mshadow::Shape * fix werror Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Ubuntu <[email protected]> Co-authored-by: Xingjian Shi <[email protected]>
- Loading branch information
Showing
7 changed files
with
876 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,231 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file index_add-inl.h | ||
* \brief Function definition of index_add operator | ||
*/ | ||
#ifndef MXNET_OPERATOR_TENSOR_INDEX_ADD_INL_H_ | ||
#define MXNET_OPERATOR_TENSOR_INDEX_ADD_INL_H_ | ||
|
||
#include <mxnet/operator_util.h> | ||
#include <vector> | ||
#include <algorithm> | ||
#include "../mxnet_op.h" | ||
#include "../operator_common.h" | ||
#include "../elemwise_op_common.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
inline bool IndexModifyOpShape(const nnvm::NodeAttrs& attrs, | ||
mxnet::ShapeVector* in_attrs, | ||
mxnet::ShapeVector* out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 3U); | ||
CHECK_EQ(out_attrs->size(), 1U); | ||
SHAPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); | ||
return true; | ||
} | ||
|
||
inline bool IndexModifyOpType(const nnvm::NodeAttrs& attrs, | ||
std::vector<int>* in_attrs, | ||
std::vector<int>* out_attrs) { | ||
CHECK_EQ(in_attrs->size(), 3U); | ||
CHECK_EQ(out_attrs->size(), 1U); | ||
CHECK_NE((*in_attrs)[0], -1); | ||
CHECK_NE((*in_attrs)[1], -1); | ||
CHECK_NE((*in_attrs)[2], -1); | ||
CHECK_EQ((*in_attrs)[0], (*in_attrs)[2]) | ||
<< "index_add(a, ind, val) only support a.dtype == val.dtype"; | ||
CHECK((*in_attrs)[1] == mshadow::kInt64 || | ||
(*in_attrs)[1] == mshadow::kInt32) | ||
<< "'ind' only support int dtype."; | ||
TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[0]); | ||
return (*out_attrs)[0] != -1; | ||
} | ||
|
||
template<typename xpu, typename DType> | ||
void IndexAddForwardCalc(mshadow::Stream<xpu> *s, | ||
const int ind_num, DType* out, | ||
const DType* val, | ||
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_tail_shape, | ||
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_pre_stride, | ||
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride, | ||
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_shape, | ||
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_shape, | ||
const int a_tail_size, | ||
const int ind_ndim, const int* ind, | ||
const int a_ndim); | ||
|
||
template<typename xpu> | ||
void IndexAddOpForward(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<TBlob>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<TBlob>& outputs) { | ||
using namespace mxnet_op; | ||
using namespace mshadow; | ||
CHECK_EQ(inputs.size(), 3U); | ||
CHECK_EQ(outputs.size(), 1U); | ||
Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
const TBlob a = inputs[0]; | ||
TBlob ind = inputs[1]; | ||
TBlob val = inputs[2]; | ||
TBlob out = outputs[0]; | ||
CHECK_GT(a.shape_.ndim(), 0) << "The first input is saclar, please use '+' instead."; | ||
int a_ndim = a.shape_.ndim(); | ||
CHECK_LE(a_ndim, MXNET_SPECIAL_MAX_NDIM) | ||
<< "ndim should less than "<< MXNET_SPECIAL_MAX_NDIM | ||
<< "but get " << a_ndim <<"\n"; | ||
int val_ndim = val.shape_.ndim(); | ||
if (val_ndim == 0) { | ||
val.shape_ = Shape1(1); | ||
val_ndim = 1; | ||
} | ||
// ind=np.array([]), ind.shape_.ndim() = 1 | ||
// ind=np.array(1), ind.shape_.ndim() = 0 | ||
// ind=np.array([[0,0],[0,1]]), ind.shape_.ndim() = 2 | ||
CHECK_NE(ind.shape_.Size(), 0) << "Param 'ind' is []. Please just use op 'add' instead.\n"; | ||
CHECK_LE(ind.shape_.ndim(), 2) << "'ind' array allow 2 dimension at most."; | ||
if (ind.shape_.ndim() == 0) { | ||
ind.shape_ = Shape2(1, 1); | ||
} else if (ind.shape_.ndim() == 1) { | ||
ind.shape_ = Shape2(1, ind.shape_[0]); | ||
} | ||
int ind_ndim = ind.shape_[0]; | ||
int ind_num = ind.shape_[1]; | ||
CHECK_LE(ind_ndim, a_ndim) << "IndexError: too many indices for array."; | ||
|
||
// check 'val' broadcast legality | ||
CHECK_LE(val_ndim, a_ndim - ind_ndim + 1) | ||
<< "The ndim of param 'val' is " << val_ndim | ||
<< ", but it should less than or equal to " << a_ndim - ind_ndim + 1; | ||
for (int i = a_ndim - 1, j = val_ndim - 1; j >= 0 ; --i, --j) { | ||
if ((j == 0) && (val_ndim == a_ndim - ind_ndim + 1)) { | ||
// val_ndim == a_ndim - ind_ndim + 1, check the first dim of input 'val' | ||
CHECK(val.shape_[j] == ind_num || val.shape_[j] == 1) | ||
<< "can not broadcast from " << val.shape_[j] << " to " << ind_num; | ||
} else { | ||
CHECK(val.shape_[j] == a.shape_[i] || val.shape_[j] == 1) | ||
<< "can not broadcast from " << val.shape_[j] << " to " << a.shape_[i] | ||
<< " in axis " << i; | ||
} | ||
} | ||
int a_tail_size = static_cast<int>(a.shape_.ProdShape(ind_ndim, a_ndim)); | ||
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_shape, val_shape; | ||
for (int i = MXNET_SPECIAL_MAX_NDIM - 1, j = a_ndim - 1; i >= 0; --i, --j) { | ||
a_shape[i] = (j >= 0) ? a.shape_[j] : 1; | ||
} | ||
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_pre_shape(a_shape); | ||
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_tail_shape(a_shape); | ||
|
||
int seg = MXNET_SPECIAL_MAX_NDIM - a_ndim; | ||
for (int i = seg; i < ind_ndim + seg; ++i) { | ||
a_tail_shape[i] = 1; | ||
} | ||
for (int i = ind_ndim + seg; i < a_ndim + seg; ++i) { | ||
a_pre_shape[i] = 1; | ||
} | ||
for (int i = MXNET_SPECIAL_MAX_NDIM - 1, j = val_ndim - 1; i >= 0; --i, --j) { | ||
val_shape[i] = (j >= 0) ? val.shape_[j] : 1; | ||
} | ||
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> a_pre_stride = calc_stride(a_pre_shape); | ||
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride = calc_stride(val_shape); | ||
mxnet_op::copy(s, out, a); | ||
TBlob t_ind = TBlob(ctx.requested[0].get_space_typed<xpu, 1, int> | ||
(Shape1(ind.shape_.Size()), s)); | ||
mxnet_op::copy(s, t_ind, ind); | ||
MSHADOW_TYPE_SWITCH(a.type_flag_, DType, { | ||
IndexAddForwardCalc<xpu, DType>(s, ind_num, | ||
out.dptr<DType>(), val.dptr<DType>(), | ||
a_tail_shape, a_pre_stride, | ||
val_stride, val_shape, a_shape, | ||
a_tail_size, ind_ndim, | ||
t_ind.dptr<int>(), a_ndim); | ||
}); | ||
} | ||
|
||
template<typename xpu> | ||
void IndexAddOpBackwardValImpl(const OpContext& ctx, | ||
const TBlob& grad_val, | ||
const TBlob& ograd, | ||
const TBlob& t_ind, | ||
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_shape, | ||
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_stride, | ||
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride, | ||
const mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_shape, | ||
const int tail_size, const int ind_num, const int ind_ndim, | ||
const int ndim); | ||
|
||
template<typename xpu> | ||
inline void IndexAddOpBackwardVal(const nnvm::NodeAttrs& attrs, | ||
const OpContext& ctx, | ||
const std::vector<TBlob>& inputs, | ||
const std::vector<OpReqType>& req, | ||
const std::vector<TBlob>& outputs) { | ||
using namespace mshadow; | ||
using namespace mxnet_op; | ||
if (req[0] == kNullOp) { | ||
return; | ||
} | ||
CHECK_EQ(inputs.size(), 2U); | ||
CHECK_EQ(outputs.size(), 1U); | ||
const TBlob& ograd = inputs[0]; | ||
TBlob ind = inputs[1]; | ||
const TBlob& grad_val = outputs[0]; | ||
mshadow::Stream<xpu> *s = ctx.get_stream<xpu>(); | ||
// get the number of 'ind' index | ||
if (ind.shape_.ndim() == 0) { | ||
ind.shape_ = Shape2(1, 1); | ||
} else if (ind.shape_.ndim() == 1) { | ||
ind.shape_ = Shape2(1, ind.shape_[0]); | ||
} | ||
int ind_ndim = ind.shape_[0]; | ||
int ind_num = ind.shape_[1]; | ||
int out_ndim = ograd.shape_.ndim(); | ||
int tail_size = static_cast<int>(ograd.shape_.ProdShape(ind_ndim, out_ndim)); | ||
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_shape, val_shape; | ||
for (int i = MXNET_SPECIAL_MAX_NDIM - 1, j = out_ndim - 1; i >= 0; --i, --j) { | ||
ograd_shape[i] = (j >= 0) ? ograd.shape_[j] : 1; | ||
} | ||
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_shape(ograd_shape); | ||
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_tail_shape(ograd_shape); | ||
TBlob t_ind = TBlob(ctx.requested[0].get_space_typed<xpu, 1, int> | ||
(Shape1(ind.shape_.Size()), s)); | ||
mxnet_op::copy(s, t_ind, ind); | ||
int seg = MXNET_SPECIAL_MAX_NDIM - out_ndim; | ||
for (int i = seg; i < seg + ind_ndim; ++i) { | ||
ograd_tail_shape[i] = 1; | ||
} | ||
for (int i = seg + ind_ndim; i < seg + out_ndim; ++i) { | ||
ograd_pre_shape[i] = 1; | ||
} | ||
for (int i = seg + out_ndim - 1, j = grad_val.shape_.ndim() - 1; i >= seg; --i, --j) { | ||
val_shape[i] = (j >= 0) ? grad_val.shape_[j] : 1; | ||
} | ||
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> ograd_pre_stride = mxnet_op::calc_stride(ograd_pre_shape); | ||
mshadow::Shape<MXNET_SPECIAL_MAX_NDIM> val_stride = mxnet_op::calc_stride(val_shape); | ||
IndexAddOpBackwardValImpl<xpu>(ctx, grad_val, ograd, t_ind, ograd_tail_shape, ograd_pre_stride, | ||
val_stride, val_shape, tail_size, ind_num, ind_ndim, out_ndim); | ||
} | ||
|
||
} // namespace op | ||
} // namespace mxnet | ||
|
||
#endif // MXNET_OPERATOR_TENSOR_INDEX_ADD_INL_H_ |
Oops, something went wrong.