Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle loads of broadcasts in FlattenNestedRamps #8139

Merged
merged 1 commit into from
Mar 8, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 26 additions & 16 deletions src/FlattenNestedRamps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,19 @@ class FlattenRamps : public IRMutator {

// If they are, we'll have a full vector of const_indices
if ((int)const_indices.size() == lanes) {

// Compute the stride for the underlying strided load
int stride = 0;
for (int c : const_indices) {
stride = (int)gcd(stride, c);
}
for (int &c : const_indices) {
c /= stride;
int stride = 0, extent = 1;
if (max_constant_offset > 0) {
for (int c : const_indices) {
stride = (int)gcd(stride, c);
}
for (int &c : const_indices) {
c /= stride;
}
// Compute the number of elements loaded
extent = (int)((max_constant_offset / stride) + 1);
}

// Compute the number of elements loaded
int extent = (int)((max_constant_offset / stride) + 1);

// If we're gathering from a very large range, it
// might be better to just do the gather rather than
// doing a big dense load and then shuffling. We
Expand All @@ -105,12 +105,22 @@ class FlattenRamps : public IRMutator {
// in the schedule somehow.
const int max_unused_lane_factor = 4;
if (extent < max_unused_lane_factor * lanes) {
Expr dense_index = Ramp::make(min_lane, make_const(min_lane.type(), stride), extent);
Expr dense_load =
Load::make(op->type.with_lanes(extent), op->name, dense_index,
op->image, op->param,
const_true(extent), ModulusRemainder{});
return Shuffle::make({dense_load}, const_indices);
if (max_constant_offset == 0) {
// It's a load of a broadcast. Convert it to a broadcast of a load
Expr load = Load::make(op->type.element_of(), op->name, min_lane,
op->image, op->param,
const_true(), ModulusRemainder{});
return Broadcast::make(load, lanes);
} else {
// Turn it into a dense load and a shuffle
Expr dense_index =
Ramp::make(min_lane, make_const(min_lane.type(), stride), extent);
Expr dense_load =
Load::make(op->type.with_lanes(extent), op->name, dense_index,
op->image, op->param,
const_true(extent), ModulusRemainder{});
return Shuffle::make({dense_load}, const_indices);
}
}
}
}
Expand Down
Loading