Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle matrix times matrix = vector case #227

Merged
merged 8 commits into from
May 12, 2023
Merged

Handle matrix times matrix = vector case #227

merged 8 commits into from
May 12, 2023

Conversation

dkarrasch
Copy link
Contributor

This "fixes" some weird behavior in the multiplication code. It occurred in https://s3.amazonaws.com/julialang-reports/nanosoldier/pkgeval/by_hash/8a3027b_vs_960870e/PositiveFactorizations.primary.log PositiveFactorizations.jl. In their case, they want to multiply a matrix by the transpose of a row matrix (i.e., vector-like) into a vector. "By chance", this works out from the pov of dimensions, but from the pov of types having such a mul! method is weird and we may not wish to continue to support this (JuliaLang/julia#49521 (comment)). To make this package (and PositiveFactorizations.jl) run smoothly across (past and upcoming) versions, I proposed to simply catch that case and reshape the output vector to a matrix. In fact, this may even turn out to be advantageous in terms of performance, because that strange method in LinearAlgebra.jl calls generic_matmatmul!, for arguments matching the following signature:

mul!(::Vector{Float64}, ::Matrix{Float64}, ::Transpose{Float64, Matrix{Float64}}, ::Bool, ::Bool)

Once we reshape the vector to a matrix, this would go down the BLAS route!

@codecov
Copy link

codecov bot commented Apr 28, 2023

Codecov Report

Patch coverage: 100.00% and project coverage change: -3.05 ⚠️

Comparison is base (f1f3d1f) 84.51% compared to head (90f676b) 81.46%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master     #227      +/-   ##
==========================================
- Coverage   84.51%   81.46%   -3.05%     
==========================================
  Files          18       18              
  Lines        1924     1581     -343     
==========================================
- Hits         1626     1288     -338     
+ Misses        298      293       -5     
Impacted Files Coverage Δ
src/derivatives/linalg/arithmetic.jl 72.15% <100.00%> (+1.02%) ⬆️

... and 16 files with indirect coverage changes

☔ View full report in Codecov by Sentry.
📢 Do you have feedback about the report comment? Let us know in this issue.

@dkarrasch
Copy link
Contributor Author

Gentle bump.

@devmotion
Copy link
Member

I didn't get why exactly this method is needed. Shouldn't the re-routing happen in LinearAlgebra, why is some special handling in ReverseDiff needed?

In any case, it seems the code is not covered by tests yet.

@dkarrasch
Copy link
Contributor Author

So here's the full stacktrace. It shows that the output is allocated by ReverseDiff, and then LinearAlgebra takes what it gets:

MethodError: no method matching mul!(::Vector{Float64}, ::Matrix{Float64}, ::Transpose{Float64, Matrix{Float64}}, ::Bool, ::Bool)
  
  Closest candidates are:
    mul!(::StridedVector{T}, ::StridedVecOrMat{T}, !Matched::StridedVector{T}, ::Number, ::Number) where T<:Union{Float32, Float64, ComplexF64, ComplexF32}
     @ LinearAlgebra /opt/julia/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:66
    mul!(!Matched::StridedMatrix{T}, ::StridedVecOrMat{T}, ::Transpose{<:Any, <:StridedVecOrMat{T}}, ::Number, ::Number) where T<:Union{Float32, Float64, ComplexF64, ComplexF32}
     @ LinearAlgebra /opt/julia/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:369
    mul!(!Matched::StridedMatrix{T}, ::Union{Adjoint{<:Any, <:StridedVecOrMat{T}}, Transpose{<:Any, <:StridedVecOrMat{T}}, StridedMatrix{T}, StridedVector{T}}, ::Union{Adjoint{<:Any, <:StridedVecOrMat{T}}, Transpose{<:Any, <:StridedVecOrMat{T}}, StridedMatrix{T}, StridedVector{T}}, ::Number, ::Number) where T<:Union{Float32, Float64, ComplexF64, ComplexF32}
     @ LinearAlgebra /opt/julia/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:356
    ...
  
  Stacktrace:
    [1] mul!
      @ /opt/julia/share/julia/stdlib/v1.10/LinearAlgebra/src/matmul.jl:251 [inlined]
    [2] reverse_mul!(output::ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, output_deriv::Matrix{Float64}, a::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, b::Matrix{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}, a_tmp::Vector{Float64}, b_tmp::Matrix{Float64})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/derivatives/linalg/arithmetic.jl:273
    [3] special_reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(*), Tuple{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Matrix{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, Tuple{Vector{Float64}, Matrix{Float64}}})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/derivatives/linalg/arithmetic.jl:265
    [4] reverse_exec!(instruction::ReverseDiff.SpecialInstruction{typeof(*), Tuple{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, Matrix{ReverseDiff.TrackedReal{Float64, Float64, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}}}}, ReverseDiff.TrackedArray{Float64, Float64, 2, Matrix{Float64}, Matrix{Float64}}, Tuple{Vector{Float64}, Matrix{Float64}}})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/tape.jl:93
    [5] reverse_pass!(tape::Vector{ReverseDiff.AbstractInstruction})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/tape.jl:87
    [6] reverse_pass!
      @ ~/.julia/packages/ReverseDiff/Zu4v6/src/api/tape.jl:36 [inlined]
    [7] seeded_reverse_pass!(result::Vector{Float64}, output::ReverseDiff.TrackedReal{Float64, Float64, Nothing}, input::ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, tape::ReverseDiff.GradientTape{var"#8#15"{DataType, var"#5#12"}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/api/utils.jl:31
    [8] seeded_reverse_pass!(result::Vector{Float64}, t::ReverseDiff.GradientTape{var"#8#15"{DataType, var"#5#12"}, ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/api/tape.jl:47
    [9] gradient(f::Function, input::Vector{Float64}, cfg::ReverseDiff.GradientConfig{ReverseDiff.TrackedArray{Float64, Float64, 1, Vector{Float64}, Vector{Float64}}})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/api/gradients.jl:24
   [10] gradient(f::Function, input::Vector{Float64})
      @ ReverseDiff ~/.julia/packages/ReverseDiff/Zu4v6/src/api/gradients.jl:22

I'm not familiar with AD and how to set it up so that it goes this route. As I wrote, it came up with PositiveFactorizations.jl and AD-ing cholesky or something like that, so not easily reproducible here.

In any case, it seems the code is not covered by tests yet.

True, there's quite a lot uncovered in that code area. Any help with that would be much appreciated.

@dkarrasch
Copy link
Contributor Author

dkarrasch commented May 4, 2023

If you wish, there is the test in PositiveFactorizations.jl ("downstream") that tests this. Not satisfactory, but given that the reverse_mul! is already largely uncovered, maybe acceptable?

EDIT: I think I found a test case.

Copy link
Member

@devmotion devmotion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked and it seems that ReverseDiff computes gradients for these functions even on the master branch? Are you sure the test is correct?

test/api/GradientTests.jl Outdated Show resolved Hide resolved
test/derivatives/LinAlgTests.jl Outdated Show resolved Hide resolved
test/derivatives/LinAlgTests.jl Outdated Show resolved Hide resolved
@dkarrasch
Copy link
Contributor Author

I think it no longer does after I merged JuliaLang/julia#49521.

@devmotion
Copy link
Member

It seems that the PR introduces method ambiguity issues which cause test failures after removing collect in the tests.

Can you fix these and add also a version of norm_hermitian with collect (but keep the one without as well since it seems to cover some of the method ambiguity issues)?

If I can find some time I'll also check if the initial problem can be fixed in some other way, since evidently adding new dispatches seems a bit problematic.

@dkarrasch
Copy link
Contributor Author

Interestingly, the ambiguities do not occur on v1.8, only on the very old versions. I have now avoided adding another method and introduced branches in the existing one. Those branches, however, should be compiled away because the conditions can be checked in the type domain.

BTW, this packages has more than 3k ambiguities! My computer almost crashed when running Aqua.jl on it.

test/api/GradientTests.jl Outdated Show resolved Hide resolved
dkarrasch and others added 2 commits May 8, 2023 15:18
@devmotion devmotion merged commit d70ba91 into JuliaDiff:master May 12, 2023
@devmotion
Copy link
Member

Thank you @dkarrasch!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants