Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce beartype & fix types #230

Merged
merged 91 commits into from
Apr 27, 2023

Conversation

st--
Copy link
Contributor

@st-- st-- commented Apr 13, 2023

Please check the type of change your PR introduces:

  • Bugfixes identified by added type checking
  • Some backwards-incompatible changes to fix behaviour
  • Code style update (formatting, renaming)

What is the current behavior?

No runtime type checking.

Resolves #202

What is the new behavior?

Runtime type checking.

st-- added 29 commits April 13, 2023 12:55
@st-- st-- changed the base branch from main to v0.6 April 13, 2023 15:43
@st--
Copy link
Contributor Author

st-- commented Apr 25, 2023

Going through the notebooks, another issue: Float[Array, "..."] specifically matches JAX-arrays, but not numpy arrays. How do JAX-functions handle numpy-arrays? Do they implicitly convert them? Should we therefore relax type hints on all user-facing parts (not just Dataset but also e.g. predict()) to allow numpy arrays?

@thomaspinder
Copy link
Collaborator

Going through the notebooks, another issue: Float[Array, "..."] specifically matches JAX-arrays, but not numpy arrays. How do JAX-functions handle numpy-arrays? Do they implicitly convert them? Should we therefore relax type hints on all user-facing parts (not just Dataset but also e.g. predict()) to allow numpy arrays?

Yes, JAX will implicitly convert them. So calling posterior(np_array, D) will produce the same result as posterior(jnp.asarray(np_array), D). I think it'd place less friction on the user if we relax the type hint to accomodate Numpy arrays.

gpjax/dataset.py Outdated Show resolved Hide resolved
gpjax/dataset.py Outdated Show resolved Hide resolved
examples/graph_kernels.pct.py Outdated Show resolved Hide resolved
st-- and others added 4 commits April 25, 2023 19:18
Co-authored-by: st-- <[email protected]>
Signed-off-by: Thomas Pinder <[email protected]>
Allow for integer responses

Co-authored-by: st-- <[email protected]>
Signed-off-by: Thomas Pinder <[email protected]>
Copy link
Collaborator

@thomaspinder thomaspinder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All looks good to me. There's a small issue with imports, but I'm happy to merge as I've setup a pre-commit action on docs_update that will fix this. I will rebase that branch and correct.

If you're happy with the proposed approach of pushing beartype onto the user, then we can merge this.

gpjax/kernels/base.py Show resolved Hide resolved
gpjax/kernels/nonstationary/linear.py Show resolved Hide resolved
Comment on lines +33 to +35
DTypes = Union[Type[jnp.float32], Type[jnp.float64], Type[jnp.int32], Type[jnp.int64]]
ShapeT = TypeVar("ShapeT", bound=NestedT[Tuple[int, ...]])
DTypeT = TypeVar("DTypeT", bound=NestedT[jnp.dtype])
DTypeT = TypeVar("DTypeT", bound=NestedT[DTypes])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let's do that.

VariationalGaussian,
WhitenedVariationalGaussian)
from jaxtyping import install_import_hook
with install_import_hook("gpjax", "beartype.beartype"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good way to position it. Let's push it onto users then. If it becomes an issue, we can always walk it back with little major disruption.

gpjax/dataset.py Outdated Show resolved Hide resolved
gpjax/dataset.py Outdated Show resolved Hide resolved
@st--
Copy link
Contributor Author

st-- commented Apr 26, 2023

@thomaspinder ok I think it's ready to merge now. I hope I didn't accidentally break anything along the way 😅

@thomaspinder thomaspinder merged commit 75f9ddd into JaxGaussianProcesses:v0.6 Apr 27, 2023
@st-- st-- mentioned this pull request Apr 27, 2023
@st-- st-- deleted the st/beartype branch November 22, 2023 21:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants