Skip to content

Commit

Permalink
More bfloat support and test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
christiangnrd committed Oct 10, 2024
1 parent 64ae892 commit 9df9893
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ end
j = fma(1.442695f0, a, 12582912.0f0)
j = j - 12582912.0f0
i = unsafe_trunc(Int32, j)
f = fma(j, -6.93145752f-1, a) # log_2_hi
f = fma(j, -6.93145752f-1, a) # log_2_hi
f = fma(j, -1.42860677f-6, f) # log_2_lo

# approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
Expand Down Expand Up @@ -460,4 +460,4 @@ end
end

return r
end
end
2 changes: 1 addition & 1 deletion src/device/intrinsics/simd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ function convert_origin(origin::NTuple{2, Int64})
return (VecElement{Int64}(origin[1]-1), VecElement{Int64}(origin[2]-1))
end

for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"), (:BFloat16, "bf18"))
for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"), (:BFloat16, "bf16"))
for as in (AS.Device, AS.ThreadGroup)
@eval begin
@device_function simdgroup_load(
Expand Down
9 changes: 5 additions & 4 deletions test/device/intrinsics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,9 @@ end
end

@testset "parametrically typed" begin
typs = [Int32, Int64, Float32]
types = [Int32, Int64, Float32]
metal_support() >= v"3.1" && push!(types, BFloat16)
@testset for typ in typs
@testset for typ in types
function kernel(d::MtlDeviceArray{T}, n) where {T}
t = thread_position_in_threadgroup_1d()
tr = n-t+1
Expand Down Expand Up @@ -405,8 +405,9 @@ end
return
end

a = MtlArray(rand(typ, 8, 8))
b = MtlArray(rand(typ, 8, 8))
#Use `ones` for figuring out issues
a = MtlArray(ones(typ, 8, 8))
b = MtlArray(ones(typ, 8, 8))
c = MtlArray(zeros(typ, 8, 8))
@metal threads=(8, 8) kernel(a, b, c)
@test Array(a) * Array(b) Array(c)
Expand Down
8 changes: 7 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,12 @@ const gpuarr_eltypes = [Int16, Int32, Int64,
ComplexF16, ComplexF32]
const gpuarr_eltypes_nobf16 = copy(gpuarr_eltypes)

# don't test BFloat16 for unsupported operations
nobf16_tests = ["random", "reductions/reducedim!",
"reductions/mapreducedim!_large", "reductions/mapreduce",
"reductions/== isequal", "reductions/minimum maximum extrema",
"reductions/sum prod", "reductions/mapreducedim!", "reductions/reduce"]

# Add BFloat16 for tests that use it
Metal.metal_support() >= v"3.1" && push!(gpuarr_eltypes, BFloat16)

Expand All @@ -90,7 +96,7 @@ for name in keys(TestSuite.tests)
continue
end

tmp_eltypes = name in ["random"] ? gpuarr_eltypes_nobf16 : gpuarr_eltypes
tmp_eltypes = name in nobf16_tests ? gpuarr_eltypes_nobf16 : gpuarr_eltypes

push!(tests, "gpuarrays$(Base.Filesystem.path_separator)$name")
test_runners["gpuarrays$(Base.Filesystem.path_separator)$name"] = ()->TestSuite.tests[name](MtlArray;eltypes=tmp_eltypes)
Expand Down

0 comments on commit 9df9893

Please sign in to comment.