Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix test error
Browse files Browse the repository at this point in the history
  • Loading branch information
Ubuntu committed Jun 16, 2020
1 parent c251837 commit 14e4b9e
Showing 1 changed file with 8 additions and 5 deletions.
13 changes: 8 additions & 5 deletions tests/python/unittest/test_numpy_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1505,9 +1505,14 @@ def check_index_update_forward(mx_ret, a, ind, val, ind_ndim, ind_num, eps):
else:
expect_tmp = val
mx_tmp = mx_ret[t_ind]
if _np.allclose(expect_tmp, mx_tmp, rtol=eps, atol=eps):
mx_ret[t_ind] = 0
a[t_ind] = 0
close_pos = _np.where(_np.isclose(expect_tmp, mx_tmp, rtol=eps, atol=eps))
if a[t_ind].ndim == 0:
if close_pos[0].size == 1:
mx_ret[t_ind] = 0
a[t_ind] = 0
else:
mx_ret[t_ind][close_pos] = 0
a[t_ind][close_pos] = 0
assert_almost_equal(mx_ret, a, rtol=eps, atol=eps)

def index_update_bwd(out_grad, a_grad, ind, val_grad, ind_ndim, ind_num, grad_req_a, grad_req_val):
Expand Down Expand Up @@ -1585,8 +1590,6 @@ def index_update_bwd(out_grad, a_grad, ind, val_grad, ind_ndim, ind_num, grad_re
itertools.product([True, False], grad_req, grad_req, dtypes, ['int32', 'int64']):
for a_shape, ind, val_shape ,ind_ndim, ind_num in configs:
eps = 1e-3
if sys.platform.startswith('linux'):
eps = 1e-2
atype = dtype
valtype = dtype
test_index_update = TestIndexUpdate()
Expand Down

0 comments on commit 14e4b9e

Please sign in to comment.