Skip to content

Commit

Permalink
Clean up post #500 bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
bwohlberg committed Feb 7, 2024
1 parent 45cf6a6 commit f9333f8
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions scico/linop/_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from functools import partial
from typing import Optional, Union

import jax.numpy as jnp

import scico.numpy as snp
from scico.numpy import Array, BlockArray
from scico.numpy.util import broadcast_nested_shapes, is_nested
Expand Down Expand Up @@ -181,7 +179,7 @@ def __init__(
diagonal = snp.ones(((),) * len(input_shape), dtype=input_dtype)
# diagonal = snp.ones(tuple((1,) * len(s) for s in input_shape), dtype=input_dtype)
else:
diagonal = jnp.ones((), dtype=input_dtype)
diagonal = snp.ones((), dtype=input_dtype)
# diagonal = snp.ones((1,) * len(input_shape), dtype=input_dtype)
super().__init__(
diagonal=diagonal,
Expand Down Expand Up @@ -219,10 +217,10 @@ def norm(self, ord=None): # pylint: disable=W0622
"""
N = self.input_size
if ord is None or ord == "fro":
return jnp.sqrt(N)
return snp.sqrt(N)
elif ord == "nuc":
return N * jnp.ones(())
return N * snp.ones(())
elif ord in (-snp.inf, -1, -2, 1, 2, snp.inf):
return jnp.ones(())
return snp.ones(())
else:
raise ValueError(f"Invalid value {ord} for parameter ord.")

0 comments on commit f9333f8

Please sign in to comment.