-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add native support for BFloat16. #51470
Conversation
I don't think twice-precision stuff matters for either Float16 or Bfloat16. twice-precision float16 is still basically unusable because it doesn't let you represent bigger or smaller values and twice-precision BFloat16 is basically unusable because it only gives you 4 digits of accuracy. That said, IEEE compatibility matters roughly equally for both IMO. |
It does, for ranges. Base has explicit tests for this behavior, at least for Float16, and since the pre-#37510 behavior (using UInt16) always extended to Float32 and demoted back to 16 bits after every operation, we couldn't break that, so we need the demote pass for Float16. With BFloat16, there's no such pre-existing behavior, and BFoat16s.jl doesn't test ranges. We may still want the demoting behavior if we ever want to implement ranges in the same way Base does though (i.e. using TwicePrecision). |
ah. ranges. yes. 😢. |
It is required, since you are declaring the results to be consistent in inference, and this would violate that assumption. OTOH, maybe we should just mark these intrinsics as not-consistent / unpredictable in inference, since that annotation isn't really valid for Float64 either anyways (NaN is not consistent). |
Is the issue that LLVM will not round on every operation but we will? |
The issue is that it is unknown and unpredictable what LLVM will do and what result you will get each time |
So one of the nice thing about us doing the demote pass is that we had equivalent behavior between software and hardware implementation. GCC added this much later under |
Do we need to implement conversion functions? |
Yes, as LLVM can emit calls to them. They are part of this PR. |
Added support for BFloat16 to the Float16 demote pass. |
Maybe add a test to llvmpasses? I imagine just copying the float16 one but using bfloat |
The Base part is done here. I'll work on BFloat16s.jl now, so maybe we shouldn't merge this until I've validated this functionality there. While updating the demote pass, I noticed that on X86 we don't demote Float16 when we have avx512fp16, while on ARM we require fp16fml. The latter defines scalar operations on Float16, while AVX512 only defines vector instructions. If we only have vector instructions, shouldn't we still demote? The same applies for BFloat16/avx512bf16, which doesn't even have scalar support on ARM. |
What does GCC 13 do with |
AVX512fp16 supports the full floating point instructions, so it just uses the vectorized ones with 1 value. For bfloat though, the only thing available is convert and dot product, so I think it shouldn't change outside of that. And from https://godbolt.org/z/Kc36svedM it just does the operations |
Ah, interesting. Doesn't seem to be the case for AVX512bf16, so I'll have to remove that part. |
Also, it looks like LLVM now implements excess precision, https://reviews.llvm.org/D136176 is probably related, so I don't think we don't need the demote pass for Float16 anymore:
But that can happen in a followup PR. |
2827502
to
b01110b
Compare
b01110b has me a bit concerned. We need to pass zext/sext for platform ABI reasons. As an example on PPC you can't pass a 16bit type in registers. So you need to extend it. IIUC the change above is breaking for custom primitive types of size 16 (or other sizes on other platforms). So we should set |
After switching to LLVM for BFloat16 in #51470 (i.e., relying on `Intrinsics.sub_float` etc instead of hand-rolling bit-twiddling implementations), we also need to provide fallback runtime implementations for these intrinsics. This is too bad; I had hoped to put as much BFloat16-related things as possible in BFloat16s.jl. This required modifying the unary operator preprocessor macros in order to differentiate between Float16 and BFloat16; I didn't generalize that to all intrinsics as the code is hairy enough already (and it's currently only useful for fptrunc/fpext).
`numsToZero` relies on being able to sample arbitrary `AbstractFloat` with `rand`, which seemingly isn't possible with the new `Core.BFloat16` introduced in JuliaLang/julia#51470. See JuliaLang/julia#53651 for the upstream issue tracking this. 1.11 also introduces `AnnotatedString`, which interacts badly with the local scope of `@testset` and trying to lazily `join` things that may degenerate in inference to `AbstractString`. The type assertion is a quick "fix", since other than moving `irb` outside of that scope, inference will continue to mess with the test, even though no `AnnotatedString` could ever actually be produced.
This PR adds native support for the LLVM
bfloat
type, through a newBFloat16
type. It doesn't however add any language-level functionality, only the bare minimum (e.g. runtime conversion routines), and it will thus still be required to use the BFloat16s.jl package.One element that needs to be discussed, is that I didn't add a BFloat16-demote pass. This means that the back-end will be able to perform multiple operations in extended precision, before demoting back to 16-bits at the end, resulting in different results than if you were to perform the operations separately. That wasn't acceptable for Float16, because of TwicePrecision-like hacks and IEEE compatibility, but hopefully we don't require this for BFloat16.
Draft, as I still need to update BFloat16s.jl, and test on other platforms.
Alternative to #50607. LLVM only supports a limited number of types, so it currently doesn't seem worth the complexity to be able to dynamically register new types with codegen.
Fixes #41075, but we'll need LLVM 17 before we can emit AVX512BF16 instructions.
cc @chriselrod @vchuravy