diff --git a/scico/linop/_diag.py b/scico/linop/_diag.py index 378a7d37e..38b596bba 100644 --- a/scico/linop/_diag.py +++ b/scico/linop/_diag.py @@ -15,8 +15,6 @@ from functools import partial from typing import Optional, Union -import jax.numpy as jnp - import scico.numpy as snp from scico.numpy import Array, BlockArray from scico.numpy.util import broadcast_nested_shapes, is_nested @@ -181,7 +179,7 @@ def __init__( diagonal = snp.ones(((),) * len(input_shape), dtype=input_dtype) # diagonal = snp.ones(tuple((1,) * len(s) for s in input_shape), dtype=input_dtype) else: - diagonal = jnp.ones((), dtype=input_dtype) + diagonal = snp.ones((), dtype=input_dtype) # diagonal = snp.ones((1,) * len(input_shape), dtype=input_dtype) super().__init__( diagonal=diagonal, @@ -219,10 +217,10 @@ def norm(self, ord=None): # pylint: disable=W0622 """ N = self.input_size if ord is None or ord == "fro": - return jnp.sqrt(N) + return snp.sqrt(N) elif ord == "nuc": - return N * jnp.ones(()) + return N * snp.ones(()) elif ord in (-snp.inf, -1, -2, 1, 2, snp.inf): - return jnp.ones(()) + return snp.ones(()) else: raise ValueError(f"Invalid value {ord} for parameter ord.")