From eebda28d66c809a52424a73193799c49512c168e Mon Sep 17 00:00:00 2001 From: Michael-T-McCann Date: Wed, 7 Feb 2024 09:12:03 -0700 Subject: [PATCH] Correctly check for nested tuple in map_func_over_tuple_of_tuples --- scico/numpy/_wrappers.py | 6 +++--- scico/test/numpy/test_numpy.py | 10 +++++++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/scico/numpy/_wrappers.py b/scico/numpy/_wrappers.py index 72db1af14..b14be9006 100644 --- a/scico/numpy/_wrappers.py +++ b/scico/numpy/_wrappers.py @@ -17,6 +17,8 @@ import jax.numpy as jnp +import scico.numpy as snp + from ._blockarray import BlockArray @@ -83,9 +85,7 @@ def mapped(*args, **kwargs): map_arg_val = bound_args.arguments.pop(map_arg_name) - if not isinstance(map_arg_val, tuple) or not all( - isinstance(x, tuple) for x in map_arg_val - ): # not nested tuple + if not snp.util.is_nested(map_arg_val): # not nested tuple return func(*args, **kwargs) # no mapping # map diff --git a/scico/test/numpy/test_numpy.py b/scico/test/numpy/test_numpy.py index 8dcf380f2..1d3efed4a 100644 --- a/scico/test/numpy/test_numpy.py +++ b/scico/test/numpy/test_numpy.py @@ -248,9 +248,17 @@ def test_ufunc_conj(): def test_create_zeros(): A = snp.zeros(2) assert np.all(A == 0) + assert isinstance(A, jax.Array) + + A = snp.zeros((2,)) + assert isinstance(A, jax.Array) A = snp.zeros(((2,), (2,))) assert all(snp.all(A == 0)) + assert isinstance(A, snp.BlockArray) + + A = snp.zeros(()) + assert isinstance(A, jax.Array) # from issue 499 def test_create_ones(): @@ -261,7 +269,7 @@ def test_create_ones(): assert all(snp.all(A == 1)) -def test_create_zeros(): +def test_create_empty(): A = snp.empty(2) assert np.all(A == 0)