Skip to content

Commit

Permalink
[cherry-pick 2.6] Fix bug of put_along_axis/take_along_axis (#62065)
Browse files Browse the repository at this point in the history
* 【Hackathon 5th No.6】 为 Paddle 增强put_along_axis API -part (#59674)

* fix bug of put_along_axis (#60551)

* Improve the performence of put_along_axis (#60618)

* fix bug of put_along_axis

* improve performence of put_along_axis

* [Bug-Fix] fix compile bug of cudaxxxAsync (#60934)

---------

Co-authored-by: YibLiu <[email protected]>
  • Loading branch information
zhwesky2010 and YibinLiu666 authored Feb 28, 2024
1 parent 609f55e commit 3a083c3
Show file tree
Hide file tree
Showing 21 changed files with 2,822 additions and 177 deletions.
4 changes: 2 additions & 2 deletions paddle/phi/api/yaml/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1760,8 +1760,8 @@
optional : boxes_num

- backward_op : put_along_axis_grad
forward : put_along_axis (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign") -> Tensor(out)
args : (Tensor arr, Tensor indices, Tensor out_grad, int axis, str reduce)
forward : put_along_axis (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign", bool include_self = true) -> Tensor(out)
args : (Tensor arr, Tensor indices, Tensor values, Tensor out, Tensor out_grad, int axis, str reduce, bool include_self)
output : Tensor(arr_grad), Tensor(values_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/op_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2432,7 +2432,7 @@
outputs :
out : Result
attrs :
{axis : Axis, reduce : Reduce}
{axis : Axis, reduce : Reduce, include_self: Include_self}

- op : pylayer
backward : pylayer_grad
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2032,7 +2032,7 @@
backward : psroi_pool_grad

- op : put_along_axis
args : (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign")
args : (Tensor arr, Tensor indices, Tensor values, int axis, str reduce = "assign", bool include_self = true)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
Expand Down
176 changes: 176 additions & 0 deletions paddle/phi/backends/gpu/gpu_primitives.h
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,182 @@ CUDA_ATOMIC_WRAPPER(Add, complex<double>) {
CudaAtomicAdd(imag, val.imag));
}

// For atomicMul.
CUDA_ATOMIC_WRAPPER(Mul, int) {
int res = *address, old = res; // NOLINT
do {
old = res;
res = atomicCAS(address, // NOLINT
old, // NOLINT
val * old); // NOLINT
} while (old != res);
return res;
}

CUDA_ATOMIC_WRAPPER(Mul, unsigned int) {
unsigned int res = *address, old = res; // NOLINT
do {
old = res;
res = atomicCAS(address, // NOLINT
old, // NOLINT
val * old); // NOLINT
} while (old != res);
return res;
}
// CUDA API uses unsigned long long int, we cannot use uint64_t here.
// It because unsigned long long int is not necessarily uint64_t
CUDA_ATOMIC_WRAPPER(Mul, unsigned long long int) { // NOLINT
unsigned long long int old = *address, assumed; // NOLINT

do {
assumed = old;
old = atomicCAS(address, assumed, val * assumed);
} while (assumed != old);
return old;
}

CUDA_ATOMIC_WRAPPER(Mul, int64_t) {
// Here, we check long long int must be int64_t.
static_assert(sizeof(int64_t) == sizeof(long long int), // NOLINT
"long long should be int64");
long long int res = *address, old = res; // NOLINT
do {
old = res;
res = (long long int)atomicCAS( // NOLINT
(unsigned long long int *)address, // NOLINT
(unsigned long long int)old, // NOLINT
(unsigned long long int)val * (unsigned long long int)old); // NOLINT
} while (old != res);
return res;
}

CUDA_ATOMIC_WRAPPER(Mul, float) {
int *const address_as_i = reinterpret_cast<int *>(address);
int old = *address_as_i, assumed;

do {
assumed = old;
old = atomicCAS(
address_as_i, assumed, __float_as_int(val * __int_as_float(assumed)));
} while (assumed != old);

return __int_as_float(old);
}

CUDA_ATOMIC_WRAPPER(Mul, double) {
unsigned long long int *const address_as_ull = // NOLINT
reinterpret_cast<unsigned long long int *>(address); // NOLINT
unsigned long long int old = *address_as_ull, assumed; // NOLINT

do {
assumed = old;

old = atomicCAS(address_as_ull,
assumed,
__double_as_longlong(val * __longlong_as_double(assumed)));
} while (assumed != old);

return __longlong_as_double(old);
}

#ifdef PADDLE_CUDA_FP16
inline static __device__ uint32_t mul_to_low_half(uint32_t val, float x) {
phi::dtype::float16 low_half;
// The float16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half = static_cast<phi::dtype::float16>(static_cast<float>(low_half) * x);
return (val & 0xFFFF0000u) | low_half.x;
}

inline static __device__ uint32_t mul_to_high_half(uint32_t val, float x) {
phi::dtype::float16 high_half;
// The float16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half =
static_cast<phi::dtype::float16>(static_cast<float>(high_half) * x);
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}

CUDA_ATOMIC_WRAPPER(Mul, phi::dtype::float16) {
if (*address >= val) {
return *address;
}
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// The float16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, mul_to_low_half(assumed, val_f));
} while (old != assumed);
phi::dtype::float16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// The float16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(address_as_ui, assumed, mul_to_high_half(assumed, val_f));
} while (old != assumed);
phi::dtype::float16 ret;
ret.x = old >> 16;
return ret;
}
}
#endif

inline static __device__ uint32_t bf16_mul_to_low_half(uint32_t val, float x) {
phi::dtype::bfloat16 low_half;
// The bfloat16 in lower 16bits
low_half.x = static_cast<uint16_t>(val & 0xFFFFu);
low_half =
static_cast<phi::dtype::bfloat16>(static_cast<float>(low_half) * x);
return (val & 0xFFFF0000u) | low_half.x;
}

inline static __device__ uint32_t bf16_mul_to_high_half(uint32_t val, float x) {
phi::dtype::bfloat16 high_half;
// The bfloat16 in higher 16bits
high_half.x = static_cast<uint16_t>(val >> 16);
high_half =
static_cast<phi::dtype::bfloat16>(static_cast<float>(high_half) * x);
return (val & 0xFFFFu) | (static_cast<uint32_t>(high_half.x) << 16);
}

CUDA_ATOMIC_WRAPPER(Mul, phi::dtype::bfloat16) {
uint32_t *address_as_ui = reinterpret_cast<uint32_t *>(
reinterpret_cast<char *>(address) -
(reinterpret_cast<uintptr_t>(address) & 0x02));
float val_f = static_cast<float>(val);
uint32_t old = *address_as_ui;
uint32_t assumed;
if (((uintptr_t)address & 0x02) == 0) {
// The bfloat16 value stay at lower 16 bits of the address.
do {
assumed = old;
old = atomicCAS(
address_as_ui, assumed, bf16_mul_to_low_half(assumed, val_f));
} while (old != assumed);
phi::dtype::bfloat16 ret;
ret.x = old & 0xFFFFu;
return ret;
} else {
// The bfloat16 value stay at higher 16 bits of the address.
do {
assumed = old;
old = atomicCAS(
address_as_ui, assumed, bf16_mul_to_high_half(assumed, val_f));
} while (old != assumed);
phi::dtype::bfloat16 ret;
ret.x = old >> 16;
return ret;
}
}

// For atomicMax
USE_CUDA_ATOMIC(Max, int);
USE_CUDA_ATOMIC(Max, unsigned int);
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/kernels/cpu/cum_maxmin_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ void CummaxGradKernel(const Context& dev_ctx,
}
if (dtype == DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
*x_grad, axis, indices, out_grad, true, dev_ctx);
} else if (dtype == DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
*x_grad, axis, indices, out_grad, true, dev_ctx);
}
}

Expand All @@ -61,10 +61,10 @@ void CumminGradKernel(const Context& dev_ctx,
}
if (dtype == DataType::INT32) {
phi::funcs::cpu_scatter_add_kernel<T, int32_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
*x_grad, axis, indices, out_grad, true, dev_ctx);
} else if (dtype == DataType::INT64) {
phi::funcs::cpu_scatter_add_kernel<T, int64_t>(
*x_grad, axis, indices, out_grad, dev_ctx);
*x_grad, axis, indices, out_grad, true, dev_ctx);
}
}

Expand Down
Loading

0 comments on commit 3a083c3

Please sign in to comment.