Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functions without type hints and import hook #197

Closed
nimashoghi opened this issue Apr 2, 2024 · 2 comments
Closed

Functions without type hints and import hook #197

nimashoghi opened this issue Apr 2, 2024 · 2 comments

Comments

@nimashoghi
Copy link
Contributor

Hi!

First of all, thanks for the awesome library. This library has made my code much more understandable, and the runtime type-checking with beartype has been immensely useful.

I'm currently working on an existing (PyTorch) codebase that did not previously use jaxtyping type hints, and I'm gradually adding type hints to areas that I work on. As a result, I have a handful of cases where I'm not using function argument/return type hints, but am rather using isinstance checks, e.g.,:

def my_func(x):
    x = ... # some operation I'm not touching
    # my code
    assert isinstance(x, Float[torch.Tensor, "bsz channels"])
    x = do_some_other_stuff(x)
    # ... and then the rest of the code for `my_func`

In these cases, due to the way the import hooking is currently set up, I'm running into some very strange and unexpected behavior. Specifically, it seems like axis bindings in these kinds of functions just get ignored and do not get registered in the memo_stack.

This seems to be because, in the case above, my_func does not have jaxtyping type hints in its args/return types and thus will not be registered using the import hook.

For now, I've patched jaxtypings' import hook code (_import_hook.py) to also register all functions with isinstance expressions:

def _has_isinstance(func_def):
    for node in ast.walk(func_def):
        if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == "isinstance":
            return True
    return False

And then checking for this in JaxtypingTransformer, changing the following lines to:

    def visit_FunctionDef(self, node: ast.FunctionDef):
        has_annotated_args = any(arg for arg in node.args.args if arg.annotation)
        has_annotated_return = bool(node.returns)
        has_isinstance = _has_isinstance(node)
        if has_annotated_args or has_annotated_return or has_isinstance:

This is a hacky fix but works in my case. Would love to hear what your thoughts on fixing this would be (and if a similar fix is warranted for now).

Thanks!

@patrick-kidger
Copy link
Owner

Ah, good catch!

Maybe we should just remove the has_annotated_args or has_annotated_return call? I don't think it's important, it was just a minor efficiency thing.

I'd be happy to take a PR on this.

@patrick-kidger
Copy link
Owner

Closing as accomplished in #205, which corresponds to jaxtyping version 0.2.29.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants