Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
reduce code
Browse files Browse the repository at this point in the history
  • Loading branch information
JiangZhaoh committed Dec 27, 2019
1 parent 97341ed commit c4c8597
Showing 1 changed file with 97 additions and 76 deletions.
173 changes: 97 additions & 76 deletions src/operator/numpy/np_insert_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -520,109 +520,130 @@ void NumpyInsertCompute(const nnvm::NodeAttrs& attrs,
int vtype = param.val.has_value() ?
mshadow::DataType<double>::kFlag :
inputs[val_pos].type_flag_;
MSHADOW_TYPE_SWITCH(outputs[out_pos].type_flag_, DType, {
if ((param.int_ind.has_value() ||
(obj_is_tensor && inputs[obj_pos].shape_.ndim() == 0) ||
(indices_len == 1)) &&
param.val.has_value()) {
MSHADOW_TYPE_SWITCH(vtype, VType, {
if ((param.int_ind.has_value() ||
(obj_is_tensor && inputs[obj_pos].shape_.ndim() == 0) ||
(indices_len == 1)) &&
param.val.has_value()) {
// If insert use single index and 'value' is inputed as numerical parameter
values = TBlob(ctx.requested[0].get_space_typed<xpu, 1, VType>(Shape1(1), s));
Fill(s, values, kWriteTo, param.val.value());
}
if (param.int_ind.has_value()) {
// 'obj' is integer, need to moveaxis
// If insert use single index and 'value' is inputed as numerical parameter
values = TBlob(ctx.requested[0].get_space_typed<xpu, 1, VType>(Shape1(1), s));
Fill(s, values, kWriteTo, param.val.value());
});
}

if (param.int_ind.has_value()) {
// 'obj' is integer, need to moveaxis
MSHADOW_TYPE_SWITCH(outputs[out_pos].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(vtype, VType, {
Kernel<InsertSingleIndexForward<ndim>, xpu>::Launch(
s, outshape.Size(), outputs[out_pos].dptr<DType>(),
values.dptr<VType>(), arr.dptr<DType>(),
k_outshape, k_valshape, index, numnew,
val_strides, old_val_strides, arr_strides, out_strides,
axis, true, req[out_pos]);
} else if (obj_is_tensor && inputs[obj_pos].shape_.ndim() == 0) {
// 'obj' is tensor and the tensor's ndim is 0, also need to moveaxis
Kernel<InsertSingleIndexForward<ndim>, xpu>::Launch(
s, outshape.Size(), outputs[out_pos].dptr<DType>(),
values.dptr<VType>(), arr.dptr<DType>(),
k_outshape, k_valshape, N, inputs[obj_pos].dptr<int64_t>(), numnew,
val_strides, old_val_strides, arr_strides, out_strides,
axis, true, req[out_pos]);
} else if (indices_len == 1) {
if (param.step.has_value()) {
});
});
} else if (obj_is_tensor && inputs[obj_pos].shape_.ndim() == 0) {
// 'obj' is tensor and the tensor's ndim is 0, also need to moveaxis
MSHADOW_TYPE_SWITCH(outputs[out_pos].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(vtype, VType, {
Kernel<InsertSingleIndexForward<ndim>, xpu>::Launch(
s, outshape.Size(), outputs[out_pos].dptr<DType>(),
values.dptr<VType>(), arr.dptr<DType>(),
k_outshape, k_valshape, N, inputs[obj_pos].dptr<int64_t>(), numnew,
val_strides, old_val_strides, arr_strides, out_strides,
axis, true, req[out_pos]);
});
});
} else if (indices_len == 1) {`
if (param.step.has_value()) {
MSHADOW_TYPE_SWITCH(outputs[out_pos].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(vtype, VType, {
Kernel<InsertSingleIndexForward<ndim>, xpu>::Launch(
s, outshape.Size(), outputs[out_pos].dptr<DType>(),
values.dptr<VType>(), arr.dptr<DType>(),
k_outshape, k_valshape, start, numnew,
val_strides, old_val_strides, arr_strides, out_strides,
axis, false, req[out_pos]);
} else {
});
});
} else {
MSHADOW_TYPE_SWITCH(outputs[out_pos].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(vtype, VType, {
Kernel<InsertSingleIndexForward<ndim>, xpu>::Launch(
s, outshape.Size(), outputs[out_pos].dptr<DType>(),
values.dptr<VType>(), arr.dptr<DType>(),
k_outshape, k_valshape, N, inputs[obj_pos].dptr<int64_t>(), numnew,
val_strides, old_val_strides, arr_strides, out_strides,
axis, false, req[out_pos]);
}
} else {
// broadcast check
for (int i = outshape.ndim() - 1; i >= 0; --i) {
int sz = outshape[i];
if (i == axis) {
sz = numnew;
}
CHECK((values.shape_[i] == 1) || (values.shape_[i] == sz));
}
size_t temp_storage_bytes, temp_mem_size;
temp_storage_bytes = SortByKeyWorkspaceSize<int64_t, int, xpu>(indices_len, false, true);
temp_mem_size = indices_len * sizeof(int64_t) * 2 +
indices_len * sizeof(int) +
outshape[axis] * sizeof(int) * 2 +
temp_storage_bytes;
Tensor<xpu, 1, char> temp_mem =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_mem_size), s);
int64_t* indices_ptr = reinterpret_cast<int64_t*>(temp_mem.dptr_);
int64_t* sorted_indices_ptr = reinterpret_cast<int64_t*>(indices_ptr + indices_len);
int* order_ptr = reinterpret_cast<int*>(sorted_indices_ptr + indices_len);
int* is_insert = reinterpret_cast<int*>(order_ptr + indices_len);
int* origin_idx = reinterpret_cast<int*>(is_insert + outshape[axis]);
Tensor<xpu, 1, char> temp_storage(reinterpret_cast<char*>(origin_idx + outshape[axis]),
Shape1(temp_storage_bytes), s);
Tensor<xpu, 1, int64_t> indices(indices_ptr, Shape1(indices_len), s);
Tensor<xpu, 1, int64_t> sorted_indices(sorted_indices_ptr, Shape1(indices_len), s);
Tensor<xpu, 1, int> order(order_ptr, Shape1(indices_len), s);
int num_bits = common::ilog2ui(static_cast<unsigned int>(indices_len) - 1);
if (param.step.has_value()) {
Kernel<SliceToIndices, xpu>::Launch(s, indices_len, indices_ptr, N, start, step);
} else {
Kernel<ObjToIndices, xpu>::Launch(s, indices_len, indices_ptr, N,
inputs[obj_pos].dptr<int64_t>());
}
Kernel<range_fwd, xpu>::Launch(s, indices_len, 1, 0, 1, kWriteTo, order_ptr);
mxnet::op::SortByKey(indices, order, true, &temp_storage, 0, num_bits, &sorted_indices);
Kernel<IndicesModify, xpu>::Launch(s, indices_len, indices_ptr, order_ptr);
});
});
}
} else {
// broadcast check
for (int i = outshape.ndim() - 1; i >= 0; --i) {
int sz = outshape[i];
if (i == axis) {
sz = numnew;
}
CHECK((values.shape_[i] == 1) || (values.shape_[i] == sz));
}
size_t temp_storage_bytes, temp_mem_size;
temp_storage_bytes = SortByKeyWorkspaceSize<int64_t, int, xpu>(indices_len, false, true);
temp_mem_size = indices_len * sizeof(int64_t) * 2 +
indices_len * sizeof(int) +
outshape[axis] * sizeof(int) * 2 +
temp_storage_bytes;
Tensor<xpu, 1, char> temp_mem =
ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(temp_mem_size), s);
int64_t* indices_ptr = reinterpret_cast<int64_t*>(temp_mem.dptr_);
int64_t* sorted_indices_ptr = reinterpret_cast<int64_t*>(indices_ptr + indices_len);
int* order_ptr = reinterpret_cast<int*>(sorted_indices_ptr + indices_len);
int* is_insert = reinterpret_cast<int*>(order_ptr + indices_len);
int* origin_idx = reinterpret_cast<int*>(is_insert + outshape[axis]);
Tensor<xpu, 1, char> temp_storage(reinterpret_cast<char*>(origin_idx + outshape[axis]),
Shape1(temp_storage_bytes), s);
Tensor<xpu, 1, int64_t> indices(indices_ptr, Shape1(indices_len), s);
Tensor<xpu, 1, int64_t> sorted_indices(sorted_indices_ptr, Shape1(indices_len), s);
Tensor<xpu, 1, int> order(order_ptr, Shape1(indices_len), s);
int num_bits = common::ilog2ui(static_cast<unsigned int>(indices_len) - 1);
if (param.step.has_value()) {
Kernel<SliceToIndices, xpu>::Launch(s, indices_len, indices_ptr, N, start, step);
} else {
Kernel<ObjToIndices, xpu>::Launch(s, indices_len, indices_ptr, N,
inputs[obj_pos].dptr<int64_t>());
}
Kernel<range_fwd, xpu>::Launch(s, indices_len, 1, 0, 1, kWriteTo, order_ptr);
mxnet::op::SortByKey(indices, order, true, &temp_storage, 0, num_bits, &sorted_indices);
Kernel<IndicesModify, xpu>::Launch(s, indices_len, indices_ptr, order_ptr);

mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(s, outshape[axis], is_insert);
Kernel<SetIsInsert, xpu>::Launch(s, indices_len, indices_ptr, is_insert);
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(s, outshape[axis], is_insert);
Kernel<SetIsInsert, xpu>::Launch(s, indices_len, indices_ptr, is_insert);

Kernel<SetOriginValuesIdx, xpu>::Launch(s, indices_len, indices_ptr, origin_idx);
Kernel<SetOriginArrIdx, xpu>::Launch(s, outshape[axis], is_insert, origin_idx);
if (param.val.has_value()) {
Kernel<InsertSeqIndicesForward<ndim>, xpu>::Launch(
s, outshape.Size(),
outputs[out_pos].dptr<DType>(),
param.val.value(), arr.dptr<DType>(),
k_outshape, is_insert, origin_idx,
arr_strides, out_strides, axis, req[out_pos]);
} else {
Kernel<SetOriginValuesIdx, xpu>::Launch(s, indices_len, indices_ptr, origin_idx);
Kernel<SetOriginArrIdx, xpu>::Launch(s, outshape[axis], is_insert, origin_idx);
if (param.val.has_value()) {
MSHADOW_TYPE_SWITCH(outputs[out_pos].type_flag_, DType, {
Kernel<InsertSeqIndicesForward<ndim>, xpu>::Launch(
s, outshape.Size(),
outputs[out_pos].dptr<DType>(),
param.val.value(), arr.dptr<DType>(),
k_outshape, is_insert, origin_idx,
arr_strides, out_strides, axis, req[out_pos]);
});
} else {
MSHADOW_TYPE_SWITCH(outputs[out_pos].type_flag_, DType, {
MSHADOW_TYPE_SWITCH(vtype, VType, {
Kernel<InsertSeqIndicesForward<ndim>, xpu>::Launch(
s, outshape.Size(),
outputs[out_pos].dptr<DType>(),
values.dptr<VType>(), arr.dptr<DType>(),
k_outshape, k_valshape, is_insert, origin_idx,
val_strides, arr_strides, out_strides, axis, req[out_pos]);
}
}
});
});
});
});
}
}
});
}

Expand Down

0 comments on commit c4c8597

Please sign in to comment.