Skip to content

Commit

Permalink
【PaddlePaddle Hackathon 4】No.56 :add fp and bf16 for bernoulli (#54232)
Browse files Browse the repository at this point in the history
* add fp&bf16 bernoulli

* add check_dtype & fix error

* fix rocm error
  • Loading branch information
Difers authored Jun 2, 2023
1 parent 17d6d93 commit 85d5f26
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 7 deletions.
15 changes: 12 additions & 3 deletions paddle/phi/kernels/gpu/bernoulli_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/distribution_helper.h"
Expand All @@ -51,11 +52,13 @@ __global__ void bernoulli_cuda_kernel(
for (size_t i = 4 * thread_idx; i < size; i += total_thread * 4) {
funcs::uniform_distribution<float> dist;
float4 rand = dist(&state);
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
#pragma unroll
for (size_t j = 0; j < 4; j++) {
size_t idx = i + j;
if (idx < size) {
out_data[idx] = static_cast<T>((&rand.x)[j] <= x_data[idx]);
out_data[idx] =
static_cast<T>((&rand.x)[j] <= static_cast<MPType>(x_data[idx]));
}
}
}
Expand Down Expand Up @@ -85,5 +88,11 @@ void BernoulliKernel(const Context& ctx,

} // namespace phi

PD_REGISTER_KERNEL(
bernoulli, GPU, ALL_LAYOUT, phi::BernoulliKernel, float, double) {}
PD_REGISTER_KERNEL(bernoulli,
GPU,
ALL_LAYOUT,
phi::BernoulliKernel,
phi::dtype::float16,
phi::dtype::bfloat16,
float,
double) {}
4 changes: 3 additions & 1 deletion python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ def bernoulli(x, name=None):
if in_dynamic_mode():
return _C_ops.bernoulli(x)
else:
check_variable_and_dtype(x, "x", ["float32", "float64"], "bernoulli")
check_variable_and_dtype(
x, "x", ["float32", "float64", "float16", "uint16"], "bernoulli"
)

helper = LayerHelper("randint", **locals())
out = helper.create_variable_for_type_inference(
Expand Down
47 changes: 44 additions & 3 deletions test/legacy_test/test_bernoulli_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
import unittest

import numpy as np
from eager_op_test import OpTest
from eager_op_test import OpTest, convert_float_to_uint16

import paddle
from paddle.fluid import core


def output_hist(out):
Expand All @@ -31,9 +32,18 @@ def output_hist(out):
class TestBernoulliOp(OpTest):
def setUp(self):
self.op_type = "bernoulli"
self.inputs = {"X": np.random.uniform(size=(1000, 784))}
self.init_dtype()
self.init_test_case()
self.inputs = {"X": self.x}
self.attrs = {}
self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")}
self.outputs = {"Out": self.out}

def init_dtype(self):
self.dtype = np.float32

def init_test_case(self):
self.x = np.random.uniform(size=(1000, 784)).astype(self.dtype)
self.out = np.zeros((1000, 784)).astype(self.dtype)

def test_check_output(self):
self.check_output_customized(self.verify_output)
Expand Down Expand Up @@ -98,5 +108,36 @@ def test_fixed_random_number(self):
paddle.enable_static()


class TestBernoulliFP16Op(TestBernoulliOp):
def init_dtype(self):
self.dtype = np.float16


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestBernoulliBF16Op(TestBernoulliOp):
def init_dtype(self):
self.dtype = np.uint16

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place_customized(self.verify_output, place)

def init_test_case(self):
self.x = convert_float_to_uint16(
np.random.uniform(size=(1000, 784)).astype("float32")
)
self.out = convert_float_to_uint16(
np.zeros((1000, 784)).astype("float32")
)

def verify_output(self, outs):
hist, prob = output_hist(np.array(outs[0]))
np.testing.assert_allclose(hist, prob, atol=0.01)


if __name__ == "__main__":
unittest.main()

0 comments on commit 85d5f26

Please sign in to comment.