From d727befa11f9c4e2f8700372e5523f5e8f884422 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 20 Jan 2023 17:23:02 -0500 Subject: [PATCH 1/2] Add missing overloads for Tridiagonal Upstreaming from DiffEqFlux (https://github.com/SciML/DiffEqFlux.jl/pull/800) --- src/lib/array.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/lib/array.jl b/src/lib/array.jl index 4b8f90609..7b7256cc1 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -432,6 +432,12 @@ end @adjoint LinearAlgebra.UnitLowerTriangular(A) = UnitLowerTriangular(A), Δ->(UnitLowerTriangular(Δ)-I,) @adjoint LinearAlgebra.UnitUpperTriangular(A) = UnitUpperTriangular(A), Δ->(UnitUpperTriangular(Δ)-I,) +@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:dl}) = A.dl, y -> Tridiagonal(dl, zeros(length(d)), zeros(length(du)),) +@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:d}) = A.d, y -> Tridiagonal(zeros(length(dl)), d, zeros(length(du)),) +@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:du}) = A.dl, y -> Tridiagonal(zeros(length(dl)), zeros(length(d), du),) +@adjoint Tridiagonal(dl, d, du) = Tridiagonal(dl, d, du), p̄ -> (diag(p̄[2:end, 1:end-1]), diag(p̄), diag(p̄[1:end-1, 2:end])) + + # This is basically a hack while we don't have a working `ldiv!`. @adjoint function \(A::Cholesky, B::AbstractVecOrMat) Y, back = Zygote.pullback((U, B)->U \ (U' \ B), A.U, B) From 37a54bb9fc28d122d006e8af45cb28dd66fdd8e9 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Fri, 20 Jan 2023 17:30:23 -0500 Subject: [PATCH 2/2] Update array.jl --- src/lib/array.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 7b7256cc1..0d1b59757 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -432,9 +432,9 @@ end @adjoint LinearAlgebra.UnitLowerTriangular(A) = UnitLowerTriangular(A), Δ->(UnitLowerTriangular(Δ)-I,) @adjoint LinearAlgebra.UnitUpperTriangular(A) = UnitUpperTriangular(A), Δ->(UnitUpperTriangular(Δ)-I,) -@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:dl}) = A.dl, y -> Tridiagonal(dl, zeros(length(d)), zeros(length(du)),) -@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:d}) = A.d, y -> Tridiagonal(zeros(length(dl)), d, zeros(length(du)),) -@adjoint ZygoteRules.literal_getproperty(A::Tridiagonal, ::Val{:du}) = A.dl, y -> Tridiagonal(zeros(length(dl)), zeros(length(d), du),) +@adjoint literal_getproperty(A::Tridiagonal, ::Val{:dl}) = A.dl, y -> Tridiagonal(dl, zeros(length(d)), zeros(length(du)),) +@adjoint literal_getproperty(A::Tridiagonal, ::Val{:d}) = A.d, y -> Tridiagonal(zeros(length(dl)), d, zeros(length(du)),) +@adjoint literal_getproperty(A::Tridiagonal, ::Val{:du}) = A.dl, y -> Tridiagonal(zeros(length(dl)), zeros(length(d), du),) @adjoint Tridiagonal(dl, d, du) = Tridiagonal(dl, d, du), p̄ -> (diag(p̄[2:end, 1:end-1]), diag(p̄), diag(p̄[1:end-1, 2:end]))