Skip to content

Commit

Permalink
Fix implementation of fused_map_reduce (#171)
Browse files Browse the repository at this point in the history
  • Loading branch information
odow authored Nov 15, 2022
1 parent 4dc6932 commit 201aafb
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 1 deletion.
9 changes: 9 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,15 @@ function promote_operation_fallback(
return promote_operation(*, promote_operation(adjoint, A), B)
end

function promote_operation_fallback(
::typeof(LinearAlgebra.dot),
::Type{<:AbstractArray{A}},
::Type{<:AbstractArray{B}},
) where {A,B}
C = promote_operation(*, A, B)
return promote_operation(+, C, C)
end

function buffer_for(::typeof(add_dot), a::Type, b::Type, c::Type)
return buffer_for(add_mul, a, promote_operation(adjoint, b), c)
end
Expand Down
4 changes: 3 additions & 1 deletion src/reduce.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@ function promote_map_reduce(op::Function, args::Vararg{Any,N}) where {N}
)
end

_concrete_eltype(x) = isempty(x) ? eltype(x) : typeof(first(x))

function fused_map_reduce(op::F, args::Vararg{Any,N}) where {F<:Function,N}
_check_same_length(args...)
T = promote_map_reduce(op, eltype.(args)...)
T = promote_map_reduce(op, _concrete_eltype.(args)...)
accumulator = neutral_element(reduce_op(op), T)
buffer = buffer_for(op, T, eltype.(args)...)
for I in zip(eachindex.(args)...)
Expand Down
13 changes: 13 additions & 0 deletions test/dispatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,17 @@ end
# On `DummyBigInt` allocates more on previous releases of Julia
# as it's dynamically allocated
dispatch_tests(DummyBigInt)

@testset "dot non-concrete vector" begin
x = [5.0, 6.0]
y = Vector{Union{Float64,String}}(x)
@test MA.operate(LinearAlgebra.dot, x, y) == LinearAlgebra.dot(x, y)
@test MA.operate(*, x', y) == x' * y
end

@testset "dot vector of vectors" begin
x = [5.0, 6.0]
z = [x, x]
@test MA.operate(LinearAlgebra.dot, z, z) == LinearAlgebra.dot(z, z)
end
end

0 comments on commit 201aafb

Please sign in to comment.