Skip to content

Commit

Permalink
Merge pull request #4 from lanl/lukepfister/radon_test_tol
Browse files Browse the repository at this point in the history
Set test_radon tolerances based on CPU/GPU
  • Loading branch information
lukepfister authored Sep 23, 2021
2 parents 82f0132 + 0bf03eb commit c378ed0
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 16 deletions.
18 changes: 10 additions & 8 deletions scico/test/linop/test_linop.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from scico.typing import PRNGKey


def adjoint_AtA_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None):
def adjoint_AtA_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None, rtol: float = 1e-4):
"""Check the validity of A.conj().T as the adjoint for a LinearOperator A
Compares the quantity sum(x.conj() * A.conj().T @ A @ x) against
Expand All @@ -27,6 +27,7 @@ def adjoint_AtA_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None):
Args:
A : LinearOperator to test
key: PRNGKey for generating `x`.
rtol: Relative tolerance
"""

# Generate a signal in the domain of A
Expand All @@ -37,20 +38,20 @@ def adjoint_AtA_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None):
AtAx = A.conj().T @ Ax
num = snp.sum(x.conj() * AtAx)
den = snp.linalg.norm(Ax) ** 2
np.testing.assert_allclose(num / den, 1, rtol=1e-4)
np.testing.assert_allclose(num / den, 1, rtol=rtol)

AtAx = A.H @ Ax
num = snp.sum(x.conj() * AtAx)
den = snp.linalg.norm(Ax) ** 2
np.testing.assert_allclose(num / den, 1, rtol=1e-4)
np.testing.assert_allclose(num / den, 1, rtol=rtol)

AtAx = A.adj(Ax)
num = snp.sum(x.conj() * AtAx)
den = snp.linalg.norm(Ax) ** 2
np.testing.assert_allclose(num / den, 1, rtol=1e-4)
np.testing.assert_allclose(num / den, 1, rtol=rtol)


def adjoint_AAt_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None):
def adjoint_AAt_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None, rtol: float = 1e-4):
"""Check the validity of A as the adjoint for a LinearOperator A.conj().T
Compares the quantity sum(y.conj() * A @ A.conj().T @ y) against
Expand All @@ -59,6 +60,7 @@ def adjoint_AAt_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None):
Args:
A : LinearOperator to test
key: PRNGKey for generating `x`.
rtol: Relative tolerance
"""
# Generate a signal in the domain of A^T
y, key = randn(A.output_shape, dtype=A.output_dtype, key=key)
Expand All @@ -67,19 +69,19 @@ def adjoint_AAt_test(A: linop.LinearOperator, key: Optional[PRNGKey] = None):
AAty = A @ Aty
num = snp.sum(y.conj() * AAty)
den = snp.linalg.norm(Aty) ** 2
np.testing.assert_allclose(num / den, 1, rtol=1e-4)
np.testing.assert_allclose(num / den, 1, rtol=rtol)

Aty = A.H @ y
AAty = A @ Aty
num = snp.sum(y.conj() * AAty)
den = snp.linalg.norm(Aty) ** 2
np.testing.assert_allclose(num / den, 1, rtol=1e-4)
np.testing.assert_allclose(num / den, 1, rtol=rtol)

Aty = A.adj(y)
AAty = A @ Aty
num = snp.sum(y.conj() * AAty)
den = snp.linalg.norm(Aty) ** 2
np.testing.assert_allclose(num / den, 1, rtol=1e-4)
np.testing.assert_allclose(num / den, 1, rtol=rtol)


class AbsMatOp(linop.LinearOperator):
Expand Down
24 changes: 16 additions & 8 deletions scico/test/linop/test_radon.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
N = 128


def get_tol():
if jax.devices()[0].device_kind == "cpu":
rtol = 5e-5
else:
rtol = 7e-2
return rtol


class ParallelBeamProjectorTest:
def __init__(self, volume_geometry):
N_proj = 180 # number of projection angles
Expand Down Expand Up @@ -44,28 +52,28 @@ def test_ATA_call(testobj):
# Test for the call-based interface
Ax = testobj.A(testobj.x)
ATAx = testobj.A.adj(Ax)
np.testing.assert_allclose(np.sum(testobj.x * ATAx), np.linalg.norm(Ax) ** 2, rtol=5e-5)
np.testing.assert_allclose(np.sum(testobj.x * ATAx), np.linalg.norm(Ax) ** 2, rtol=get_tol())


def test_ATA_matmul(testobj):
# Test for the matmul interface
Ax = testobj.A @ testobj.x
ATAx = testobj.A.T @ Ax
np.testing.assert_allclose(np.sum(testobj.x * ATAx), np.linalg.norm(Ax) ** 2, rtol=5e-5)
np.testing.assert_allclose(np.sum(testobj.x * ATAx), np.linalg.norm(Ax) ** 2, rtol=get_tol())


def test_AAT_call(testobj):
# Test for the call-based interface
ATy = testobj.A.adj(testobj.y)
AATy = testobj.A(ATy)
np.testing.assert_allclose(np.sum(testobj.y * AATy), np.linalg.norm(ATy) ** 2, rtol=5e-5)
np.testing.assert_allclose(np.sum(testobj.y * AATy), np.linalg.norm(ATy) ** 2, rtol=get_tol())


def test_AAT_matmul(testobj):
# Test for the matmul interface
ATy = testobj.A.T @ testobj.y
AATy = testobj.A @ ATy
np.testing.assert_allclose(np.sum(testobj.y * AATy), np.linalg.norm(ATy) ** 2, rtol=5e-5)
np.testing.assert_allclose(np.sum(testobj.y * AATy), np.linalg.norm(ATy) ** 2, rtol=get_tol())


def test_grad(testobj):
Expand All @@ -74,18 +82,18 @@ def test_grad(testobj):
A = testobj.A
x = testobj.x
g = lambda x: jax.numpy.linalg.norm(A(x)) ** 2
np.testing.assert_allclose(scico.grad(g)(x), 2 * A.adj(A(x)), rtol=5e-5)
np.testing.assert_allclose(scico.grad(g)(x), 2 * A.adj(A(x)), rtol=get_tol())


def test_adjoint_grad(testobj):
A = testobj.A
x = testobj.x
Ax = A @ x
f = lambda y: jax.numpy.linalg.norm(A.T(y)) ** 2
np.testing.assert_allclose(scico.grad(f)(Ax), 2 * A(A.adj(Ax)), rtol=5e-5)
np.testing.assert_allclose(scico.grad(f)(Ax), 2 * A(A.adj(Ax)), rtol=get_tol())


def test_adjoint(testobj):
A = testobj.A
adjoint_AAt_test(A)
adjoint_AtA_test(A)
adjoint_AAt_test(A, rtol=get_tol())
adjoint_AtA_test(A, rtol=get_tol())

0 comments on commit c378ed0

Please sign in to comment.