From 0dc03ee06759e638fa3b86b9b94f02f0baf16efb Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Wed, 6 Mar 2024 11:17:59 -0800 Subject: [PATCH] Handle loads of broadcasts in FlattenNestedRamps With sufficiently perverse schedules, it's possible to end up with a load of a broadcast index (rather than a broadcast of a scalar load). This made FlattenNestedRamps divide by zero. Unfortunately this happened in a complex production pipeline, so I'm not entirely sure how to reproduce it. For that pipeline, this change fixes it and produces correct output. --- src/FlattenNestedRamps.cpp | 42 +++++++++++++++++++++++--------------- 1 file changed, 26 insertions(+), 16 deletions(-) diff --git a/src/FlattenNestedRamps.cpp b/src/FlattenNestedRamps.cpp index f48bd75c37a2..92bcf3870d5d 100644 --- a/src/FlattenNestedRamps.cpp +++ b/src/FlattenNestedRamps.cpp @@ -81,19 +81,19 @@ class FlattenRamps : public IRMutator { // If they are, we'll have a full vector of const_indices if ((int)const_indices.size() == lanes) { - // Compute the stride for the underlying strided load - int stride = 0; - for (int c : const_indices) { - stride = (int)gcd(stride, c); - } - for (int &c : const_indices) { - c /= stride; + int stride = 0, extent = 1; + if (max_constant_offset > 0) { + for (int c : const_indices) { + stride = (int)gcd(stride, c); + } + for (int &c : const_indices) { + c /= stride; + } + // Compute the number of elements loaded + extent = (int)((max_constant_offset / stride) + 1); } - // Compute the number of elements loaded - int extent = (int)((max_constant_offset / stride) + 1); - // If we're gathering from a very large range, it // might be better to just do the gather rather than // doing a big dense load and then shuffling. We @@ -105,12 +105,22 @@ class FlattenRamps : public IRMutator { // in the schedule somehow. const int max_unused_lane_factor = 4; if (extent < max_unused_lane_factor * lanes) { - Expr dense_index = Ramp::make(min_lane, make_const(min_lane.type(), stride), extent); - Expr dense_load = - Load::make(op->type.with_lanes(extent), op->name, dense_index, - op->image, op->param, - const_true(extent), ModulusRemainder{}); - return Shuffle::make({dense_load}, const_indices); + if (max_constant_offset == 0) { + // It's a load of a broadcast. Convert it to a broadcast of a load + Expr load = Load::make(op->type.element_of(), op->name, min_lane, + op->image, op->param, + const_true(), ModulusRemainder{}); + return Broadcast::make(load, lanes); + } else { + // Turn it into a dense load and a shuffle + Expr dense_index = + Ramp::make(min_lane, make_const(min_lane.type(), stride), extent); + Expr dense_load = + Load::make(op->type.with_lanes(extent), op->name, dense_index, + op->image, op->param, + const_true(extent), ModulusRemainder{}); + return Shuffle::make({dense_load}, const_indices); + } } } }