You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The __init__ code snippet of the original MatrixOperator is as follows:
class MatrixOperator(LinearOperator):
"""Linear operator implementing matrix multiplication."""
def __init__(self, A: JaxArray):
# Some Code
super().__init__(input_shape=A.shape[1], output_shape=A.shape[0], input_dtype=self.A.dtype)
Such code will produce a bug in calculating the Hessian matrix when the input is not a square matrix. For example, the following code reproduces the bug
from scico import linop, loss
import numpy as np
A=linop.MatrixOperator(A=np.random.random((5,5)))
x=np.random.random((5,3))
f=loss.SquaredL2Loss(y=x,A=A)
print(f.hessian)
The above code will generate the following error: ValueError: Cannot evaluate <class 'scico.linop.Identity'> with input_shape=(5, 3) on array with shape=(5,)
This bug is caused in computing the Hessian of linear operator 'A' in the following code:
class SquaredL2Loss(Loss):
@property
def hessian(self) -> linop.LinearOperator:
r"""Compute the Hessian of linear operator `A`.
If `self.A` is a :class:`scico.linop.LinearOperator`, returns a
:class:`scico.linop.LinearOperator` corresponding to the Hessian
:math:`2 \alpha \mathrm{A^H W A}`. Otherwise not implemented.
"""
A = self.A
W = self.W
if isinstance(A, linop.LinearOperator):
return linop.LinearOperator(
** input_shape=A.input_shape **,
** output_shape=A.input_shape **,
eval_fn=lambda x: 2 * self.scale * A.adj(W(A(x))), # type: ignore
adj_fn=lambda x: 2 * self.scale * A.adj(W(A(x))), # type: ignore
input_dtype=A.input_dtype,
)
raise NotImplementedError(
f"Hessian is not implemented for {type(self)} when A is {type(A)}; "
"must be LinearOperator."
)
Because the default input_shape of the returned linear operator is A.input_shape and is not consistent with the input x when x is not a square matrix. Besides, the output_shape of the returned linear operator is set to A.input_shape . I am not sure if it is another bug. The following minor modification can fix the bug:
class MatrixOperator(LinearOperator):
"""Linear operator implementing matrix multiplication."""
def __init__(self, A: JaxArray, input_shape: Union[Shape, BlockShape] = None,
output_shape: Optional[Union[Shape, BlockShape]] = None):
"""
Args:
A: Dense array. The action of the created LinearOperator will
implement matrix multiplication with `A`.
"""
self.A: JaxArray #: Dense array implementing this matrix
# if A is an ndarray, make sure it gets converted to a DeviceArray
if isinstance(A, DeviceArray):
self.A = A
elif isinstance(A, np.ndarray):
self.A = jax.device_put(A)
else:
raise TypeError(f"Expected np.ndarray or DeviceArray, got {type(A)}")
# Can only do rank-2 arrays
if A.ndim != 2:
raise TypeError(f"Expected a two-dimensional array, got array of shape {A.shape}")
super_input_shape = A.shape[1] if input_shape is None else input_shape
super_output_shape = A.shape[0] if output_shape is None else output_shape
super().__init__(input_shape=super_input_shape, output_shape=super_output_shape, input_dtype=A.dtype)
The text was updated successfully, but these errors were encountered:
This problem is resolved in #360. Your example to reproduce the bug (with minor modifications to account for the new interface of the MatrixOperator initializer)
The
__init__
code snippet of the original MatrixOperator is as follows:Such code will produce a bug in calculating the Hessian matrix when the input is not a square matrix. For example, the following code reproduces the bug
The above code will generate the following error:
ValueError: Cannot evaluate <class 'scico.linop.Identity'> with input_shape=(5, 3) on array with shape=(5,)
This bug is caused in computing the Hessian of linear operator 'A' in the following code:
Because the default input_shape of the returned linear operator is
A.input_shape
and is not consistent with the inputx
whenx
is not a square matrix. Besides, the output_shape of the returned linear operator is set toA.input_shape
. I am not sure if it is another bug. The following minor modification can fix the bug:The text was updated successfully, but these errors were encountered: