Skip to content

Commit

Permalink
Avoid redundant scope lookups (#8103)
Browse files Browse the repository at this point in the history
* Avoid redundant scope lookups

This pattern has been bugging me for a long time:

```
if (scope.contains(key)) {
  Foo f = scope.get(key);
}
```

This redundantly looks up the key in the scope twice. I've finally
gotten around to fixing it. I've introduced a find method that either
returns a const pointer to the value, if it exists, or null. It also
searches any containing scopes, which are held by const pointer, so the
method has to return a const pointer.

```
if (const Foo *f = scope.find(key)) {
}
```

For cases where you want to get and then mutate, I added shallow_find,
which doesn't search enclosing scopes, but returns a mutable pointer.

We were also doing redundant scope lookups in ScopedBinding. We stored
the key in the helper object, and then did a pop on that key in the
ScopedBinding destructor. This commit changes Scope so that Scope::push
returns an opaque token that you can pass to Scope::pop to have it
remove that element without doing a fresh lookup. ScopedBinding now uses
this. Under the hood it's just an iterator on the underlying map (map
iterators are not invalidated on inserting or removing other stuff).

The net effect is to speed up local laplacian lowering by about 5%

I also considered making it look more like an stl class, and having find
return an iterator, but it doesn't really work. The iterator it returns
might point to an entry in an enclosing scope, in which case you can't
compare it to the .end() method of the scope you have. Scopes are
different enough from maps that the interface really needs to be
distinct.
  • Loading branch information
abadams authored Feb 22, 2024
1 parent ef31bf9 commit 57164df
Show file tree
Hide file tree
Showing 37 changed files with 305 additions and 261 deletions.
65 changes: 32 additions & 33 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,13 +406,12 @@ class Bounds : public IRVisitor {

if (const_bound) {
bounds_of_type(op->type);
if (scope.contains(op->name)) {
const Interval &scope_interval = scope.get(op->name);
if (scope_interval.has_upper_bound() && is_const(scope_interval.max)) {
interval.max = Interval::make_min(interval.max, scope_interval.max);
if (const Interval *scope_interval = scope.find(op->name)) {
if (scope_interval->has_upper_bound() && is_const(scope_interval->max)) {
interval.max = Interval::make_min(interval.max, scope_interval->max);
}
if (scope_interval.has_lower_bound() && is_const(scope_interval.min)) {
interval.min = Interval::make_max(interval.min, scope_interval.min);
if (scope_interval->has_lower_bound() && is_const(scope_interval->min)) {
interval.min = Interval::make_max(interval.min, scope_interval->min);
}
}

Expand All @@ -429,8 +428,8 @@ class Bounds : public IRVisitor {
}
}
} else {
if (scope.contains(op->name)) {
interval = scope.get(op->name);
if (const Interval *in = scope.find(op->name)) {
interval = *in;
} else if (op->type.is_vector()) {
// Uh oh, we need to take the min/max lane of some unknown vector. Treat as unbounded.
bounds_of_type(op->type);
Expand Down Expand Up @@ -2054,11 +2053,10 @@ class FindInnermostVar : public IRVisitor {
int innermost_depth = -1;

void visit(const Variable *op) override {
if (vars_depth.contains(op->name)) {
int depth = vars_depth.get(op->name);
if (depth > innermost_depth) {
if (const int *depth = vars_depth.find(op->name)) {
if (*depth > innermost_depth) {
innermost_var = op->name;
innermost_depth = depth;
innermost_depth = *depth;
}
}
}
Expand Down Expand Up @@ -2545,16 +2543,17 @@ class BoxesTouched : public IRGraphVisitor {
// If this let stmt is a redefinition of a previous one, we should
// remove the old let stmt from the 'children' map since it is
// no longer valid at this point.
if ((f.vi.instance > 0) && let_stmts.contains(op->name)) {
const Expr &val = let_stmts.get(op->name);
CollectVars collect(op->name);
val.accept(&collect);
f.old_let_vars = collect.vars;

VarInstance old_vi = VarInstance(f.vi.var, f.vi.instance - 1);
for (const auto &v : f.old_let_vars) {
internal_assert(vars_renaming.count(v));
children[get_var_instance(v)].erase(old_vi);
if (f.vi.instance > 0) {
if (const Expr *val = let_stmts.find(op->name)) {
CollectVars collect(op->name);
val->accept(&collect);
f.old_let_vars = collect.vars;

VarInstance old_vi = VarInstance(f.vi.var, f.vi.instance - 1);
for (const auto &v : f.old_let_vars) {
internal_assert(vars_renaming.count(v));
children[get_var_instance(v)].erase(old_vi);
}
}
}
let_stmts.push(op->name, op->value);
Expand Down Expand Up @@ -2756,17 +2755,17 @@ class BoxesTouched : public IRGraphVisitor {
expr_uses_var(box[i].min, l.min_name))) ||
(box[i].has_upper_bound() && (expr_uses_var(box[i].max, l.max_name) ||
expr_uses_var(box[i].max, l.min_name)))) {
internal_assert(let_stmts.contains(l.var));
const Expr &val = let_stmts.get(l.var);
v_bound = bounds_of_expr_in_scope(val, scope, func_bounds);
const Expr *val = let_stmts.find(l.var);
internal_assert(val);
v_bound = bounds_of_expr_in_scope(*val, scope, func_bounds);
bool fixed = v_bound.min.same_as(v_bound.max);
v_bound.min = simplify(v_bound.min);
v_bound.max = fixed ? v_bound.min : simplify(v_bound.max);

internal_assert(scope.contains(l.var));
const Interval &old_bound = scope.get(l.var);
v_bound.max = simplify(min(v_bound.max, old_bound.max));
v_bound.min = simplify(max(v_bound.min, old_bound.min));
const Interval *old_bound = scope.find(l.var);
internal_assert(old_bound);
v_bound.max = simplify(min(v_bound.max, old_bound->max));
v_bound.min = simplify(max(v_bound.min, old_bound->min));
}

if (box[i].has_lower_bound()) {
Expand Down Expand Up @@ -3017,14 +3016,14 @@ class BoxesTouched : public IRGraphVisitor {
}

Expr min_val, max_val;
if (scope.contains(op->name + ".loop_min")) {
min_val = scope.get(op->name + ".loop_min").min;
if (const Interval *in = scope.find(op->name + ".loop_min")) {
min_val = in->min;
} else {
min_val = bounds_of_expr_in_scope(op->min, scope, func_bounds).min;
}

if (scope.contains(op->name + ".loop_max")) {
max_val = scope.get(op->name + ".loop_max").max;
if (const Interval *in = scope.find(op->name + ".loop_max")) {
max_val = in->max;
} else {
max_val = bounds_of_expr_in_scope(op->extent, scope, func_bounds).max;
max_val += bounds_of_expr_in_scope(op->min, scope, func_bounds).max;
Expand Down
4 changes: 2 additions & 2 deletions src/CSE.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,8 @@ class RemoveLets : public IRGraphMutator {
Scope<Expr> scope;

Expr visit(const Variable *op) override {
if (scope.contains(op->name)) {
return scope.get(op->name);
if (const Expr *e = scope.find(op->name)) {
return *e;
} else {
return op;
}
Expand Down
6 changes: 4 additions & 2 deletions src/ClampUnsafeAccesses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ struct ClampUnsafeAccesses : IRMutator {
}

Expr visit(const Variable *var) override {
if (is_inside_indexing && let_var_inside_indexing.contains(var->name)) {
let_var_inside_indexing.ref(var->name) = true;
if (is_inside_indexing) {
if (bool *b = let_var_inside_indexing.shallow_find(var->name)) {
*b = true;
}
}
return var;
}
Expand Down
5 changes: 3 additions & 2 deletions src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,14 @@ class SubstituteInStridedLoads : public IRMutator {
Expr visit(const Shuffle *op) override {
int stride = op->slice_stride();
const Variable *var = op->vectors[0].as<Variable>();
const Expr *vec = nullptr;
if (var &&
poisoned_vars.count(var->name) == 0 &&
op->vectors.size() == 1 &&
2 <= stride && stride <= 4 &&
op->slice_begin() < stride &&
loads.contains(var->name)) {
return Shuffle::make_slice({loads.get(var->name)}, op->slice_begin(), op->slice_stride(), op->type.lanes());
(vec = loads.find(var->name))) {
return Shuffle::make_slice({*vec}, op->slice_begin(), op->slice_stride(), op->type.lanes());
} else {
return IRMutator::visit(op);
}
Expand Down
5 changes: 3 additions & 2 deletions src/CodeGen_C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1936,8 +1936,9 @@ void CodeGen_C::visit(const Load *op) {
user_assert(is_const_one(op->predicate)) << "Predicated scalar load is not supported by C backend.\n";

string id_index = print_expr(op->index);
bool type_cast_needed = !(allocations.contains(op->name) &&
allocations.get(op->name).type.element_of() == t.element_of());
const auto *alloc = allocations.find(op->name);
bool type_cast_needed = !(alloc &&
alloc->type.element_of() == t.element_of());
if (type_cast_needed) {
const char *const_flag = output_kind == CPlusPlusImplementation ? " const" : "";
rhs << "((" << print_type(t.element_of()) << const_flag << " *)" << name << ")";
Expand Down
5 changes: 3 additions & 2 deletions src/CodeGen_D3D12Compute_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,9 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Load *op) {
string id_index = print_expr(op->index);

// Get the rhs just for the cache.
bool type_cast_needed = !(allocations.contains(op->name) &&
allocations.get(op->name).type == op->type);
const auto *alloc = allocations.find(op->name);
bool type_cast_needed = !(alloc &&
alloc->type == op->type);

ostringstream rhs;
if (type_cast_needed) {
Expand Down
11 changes: 5 additions & 6 deletions src/CodeGen_Hexagon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,8 @@ class SloppyUnpredicateLoadsAndStores : public IRMutator {
}
}
} else if (const Variable *op = e.as<Variable>()) {
if (monotonic_vectors.contains(op->name)) {
return monotonic_vectors.get(op->name);
if (const auto *p = monotonic_vectors.find(op->name)) {
return *p;
}
} else if (const Let *op = e.as<Let>()) {
auto v = get_extreme_lanes(op->value);
Expand Down Expand Up @@ -2245,10 +2245,9 @@ void CodeGen_Hexagon::visit(const Allocate *alloc) {
codegen(alloc->body);

// If there was no early free, free it now.
if (allocations.contains(alloc->name)) {
Allocation alloc_obj = allocations.get(alloc->name);
internal_assert(alloc_obj.destructor);
trigger_destructor(alloc_obj.destructor_function, alloc_obj.destructor);
if (const Allocation *alloc_obj = allocations.find(alloc->name)) {
internal_assert(alloc_obj->destructor);
trigger_destructor(alloc_obj->destructor_function, alloc_obj->destructor);

allocations.pop(alloc->name);
sym_pop(alloc->name);
Expand Down
5 changes: 3 additions & 2 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1268,7 +1268,8 @@ void CodeGen_LLVM::sym_pop(const string &name) {

llvm::Value *CodeGen_LLVM::sym_get(const string &name, bool must_succeed) const {
// look in the symbol table
if (!symbol_table.contains(name)) {
llvm::Value *const *v = symbol_table.find(name);
if (!v) {
if (must_succeed) {
std::ostringstream err;
err << "Symbol not found: " << name << "\n";
Expand All @@ -1283,7 +1284,7 @@ llvm::Value *CodeGen_LLVM::sym_get(const string &name, bool must_succeed) const
return nullptr;
}
}
return symbol_table.get(name);
return *v;
}

bool CodeGen_LLVM::sym_exists(const string &name) const {
Expand Down
9 changes: 5 additions & 4 deletions src/CodeGen_Metal_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,8 +390,9 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Load *op) {
string id_index = print_expr(op->index);

// Get the rhs just for the cache.
bool type_cast_needed = !(allocations.contains(op->name) &&
allocations.get(op->name).type == op->type);
const auto *alloc = allocations.find(op->name);
bool type_cast_needed = !(alloc &&
alloc->type == op->type);
ostringstream rhs;
if (type_cast_needed) {
rhs << "((" << get_memory_space(op->name) << " "
Expand Down Expand Up @@ -467,8 +468,8 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Store *op) {
<< id_value << "[" << i << "];\n";
}
} else {
bool type_cast_needed = !(allocations.contains(op->name) &&
allocations.get(op->name).type == t);
const auto *alloc = allocations.find(op->name);
bool type_cast_needed = !(alloc && alloc->type == t);

string id_index = print_expr(op->index);
stream << get_indent();
Expand Down
8 changes: 4 additions & 4 deletions src/CodeGen_OpenCL_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -484,8 +484,8 @@ string CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::print_array_access(const string &na
const Type &type,
const string &id_index) {
ostringstream rhs;
bool type_cast_needed = !(allocations.contains(name) &&
allocations.get(name).type == type);
const auto *alloc = allocations.find(name);
bool type_cast_needed = !(alloc && alloc->type == type);

if (type_cast_needed) {
rhs << "((" << get_memory_space(name) << " "
Expand Down Expand Up @@ -583,8 +583,8 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Store *op) {
// For atomicAdd, we check if op->value - store[index] is independent of store.
// The atomicAdd operations in OpenCL only supports integers so we also check that.
bool is_atomic_add = t.is_int_or_uint() && !expr_uses_var(delta, op->name);
bool type_cast_needed = !(allocations.contains(op->name) &&
allocations.get(op->name).type == t);
const auto *alloc = allocations.find(op->name);
bool type_cast_needed = !(alloc && alloc->type == t);
auto print_store_var = [&]() {
if (type_cast_needed) {
stream << "(("
Expand Down
4 changes: 2 additions & 2 deletions src/CodeGen_Posix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -342,8 +342,8 @@ void CodeGen_Posix::free_allocation(const std::string &name) {
}

string CodeGen_Posix::get_allocation_name(const std::string &n) {
if (allocations.contains(n)) {
return allocations.get(n).name;
if (const auto *alloc = allocations.find(n)) {
return alloc->name;
} else {
return n;
}
Expand Down
28 changes: 15 additions & 13 deletions src/CodeGen_Vulkan_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1539,10 +1539,10 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Load *op) {
user_assert(is_const_one(op->predicate)) << "Predicated loads not supported by SPIR-V codegen\n";

// Construct the pointer to read from
internal_assert(symbol_table.contains(op->name));
SymbolIdStorageClassPair id_and_storage_class = symbol_table.get(op->name);
SpvId variable_id = id_and_storage_class.first;
SpvStorageClass storage_class = id_and_storage_class.second;
const SymbolIdStorageClassPair *id_and_storage_class = symbol_table.find(op->name);
internal_assert(id_and_storage_class);
SpvId variable_id = id_and_storage_class->first;
SpvStorageClass storage_class = id_and_storage_class->second;
internal_assert(variable_id != SpvInvalidId);
internal_assert(((uint32_t)storage_class) < ((uint32_t)SpvStorageClassMax));

Expand Down Expand Up @@ -1576,10 +1576,10 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Store *op) {
op->value.accept(this);
SpvId value_id = builder.current_id();

internal_assert(symbol_table.contains(op->name));
SymbolIdStorageClassPair id_and_storage_class = symbol_table.get(op->name);
SpvId variable_id = id_and_storage_class.first;
SpvStorageClass storage_class = id_and_storage_class.second;
const SymbolIdStorageClassPair *id_and_storage_class = symbol_table.find(op->name);
internal_assert(id_and_storage_class);
SpvId variable_id = id_and_storage_class->first;
SpvStorageClass storage_class = id_and_storage_class->second;
internal_assert(variable_id != SpvInvalidId);
internal_assert(((uint32_t)storage_class) < ((uint32_t)SpvStorageClassMax));

Expand Down Expand Up @@ -1665,9 +1665,10 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const For *op) {
const std::string intrinsic_var_name = std::string("k") + std::to_string(kernel_index) + std::string("_") + intrinsic.first;

// Intrinsics are inserted when adding the kernel
internal_assert(symbol_table.contains(intrinsic_var_name));
SpvId intrinsic_id = symbol_table.get(intrinsic_var_name).first;
SpvStorageClass storage_class = symbol_table.get(intrinsic_var_name).second;
const auto *intrin = symbol_table.find(intrinsic_var_name);
internal_assert(intrin);
SpvId intrinsic_id = intrin->first;
SpvStorageClass storage_class = intrin->second;

// extract and cast to the extent type (which is what's expected by Halide's for loops)
Type unsigned_type = UInt(32);
Expand Down Expand Up @@ -1908,8 +1909,9 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Allocate *op) {

void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Free *op) {
debug(3) << "Vulkan: Popping allocation called " << op->name << " off the symbol table\n";
internal_assert(symbol_table.contains(op->name));
SpvId variable_id = symbol_table.get(op->name).first;
const auto *id = symbol_table.find(op->name);
internal_assert(id);
SpvId variable_id = id->first;
storage_access_map.erase(variable_id);
symbol_table.pop(op->name);
}
Expand Down
8 changes: 4 additions & 4 deletions src/CodeGen_WebGPU_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -684,8 +684,8 @@ void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Load *op) {

// Get the allocation type, which may be different from the result type.
Type alloc_type = result_type;
if (allocations.contains(op->name)) {
alloc_type = allocations.get(op->name).type;
if (const auto *alloc = allocations.find(op->name)) {
alloc_type = alloc->type;
} else if (workgroup_allocations.count(op->name)) {
alloc_type = workgroup_allocations.at(op->name)->type;
}
Expand Down Expand Up @@ -826,8 +826,8 @@ void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Store *op) {

// Get the allocation type, which may be different from the value type.
Type alloc_type = value_type;
if (allocations.contains(op->name)) {
alloc_type = allocations.get(op->name).type;
if (const auto *alloc = allocations.find(op->name)) {
alloc_type = alloc->type;
} else if (workgroup_allocations.count(op->name)) {
alloc_type = workgroup_allocations.at(op->name)->type;
}
Expand Down
Loading

0 comments on commit 57164df

Please sign in to comment.