Skip to content

Commit

Permalink
Fix the topk regression issue (apache#12197) (apache#12202)
Browse files Browse the repository at this point in the history
* Fix the topk regression issue (apache#12197)

* Add comments
  • Loading branch information
ciyongch authored and szha committed Aug 16, 2018
1 parent 82382ef commit 53ccc66
Showing 1 changed file with 26 additions and 22 deletions.
48 changes: 26 additions & 22 deletions src/operator/tensor/ordering_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,8 @@ void TopKImpl(RunContext ctx,
// 3. Assign results to the ret blob
// When returning indices, only update(modulo) required elements instead of full elements
// to avoid redundant calculation.
// Cast `ret_indices` from int to real_t could introduce conversion error when the element_num
// is large enough.
if (param.ret_typ == topk_enum::kReturnMask) {
Tensor<xpu, 2, real_t> ret_mask =
ret[0].get_with_shape<xpu, 2, real_t>(Shape2(ret[0].Size(), 1), s);
Expand All @@ -452,20 +454,21 @@ void TopKImpl(RunContext ctx,
} else if (param.ret_typ == topk_enum::kReturnIndices) {
if (do_transpose) {
Tensor<xpu, 3, real_t> ret_indices = ret[0].FlatTo3D<xpu, real_t>(axis, axis, s);
ret_indices = tcast<real_t>(transpose(
slice<2>(inplace_reshape(indices,
Shape3(ret_indices.shape_[0],
ret_indices.shape_[2],
element_num)),
0, k),
Shape3(0, 2, 1)));
ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
ret_indices = tcast<real_t>(F<mshadow_op::mod>(
transpose(slice<2>(inplace_reshape(indices,
Shape3(ret_indices.shape_[0],
ret_indices.shape_[2],
element_num)),
0, k),
Shape3(0, 2, 1)),
element_num));
} else {
Tensor<xpu, 2, real_t> ret_indices =
ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
ret_indices = tcast<real_t>(slice<1>(
inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k));
ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
ret_indices = tcast<real_t>(F<mshadow_op::mod>(
slice<1>(inplace_reshape(indices, Shape2(batch_size, element_num)),
0, k),
element_num));
}
} else {
if (do_transpose) {
Expand All @@ -476,23 +479,24 @@ void TopKImpl(RunContext ctx,
Shape3(ret_value.shape_[0], ret_value.shape_[2], element_num)),
0, k),
Shape3(0, 2, 1));
ret_indices = tcast<real_t>(transpose(
slice<2>(inplace_reshape(indices,
Shape3(ret_indices.shape_[0],
ret_indices.shape_[2],
element_num)),
0, k),
Shape3(0, 2, 1)));
ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
ret_indices = tcast<real_t>(F<mshadow_op::mod>(
transpose(slice<2>(inplace_reshape(indices,
Shape3(ret_indices.shape_[0],
ret_indices.shape_[2],
element_num)),
0, k),
Shape3(0, 2, 1)),
element_num));
} else {
Tensor<xpu, 2, real_t> ret_value =
ret[0].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
Tensor<xpu, 2, real_t> ret_indices =
ret[1].get_with_shape<xpu, 2, real_t>(Shape2(batch_size, k), s);
ret_value = slice<1>(inplace_reshape(sorted_dat, Shape2(batch_size, element_num)), 0, k);
ret_indices = tcast<real_t>(slice<1>(
inplace_reshape(indices, Shape2(batch_size, element_num)), 0, k));
ret_indices = F<mshadow_op::mod>(ret_indices, element_num);
ret_indices = tcast<real_t>(F<mshadow_op::mod>(
slice<1>(inplace_reshape(indices, Shape2(batch_size, element_num)),
0, k),
element_num));
}
}
}
Expand Down

0 comments on commit 53ccc66

Please sign in to comment.