From 09499f8ce5ccf320f993e932221797250e6dd0fa Mon Sep 17 00:00:00 2001 From: Willow Ahrens Date: Fri, 25 Oct 2024 14:48:32 -0400 Subject: [PATCH] quick fix --- src/scheduler/optimize.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/scheduler/optimize.jl b/src/scheduler/optimize.jl index 9a479a715..93c7b84c4 100644 --- a/src/scheduler/optimize.jl +++ b/src/scheduler/optimize.jl @@ -121,10 +121,9 @@ function push_fields(root) (@rule reorder(mapjoin(~op, ~args...), ~idxs...) => mapjoin(op, map(arg -> reorder(arg, ~idxs...), args)...)), (@rule reorder(aggregate(~op, ~init, ~arg, ~idxs...), ~idxs_2...) => begin - idxs_3 = setdiff(getfields(arg), idxs) - reidx = Dict(map(Pair, idxs_3, idxs_2)...) - idxs_4 = map(idx -> get(reidx, idx, idx), getfields(arg)) - aggregate(op, init, reorder(arg, idxs_4...), idxs...) + #TODO it should be correct to write this, but subsequent phases interpret singleton dimensions as canonical ones when we do this. + #aggregate(op, init, reorder(arg, idxs_2..., idxs...), idxs...) + aggregate(op, init, reorder(arg, intersect(getfields(arg), idxs_2)..., idxs...), idxs...) end), (@rule reorder(reorder(~arg, ~idxs...), ~idxs_2...) => reorder(~arg, ~idxs_2...)),