From 7a01b1179e72c56289e43c771a44944d9357c435 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 14 Dec 2023 08:34:49 -0700 Subject: [PATCH 1/3] Add a test for issue to be resolved --- scico/test/test_blockarray.py | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/scico/test/test_blockarray.py b/scico/test/test_blockarray.py index 4b85476a4..101201e7d 100644 --- a/scico/test/test_blockarray.py +++ b/scico/test/test_blockarray.py @@ -9,8 +9,10 @@ import scico.numpy as snp from scico.numpy import BlockArray +from scico.numpy._wrapped_function_lists import testing_functions from scico.numpy.testing import assert_array_equal from scico.random import randn +from scico.util import rgetattr math_ops = [op.add, op.sub, op.mul, op.truediv, op.pow] # op.floordiv doesn't work on complex comp_ops = [op.le, op.lt, op.ge, op.gt, op.eq] @@ -86,7 +88,8 @@ def test_ba_ba_operator(test_operator_obj, operator): snp.testing.assert_allclose(x, y) -# Testing the @ interface for blockarrays of same size, and a blockarray and flattened ndarray/devicearray +# Testing the @ interface for blockarrays of same size, and a blockarray and flattened +# ndarray/devicearray def test_ba_ba_matmul(test_operator_obj): a = test_operator_obj.a b = test_operator_obj.d @@ -135,20 +138,20 @@ def test_ndim(test_operator_obj): def test_getitem(test_operator_obj): - # Make a length-4 blockarray + # make a length-4 blockarray a0 = test_operator_obj.a0 a1 = test_operator_obj.a1 b0 = test_operator_obj.b0 b1 = test_operator_obj.b1 x = BlockArray([a0, a1, b0, b1]) - # Positive indexing + # positive indexing np.testing.assert_allclose(x[0], a0) np.testing.assert_allclose(x[1], a1) np.testing.assert_allclose(x[2], b0) np.testing.assert_allclose(x[3], b1) - # Negative indexing + # negative indexing np.testing.assert_allclose(x[-4], a0) np.testing.assert_allclose(x[-3], a1) np.testing.assert_allclose(x[-2], b0) @@ -193,9 +196,7 @@ def test_ba_ba_dot(test_operator_obj, operator): snp.testing.assert_allclose(x, y) -############################################################################### -# Reduction tests -############################################################################### +# reduction tests reduction_funcs = [ snp.sum, snp.linalg.norm, @@ -315,6 +316,16 @@ def test_full_nodtype(self): assert snp.all(x == fill_value) +# testing function tests +@pytest.mark.parametrize("func", testing_functions) +def test_test_func(func): + a = snp.array([1.0, 2.0]) + b = snp.blockarray((a, a)) + f = rgetattr(snp, func) + retval = f(b, b) + assert retval is None + + # tests added for the BlockArray refactor @pytest.fixture def x(): From 36703cb258aa66876b99c7e8e5ce627c067f3711 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 14 Dec 2023 08:59:06 -0700 Subject: [PATCH 2/3] Distinct handling for functions without any return value --- scico/numpy/__init__.py | 2 +- scico/numpy/_wrappers.py | 66 +++++++++++++++++++++++++++++----------- 2 files changed, 49 insertions(+), 19 deletions(-) diff --git a/scico/numpy/__init__.py b/scico/numpy/__init__.py index 99570fe9e..83e495578 100644 --- a/scico/numpy/__init__.py +++ b/scico/numpy/__init__.py @@ -42,7 +42,7 @@ _wrappers.wrap_recursively(vars(), reduction_functions, _wrappers.add_full_reduction) # wrap testing funcs -_wrappers.wrap_recursively(vars(), testing_functions, _wrappers.map_func_over_blocks) +_wrappers.wrap_recursively(vars(), testing_functions, _wrappers.map_void_func_over_blocks) # clean up del np, jnp, _wrappers diff --git a/scico/numpy/_wrappers.py b/scico/numpy/_wrappers.py index d1f456935..1bc45ebf8 100644 --- a/scico/numpy/_wrappers.py +++ b/scico/numpy/_wrappers.py @@ -96,6 +96,32 @@ def mapped(*args, **kwargs): return mapped +def _num_blocks_in_args(*args, **kwargs): + """Count the number of BlockArray arguments.""" + first_ba_arg = next((arg for arg in args if isinstance(arg, BlockArray)), None) + if first_ba_arg is None: + first_ba_kwarg = next((v for k, v in kwargs.items() if isinstance(v, BlockArray)), None) + if first_ba_kwarg is None: + num_blocks = 0 + else: + num_blocks = len(first_ba_kwarg) + else: + num_blocks = len(first_ba_arg) + return num_blocks + + +def _block_args_kwargs(num_blocks, *args, **kwargs): + """Construct nested args/kwargs for each BlockArrays block.""" + new_args = [] + new_kwargs = [] + for i in range(num_blocks): + new_args.append([arg[i] if isinstance(arg, BlockArray) else arg for arg in args]) + new_kwargs.append( + {k: (v[i] if isinstance(v, BlockArray) else v) for k, v in kwargs.items()} + ) + return new_args, new_kwargs + + def map_func_over_blocks(func): """Wrap a function so that it maps over all of its BlockArray arguments. @@ -103,27 +129,31 @@ def map_func_over_blocks(func): @wraps(func) def mapped(*args, **kwargs): + num_blocks = _num_blocks_in_args(*args, **kwargs) + if num_blocks == 0: + return func(*args, **kwargs) # no BlockArray arguments, so no mapping + new_args, new_kwargs = _block_args_kwargs(num_blocks, *args, **kwargs) - first_ba_arg = next((arg for arg in args if isinstance(arg, BlockArray)), None) - if first_ba_arg is None: - first_ba_kwarg = next((v for k, v in kwargs.items() if isinstance(v, BlockArray)), None) - if first_ba_kwarg is None: - return func(*args, **kwargs) # no BlockArray arguments, so no mapping - num_blocks = len(first_ba_kwarg) - else: - num_blocks = len(first_ba_arg) + # run the function num_blocks times, return results in a BlockArray + return BlockArray(func(*new_args[i], **new_kwargs[i]) for i in range(num_blocks)) - # build a list of new args and kwargs, one for each block - new_args_list = [] - new_kwargs_list = [] - for i in range(num_blocks): - new_args_list.append([arg[i] if isinstance(arg, BlockArray) else arg for arg in args]) - new_kwargs_list.append( - {k: (v[i] if isinstance(v, BlockArray) else v) for k, v in kwargs.items()} - ) + return mapped - # run the function num_blocks times, return results in a BlockArray - return BlockArray(func(*new_args_list[i], **new_kwargs_list[i]) for i in range(num_blocks)) + +def map_void_func_over_blocks(func): + """Wrap a function without a return value so that it maps over all + of its BlockArray arguments. + """ + + @wraps(func) + def mapped(*args, **kwargs): + num_blocks = _num_blocks_in_args(*args, **kwargs) + if num_blocks == 0: + func(*args, **kwargs) # no BlockArray arguments, so no mapping + new_args, new_kwargs = _block_args_kwargs(num_blocks, *args, **kwargs) + + # run the function num_blocks times + [func(*new_args[i], **new_kwargs[i]) for i in range(num_blocks)] return mapped From 95a332834fb9941fd02df76e5c075bb98f98c9f3 Mon Sep 17 00:00:00 2001 From: Brendt Wohlberg Date: Thu, 14 Dec 2023 09:19:21 -0700 Subject: [PATCH 3/3] Resolve oversight identified in PR review --- scico/numpy/_wrappers.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/scico/numpy/_wrappers.py b/scico/numpy/_wrappers.py index 1bc45ebf8..72db1af14 100644 --- a/scico/numpy/_wrappers.py +++ b/scico/numpy/_wrappers.py @@ -111,7 +111,7 @@ def _num_blocks_in_args(*args, **kwargs): def _block_args_kwargs(num_blocks, *args, **kwargs): - """Construct nested args/kwargs for each BlockArrays block.""" + """Construct nested args/kwargs for each BlockArray block.""" new_args = [] new_kwargs = [] for i in range(num_blocks): @@ -133,7 +133,6 @@ def mapped(*args, **kwargs): if num_blocks == 0: return func(*args, **kwargs) # no BlockArray arguments, so no mapping new_args, new_kwargs = _block_args_kwargs(num_blocks, *args, **kwargs) - # run the function num_blocks times, return results in a BlockArray return BlockArray(func(*new_args[i], **new_kwargs[i]) for i in range(num_blocks)) @@ -150,10 +149,10 @@ def mapped(*args, **kwargs): num_blocks = _num_blocks_in_args(*args, **kwargs) if num_blocks == 0: func(*args, **kwargs) # no BlockArray arguments, so no mapping - new_args, new_kwargs = _block_args_kwargs(num_blocks, *args, **kwargs) - - # run the function num_blocks times - [func(*new_args[i], **new_kwargs[i]) for i in range(num_blocks)] + else: + new_args, new_kwargs = _block_args_kwargs(num_blocks, *args, **kwargs) + # run the function num_blocks times + [func(*new_args[i], **new_kwargs[i]) for i in range(num_blocks)] return mapped