-
-
Notifications
You must be signed in to change notification settings - Fork 4.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add types to vision_transformers.py #2036
Conversation
|
||
def forward(self, x): | ||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
B, N, C = x.shape | ||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using einops
here could improve readability, though transformers
chose not to, as it may interfere with torch.compiler or onnx.export (see huggingface/transformers#25110, arogozhnikov/einops#250 and arogozhnikov/einops#274)
The documentation is not available anymore as the PR was closed or merged. |
@Laurent2916 I do like the typing changes, but I do not want the rest (formatting, pos -> kwarg for obvious ones, inplace changes x +=, etc). Few comments on formatting:
|
Alright I reverted most of the formatting that I applied. |
@Laurent2916 awesome, thx, should be able to merge this later today |
Hi !
In continuation of #1989, here's a PR that adds the missing types to
models/vision_transformers.py
, plus a few other things.I haven't added any docstrings, as
VisionTransformer
already has a docstring.