Skip to content

Commit

Permalink
Merge pull request #627 from finch-tensor/wma/backwards-fusion
Browse files Browse the repository at this point in the history
Wma/backwards fusion
  • Loading branch information
willow-ahrens authored Oct 25, 2024
2 parents 4795478 + 968d4eb commit 634e6d1
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 8 deletions.
30 changes: 28 additions & 2 deletions benchmark/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,33 @@ let
A, x = ($A, $x)
@einsum y[i] += A[i, j] * x[j]
end,
seconds = 10.0 #Bug in benchmarktools, will be fixed soon.
)
end

let
N = 1_000
K = 1_000
p = 0.001
A = Tensor(Dense(Dense(Element(0.0))), rand(N, K))
B = Tensor(Dense(Dense(Element(0.0))), rand(K, N))
M = Tensor(Dense(SparseList(Element(0.0))), fsprand(N, N, p))

SUITE["high-level"]["sddmm_fused"] = @benchmarkable(
begin
M = lazy($M)
A = lazy($A)
B = lazy($B)
compute(M .* (A * B))
end,
)

SUITE["high-level"]["sddmm_unfused"] = @benchmarkable(
begin
M = $M
A = $A
B = $B
M .* (A * B)
end,
)
end

Expand Down Expand Up @@ -241,4 +267,4 @@ for (key, mtx) in [
x = Tensor(Dense{Int64}(Element{0.0, Float64, Int64}()), rand(size(A)[2]))
SUITE["parallel"]["SpMV_serial"][key] = @benchmarkable spmv_serial($A, $x)
SUITE["parallel"]["SpMV_threaded"][key] = @benchmarkable spmv_threaded($A, $x)
end
end
15 changes: 14 additions & 1 deletion src/interface/eager.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ function Base.reduce(op::Function, bc::Broadcasted{FinchStyle{N}}; dims=:, init
end
end

function tensordot(A::AbstractTensor, B::AbstractTensor, idxs; kw...)
function tensordot(A::Union{AbstractTensor, AbstractArray}, B::Union{AbstractTensor, AbstractArray}, idxs; kw...)
compute(tensordot(lazy(A), lazy(B), idxs; kw...))
end

Expand Down Expand Up @@ -79,6 +79,19 @@ Base.:*(
z::Number...
) = map(*, y, x, z...)

Base.:*(
A::AbstractTensor,
B::Union{AbstractTensor, AbstractArray}
) = tensordot(A, B, (2, 1))
Base.:*(
A::Union{AbstractTensor, AbstractArray},
B::AbstractTensor
) = tensordot(A, B, (2, 1))
Base.:*(
A::AbstractTensor,
B::AbstractTensor
) = tensordot(A, B, (2, 1))

Base.:-(x::AbstractTensor) = map(-, x)

Base.:-(x::AbstractTensor, y::Union{Base.AbstractArrayOrBroadcasted, Number}) = map(-, x, y)
Expand Down
16 changes: 16 additions & 0 deletions src/interface/lazy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,9 @@ function Base.reduce(op, arg::LazyTensor{T, N}; dims=:, init = initial_value(op,
LazyTensor{S}(identify(data), extrude, init)
end

tensordot(A::LazyTensor, B::Union{AbstractTensor, AbstractArray}, idxs; kwargs...) = tensordot(A, LazyTensor(B), idxs; kwargs...)
tensordot(A::Union{AbstractTensor, AbstractArray}, B::LazyTensor, idxs; kwargs...) = tensordot(LazyTensor(A), B, idxs; kwargs...)

# tensordot takes in two tensors `A` and `B` and performs a product and contraction
function tensordot(A::LazyTensor{T1, N1}, B::LazyTensor{T2, N2}, idxs; mult_op=*, add_op=+, init = initial_value(add_op, return_type(DefaultAlgebra(), mult_op, T1, T2))) where {T1, T2, N1, N2}
if idxs isa Number
Expand Down Expand Up @@ -272,6 +275,19 @@ Base.:*(
z::Number...
) = map(*, y, x, z...)

Base.:*(
A::LazyTensor,
B::Union{LazyTensor, AbstractTensor, AbstractArray}
) = tensordot(A, B, (2, 1))
Base.:*(
A::Union{LazyTensor, AbstractTensor, AbstractArray},
B::LazyTensor
) = tensordot(A, B, (2, 1))
Base.:*(
A::LazyTensor,
B::LazyTensor
) = tensordot(A, B, (2, 1))

Base.:-(x::LazyTensor) = map(-, x)

Base.:-(x::LazyTensor, y::Union{LazyTensor, AbstractTensor, Base.AbstractArrayOrBroadcasted, Number}) = map(-, x, y)
Expand Down
48 changes: 47 additions & 1 deletion src/scheduler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ defined in the following grammar:
EXPR := LEAF |
reorder(EXPR, FIELD...) |
relabel(EXPR, FIELD...) |
mapjoin(IMMEDIATE, EXPR...)
mapjoin(IMMEDIATE, EXPR...) |
aggregate(IMMEDIATE, IMMEDIATE, EXPR, FIELD...)
```
Pushes all reorder and relabel statements down to LEAF nodes of each EXPR.
Expand All @@ -97,6 +98,12 @@ function push_fields(root)
idxs_2 = getfields(mapjoin(op, args...))
mapjoin(op, map(arg -> relabel(reorder(arg, idxs_2...), idxs...), args)...)
end),
(@rule relabel(aggregate(~op, ~init, ~arg, ~idxs...), ~idxs_2...) => begin
idxs_3 = setdiff(getfields(arg), idxs)
reidx = Dict(map(Pair, idxs_3, idxs_2)...)
idxs_4 = map(idx -> get(reidx, idx, idx), getfields(arg))
aggregate(op, init, relabel(arg, idxs_4...), idxs...)
end),
(@rule relabel(relabel(~arg, ~idxs...), ~idxs_2...) =>
relabel(~arg, ~idxs_2...)),
(@rule relabel(reorder(~arg, ~idxs_1...), ~idxs_2...) => begin
Expand All @@ -113,6 +120,12 @@ function push_fields(root)
root = Rewrite(Prewalk(Fixpoint(Chain([
(@rule reorder(mapjoin(~op, ~args...), ~idxs...) =>
mapjoin(op, map(arg -> reorder(arg, ~idxs...), args)...)),
(@rule reorder(aggregate(~op, ~init, ~arg, ~idxs...), ~idxs_2...) => begin
idxs_3 = setdiff(getfields(arg), idxs)
reidx = Dict(map(Pair, idxs_3, idxs_2)...)
idxs_4 = map(idx -> get(reidx, idx, idx), getfields(arg))
aggregate(op, init, reorder(arg, idxs_4...), idxs...)
end),
(@rule reorder(reorder(~arg, ~idxs...), ~idxs_2...) =>
reorder(~arg, ~idxs_2...)),
]))))(root)
Expand Down Expand Up @@ -405,6 +418,37 @@ function propagate_map_queries(root)
])))(root)
end

function propagate_map_queries_backward(root)
root = Rewrite(Postwalk(@rule aggregate(~op, ~init, ~arg) => mapjoin(op, init, arg)))(root)
uses = Dict{LogicNode, Int}()
defs = Dict{LogicNode, LogicNode}()
rets = getproductions(root)
for node in PostOrderDFS(root)
if node.kind === alias
uses[node] = get(uses, node, 0) + 1
elseif @capture node query(~a, ~b)
uses[a] = get(uses, a, 0) - 1
defs[a] = b
end
end
root = Rewrite(Prewalk(Chain([
(@rule query(~a, ~b) => if uses[a] == 1 && !(a in rets) plan() end),
(@rule ~a => if get(uses, a, 0) == 1 && !(a in rets) get(defs, a, a) end)
])))(root)
root = push_fields(root)
root = Rewrite(Prewalk(Chain([
(@rule mapjoin(~f::isimmediate, ~a1..., aggregate(~g::isimmediate, ~init::isimmediate, ~arg, ~idxs...), ~a2...) => begin
if isdistributive(DefaultAlgebra(), literal(g.val), literal(f.val)) &&
isannihilator(DefaultAlgebra(), literal(f.val), literal(init.val)) &&
length(getfields(aggregate(g, init, arg, idxs...))) ==
length(getfields(mapjoin(f, a1..., a2...)))
aggregate(g, init, mapjoin(f, a1..., arg, a2...), idxs...)
end
end),
])))(root)
root
end

function normalize_names(ex)
spc = Namespace()
scope = Dict()
Expand Down Expand Up @@ -541,6 +585,8 @@ function optimize(prgm)
#At this point in the program, all statements should be unique, so
#it is okay to name different occurences of things.

prgm = propagate_map_queries_backward(prgm)

#these steps lift reformat, aggregate, and table nodes into separate
#queries, using subqueries as temporaries.
prgm = isolate_reformats(prgm)
Expand Down
27 changes: 23 additions & 4 deletions test/test_issues.jl
Original file line number Diff line number Diff line change
Expand Up @@ -945,10 +945,29 @@ using SparseArrays

#https://github.com/finch-tensor/Finch.jl/issues/615

A = Tensor(Dense(Dense(Element(0.0))), 10, 10)
res = sum(tensordot(A, A, ((1,), (2,))))
let
A = Tensor(Dense(Dense(Element(0.0))), 10, 10)
res = sum(tensordot(A, A, ((1,), (2,))))

A_lazy = Finch.LazyTensor(A)
res = sum(tensordot(A_lazy, A_lazy, ((1,), (2,)))) # fails
end

#https://github.com/finch-tensor/Finch.jl/issues/614

let
A = sprand(5, 5, 0.5)
B = sprand(5, 5, 0.5)
x = rand(5)
C = Tensor(Dense(SparseList(Element(0.0))), A)
D = Tensor(Dense(SparseList(Element(0.0))), B)

@test A * B == C * D
@test A * B == compute(lazy(C) * D)
@test A * B == compute(C * lazy(D))
@test A * x == C * x
@test A * x == compute(lazy(C) * x)
end

A_lazy = Finch.LazyTensor(A)
res = sum(tensordot(A_lazy, A_lazy, ((1,), (2,)))) # fails

end

0 comments on commit 634e6d1

Please sign in to comment.