Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Matrix Operator should add 'input_shape' and 'output_shape' parameters #356

Closed
Bilybill opened this issue Oct 22, 2022 · 3 comments · Fixed by #360
Closed

BUG: Matrix Operator should add 'input_shape' and 'output_shape' parameters #356

Bilybill opened this issue Oct 22, 2022 · 3 comments · Fixed by #360
Assignees
Labels
bug Something isn't working

Comments

@Bilybill
Copy link

Bilybill commented Oct 22, 2022

The __init__ code snippet of the original MatrixOperator is as follows:

class MatrixOperator(LinearOperator):
    """Linear operator implementing matrix multiplication."""
    def __init__(self, A: JaxArray):
        # Some Code
        super().__init__(input_shape=A.shape[1], output_shape=A.shape[0], input_dtype=self.A.dtype)

Such code will produce a bug in calculating the Hessian matrix when the input is not a square matrix. For example, the following code reproduces the bug

from scico import linop, loss
import numpy as np
A=linop.MatrixOperator(A=np.random.random((5,5)))
x=np.random.random((5,3))
f=loss.SquaredL2Loss(y=x,A=A)
print(f.hessian)

The above code will generate the following error:
ValueError: Cannot evaluate <class 'scico.linop.Identity'> with input_shape=(5, 3) on array with shape=(5,)

This bug is caused in computing the Hessian of linear operator 'A' in the following code:

class SquaredL2Loss(Loss):
    @property
    def hessian(self) -> linop.LinearOperator:
        r"""Compute the Hessian of linear operator `A`.

        If `self.A` is a :class:`scico.linop.LinearOperator`, returns a
        :class:`scico.linop.LinearOperator` corresponding to  the Hessian
        :math:`2 \alpha \mathrm{A^H W A}`. Otherwise not implemented.
        """
        A = self.A
        W = self.W
        if isinstance(A, linop.LinearOperator):
            return linop.LinearOperator(
                ** input_shape=A.input_shape **,
                ** output_shape=A.input_shape **,
                eval_fn=lambda x: 2 * self.scale * A.adj(W(A(x))),  # type: ignore
                adj_fn=lambda x: 2 * self.scale * A.adj(W(A(x))),  # type: ignore
                input_dtype=A.input_dtype,
            )

        raise NotImplementedError(
            f"Hessian is not implemented for {type(self)} when A is {type(A)}; "
            "must be LinearOperator."
        )

Because the default input_shape of the returned linear operator is A.input_shape and is not consistent with the input x when x is not a square matrix. Besides, the output_shape of the returned linear operator is set to A.input_shape . I am not sure if it is another bug. The following minor modification can fix the bug:

class MatrixOperator(LinearOperator):
    """Linear operator implementing matrix multiplication."""

    def __init__(self, A: JaxArray, input_shape: Union[Shape, BlockShape] = None,
        output_shape: Optional[Union[Shape, BlockShape]] = None):
        """
        Args:
            A: Dense array. The action of the created LinearOperator will
                implement matrix multiplication with `A`.
        """
        self.A: JaxArray  #: Dense array implementing this matrix

        # if A is an ndarray, make sure it gets converted to a DeviceArray
        if isinstance(A, DeviceArray):
            self.A = A
        elif isinstance(A, np.ndarray):
            self.A = jax.device_put(A)
        else:
            raise TypeError(f"Expected np.ndarray or DeviceArray, got {type(A)}")

        # Can only do rank-2 arrays
        if A.ndim != 2:
            raise TypeError(f"Expected a two-dimensional array, got array of shape {A.shape}")
        
        super_input_shape = A.shape[1] if input_shape is None else input_shape
        super_output_shape = A.shape[0] if output_shape is None else output_shape
        
        super().__init__(input_shape=super_input_shape, output_shape=super_output_shape, input_dtype=A.dtype)
@bwohlberg bwohlberg added the bug Something isn't working label Oct 22, 2022
@bwohlberg bwohlberg self-assigned this Oct 22, 2022
@bwohlberg
Copy link
Collaborator

Thank you for the bug report.

@bwohlberg bwohlberg linked a pull request Oct 28, 2022 that will close this issue
@bwohlberg
Copy link
Collaborator

This problem is resolved in #360. Your example to reproduce the bug (with minor modifications to account for the new interface of the MatrixOperator initializer)

from scico import linop, loss
import numpy as np
A=linop.MatrixOperator(A=np.random.random((5,5)), input_cols=3)
x=np.random.random((5,3))
f=loss.SquaredL2Loss(y=x,A=A)
print(f.hessian)

no longer produces an error. Please let us know if you have any comments or reservations about this solution.

@Bilybill
Copy link
Author

Thanks for your modification, this issue can be closed

bwohlberg added a commit that referenced this issue Oct 31, 2022
* Resolve oversight in possible input shapes

* Improve parameter name

* Add tests for matrix input shapes
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants