Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix test file
Browse files Browse the repository at this point in the history
  • Loading branch information
gyshi committed Sep 2, 2019
1 parent 9da384c commit bea9be8
Showing 1 changed file with 22 additions and 54 deletions.
76 changes: 22 additions & 54 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1522,69 +1522,37 @@ def hybrid_forward(self, F, a):
@with_seed()
@use_np
def test_np_windows():
class TestBlackman(HybridBlock):
def __init__(self, M, dtype):
super(TestBlackman, self).__init__()
class TestWindows(HybridBlock):
def __init__(self, func, M, dtype):
super(TestWindows, self).__init__()
self._func = func
self._M = M
self._dtype = dtype

def hybrid_forward(self, F, x, *args, **kwargs):
return x + F.np.blackman(M=self._M, dtype=self._dtype)

@use_np
class TestHamming(HybridBlock):
def __init__(self, M, dtype):
super(TestHamming, self).__init__()
self._M = M
self._dtype = dtype

def hybrid_forward(self, F, x, *args, **kwargs):
return x + F.np.hamming(M=self._M, dtype=self._dtype)

@use_np
class TestHanning(HybridBlock):
def __init__(self, M, dtype):
super(TestHanning, self).__init__()
self._M = M
self._dtype = dtype

def hybrid_forward(self, F, x, *args, **kwargs):
return x + F.np.hanning(M=self._M, dtype=self._dtype)
op = getattr(F.np, self._func)
assert op is not None
return x + op(M=self._M, dtype=self._dtype)

configs = [-10, -3, -1, 0, 1, 6, 10, 20]
dtypes = ['float32', 'float64']

funcs = ['hanning', 'hamming', 'blackman']
for config in configs:
for dtype in dtypes:
x = np.zeros(shape=(), dtype=dtype)
for hybridize in [False, True]:
net_hanning = TestHanning(M=config, dtype=dtype)
net_hamming = TestHamming(M=config, dtype=dtype)
net_blackman = TestBlackman(M=config, dtype=dtype)
np_out_hanning = _np.hanning(M=config)
np_out_hamming = _np.hamming(M=config)
np_out_blackman = _np.blackman(M=config)
if hybridize:
net_hanning.hybridize()
net_hamming.hybridize()
net_blackman.hybridize()

mx_out_hanning = net_hanning(x)
mx_out_hamming = net_hamming(x)
mx_out_blackman = net_blackman(x)
assert_almost_equal(mx_out_hanning.asnumpy(), np_out_hanning, rtol=1e-3, atol=1e-5)
assert_almost_equal(mx_out_hamming.asnumpy(), np_out_hamming, rtol=1e-3, atol=1e-5)
assert_almost_equal(mx_out_blackman.asnumpy(), np_out_blackman, rtol=1e-3, atol=1e-5)
# test imperative
mx_out_hanning = np.hanning(M=config, dtype=dtype)
mx_out_hamming = np.hamming(M=config, dtype=dtype)
mx_out_blackman = np.blackman(M=config, dtype=dtype)
np_out_hanning = _np.hanning(M=config)
np_out_hamming = _np.hamming(M=config)
np_out_blackman = _np.blackman(M=config)
assert_almost_equal(mx_out_hanning.asnumpy(), np_out_hanning, rtol=1e-3, atol=1e-5)
assert_almost_equal(mx_out_hamming.asnumpy(), np_out_hamming, rtol=1e-3, atol=1e-5)
assert_almost_equal(mx_out_blackman.asnumpy(), np_out_blackman, rtol=1e-3, atol=1e-5)
for func in funcs:
x = np.zeros(shape=(), dtype=dtype)
for hybridize in [False, True]:
np_func = getattr(_np, func)
mx_func = TestWindows(func, M=config, dtype=dtype)
np_out = np_func(M=config).astype(dtype)
if hybridize:
mx_func.hybridize()
mx_out = mx_func(x)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)
# test imperative
mx_out = getattr(np, func)(M=config, dtype=dtype)
np_out = np_func(M=config).astype(dtype)
assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
Expand Down

0 comments on commit bea9be8

Please sign in to comment.