You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As the kernels seem to be limited to the FP32 data type at the moment, it would be immensely helpful to have the implementations support mixed precision computations (FP16 and BF16) as well. This would be helpful for broader ranging applications in NLP, not just in graph neural nets.
How involved would enabling mixed-precision computations be? Any pointers to potentially start a PR?
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered:
So adding on to @sidnb13 's comments here, it looks like segment_matmul just takes in two Tensor types here which are simply torch.Tensors, right? And torch.Tensor does have native support for bfloat16 / torch.float16 / torch.half, right? The weird thing is that when one tries to run segment_matmul on two tensors cast to bfloat16, you get this error:
(segment|grouped)_matmul had an incomplete dispatch types set, I've fixed CPU implementation with pyg-lib @ 272 (@puririshi98 could you please take a look at CUDA implementation?). If you find any custom operation that is lacking bf16 support you can take a look at @yanbing-j PRs, e.g. pytorch_scatter @ 316 and pytorch_scatter @ 375.
🚀 The feature, motivation and pitch
As the kernels seem to be limited to the FP32 data type at the moment, it would be immensely helpful to have the implementations support mixed precision computations (FP16 and BF16) as well. This would be helpful for broader ranging applications in NLP, not just in graph neural nets.
How involved would enabling mixed-precision computations be? Any pointers to potentially start a PR?
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: