-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce beartype & fix types #230
Merged
Merged
Changes from 84 commits
Commits
Show all changes
91 commits
Select commit
Hold shift + click to select a range
15b808f
add beartype dependency
st-- 0eaec30
from typing import -> from beartype.typing import
st-- 781bc38
jaxtyping import_hook for @jaxtyped @beartype everywhere
st-- 67e315f
fix Type[] of class-as-argument
st-- e77732b
fix KeyArray type hint (should probably move into jaxutils though)
st-- 5175fa7
fix return value of slice_input when active_dims is None
st-- c548d73
fix return value of squared_distance
st-- 079936f
fix return type of recursive_bijectors
st-- 1005212
fix slice_input type annotations
st-- 0c57a4e
new KernelCallable type to fix kernel_fn annotations
st-- 777b042
fix kernel __call__ annotation
st-- df88c69
fix KeyArray type hint
st-- 711fcac
beartype does not like forward references; replaced with string types
st-- 1c95fa5
linops other type hint fixes
st-- 1b915cf
fix KeyArray
st-- c3e48c1
abstractions.py some type fixes
st-- 4cc9db1
fix GaussianDistribution.log_prob return type
st-- 4c4a572
fix depreciations & warnings
st-- c512195
Merge branch 'st/fix_depreciations' into st/beartype
st-- 9516fdf
fix scalar array types
st-- 2d45258
introduce ScalarBool, ScalarInt for jitted calls in abstractions
st-- ecfde64
relax LinearOperator's solve() types (can be both matrix or vector), …
st-- 066a5d7
remove _stop_grad type hints, not sure what they should be
st-- ff68725
found some more
st-- 4716773
Merge branch 'st/fix_depreciations' into st/beartype
st-- 5c90755
float -> ScalarFloat fixes
st-- c692568
linops log_det type fixes
st-- ff5211a
some more linops type fixes
st-- f3cc17e
Merge remote-tracking branch 'upstream/v0.6' into st/beartype
st-- af69a73
actually commit KeyArray and Scalar* types
st-- cc97fe2
add beartype to pyproject
st-- 97d6fdd
from beartype.typing import ...
st-- cd7c2ee
try to fix Self in gpjax/base/module
st-- 8a62aae
fix _check_shape
st-- 9359b05
gpjax.objectives: always import from gps and variational_families
st-- c687cfe
Revert "gpjax.objectives: always import from gps and variational_fami…
st-- 96d3d1e
fix gpjax.objectives imported types
st-- fc18652
<...> | None not supported by beartype; replaced by Optional[<...>]
st-- 4cbe562
gpjax.datasets: cannot specify strict array shape AND rely on _check_…
st-- 16006ac
our tfd.Distribution subclassing requires the fix introduced in jaxty…
st-- c95badb
need to import base first!
st-- 3c8969d
bugfix
st-- 243e22f
AbstractKernel: string for forward references
st-- 0c3ae8a
remove from __future__ import annotations
st-- cd6cf66
fix type annotations to make up for changes in 0c3ae8ac33e7938a2f4d5b…
st-- c31885e
pytree map functions may take a non-Module argument
st-- 7083d08
ScalarFloat
st-- 692d337
VecNOrMatNM
st-- 1b4ae52
remove unnecessary / buggy methods
st-- 7286cd0
more ScalarFloat
st-- 6e55323
ScalarFloat
st-- 8159a2d
type fixes
st-- 62c7a69
fix shape type
st-- 3a8814a
fix one KeyArray
st-- 25ed38f
more ScalarFloat corrections in kernels
st-- 76dd035
fix test_stationary accordingly for ScalarFloat params
st-- 2ba27b2
fix return type
st-- 33a3e7b
ScalarInt for Polynomial kernel and fix test for Scalar* params
st-- 4dd9ed5
fix mock in test_abstract_variational_family
st-- 54ef19c
fix link_function and variational_expectations shape annotations
st-- 3a21676
minor test fix
st-- 551f055
fix exception test for beartype
st-- 6fc3a55
fix Constant mean function
st-- bf2a483
base_kernel as kwarg in test_approximations
st-- ac70773
rename func to test_ so it actually gets collected
st-- e2b6a88
mark test_graph_kernel as broken
st-- 985d554
fix LinearOperator DTypeT
st-- 7ea37d2
Revert "fix Constant mean function"
st-- 760d470
fix test_mean_functions instead
st-- fcf6572
fix one more bug in RFF test
st-- 78397d5
Self
st-- 67f4c94
relax fit objective type
st-- 68d6371
rename gpjax.utils -> gpjax.typing
st-- 65a9bfb
Kernel = Any -> string forward reference
st-- c439a93
relax Gaussian.predict type annotation to include GaussianDistribution
st-- edaf2cb
our own `Array` type that accepts both JAX and Numpy arrays
st-- bb4d51f
some Float -> Num relaxations for graph kernel...
st-- ea1537d
ScalarFloat for GraphKernel hyperparams
st-- 495c649
fix type hints to what happens (even if it seems wrong)
st-- dca6799
type relaxation for deep_kernels.pct.py
st-- f26d52e
some more minor consistency fixes
st-- 23efd03
bugfix
st-- 1077efa
Update examples/graph_kernels.pct.py
thomaspinder c292c7e
Update gpjax/dataset.py
thomaspinder 7e16487
jaxtyping import hook for notebooks
st-- 6620802
conftest.py to apply jaxtyping import hook before loading tests
st-- bc70fdf
remove import hook from gpjax/__init__
st-- 249c9d0
Merge branch 'v0.6' of https://github.com/JaxGaussianProcesses/GPJax …
st-- c4a8af8
Update gpjax/dataset.py
st-- 62abd1a
Update gpjax/dataset.py
thomaspinder 42e21fb
fix tests of shape checks now that we have beartype
st-- File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to push beartype onto end users?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand well enough to comment here. What are the pros/cons of this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, mainly in some places we might have specified stricter types (array dtypes/shapes) than is strictly required by the code, so some code might have run fine if only beartype hadn't intervened. also, the explicit _check_shape error messages might be a bit more informative than the generic beartype ones. I don't know if there's also some more interaction with the jaxtyping
@jaxtyped
decorator.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if we didn't push beartype onto the end-user, then it would just be a testing utility for the package? I could easily imagine how any overly rigid beartype assertions could be annoying for an end user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah! Though on the other hand, it could help us discover more bugs not handled by the tests...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, a bunch of them were only uncovered by the notebooks. But I guess we could add beartype there as well, and then use that as a way to suggest to users that they should use e.g. beartype (and report when one of our type hints is wrong)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good way to position it. Let's push it onto users then. If it becomes an issue, we can always walk it back with little major disruption.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Umm, do you mean force it on users (by including it in the general import), or strongly suggest it to users by having it at the start of every notebook (and inside the tests)? I think I'd be in favour of the latter actually...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latter - it's consistent with how we encourage people to use float64 without enforcing it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done