Skip to content

Commit

Permalink
Explore removal of triangular rule (EnzymeAD#1614)
Browse files Browse the repository at this point in the history
* Explore removal of triangular rule

* Test triangular solve

* restore for complex

* Update Project.toml

* Update runtests.jl

* Update Project.toml
  • Loading branch information
wsmoses authored Jul 8, 2024
1 parent c799584 commit c83fcf8
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 98 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ EnzymeStaticArraysExt = "StaticArrays"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.7.5"
Enzyme_jll = "0.0.131"
Enzyme_jll = "0.0.133"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26"
LLVM = "6.1, 7, 8"
ObjectFile = "0.4"
Expand Down
94 changes: 4 additions & 90 deletions src/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,10 @@ function EnzymeRules.reverse(config, func::Const{typeof(\)}, ::Type{RT}, cache,
end

const EnzymeTriangulars = Union{
UpperTriangular,
LowerTriangular,
UnitUpperTriangular,
UnitLowerTriangular
UpperTriangular{<:Complex},
LowerTriangular{<:Complex},
UnitUpperTriangular{<:Complex},
UnitLowerTriangular{<:Complex}
}

function EnzymeRules.augmented_primal(
Expand Down Expand Up @@ -803,89 +803,3 @@ function EnzymeRules.forward(func::Const{typeof(ldiv!)},
end
end
end


# y=inv(A) B
# dA −= z y^T
# dB += z, where z = inv(A^T) dy
# ->
#
# B(out)=inv(A) B(in)
# dA −= z B(out)^T
# dB = z, where z = inv(A^T) dB
# function EnzymeRules.augmented_primal(
# config,
# func::Const{typeof(ldiv!)},
# RT::Type{<:Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated}},
#
# A::Annotation{<:Cholesky},
# B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated};
# kwargs...
# )
# func.val(A.val, B.val; kwargs...)
#
# cache_Bout = if !isa(A, Const) && !isa(B, Const)
# if EnzymeRules.overwritten(config)[3]
# copy(B.val)
# else
# B.val
# end
# else
# nothing
# end
#
# cache_A = if !isa(B, Const)
# if EnzymeRules.overwritten(config)[2]
# copy(A.val)
# else
# A.val
# end
# else
# nothing
# end
#
# primal = if EnzymeRules.needs_primal(config)
# B.val
# else
# nothing
# end
#
# shadow = if EnzymeRules.needs_shadow(config)
# B.dval
# else
# nothing
# end
#
# return EnzymeRules.AugmentedReturn(primal, shadow, (cache_A, cache_Bout))
# end
#
# function EnzymeRules.reverse(
# config,
# func::Const{typeof(ldiv!)},
# dret,
# cache,
# A::Annotation{<:Cholesky},
# B::Union{Const, DuplicatedNoNeed, Duplicated, BatchDuplicatedNoNeed, BatchDuplicated};
# kwargs...
# )
# if !isa(B, Const)
#
# (cache_A, cache_Bout) = cache
#
# for b in 1:EnzymeRules.width(config)
#
# dB = EnzymeRules.width(config) == 1 ? B.dval : B.dval[b]
#
# # dB = z, where z = inv(A^T) dB
# # dA −= z B(out)^T
#
# func.val(cache_A, dB; kwargs...)
# if !isa(A, Const)
# dA = EnzymeRules.width(config) == 1 ? A.dval : A.dval[b]
# mul!(dA.factors, dB, transpose(cache_Bout), -1, 1)
# end
# end
# end
#
# return (nothing, nothing)
# end
32 changes: 32 additions & 0 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,38 @@ end
@test dA (-z * transpose(y))
end

function tr_solv(A, B, uplo, trans, diag, idx)
B = copy(B)
LAPACK.trtrs!(uplo, trans, diag, A, B)
return @inbounds B[idx]
end


@testset "Reverse triangular solve" begin
A = [0.7550523937508613 0.7979976952197996 0.29318222271218364; 0.4416768066117529 0.4335305304334933 0.8895389673238051; 0.07752980210005678 0.05978245503334367 0.4504482683752542]
B = [0.10527381151977078 0.5450388247476627 0.3179106723232359 0.43919576779182357 0.20974326586875847; 0.7551160501548224 0.049772782182839426 0.09284926395551141 0.07862188927391855 0.17346407477062986; 0.6258040138863172 0.5928022963567454 0.24251650865340169 0.6626410383247967 0.32752198021506784]
for idx in 1:15
for uplo in ('L', 'U')
for diag in ('N', 'U')
for trans in ('N', 'T')
dA = zero(A)
dB = zero(B)
Enzyme.autodiff(Reverse, tr_solv, Duplicated(A, dA), Duplicated(B, dB), Const(uplo),Const(trans), Const(diag), Const(idx))
fA = FiniteDifferences.grad(central_fdm(5, 1), A->tr_solv(A, B, uplo, trans, diag, idx), A)[1]
fB = FiniteDifferences.grad(central_fdm(5, 1), B->tr_solv(A, B, uplo, trans, diag, idx), B)[1]

if max(abs.(dA)...) >= 1e-10 || max(abs.(fA)...) >= 1e-10
@test dA fA
end
if max(abs.(dB)...) >= 1e-10 || max(abs.(fB)...) >= 1e-10
@test dB fB
end
end
end
end
end
end

function chol_lower0(x)
c = copy(x)
C, info = LinearAlgebra.LAPACK.potrf!('L', c)
Expand Down
13 changes: 6 additions & 7 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,12 @@ function euroad(f::T) where T
return g
end

euroad′(x) = first(autodiff(Reverse, euroad, Active, Active(x)))[1]

@test euroad(0.5) -log(0.5) # -log(1-x)
@test euroad′(0.5) 2.0 # d/dx -log(1-x) = 1/(1-x)
test_scalar(euroad, 0.5)
end
@noinline function womylogpdf(X::AbstractArray{<:Real})
map(womylogpdf, X)
end
Expand All @@ -736,13 +742,6 @@ end
@test res[1] [0.7]
end

euroad′(x) = first(autodiff(Reverse, euroad, Active, Active(x)))[1]

@test euroad(0.5) -log(0.5) # -log(1-x)
@test euroad′(0.5) 2.0 # d/dx -log(1-x) = 1/(1-x)
test_scalar(euroad, 0.5)
end

@testset "Nested AD" begin
tonest(x,y) = (x + y)^2

Expand Down

0 comments on commit c83fcf8

Please sign in to comment.