From f43aee0cf8228cf1243a92d27edbd3cc3a1faf16 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Oct 2024 10:28:54 -0400 Subject: [PATCH 1/4] ci(buildkite): add downstream testing for NeuralOperators --- .buildkite/testing.yml | 5 ++--- Project.toml | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/.buildkite/testing.yml b/.buildkite/testing.yml index a4cfaa6e..ad88470c 100644 --- a/.buildkite/testing.yml +++ b/.buildkite/testing.yml @@ -38,7 +38,6 @@ steps: - src - ext env: - RETESTITEMS_NWORKERS: 2 BACKEND_GROUP: "AMDGPU" agents: queue: "juliagpu" @@ -126,6 +125,7 @@ steps: repo: - "Boltz" - "Lux" + - "NeuralOperators" - group: ":telescope: Downstream AMD GPU" steps: @@ -143,8 +143,6 @@ steps: queue: "juliagpu" rocm: "*" rocmgpu: "*" - env: - RETESTITEMS_NWORKERS: 2 if: build.message !~ /\[skip tests\]/ && build.message !~ /\[skip downstream\]/ && build.message !~ /\[skip ci\]/ && build.branch != "main" timeout_in_minutes: 240 matrix: @@ -152,6 +150,7 @@ steps: repo: - "Boltz" - "Lux" + - "NeuralOperators" env: JULIA_PKG_SERVER: "" diff --git a/Project.toml b/Project.toml index 7225334c..6f6005b7 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.3.4" +version = "1.3.5" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" From a8c0f3b4615f96a8773577e16fac61ba310d8123 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Oct 2024 10:38:33 -0400 Subject: [PATCH 2/4] perf: restore old batched_mul --- Project.toml | 2 +- src/impl/batched_mul.jl | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/Project.toml b/Project.toml index 6f6005b7..d7fb56a4 100644 --- a/Project.toml +++ b/Project.toml @@ -65,7 +65,7 @@ ChainRulesCore = "1.24" Compat = "4.15.0" CpuId = "0.3" DispatchDoctor = "0.4.12" -Enzyme = "0.13.1" +Enzyme = "0.13.12" EnzymeCore = "0.8.1" FastClosures = "0.3.2" ForwardDiff = "0.10.36" diff --git a/src/impl/batched_mul.jl b/src/impl/batched_mul.jl index 257b4e0f..b29b5bd4 100644 --- a/src/impl/batched_mul.jl +++ b/src/impl/batched_mul.jl @@ -61,9 +61,7 @@ function batched_matmul_cpu!(z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, batched_matmul_loopvec_impl!(z, x, y) return end - # Avoid an Enzyme segfault https://github.com/EnzymeAD/Enzyme.jl/issues/1983 - fallback_batched_matmul!(z, LoopedArrayOp(), x, y) - # NNlib.batched_mul!(z, x, y) # XXX: restore once the enzyme segfault is fixed + NNlib.batched_mul!(z, x, y) return end @@ -80,10 +78,9 @@ end function fallback_batched_matmul!( z::AbstractArray{zT, 3}, dev, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} - # XXX: bring back once the enzyme segfault is fixed - # @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \ - # $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \ - # slow." maxlog=1 + @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \ + $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \ + slow." maxlog=1 if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) || (size(x, 2) != size(y, 1)) From 583a6b4dcd7e1cedc35885c00093a858bc379353 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Oct 2024 10:49:18 -0400 Subject: [PATCH 3/4] fix: disable threading for certain devices --- src/impl/batched_mul.jl | 43 ++++++++++++++++++++++++++++++++++++----- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/src/impl/batched_mul.jl b/src/impl/batched_mul.jl index b29b5bd4..52e038bc 100644 --- a/src/impl/batched_mul.jl +++ b/src/impl/batched_mul.jl @@ -68,17 +68,17 @@ end function batched_matmul_loopvec_impl! end function fallback_batched_matmul( - dev, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} + opmode, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} z = similar(x, promote_type(eltype(x), eltype(y)), size(x, 1), size(y, 2), max(size(x, 3), size(y, 3))) - fallback_batched_matmul!(z, dev, x, y) + fallback_batched_matmul!(z, opmode, x, y) return z end function fallback_batched_matmul!( - z::AbstractArray{zT, 3}, dev, x::AbstractArray{xT, 3}, + z::AbstractArray{zT, 3}, opmode, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} - @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \ + @warn "Using fallback Batched Matrix Multiply routine for $(opmode) with A: size = \ $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \ slow." maxlog=1 @@ -87,6 +87,36 @@ function fallback_batched_matmul!( throw(DimensionMismatch(lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.")) end + if use_threaded_batched_matmul(get_device_type(x)) + unsafe_fallback_threaded_batched_matmul!(z, x, y) + else + unsafe_fallback_serial_batched_matmul!(z, x, y) + end + + return +end + +function unsafe_fallback_serial_batched_matmul!( + z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {zT, xT, yT} + if size(x, 3) == size(y, 3) + for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, L), batchview(y, L)) + end + elseif size(x, 3) == 1 + for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, 1), batchview(y, L)) + end + else # has to be size(y, 3) == 1 + for L in axes(z, 3) + mul!(batchview(z, L), batchview(x, L), batchview(y, 1)) + end + end +end + +function unsafe_fallback_threaded_batched_matmul!( + z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, + y::AbstractArray{yT, 3}) where {zT, xT, yT} old_threads = maybe_reduce_BLAS_threads(z) if size(x, 3) == size(y, 3) @@ -104,10 +134,13 @@ function fallback_batched_matmul!( end reset_BLAS_threads(old_threads) - return end +use_threaded_batched_matmul(::Type) = false +use_threaded_batched_matmul(::Type{CUDADevice}) = true +use_threaded_batched_matmul(::Type{CPUDevice}) = true + function CRC.rrule(::typeof(batched_matmul), x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {xT, yT} ∇batched_matmul = @closure Δ_ -> begin From 076688520f0e1875d9609a6436e405478bae8033 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 25 Oct 2024 11:10:40 -0400 Subject: [PATCH 4/4] revert: "perf: restore old batched_mul" This reverts commit a8c0f3b4615f96a8773577e16fac61ba310d8123. --- Project.toml | 2 +- src/impl/batched_mul.jl | 11 +++++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index d7fb56a4..6f6005b7 100644 --- a/Project.toml +++ b/Project.toml @@ -65,7 +65,7 @@ ChainRulesCore = "1.24" Compat = "4.15.0" CpuId = "0.3" DispatchDoctor = "0.4.12" -Enzyme = "0.13.12" +Enzyme = "0.13.1" EnzymeCore = "0.8.1" FastClosures = "0.3.2" ForwardDiff = "0.10.36" diff --git a/src/impl/batched_mul.jl b/src/impl/batched_mul.jl index 52e038bc..b8900d8e 100644 --- a/src/impl/batched_mul.jl +++ b/src/impl/batched_mul.jl @@ -61,7 +61,9 @@ function batched_matmul_cpu!(z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, batched_matmul_loopvec_impl!(z, x, y) return end - NNlib.batched_mul!(z, x, y) + # Avoid an Enzyme segfault https://github.com/EnzymeAD/Enzyme.jl/issues/1983 + fallback_batched_matmul!(z, LoopedArrayOp(), x, y) + # NNlib.batched_mul!(z, x, y) # XXX: restore once the enzyme segfault is fixed return end @@ -78,9 +80,10 @@ end function fallback_batched_matmul!( z::AbstractArray{zT, 3}, opmode, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT} - @warn "Using fallback Batched Matrix Multiply routine for $(opmode) with A: size = \ - $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \ - slow." maxlog=1 + # XXX: bring back once the enzyme segfault is fixed + # @warn "Using fallback Batched Matrix Multiply routine for $(dev) with A: size = \ + # $(size(x)) eltype = $(xT) and B: size = $(size(y)) eltype = $(yT). This may be \ + # slow." maxlog=1 if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) || (size(x, 2) != size(y, 1))