diff --git a/tests/python/unittest/test_executor.py b/tests/python/unittest/test_executor.py index 3117f6646481..2bc696fd4e43 100644 --- a/tests/python/unittest/test_executor.py +++ b/tests/python/unittest/test_executor.py @@ -72,7 +72,7 @@ def check_bind_with_uniform(uf, gf, dim, sf=None, lshape=None, rshape=None): assert_almost_equal(rhs_grad.asnumpy(), rhs_grad2, rtol=1e-5, atol=1e-5) -@with_seed(0) +@with_seed() def test_bind(): def check_bind(disable_bulk_exec): if disable_bulk_exec: @@ -97,11 +97,11 @@ def check_bind(disable_bulk_exec): dim) check_bind_with_uniform(lambda x, y: np.maximum(x, y), - lambda g, x, y: (g * (x>y), g * (y>x)), + lambda g, x, y: (g * (x>=y), g * (y>x)), dim, sf=mx.symbol.maximum) check_bind_with_uniform(lambda x, y: np.minimum(x, y), - lambda g, x, y: (g * (x