Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve #483 #484

Merged
merged 3 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scico/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
_wrappers.wrap_recursively(vars(), reduction_functions, _wrappers.add_full_reduction)

# wrap testing funcs
_wrappers.wrap_recursively(vars(), testing_functions, _wrappers.map_func_over_blocks)
_wrappers.wrap_recursively(vars(), testing_functions, _wrappers.map_void_func_over_blocks)

# clean up
del np, jnp, _wrappers
66 changes: 48 additions & 18 deletions scico/numpy/_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,34 +96,64 @@ def mapped(*args, **kwargs):
return mapped


def _num_blocks_in_args(*args, **kwargs):
"""Count the number of BlockArray arguments."""
first_ba_arg = next((arg for arg in args if isinstance(arg, BlockArray)), None)
if first_ba_arg is None:
first_ba_kwarg = next((v for k, v in kwargs.items() if isinstance(v, BlockArray)), None)
if first_ba_kwarg is None:
num_blocks = 0
else:
num_blocks = len(first_ba_kwarg)
else:
num_blocks = len(first_ba_arg)
return num_blocks


def _block_args_kwargs(num_blocks, *args, **kwargs):
"""Construct nested args/kwargs for each BlockArrays block."""
new_args = []
new_kwargs = []
for i in range(num_blocks):
new_args.append([arg[i] if isinstance(arg, BlockArray) else arg for arg in args])
new_kwargs.append(
{k: (v[i] if isinstance(v, BlockArray) else v) for k, v in kwargs.items()}
)
return new_args, new_kwargs


def map_func_over_blocks(func):
"""Wrap a function so that it maps over all of its BlockArray
arguments.
"""

@wraps(func)
def mapped(*args, **kwargs):
num_blocks = _num_blocks_in_args(*args, **kwargs)
if num_blocks == 0:
return func(*args, **kwargs) # no BlockArray arguments, so no mapping
new_args, new_kwargs = _block_args_kwargs(num_blocks, *args, **kwargs)

first_ba_arg = next((arg for arg in args if isinstance(arg, BlockArray)), None)
if first_ba_arg is None:
first_ba_kwarg = next((v for k, v in kwargs.items() if isinstance(v, BlockArray)), None)
if first_ba_kwarg is None:
return func(*args, **kwargs) # no BlockArray arguments, so no mapping
num_blocks = len(first_ba_kwarg)
else:
num_blocks = len(first_ba_arg)
# run the function num_blocks times, return results in a BlockArray
return BlockArray(func(*new_args[i], **new_kwargs[i]) for i in range(num_blocks))

# build a list of new args and kwargs, one for each block
new_args_list = []
new_kwargs_list = []
for i in range(num_blocks):
new_args_list.append([arg[i] if isinstance(arg, BlockArray) else arg for arg in args])
new_kwargs_list.append(
{k: (v[i] if isinstance(v, BlockArray) else v) for k, v in kwargs.items()}
)
return mapped

# run the function num_blocks times, return results in a BlockArray
return BlockArray(func(*new_args_list[i], **new_kwargs_list[i]) for i in range(num_blocks))

def map_void_func_over_blocks(func):
"""Wrap a function without a return value so that it maps over all
of its BlockArray arguments.
"""

@wraps(func)
def mapped(*args, **kwargs):
num_blocks = _num_blocks_in_args(*args, **kwargs)
if num_blocks == 0:
func(*args, **kwargs) # no BlockArray arguments, so no mapping
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest return immediately after calling func inside this if

new_args, new_kwargs = _block_args_kwargs(num_blocks, *args, **kwargs)

# run the function num_blocks times
[func(*new_args[i], **new_kwargs[i]) for i in range(num_blocks)]

return mapped

Expand Down
25 changes: 18 additions & 7 deletions scico/test/test_blockarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@

import scico.numpy as snp
from scico.numpy import BlockArray
from scico.numpy._wrapped_function_lists import testing_functions
from scico.numpy.testing import assert_array_equal
from scico.random import randn
from scico.util import rgetattr

math_ops = [op.add, op.sub, op.mul, op.truediv, op.pow] # op.floordiv doesn't work on complex
comp_ops = [op.le, op.lt, op.ge, op.gt, op.eq]
Expand Down Expand Up @@ -86,7 +88,8 @@ def test_ba_ba_operator(test_operator_obj, operator):
snp.testing.assert_allclose(x, y)


# Testing the @ interface for blockarrays of same size, and a blockarray and flattened ndarray/devicearray
# Testing the @ interface for blockarrays of same size, and a blockarray and flattened
# ndarray/devicearray
def test_ba_ba_matmul(test_operator_obj):
a = test_operator_obj.a
b = test_operator_obj.d
Expand Down Expand Up @@ -135,20 +138,20 @@ def test_ndim(test_operator_obj):


def test_getitem(test_operator_obj):
# Make a length-4 blockarray
# make a length-4 blockarray
a0 = test_operator_obj.a0
a1 = test_operator_obj.a1
b0 = test_operator_obj.b0
b1 = test_operator_obj.b1
x = BlockArray([a0, a1, b0, b1])

# Positive indexing
# positive indexing
np.testing.assert_allclose(x[0], a0)
np.testing.assert_allclose(x[1], a1)
np.testing.assert_allclose(x[2], b0)
np.testing.assert_allclose(x[3], b1)

# Negative indexing
# negative indexing
np.testing.assert_allclose(x[-4], a0)
np.testing.assert_allclose(x[-3], a1)
np.testing.assert_allclose(x[-2], b0)
Expand Down Expand Up @@ -193,9 +196,7 @@ def test_ba_ba_dot(test_operator_obj, operator):
snp.testing.assert_allclose(x, y)


###############################################################################
# Reduction tests
###############################################################################
# reduction tests
reduction_funcs = [
snp.sum,
snp.linalg.norm,
Expand Down Expand Up @@ -315,6 +316,16 @@ def test_full_nodtype(self):
assert snp.all(x == fill_value)


# testing function tests
@pytest.mark.parametrize("func", testing_functions)
def test_test_func(func):
a = snp.array([1.0, 2.0])
b = snp.blockarray((a, a))
f = rgetattr(snp, func)
retval = f(b, b)
assert retval is None


# tests added for the BlockArray refactor
@pytest.fixture
def x():
Expand Down
Loading