You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
First of all, thanks for the awesome library. This library has made my code much more understandable, and the runtime type-checking with beartype has been immensely useful.
I'm currently working on an existing (PyTorch) codebase that did not previously use jaxtyping type hints, and I'm gradually adding type hints to areas that I work on. As a result, I have a handful of cases where I'm not using function argument/return type hints, but am rather using isinstance checks, e.g.,:
defmy_func(x):
x= ... # some operation I'm not touching# my codeassertisinstance(x, Float[torch.Tensor, "bsz channels"])
x=do_some_other_stuff(x)
# ... and then the rest of the code for `my_func`
In these cases, due to the way the import hooking is currently set up, I'm running into some very strange and unexpected behavior. Specifically, it seems like axis bindings in these kinds of functions just get ignored and do not get registered in the memo_stack.
This seems to be because, in the case above, my_func does not have jaxtyping type hints in its args/return types and thus will not be registered using the import hook.
For now, I've patched jaxtypings' import hook code (_import_hook.py) to also register all functions with isinstance expressions:
Hi!
First of all, thanks for the awesome library. This library has made my code much more understandable, and the runtime type-checking with beartype has been immensely useful.
I'm currently working on an existing (PyTorch) codebase that did not previously use jaxtyping type hints, and I'm gradually adding type hints to areas that I work on. As a result, I have a handful of cases where I'm not using function argument/return type hints, but am rather using
isinstance
checks, e.g.,:In these cases, due to the way the import hooking is currently set up, I'm running into some very strange and unexpected behavior. Specifically, it seems like axis bindings in these kinds of functions just get ignored and do not get registered in the
memo_stack
.This seems to be because, in the case above,
my_func
does not have jaxtyping type hints in its args/return types and thus will not be registered using the import hook.For now, I've patched jaxtypings' import hook code (
_import_hook.py
) to also register all functions withisinstance
expressions:And then checking for this in
JaxtypingTransformer
, changing the following lines to:This is a hacky fix but works in my case. Would love to hear what your thoughts on fixing this would be (and if a similar fix is warranted for now).
Thanks!
The text was updated successfully, but these errors were encountered: