From 9df989339f9554a4fb160c4396d065f964758207 Mon Sep 17 00:00:00 2001 From: Christian Guinard <28689358+christiangnrd@users.noreply.github.com> Date: Tue, 1 Oct 2024 17:46:53 -0300 Subject: [PATCH] More bfloat support and test fixes --- src/device/intrinsics/math.jl | 4 ++-- src/device/intrinsics/simd.jl | 2 +- test/device/intrinsics.jl | 9 +++++---- test/runtests.jl | 8 +++++++- 4 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index e7544c1f0..125f41bf8 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -418,7 +418,7 @@ end j = fma(1.442695f0, a, 12582912.0f0) j = j - 12582912.0f0 i = unsafe_trunc(Int32, j) - f = fma(j, -6.93145752f-1, a) # log_2_hi + f = fma(j, -6.93145752f-1, a) # log_2_hi f = fma(j, -1.42860677f-6, f) # log_2_lo # approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2] @@ -460,4 +460,4 @@ end end return r -end \ No newline at end of file +end diff --git a/src/device/intrinsics/simd.jl b/src/device/intrinsics/simd.jl index e8815797d..8d83e92bd 100644 --- a/src/device/intrinsics/simd.jl +++ b/src/device/intrinsics/simd.jl @@ -7,7 +7,7 @@ function convert_origin(origin::NTuple{2, Int64}) return (VecElement{Int64}(origin[1]-1), VecElement{Int64}(origin[2]-1)) end -for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"), (:BFloat16, "bf18")) +for (jltype, suffix) in ((:Float16, "f16"), (:Float32, "f32"), (:BFloat16, "bf16")) for as in (AS.Device, AS.ThreadGroup) @eval begin @device_function simdgroup_load( diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index 3b5155b7b..a90812242 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -275,9 +275,9 @@ end end @testset "parametrically typed" begin - typs = [Int32, Int64, Float32] + types = [Int32, Int64, Float32] metal_support() >= v"3.1" && push!(types, BFloat16) - @testset for typ in typs + @testset for typ in types function kernel(d::MtlDeviceArray{T}, n) where {T} t = thread_position_in_threadgroup_1d() tr = n-t+1 @@ -405,8 +405,9 @@ end return end - a = MtlArray(rand(typ, 8, 8)) - b = MtlArray(rand(typ, 8, 8)) + #Use `ones` for figuring out issues + a = MtlArray(ones(typ, 8, 8)) + b = MtlArray(ones(typ, 8, 8)) c = MtlArray(zeros(typ, 8, 8)) @metal threads=(8, 8) kernel(a, b, c) @test Array(a) * Array(b) ≈ Array(c) diff --git a/test/runtests.jl b/test/runtests.jl index 0554e8689..584dc3183 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -81,6 +81,12 @@ const gpuarr_eltypes = [Int16, Int32, Int64, ComplexF16, ComplexF32] const gpuarr_eltypes_nobf16 = copy(gpuarr_eltypes) +# don't test BFloat16 for unsupported operations +nobf16_tests = ["random", "reductions/reducedim!", + "reductions/mapreducedim!_large", "reductions/mapreduce", + "reductions/== isequal", "reductions/minimum maximum extrema", + "reductions/sum prod", "reductions/mapreducedim!", "reductions/reduce"] + # Add BFloat16 for tests that use it Metal.metal_support() >= v"3.1" && push!(gpuarr_eltypes, BFloat16) @@ -90,7 +96,7 @@ for name in keys(TestSuite.tests) continue end - tmp_eltypes = name in ["random"] ? gpuarr_eltypes_nobf16 : gpuarr_eltypes + tmp_eltypes = name in nobf16_tests ? gpuarr_eltypes_nobf16 : gpuarr_eltypes push!(tests, "gpuarrays$(Base.Filesystem.path_separator)$name") test_runners["gpuarrays$(Base.Filesystem.path_separator)$name"] = ()->TestSuite.tests[name](MtlArray;eltypes=tmp_eltypes)