From 675fecac6d1c90e083ab1173374fce3b2e8c6e13 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 22 Nov 2023 10:46:37 -0500 Subject: [PATCH 01/10] Add Tridiagonal construction rule --- src/rulesets/LinearAlgebra/structured.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 245335774..705c372da 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -268,3 +268,16 @@ function rrule(::typeof(logdet), X::Union{Diagonal, AbstractTriangular}) end return y, logdet_pullback end + +##### +##### Tridiagonal +##### + +function rrule(::Type{Tridiagonal}, dl, d, du) + y = Tridiagonal(dl, d, du) + @views function ∇Tridiagonal(∂y) + return (NoTangent(), diag(∂y[2:end, 1:(end - 1)]), diag(∂y), + diag(∂y[1:(end - 1), 2:end])) + end + return y, ∇Tridiagonal +end From 3459acb1b02f31a663299b8f28f103ef62d2c110 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 22 Nov 2023 10:48:58 -0500 Subject: [PATCH 02/10] Update structured.jl --- test/rulesets/LinearAlgebra/structured.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 1b83cc394..46f089dbf 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -161,4 +161,9 @@ end end end + + @testset "Tridiagonal" begin + res, pb = rrule(Tridiagonal, [1, 4], [2, 3, 4], [5, 3]) + @test pb(10*res) == (NoTangent(), [10, 40], [20, 30, 40], [50, 30]) + end end From a275f1ff33a6ef20613803e6475750e5d93fbdd2 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 22 Nov 2023 10:49:18 -0500 Subject: [PATCH 03/10] Update src/rulesets/LinearAlgebra/structured.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rulesets/LinearAlgebra/structured.jl | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 705c372da..d2ecdaba9 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -276,8 +276,12 @@ end function rrule(::Type{Tridiagonal}, dl, d, du) y = Tridiagonal(dl, d, du) @views function ∇Tridiagonal(∂y) - return (NoTangent(), diag(∂y[2:end, 1:(end - 1)]), diag(∂y), - diag(∂y[1:(end - 1), 2:end])) + return ( + NoTangent(), + diag(∂y[2:end, 1:(end - 1)]), + diag(∂y), + diag(∂y[1:(end - 1), 2:end]), + ) end return y, ∇Tridiagonal end From e823ed8c92d2054df5ab400290dd431455abd637 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Wed, 22 Nov 2023 10:53:30 -0500 Subject: [PATCH 04/10] Update test/rulesets/LinearAlgebra/structured.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/rulesets/LinearAlgebra/structured.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 46f089dbf..b64152204 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -164,6 +164,6 @@ @testset "Tridiagonal" begin res, pb = rrule(Tridiagonal, [1, 4], [2, 3, 4], [5, 3]) - @test pb(10*res) == (NoTangent(), [10, 40], [20, 30, 40], [50, 30]) + @test pb(10 * res) == (NoTangent(), [10, 40], [20, 30, 40], [50, 30]) end end From 9ba86d8052b532e8e021f446403a63d4d0959612 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 23 Nov 2023 07:41:01 -0500 Subject: [PATCH 05/10] Update test/rulesets/LinearAlgebra/structured.jl Co-authored-by: Frames White --- test/rulesets/LinearAlgebra/structured.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index b64152204..3c0c7f92f 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -163,7 +163,7 @@ end @testset "Tridiagonal" begin - res, pb = rrule(Tridiagonal, [1, 4], [2, 3, 4], [5, 3]) + test_rrule(Tridiagonal, [1.0, 4.0], [2.0, 3.0, 4.0], [5.0, 3.0]) @test pb(10 * res) == (NoTangent(), [10, 40], [20, 30, 40], [50, 30]) end end From ee952df7f1092ec632fbcbfefacd2e2b24961f19 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 23 Nov 2023 07:43:34 -0500 Subject: [PATCH 06/10] Update src/rulesets/LinearAlgebra/structured.jl --- src/rulesets/LinearAlgebra/structured.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index d2ecdaba9..f2d21a274 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -278,9 +278,9 @@ function rrule(::Type{Tridiagonal}, dl, d, du) @views function ∇Tridiagonal(∂y) return ( NoTangent(), - diag(∂y[2:end, 1:(end - 1)]), + diag(∂y, -1), diag(∂y), - diag(∂y[1:(end - 1), 2:end]), + diag(∂y, 1), ) end return y, ∇Tridiagonal From 58957361ecfcf78488c831729cb66b26a473b893 Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 23 Nov 2023 09:34:05 -0500 Subject: [PATCH 07/10] Update src/rulesets/LinearAlgebra/structured.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rulesets/LinearAlgebra/structured.jl | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index f2d21a274..fa96dd126 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -276,12 +276,7 @@ end function rrule(::Type{Tridiagonal}, dl, d, du) y = Tridiagonal(dl, d, du) @views function ∇Tridiagonal(∂y) - return ( - NoTangent(), - diag(∂y, -1), - diag(∂y), - diag(∂y, 1), - ) + return (NoTangent(), diag(∂y, -1), diag(∂y), diag(∂y, 1)) end return y, ∇Tridiagonal end From dfbd363302a7bbeceb9c59078984d3c36b131d5d Mon Sep 17 00:00:00 2001 From: Christopher Rackauckas Date: Thu, 23 Nov 2023 09:34:14 -0500 Subject: [PATCH 08/10] Update test/rulesets/LinearAlgebra/structured.jl Co-authored-by: Frames White --- test/rulesets/LinearAlgebra/structured.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 3c0c7f92f..c46460f46 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -164,6 +164,5 @@ @testset "Tridiagonal" begin test_rrule(Tridiagonal, [1.0, 4.0], [2.0, 3.0, 4.0], [5.0, 3.0]) - @test pb(10 * res) == (NoTangent(), [10, 40], [20, 30, 40], [50, 30]) end end From 938623b9496c0a3f38a9d1479ed8a1c477a86620 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 20 May 2024 21:48:27 +0800 Subject: [PATCH 09/10] unthunk input to Tridiagonal_pullback --- src/rulesets/LinearAlgebra/structured.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index fa96dd126..722e37791 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -275,8 +275,9 @@ end function rrule(::Type{Tridiagonal}, dl, d, du) y = Tridiagonal(dl, d, du) - @views function ∇Tridiagonal(∂y) + @views function Tridiagonal_pullback(ȳ) + ∂y = unthunk(ȳ) return (NoTangent(), diag(∂y, -1), diag(∂y), diag(∂y, 1)) end - return y, ∇Tridiagonal + return y, Tridiagonal_pullback end From ce288b9c571f636b73fd8e128543c9cde979edb1 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 20 May 2024 21:50:34 +0800 Subject: [PATCH 10/10] remove unneded veiws macrod --- src/rulesets/LinearAlgebra/structured.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index 722e37791..e48416739 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -275,7 +275,7 @@ end function rrule(::Type{Tridiagonal}, dl, d, du) y = Tridiagonal(dl, d, du) - @views function Tridiagonal_pullback(ȳ) + function Tridiagonal_pullback(ȳ) ∂y = unthunk(ȳ) return (NoTangent(), diag(∂y, -1), diag(∂y), diag(∂y, 1)) end