Skip to content

Commit

Permalink
Fixed the device allocation for Transpose and Adjoint. Added tests (#97)
Browse files Browse the repository at this point in the history
A.T and A.H failed to grab the device from A as the [
`find_device`](https://github.com/wilson-labs/cola/blob/main/cola/ops/operator_base.py#L264-L280)
function was receiving as arguments `[(), {}]`. To get a solution, now
both `Transpose` and `Adjoint` fetch the `device` from the original
operator.
  • Loading branch information
AndPotap authored Aug 29, 2024
1 parent 7ae87b6 commit d9e12c0
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 22 deletions.
12 changes: 8 additions & 4 deletions cola/ops/operators.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from functools import reduce, partial
from cola.ops.operator_base import LinearOperator, Array
from cola.backends import get_library_fns
from functools import partial, reduce

import numpy as np
from plum import parametric

import cola
import numpy as np
from cola.backends import get_library_fns
from cola.ops.operator_base import Array, LinearOperator


class Dense(LinearOperator):
Expand Down Expand Up @@ -359,6 +361,7 @@ class Transpose(LinearOperator):
def __init__(self, A):
self.A = A
super().__init__(dtype=A.dtype, shape=(A.shape[1], A.shape[0]))
self.device = A.device

def _matmat(self, x):
return self.A._rmatmat(x.T).T
Expand All @@ -376,6 +379,7 @@ class Adjoint(LinearOperator):
def __init__(self, A):
self.A = A
super().__init__(dtype=A.dtype, shape=(A.shape[1], A.shape[0]))
self.device = A.device

def _matmat(self, x):
return self.xnp.conj(self.A._rmatmat(self.xnp.conj(x).T)).T
Expand Down
51 changes: 33 additions & 18 deletions tests/test_operators.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,28 @@
import pytest
import numpy as np
import pytest
from linalg.operator_market import get_test_operator, op_names

from cola.backends import all_backends, tracing_backends
from cola.fns import kron, lazify
from cola.ops import Tridiagonal
from cola.ops import Diagonal
from cola.ops import Identity
from cola.ops import I_like
from cola.ops import KronSum
from cola.ops import Sum
from cola.ops import ScalarMul
from cola.ops import Product
from cola.ops import Sliced
from cola.ops import Householder
from cola.ops import Sparse
from cola.ops import Jacobian
from cola.ops import LinearOperator
from cola.ops import Kernel
from cola.ops import Hessian
from cola.linalg.decompositions.arnoldi import get_householder_vec
from cola.ops import (
Diagonal,
Hessian,
Householder,
I_like,
Identity,
Jacobian,
Kernel,
KronSum,
LinearOperator,
Product,
ScalarMul,
Sliced,
Sparse,
Sum,
Tridiagonal,
)
from cola.utils.test_utils import get_xnp, parametrize, relative_error
from cola.backends import all_backends, tracing_backends
from linalg.operator_market import op_names, get_test_operator

_tol = 1e-6

Expand All @@ -46,6 +49,18 @@ def fn(z):
_exclude = (slice(None), slice(None), ['square_fft'])


@parametrize(tracing_backends)
def test_device_inheritance(backend):
xnp = get_xnp(backend)
xnp = get_xnp(backend)
dtype = xnp.float32
Aop = Diagonal(xnp.array([0.1, -0.2], dtype=dtype, device=None))
Aop.device = "cuda:0"

assert Aop.T.device == Aop.device
assert Aop.H.device == Aop.device


@parametrize(tracing_backends, ['float32'], op_names).excluding[_exclude]
def test_ops_to(backend, precision, op_name):
Op = get_test_operator(backend, precision, op_name)
Expand Down

0 comments on commit d9e12c0

Please sign in to comment.