diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index 1c3ec57f3fb7..6d10d2e9d5f3 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -134,7 +134,7 @@ Interval bounds_of_lanes(const Expr &e) { Interval ia = bounds_of_lanes(not_->a); return {!ia.max, !ia.min}; } else if (const Ramp *r = e.as<Ramp>()) { - Expr last_lane_idx = make_const(r->base.type(), r->lanes - 1); + Expr last_lane_idx = make_const(r->base.type().element_of(), r->lanes - 1); Interval ib = bounds_of_lanes(r->base); const Broadcast *b = as_scalar_broadcast(r->stride); Expr stride = b ? b->value : r->stride; @@ -875,6 +875,7 @@ class VectorSubs : public IRMutator { // generating a scalar condition that checks if // the least-true lane is true. Expr all_true = bounds_of_lanes(likely->args[0]).min; + internal_assert(all_true.type() == Bool()); // Wrap it in the same flavor of likely all_true = Call::make(Bool(), likely->name, {all_true}, Call::PureIntrinsic); diff --git a/test/correctness/fuzz_schedule.cpp b/test/correctness/fuzz_schedule.cpp index 9f0f86e3854b..a774335a07bf 100644 --- a/test/correctness/fuzz_schedule.cpp +++ b/test/correctness/fuzz_schedule.cpp @@ -202,6 +202,74 @@ int main(int argc, char **argv) { check_blur_output(buf, correct); } + // https://github.com/halide/Halide/issues/8054 + { + ImageParam input(Float(32), 2, "input"); + const float r_sigma = 0.1; + const int s_sigma = 8; + Func bilateral_grid{"bilateral_grid"}; + + Var x("x"), y("y"), z("z"), c("c"); + + // Add a boundary condition + Func clamped = Halide::BoundaryConditions::repeat_edge(input); + + // Construct the bilateral grid + RDom r(0, s_sigma, 0, s_sigma); + Expr val = clamped(x * s_sigma + r.x - s_sigma / 2, y * s_sigma + r.y - s_sigma / 2); + val = clamp(val, 0.0f, 1.0f); + + Expr zi = cast<int>(val * (1.0f / r_sigma) + 0.5f); + + Func histogram("histogram"); + histogram(x, y, z, c) = 0.0f; + histogram(x, y, zi, c) += mux(c, {val, 1.0f}); + + // Blur the grid using a five-tap filter + Func blurx("blurx"), blury("blury"), blurz("blurz"); + blurz(x, y, z, c) = (histogram(x, y, z - 2, c) + + histogram(x, y, z - 1, c) * 4 + + histogram(x, y, z, c) * 6 + + histogram(x, y, z + 1, c) * 4 + + histogram(x, y, z + 2, c)); + blurx(x, y, z, c) = (blurz(x - 2, y, z, c) + + blurz(x - 1, y, z, c) * 4 + + blurz(x, y, z, c) * 6 + + blurz(x + 1, y, z, c) * 4 + + blurz(x + 2, y, z, c)); + blury(x, y, z, c) = (blurx(x, y - 2, z, c) + + blurx(x, y - 1, z, c) * 4 + + blurx(x, y, z, c) * 6 + + blurx(x, y + 1, z, c) * 4 + + blurx(x, y + 2, z, c)); + + // Take trilinear samples to compute the output + val = clamp(input(x, y), 0.0f, 1.0f); + Expr zv = val * (1.0f / r_sigma); + zi = cast<int>(zv); + Expr zf = zv - zi; + Expr xf = cast<float>(x % s_sigma) / s_sigma; + Expr yf = cast<float>(y % s_sigma) / s_sigma; + Expr xi = x / s_sigma; + Expr yi = y / s_sigma; + Func interpolated("interpolated"); + interpolated(x, y, c) = + lerp(lerp(lerp(blury(xi, yi, zi, c), blury(xi + 1, yi, zi, c), xf), + lerp(blury(xi, yi + 1, zi, c), blury(xi + 1, yi + 1, zi, c), xf), yf), + lerp(lerp(blury(xi, yi, zi + 1, c), blury(xi + 1, yi, zi + 1, c), xf), + lerp(blury(xi, yi + 1, zi + 1, c), blury(xi + 1, yi + 1, zi + 1, c), xf), yf), + zf); + + // Normalize + bilateral_grid(x, y) = interpolated(x, y, 0) / interpolated(x, y, 1); + Pipeline p({bilateral_grid}); + + Var v6, zo, vzi; + + blury.compute_root().split(x, x, v6, 6, TailStrategy::GuardWithIf).split(z, zo, vzi, 8, TailStrategy::GuardWithIf).reorder(y, x, c, vzi, zo, v6).vectorize(vzi).vectorize(v6); + p.compile_to_module({input}, "bilateral_grid", {Target("host")}); + } + printf("Success!\n"); return 0; }