Change the type hint for nn.Module.__call__ to be friendly to overrides. #74746
Labels
enhancement
Not as big of a feature, but technically not a bug. Should be easy to fix
module: nn
Related to torch.nn
module: typing
Related to mypy type annotations
triaged
This issue has been looked at a team member, and triaged and prioritized into an appropriate module
🚀 The feature, motivation and pitch
Currently,
nn.Module.__call__
has type hints defined using__call__ : Callable[..., Any] = _call_impl
. However, this declared type makes it difficult for the user to override the type hint ofnn.Module.__call__
using inferred types. For example:If you paste the above code in VSCode, you can see that the type for z1 is
Any
(wrong) while that forz2
is Tensor (correct).Alternatives
Instead, if we remove the type annotations in
nn.Module
and do the following instead, the above example would work as expected.Additional context
No response
cc @albanD @mruberry @jbschlosser @walterddr @kshitij12345 @ezyang @malfet @rgommers @xuzhao9 @gramster
The text was updated successfully, but these errors were encountered: