From 2faddfd245546b138ad602c3d8ae11d228ecd70c Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 4 Sep 2020 00:42:07 +0200 Subject: [PATCH 1/7] rrule for 3-arg dot, take 1 --- src/rulesets/LinearAlgebra/dense.jl | 12 ++++++++++++ test/rulesets/LinearAlgebra/dense.jl | 8 ++++++++ 2 files changed, 20 insertions(+) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 9800363c1..aca0fd0de 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -19,6 +19,18 @@ function rrule(::typeof(dot), x, y) return dot(x, y), dot_pullback end +function rrule(::typeof(dot), x::AbstractVector, A::AbstractMatrix, y::AbstractVector) + Ay = A * y + z = adjoint(x) * Ay + function dot_pullback(ΔΩ) + dx = @thunk conj(ΔΩ) .* Ay + dA = @thunk conj.(ΔΩ .* x) .* transpose(y) + dy = @thunk conj(ΔΩ) .* vec(adjoint(x) * A) + return (NO_FIELDS, dx, dA, dy) + end + return z, dot_pullback +end + ##### ##### `cross` ##### diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 8b2f4437d..87e1950e5 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -24,6 +24,14 @@ frule_test(dot, (x, ẋ), (y, ẏ)) rrule_test(dot, randn(T), (x, x̄), (y, ȳ)) end + @testset "3-arg dot, {$T}" for T in (Float64, )#ComplexF64) + M, N = 3, 4 + x, A, y = randn(T, M), randn(T, M,N), randn(T, N) + # ẋ, Adot, ẏ = randn(T, M), randn(T, M,N), randn(T, N) + x̄, Abar, ȳ = randn(T, M), randn(T, M,N), randn(T, N) + # frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ)) + rrule_test(dot, randn(T), (x, x̄), (A, Abar), (y, ȳ)) + end end @testset "cross" begin @testset "frule" begin From 06dbd4f96875f3a712d99e0cbd0e3387ba96d780 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Fri, 4 Sep 2020 11:34:14 +0200 Subject: [PATCH 2/7] fix complex --- src/rulesets/LinearAlgebra/dense.jl | 9 +++++++-- test/rulesets/LinearAlgebra/dense.jl | 6 +++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index aca0fd0de..613b3155d 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -19,15 +19,20 @@ function rrule(::typeof(dot), x, y) return dot(x, y), dot_pullback end +function frule((_, Δx, ΔA, Δy), ::typeof(dot), x::AbstractVector, A::AbstractMatrix, y::AbstractVector) + return dot(x, A, y), dot(Δx, A, y) + dot(x, ΔA, y) + dot(x, A, Δy) +end + function rrule(::typeof(dot), x::AbstractVector, A::AbstractMatrix, y::AbstractVector) Ay = A * y z = adjoint(x) * Ay function dot_pullback(ΔΩ) dx = @thunk conj(ΔΩ) .* Ay - dA = @thunk conj.(ΔΩ .* x) .* transpose(y) - dy = @thunk conj(ΔΩ) .* vec(adjoint(x) * A) + dA = @thunk ΔΩ .* x .* adjoint(y) + dy = @thunk ΔΩ .* (adjoint(A) * x) return (NO_FIELDS, dx, dA, dy) end + dot_pullback(::Zero) = (NO_FIELDS, Zero(), Zero(), Zero()) return z, dot_pullback end diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 87e1950e5..fa6406176 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -24,12 +24,12 @@ frule_test(dot, (x, ẋ), (y, ẏ)) rrule_test(dot, randn(T), (x, x̄), (y, ȳ)) end - @testset "3-arg dot, {$T}" for T in (Float64, )#ComplexF64) + @testset "3-arg dot, Array{$T}" for T in (Float64, ComplexF64) M, N = 3, 4 x, A, y = randn(T, M), randn(T, M,N), randn(T, N) - # ẋ, Adot, ẏ = randn(T, M), randn(T, M,N), randn(T, N) + ẋ, Adot, ẏ = randn(T, M), randn(T, M,N), randn(T, N) x̄, Abar, ȳ = randn(T, M), randn(T, M,N), randn(T, N) - # frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ)) + frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ)) rrule_test(dot, randn(T), (x, x̄), (A, Abar), (y, ȳ)) end end From 5db1940122683d47327f7d85f9956a37c47700cf Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 4 Sep 2020 11:52:36 +0200 Subject: [PATCH 3/7] Update test/rulesets/LinearAlgebra/dense.jl Co-authored-by: willtebbutt --- test/rulesets/LinearAlgebra/dense.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index fa6406176..5b5c8d2ca 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -26,9 +26,9 @@ end @testset "3-arg dot, Array{$T}" for T in (Float64, ComplexF64) M, N = 3, 4 - x, A, y = randn(T, M), randn(T, M,N), randn(T, N) - ẋ, Adot, ẏ = randn(T, M), randn(T, M,N), randn(T, N) - x̄, Abar, ȳ = randn(T, M), randn(T, M,N), randn(T, N) + x, A, y = randn(T, M), randn(T, M, N), randn(T, N) + ẋ, Adot, ẏ = randn(T, M), randn(T, M, N), randn(T, N) + x̄, Abar, ȳ = randn(T, M), randn(T, M, N), randn(T, N) frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ)) rrule_test(dot, randn(T), (x, x̄), (A, Abar), (y, ȳ)) end From 17cc2fe448bc1533eefc7a563014da1b16672ebe Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sat, 5 Sep 2020 10:53:01 +0200 Subject: [PATCH 4/7] add method for Diagonal --- src/rulesets/LinearAlgebra/dense.jl | 12 ++++++++++++ test/rulesets/LinearAlgebra/structured.jl | 7 ++++++- 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 613b3155d..782783af9 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -36,6 +36,18 @@ function rrule(::typeof(dot), x::AbstractVector, A::AbstractMatrix, y::AbstractV return z, dot_pullback end +function rrule(::typeof(dot), x::AbstractVector, A::Diagonal, y::AbstractVector) + z = dot(x,A,y) + function dot_pullback(ΔΩ) + dx = @thunk conj(ΔΩ) .* A.diag .* y # A*y is this broadcast, can be fused + dA = @thunk Diagonal(ΔΩ .* x .* conj(y)) # calculate N not N^2 elements + dy = @thunk ΔΩ .* conj.(A.diag) .* x + return (NO_FIELDS, dx, dA, dy) + end + dot_pullback(::Zero) = (NO_FIELDS, Zero(), Zero(), Zero()) + return z, dot_pullback +end + ##### ##### `cross` ##### diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 8ca1f0284..1867a4290 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -14,7 +14,12 @@ comp = Composite{typeof(res)}(; diag=10*res.diag) # this is the structure of Diagonal @test pb(comp) == (NO_FIELDS, [10, 40]) end - + @testset "dot(x, ::Diagonal, y)" begin + N = 4 + x, d, y = randn(ComplexF64, N), randn(ComplexF64, N), randn(ComplexF64, N) + D = Diagonal(d) + rrule_test(dot, rand(ComplexF64), (x,similar(x)), (D,similar(D)), (y,similar(y))) + end @testset "::Diagonal * ::AbstractVector" begin N = 3 rrule_test( From 2705982efaee1562ac6f430812fb3ffa5ee35d55 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Sun, 6 Sep 2020 16:54:11 +0200 Subject: [PATCH 5/7] restrict to AbstractVector{<:Number} etc --- src/rulesets/LinearAlgebra/dense.jl | 6 +++--- test/rulesets/LinearAlgebra/dense.jl | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rulesets/LinearAlgebra/dense.jl b/src/rulesets/LinearAlgebra/dense.jl index 782783af9..f04f83321 100644 --- a/src/rulesets/LinearAlgebra/dense.jl +++ b/src/rulesets/LinearAlgebra/dense.jl @@ -19,11 +19,11 @@ function rrule(::typeof(dot), x, y) return dot(x, y), dot_pullback end -function frule((_, Δx, ΔA, Δy), ::typeof(dot), x::AbstractVector, A::AbstractMatrix, y::AbstractVector) +function frule((_, Δx, ΔA, Δy), ::typeof(dot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:Number}, y::AbstractVector{<:Number}) return dot(x, A, y), dot(Δx, A, y) + dot(x, ΔA, y) + dot(x, A, Δy) end -function rrule(::typeof(dot), x::AbstractVector, A::AbstractMatrix, y::AbstractVector) +function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::AbstractMatrix{<:Number}, y::AbstractVector{<:Number}) Ay = A * y z = adjoint(x) * Ay function dot_pullback(ΔΩ) @@ -36,7 +36,7 @@ function rrule(::typeof(dot), x::AbstractVector, A::AbstractMatrix, y::AbstractV return z, dot_pullback end -function rrule(::typeof(dot), x::AbstractVector, A::Diagonal, y::AbstractVector) +function rrule(::typeof(dot), x::AbstractVector{<:Number}, A::Diagonal{<:Number}, y::AbstractVector{<:Number}) z = dot(x,A,y) function dot_pullback(ΔΩ) dx = @thunk conj(ΔΩ) .* A.diag .* y # A*y is this broadcast, can be fused diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index 5b5c8d2ca..f3afa93ca 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -28,7 +28,7 @@ M, N = 3, 4 x, A, y = randn(T, M), randn(T, M, N), randn(T, N) ẋ, Adot, ẏ = randn(T, M), randn(T, M, N), randn(T, N) - x̄, Abar, ȳ = randn(T, M), randn(T, M, N), randn(T, N) + x̄, Abar, ȳ = similar(x), similar(A), similar(y) frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ)) rrule_test(dot, randn(T), (x, x̄), (A, Abar), (y, ȳ)) end From a2b694d5096edbd3dbc010be97c12cd334376b8d Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Wed, 21 Oct 2020 22:38:59 +0200 Subject: [PATCH 6/7] add PermutedDimsArray tests --- test/rulesets/LinearAlgebra/dense.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/rulesets/LinearAlgebra/dense.jl b/test/rulesets/LinearAlgebra/dense.jl index f3afa93ca..04a1b7298 100644 --- a/test/rulesets/LinearAlgebra/dense.jl +++ b/test/rulesets/LinearAlgebra/dense.jl @@ -32,6 +32,15 @@ frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ)) rrule_test(dot, randn(T), (x, x̄), (A, Abar), (y, ȳ)) end + permuteddimsarray(A) = PermutedDimsArray(A, (2,1)) + @testset "3-arg dot, $F{$T}" for T in (Float32, ComplexF32), F in (adjoint, permuteddimsarray) + M, N = 3, 4 + x, A, y = rand(T, M), F(rand(T, N, M)), rand(T, N) + ẋ, Adot, ẏ = rand(T, M), F(rand(T, N, M)), rand(T, N) + x̄, Abar, ȳ = similar(x), F(rand(T, N, M)), similar(y) + frule_test(dot, (x, ẋ), (A, Adot), (y, ẏ); rtol=1f-3) + rrule_test(dot, float(rand(T)), (x, x̄), (A, Abar), (y, ȳ); rtol=1f-3) + end end @testset "cross" begin @testset "frule" begin From 9201c60743cae102e9f399ee43b88f337b530298 Mon Sep 17 00:00:00 2001 From: Michael Abbott Date: Thu, 22 Oct 2020 23:00:28 +0200 Subject: [PATCH 7/7] v0.7.30 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 632a6f587..139c87a7a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.17" +version = "0.7.30" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"