From fb2cf290e1b17a4e074dd673439f7272f7916308 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sat, 17 Feb 2024 17:28:26 -0800 Subject: [PATCH 1/4] Avoid redundant scope lookups This pattern has been bugging me for a long time: ``` if (scope.contains(key)) { Foo f = scope.get(key); } ``` This redundantly looks up the key in the scope twice. I've finally gotten around to fixing it. I've introduced a find method that either returns a const pointer to the value, if it exists, or null. It also searches any containing scopes, which are held by const pointer, so the method has to return a const pointer. ``` if (const Foo *f = scope.find(key)) { } ``` For cases where you want to get and then mutate, I added shallow_find, which doesn't search enclosing scopes, but returns a mutable pointer. We were also doing redundant scope lookups in ScopedBinding. We stored the key in the helper object, and then did a pop on that key in the ScopedBinding destructor. This commit changes Scope so that Scope::push returns an opaque token that you can pass to Scope::pop to have it remove that element without doing a fresh lookup. ScopedBinding now uses this. Under the hood it's just an iterator on the underlying map (map iterators are not invalidated on inserting or removing other stuff). The net effect is to speed up local laplacian lowering by about 5% I also considered making it look more like an stl class, and having find return an iterator, but it doesn't really work. The iterator it returns might point to an entry in an enclosing scope, in which case you can't compare it to the .end() method of the scope you have. Scopes are different enough from maps that the interface really needs to be distinct. --- src/Bounds.cpp | 65 +++++++++++----------- src/CSE.cpp | 4 +- src/ClampUnsafeAccesses.cpp | 6 ++- src/CodeGen_ARM.cpp | 5 +- src/CodeGen_C.cpp | 5 +- src/CodeGen_D3D12Compute_Dev.cpp | 5 +- src/CodeGen_Hexagon.cpp | 11 ++-- src/CodeGen_LLVM.cpp | 5 +- src/CodeGen_Metal_Dev.cpp | 9 ++-- src/CodeGen_OpenCL_Dev.cpp | 8 +-- src/CodeGen_Posix.cpp | 4 +- src/CodeGen_Vulkan_Dev.cpp | 28 +++++----- src/CodeGen_WebGPU_Dev.cpp | 8 +-- src/CodeGen_X86.cpp | 38 +++++++------ src/EliminateBoolVectors.cpp | 4 +- src/ExprUsesVar.h | 4 +- src/FindIntrinsics.cpp | 4 +- src/FuseGPUThreadLoops.cpp | 16 +++--- src/HexagonOptimize.cpp | 32 +++++------ src/LICM.cpp | 4 +- src/LoopCarry.cpp | 13 +++-- src/LowerWarpShuffles.cpp | 20 +++---- src/ModulusRemainder.cpp | 4 +- src/Monotonic.cpp | 4 +- src/Prefetch.cpp | 7 ++- src/PrintLoopNest.cpp | 15 +++--- src/Scope.h | 93 +++++++++++++++++++++++--------- src/Simplify.cpp | 43 +++++++-------- src/Simplify_Exprs.cpp | 32 +++++------ src/Simplify_Stmts.cpp | 20 +++---- src/SlidingWindow.cpp | 7 ++- src/Solve.cpp | 18 +++---- src/StageStridedLoads.cpp | 8 +-- src/StmtToHTML.cpp | 4 +- src/StorageFlattening.cpp | 5 +- src/UniquifyVariableNames.cpp | 7 ++- src/VectorizeLoops.cpp | 4 +- 37 files changed, 306 insertions(+), 263 deletions(-) diff --git a/src/Bounds.cpp b/src/Bounds.cpp index a08bb0b9ad61..16fd69f3e8fb 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -406,13 +406,12 @@ class Bounds : public IRVisitor { if (const_bound) { bounds_of_type(op->type); - if (scope.contains(op->name)) { - const Interval &scope_interval = scope.get(op->name); - if (scope_interval.has_upper_bound() && is_const(scope_interval.max)) { - interval.max = Interval::make_min(interval.max, scope_interval.max); + if (const Interval *scope_interval = scope.find(op->name)) { + if (scope_interval->has_upper_bound() && is_const(scope_interval->max)) { + interval.max = Interval::make_min(interval.max, scope_interval->max); } - if (scope_interval.has_lower_bound() && is_const(scope_interval.min)) { - interval.min = Interval::make_max(interval.min, scope_interval.min); + if (scope_interval->has_lower_bound() && is_const(scope_interval->min)) { + interval.min = Interval::make_max(interval.min, scope_interval->min); } } @@ -429,8 +428,8 @@ class Bounds : public IRVisitor { } } } else { - if (scope.contains(op->name)) { - interval = scope.get(op->name); + if (const Interval *in = scope.find(op->name)) { + interval = *in; } else if (op->type.is_vector()) { // Uh oh, we need to take the min/max lane of some unknown vector. Treat as unbounded. bounds_of_type(op->type); @@ -2054,11 +2053,10 @@ class FindInnermostVar : public IRVisitor { int innermost_depth = -1; void visit(const Variable *op) override { - if (vars_depth.contains(op->name)) { - int depth = vars_depth.get(op->name); - if (depth > innermost_depth) { + if (const int *depth = vars_depth.find(op->name)) { + if (*depth > innermost_depth) { innermost_var = op->name; - innermost_depth = depth; + innermost_depth = *depth; } } } @@ -2545,16 +2543,17 @@ class BoxesTouched : public IRGraphVisitor { // If this let stmt is a redefinition of a previous one, we should // remove the old let stmt from the 'children' map since it is // no longer valid at this point. - if ((f.vi.instance > 0) && let_stmts.contains(op->name)) { - const Expr &val = let_stmts.get(op->name); - CollectVars collect(op->name); - val.accept(&collect); - f.old_let_vars = collect.vars; - - VarInstance old_vi = VarInstance(f.vi.var, f.vi.instance - 1); - for (const auto &v : f.old_let_vars) { - internal_assert(vars_renaming.count(v)); - children[get_var_instance(v)].erase(old_vi); + if (f.vi.instance > 0) { + if (const Expr *val = let_stmts.find(op->name)) { + CollectVars collect(op->name); + val->accept(&collect); + f.old_let_vars = collect.vars; + + VarInstance old_vi = VarInstance(f.vi.var, f.vi.instance - 1); + for (const auto &v : f.old_let_vars) { + internal_assert(vars_renaming.count(v)); + children[get_var_instance(v)].erase(old_vi); + } } } let_stmts.push(op->name, op->value); @@ -2756,17 +2755,17 @@ class BoxesTouched : public IRGraphVisitor { expr_uses_var(box[i].min, l.min_name))) || (box[i].has_upper_bound() && (expr_uses_var(box[i].max, l.max_name) || expr_uses_var(box[i].max, l.min_name)))) { - internal_assert(let_stmts.contains(l.var)); - const Expr &val = let_stmts.get(l.var); - v_bound = bounds_of_expr_in_scope(val, scope, func_bounds); + const Expr *val = let_stmts.find(l.var); + internal_assert(val); + v_bound = bounds_of_expr_in_scope(*val, scope, func_bounds); bool fixed = v_bound.min.same_as(v_bound.max); v_bound.min = simplify(v_bound.min); v_bound.max = fixed ? v_bound.min : simplify(v_bound.max); - internal_assert(scope.contains(l.var)); - const Interval &old_bound = scope.get(l.var); - v_bound.max = simplify(min(v_bound.max, old_bound.max)); - v_bound.min = simplify(max(v_bound.min, old_bound.min)); + const Interval *old_bound = scope.find(l.var); + internal_assert(old_bound); + v_bound.max = simplify(min(v_bound.max, old_bound->max)); + v_bound.min = simplify(max(v_bound.min, old_bound->min)); } if (box[i].has_lower_bound()) { @@ -3017,14 +3016,14 @@ class BoxesTouched : public IRGraphVisitor { } Expr min_val, max_val; - if (scope.contains(op->name + ".loop_min")) { - min_val = scope.get(op->name + ".loop_min").min; + if (const Interval *in = scope.find(op->name + ".loop_min")) { + min_val = in->min; } else { min_val = bounds_of_expr_in_scope(op->min, scope, func_bounds).min; } - if (scope.contains(op->name + ".loop_max")) { - max_val = scope.get(op->name + ".loop_max").max; + if (const Interval *in = scope.find(op->name + ".loop_max")) { + max_val = in->max; } else { max_val = bounds_of_expr_in_scope(op->extent, scope, func_bounds).max; max_val += bounds_of_expr_in_scope(op->min, scope, func_bounds).max; diff --git a/src/CSE.cpp b/src/CSE.cpp index 7d39fcc90dc5..d8ecd619db81 100644 --- a/src/CSE.cpp +++ b/src/CSE.cpp @@ -201,8 +201,8 @@ class RemoveLets : public IRGraphMutator { Scope scope; Expr visit(const Variable *op) override { - if (scope.contains(op->name)) { - return scope.get(op->name); + if (const Expr *e = scope.find(op->name)) { + return *e; } else { return op; } diff --git a/src/ClampUnsafeAccesses.cpp b/src/ClampUnsafeAccesses.cpp index 5e2e1f5d5b2e..b3dd9ddc235e 100644 --- a/src/ClampUnsafeAccesses.cpp +++ b/src/ClampUnsafeAccesses.cpp @@ -50,8 +50,10 @@ struct ClampUnsafeAccesses : IRMutator { } Expr visit(const Variable *var) override { - if (is_inside_indexing && let_var_inside_indexing.contains(var->name)) { - let_var_inside_indexing.ref(var->name) = true; + if (is_inside_indexing) { + if (bool *b = let_var_inside_indexing.shallow_find(var->name)) { + *b = true; + } } return var; } diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index 9c6525703f16..7852532183bf 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -82,13 +82,14 @@ class SubstituteInStridedLoads : public IRMutator { Expr visit(const Shuffle *op) override { int stride = op->slice_stride(); const Variable *var = op->vectors[0].as(); + const Expr *vec = nullptr; if (var && poisoned_vars.count(var->name) == 0 && op->vectors.size() == 1 && 2 <= stride && stride <= 4 && op->slice_begin() < stride && - loads.contains(var->name)) { - return Shuffle::make_slice({loads.get(var->name)}, op->slice_begin(), op->slice_stride(), op->type.lanes()); + (vec = loads.find(var->name))) { + return Shuffle::make_slice({*vec}, op->slice_begin(), op->slice_stride(), op->type.lanes()); } else { return IRMutator::visit(op); } diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index 89c18cb8ab28..b0cdcb3e956c 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -1936,8 +1936,9 @@ void CodeGen_C::visit(const Load *op) { user_assert(is_const_one(op->predicate)) << "Predicated scalar load is not supported by C backend.\n"; string id_index = print_expr(op->index); - bool type_cast_needed = !(allocations.contains(op->name) && - allocations.get(op->name).type.element_of() == t.element_of()); + const auto *alloc = allocations.find(op->name); + bool type_cast_needed = !(alloc && + alloc->type.element_of() == t.element_of()); if (type_cast_needed) { const char *const_flag = output_kind == CPlusPlusImplementation ? " const" : ""; rhs << "((" << print_type(t.element_of()) << const_flag << " *)" << name << ")"; diff --git a/src/CodeGen_D3D12Compute_Dev.cpp b/src/CodeGen_D3D12Compute_Dev.cpp index c8e45ea2ae09..4fd614cc0dfc 100644 --- a/src/CodeGen_D3D12Compute_Dev.cpp +++ b/src/CodeGen_D3D12Compute_Dev.cpp @@ -592,8 +592,9 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Load *op) { string id_index = print_expr(op->index); // Get the rhs just for the cache. - bool type_cast_needed = !(allocations.contains(op->name) && - allocations.get(op->name).type == op->type); + const auto *alloc = allocations.find(op->name); + bool type_cast_needed = !(alloc && + alloc->type == op->type); ostringstream rhs; if (type_cast_needed) { diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index 9463a4c921aa..a77e9c7c1a76 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -221,8 +221,8 @@ class SloppyUnpredicateLoadsAndStores : public IRMutator { } } } else if (const Variable *op = e.as()) { - if (monotonic_vectors.contains(op->name)) { - return monotonic_vectors.get(op->name); + if (const auto *p = monotonic_vectors.find(op->name)) { + return *p; } } else if (const Let *op = e.as()) { auto v = get_extreme_lanes(op->value); @@ -2245,10 +2245,9 @@ void CodeGen_Hexagon::visit(const Allocate *alloc) { codegen(alloc->body); // If there was no early free, free it now. - if (allocations.contains(alloc->name)) { - Allocation alloc_obj = allocations.get(alloc->name); - internal_assert(alloc_obj.destructor); - trigger_destructor(alloc_obj.destructor_function, alloc_obj.destructor); + if (const Allocation *alloc_obj = allocations.find(alloc->name)) { + internal_assert(alloc_obj->destructor); + trigger_destructor(alloc_obj->destructor_function, alloc_obj->destructor); allocations.pop(alloc->name); sym_pop(alloc->name); diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index a5c32cf83cc7..8922461524c5 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -1268,7 +1268,8 @@ void CodeGen_LLVM::sym_pop(const string &name) { llvm::Value *CodeGen_LLVM::sym_get(const string &name, bool must_succeed) const { // look in the symbol table - if (!symbol_table.contains(name)) { + llvm::Value *const *v = symbol_table.find(name); + if (!v) { if (must_succeed) { std::ostringstream err; err << "Symbol not found: " << name << "\n"; @@ -1283,7 +1284,7 @@ llvm::Value *CodeGen_LLVM::sym_get(const string &name, bool must_succeed) const return nullptr; } } - return symbol_table.get(name); + return *v; } bool CodeGen_LLVM::sym_exists(const string &name) const { diff --git a/src/CodeGen_Metal_Dev.cpp b/src/CodeGen_Metal_Dev.cpp index 69d47279e9ae..79060294798e 100644 --- a/src/CodeGen_Metal_Dev.cpp +++ b/src/CodeGen_Metal_Dev.cpp @@ -390,8 +390,9 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Load *op) { string id_index = print_expr(op->index); // Get the rhs just for the cache. - bool type_cast_needed = !(allocations.contains(op->name) && - allocations.get(op->name).type == op->type); + const auto *alloc = allocations.find(op->name); + bool type_cast_needed = !(alloc && + alloc->type == op->type); ostringstream rhs; if (type_cast_needed) { rhs << "((" << get_memory_space(op->name) << " " @@ -467,8 +468,8 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Store *op) { << id_value << "[" << i << "];\n"; } } else { - bool type_cast_needed = !(allocations.contains(op->name) && - allocations.get(op->name).type == t); + const auto *alloc = allocations.find(op->name); + bool type_cast_needed = !(alloc && alloc->type == t); string id_index = print_expr(op->index); stream << get_indent(); diff --git a/src/CodeGen_OpenCL_Dev.cpp b/src/CodeGen_OpenCL_Dev.cpp index 52feed53f9e0..c86e483cc5a8 100644 --- a/src/CodeGen_OpenCL_Dev.cpp +++ b/src/CodeGen_OpenCL_Dev.cpp @@ -484,8 +484,8 @@ string CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::print_array_access(const string &na const Type &type, const string &id_index) { ostringstream rhs; - bool type_cast_needed = !(allocations.contains(name) && - allocations.get(name).type == type); + const auto *alloc = allocations.find(name); + bool type_cast_needed = !(alloc && alloc->type == type); if (type_cast_needed) { rhs << "((" << get_memory_space(name) << " " @@ -583,8 +583,8 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Store *op) { // For atomicAdd, we check if op->value - store[index] is independent of store. // The atomicAdd operations in OpenCL only supports integers so we also check that. bool is_atomic_add = t.is_int_or_uint() && !expr_uses_var(delta, op->name); - bool type_cast_needed = !(allocations.contains(op->name) && - allocations.get(op->name).type == t); + const auto *alloc = allocations.find(op->name); + bool type_cast_needed = !(alloc && alloc->type == t); auto print_store_var = [&]() { if (type_cast_needed) { stream << "((" diff --git a/src/CodeGen_Posix.cpp b/src/CodeGen_Posix.cpp index af508194b06e..f812b63cce9d 100644 --- a/src/CodeGen_Posix.cpp +++ b/src/CodeGen_Posix.cpp @@ -342,8 +342,8 @@ void CodeGen_Posix::free_allocation(const std::string &name) { } string CodeGen_Posix::get_allocation_name(const std::string &n) { - if (allocations.contains(n)) { - return allocations.get(n).name; + if (const auto *alloc = allocations.find(n)) { + return alloc->name; } else { return n; } diff --git a/src/CodeGen_Vulkan_Dev.cpp b/src/CodeGen_Vulkan_Dev.cpp index 61b365f2f7aa..39dd65b67671 100644 --- a/src/CodeGen_Vulkan_Dev.cpp +++ b/src/CodeGen_Vulkan_Dev.cpp @@ -1539,10 +1539,10 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Load *op) { user_assert(is_const_one(op->predicate)) << "Predicated loads not supported by SPIR-V codegen\n"; // Construct the pointer to read from - internal_assert(symbol_table.contains(op->name)); - SymbolIdStorageClassPair id_and_storage_class = symbol_table.get(op->name); - SpvId variable_id = id_and_storage_class.first; - SpvStorageClass storage_class = id_and_storage_class.second; + const SymbolIdStorageClassPair *id_and_storage_class = symbol_table.find(op->name); + internal_assert(id_and_storage_class); + SpvId variable_id = id_and_storage_class->first; + SpvStorageClass storage_class = id_and_storage_class->second; internal_assert(variable_id != SpvInvalidId); internal_assert(((uint32_t)storage_class) < ((uint32_t)SpvStorageClassMax)); @@ -1576,10 +1576,10 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Store *op) { op->value.accept(this); SpvId value_id = builder.current_id(); - internal_assert(symbol_table.contains(op->name)); - SymbolIdStorageClassPair id_and_storage_class = symbol_table.get(op->name); - SpvId variable_id = id_and_storage_class.first; - SpvStorageClass storage_class = id_and_storage_class.second; + const SymbolIdStorageClassPair *id_and_storage_class = symbol_table.find(op->name); + internal_assert(id_and_storage_class); + SpvId variable_id = id_and_storage_class->first; + SpvStorageClass storage_class = id_and_storage_class->second; internal_assert(variable_id != SpvInvalidId); internal_assert(((uint32_t)storage_class) < ((uint32_t)SpvStorageClassMax)); @@ -1665,9 +1665,10 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const For *op) { const std::string intrinsic_var_name = std::string("k") + std::to_string(kernel_index) + std::string("_") + intrinsic.first; // Intrinsics are inserted when adding the kernel - internal_assert(symbol_table.contains(intrinsic_var_name)); - SpvId intrinsic_id = symbol_table.get(intrinsic_var_name).first; - SpvStorageClass storage_class = symbol_table.get(intrinsic_var_name).second; + const auto *intrin = symbol_table.find(intrinsic_var_name); + internal_assert(intrin); + SpvId intrinsic_id = intrin->first; + SpvStorageClass storage_class = intrin->second; // extract and cast to the extent type (which is what's expected by Halide's for loops) Type unsigned_type = UInt(32); @@ -1908,8 +1909,9 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Allocate *op) { void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Free *op) { debug(3) << "Vulkan: Popping allocation called " << op->name << " off the symbol table\n"; - internal_assert(symbol_table.contains(op->name)); - SpvId variable_id = symbol_table.get(op->name).first; + const auto *id = symbol_table.find(op->name); + internal_assert(id); + SpvId variable_id = id->first; storage_access_map.erase(variable_id); symbol_table.pop(op->name); } diff --git a/src/CodeGen_WebGPU_Dev.cpp b/src/CodeGen_WebGPU_Dev.cpp index 08d3a542f41b..de55113ff695 100644 --- a/src/CodeGen_WebGPU_Dev.cpp +++ b/src/CodeGen_WebGPU_Dev.cpp @@ -684,8 +684,8 @@ void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Load *op) { // Get the allocation type, which may be different from the result type. Type alloc_type = result_type; - if (allocations.contains(op->name)) { - alloc_type = allocations.get(op->name).type; + if (const auto *alloc = allocations.find(op->name)) { + alloc_type = alloc->type; } else if (workgroup_allocations.count(op->name)) { alloc_type = workgroup_allocations.at(op->name)->type; } @@ -826,8 +826,8 @@ void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Store *op) { // Get the allocation type, which may be different from the value type. Type alloc_type = value_type; - if (allocations.contains(op->name)) { - alloc_type = allocations.get(op->name).type; + if (const auto *alloc = allocations.find(op->name)) { + alloc_type = alloc->type; } else if (workgroup_allocations.count(op->name)) { alloc_type = workgroup_allocations.at(op->name)->type; } diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 8d87f4c1937e..51cf56fabbc0 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -866,28 +866,32 @@ void CodeGen_X86::visit(const Allocate *op) { } void CodeGen_X86::visit(const Load *op) { - if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) { - const Ramp *ramp = op->index.as(); - internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; - Value *ptr = codegen_buffer_pointer(op->name, op->type, ramp->base); - LoadInst *load = builder->CreateAlignedLoad(llvm_type_of(upgrade_type_for_storage(op->type)), ptr, llvm::Align(op->type.bytes())); - add_tbaa_metadata(load, op->name, op->index); - value = load; - return; + if (const auto *mt = mem_type.find(op->name)) { + if (*mt == MemoryType::AMXTile) { + const Ramp *ramp = op->index.as(); + internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; + Value *ptr = codegen_buffer_pointer(op->name, op->type, ramp->base); + LoadInst *load = builder->CreateAlignedLoad(llvm_type_of(upgrade_type_for_storage(op->type)), ptr, llvm::Align(op->type.bytes())); + add_tbaa_metadata(load, op->name, op->index); + value = load; + return; + } } CodeGen_Posix::visit(op); } void CodeGen_X86::visit(const Store *op) { - if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) { - Value *val = codegen(op->value); - Halide::Type value_type = op->value.type(); - const Ramp *ramp = op->index.as(); - internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; - Value *ptr = codegen_buffer_pointer(op->name, value_type, ramp->base); - StoreInst *store = builder->CreateAlignedStore(val, ptr, llvm::Align(value_type.bytes())); - add_tbaa_metadata(store, op->name, op->index); - return; + if (const auto *mt = mem_type.find(op->name)) { + if (mem_type.get(op->name) == MemoryType::AMXTile) { + Value *val = codegen(op->value); + Halide::Type value_type = op->value.type(); + const Ramp *ramp = op->index.as(); + internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; + Value *ptr = codegen_buffer_pointer(op->name, value_type, ramp->base); + StoreInst *store = builder->CreateAlignedStore(val, ptr, llvm::Align(value_type.bytes())); + add_tbaa_metadata(store, op->name, op->index); + return; + } } CodeGen_Posix::visit(op); } diff --git a/src/EliminateBoolVectors.cpp b/src/EliminateBoolVectors.cpp index cebfe0f0019b..62cdbdbef5b5 100644 --- a/src/EliminateBoolVectors.cpp +++ b/src/EliminateBoolVectors.cpp @@ -15,8 +15,8 @@ class EliminateBoolVectors : public IRMutator { Scope lets; Expr visit(const Variable *op) override { - if (lets.contains(op->name)) { - return Variable::make(lets.get(op->name), op->name); + if (const Type *t = lets.find(op->name)) { + return Variable::make(*t, op->name); } else { return op; } diff --git a/src/ExprUsesVar.h b/src/ExprUsesVar.h index 3bf129d259f7..84c3f7ae23d4 100644 --- a/src/ExprUsesVar.h +++ b/src/ExprUsesVar.h @@ -36,8 +36,8 @@ class ExprUsesVars : public IRGraphVisitor { void visit_name(const std::string &name) { if (vars.contains(name)) { result = true; - } else if (scope.contains(name)) { - include(scope.get(name)); + } else if (const Expr *e = scope.find(name)) { + IRGraphVisitor::include(*e); } } diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index a77a7b1798f3..f142df7f05f0 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -1083,8 +1083,8 @@ class SubstituteInWideningLets : public IRMutator { Scope replacements; Expr visit(const Variable *op) override { - if (replacements.contains(op->name)) { - return replacements.get(op->name); + if (const Expr *e = replacements.find(op->name)) { + return *e; } else { return op; } diff --git a/src/FuseGPUThreadLoops.cpp b/src/FuseGPUThreadLoops.cpp index ef5a75344bb8..abde50d62e1f 100644 --- a/src/FuseGPUThreadLoops.cpp +++ b/src/FuseGPUThreadLoops.cpp @@ -1140,21 +1140,21 @@ class ExtractRegisterAllocations : public IRMutator { } Expr visit(const Load *op) override { - string new_name = op->name; - if (alloc_renaming.contains(op->name)) { - new_name = alloc_renaming.get(op->name); + const string *new_name = alloc_renaming.find(op->name); + if (!new_name) { + new_name = &(op->name); } - return Load::make(op->type, new_name, mutate(op->index), + return Load::make(op->type, *new_name, mutate(op->index), op->image, op->param, mutate(op->predicate), op->alignment); } Stmt visit(const Store *op) override { - string new_name = op->name; - if (alloc_renaming.contains(op->name)) { - new_name = alloc_renaming.get(op->name); + const string *new_name = alloc_renaming.find(op->name); + if (!new_name) { + new_name = &(op->name); } - return Store::make(new_name, mutate(op->value), mutate(op->index), + return Store::make(*new_name, mutate(op->value), mutate(op->index), op->param, mutate(op->predicate), op->alignment); } diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index b76a9eb1cfef..deabd95d1d1b 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -1357,8 +1357,8 @@ class EliminateInterleaves : public IRMutator { } if (const Load *load = x.as()) { - if (buffers.contains(load->name)) { - return buffers.get(load->name) != BufferState::NotInterleaved; + if (const auto *state = buffers.find(load->name)) { + return *state != BufferState::NotInterleaved; } } @@ -1398,8 +1398,8 @@ class EliminateInterleaves : public IRMutator { } if (const Load *load = x.as()) { - if (buffers.contains(load->name)) { - return buffers.get(load->name) != BufferState::NotInterleaved; + if (const auto *state = buffers.find(load->name)) { + return *state != BufferState::NotInterleaved; } } @@ -1816,34 +1816,33 @@ class EliminateInterleaves : public IRMutator { Expr value = mutate(op->value); Expr index = mutate(op->index); - if (buffers.contains(op->name)) { + if (BufferState *state = buffers.shallow_find(op->name)) { // When inspecting the stores to a buffer, update the state. - BufferState &state = buffers.ref(op->name); if (!is_const_one(predicate) || !op->value.type().is_vector()) { // TODO(psuriana): This store is predicated. Mark the buffer as // not interleaved for now. - state = BufferState::NotInterleaved; + *state = BufferState::NotInterleaved; } else if (yields_removable_interleave(value)) { // The value yields a removable interleave. If we aren't tracking // this buffer, mark it as interleaved. - if (state == BufferState::Unknown) { - state = BufferState::Interleaved; + if (*state == BufferState::Unknown) { + *state = BufferState::Interleaved; } } else if (!yields_interleave(value)) { // The value does not yield an interleave. Mark the // buffer as not interleaved. - state = BufferState::NotInterleaved; + *state = BufferState::NotInterleaved; } else { // If the buffer yields an interleave, but is not an // interleave itself, we don't want to change the // buffer state. } - internal_assert(aligned_buffer_access.contains(op->name) && "Buffer not found in scope"); - bool &aligned_accesses = aligned_buffer_access.ref(op->name); + bool *aligned_accesses = aligned_buffer_access.shallow_find(op->name); + internal_assert(aligned_accesses) << "Buffer not found in scope"; int64_t aligned_offset = 0; if (!alignment_analyzer.is_aligned(op, &aligned_offset)) { - aligned_accesses = false; + *aligned_accesses = false; } } if (deinterleave_buffers.contains(op->name)) { @@ -1872,12 +1871,13 @@ class EliminateInterleaves : public IRMutator { // which is only true if any of the stores are // actually interleaved (and don't just yield an // interleave). - internal_assert(aligned_buffer_access.contains(op->name) && "Buffer not found in scope"); - bool &aligned_accesses = aligned_buffer_access.ref(op->name); + bool *aligned_accesses = aligned_buffer_access.shallow_find(op->name); + internal_assert(aligned_accesses) << "Buffer not found in scope"; + int64_t aligned_offset = 0; if (!alignment_analyzer.is_aligned(op, &aligned_offset)) { - aligned_accesses = false; + *aligned_accesses = false; } } else { // This is not a double vector load, so we can't diff --git a/src/LICM.cpp b/src/LICM.cpp index 641f4982a3e2..719b41442cfc 100644 --- a/src/LICM.cpp +++ b/src/LICM.cpp @@ -350,8 +350,8 @@ class GroupLoopInvariants : public IRMutator { const Scope &depth; void visit(const Variable *op) override { - if (depth.contains(op->name)) { - result = std::max(result, depth.get(op->name)); + if (const int *d = depth.find(op->name)) { + result = std::max(result, *d); } } diff --git a/src/LoopCarry.cpp b/src/LoopCarry.cpp index 050cdfbfc8d9..bfc2abc8ddf1 100644 --- a/src/LoopCarry.cpp +++ b/src/LoopCarry.cpp @@ -27,8 +27,8 @@ Expr is_linear(const Expr &e, const Scope &linear) { return Expr(); } if (const Variable *v = e.as()) { - if (linear.contains(v->name)) { - return linear.get(v->name); + if (const Expr *e = linear.find(v->name)) { + return *e; } else { return make_zero(v->type); } @@ -140,18 +140,17 @@ class StepForwards : public IRGraphMutator { using IRGraphMutator::visit; Expr visit(const Variable *op) override { - if (linear.contains(op->name)) { - Expr step = linear.get(op->name); - if (!step.defined()) { + if (const Expr *step = linear.find(op->name)) { + if (!step->defined()) { // It's non-linear success = false; return op; - } else if (is_const_zero(step)) { + } else if (is_const_zero(*step)) { // It's a known inner constant return op; } else { // It's linear - return Expr(op) + step; + return Expr(op) + *step; } } else { // It's some external constant diff --git a/src/LowerWarpShuffles.cpp b/src/LowerWarpShuffles.cpp index 79332c9336e5..a3499ace7e03 100644 --- a/src/LowerWarpShuffles.cpp +++ b/src/LowerWarpShuffles.cpp @@ -149,8 +149,8 @@ class DetermineAllocStride : public IRVisitor { } else if (const Variable *var = e.as()) { if (var->name == lane_var) { return 1; - } else if (dependent_vars.contains(var->name)) { - return dependent_vars.get(var->name); + } else if (const Expr *e = dependent_vars.find(var->name)) { + return *e; } else { return 0; } @@ -475,10 +475,10 @@ class LowerWarpShuffles : public IRMutator { if ((lt && equal(lt->a, this_lane) && is_const(lt->b)) || (le && equal(le->a, this_lane) && is_const(le->b))) { Expr condition = mutate(op->condition); - internal_assert(bounds.contains(this_lane_name)); - Interval interval = bounds.get(this_lane_name); - interval.max = lt ? simplify(lt->b - 1) : le->b; - ScopedBinding bind(bounds, this_lane_name, interval); + Interval *interval = bounds.shallow_find(this_lane_name); + internal_assert(interval); + interval->max = lt ? simplify(lt->b - 1) : le->b; + ScopedBinding bind(bounds, this_lane_name, *interval); Stmt then_case = mutate(op->then_case); Stmt else_case = mutate(op->else_case); return IfThenElse::make(condition, then_case, else_case); @@ -488,10 +488,10 @@ class LowerWarpShuffles : public IRMutator { } Stmt visit(const Store *op) override { - if (allocation_info.contains(op->name)) { + if (const auto *alloc = allocation_info.find(op->name)) { Expr idx = mutate(op->index); Expr value = mutate(op->value); - Expr stride = allocation_info.get(op->name).stride; + Expr stride = alloc->stride; internal_assert(stride.defined() && warp_size.defined()); // Reduce the index to an index in my own stripe. We have @@ -639,9 +639,9 @@ class LowerWarpShuffles : public IRMutator { } Expr visit(const Load *op) override { - if (allocation_info.contains(op->name)) { + if (const auto *alloc = allocation_info.find(op->name)) { Expr idx = mutate(op->index); - Expr stride = allocation_info.get(op->name).stride; + Expr stride = alloc->stride; // Break the index into lane and stripe components Expr lane = simplify(reduce_expr(idx / stride, warp_size, bounds), true, bounds); diff --git a/src/ModulusRemainder.cpp b/src/ModulusRemainder.cpp index cfccce1da786..13b3c72a181d 100644 --- a/src/ModulusRemainder.cpp +++ b/src/ModulusRemainder.cpp @@ -110,8 +110,8 @@ void ComputeModulusRemainder::visit(const Reinterpret *) { } void ComputeModulusRemainder::visit(const Variable *op) { - if (scope.contains(op->name)) { - result = scope.get(op->name); + if (const auto *m = scope.find(op->name)) { + result = *m; } else { result = ModulusRemainder{}; } diff --git a/src/Monotonic.cpp b/src/Monotonic.cpp index dd8e17d5b177..fee151f00a22 100644 --- a/src/Monotonic.cpp +++ b/src/Monotonic.cpp @@ -280,8 +280,8 @@ class DerivativeBounds : public IRVisitor { void visit(const Variable *op) override { if (op->name == var) { result = ConstantInterval::single_point(1); - } else if (scope.contains(op->name)) { - result = scope.get(op->name); + } else if (const auto *r = scope.find(op->name)) { + result = *r; } else { result = ConstantInterval::single_point(0); } diff --git a/src/Prefetch.cpp b/src/Prefetch.cpp index c0fb1f5c9a64..144b1950c5cd 100644 --- a/src/Prefetch.cpp +++ b/src/Prefetch.cpp @@ -86,10 +86,9 @@ class InjectPrefetch : public IRMutator { using IRMutator::visit; Box get_buffer_bounds(const string &name, int dims) { - if (buffer_bounds.contains(name)) { - const Box &b = buffer_bounds.ref(name); - internal_assert((int)b.size() == dims); - return b; + if (const Box *b = buffer_bounds.find(name)) { + internal_assert((int)b->size() == dims); + return *b; } // It is an external buffer. diff --git a/src/PrintLoopNest.cpp b/src/PrintLoopNest.cpp index 52f1c319951a..9d38efaaf80a 100644 --- a/src/PrintLoopNest.cpp +++ b/src/PrintLoopNest.cpp @@ -94,12 +94,16 @@ class PrintLoopNest : public IRVisitor { Expr min_val = op->min, extent_val = op->extent; const Variable *min_var = min_val.as(); const Variable *extent_var = extent_val.as(); - if (min_var && constants.contains(min_var->name)) { - min_val = constants.get(min_var->name); + if (min_var) { + if (const Expr *e = constants.find(min_var->name)) { + min_val = *e; + } } - if (extent_var && constants.contains(extent_var->name)) { - extent_val = constants.get(extent_var->name); + if (extent_var) { + if (const Expr *e = constants.find(extent_var->name)) { + extent_val = *e; + } } if (extent_val.defined() && is_const(extent_val) && @@ -151,9 +155,8 @@ class PrintLoopNest : public IRVisitor { void visit(const LetStmt *op) override { if (is_const(op->value)) { - constants.push(op->name, op->value); + ScopedBinding bind(constants, op->name, op->value); op->body.accept(this); - constants.pop(op->name); } else { op->body.accept(this); } diff --git a/src/Scope.h b/src/Scope.h index 9d1cc43e1164..6e3c4e57bec2 100644 --- a/src/Scope.h +++ b/src/Scope.h @@ -150,7 +150,39 @@ class Scope { return iter->second.top_ref(); } - /** Tests if a name is in scope */ + /** Returns a const pointer to an entry if it exists in this scope or any + * containing scope, or nullptr if it does not. Use this instead of if + * (scope.contains(foo)) { ... scope.get(foo) ... } to avoid doing two + * lookups. */ + template::value>::type> + const T2 *find(const std::string &name) const { + typename std::map>::const_iterator iter = table.find(name); + if (iter == table.end() || iter->second.empty()) { + if (containing_scope) { + return containing_scope->find(name); + } else { + return nullptr; + } + } + return &(iter->second.top_ref()); + } + + /** A version of find that returns a non-const pointer, but ignores + * containing scope. */ + template::value>::type> + T2 *shallow_find(const std::string &name) { + typename std::map>::iterator iter = table.find(name); + if (iter == table.end() || iter->second.empty()) { + return nullptr; + } else { + return &(iter->second.top_ref()); + } + } + + /** Tests if a name is in scope. If you plan to use the value if it is, call + * find instead. */ bool contains(const std::string &name) const { typename std::map>::const_iterator iter = table.find(name); if (iter == table.end() || iter->second.empty()) { @@ -173,19 +205,28 @@ class Scope { } } - /** Add a new (name, value) pair to the current scope. Hide old - * values that have this name until we pop this name. + struct PushToken { + typename std::map>::iterator iter; + }; + + /** Add a new (name, value) pair to the current scope. Hide old values that + * have this name until we pop this name. Returns a token that can be used + * to pop the same value without doing a fresh lookup. */ template::value>::type> - void push(const std::string &name, T2 &&value) { - table[name].push(std::forward(value)); + PushToken push(const std::string &name, T2 &&value) { + auto it = table.try_emplace(name).first; + it->second.push(std::forward(value)); + return PushToken{it}; } template::value>::type> - void push(const std::string &name) { - table[name].push(); + PushToken push(const std::string &name) { + auto it = table.try_emplace(name).first; + it->second.push(); + return PushToken{it}; } /** A name goes out of scope. Restore whatever its old value @@ -201,6 +242,14 @@ class Scope { } } + /** Pop a name using a token returned by push instead of a string. */ + void pop(PushToken p) { + p.iter->second.pop(); + if (p.iter->second.empty()) { + table.erase(p.iter); + } + } + /** Iterate through the scope. Does not capture any containing scope. */ class const_iterator { typename std::map>::const_iterator iter; @@ -271,20 +320,17 @@ std::ostream &operator<<(std::ostream &stream, const Scope &s) { template struct ScopedBinding { Scope *scope = nullptr; - std::string name; + typename Scope::PushToken token; ScopedBinding() = default; ScopedBinding(Scope &s, const std::string &n, T value) - : scope(&s), name(n) { - scope->push(name, std::move(value)); + : scope(&s), token(scope->push(n, std::move(value))) { } ScopedBinding(bool condition, Scope &s, const std::string &n, const T &value) - : scope(condition ? &s : nullptr), name(n) { - if (condition) { - scope->push(name, value); - } + : scope(condition ? &s : nullptr), + token(condition ? scope->push(n, value) : typename Scope::PushToken{}) { } bool bound() const { @@ -293,7 +339,7 @@ struct ScopedBinding { ~ScopedBinding() { if (scope) { - scope->pop(name); + scope->pop(token); } } @@ -301,7 +347,7 @@ struct ScopedBinding { ScopedBinding(const ScopedBinding &that) = delete; ScopedBinding(ScopedBinding &&that) noexcept : scope(that.scope), - name(std::move(that.name)) { + token(std::move(that.token)) { // The move constructor must null out scope, so we don't try to pop it that.scope = nullptr; } @@ -313,20 +359,17 @@ struct ScopedBinding { template<> struct ScopedBinding { Scope<> *scope; - std::string name; + Scope<>::PushToken token; ScopedBinding(Scope<> &s, const std::string &n) - : scope(&s), name(n) { - scope->push(name); + : scope(&s), token(scope->push(n)) { } ScopedBinding(bool condition, Scope<> &s, const std::string &n) - : scope(condition ? &s : nullptr), name(n) { - if (condition) { - scope->push(name); - } + : scope(condition ? &s : nullptr), + token(condition ? scope->push(n) : Scope<>::PushToken{}) { } ~ScopedBinding() { if (scope) { - scope->pop(name); + scope->pop(token); } } @@ -334,7 +377,7 @@ struct ScopedBinding { ScopedBinding(const ScopedBinding &that) = delete; ScopedBinding(ScopedBinding &&that) noexcept : scope(that.scope), - name(std::move(that.name)) { + token(std::move(that.token)) { // The move constructor must null out scope, so we don't try to pop it that.scope = nullptr; } diff --git a/src/Simplify.cpp b/src/Simplify.cpp index 339ef2917c83..61cf7886cb70 100644 --- a/src/Simplify.cpp +++ b/src/Simplify.cpp @@ -34,8 +34,8 @@ Simplify::Simplify(bool r, const Scope *bi, const Scopecontains(iter.name())) { - bounds.alignment = ai->get(iter.name()); + if (const auto *a = ai->find(iter.name())) { + bounds.alignment = *a; } if (bounds.min_defined || bounds.max_defined || bounds.alignment.modulus != 1) { @@ -74,18 +74,18 @@ std::pair, bool> Simplify::mutate_with_changes(const std::vect void Simplify::found_buffer_reference(const string &name, size_t dimensions) { for (size_t i = 0; i < dimensions; i++) { string stride = name + ".stride." + std::to_string(i); - if (var_info.contains(stride)) { - var_info.ref(stride).old_uses++; + if (auto *info = var_info.shallow_find(stride)) { + info->old_uses++; } string min = name + ".min." + std::to_string(i); - if (var_info.contains(min)) { - var_info.ref(min).old_uses++; + if (auto *info = var_info.shallow_find(min)) { + info->old_uses++; } } - if (var_info.contains(name)) { - var_info.ref(name).old_uses++; + if (auto *info = var_info.shallow_find(name)) { + info->old_uses++; } } @@ -187,8 +187,8 @@ void Simplify::ScopedFact::learn_upper_bound(const Variable *v, int64_t val) { ExprInfo b; b.max_defined = true; b.max = val; - if (simplify->bounds_and_alignment_info.contains(v->name)) { - b.intersect(simplify->bounds_and_alignment_info.get(v->name)); + if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) { + b.intersect(*info); } simplify->bounds_and_alignment_info.push(v->name, b); bounds_pop_list.push_back(v); @@ -198,8 +198,8 @@ void Simplify::ScopedFact::learn_lower_bound(const Variable *v, int64_t val) { ExprInfo b; b.min_defined = true; b.min = val; - if (simplify->bounds_and_alignment_info.contains(v->name)) { - b.intersect(simplify->bounds_and_alignment_info.get(v->name)); + if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) { + b.intersect(*info); } simplify->bounds_and_alignment_info.push(v->name, b); bounds_pop_list.push_back(v); @@ -228,10 +228,9 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { // TODO: Visiting it again is inefficient Simplify::ExprInfo expr_info; simplify->mutate(eq->b, &expr_info); - if (simplify->bounds_and_alignment_info.contains(v->name)) { + if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) { // We already know something about this variable and don't want to suppress it. - auto existing_knowledge = simplify->bounds_and_alignment_info.get(v->name); - expr_info.intersect(existing_knowledge); + expr_info.intersect(*info); } simplify->bounds_and_alignment_info.push(v->name, expr_info); bounds_pop_list.push_back(v); @@ -245,10 +244,9 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { // TODO: Visiting it again is inefficient Simplify::ExprInfo expr_info; simplify->mutate(eq->a, &expr_info); - if (simplify->bounds_and_alignment_info.contains(vb->name)) { + if (const auto *info = simplify->bounds_and_alignment_info.find(vb->name)) { // We already know something about this variable and don't want to suppress it. - auto existing_knowledge = simplify->bounds_and_alignment_info.get(vb->name); - expr_info.intersect(existing_knowledge); + expr_info.intersect(*info); } simplify->bounds_and_alignment_info.push(vb->name, expr_info); bounds_pop_list.push_back(vb); @@ -257,10 +255,9 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { Simplify::ExprInfo expr_info; expr_info.alignment.modulus = *modulus; expr_info.alignment.remainder = *remainder; - if (simplify->bounds_and_alignment_info.contains(v->name)) { + if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) { // We already know something about this variable and don't want to suppress it. - auto existing_knowledge = simplify->bounds_and_alignment_info.get(v->name); - expr_info.intersect(existing_knowledge); + expr_info.intersect(*info); } simplify->bounds_and_alignment_info.push(v->name, expr_info); bounds_pop_list.push_back(v); @@ -417,8 +414,8 @@ bool can_prove(Expr e, const Scope &bounds) { Expr visit(const Variable *op) override { auto it = vars.find(op->name); - if (lets.contains(op->name)) { - return Variable::make(op->type, lets.get(op->name)); + if (const std::string *n = lets.find(op->name)) { + return Variable::make(op->type, *n); } else if (it == vars.end()) { std::string name = "v" + std::to_string(count++); vars[op->name] = name; diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index a8e5fcce1a8d..b5fcc96ac0cd 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -221,35 +221,32 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { } Expr Simplify::visit(const Variable *op, ExprInfo *bounds) { - if (bounds_and_alignment_info.contains(op->name)) { - const ExprInfo &b = bounds_and_alignment_info.get(op->name); + if (const ExprInfo *b = bounds_and_alignment_info.find(op->name)) { if (bounds) { - *bounds = b; + *bounds = *b; } - if (b.min_defined && b.max_defined && b.min == b.max) { - return make_const(op->type, b.min); + if (b->min_defined && b->max_defined && b->min == b->max) { + return make_const(op->type, b->min); } } - if (var_info.contains(op->name)) { - auto &info = var_info.ref(op->name); - + if (auto *info = var_info.shallow_find(op->name)) { // if replacement is defined, we should substitute it in (unless // it's a var that has been hidden by a nested scope). - if (info.replacement.defined()) { - internal_assert(info.replacement.type() == op->type) + if (info->replacement.defined()) { + internal_assert(info->replacement.type() == op->type) << "Cannot replace variable " << op->name << " of type " << op->type - << " with expression of type " << info.replacement.type() << "\n"; - info.new_uses++; + << " with expression of type " << info->replacement.type() << "\n"; + info->new_uses++; // We want to remutate the replacement, because we may be // injecting it into a context where it is known to be a // constant (e.g. due to an if). - return mutate(info.replacement, bounds); + return mutate(info->replacement, bounds); } else { // This expression was not something deemed // substitutable - no replacement is defined. - info.old_uses++; + info->old_uses++; return op; } } else { @@ -321,15 +318,14 @@ Expr Simplify::visit(const Load *op, ExprInfo *bounds) { // unreachable loads. if (is_const_one(op->predicate)) { string alloc_extent_name = op->name + ".total_extent_bytes"; - if (bounds_and_alignment_info.contains(alloc_extent_name)) { + if (const auto *alloc_info = bounds_and_alignment_info.find(alloc_extent_name)) { if (index_info.max_defined && index_info.max < 0) { in_unreachable = true; return unreachable(op->type); } - const ExprInfo &alloc_info = bounds_and_alignment_info.get(alloc_extent_name); - if (alloc_info.max_defined && index_info.min_defined) { + if (alloc_info->max_defined && index_info.min_defined) { int index_min_bytes = index_info.min * op->type.bytes(); - if (index_min_bytes > alloc_info.max) { + if (index_min_bytes > alloc_info->max) { in_unreachable = true; return unreachable(op->type); } diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index 11b146ecdc6a..f6cb81345961 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -305,19 +305,19 @@ Stmt Simplify::visit(const Store *op) { // but perhaps the branch was hard to prove constant true or false. This // provides an alternative mechanism to simplify these unreachable stores. string alloc_extent_name = op->name + ".total_extent_bytes"; - if (is_const_one(op->predicate) && - bounds_and_alignment_info.contains(alloc_extent_name)) { - if (index_info.max_defined && index_info.max < 0) { - in_unreachable = true; - return Evaluate::make(unreachable()); - } - const ExprInfo &alloc_info = bounds_and_alignment_info.get(alloc_extent_name); - if (alloc_info.max_defined && index_info.min_defined) { - int index_min_bytes = index_info.min * op->value.type().bytes(); - if (index_min_bytes > alloc_info.max) { + if (is_const_one(op->predicate)) { + if (const auto *alloc_info = bounds_and_alignment_info.find(alloc_extent_name)) { + if (index_info.max_defined && index_info.max < 0) { in_unreachable = true; return Evaluate::make(unreachable()); } + if (alloc_info->max_defined && index_info.min_defined) { + int index_min_bytes = index_info.min * op->value.type().bytes(); + if (index_min_bytes > alloc_info->max) { + in_unreachable = true; + return Evaluate::make(unreachable()); + } + } } } diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index ab25ad32bc87..dfb50d714e37 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -69,10 +69,9 @@ class ExpandExpr : public IRMutator { const Scope &scope; Expr visit(const Variable *var) override { - if (scope.contains(var->name)) { - Expr expr = scope.get(var->name); - debug(4) << "Fully expanded " << var->name << " -> " << expr << "\n"; - return expr; + if (const Expr *expr = scope.find(var->name)) { + debug(4) << "Fully expanded " << var->name << " -> " << *expr << "\n"; + return *expr; } else { return var; } diff --git a/src/Solve.cpp b/src/Solve.cpp index b25719cff8c7..09245d90bf24 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -786,17 +786,15 @@ class SolveExpression : public IRMutator { if (op->name == var) { uses_var = true; return op; - } else if (scope.contains(op->name)) { - CacheEntry e = scope.get(op->name); - uses_var = uses_var || e.uses_var; - failed = failed || e.failed; - return e.expr; - } else if (external_scope.contains(op->name)) { - Expr e = external_scope.get(op->name); + } else if (const CacheEntry *e = scope.find(op->name)) { + uses_var = uses_var || e->uses_var; + failed = failed || e->failed; + return e->expr; + } else if (const Expr *e = external_scope.find(op->name)) { // Expressions in the external scope haven't been solved // yet. This will either pull its solution from the cache, // or solve it and then put it into the cache. - return mutate(e); + return mutate(*e); } else { return op; } @@ -948,13 +946,13 @@ class SolveForInterval : public IRVisitor { void visit(const Variable *op) override { internal_assert(op->type.is_bool()); - if (scope.contains(op->name)) { + if (const Expr *e = scope.find(op->name)) { pair key = {op->name, target}; auto it = solved_vars.find(key); if (it != solved_vars.end()) { result = it->second; } else { - scope.get(op->name).accept(this); + e->accept(this); solved_vars[key] = result; } } else { diff --git a/src/StageStridedLoads.cpp b/src/StageStridedLoads.cpp index feeab56a4122..723fc738ce51 100644 --- a/src/StageStridedLoads.cpp +++ b/src/StageStridedLoads.cpp @@ -103,8 +103,8 @@ class FindStridedLoads : public IRVisitor { if (stride >= 2 && stride < r->lanes && r->stride.type().is_scalar()) { const IRNode *s = scope; const Allocate *a = nullptr; - if (allocation_scope.contains(op->name)) { - a = allocation_scope.get(op->name); + if (const Allocate *const *a_ptr = allocation_scope.find(op->name)) { + a = *a_ptr; } found_loads[Key{op->name, base, stride, r->lanes, op->type, a, s}][offset].push_back(op); } @@ -161,8 +161,8 @@ class ReplaceStridedLoads : public IRMutator { protected: Expr visit(const Load *op) override { const Allocate *alloc = nullptr; - if (allocation_scope.contains(op->name)) { - alloc = allocation_scope.get(op->name); + if (const Allocate *const *a_ptr = allocation_scope.find(op->name)) { + alloc = *a_ptr; } auto it = replacements.find({alloc, op}); if (it != replacements.end()) { diff --git a/src/StmtToHTML.cpp b/src/StmtToHTML.cpp index 9c317ba35525..79cf6563551e 100644 --- a/src/StmtToHTML.cpp +++ b/src/StmtToHTML.cpp @@ -1134,8 +1134,8 @@ class HTMLCodePrinter : public IRVisitor { std::string variable(const std::string &x, const std::string &tooltip) { int id; - if (scope.contains(x)) { - id = scope.get(x); + if (const int *i = scope.find(x)) { + id = *i; } else { id = gen_unique_id(); scope.push(x, id); diff --git a/src/StorageFlattening.cpp b/src/StorageFlattening.cpp index d7e7c50002f6..13d7d6475120 100644 --- a/src/StorageFlattening.cpp +++ b/src/StorageFlattening.cpp @@ -31,10 +31,9 @@ class ExpandExpr : public IRMutator { const Scope &scope; Expr visit(const Variable *var) override { - if (scope.contains(var->name)) { - Expr expr = scope.get(var->name); + if (const Expr *e = scope.find(var->name)) { // Mutate the expression, so lets can get replaced recursively. - expr = mutate(expr); + Expr expr = mutate(*e); debug(4) << "Fully expanded " << var->name << " -> " << expr << "\n"; return expr; } else { diff --git a/src/UniquifyVariableNames.cpp b/src/UniquifyVariableNames.cpp index 26689ec34633..85a6ba521771 100644 --- a/src/UniquifyVariableNames.cpp +++ b/src/UniquifyVariableNames.cpp @@ -104,10 +104,9 @@ class UniquifyVariableNames : public IRMutator { } Expr visit(const Variable *op) override { - if (renaming.contains(op->name)) { - string new_name = renaming.get(op->name); - if (new_name != op->name) { - return Variable::make(op->type, new_name); + if (const string *new_name = renaming.find(op->name)) { + if (*new_name != op->name) { + return Variable::make(op->type, *new_name); } } return op; diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index 6d10d2e9d5f3..0745a34a9d39 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -297,8 +297,8 @@ bool is_interleaved_ramp(const Expr &e, const Scope &scope, InterleavedRam return true; } } else if (const Variable *var = e.as()) { - if (scope.contains(var->name)) { - return is_interleaved_ramp(scope.get(var->name), scope, result); + if (const Expr *e = scope.find(var->name)) { + return is_interleaved_ramp(*e, scope, result); } } return false; From 72bcf1d77c13ca89c243d0b03d8735e6a18e622b Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Sat, 17 Feb 2024 19:35:50 -0800 Subject: [PATCH 2/4] Pacify clang-tidy --- src/Scope.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/Scope.h b/src/Scope.h index 6e3c4e57bec2..94d9eb9c165b 100644 --- a/src/Scope.h +++ b/src/Scope.h @@ -347,7 +347,7 @@ struct ScopedBinding { ScopedBinding(const ScopedBinding &that) = delete; ScopedBinding(ScopedBinding &&that) noexcept : scope(that.scope), - token(std::move(that.token)) { + token(that.token) { // The move constructor must null out scope, so we don't try to pop it that.scope = nullptr; } @@ -377,7 +377,7 @@ struct ScopedBinding { ScopedBinding(const ScopedBinding &that) = delete; ScopedBinding(ScopedBinding &&that) noexcept : scope(that.scope), - token(std::move(that.token)) { + token(that.token) { // The move constructor must null out scope, so we don't try to pop it that.scope = nullptr; } From 1f8c8b5274db789939c44fa94f13897fd065aa1a Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Mon, 19 Feb 2024 08:17:42 -0800 Subject: [PATCH 3/4] Fix unintentional mutation of interval in scope --- src/LowerWarpShuffles.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/LowerWarpShuffles.cpp b/src/LowerWarpShuffles.cpp index a3499ace7e03..ad48c37db78f 100644 --- a/src/LowerWarpShuffles.cpp +++ b/src/LowerWarpShuffles.cpp @@ -475,10 +475,11 @@ class LowerWarpShuffles : public IRMutator { if ((lt && equal(lt->a, this_lane) && is_const(lt->b)) || (le && equal(le->a, this_lane) && is_const(le->b))) { Expr condition = mutate(op->condition); - Interval *interval = bounds.shallow_find(this_lane_name); - internal_assert(interval); - interval->max = lt ? simplify(lt->b - 1) : le->b; - ScopedBinding bind(bounds, this_lane_name, *interval); + const Interval *in = bounds.find(this_lane_name); + internal_assert(in); + Interval interval = *in; + interval.max = lt ? simplify(lt->b - 1) : le->b; + ScopedBinding bind(bounds, this_lane_name, interval); Stmt then_case = mutate(op->then_case); Stmt else_case = mutate(op->else_case); return IfThenElse::make(condition, then_case, else_case); From 8d59c7ccfdc0a22fe10eedfe865ff0db716d31b8 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 20 Feb 2024 13:39:48 -0800 Subject: [PATCH 4/4] Fix accidental Scope::get --- src/CodeGen_X86.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 51cf56fabbc0..0320e64b5ae5 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -882,7 +882,7 @@ void CodeGen_X86::visit(const Load *op) { void CodeGen_X86::visit(const Store *op) { if (const auto *mt = mem_type.find(op->name)) { - if (mem_type.get(op->name) == MemoryType::AMXTile) { + if (*mt == MemoryType::AMXTile) { Value *val = codegen(op->value); Halide::Type value_type = op->value.type(); const Ramp *ramp = op->index.as();