Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

feat: Overload jax.numpy operations for LinearOperator compatibility. #2

Open
daniel-dodd opened this issue Nov 23, 2022 · 0 comments
Labels
enhancement New feature or request

Comments

@daniel-dodd
Copy link
Member

The problem:

Currently to add, e.g., a diagonal linear operator, to a matrix we would do the following:

from jaxlinop import DiagonalLinearOperator
import jax.numpy as jnp

M =  jnp.array([[1.,2.],[3., 4.]])
D = DiagonalLinearOperator(jnp.array([1.,2.]))

res = D + M

print(res.to_dense())
[[2. 2.]
 [3. 6.]]

This is all very nice. But if we tried to use jnp.add instead, this would result in an error.

res = jnp.add(D, M)
TypeError: add requires ndarray or scalar arguments, got <class 'jaxlinop.diagonal_linear_operator.DiagonalLinearOperator'> at position 0.

If would be nice to overload jnp.add, so that it returned D + M instead. This would be particularly nice for matrix solves, using jnp.solve(D, M) would be more familiar to JAX users than D.solve(M)

Proposed solution

Overload jax.numpy operations, but via activated by a global config.

from jaxlinop import config

config.update("overload_jax_numpy", True)

res = jnp.add(D, M)

print(res.to_dense())
[[2. 2.]
 [3. 6.]]
@daniel-dodd daniel-dodd added the enhancement New feature or request label Nov 23, 2022
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant