From 730511bdd9d39a0c06d271ff9f051b3b25d9ee93 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Thu, 2 Nov 2023 09:39:20 +0100 Subject: [PATCH] Adapt to upstream changes wrt. native support for BFloat16 (#51) --- src/bfloat16.jl | 245 ++++++++++++++++++++++++++++++++++++----------- test/runtests.jl | 14 ++- 2 files changed, 202 insertions(+), 57 deletions(-) diff --git a/src/bfloat16.jl b/src/bfloat16.jl index 218d48a..ebe02d6 100644 --- a/src/bfloat16.jl +++ b/src/bfloat16.jl @@ -2,7 +2,7 @@ import Base: isfinite, isnan, precision, iszero, eps, typemin, typemax, floatmin, floatmax, sign_mask, exponent_mask, significand_mask, exponent_bits, significand_bits, exponent_bias, - exponent_one, exponent_half, + exponent_one, exponent_half, leading_zeros, signbit, exponent, significand, frexp, ldexp, round, Int16, Int32, Int64, +, -, *, /, ^, ==, <, <=, >=, >, !=, inv, @@ -13,33 +13,66 @@ import Base: isfinite, isnan, precision, iszero, eps, asin, acos, atan, acsc, asec, acot, sinh, cosh, tanh, csch, sech, coth, asinh, acosh, atanh, acsch, asech, acoth, - bitstring - -primitive type BFloat16 <: AbstractFloat 16 end + bitstring, isinteger + +import Printf + +# LLVM 11 added support for BFloat16 in the IR; Julia 1.11 added support for generating +# code that uses the `bfloat` IR type, together with the necessary runtime functions. +# However, not all LLVM targets support `bfloat`. If the target can store/load BFloat16s +# (and supports synthesizing constants) we can use the `bfloat` IR type, otherwise we fall +# back to defining a primitive type that will be represented as an `i16`. If, in addition, +# the target supports BFloat16 arithmetic, we can use LLVM intrinsics. +# - x86: storage and arithmetic support in LLVM 15 +# - aarch64: storage support in LLVM 17 +const llvm_storage = if isdefined(Core, :BFloat16) + if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15" + true + elseif Sys.ARCH == :aarch64 && Base.libllvm_version >= v"17" + true + else + false + end +else + false +end +const llvm_arithmetic = if llvm_storage + using Core: BFloat16 + if Sys.ARCH in [:x86_64, :i686] && Base.libllvm_version >= v"15" + true + else + false + end +else + primitive type BFloat16 <: AbstractFloat 16 end + false +end Base.reinterpret(::Type{Unsigned}, x::BFloat16) = reinterpret(UInt16, x) Base.reinterpret(::Type{Signed}, x::BFloat16) = reinterpret(Int16, x) # Floating point property queries for f in (:sign_mask, :exponent_mask, :exponent_one, - :exponent_half, :significand_mask) + :exponent_half, :significand_mask) @eval $(f)(::Type{BFloat16}) = UInt16($(f)(Float32) >> 16) end - Base.exponent_bias(::Type{BFloat16}) = 127 Base.exponent_bits(::Type{BFloat16}) = 8 Base.significand_bits(::Type{BFloat16}) = 7 Base.signbit(x::BFloat16) = (reinterpret(Unsigned, x) & 0x8000) !== 0x0000 function Base.significand(x::BFloat16) - result = abs_significand(x) - ifelse(signbit(x), -result, result) -end - -@inline function abs_significand(x::BFloat16) - usig = Base.significand_mask(BFloat16) & reinterpret(Unsigned, x) - isig = Int16(usig) - 1 + isig / BFloat16(2)^7 + xu = reinterpret(Unsigned, x) + xs = xu & ~sign_mask(BFloat16) + xs >= exponent_mask(BFloat16) && return x # NaN or Inf + if xs <= (~exponent_mask(BFloat16) & ~sign_mask(BFloat16)) # x is subnormal + xs == 0 && return x # +-0 + m = unsigned(leading_zeros(xs) - exponent_bits(BFloat16)) + xs <<= m + xu = xs | (xu & sign_mask(BFloat16)) + end + xu = (xu & ~exponent_mask(BFloat16)) | exponent_one(BFloat16) + return reinterpret(BFloat16, xu) end Base.exponent(x::BFloat16) = @@ -65,16 +98,24 @@ isnan(x::BFloat16) = (reinterpret(Unsigned,x) & ~sign_mask(BFloat16)) > exponent precision(::Type{BFloat16}) = 8 eps(::Type{BFloat16}) = Base.bitcast(BFloat16, 0x3c00) -round(x::BFloat16, r::RoundingMode{:Up}) = BFloat16(ceil(Float32(x))) -round(x::BFloat16, r::RoundingMode{:Down}) = BFloat16(floor(Float32(x))) -round(x::BFloat16, r::RoundingMode{:Nearest}) = BFloat16(round(Float32(x))) +## Rounding ## +if llvm_arithmetic + round(x::BFloat16, ::RoundingMode{:ToZero}) = Base.trunc_llvm(x) + round(x::BFloat16, ::RoundingMode{:Down}) = Base.floor_llvm(x) + round(x::BFloat16, ::RoundingMode{:Up}) = Base.ceil_llvm(x) + round(x::BFloat16, ::RoundingMode{:Nearest}) = Base.rint_llvm(x) +else + round(x::BFloat16, r::RoundingMode{:ToZero}) = BFloat16(trunc(Float32(x))) + round(x::BFloat16, r::RoundingMode{:Down}) = BFloat16(floor(Float32(x))) + round(x::BFloat16, r::RoundingMode{:Up}) = BFloat16(ceil(Float32(x))) + round(x::BFloat16, r::RoundingMode{:Nearest}) = BFloat16(round(Float32(x))) +end +# round(::Type{Signed}, x::BFloat16, r::RoundingMode) = round(Int, x, r) +# round(::Type{Unsigned}, x::BFloat16, r::RoundingMode) = round(UInt, x, r) +# round(::Type{Integer}, x::BFloat16, r::RoundingMode) = round(Int, x, r) Base.trunc(bf::BFloat16) = signbit(bf) ? ceil(bf) : floor(bf) -Int64(x::BFloat16) = Int64(Float32(x)) -Int32(x::BFloat16) = Int32(Float32(x)) -Int16(x::BFloat16) = Int16(Float32(x)) - ## floating point traits ## """ InfB16 @@ -100,56 +141,88 @@ Base.trunc(::Type{BFloat16}, x::Float32) = reinterpret(BFloat16, (reinterpret(UInt32, x) >> 16) % UInt16 ) -# Conversion from Float32 -function BFloat16(x::Float32) - isnan(x) && return NaNB16 - # Round to nearest even (matches TensorFlow and our convention for - # rounding to lower precision floating point types). - h = reinterpret(UInt32, x) - h += 0x7fff + ((h >> 16) & 1) - return reinterpret(BFloat16, (h >> 16) % UInt16) -end +if llvm_arithmetic + BFloat16(x::Float32) = Base.fptrunc(BFloat16, x) + BFloat16(x::Float64) = Base.fptrunc(BFloat16, x) + + # XXX: can LLVM do this natively? + BFloat16(x::Float16) = BFloat16(Float32(x)) +else + # Conversion from Float32 + function BFloat16(x::Float32) + isnan(x) && return NaNB16 + # Round to nearest even (matches TensorFlow and our convention for + # rounding to lower precision floating point types). + h = reinterpret(UInt32, x) + h += 0x7fff + ((h >> 16) & 1) + return reinterpret(BFloat16, (h >> 16) % UInt16) + end -# Conversion from Float64 -function BFloat16(x::Float64) - BFloat16(Float32(x)) -end + # Conversion from Float64 + function BFloat16(x::Float64) + BFloat16(Float32(x)) + end -# Conversion from Float16 -function BFloat16(x::Float16) - BFloat16(Float32(x)) + # Conversion from Float16 + function BFloat16(x::Float16) + BFloat16(Float32(x)) + end end # Conversion from Integer -function BFloat16(x::Integer) - convert(BFloat16, convert(Float32, x)) +if llvm_arithmetic + for st in (Int8, Int16, Int32, Int64) + @eval begin + BFloat16(x::($st)) = Base.sitofp(BFloat16, x) + end + end + for ut in (Bool, UInt8, UInt16, UInt32, UInt64) + @eval begin + BFloat16(x::($ut)) = Base.uitofp(BFloat16, x) + end + end +else + BFloat16(x::Integer) = convert(BFloat16, convert(Float32, x)) end +# TODO: optimize +BFloat16(x::UInt128) = convert(BFloat16, Float64(x)) +BFloat16(x::Int128) = convert(BFloat16, Float64(x)) # Conversion to Float16 function Base.Float16(x::BFloat16) Float16(Float32(x)) end -# Expansion to Float32 -function Base.Float32(x::BFloat16) - reinterpret(Float32, UInt32(reinterpret(Unsigned, x)) << 16) -end +if llvm_arithmetic + Base.Float32(x::BFloat16) = Base.fpext(Float32, x) + Base.Float64(x::BFloat16) = Base.fpext(Float64, x) +else + # Expansion to Float32 + function Base.Float32(x::BFloat16) + reinterpret(Float32, UInt32(reinterpret(Unsigned, x)) << 16) + end -# Expansion to Float64 -function Base.Float64(x::BFloat16) - Float64(Float32(x)) + # Expansion to Float64 + function Base.Float64(x::BFloat16) + Float64(Float32(x)) + end end -# Truncation to integer types -Base.unsafe_trunc(T::Type{<:Integer}, x::BFloat16) = unsafe_trunc(T, Float32(x)) -Base.trunc(::Type{T}, x::BFloat16) where {T<:Integer} = trunc(T, Float32(x)) - # Basic arithmetic -for f in (:+, :-, :*, :/, :^) - @eval ($f)(x::BFloat16, y::BFloat16) = BFloat16($(f)(Float32(x), Float32(y))) +if llvm_arithmetic + +(x::T, y::T) where {T<:BFloat16} = Base.add_float(x, y) + -(x::T, y::T) where {T<:BFloat16} = Base.sub_float(x, y) + *(x::T, y::T) where {T<:BFloat16} = Base.mul_float(x, y) + /(x::T, y::T) where {T<:BFloat16} = Base.div_float(x, y) + -(x::BFloat16) = Base.neg_float(x) + ^(x::BFloat16, y::BFloat16) = BFloat16(Float32(x)^Float32(y)) +else + for f in (:+, :-, :*, :/, :^) + @eval ($f)(x::BFloat16, y::BFloat16) = BFloat16($(f)(Float32(x), Float32(y))) + end + -(x::BFloat16) = reinterpret(BFloat16, reinterpret(Unsigned, x) ⊻ sign_mask(BFloat16)) end --(x::BFloat16) = reinterpret(BFloat16, reinterpret(Unsigned, x) ⊻ sign_mask(BFloat16)) -^(x::BFloat16, y::Integer) = BFloat16(^(Float32(x), y)) +^(x::BFloat16, y::Integer) = BFloat16(Float32(x)^y) const ZeroBFloat16 = BFloat16(0.0f0) const OneBFloat16 = BFloat16(1.0f0) @@ -185,7 +258,68 @@ for t in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt end # Wide multiplication -Base.widemul(x::BFloat16, y::BFloat16) = Float32(x) * Float32(y) +Base.widemul(x::BFloat16, y::BFloat16) = widen(x) * widen(y) + +# Truncation to integer types +if llvm_arithmetic + for Ti in (Int8, Int16, Int32, Int64) + @eval begin + Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptosi($Ti, x) + end + end + for Ti in (UInt8, UInt16, UInt32, UInt64) + @eval begin + Base.unsafe_trunc(::Type{$Ti}, x::BFloat16) = Base.fptoui($Ti, x) + end + end +else + Base.unsafe_trunc(T::Type{<:Integer}, x::BFloat16) = unsafe_trunc(T, Float32(x)) +end +for Ti in (Int8, Int16, Int32, Int64, Int128, UInt8, UInt16, UInt32, UInt64, UInt128) + if Ti <: Unsigned || sizeof(Ti) < 2 + # Here `BFloat16(typemin(Ti))-1` is exact, so we can compare the lower-bound + # directly. `BFloat16(typemax(Ti))+1` is either always exactly representable, or + # rounded to `Inf` (e.g. when `Ti==UInt128 && BFloat16==Float32`). + @eval begin + function Base.trunc(::Type{$Ti}, x::BFloat16) + if $(BFloat16(typemin(Ti))-one(BFloat16)) < x < $(BFloat16(typemax(Ti))+one(BFloat16)) + return Base.unsafe_trunc($Ti,x) + else + throw(InexactError(:trunc, $Ti, x)) + end + end + function (::Type{$Ti})(x::BFloat16) + if ($(BFloat16(typemin(Ti))) <= x <= $(BFloat16(typemax(Ti)))) && isinteger(x) + return Base.unsafe_trunc($Ti,x) + else + throw(InexactError($(Expr(:quote,Ti.name.name)), $Ti, x)) + end + end + end + else + # Here `eps(BFloat16(typemin(Ti))) > 1`, so the only value which can be + # truncated to `BFloat16(typemin(Ti)` is itself. Similarly, + # `BFloat16(typemax(Ti))` is inexact and will be rounded up. This assumes that + # `BFloat16(typemin(Ti)) > -Inf`, which is true for these types, but not for + # `Float16` or larger integer types. + @eval begin + function Base.trunc(::Type{$Ti}, x::BFloat16) + if $(BFloat16(typemin(Ti))) <= x < $(BFloat16(typemax(Ti))) + return unsafe_trunc($Ti,x) + else + throw(InexactError(:trunc, $Ti, x)) + end + end + function (::Type{$Ti})(x::BFloat16) + if ($(BFloat16(typemin(Ti))) <= x < $(BFloat16(typemax(Ti)))) && isinteger(x) + return unsafe_trunc($Ti,x) + else + throw(InexactError($(Expr(:quote,Ti.name.name)), $Ti, x)) + end + end + end + end +end # Showing function Base.show(io::IO, x::BFloat16) @@ -200,6 +334,7 @@ function Base.show(io::IO, x::BFloat16) hastypeinfo || print(io, ")") end end +Printf.tofloat(x::BFloat16) = Float32(x) # Random import Random: rand, randn, randexp, AbstractRNG, Sampler diff --git a/test/runtests.jl b/test/runtests.jl index 44f4bcd..8196e9d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,7 @@ using Test, BFloat16s, Printf, Random +@info "Testing BFloat16s" BFloat16s.llvm_storage BFloat16s.llvm_arithmetic + @testset "comparisons" begin @test BFloat16(1) < BFloat16(2) @test BFloat16(1f0) < BFloat16(2f0) @@ -27,6 +29,14 @@ end @test Int64(BFloat16(10)) == Int64(10) end +@testset "abi" begin + f() = BFloat16(1) + @test f() == BFloat16(1) + + g(x) = x+BFloat16(1) + @test g(BFloat16(2)) == BFloat16(3) +end + @testset "functions" begin @test abs(BFloat16(-10)) == BFloat16(10) @test BFloat16(2) ^ BFloat16(4) == BFloat16(16) @@ -58,7 +68,7 @@ end ("%.2a", "0x1.3cp+0"), ("%.2A", "0X1.3CP+0")), num in (BFloat16(1.234),) - @test @eval(@sprintf($fmt, $num) == $val) + @eval @test @sprintf($fmt, $num) == $val end @test (@sprintf "%f" BFloat16(Inf)) == "Inf" @test (@sprintf "%f" BFloat16(NaN)) == "NaN" @@ -73,7 +83,7 @@ end ("%-+10.5g", "+123.5 "), ("%010.5g", "00000123.5")), num in (BFloat16(123.5),) - @test @eval(@sprintf($fmt, $num) == $val) + @eval @test @sprintf($fmt, $num) == $val end @test( @sprintf( "%10.5g", BFloat16(-123.5) ) == " -123.5") @test( @sprintf( "%010.5g", BFloat16(-123.5) ) == "-0000123.5")