Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: patch more enzyme issues
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 18, 2024
1 parent ddb384c commit 08f8448
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 12 deletions.
11 changes: 6 additions & 5 deletions src/impl/batched_mul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,15 @@ function batched_matmul!(z::AbstractArray{zT, 3}, ::LoopedArrayOp,
return
end

function batched_matmul_cpu!(z::AbstractArray{zT, 3},
x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3}) where {zT, xT, yT}
function batched_matmul_cpu!(
z::AbstractArray{zT, 3}, x::AbstractArray{xT, 3}, y::AbstractArray{yT, 3},
α::Number=true, β::Number=false) where {zT, xT, yT}
if can_loopvec_args(batchview(z, 1), batchview(x, 1), batchview(y, 1)) &&
!unsafe_known(explicit_blas_loaded())
batched_matmul_loopvec_impl!(z, x, y)
batched_matmul_loopvec_impl!(z, x, y, α, β)
return
end
NNlib.batched_mul!(z, x, y)
NNlib.batched_mul!(z, x, y, α, β)
return
end

Expand Down Expand Up @@ -120,7 +121,7 @@ end
# This is type-piracy but needed to fix a blocking issue. TODO: upstream to NNlib
# Enzyme causes a "active variables passed by value to jl_new_task are not yet supported"
# warning without this patch.
for func in (NNlib.batched_mul!, batched_matmul_cpu!)
for func in (NNlib.batched_mul!, batched_matmul_loopvec_impl!)
@eval begin
function EnzymeRules.augmented_primal(
cfg::EnzymeRules.RevConfigWidth, ::EnzymeCore.Const{typeof($(func))},
Expand Down
7 changes: 7 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ const KA = KernelAbstractions

is_extension_loaded(::Val) = False()

CRC.@non_differentiable is_extension_loaded(::Any...)
EnzymeRules.inactive_noinl(::typeof(is_extension_loaded), ::Any...) = nothing

# Simple Operations -- no rrules needed
ofeltype_array(::Type{T}, x::AbstractArray{T}) where {T} = x
function ofeltype_array(
Expand Down Expand Up @@ -328,4 +331,8 @@ end

@inline can_loopvec_args_check(::False, args...) = false

CRC.@non_differentiable can_loopvec_args_check(::Any...)

EnzymeRules.inactive_noinl(::typeof(can_loopvec_args_check), ::Any...) = nothing

end
2 changes: 1 addition & 1 deletion test/common_ops/activation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
@jet apply_act_fast2(f, x)

@test @inferred(Zygote.gradient(apply_act, f, x)) isa Any
if f !== lisht || (f === lisht && T == Float32 && !ongpu)
if f !== lisht
@test @inferred(Zygote.gradient(apply_act_fast, f, x)) isa Any
end
@test @inferred(Zygote.gradient(apply_act_fast2, f, x)) isa Any
Expand Down
11 changes: 5 additions & 6 deletions test/common_ops/bias_act_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,11 @@
@jet bias_act_loss2(act, x, b)
@jet bias_act_loss3(act, x, b)

if (act !== lisht || (act === lisht && T == Float32 && !ongpu)) && T != Float16
@test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any
@test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any
elseif T != Float16
@test_broken @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any
@test_broken @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any
if act !== lisht
@test @inferred(Zygote.gradient(bias_act_loss2, act, x, b)) isa Any broken=(T ==
Float16)
@test @inferred(Zygote.gradient(bias_act_loss3, act, x, b)) isa Any broken=(T ==
Float16)
end

@test_gradients(__Fix1(bias_act_loss1, act), x, b; atol, rtol,
Expand Down

0 comments on commit 08f8448

Please sign in to comment.