Skip to content

Commit

Permalink
add stable to argsort and sort
Browse files Browse the repository at this point in the history
  • Loading branch information
NKNaN committed Apr 15, 2024
1 parent 5322bd0 commit 1eb8f80
Show file tree
Hide file tree
Showing 13 changed files with 324 additions and 31 deletions.
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@
func : angle_grad

- backward_op : argsort_grad
forward : argsort (Tensor x, int axis, bool descending) -> Tensor(out), Tensor(indices)
args : (Tensor indices, Tensor x, Tensor out_grad, int axis, bool descending)
forward : argsort (Tensor x, int axis, bool descending, bool stable) -> Tensor(out), Tensor(indices)
args : (Tensor indices, Tensor x, Tensor out_grad, int axis, bool descending, bool stable)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : argsort
args : (Tensor x, int axis=-1, bool descending=false)
args : (Tensor x, int axis=-1, bool descending=false, bool stable=false)
output : Tensor(out), Tensor(indices)
infer_meta :
func : ArgsortInferMeta
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
void ArgsortInferMeta(const MetaTensor& input,
int axis,
bool descending,
bool stable,
MetaTensor* output,
MetaTensor* indices) {
auto in_dims = input.dims();
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/infermeta/unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
void ArgsortInferMeta(const MetaTensor& input,
int axis,
bool descending,
bool stable,
MetaTensor* output,
MetaTensor* indices);

Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/argsort_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ void ArgsortGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int axis,
bool descending,
bool stable,
DenseTensor* in_grad);

} // namespace phi
4 changes: 4 additions & 0 deletions paddle/phi/kernels/argsort_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ namespace phi {
* algorithm how to sort the input data.
* If descending is true, will sort by descending order,
* else if false, sort by ascending order
* @param stable Indicate whether to use stable sorting algorithm, which
* guarantees that the order of equivalent elements is
* preserved.
* @param out The sorted tensor of Argsort op, with the same shape as
* x
* @param indices The indices of a tensor giving the sorted order, with
Expand All @@ -43,6 +46,7 @@ void ArgsortKernel(const Context& dev_ctx,
const DenseTensor& input,
int axis,
bool descending,
bool stable,
DenseTensor* output,
DenseTensor* indices);

Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/cpu/argsort_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ void ArgsortGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int axis,
bool descending UNUSED,
bool stable UNUSED,
DenseTensor* in_grad) {
auto in_dims = indices.dims();
auto rank = input.dims().size();
Expand Down
50 changes: 35 additions & 15 deletions paddle/phi/kernels/cpu/argsort_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ static void FullSort(Type input_height,
const DenseTensor* input,
T* t_out,
Type* t_indices,
bool descending) {
bool descending,
bool stable) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
Expand All @@ -48,18 +49,34 @@ static void FullSort(Type input_height,
col_vec.push_back(std::pair<T, Type>(e_input(i, j), j));
}
}
std::sort(col_vec.begin(),
col_vec.end(),
[&](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (descending)
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
else
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
if (stable) {
std::stable_sort(
col_vec.begin(),
col_vec.end(),
[&](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (descending)
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
else
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
} else {
std::sort(col_vec.begin(),
col_vec.end(),
[&](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
if (descending)
return (std::isnan(static_cast<double>(l.first)) &&
!std::isnan(static_cast<double>(r.first))) ||
(l.first > r.first);
else
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
}

for (Type j = 0; j < input_width; ++j) {
t_out[i * input_width + j] = col_vec[j].first;
Expand All @@ -73,6 +90,7 @@ void ArgsortKernel(const Context& dev_ctx,
const DenseTensor& input,
int axis,
bool descending,
bool stable,
DenseTensor* output,
DenseTensor* indices) {
auto in_dims = input.dims();
Expand Down Expand Up @@ -100,7 +118,8 @@ void ArgsortKernel(const Context& dev_ctx,
&input,
out_data,
ids_data,
descending);
descending,
stable);
} else {
// If not full sort do transpose
std::vector<int> trans;
Expand Down Expand Up @@ -141,7 +160,8 @@ void ArgsortKernel(const Context& dev_ctx,
&trans_inp,
t_out,
t_ind,
descending);
descending,
stable);

dev_ctx.template Alloc<int64_t>(indices);
TransposeKernel<int64_t, Context>(dev_ctx, tmp_indices, trans, indices);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/kernels/gpu/argsort_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ void ArgsortGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
int axis,
bool descending,
bool stable,
DenseTensor* in_grad) {
dev_ctx.template Alloc<T>(in_grad);
phi::funcs::set_constant(dev_ctx, in_grad, static_cast<T>(0.0));
Expand Down
31 changes: 24 additions & 7 deletions paddle/phi/kernels/gpu/argsort_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ void ArgsortKernel(const Context& dev_ctx,
const DenseTensor& input,
int axis,
bool descending,
bool stable,
DenseTensor* output,
DenseTensor* indices) {
auto in_dims = input.dims();
Expand All @@ -251,14 +252,30 @@ void ArgsortKernel(const Context& dev_ctx,
// Compared to the following 'Special case for full sort', ascending sort is
// 34 times faster and descending sort is 31 times faster.
if (size == in_dims[axis]) {
thrust::sequence(thrust::device, ids_data, ids_data + size);
thrust::copy(thrust::device, in_data, in_data + size, out_data);
thrust::sort_by_key(thrust::device, out_data, out_data + size, ids_data);
if (descending) {
thrust::reverse(thrust::device, out_data, out_data + size);
thrust::reverse(thrust::device, ids_data, ids_data + size);
if (stable) {
thrust::sequence(thrust::device, ids_data, ids_data + size);
thrust::copy(thrust::device, in_data, in_data + size, out_data);
if (descending) {
thrust::stable_sort_by_key(thrust::device,
out_data,
out_data + size,
ids_data,
thrust::greater<KeyT>());
} else {
thrust::stable_sort_by_key(
thrust::device, out_data, out_data + size, ids_data);
}
return;
} else {
thrust::sequence(thrust::device, ids_data, ids_data + size);
thrust::copy(thrust::device, in_data, in_data + size, out_data);
thrust::sort_by_key(thrust::device, out_data, out_data + size, ids_data);
if (descending) {
thrust::reverse(thrust::device, out_data, out_data + size);
thrust::reverse(thrust::device, ids_data, ids_data + size);
}
return;
}
return;
}

// Special case for full sort, speedup ~190x.
Expand Down
18 changes: 12 additions & 6 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
__all__ = []


def argsort(x, axis=-1, descending=False, name=None):
def argsort(x, axis=-1, descending=False, stable=False, name=None):
"""
Sorts the input along the given axis, and returns the corresponding index tensor for the sorted output values. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
Expand All @@ -49,6 +49,9 @@ def argsort(x, axis=-1, descending=False, name=None):
descending (bool, optional) : Descending is a flag, if set to true,
algorithm will sort by descending order, else sort by
ascending order. Default is false.
stable (bool, optional): Whether to use stable sorting algorithm or not.
When using stable sorting algorithm, the order of equivalent elements
will be preserved. Default is False.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
Expand Down Expand Up @@ -100,7 +103,7 @@ def argsort(x, axis=-1, descending=False, name=None):
[0, 2, 1, 1]]])
"""
if in_dynamic_or_pir_mode():
_, ids = _C_ops.argsort(x, axis, descending)
_, ids = _C_ops.argsort(x, axis, descending, stable)
return ids
else:
check_variable_and_dtype(
Expand Down Expand Up @@ -129,7 +132,7 @@ def argsort(x, axis=-1, descending=False, name=None):
type='argsort',
inputs={'X': x},
outputs={'Out': out, 'Indices': ids},
attrs={'axis': axis, 'descending': descending},
attrs={'axis': axis, 'descending': descending, 'stable': stable},
)
return ids

Expand Down Expand Up @@ -500,7 +503,7 @@ def nonzero(x, as_tuple=False):
return tuple(list_out)


def sort(x, axis=-1, descending=False, name=None):
def sort(x, axis=-1, descending=False, stable=False, name=None):
"""
Sorts the input along the given axis, and returns the sorted output tensor. The default sort algorithm is ascending, if you want the sort algorithm to be descending, you must set the :attr:`descending` as True.
Expand All @@ -514,6 +517,9 @@ def sort(x, axis=-1, descending=False, name=None):
descending (bool, optional) : Descending is a flag, if set to true,
algorithm will sort by descending order, else sort by
ascending order. Default is false.
stable (bool, optional): Whether to use stable sorting algorithm or not.
When using stable sorting algorithm, the order of equivalent elements
will be preserved. Default is False.
name (str, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
Returns:
Expand Down Expand Up @@ -557,7 +563,7 @@ def sort(x, axis=-1, descending=False, name=None):
[5. 7. 7. 9.]]]
"""
if in_dynamic_or_pir_mode():
outs, _ = _C_ops.argsort(x, axis, descending)
outs, _ = _C_ops.argsort(x, axis, descending, stable)
return outs
else:
helper = LayerHelper("sort", **locals())
Expand All @@ -571,7 +577,7 @@ def sort(x, axis=-1, descending=False, name=None):
type='argsort',
inputs={'X': x},
outputs={'Out': out, 'Indices': ids},
attrs={'axis': axis, 'descending': descending},
attrs={'axis': axis, 'descending': descending, 'stable': stable},
)
return out

Expand Down
Loading

0 comments on commit 1eb8f80

Please sign in to comment.