Skip to content
This repository has been archived by the owner on May 11, 2023. It is now read-only.

Jax utils py tree refactor #20

Merged
merged 12 commits into from
Mar 16, 2023
Merged

Jax utils py tree refactor #20

merged 12 commits into from
Mar 16, 2023

Conversation

daniel-dodd
Copy link
Member

This moves the base class inheritance to simple-pytree.

We move shape and dtype to attributes of the dataclass, and ensure that dtype can be passed to each LinOps __init__ to override the dtype of the linear operator.

Create kroncker_linear_operator.py
Merging Kronecker work to Equinox branch since it restructures class properties (would have to do this again otherwise).
Stage two:

Abstract out more properties / methods as attributes to make it possible to build LinearOperators from defining a mat_mul, shape and dtype, *args, **kwargs.
Copy link
Collaborator

@thomaspinder thomaspinder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good. Before approving, I would ideally like to see two things:

  1. More documentation would be helpful. A simple notebook demonstrating the functionality of JaxLinOp and some more informative docstrings would be sufficient for now.
  2. Coverage - it looks like a report is being uploaded to codecov, but I don't see the percentage anywhere.

jaxlinop/utils.py Show resolved Hide resolved
jaxlinop/linear_operator.py Show resolved Hide resolved

if not isinstance(size, int):
raise ValueError(f"`length` must be an integer, but `length = {size}`.")
if not isinstance(shape, tuple):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we not want it be an Iterable? I'd have thought a list was valid here.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will open issue for this.


# TODO: Generalise to non-square matrices.
# TODO: Generalise to non-square matrices.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we either act on TODOs or open issues for them please

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will open issue for these.

jaxlinop/linear_operator.py Show resolved Hide resolved
@@ -175,7 +178,8 @@ def from_dense(cls, dense: Float[Array, "N N"]) -> ZeroLinearOperator:
"""

# TODO: check shapes.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above re. TODOs

@daniel-dodd daniel-dodd merged commit c37f7eb into main Mar 16, 2023
@daniel-dodd daniel-dodd deleted the JaxUtils-PyTree-refactor- branch March 16, 2023 18:32
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants