diff --git a/src/implementations/LinearAlgebra.jl b/src/implementations/LinearAlgebra.jl index 7d6dd3e..00b301f 100644 --- a/src/implementations/LinearAlgebra.jl +++ b/src/implementations/LinearAlgebra.jl @@ -385,6 +385,12 @@ function operate( A::AbstractMatrix{S}, B::AbstractVector{T}, ) where {T,S} + # Only use the efficient in-place operate_to! if both arrays are + # concrete. Bad things can happen if S or T is abstract and we pick the + # wrong type for C. + if !(isconcretetype(S) && isconcretetype(T)) + return A * B + end C = undef_array(promote_array_mul(typeof(A), typeof(B)), axes(A, 1)) return operate_to!(C, *, A, B) end @@ -394,6 +400,12 @@ function operate( A::AbstractMatrix{S}, B::AbstractMatrix{T}, ) where {T,S} + # Only use the efficient in-place operate_to! if both arrays are + # concrete. Bad things can happen if S or T is abstract and we pick the + # wrong type for C. + if !(isconcretetype(S) && isconcretetype(T)) + return A * B + end C = undef_array( promote_array_mul(typeof(A), typeof(B)), axes(A, 1), diff --git a/test/matmul.jl b/test/matmul.jl index bacbbef..51c8027 100644 --- a/test/matmul.jl +++ b/test/matmul.jl @@ -335,3 +335,68 @@ end LinearAlgebra.mul!(ret, A, B) @test ret == A * B end + +@testset "Abstract eltype in matmul" begin + # Test that we don't initialize the output with zero(T), which might not + # exist. + for M in (Matrix, LinearAlgebra.Diagonal) + for T in (Any, Union{String,Int}) + x, x12, x22 = T[1, 2], T[1 2], M([1 2; 3 4]) + @test MA.operate(*, x, x') ≈ x * x' + @test MA.operate(*, x', x) ≈ x' * x + @test MA.operate(*, x12, x) ≈ x12 * x + @test MA.operate(*, x22, x) ≈ x22 * x + @test MA.operate(*, x', x22) ≈ x' * x22 + @test MA.operate(*, x12, x22) ≈ x12 * x22 + @test MA.operate(*, x22, x22) ≈ x22 * x22 + y = M([1.1 1.2; 1.3 1.4]) + @test MA.operate(*, y, x) ≈ y * x + @test MA.operate(*, x', y) ≈ x' * y + @test MA.operate(*, y, x12') ≈ y * x12' + @test MA.operate(*, x12, y) ≈ x12 * y + @test MA.operate(*, x22, y) ≈ x22 * y + @test MA.operate(*, y, x22) ≈ y * x22 + end + end + for T in (Any, Union{String,Int}) + x, x12, x22 = T[1, 2], T[1 2], LinearAlgebra.LowerTriangular([1 2; 3 4]) + @test MA.operate(*, x, x') ≈ x * x' + @test MA.operate(*, x', x) ≈ x' * x + @test MA.operate(*, x12, x) ≈ x12 * x + @test MA.operate(*, x22, x22) ≈ x22 * x22 + y = LinearAlgebra.LowerTriangular([1.1 1.2; 1.3 1.4]) + @test MA.operate(*, x22, y) ≈ x22 * y + @test MA.operate(*, y, x22) ≈ y * x22 + # TODO(odow): These tests are broken because `Base` is also broken. + # Although it fixed y * x12' in Julia v1.9.0. + # @test_broken MA.operate(*, x22, x) ≈ x22 * x + # @test_broken MA.operate(*, x', x22) ≈ x' * x22 + # @test_broken MA.operate(*, x12, x22) ≈ x12 * x22 + # @test_broken MA.operate(*, y, x) ≈ y * x + # @test_broken MA.operate(*, x', y) ≈ x' * y + # @test_broken MA.operate(*, y, x12') ≈ y * x12' + # @test_broken MA.operate(*, x12, y) ≈ x12 * y + end +end + +@testset "Union{Int,Float64} eltype in matmul" begin + # Test that we don't initialize the output with zero(Int), either by taking + # the first available type in the union, or by looking at the first element + # in the array. + T = Union{Int,Float64} + x, x12, x22 = T[1, 2.5], T[1 2.5], T[1 2.5; 3.5 4] + @test MA.operate(*, x, x') == x * x' + @test MA.operate(*, x', x) == x' * x + @test MA.operate(*, x12, x) == x12 * x + @test MA.operate(*, x22, x) == x22 * x + @test MA.operate(*, x', x22) == x' * x22 + @test MA.operate(*, x12, x22) == x12 * x22 + @test MA.operate(*, x22, x22) == x22 * x22 + y = [1.1 1.2; 1.3 1.4] + @test MA.operate(*, y, x) == y * x + @test MA.operate(*, x', y) == x' * y + @test MA.operate(*, y, x12') == y * x12' + @test MA.operate(*, x12, y) == x12 * y + @test MA.operate(*, x22, y) == x22 * y + @test MA.operate(*, y, x22) == y * x22 +end