Skip to content

Commit

Permalink
[AMP OP&Test] where support bf16/fp16 (#51137)
Browse files Browse the repository at this point in the history
* where op test

* update bfloat16

* fix

* fix windows ci

* update bfloat16 data

* fix bloat16 x

* reset

* fix randint

* add print

* add delta

* cancel print

* code style

* update revirew
  • Loading branch information
yangjianfengo1 authored Mar 9, 2023
1 parent 86e990d commit 2727ddd
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions python/paddle/fluid/tests/unittests/test_where_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
import unittest

import numpy as np
from op_test import OpTest
from op_test import OpTest, convert_float_to_uint16

import paddle
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
from paddle.fluid import Program, core, program_guard
from paddle.fluid.backward import append_backward


Expand Down Expand Up @@ -50,6 +50,50 @@ def init_config(self):
self.cond = np.ones((60, 2)).astype('bool')


class TestWhereFP16OP(TestWhereOp):
def init_config(self):
self.dtype = np.float16
self.x = np.random.uniform((-5), 5, (60, 2)).astype(self.dtype)
self.y = np.random.uniform((-5), 5, (60, 2)).astype(self.dtype)
self.cond = np.ones((60, 2)).astype('bool')


@unittest.skipIf(
not core.is_compiled_with_cuda()
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestWhereBF16OP(OpTest):
def setUp(self):
self.op_type = 'where'
self.dtype = np.uint16
self.python_api = paddle.where
self.init_config()
self.inputs = {
'Condition': self.cond,
'X': convert_float_to_uint16(self.x),
'Y': convert_float_to_uint16(self.y),
}
self.outputs = {
'Out': convert_float_to_uint16(np.where(self.cond, self.x, self.y))
}

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_eager=False)

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(
place, ['X', 'Y'], 'Out', check_eager=False, numeric_grad_delta=0.05
)

def init_config(self):
self.x = np.random.uniform((-5), 5, (60, 2)).astype(np.float32)
self.y = np.random.uniform((-5), 5, (60, 2)).astype(np.float32)
self.cond = np.random.randint(2, size=(60, 2)).astype('bool')


class TestWhereOp3(TestWhereOp):
def init_config(self):
self.x = np.random.uniform((-3), 5, (20, 2, 4)).astype('float64')
Expand Down

0 comments on commit 2727ddd

Please sign in to comment.