Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change the type hint for nn.Module.__call__ to be friendly to overrides. #74746

Open
ppeetteerrs opened this issue Mar 25, 2022 · 6 comments
Open
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: nn Related to torch.nn module: typing Related to mypy type annotations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ppeetteerrs
Copy link

ppeetteerrs commented Mar 25, 2022

🚀 The feature, motivation and pitch

Currently, nn.Module.__call__ has type hints defined using __call__ : Callable[..., Any] = _call_impl. However, this declared type makes it difficult for the user to override the type hint of nn.Module.__call__ using inferred types. For example:

from typing import Callable, TypeVar, cast

from torch import Tensor, nn

C = TypeVar("C", bound=Callable)


def proxy(f: C) -> C:
    return cast(C, lambda self, *x, **y: super(self.__class__, self).__call__(*x, **y))


class MyModule(nn.Module):
    def forward(self, x: Tensor, y: Tensor) -> Tensor:
        return x + y

    __call__ = proxy(forward)  # __call__ exists in nn.Module

    another_call = proxy(forward)  # another_call does not exist in nn.Module


x, y = Tensor(), Tensor()

model = MyModule()

z1 = model(x, y)  # Type hint of z1 is Any, same for model.__call__(x, y)

z2 = model.another_call(x, y)  # Type hint of z2 is Tensor

If you paste the above code in VSCode, you can see that the type for z1 is Any (wrong) while that for z2 is Tensor (correct).

Alternatives

Instead, if we remove the type annotations in nn.Module and do the following instead, the above example would work as expected.

class Module:
    __call__ = _call_impl
    def _call_impl(self, *input: Any, **kwargs: Any) -> Any:
        # actual implementation

Additional context

No response

cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345 @ezyang @malfet @rgommers @xuzhao9 @gramster

@VitalyFedyunin VitalyFedyunin added module: nn Related to torch.nn module: typing Related to mypy type annotations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module enhancement Not as big of a feature, but technically not a bug. Should be easy to fix labels Mar 25, 2022
@jonasrauber
Copy link

Actually, using mypy works just fine with your example: reveal_type(z1) and reveal_type(z2) both show Revealed type is "torch._tensor.Tensor". This might just be a limitation for VSCode's type checking.

@ppeetteerrs
Copy link
Author

There might be some difference in the type inference mechanism? Just for reference I also discussed this issue in the pyright repo: microsoft/pyright#3249 (comment)

@mert-kurttutan
Copy link

mert-kurttutan commented Jan 20, 2023

For nn.Linear, reveal_type returns Any, instead of torch.Tensor. To replicate this, run the code below

pytorch=1.13
mypy=0.991
# main.py
import torch
from torch import nn


class LinWrapper(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.lin = nn.Linear(12,2)

    def forward(
        self, x: torch.Tensor
    ) -> torch.Tensor:
        out = self.lin(x)
        reveal_type(out)
        return out

Bash command to run

mypy main.py

Output:

main.py:15: note: Revealed type is "Any"

Note when tested with custom module (with type hints), it reveals to be torch.Tensor, just like @jonasrauber obtained.

@ezyang
Copy link
Contributor

ezyang commented Jan 21, 2023

Can someone pr the proposed change so we can see how it cis

@mert-kurttutan
Copy link

mert-kurttutan commented Jan 21, 2023

One possible way to go is to copy the type signature of forward method if there is any type annotation of forward function. If there is not, continue using current default one Callable[..., Any] for type hinting.
For instance,

def get_type_signature(func):
    # if func has type annotation
    # return the annotation
    # otherwise, return Callable[..., Any]
    pass

forward_type_sign = get_type_signature(self.forward)
__call__ : forward_type_sign = _call_impl

But I am not sure how to do this in a reliable way.
Maybe one solution here: typing/issues/270
even though this does not do exactly what I described, but annotates the with type of another function.

@hchau630
Copy link

any updates to this? if not, is there any simple workaround?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix module: nn Related to torch.nn module: typing Related to mypy type annotations triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants