Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error message could include more shape information #6

Closed
awf opened this issue Jul 19, 2022 · 13 comments
Closed

Error message could include more shape information #6

awf opened this issue Jul 19, 2022 · 13 comments
Labels
feature New feature next Higher-priority items

Comments

@awf
Copy link

awf commented Jul 19, 2022

The following code passes typechecking, and runs without error

import jax
from typeguard import typechecked as typechecker
from jaxtyping import f32, u, jaxtyped

@jaxtyped
@typechecker
def standardize(x : f32["N"], eps=1e-5):
    return (x - x.mean()) / (x.std() + eps)

rng = jax.random.PRNGKey(42)

embeddings = jax.random.uniform(rng, (11,))
t1 = standardize(embeddings)

The following code currectly fails typechecking, but the message would ideally tell us why the shapes don't match

embeddings = jax.random.uniform(rng, (11,13))
t1 = standardize(embeddings)
# TypeError: type of argument "x" must be jaxtyping.array_types.f32['N']; got jaxlib.xla_extension.DeviceArray instead

This would more ideally be something like

# TypeError: type of argument "x" must be jaxtyping.array_types.f32['N']; got jaxlib.xla_extension.DeviceArray(dtype=float32,shape=(11,13)) instead
@patrick-kidger
Copy link
Owner

Absolutely! I completely agree.

So at the moment this is a limitation of the current approach. The checking is performed via isinstance, which simply returns True or False, and it's then up to either typeguard or beartype to take this and turn it into an error message. This means that there isn't really any way of returning this additional information about why the isinstance check failed.

I don't have a great solution in mind for this at the moment. I'd welcome any thoughts on how to accomplish this.

@awf
Copy link
Author

awf commented Jul 19, 2022

I see: you're doing all your work at https://github.com/google/jaxtyping/blob/35201eb189cc004276925f96e0aa6bfc469e46be/jaxtyping/array_types.py#L102, and then typeguard says

            elif not isinstance(value, expected_type):
                raise TypeError(
                    'type of {} must be {}; got {} instead'.
                    format(argname, qualified_name(expected_type), qualified_name(value)))

Hmmm.

So it turns out this isn't too noisy, as when your check fails, we are almost certainly going to error:

class _MetaAbstractArray(type):
    def __instancecheck__(cls, obj):
        if not isinstance(obj, jnp.ndarray):
            print(f'jaxtyping: {obj}:{type(obj)} is not a jnp.ndarray.')
            return False

        if cls.dtypes is not _any_dtype and obj.dtype not in cls.dtypes:
            print(f'jaxtyping: {obj} dtype ({obj.dtype}) is not in {cls.dtypes}.')
            return False

@patrick-kidger
Copy link
Owner

Yeah, adding our own manual print statements might be one approach. Not super elegant of course so if we did this I'd probably add a global toggle on whether to print them out.

@awf
Copy link
Author

awf commented Jul 19, 2022

Exactly so. It might even be a case for, ugh, an environment variable, so a usage pattern might be

% python t.py
...
Error message.
% JAXTYPING=verbose python t.py

@GallagherCommaJack
Copy link

probably verbose should be the default? probably >90% of exceptions for a library like this one will be thrown while the dev is looking, not in some production use case where the print statement would be an issue.

that said, it probably should still print to stderr not stdout

@dkamm
Copy link

dkamm commented Mar 10, 2023

Hi @patrick-kidger - any updates on this? Feels like this makes jaxtyping a bit frustrating to use with a typechecker since shape mismatches are so common

@patrick-kidger
Copy link
Owner

As it turns out, an analogous point has just been raised over on the beartype repo: beartype/beartype#216

If beartype includes a hook for this use case, then it's possible that we could add in some nicer error messages here.

Until then, my usual recommendation is to arrange to open a debugger when things crash (e.g. pytest --pdb if using this as test time), and then just walk the strack trace looking at the object that was passed.

@dkamm
Copy link

dkamm commented Mar 10, 2023

@patrick-kidger thanks for the swift response! Crazy how that timing worked out.

Just out of curiosity, do you think patching typeguard like in torchtyping could work as a temporary solution? Not requesting to add it here but figured I'd ask since it looks complicated

@patrick-kidger
Copy link
Owner

patrick-kidger commented Mar 10, 2023

In principle, anything is possible with monkey patching :)

In practice that was a crazy solution that I'm not keen to repeat!

@dkamm
Copy link

dkamm commented Apr 24, 2023

@patrick-kidger it looks like typeguard 4 is adding support for a typecheck fail callback (see for example https://github.com/agronholm/typeguard/blob/master/src/typeguard/_functions.py#L116-L144). Maybe jaxtyping could make use of this when it's released?

@patrick-kidger
Copy link
Owner

Nice! Beartype also has similar plans: beartype/beartype#235

I'd be happy to add support for either/both when they're added. In fact, maybe it's worth asking if they could standardise on an API.

@knyazer
Copy link
Contributor

knyazer commented Oct 1, 2023

Coming back to this, it looks like it might take quite a bit of time for beartype/typeguard to standardize their APIs, and implement them, so I think it would be nice to implement this, even if guarded by a global flag. I am guessing that a better solution that the one with printing could be decorating functions with another decorators, that would catch exceptions related to jaxtyping, and reraise them with better messages, while still preserving the original error message. Something like this:

@jaxtyping.pretty_errors
@beartype.beartype
@jaxtyping.jaxtyped
def f():
    ...

I imagine reraising could look similar to the jax errors, so that we have a "pretty" error printed after the original trace from the typechecker. Similar to this:

BeartypeTypeHintViolation: blah blah blah / TypeError: blah blah blah

The above exception was the direct cause of the following exception:

In the function 'f' argument 'x':
expected:      Array["N",     dtype=float]
got:           Array["1,2",   dtype=float]

argument 'y':
expected: Array["", dtype=int]
got:      Array["", dtype=float]

The problem is that we will have to make a conditional based on whether the error is typeguard-raised or beartype-raised, or anything-else-raised, transform the culprit log into a unified format, and only then do a pretty printing.

When the official API is going to be implemented, we anyway will have to have a functionality for pretty printing, so implementing it beforehand does not look like a waste of work. And, even though it is an ugly (and unstable) solution, I am guessing that most of the users of jaxtyping would largely appreciate having this functionality available at hand. For example, in my case, the runtime type checking is mostly useful during prototyping/debugging, and this would save me quite a bit of time, since I would only need to take a quick look at the trace instead of inserting jax.debug.print("{x}", x=x) in the place where 'f' is called from.

@marksandler2
Copy link

marksandler2 commented Oct 1, 2023

FWIW we ended up implementing a small wrapper that does that for typeguard, it is literally ~30 lines of code (of which only 2 lines are typeguard specific, 15 lines do pretty printing, and the rest just boiler plate and comments) so instead of using

@jt.jaxtyped

we just use

@util.jaxtyped

patrick-kidger added a commit that referenced this issue Oct 23, 2023
Phew, this ended up being a pretty complicated change!
The basic summary is that we now support the syntax
```
@jaxtyped(typechecker=typechecker)
def f(...): ...
```
and when using this, we now get pretty error messages about what went
wrong.

(
The old syntax, i.e.
```
@jaxtyped
@typechecker
def f(...): ...
```
is still supported, but doesn't give much information.
)

The internals of this do quite a lot of magic! In particular we
dynamically create quite a lot of functions and test the provided
arguments against their signatures. The overhead should still be
minimal under `jax.jit`, though.
(TODO: what's the overhead like in non-jit situations, e.g. PyTorch?
I've tried to minimise the overhead throughout just to be sure, but
perhaps PyTorch users should stick to the old syntax?)
patrick-kidger added a commit that referenced this issue Oct 23, 2023
Phew, this ended up being a pretty complicated change!
The basic summary is that we now support the syntax
```
@jaxtyped(typechecker=typechecker)
def f(...): ...
```
and when using this, we now get pretty error messages about what went
wrong.

(
The old syntax, i.e.
```
@jaxtyped
@typechecker
def f(...): ...
```
is still supported, but doesn't give much information.
)

The internals of this do quite a lot of magic! In particular we
dynamically create quite a lot of functions and test the provided
arguments against their signatures. The overhead should still be
minimal under `jax.jit`, though.
(TODO: what's the overhead like in non-jit situations, e.g. PyTorch?
I've tried to minimise the overhead throughout just to be sure, but
perhaps PyTorch users should stick to the old syntax?)
patrick-kidger added a commit that referenced this issue Oct 24, 2023
Phew, this ended up being a pretty complicated change!
The basic summary is that we now support the syntax
```
@jaxtyped(typechecker=typechecker)
def f(...): ...
```
and when using this, we now get pretty error messages about what went
wrong.

(
The old syntax, i.e.
```
@jaxtyped
@typechecker
def f(...): ...
```
is still supported, but doesn't give much information.
)

The internals of this do quite a lot of magic! In particular we
dynamically create quite a lot of functions and test the provided
arguments against their signatures. The overhead should still be
minimal under `jax.jit`, though.
(TODO: what's the overhead like in non-jit situations, e.g. PyTorch?
I've tried to minimise the overhead throughout just to be sure, but
perhaps PyTorch users should stick to the old syntax?)
patrick-kidger added a commit that referenced this issue Oct 24, 2023
Phew, this ended up being a pretty complicated change!
The basic summary is that we now support the syntax
```
@jaxtyped(typechecker=typechecker)
def f(...): ...
```
and when using this, we now get pretty error messages about what went
wrong.

(
The old syntax, i.e.
```
@jaxtyped
@typechecker
def f(...): ...
```
is still supported, but doesn't give much information.
)

The internals of this do quite a lot of magic! In particular we
dynamically create quite a lot of functions and test the provided
arguments against their signatures. The overhead should still be
minimal under `jax.jit`, though.
(TODO: what's the overhead like in non-jit situations, e.g. PyTorch?
I've tried to minimise the overhead throughout just to be sure, but
perhaps PyTorch users should stick to the old syntax?)
patrick-kidger added a commit that referenced this issue Oct 24, 2023
Phew, this ended up being a pretty complicated change!
The basic summary is that we now support the syntax
```
@jaxtyped(typechecker=typechecker)
def f(...): ...
```
and when using this, we now get pretty error messages about what went
wrong.

(
The old syntax, i.e.
```
@jaxtyped
@typechecker
def f(...): ...
```
is still supported, but doesn't give much information.
)

The internals of this do quite a lot of magic! In particular we
dynamically create quite a lot of functions and test the provided
arguments against their signatures. The overhead should still be
minimal under `jax.jit`, though.
(TODO: what's the overhead like in non-jit situations, e.g. PyTorch?
I've tried to minimise the overhead throughout just to be sure, but
perhaps PyTorch users should stick to the old syntax?)
patrick-kidger added a commit that referenced this issue Nov 7, 2023
Phew, this ended up being a pretty complicated change!
The basic summary is that we now support the syntax
```
@jaxtyped(typechecker=typechecker)
def f(...): ...
```
and when using this, we now get pretty error messages about what went
wrong.

(
The old syntax, i.e.
```
@jaxtyped
@typechecker
def f(...): ...
```
is still supported, but doesn't give much information.
)

The internals of this do quite a lot of magic! In particular we
dynamically create quite a lot of functions and test the provided
arguments against their signatures. The overhead should still be
minimal under `jax.jit`, though.
(TODO: what's the overhead like in non-jit situations, e.g. PyTorch?
I've tried to minimise the overhead throughout just to be sure, but
perhaps PyTorch users should stick to the old syntax?)
patrick-kidger added a commit that referenced this issue Nov 14, 2023
Phew, this ended up being a pretty complicated change!
The basic summary is that we now support the syntax
```
@jaxtyped(typechecker=typechecker)
def f(...): ...
```
and when using this, we now get pretty error messages about what went
wrong.

(
The old syntax, i.e.
```
@jaxtyped
@typechecker
def f(...): ...
```
is still supported, but doesn't give much information.
)

The internals of this do quite a lot of magic! In particular we
dynamically create quite a lot of functions and test the provided
arguments against their signatures. The overhead should still be
minimal under `jax.jit`, though.
(TODO: what's the overhead like in non-jit situations, e.g. PyTorch?
I've tried to minimise the overhead throughout just to be sure, but
perhaps PyTorch users should stick to the old syntax?)
patrick-kidger added a commit that referenced this issue Nov 15, 2023
Phew, this ended up being a pretty complicated change!
The basic summary is that we now support the syntax
```
@jaxtyped(typechecker=typechecker)
def f(...): ...
```
and when using this, we now get pretty error messages about what went
wrong.

(
The old syntax, i.e.
```
@jaxtyped
@typechecker
def f(...): ...
```
is still supported, but doesn't give much information.
)

The internals of this do quite a lot of magic! In particular we
dynamically create quite a lot of functions and test the provided
arguments against their signatures. The overhead should still be
minimal under `jax.jit`, though.
(TODO: what's the overhead like in non-jit situations, e.g. PyTorch?
I've tried to minimise the overhead throughout just to be sure, but
perhaps PyTorch users should stick to the old syntax?)
patrick-kidger added a commit that referenced this issue Nov 27, 2023
Phew, this ended up being a pretty complicated change!
The basic summary is that we now support the syntax
```
@jaxtyped(typechecker=typechecker)
def f(...): ...
```
and when using this, we now get pretty error messages about what went
wrong.

(
The old syntax, i.e.
```
@jaxtyped
@typechecker
def f(...): ...
```
is still supported, but doesn't give much information.
)

The internals of this do quite a lot of magic! In particular we
dynamically create quite a lot of functions and test the provided
arguments against their signatures. The overhead should still be
minimal under `jax.jit`, though.
(TODO: what's the overhead like in non-jit situations, e.g. PyTorch?
I've tried to minimise the overhead throughout just to be sure, but
perhaps PyTorch users should stick to the old syntax?)
patrick-kidger added a commit that referenced this issue Nov 27, 2023
Phew, this ended up being a pretty complicated change!
The basic summary is that we now support the syntax
```
@jaxtyped(typechecker=typechecker)
def f(...): ...
```
and when using this, we now get pretty error messages about what went
wrong.

(
The old syntax, i.e.
```
@jaxtyped
@typechecker
def f(...): ...
```
is still supported, but doesn't give much information.
)

The internals of this do quite a lot of magic! In particular we
dynamically create quite a lot of functions and test the provided
arguments against their signatures. The overhead should still be
minimal under `jax.jit`, though.
(TODO: what's the overhead like in non-jit situations, e.g. PyTorch?
I've tried to minimise the overhead throughout just to be sure, but
perhaps PyTorch users should stick to the old syntax?)
patrick-kidger added a commit that referenced this issue Nov 27, 2023
Phew, this ended up being a pretty complicated change!
The basic summary is that we now support the syntax
```
@jaxtyped(typechecker=typechecker)
def f(...): ...
```
and when using this, we now get pretty error messages about what went
wrong.

(
The old syntax, i.e.
```
@jaxtyped
@typechecker
def f(...): ...
```
is still supported, but doesn't give much information.
)

The internals of this do quite a lot of magic! In particular we
dynamically create quite a lot of functions and test the provided
arguments against their signatures. The overhead should still be
minimal under `jax.jit`, though.
(TODO: what's the overhead like in non-jit situations, e.g. PyTorch?
I've tried to minimise the overhead throughout just to be sure, but
perhaps PyTorch users should stick to the old syntax?)
ddorn added a commit to EffiSciencesResearch/ML4G that referenced this issue Mar 7, 2024
It should be used once the newest version of torchtyping works with
jupyter, because the error reporting in the old versions for jaxtyping
are bad (don't tell the shape).

See:
- agronholm/typeguard#364
- patrick-kidger/jaxtyping#6

Former-commit-id: aab201f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature next Higher-priority items
Projects
None yet
Development

No branches or pull requests

6 participants