Skip to content

Commit

Permalink
Use rather than for output equality tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Aug 25, 2023
1 parent cbbc13e commit d2d1c15
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions tests/tensor/test_slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ def test_perform(self):
scipy_val = scipy.linalg.kron(a[np.newaxis, :], b).flatten()
else:
scipy_val = scipy.linalg.kron(a, b)
utt.assert_allclose(out, scipy_val)
np.testing.assert_allclose(out, scipy_val)

def test_numpy_2d(self):
for shp0 in [(2, 3)]:
Expand Down Expand Up @@ -578,7 +578,7 @@ def test_solve_discrete_lyapunov_via_direct_complex():
A = rng.normal(size=(N, N)) + rng.normal(size=(N, N)) * 1j
Q = rng.normal(size=(N, N))
X = f(A, Q)
utt.assert_allclose(A @ X @ A.conj().T - X + Q, 0.0)
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)

# TODO: the .conj() method currently does not have a gradient; add this test when gradients are implemented.
# utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)
Expand All @@ -596,7 +596,7 @@ def test_solve_discrete_lyapunov_via_bilinear():

X = f(A, Q)

utt.assert_allclose(A @ X @ A.conj().T - X + Q, 0.0)
np.testing.assert_array_less(A @ X @ A.conj().T - X + Q, 1e-12)
utt.verify_grad(solve_discrete_lyapunov, pt=[A, Q], rng=rng)


Expand All @@ -611,7 +611,9 @@ def test_solve_continuous_lyapunov():
Q = rng.normal(size=(N, N))
X = f(A, Q)

utt.assert_allclose(A @ X + X @ A.conj().T, Q)
Q_recovered = A @ X + X @ A.conj().T

np.testing.assert_allclose(Q_recovered.squeeze(), Q)
utt.verify_grad(solve_continuous_lyapunov, pt=[A, Q], rng=rng)


Expand All @@ -634,7 +636,7 @@ def test_solve_discrete_are_forward():
)

atol = 1e-4 if config.floatX == "float32" else 1e-12
utt.assert_allclose(res, np.zeros_like(res), atol=atol)
np.testing.assert_allclose(res, np.zeros_like(res), atol=atol)


def test_solve_discrete_are_grad():
Expand Down

0 comments on commit d2d1c15

Please sign in to comment.