Skip to content

Commit

Permalink
Matrix(Diagonal)
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Aug 21, 2022
1 parent 7dd6ede commit f885295
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 0 deletions.
3 changes: 3 additions & 0 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,9 @@ end
LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...)
@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i))

# fix Matrix(Diagonal(param([1,2,3]))) after https://github.com/JuliaLang/julia/pull/44615
(::Type{Matrix})(d::Diagonal{<:Any,<:TrackedArray}) = diagm(0 => d.diag)

x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
Expand Down
1 change: 1 addition & 0 deletions test/tracker.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@ end
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))

@test gradtest(x -> diagm(0 => x), rand(3))
@test gradtest(x -> Matrix(Diagonal(x)), rand(3))

@test gradtest(W -> inv(log.(W * W)), (5,5))
@test gradtest((A, B) -> A / B , (1,5), (5,5))
Expand Down

0 comments on commit f885295

Please sign in to comment.