-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Introduce beartype & fix types #230
Changes from 69 commits
15b808f
0eaec30
781bc38
67e315f
e77732b
5175fa7
c548d73
079936f
1005212
0c57a4e
777b042
df88c69
711fcac
1c95fa5
1b915cf
c3e48c1
4cc9db1
4c4a572
c512195
9516fdf
2d45258
ecfde64
066a5d7
ff68725
4716773
5c90755
c692568
ff5211a
f3cc17e
af69a73
cc97fe2
97d6fdd
cd7c2ee
8a62aae
9359b05
c687cfe
96d3d1e
fc18652
4cbe562
16006ac
c95badb
3c8969d
243e22f
0c3ae8a
cd6cf66
c31885e
7083d08
692d337
1b4ae52
7286cd0
6e55323
8159a2d
62c7a69
3a8814a
25ed38f
76dd035
2ba27b2
33a3e7b
4dd9ed5
54ef19c
3a21676
551f055
6fc3a55
bf2a483
ac70773
e2b6a88
985d554
7ea37d2
760d470
fcf6572
78397d5
67f4c94
68d6371
65a9bfb
c439a93
edaf2cb
bb4d51f
ea1537d
495c649
dca6799
f26d52e
23efd03
1077efa
c292c7e
7e16487
6620802
bc70fdf
249c9d0
c4a8af8
62abd1a
42e21fb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,10 +12,9 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
from __future__ import annotations | ||
|
||
from dataclasses import dataclass | ||
from typing import Optional | ||
from beartype.typing import Optional, Union | ||
|
||
import jax.numpy as jnp | ||
from jaxtyping import Array, Float | ||
|
@@ -31,8 +30,8 @@ class Dataset(Pytree): | |
y (Optional[Float[Array, "N Q"]]): Output data. | ||
""" | ||
|
||
X: Optional[Float[Array, "N D"]] = None | ||
y: Optional[Float[Array, "N Q"]] = None | ||
X: Optional[Union[Float[Array, "N D"], Float[Array, "..."]]] = None | ||
y: Optional[Union[Float[Array, "N Q"], Float[Array, "..."]]] = None | ||
thomaspinder marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __post_init__(self) -> None: | ||
"""Checks that the shapes of X and y are compatible.""" | ||
|
@@ -54,7 +53,7 @@ def is_unsupervised(self) -> bool: | |
"""Returns `True` if the dataset is unsupervised.""" | ||
return self.X is None and self.y is not None | ||
|
||
def __add__(self, other: Dataset) -> Dataset: | ||
def __add__(self, other: "Dataset") -> "Dataset": | ||
"""Combine two datasets. Right hand dataset is stacked beneath the left.""" | ||
|
||
X = None | ||
|
@@ -84,7 +83,7 @@ def out_dim(self) -> int: | |
return self.y.shape[1] | ||
|
||
|
||
def _check_shape(X: Float[Array, "N D"], y: Float[Array, "N Q"]) -> None: | ||
def _check_shape(X: Optional[Float[Array, "..."]], y: Optional[Float[Array, "..."]]) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I removed the beartype-shape checking. Could instead simply remove the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there any reason not to remove There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, beartype only checks arguments individually, and does not check consistency of dimensions across multiple arguments/return values, so might be better to keep the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually actually, jaxtyping itself checks those, I just had a bug in my toy example trying it out. But the error messages are still more verbose/less precise than the ones emitted by _check_shape 😞 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK. I'm guessing there's no easy way for us to customise the error messages that Beartype throws? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the answer to my above question is no, then let's keep There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It looks like this is a very hot topic right now: patrick-kidger/jaxtyping#6. |
||
"""Checks that the shapes of X and y are compatible.""" | ||
if X is not None and y is not None: | ||
if X.shape[0] != y.shape[0]: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we want to push beartype onto end users?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand well enough to comment here. What are the pros/cons of this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hm, mainly in some places we might have specified stricter types (array dtypes/shapes) than is strictly required by the code, so some code might have run fine if only beartype hadn't intervened. also, the explicit _check_shape error messages might be a bit more informative than the generic beartype ones. I don't know if there's also some more interaction with the jaxtyping
@jaxtyped
decorator.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So if we didn't push beartype onto the end-user, then it would just be a testing utility for the package? I could easily imagine how any overly rigid beartype assertions could be annoying for an end user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah! Though on the other hand, it could help us discover more bugs not handled by the tests...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For example, a bunch of them were only uncovered by the notebooks. But I guess we could add beartype there as well, and then use that as a way to suggest to users that they should use e.g. beartype (and report when one of our type hints is wrong)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good way to position it. Let's push it onto users then. If it becomes an issue, we can always walk it back with little major disruption.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Umm, do you mean force it on users (by including it in the general import), or strongly suggest it to users by having it at the start of every notebook (and inside the tests)? I think I'd be in favour of the latter actually...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latter - it's consistent with how we encourage people to use float64 without enforcing it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done