From f885295f0c7dceaf6d96b5e8d9847d458c84d057 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Sun, 21 Aug 2022 11:05:16 -0400 Subject: [PATCH] Matrix(Diagonal) --- src/lib/array.jl | 3 +++ test/tracker.jl | 1 + 2 files changed, 4 insertions(+) diff --git a/src/lib/array.jl b/src/lib/array.jl index 8daca8c..c69f17b 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -420,6 +420,9 @@ end LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...) @grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i)) +# fix Matrix(Diagonal(param([1,2,3]))) after https://github.com/JuliaLang/julia/pull/44615 +(::Type{Matrix})(d::Diagonal{<:Any,<:TrackedArray}) = diagm(0 => d.diag) + x::TrackedMatrix * y::AbstractMatrix = track(*, x, y) x::AbstractMatrix * y::TrackedMatrix = track(*, x, y) x::TrackedMatrix * y::TrackedMatrix = track(*, x, y) diff --git a/test/tracker.jl b/test/tracker.jl index dbd2a41..97dddcc 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -172,6 +172,7 @@ end @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) @test gradtest(x -> diagm(0 => x), rand(3)) +@test gradtest(x -> Matrix(Diagonal(x)), rand(3)) @test gradtest(W -> inv(log.(W * W)), (5,5)) @test gradtest((A, B) -> A / B , (1,5), (5,5))