This repository has been archived by the owner on May 11, 2023. It is now read-only.
feat: Overload jax.numpy operations for LinearOperator compatibility. #2
Labels
enhancement
New feature or request
The problem:
Currently to add, e.g., a diagonal linear operator, to a matrix we would do the following:
This is all very nice. But if we tried to use
jnp.add
instead, this would result in an error.If would be nice to overload
jnp.add
, so that it returnedD + M
instead. This would be particularly nice for matrix solves, usingjnp.solve(D, M)
would be more familiar to JAX users thanD.solve(M)
Proposed solution
Overload
jax.numpy
operations, but via activated by a global config.The text was updated successfully, but these errors were encountered: