Skip to content

Commit

Permalink
Merge pull request #628 from finch-tensor/wma/backwards-fusion
Browse files Browse the repository at this point in the history
quick fix
  • Loading branch information
willow-ahrens authored Oct 25, 2024
2 parents 634e6d1 + 0e375e9 commit d47c3ef
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions src/scheduler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,10 +121,9 @@ function push_fields(root)
(@rule reorder(mapjoin(~op, ~args...), ~idxs...) =>
mapjoin(op, map(arg -> reorder(arg, ~idxs...), args)...)),
(@rule reorder(aggregate(~op, ~init, ~arg, ~idxs...), ~idxs_2...) => begin
idxs_3 = setdiff(getfields(arg), idxs)
reidx = Dict(map(Pair, idxs_3, idxs_2)...)
idxs_4 = map(idx -> get(reidx, idx, idx), getfields(arg))
aggregate(op, init, reorder(arg, idxs_4...), idxs...)
#TODO it should be correct to write this, but subsequent phases interpret singleton dimensions as canonical ones when we do this.
#aggregate(op, init, reorder(arg, idxs_2..., idxs...), idxs...)
aggregate(op, init, reorder(arg, intersect(getfields(arg), idxs_2)..., idxs...), idxs...)
end),
(@rule reorder(reorder(~arg, ~idxs...), ~idxs_2...) =>
reorder(~arg, ~idxs_2...)),
Expand Down

0 comments on commit d47c3ef

Please sign in to comment.