Skip to content

Commit

Permalink
Add diagonal operator mapping base operator over an array axis (#521)
Browse files Browse the repository at this point in the history
* Add diagonal operators constructed by replication of a base operator

* Add documentation

* Improve error checking

* Fix allowed range for input_axis and allow negative values

* Rename class aliases as per PR review comments
  • Loading branch information
bwohlberg authored Jun 25, 2024
1 parent 81806fc commit c06bdb5
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 10 deletions.
3 changes: 2 additions & 1 deletion scico/linop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ._func import Crop, Pad, Reshape, Slice, Sum, Transpose, linop_from_function
from ._linop import ComposedLinearOperator, LinearOperator
from ._matrix import MatrixOperator
from ._stack import DiagonalStack, VerticalStack, linop_over_axes
from ._stack import DiagonalReplicated, DiagonalStack, VerticalStack, linop_over_axes
from ._util import jacobian, operator_norm, power_iteration, valid_adjoint
from .xray import Parallel2dProjector, XRayTransform

Expand All @@ -29,6 +29,7 @@
"FiniteDifference",
"SingleAxisFiniteDifference",
"Identity",
"DiagonalReplicated",
"VerticalStack",
"DiagonalStack",
"MatrixOperator",
Expand Down
86 changes: 82 additions & 4 deletions scico/linop/_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
import scico.numpy as snp
from scico.numpy import Array, BlockArray
from scico.numpy.util import normalize_axes
from scico.operator._stack import DiagonalStack as DStack
from scico.operator._stack import VerticalStack as VStack
from scico.operator._stack import DiagonalReplicated as DiagonalReplicatedOperator
from scico.operator._stack import DiagonalStack as DiagonalStackOperator
from scico.operator._stack import VerticalStack as VerticalStackOperator
from scico.typing import Axes, Shape

from ._linop import LinearOperator


class VerticalStack(VStack, LinearOperator):
class VerticalStack(VerticalStackOperator, LinearOperator):
r"""A vertical stack of linear operators.
Given linear operators :math:`A_1, A_2, \dots, A_N`, create the
Expand Down Expand Up @@ -71,7 +72,7 @@ def _adj(self, y: Union[Array, BlockArray]) -> Array: # type: ignore
return sum([op.adj(y_block) for y_block, op in zip(y, self.ops)]) # type: ignore


class DiagonalStack(DStack, LinearOperator):
class DiagonalStack(DiagonalStackOperator, LinearOperator):
r"""A diagonal stack of linear operators.
Given linear operators :math:`A_1, A_2, \dots, A_N`, create the
Expand Down Expand Up @@ -146,6 +147,83 @@ def _adj(self, y: Union[Array, BlockArray]) -> Union[Array, BlockArray]: # type
return snp.blockarray(result)


class DiagonalReplicated(DiagonalReplicatedOperator, LinearOperator):
r"""A diagonal stack constructed from a single linear operator.
Given linear operator :math:`A`, create the linear operator
.. math::
H =
\begin{pmatrix}
A & 0 & \ldots & 0\\
0 & A & \ldots & 0\\
\vdots & \vdots & \ddots & \vdots\\
0 & 0 & \ldots & A \\
\end{pmatrix} \qquad
\text{such that} \qquad
H
\begin{pmatrix}
\mb{x}_1 \\
\mb{x}_2 \\
\vdots \\
\mb{x}_N \\
\end{pmatrix}
=
\begin{pmatrix}
A(\mb{x}_1) \\
A(\mb{x}_2) \\
\vdots \\
A(\mb{x}_N) \\
\end{pmatrix} \;.
The application of :math:`A` to each component :math:`\mb{x}_k` is
computed using :func:`jax.pmap` or :func:`jax.vmap`. The input shape
for linear operator :math:`A` should exclude the array axis on which
:math:`A` is replicated to form :math:`H`. For example, if :math:`A`
has input shape `(3, 4)` and :math:`H` is constructed to replicate
on axis 0 with 2 replicates, the input shape of :math:`H` will be
`(2, 3, 4)`.
Linear operators taking :class:`.BlockArray` input are not supported.
"""

def __init__(
self,
op: LinearOperator,
replicates: int,
input_axis: int = 0,
output_axis: Optional[int] = None,
map_type: str = "auto",
**kwargs,
):
"""
Args:
op: Linear operator to replicate.
replicates: Number of replicates of `op`.
input_axis: Input axis over which `op` should be replicated.
output_axis: Index of replication axis in output array.
If ``None``, the input replication axis is used.
map_type: If "pmap" or "vmap", apply replicated mapping using
:func:`jax.pmap` or :func:`jax.vmap` respectively. If
"auto", use :func:`jax.pmap` if sufficient devices are
available for the number of replicates, otherwise use
:func:`jax.vmap`.
"""
if not isinstance(op, LinearOperator):
raise TypeError("Argument op must be of type LinearOperator.")

super().__init__(
op,
replicates,
input_axis=input_axis,
output_axis=output_axis,
map_type=map_type,
**kwargs,
)

self._adj = self.jaxmap(op.adj, in_axes=self.input_axis, out_axes=self.output_axis)


def linop_over_axes(
linop: type[LinearOperator],
input_shape: Shape,
Expand Down
5 changes: 3 additions & 2 deletions scico/operator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2021-2023 by SCICO Developers
# Copyright (C) 2021-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -13,11 +13,12 @@
from ._operator import Operator
from .biconvolve import BiConvolve
from ._func import operator_from_function, Abs, Angle, Exp
from ._stack import DiagonalStack, VerticalStack
from ._stack import DiagonalStack, VerticalStack, DiagonalReplicated

__all__ = [
"Operator",
"BiConvolve",
"DiagonalReplicated",
"DiagonalStack",
"VerticalStack",
"operator_from_function",
Expand Down
108 changes: 107 additions & 1 deletion scico/operator/_stack.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2023 by SCICO Developers
# Copyright (C) 2023-2024 by SCICO Developers
# All rights reserved. BSD 3-clause License.
# This file is part of the SCICO package. Details of the copyright and
# user license can be found in the 'LICENSE' file distributed with the
Expand All @@ -13,6 +13,8 @@

import numpy as np

import jax

from typing_extensions import TypeGuard

import scico.numpy as snp
Expand Down Expand Up @@ -234,3 +236,107 @@ def _eval(self, x: Union[Array, BlockArray]) -> Union[Array, BlockArray]:
if self.collapse_output:
return snp.stack(result)
return snp.blockarray(result)


class DiagonalReplicated(Operator):
r"""A diagonal stack constructed from a single operator.
Given operator :math:`A`, create the operator :math:`H` such that
.. math::
H \left(
\begin{pmatrix}
\mb{x}_1 \\
\mb{x}_2 \\
\vdots \\
\mb{x}_N \\
\end{pmatrix} \right)
=
\begin{pmatrix}
A(\mb{x}_1) \\
A(\mb{x}_2) \\
\vdots \\
A(\mb{x}_N) \\
\end{pmatrix} \;.
The application of :math:`A` to each component :math:`\mb{x}_k` is
computed using :func:`jax.pmap` or :func:`jax.vmap`. The input shape
for operator :math:`A` should exclude the array axis on which
:math:`A` is replicated to form :math:`H`. For example, if :math:`A`
has input shape `(3, 4)` and :math:`H` is constructed to replicate
on axis 0 with 2 replicates, the input shape of :math:`H` will be
`(2, 3, 4)`.
Operators taking :class:`.BlockArray` input are not supported.
"""

def __init__(
self,
op: Operator,
replicates: int,
input_axis: int = 0,
output_axis: Optional[int] = None,
map_type: str = "auto",
**kwargs,
):
"""
Args:
op: Operator to replicate.
replicates: Number of replicates of `op`.
input_axis: Input axis over which `op` should be replicated.
output_axis: Index of replication axis in output array.
If ``None``, the input replication axis is used.
map_type: If "pmap" or "vmap", apply replicated mapping using
:func:`jax.pmap` or :func:`jax.vmap` respectively. If
"auto", use :func:`jax.pmap` if sufficient devices are
available for the number of replicates, otherwise use
:func:`jax.vmap`.
"""
if map_type not in ["auto", "pmap", "vmap"]:
raise ValueError("Argument map_type must be one of 'auto', 'pmap, or 'vmap'.")
if input_axis < 0:
input_axis = len(op.input_shape) + 1 + input_axis
if input_axis < 0 or input_axis > len(op.input_shape):
raise ValueError(
"Argument input_axis must be positive and less than the number of axes "
"in the input shape of op."
)
if is_nested(op.input_shape):
raise ValueError("Argument op may not be an Operator taking BlockArray input.")
if is_nested(op.output_shape):
raise ValueError("Argument op may not be an Operator with BlockArray output.")
self.op = op
self.replicates = replicates
self.input_axis = input_axis
self.output_axis = self.input_axis if output_axis is None else output_axis

if map_type == "auto":
self.jaxmap = jax.pmap if replicates <= jax.device_count() else jax.vmap
else:
if map_type == "pmap" and replicates > jax.device_count():
raise ValueError(
"Requested pmap mapping but number of replicates exceeds device count."
)
else:
self.jaxmap = jax.pmap if map_type == "pmap" else jax.vmap

eval_fn = self.jaxmap(op.__call__, in_axes=self.input_axis, out_axes=self.output_axis)

input_shape = (
op.input_shape[0 : self.input_axis] + (replicates,) + op.input_shape[self.input_axis :]
)
output_shape = (
op.output_shape[0 : self.output_axis]
+ (replicates,)
+ op.output_shape[self.output_axis :]
)

super().__init__(
input_shape=input_shape, # type: ignore
output_shape=output_shape, # type: ignore
eval_fn=eval_fn,
input_dtype=op.input_dtype,
output_dtype=op.output_dtype,
jit=False,
**kwargs,
)
23 changes: 22 additions & 1 deletion scico/test/linop/test_linop_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@
import pytest

import scico.numpy as snp
from scico.linop import Convolve, DiagonalStack, Identity, Sum, VerticalStack
from scico.linop import (
Convolve,
DiagonalReplicated,
DiagonalStack,
Identity,
Sum,
VerticalStack,
)
from scico.operator import Abs
from scico.random import randn
from scico.test.linop.test_linop import adjoint_test


Expand Down Expand Up @@ -166,3 +174,16 @@ def test_output_collapse(self):

H = DiagonalStack((A1, A2), collapse_output=False)
assert H.output_shape == (S1, S1)


class TestDiagonalReplicated:
def setup_method(self, method):
self.key = jax.random.PRNGKey(12345)

def test_adjoint(self):
x, key = randn((2, 3, 4), key=self.key)
A = Sum(x.shape[1:], axis=-1)
D = DiagonalReplicated(A, x.shape[0])
y = D.T(D(x))
np.testing.assert_allclose(y[0], A.T(A(x[0])))
np.testing.assert_allclose(y[1], A.T(A(x[1])))
54 changes: 53 additions & 1 deletion scico/test/operator/test_op_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
import pytest

import scico.numpy as snp
from scico.operator import Abs, DiagonalStack, Operator, VerticalStack
from scico.operator import (
Abs,
DiagonalReplicated,
DiagonalStack,
Operator,
VerticalStack,
)
from scico.random import randn

TestOpA = Operator(input_shape=(3, 4), output_shape=(2, 3, 4), eval_fn=lambda x: snp.stack((x, x)))
TestOpB = Operator(
Expand Down Expand Up @@ -140,3 +147,48 @@ def test_output_collapse(self):

H = DiagonalStack((A1, A2), collapse_output=False)
assert H.output_shape == (A1.output_shape, A1.output_shape)


class TestDiagonalReplicated:
def setup_method(self, method):
self.key = jax.random.PRNGKey(12345)

@pytest.mark.parametrize("map_type", ["auto", "vmap"])
@pytest.mark.parametrize("input_axis", [0, 1])
def test_map_auto_vmap(self, input_axis, map_type):
x, key = randn((2, 3, 4), key=self.key)
mapshape = (3, 4) if input_axis == 0 else (2, 4)
replicates = x.shape[input_axis]
A = Abs(mapshape)
D = DiagonalReplicated(A, replicates, input_axis=input_axis, map_type=map_type)
y = D(x)
assert y.shape[input_axis] == replicates

@pytest.mark.skipif(jax.device_count() < 2, reason="multiple devices required for test")
def test_map_auto_pmap(self):
x, key = randn((2, 3, 4), key=self.key)
A = Abs(x.shape[1:])
replicates = x.shape[0]
D = DiagonalReplicated(A, replicates, map_type="pmap")
y = D(x)
assert y.shape[0] == replicates

def test_input_axis(self):
# Ensure that operators can be stacked on final axis
x, key = randn((2, 3, 4), key=self.key)
A = Abs(x.shape[0:2])
replicates = x.shape[2]
D = DiagonalReplicated(A, replicates, input_axis=2)
y = D(x)
assert y.shape == (2, 3, 4)
D = DiagonalReplicated(A, replicates, input_axis=-1)
y = D(x)
assert y.shape == (2, 3, 4)

def test_output_axis(self):
x, key = randn((2, 3, 4), key=self.key)
A = Abs(x.shape[1:])
replicates = x.shape[0]
D = DiagonalReplicated(A, replicates, output_axis=1)
y = D(x)
assert y.shape == (3, 2, 4)

0 comments on commit c06bdb5

Please sign in to comment.