diff --git a/src/Bounds.cpp b/src/Bounds.cpp index a08bb0b9ad61..16fd69f3e8fb 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -406,13 +406,12 @@ class Bounds : public IRVisitor { if (const_bound) { bounds_of_type(op->type); - if (scope.contains(op->name)) { - const Interval &scope_interval = scope.get(op->name); - if (scope_interval.has_upper_bound() && is_const(scope_interval.max)) { - interval.max = Interval::make_min(interval.max, scope_interval.max); + if (const Interval *scope_interval = scope.find(op->name)) { + if (scope_interval->has_upper_bound() && is_const(scope_interval->max)) { + interval.max = Interval::make_min(interval.max, scope_interval->max); } - if (scope_interval.has_lower_bound() && is_const(scope_interval.min)) { - interval.min = Interval::make_max(interval.min, scope_interval.min); + if (scope_interval->has_lower_bound() && is_const(scope_interval->min)) { + interval.min = Interval::make_max(interval.min, scope_interval->min); } } @@ -429,8 +428,8 @@ class Bounds : public IRVisitor { } } } else { - if (scope.contains(op->name)) { - interval = scope.get(op->name); + if (const Interval *in = scope.find(op->name)) { + interval = *in; } else if (op->type.is_vector()) { // Uh oh, we need to take the min/max lane of some unknown vector. Treat as unbounded. bounds_of_type(op->type); @@ -2054,11 +2053,10 @@ class FindInnermostVar : public IRVisitor { int innermost_depth = -1; void visit(const Variable *op) override { - if (vars_depth.contains(op->name)) { - int depth = vars_depth.get(op->name); - if (depth > innermost_depth) { + if (const int *depth = vars_depth.find(op->name)) { + if (*depth > innermost_depth) { innermost_var = op->name; - innermost_depth = depth; + innermost_depth = *depth; } } } @@ -2545,16 +2543,17 @@ class BoxesTouched : public IRGraphVisitor { // If this let stmt is a redefinition of a previous one, we should // remove the old let stmt from the 'children' map since it is // no longer valid at this point. - if ((f.vi.instance > 0) && let_stmts.contains(op->name)) { - const Expr &val = let_stmts.get(op->name); - CollectVars collect(op->name); - val.accept(&collect); - f.old_let_vars = collect.vars; - - VarInstance old_vi = VarInstance(f.vi.var, f.vi.instance - 1); - for (const auto &v : f.old_let_vars) { - internal_assert(vars_renaming.count(v)); - children[get_var_instance(v)].erase(old_vi); + if (f.vi.instance > 0) { + if (const Expr *val = let_stmts.find(op->name)) { + CollectVars collect(op->name); + val->accept(&collect); + f.old_let_vars = collect.vars; + + VarInstance old_vi = VarInstance(f.vi.var, f.vi.instance - 1); + for (const auto &v : f.old_let_vars) { + internal_assert(vars_renaming.count(v)); + children[get_var_instance(v)].erase(old_vi); + } } } let_stmts.push(op->name, op->value); @@ -2756,17 +2755,17 @@ class BoxesTouched : public IRGraphVisitor { expr_uses_var(box[i].min, l.min_name))) || (box[i].has_upper_bound() && (expr_uses_var(box[i].max, l.max_name) || expr_uses_var(box[i].max, l.min_name)))) { - internal_assert(let_stmts.contains(l.var)); - const Expr &val = let_stmts.get(l.var); - v_bound = bounds_of_expr_in_scope(val, scope, func_bounds); + const Expr *val = let_stmts.find(l.var); + internal_assert(val); + v_bound = bounds_of_expr_in_scope(*val, scope, func_bounds); bool fixed = v_bound.min.same_as(v_bound.max); v_bound.min = simplify(v_bound.min); v_bound.max = fixed ? v_bound.min : simplify(v_bound.max); - internal_assert(scope.contains(l.var)); - const Interval &old_bound = scope.get(l.var); - v_bound.max = simplify(min(v_bound.max, old_bound.max)); - v_bound.min = simplify(max(v_bound.min, old_bound.min)); + const Interval *old_bound = scope.find(l.var); + internal_assert(old_bound); + v_bound.max = simplify(min(v_bound.max, old_bound->max)); + v_bound.min = simplify(max(v_bound.min, old_bound->min)); } if (box[i].has_lower_bound()) { @@ -3017,14 +3016,14 @@ class BoxesTouched : public IRGraphVisitor { } Expr min_val, max_val; - if (scope.contains(op->name + ".loop_min")) { - min_val = scope.get(op->name + ".loop_min").min; + if (const Interval *in = scope.find(op->name + ".loop_min")) { + min_val = in->min; } else { min_val = bounds_of_expr_in_scope(op->min, scope, func_bounds).min; } - if (scope.contains(op->name + ".loop_max")) { - max_val = scope.get(op->name + ".loop_max").max; + if (const Interval *in = scope.find(op->name + ".loop_max")) { + max_val = in->max; } else { max_val = bounds_of_expr_in_scope(op->extent, scope, func_bounds).max; max_val += bounds_of_expr_in_scope(op->min, scope, func_bounds).max; diff --git a/src/CSE.cpp b/src/CSE.cpp index 7d39fcc90dc5..d8ecd619db81 100644 --- a/src/CSE.cpp +++ b/src/CSE.cpp @@ -201,8 +201,8 @@ class RemoveLets : public IRGraphMutator { Scope scope; Expr visit(const Variable *op) override { - if (scope.contains(op->name)) { - return scope.get(op->name); + if (const Expr *e = scope.find(op->name)) { + return *e; } else { return op; } diff --git a/src/ClampUnsafeAccesses.cpp b/src/ClampUnsafeAccesses.cpp index 5e2e1f5d5b2e..b3dd9ddc235e 100644 --- a/src/ClampUnsafeAccesses.cpp +++ b/src/ClampUnsafeAccesses.cpp @@ -50,8 +50,10 @@ struct ClampUnsafeAccesses : IRMutator { } Expr visit(const Variable *var) override { - if (is_inside_indexing && let_var_inside_indexing.contains(var->name)) { - let_var_inside_indexing.ref(var->name) = true; + if (is_inside_indexing) { + if (bool *b = let_var_inside_indexing.shallow_find(var->name)) { + *b = true; + } } return var; } diff --git a/src/CodeGen_ARM.cpp b/src/CodeGen_ARM.cpp index 9c6525703f16..7852532183bf 100644 --- a/src/CodeGen_ARM.cpp +++ b/src/CodeGen_ARM.cpp @@ -82,13 +82,14 @@ class SubstituteInStridedLoads : public IRMutator { Expr visit(const Shuffle *op) override { int stride = op->slice_stride(); const Variable *var = op->vectors[0].as(); + const Expr *vec = nullptr; if (var && poisoned_vars.count(var->name) == 0 && op->vectors.size() == 1 && 2 <= stride && stride <= 4 && op->slice_begin() < stride && - loads.contains(var->name)) { - return Shuffle::make_slice({loads.get(var->name)}, op->slice_begin(), op->slice_stride(), op->type.lanes()); + (vec = loads.find(var->name))) { + return Shuffle::make_slice({*vec}, op->slice_begin(), op->slice_stride(), op->type.lanes()); } else { return IRMutator::visit(op); } diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index 89c18cb8ab28..b0cdcb3e956c 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -1936,8 +1936,9 @@ void CodeGen_C::visit(const Load *op) { user_assert(is_const_one(op->predicate)) << "Predicated scalar load is not supported by C backend.\n"; string id_index = print_expr(op->index); - bool type_cast_needed = !(allocations.contains(op->name) && - allocations.get(op->name).type.element_of() == t.element_of()); + const auto *alloc = allocations.find(op->name); + bool type_cast_needed = !(alloc && + alloc->type.element_of() == t.element_of()); if (type_cast_needed) { const char *const_flag = output_kind == CPlusPlusImplementation ? " const" : ""; rhs << "((" << print_type(t.element_of()) << const_flag << " *)" << name << ")"; diff --git a/src/CodeGen_D3D12Compute_Dev.cpp b/src/CodeGen_D3D12Compute_Dev.cpp index c8e45ea2ae09..4fd614cc0dfc 100644 --- a/src/CodeGen_D3D12Compute_Dev.cpp +++ b/src/CodeGen_D3D12Compute_Dev.cpp @@ -592,8 +592,9 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Load *op) { string id_index = print_expr(op->index); // Get the rhs just for the cache. - bool type_cast_needed = !(allocations.contains(op->name) && - allocations.get(op->name).type == op->type); + const auto *alloc = allocations.find(op->name); + bool type_cast_needed = !(alloc && + alloc->type == op->type); ostringstream rhs; if (type_cast_needed) { diff --git a/src/CodeGen_Hexagon.cpp b/src/CodeGen_Hexagon.cpp index 9463a4c921aa..a77e9c7c1a76 100644 --- a/src/CodeGen_Hexagon.cpp +++ b/src/CodeGen_Hexagon.cpp @@ -221,8 +221,8 @@ class SloppyUnpredicateLoadsAndStores : public IRMutator { } } } else if (const Variable *op = e.as()) { - if (monotonic_vectors.contains(op->name)) { - return monotonic_vectors.get(op->name); + if (const auto *p = monotonic_vectors.find(op->name)) { + return *p; } } else if (const Let *op = e.as()) { auto v = get_extreme_lanes(op->value); @@ -2245,10 +2245,9 @@ void CodeGen_Hexagon::visit(const Allocate *alloc) { codegen(alloc->body); // If there was no early free, free it now. - if (allocations.contains(alloc->name)) { - Allocation alloc_obj = allocations.get(alloc->name); - internal_assert(alloc_obj.destructor); - trigger_destructor(alloc_obj.destructor_function, alloc_obj.destructor); + if (const Allocation *alloc_obj = allocations.find(alloc->name)) { + internal_assert(alloc_obj->destructor); + trigger_destructor(alloc_obj->destructor_function, alloc_obj->destructor); allocations.pop(alloc->name); sym_pop(alloc->name); diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index a5c32cf83cc7..8922461524c5 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -1268,7 +1268,8 @@ void CodeGen_LLVM::sym_pop(const string &name) { llvm::Value *CodeGen_LLVM::sym_get(const string &name, bool must_succeed) const { // look in the symbol table - if (!symbol_table.contains(name)) { + llvm::Value *const *v = symbol_table.find(name); + if (!v) { if (must_succeed) { std::ostringstream err; err << "Symbol not found: " << name << "\n"; @@ -1283,7 +1284,7 @@ llvm::Value *CodeGen_LLVM::sym_get(const string &name, bool must_succeed) const return nullptr; } } - return symbol_table.get(name); + return *v; } bool CodeGen_LLVM::sym_exists(const string &name) const { diff --git a/src/CodeGen_Metal_Dev.cpp b/src/CodeGen_Metal_Dev.cpp index 69d47279e9ae..79060294798e 100644 --- a/src/CodeGen_Metal_Dev.cpp +++ b/src/CodeGen_Metal_Dev.cpp @@ -390,8 +390,9 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Load *op) { string id_index = print_expr(op->index); // Get the rhs just for the cache. - bool type_cast_needed = !(allocations.contains(op->name) && - allocations.get(op->name).type == op->type); + const auto *alloc = allocations.find(op->name); + bool type_cast_needed = !(alloc && + alloc->type == op->type); ostringstream rhs; if (type_cast_needed) { rhs << "((" << get_memory_space(op->name) << " " @@ -467,8 +468,8 @@ void CodeGen_Metal_Dev::CodeGen_Metal_C::visit(const Store *op) { << id_value << "[" << i << "];\n"; } } else { - bool type_cast_needed = !(allocations.contains(op->name) && - allocations.get(op->name).type == t); + const auto *alloc = allocations.find(op->name); + bool type_cast_needed = !(alloc && alloc->type == t); string id_index = print_expr(op->index); stream << get_indent(); diff --git a/src/CodeGen_OpenCL_Dev.cpp b/src/CodeGen_OpenCL_Dev.cpp index 52feed53f9e0..c86e483cc5a8 100644 --- a/src/CodeGen_OpenCL_Dev.cpp +++ b/src/CodeGen_OpenCL_Dev.cpp @@ -484,8 +484,8 @@ string CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::print_array_access(const string &na const Type &type, const string &id_index) { ostringstream rhs; - bool type_cast_needed = !(allocations.contains(name) && - allocations.get(name).type == type); + const auto *alloc = allocations.find(name); + bool type_cast_needed = !(alloc && alloc->type == type); if (type_cast_needed) { rhs << "((" << get_memory_space(name) << " " @@ -583,8 +583,8 @@ void CodeGen_OpenCL_Dev::CodeGen_OpenCL_C::visit(const Store *op) { // For atomicAdd, we check if op->value - store[index] is independent of store. // The atomicAdd operations in OpenCL only supports integers so we also check that. bool is_atomic_add = t.is_int_or_uint() && !expr_uses_var(delta, op->name); - bool type_cast_needed = !(allocations.contains(op->name) && - allocations.get(op->name).type == t); + const auto *alloc = allocations.find(op->name); + bool type_cast_needed = !(alloc && alloc->type == t); auto print_store_var = [&]() { if (type_cast_needed) { stream << "((" diff --git a/src/CodeGen_Posix.cpp b/src/CodeGen_Posix.cpp index af508194b06e..f812b63cce9d 100644 --- a/src/CodeGen_Posix.cpp +++ b/src/CodeGen_Posix.cpp @@ -342,8 +342,8 @@ void CodeGen_Posix::free_allocation(const std::string &name) { } string CodeGen_Posix::get_allocation_name(const std::string &n) { - if (allocations.contains(n)) { - return allocations.get(n).name; + if (const auto *alloc = allocations.find(n)) { + return alloc->name; } else { return n; } diff --git a/src/CodeGen_Vulkan_Dev.cpp b/src/CodeGen_Vulkan_Dev.cpp index 61b365f2f7aa..39dd65b67671 100644 --- a/src/CodeGen_Vulkan_Dev.cpp +++ b/src/CodeGen_Vulkan_Dev.cpp @@ -1539,10 +1539,10 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Load *op) { user_assert(is_const_one(op->predicate)) << "Predicated loads not supported by SPIR-V codegen\n"; // Construct the pointer to read from - internal_assert(symbol_table.contains(op->name)); - SymbolIdStorageClassPair id_and_storage_class = symbol_table.get(op->name); - SpvId variable_id = id_and_storage_class.first; - SpvStorageClass storage_class = id_and_storage_class.second; + const SymbolIdStorageClassPair *id_and_storage_class = symbol_table.find(op->name); + internal_assert(id_and_storage_class); + SpvId variable_id = id_and_storage_class->first; + SpvStorageClass storage_class = id_and_storage_class->second; internal_assert(variable_id != SpvInvalidId); internal_assert(((uint32_t)storage_class) < ((uint32_t)SpvStorageClassMax)); @@ -1576,10 +1576,10 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Store *op) { op->value.accept(this); SpvId value_id = builder.current_id(); - internal_assert(symbol_table.contains(op->name)); - SymbolIdStorageClassPair id_and_storage_class = symbol_table.get(op->name); - SpvId variable_id = id_and_storage_class.first; - SpvStorageClass storage_class = id_and_storage_class.second; + const SymbolIdStorageClassPair *id_and_storage_class = symbol_table.find(op->name); + internal_assert(id_and_storage_class); + SpvId variable_id = id_and_storage_class->first; + SpvStorageClass storage_class = id_and_storage_class->second; internal_assert(variable_id != SpvInvalidId); internal_assert(((uint32_t)storage_class) < ((uint32_t)SpvStorageClassMax)); @@ -1665,9 +1665,10 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const For *op) { const std::string intrinsic_var_name = std::string("k") + std::to_string(kernel_index) + std::string("_") + intrinsic.first; // Intrinsics are inserted when adding the kernel - internal_assert(symbol_table.contains(intrinsic_var_name)); - SpvId intrinsic_id = symbol_table.get(intrinsic_var_name).first; - SpvStorageClass storage_class = symbol_table.get(intrinsic_var_name).second; + const auto *intrin = symbol_table.find(intrinsic_var_name); + internal_assert(intrin); + SpvId intrinsic_id = intrin->first; + SpvStorageClass storage_class = intrin->second; // extract and cast to the extent type (which is what's expected by Halide's for loops) Type unsigned_type = UInt(32); @@ -1908,8 +1909,9 @@ void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Allocate *op) { void CodeGen_Vulkan_Dev::SPIRV_Emitter::visit(const Free *op) { debug(3) << "Vulkan: Popping allocation called " << op->name << " off the symbol table\n"; - internal_assert(symbol_table.contains(op->name)); - SpvId variable_id = symbol_table.get(op->name).first; + const auto *id = symbol_table.find(op->name); + internal_assert(id); + SpvId variable_id = id->first; storage_access_map.erase(variable_id); symbol_table.pop(op->name); } diff --git a/src/CodeGen_WebGPU_Dev.cpp b/src/CodeGen_WebGPU_Dev.cpp index 08d3a542f41b..de55113ff695 100644 --- a/src/CodeGen_WebGPU_Dev.cpp +++ b/src/CodeGen_WebGPU_Dev.cpp @@ -684,8 +684,8 @@ void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Load *op) { // Get the allocation type, which may be different from the result type. Type alloc_type = result_type; - if (allocations.contains(op->name)) { - alloc_type = allocations.get(op->name).type; + if (const auto *alloc = allocations.find(op->name)) { + alloc_type = alloc->type; } else if (workgroup_allocations.count(op->name)) { alloc_type = workgroup_allocations.at(op->name)->type; } @@ -826,8 +826,8 @@ void CodeGen_WebGPU_Dev::CodeGen_WGSL::visit(const Store *op) { // Get the allocation type, which may be different from the value type. Type alloc_type = value_type; - if (allocations.contains(op->name)) { - alloc_type = allocations.get(op->name).type; + if (const auto *alloc = allocations.find(op->name)) { + alloc_type = alloc->type; } else if (workgroup_allocations.count(op->name)) { alloc_type = workgroup_allocations.at(op->name)->type; } diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 8d87f4c1937e..0320e64b5ae5 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -866,28 +866,32 @@ void CodeGen_X86::visit(const Allocate *op) { } void CodeGen_X86::visit(const Load *op) { - if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) { - const Ramp *ramp = op->index.as(); - internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; - Value *ptr = codegen_buffer_pointer(op->name, op->type, ramp->base); - LoadInst *load = builder->CreateAlignedLoad(llvm_type_of(upgrade_type_for_storage(op->type)), ptr, llvm::Align(op->type.bytes())); - add_tbaa_metadata(load, op->name, op->index); - value = load; - return; + if (const auto *mt = mem_type.find(op->name)) { + if (*mt == MemoryType::AMXTile) { + const Ramp *ramp = op->index.as(); + internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; + Value *ptr = codegen_buffer_pointer(op->name, op->type, ramp->base); + LoadInst *load = builder->CreateAlignedLoad(llvm_type_of(upgrade_type_for_storage(op->type)), ptr, llvm::Align(op->type.bytes())); + add_tbaa_metadata(load, op->name, op->index); + value = load; + return; + } } CodeGen_Posix::visit(op); } void CodeGen_X86::visit(const Store *op) { - if (mem_type.contains(op->name) && mem_type.get(op->name) == MemoryType::AMXTile) { - Value *val = codegen(op->value); - Halide::Type value_type = op->value.type(); - const Ramp *ramp = op->index.as(); - internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; - Value *ptr = codegen_buffer_pointer(op->name, value_type, ramp->base); - StoreInst *store = builder->CreateAlignedStore(val, ptr, llvm::Align(value_type.bytes())); - add_tbaa_metadata(store, op->name, op->index); - return; + if (const auto *mt = mem_type.find(op->name)) { + if (*mt == MemoryType::AMXTile) { + Value *val = codegen(op->value); + Halide::Type value_type = op->value.type(); + const Ramp *ramp = op->index.as(); + internal_assert(ramp) << "Expected AMXTile to have index ramp\n"; + Value *ptr = codegen_buffer_pointer(op->name, value_type, ramp->base); + StoreInst *store = builder->CreateAlignedStore(val, ptr, llvm::Align(value_type.bytes())); + add_tbaa_metadata(store, op->name, op->index); + return; + } } CodeGen_Posix::visit(op); } diff --git a/src/EliminateBoolVectors.cpp b/src/EliminateBoolVectors.cpp index cebfe0f0019b..62cdbdbef5b5 100644 --- a/src/EliminateBoolVectors.cpp +++ b/src/EliminateBoolVectors.cpp @@ -15,8 +15,8 @@ class EliminateBoolVectors : public IRMutator { Scope lets; Expr visit(const Variable *op) override { - if (lets.contains(op->name)) { - return Variable::make(lets.get(op->name), op->name); + if (const Type *t = lets.find(op->name)) { + return Variable::make(*t, op->name); } else { return op; } diff --git a/src/ExprUsesVar.h b/src/ExprUsesVar.h index 3bf129d259f7..84c3f7ae23d4 100644 --- a/src/ExprUsesVar.h +++ b/src/ExprUsesVar.h @@ -36,8 +36,8 @@ class ExprUsesVars : public IRGraphVisitor { void visit_name(const std::string &name) { if (vars.contains(name)) { result = true; - } else if (scope.contains(name)) { - include(scope.get(name)); + } else if (const Expr *e = scope.find(name)) { + IRGraphVisitor::include(*e); } } diff --git a/src/FindIntrinsics.cpp b/src/FindIntrinsics.cpp index febd88d2399b..d453d0134c29 100644 --- a/src/FindIntrinsics.cpp +++ b/src/FindIntrinsics.cpp @@ -1118,8 +1118,8 @@ class SubstituteInWideningLets : public IRMutator { Scope replacements; Expr visit(const Variable *op) override { - if (replacements.contains(op->name)) { - return replacements.get(op->name); + if (const Expr *e = replacements.find(op->name)) { + return *e; } else { return op; } diff --git a/src/FuseGPUThreadLoops.cpp b/src/FuseGPUThreadLoops.cpp index ef5a75344bb8..abde50d62e1f 100644 --- a/src/FuseGPUThreadLoops.cpp +++ b/src/FuseGPUThreadLoops.cpp @@ -1140,21 +1140,21 @@ class ExtractRegisterAllocations : public IRMutator { } Expr visit(const Load *op) override { - string new_name = op->name; - if (alloc_renaming.contains(op->name)) { - new_name = alloc_renaming.get(op->name); + const string *new_name = alloc_renaming.find(op->name); + if (!new_name) { + new_name = &(op->name); } - return Load::make(op->type, new_name, mutate(op->index), + return Load::make(op->type, *new_name, mutate(op->index), op->image, op->param, mutate(op->predicate), op->alignment); } Stmt visit(const Store *op) override { - string new_name = op->name; - if (alloc_renaming.contains(op->name)) { - new_name = alloc_renaming.get(op->name); + const string *new_name = alloc_renaming.find(op->name); + if (!new_name) { + new_name = &(op->name); } - return Store::make(new_name, mutate(op->value), mutate(op->index), + return Store::make(*new_name, mutate(op->value), mutate(op->index), op->param, mutate(op->predicate), op->alignment); } diff --git a/src/HexagonOptimize.cpp b/src/HexagonOptimize.cpp index b76a9eb1cfef..deabd95d1d1b 100644 --- a/src/HexagonOptimize.cpp +++ b/src/HexagonOptimize.cpp @@ -1357,8 +1357,8 @@ class EliminateInterleaves : public IRMutator { } if (const Load *load = x.as()) { - if (buffers.contains(load->name)) { - return buffers.get(load->name) != BufferState::NotInterleaved; + if (const auto *state = buffers.find(load->name)) { + return *state != BufferState::NotInterleaved; } } @@ -1398,8 +1398,8 @@ class EliminateInterleaves : public IRMutator { } if (const Load *load = x.as()) { - if (buffers.contains(load->name)) { - return buffers.get(load->name) != BufferState::NotInterleaved; + if (const auto *state = buffers.find(load->name)) { + return *state != BufferState::NotInterleaved; } } @@ -1816,34 +1816,33 @@ class EliminateInterleaves : public IRMutator { Expr value = mutate(op->value); Expr index = mutate(op->index); - if (buffers.contains(op->name)) { + if (BufferState *state = buffers.shallow_find(op->name)) { // When inspecting the stores to a buffer, update the state. - BufferState &state = buffers.ref(op->name); if (!is_const_one(predicate) || !op->value.type().is_vector()) { // TODO(psuriana): This store is predicated. Mark the buffer as // not interleaved for now. - state = BufferState::NotInterleaved; + *state = BufferState::NotInterleaved; } else if (yields_removable_interleave(value)) { // The value yields a removable interleave. If we aren't tracking // this buffer, mark it as interleaved. - if (state == BufferState::Unknown) { - state = BufferState::Interleaved; + if (*state == BufferState::Unknown) { + *state = BufferState::Interleaved; } } else if (!yields_interleave(value)) { // The value does not yield an interleave. Mark the // buffer as not interleaved. - state = BufferState::NotInterleaved; + *state = BufferState::NotInterleaved; } else { // If the buffer yields an interleave, but is not an // interleave itself, we don't want to change the // buffer state. } - internal_assert(aligned_buffer_access.contains(op->name) && "Buffer not found in scope"); - bool &aligned_accesses = aligned_buffer_access.ref(op->name); + bool *aligned_accesses = aligned_buffer_access.shallow_find(op->name); + internal_assert(aligned_accesses) << "Buffer not found in scope"; int64_t aligned_offset = 0; if (!alignment_analyzer.is_aligned(op, &aligned_offset)) { - aligned_accesses = false; + *aligned_accesses = false; } } if (deinterleave_buffers.contains(op->name)) { @@ -1872,12 +1871,13 @@ class EliminateInterleaves : public IRMutator { // which is only true if any of the stores are // actually interleaved (and don't just yield an // interleave). - internal_assert(aligned_buffer_access.contains(op->name) && "Buffer not found in scope"); - bool &aligned_accesses = aligned_buffer_access.ref(op->name); + bool *aligned_accesses = aligned_buffer_access.shallow_find(op->name); + internal_assert(aligned_accesses) << "Buffer not found in scope"; + int64_t aligned_offset = 0; if (!alignment_analyzer.is_aligned(op, &aligned_offset)) { - aligned_accesses = false; + *aligned_accesses = false; } } else { // This is not a double vector load, so we can't diff --git a/src/LICM.cpp b/src/LICM.cpp index 641f4982a3e2..719b41442cfc 100644 --- a/src/LICM.cpp +++ b/src/LICM.cpp @@ -350,8 +350,8 @@ class GroupLoopInvariants : public IRMutator { const Scope &depth; void visit(const Variable *op) override { - if (depth.contains(op->name)) { - result = std::max(result, depth.get(op->name)); + if (const int *d = depth.find(op->name)) { + result = std::max(result, *d); } } diff --git a/src/LoopCarry.cpp b/src/LoopCarry.cpp index 050cdfbfc8d9..bfc2abc8ddf1 100644 --- a/src/LoopCarry.cpp +++ b/src/LoopCarry.cpp @@ -27,8 +27,8 @@ Expr is_linear(const Expr &e, const Scope &linear) { return Expr(); } if (const Variable *v = e.as()) { - if (linear.contains(v->name)) { - return linear.get(v->name); + if (const Expr *e = linear.find(v->name)) { + return *e; } else { return make_zero(v->type); } @@ -140,18 +140,17 @@ class StepForwards : public IRGraphMutator { using IRGraphMutator::visit; Expr visit(const Variable *op) override { - if (linear.contains(op->name)) { - Expr step = linear.get(op->name); - if (!step.defined()) { + if (const Expr *step = linear.find(op->name)) { + if (!step->defined()) { // It's non-linear success = false; return op; - } else if (is_const_zero(step)) { + } else if (is_const_zero(*step)) { // It's a known inner constant return op; } else { // It's linear - return Expr(op) + step; + return Expr(op) + *step; } } else { // It's some external constant diff --git a/src/LowerWarpShuffles.cpp b/src/LowerWarpShuffles.cpp index 79332c9336e5..ad48c37db78f 100644 --- a/src/LowerWarpShuffles.cpp +++ b/src/LowerWarpShuffles.cpp @@ -149,8 +149,8 @@ class DetermineAllocStride : public IRVisitor { } else if (const Variable *var = e.as()) { if (var->name == lane_var) { return 1; - } else if (dependent_vars.contains(var->name)) { - return dependent_vars.get(var->name); + } else if (const Expr *e = dependent_vars.find(var->name)) { + return *e; } else { return 0; } @@ -475,8 +475,9 @@ class LowerWarpShuffles : public IRMutator { if ((lt && equal(lt->a, this_lane) && is_const(lt->b)) || (le && equal(le->a, this_lane) && is_const(le->b))) { Expr condition = mutate(op->condition); - internal_assert(bounds.contains(this_lane_name)); - Interval interval = bounds.get(this_lane_name); + const Interval *in = bounds.find(this_lane_name); + internal_assert(in); + Interval interval = *in; interval.max = lt ? simplify(lt->b - 1) : le->b; ScopedBinding bind(bounds, this_lane_name, interval); Stmt then_case = mutate(op->then_case); @@ -488,10 +489,10 @@ class LowerWarpShuffles : public IRMutator { } Stmt visit(const Store *op) override { - if (allocation_info.contains(op->name)) { + if (const auto *alloc = allocation_info.find(op->name)) { Expr idx = mutate(op->index); Expr value = mutate(op->value); - Expr stride = allocation_info.get(op->name).stride; + Expr stride = alloc->stride; internal_assert(stride.defined() && warp_size.defined()); // Reduce the index to an index in my own stripe. We have @@ -639,9 +640,9 @@ class LowerWarpShuffles : public IRMutator { } Expr visit(const Load *op) override { - if (allocation_info.contains(op->name)) { + if (const auto *alloc = allocation_info.find(op->name)) { Expr idx = mutate(op->index); - Expr stride = allocation_info.get(op->name).stride; + Expr stride = alloc->stride; // Break the index into lane and stripe components Expr lane = simplify(reduce_expr(idx / stride, warp_size, bounds), true, bounds); diff --git a/src/ModulusRemainder.cpp b/src/ModulusRemainder.cpp index cfccce1da786..13b3c72a181d 100644 --- a/src/ModulusRemainder.cpp +++ b/src/ModulusRemainder.cpp @@ -110,8 +110,8 @@ void ComputeModulusRemainder::visit(const Reinterpret *) { } void ComputeModulusRemainder::visit(const Variable *op) { - if (scope.contains(op->name)) { - result = scope.get(op->name); + if (const auto *m = scope.find(op->name)) { + result = *m; } else { result = ModulusRemainder{}; } diff --git a/src/Monotonic.cpp b/src/Monotonic.cpp index dd8e17d5b177..fee151f00a22 100644 --- a/src/Monotonic.cpp +++ b/src/Monotonic.cpp @@ -280,8 +280,8 @@ class DerivativeBounds : public IRVisitor { void visit(const Variable *op) override { if (op->name == var) { result = ConstantInterval::single_point(1); - } else if (scope.contains(op->name)) { - result = scope.get(op->name); + } else if (const auto *r = scope.find(op->name)) { + result = *r; } else { result = ConstantInterval::single_point(0); } diff --git a/src/Prefetch.cpp b/src/Prefetch.cpp index c0fb1f5c9a64..144b1950c5cd 100644 --- a/src/Prefetch.cpp +++ b/src/Prefetch.cpp @@ -86,10 +86,9 @@ class InjectPrefetch : public IRMutator { using IRMutator::visit; Box get_buffer_bounds(const string &name, int dims) { - if (buffer_bounds.contains(name)) { - const Box &b = buffer_bounds.ref(name); - internal_assert((int)b.size() == dims); - return b; + if (const Box *b = buffer_bounds.find(name)) { + internal_assert((int)b->size() == dims); + return *b; } // It is an external buffer. diff --git a/src/PrintLoopNest.cpp b/src/PrintLoopNest.cpp index 52f1c319951a..9d38efaaf80a 100644 --- a/src/PrintLoopNest.cpp +++ b/src/PrintLoopNest.cpp @@ -94,12 +94,16 @@ class PrintLoopNest : public IRVisitor { Expr min_val = op->min, extent_val = op->extent; const Variable *min_var = min_val.as(); const Variable *extent_var = extent_val.as(); - if (min_var && constants.contains(min_var->name)) { - min_val = constants.get(min_var->name); + if (min_var) { + if (const Expr *e = constants.find(min_var->name)) { + min_val = *e; + } } - if (extent_var && constants.contains(extent_var->name)) { - extent_val = constants.get(extent_var->name); + if (extent_var) { + if (const Expr *e = constants.find(extent_var->name)) { + extent_val = *e; + } } if (extent_val.defined() && is_const(extent_val) && @@ -151,9 +155,8 @@ class PrintLoopNest : public IRVisitor { void visit(const LetStmt *op) override { if (is_const(op->value)) { - constants.push(op->name, op->value); + ScopedBinding bind(constants, op->name, op->value); op->body.accept(this); - constants.pop(op->name); } else { op->body.accept(this); } diff --git a/src/Scope.h b/src/Scope.h index 9d1cc43e1164..94d9eb9c165b 100644 --- a/src/Scope.h +++ b/src/Scope.h @@ -150,7 +150,39 @@ class Scope { return iter->second.top_ref(); } - /** Tests if a name is in scope */ + /** Returns a const pointer to an entry if it exists in this scope or any + * containing scope, or nullptr if it does not. Use this instead of if + * (scope.contains(foo)) { ... scope.get(foo) ... } to avoid doing two + * lookups. */ + template::value>::type> + const T2 *find(const std::string &name) const { + typename std::map>::const_iterator iter = table.find(name); + if (iter == table.end() || iter->second.empty()) { + if (containing_scope) { + return containing_scope->find(name); + } else { + return nullptr; + } + } + return &(iter->second.top_ref()); + } + + /** A version of find that returns a non-const pointer, but ignores + * containing scope. */ + template::value>::type> + T2 *shallow_find(const std::string &name) { + typename std::map>::iterator iter = table.find(name); + if (iter == table.end() || iter->second.empty()) { + return nullptr; + } else { + return &(iter->second.top_ref()); + } + } + + /** Tests if a name is in scope. If you plan to use the value if it is, call + * find instead. */ bool contains(const std::string &name) const { typename std::map>::const_iterator iter = table.find(name); if (iter == table.end() || iter->second.empty()) { @@ -173,19 +205,28 @@ class Scope { } } - /** Add a new (name, value) pair to the current scope. Hide old - * values that have this name until we pop this name. + struct PushToken { + typename std::map>::iterator iter; + }; + + /** Add a new (name, value) pair to the current scope. Hide old values that + * have this name until we pop this name. Returns a token that can be used + * to pop the same value without doing a fresh lookup. */ template::value>::type> - void push(const std::string &name, T2 &&value) { - table[name].push(std::forward(value)); + PushToken push(const std::string &name, T2 &&value) { + auto it = table.try_emplace(name).first; + it->second.push(std::forward(value)); + return PushToken{it}; } template::value>::type> - void push(const std::string &name) { - table[name].push(); + PushToken push(const std::string &name) { + auto it = table.try_emplace(name).first; + it->second.push(); + return PushToken{it}; } /** A name goes out of scope. Restore whatever its old value @@ -201,6 +242,14 @@ class Scope { } } + /** Pop a name using a token returned by push instead of a string. */ + void pop(PushToken p) { + p.iter->second.pop(); + if (p.iter->second.empty()) { + table.erase(p.iter); + } + } + /** Iterate through the scope. Does not capture any containing scope. */ class const_iterator { typename std::map>::const_iterator iter; @@ -271,20 +320,17 @@ std::ostream &operator<<(std::ostream &stream, const Scope &s) { template struct ScopedBinding { Scope *scope = nullptr; - std::string name; + typename Scope::PushToken token; ScopedBinding() = default; ScopedBinding(Scope &s, const std::string &n, T value) - : scope(&s), name(n) { - scope->push(name, std::move(value)); + : scope(&s), token(scope->push(n, std::move(value))) { } ScopedBinding(bool condition, Scope &s, const std::string &n, const T &value) - : scope(condition ? &s : nullptr), name(n) { - if (condition) { - scope->push(name, value); - } + : scope(condition ? &s : nullptr), + token(condition ? scope->push(n, value) : typename Scope::PushToken{}) { } bool bound() const { @@ -293,7 +339,7 @@ struct ScopedBinding { ~ScopedBinding() { if (scope) { - scope->pop(name); + scope->pop(token); } } @@ -301,7 +347,7 @@ struct ScopedBinding { ScopedBinding(const ScopedBinding &that) = delete; ScopedBinding(ScopedBinding &&that) noexcept : scope(that.scope), - name(std::move(that.name)) { + token(that.token) { // The move constructor must null out scope, so we don't try to pop it that.scope = nullptr; } @@ -313,20 +359,17 @@ struct ScopedBinding { template<> struct ScopedBinding { Scope<> *scope; - std::string name; + Scope<>::PushToken token; ScopedBinding(Scope<> &s, const std::string &n) - : scope(&s), name(n) { - scope->push(name); + : scope(&s), token(scope->push(n)) { } ScopedBinding(bool condition, Scope<> &s, const std::string &n) - : scope(condition ? &s : nullptr), name(n) { - if (condition) { - scope->push(name); - } + : scope(condition ? &s : nullptr), + token(condition ? scope->push(n) : Scope<>::PushToken{}) { } ~ScopedBinding() { if (scope) { - scope->pop(name); + scope->pop(token); } } @@ -334,7 +377,7 @@ struct ScopedBinding { ScopedBinding(const ScopedBinding &that) = delete; ScopedBinding(ScopedBinding &&that) noexcept : scope(that.scope), - name(std::move(that.name)) { + token(that.token) { // The move constructor must null out scope, so we don't try to pop it that.scope = nullptr; } diff --git a/src/Simplify.cpp b/src/Simplify.cpp index 339ef2917c83..61cf7886cb70 100644 --- a/src/Simplify.cpp +++ b/src/Simplify.cpp @@ -34,8 +34,8 @@ Simplify::Simplify(bool r, const Scope *bi, const Scopecontains(iter.name())) { - bounds.alignment = ai->get(iter.name()); + if (const auto *a = ai->find(iter.name())) { + bounds.alignment = *a; } if (bounds.min_defined || bounds.max_defined || bounds.alignment.modulus != 1) { @@ -74,18 +74,18 @@ std::pair, bool> Simplify::mutate_with_changes(const std::vect void Simplify::found_buffer_reference(const string &name, size_t dimensions) { for (size_t i = 0; i < dimensions; i++) { string stride = name + ".stride." + std::to_string(i); - if (var_info.contains(stride)) { - var_info.ref(stride).old_uses++; + if (auto *info = var_info.shallow_find(stride)) { + info->old_uses++; } string min = name + ".min." + std::to_string(i); - if (var_info.contains(min)) { - var_info.ref(min).old_uses++; + if (auto *info = var_info.shallow_find(min)) { + info->old_uses++; } } - if (var_info.contains(name)) { - var_info.ref(name).old_uses++; + if (auto *info = var_info.shallow_find(name)) { + info->old_uses++; } } @@ -187,8 +187,8 @@ void Simplify::ScopedFact::learn_upper_bound(const Variable *v, int64_t val) { ExprInfo b; b.max_defined = true; b.max = val; - if (simplify->bounds_and_alignment_info.contains(v->name)) { - b.intersect(simplify->bounds_and_alignment_info.get(v->name)); + if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) { + b.intersect(*info); } simplify->bounds_and_alignment_info.push(v->name, b); bounds_pop_list.push_back(v); @@ -198,8 +198,8 @@ void Simplify::ScopedFact::learn_lower_bound(const Variable *v, int64_t val) { ExprInfo b; b.min_defined = true; b.min = val; - if (simplify->bounds_and_alignment_info.contains(v->name)) { - b.intersect(simplify->bounds_and_alignment_info.get(v->name)); + if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) { + b.intersect(*info); } simplify->bounds_and_alignment_info.push(v->name, b); bounds_pop_list.push_back(v); @@ -228,10 +228,9 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { // TODO: Visiting it again is inefficient Simplify::ExprInfo expr_info; simplify->mutate(eq->b, &expr_info); - if (simplify->bounds_and_alignment_info.contains(v->name)) { + if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) { // We already know something about this variable and don't want to suppress it. - auto existing_knowledge = simplify->bounds_and_alignment_info.get(v->name); - expr_info.intersect(existing_knowledge); + expr_info.intersect(*info); } simplify->bounds_and_alignment_info.push(v->name, expr_info); bounds_pop_list.push_back(v); @@ -245,10 +244,9 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { // TODO: Visiting it again is inefficient Simplify::ExprInfo expr_info; simplify->mutate(eq->a, &expr_info); - if (simplify->bounds_and_alignment_info.contains(vb->name)) { + if (const auto *info = simplify->bounds_and_alignment_info.find(vb->name)) { // We already know something about this variable and don't want to suppress it. - auto existing_knowledge = simplify->bounds_and_alignment_info.get(vb->name); - expr_info.intersect(existing_knowledge); + expr_info.intersect(*info); } simplify->bounds_and_alignment_info.push(vb->name, expr_info); bounds_pop_list.push_back(vb); @@ -257,10 +255,9 @@ void Simplify::ScopedFact::learn_true(const Expr &fact) { Simplify::ExprInfo expr_info; expr_info.alignment.modulus = *modulus; expr_info.alignment.remainder = *remainder; - if (simplify->bounds_and_alignment_info.contains(v->name)) { + if (const auto *info = simplify->bounds_and_alignment_info.find(v->name)) { // We already know something about this variable and don't want to suppress it. - auto existing_knowledge = simplify->bounds_and_alignment_info.get(v->name); - expr_info.intersect(existing_knowledge); + expr_info.intersect(*info); } simplify->bounds_and_alignment_info.push(v->name, expr_info); bounds_pop_list.push_back(v); @@ -417,8 +414,8 @@ bool can_prove(Expr e, const Scope &bounds) { Expr visit(const Variable *op) override { auto it = vars.find(op->name); - if (lets.contains(op->name)) { - return Variable::make(op->type, lets.get(op->name)); + if (const std::string *n = lets.find(op->name)) { + return Variable::make(op->type, *n); } else if (it == vars.end()) { std::string name = "v" + std::to_string(count++); vars[op->name] = name; diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index a8e5fcce1a8d..b5fcc96ac0cd 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -221,35 +221,32 @@ Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { } Expr Simplify::visit(const Variable *op, ExprInfo *bounds) { - if (bounds_and_alignment_info.contains(op->name)) { - const ExprInfo &b = bounds_and_alignment_info.get(op->name); + if (const ExprInfo *b = bounds_and_alignment_info.find(op->name)) { if (bounds) { - *bounds = b; + *bounds = *b; } - if (b.min_defined && b.max_defined && b.min == b.max) { - return make_const(op->type, b.min); + if (b->min_defined && b->max_defined && b->min == b->max) { + return make_const(op->type, b->min); } } - if (var_info.contains(op->name)) { - auto &info = var_info.ref(op->name); - + if (auto *info = var_info.shallow_find(op->name)) { // if replacement is defined, we should substitute it in (unless // it's a var that has been hidden by a nested scope). - if (info.replacement.defined()) { - internal_assert(info.replacement.type() == op->type) + if (info->replacement.defined()) { + internal_assert(info->replacement.type() == op->type) << "Cannot replace variable " << op->name << " of type " << op->type - << " with expression of type " << info.replacement.type() << "\n"; - info.new_uses++; + << " with expression of type " << info->replacement.type() << "\n"; + info->new_uses++; // We want to remutate the replacement, because we may be // injecting it into a context where it is known to be a // constant (e.g. due to an if). - return mutate(info.replacement, bounds); + return mutate(info->replacement, bounds); } else { // This expression was not something deemed // substitutable - no replacement is defined. - info.old_uses++; + info->old_uses++; return op; } } else { @@ -321,15 +318,14 @@ Expr Simplify::visit(const Load *op, ExprInfo *bounds) { // unreachable loads. if (is_const_one(op->predicate)) { string alloc_extent_name = op->name + ".total_extent_bytes"; - if (bounds_and_alignment_info.contains(alloc_extent_name)) { + if (const auto *alloc_info = bounds_and_alignment_info.find(alloc_extent_name)) { if (index_info.max_defined && index_info.max < 0) { in_unreachable = true; return unreachable(op->type); } - const ExprInfo &alloc_info = bounds_and_alignment_info.get(alloc_extent_name); - if (alloc_info.max_defined && index_info.min_defined) { + if (alloc_info->max_defined && index_info.min_defined) { int index_min_bytes = index_info.min * op->type.bytes(); - if (index_min_bytes > alloc_info.max) { + if (index_min_bytes > alloc_info->max) { in_unreachable = true; return unreachable(op->type); } diff --git a/src/Simplify_Stmts.cpp b/src/Simplify_Stmts.cpp index 11b146ecdc6a..f6cb81345961 100644 --- a/src/Simplify_Stmts.cpp +++ b/src/Simplify_Stmts.cpp @@ -305,19 +305,19 @@ Stmt Simplify::visit(const Store *op) { // but perhaps the branch was hard to prove constant true or false. This // provides an alternative mechanism to simplify these unreachable stores. string alloc_extent_name = op->name + ".total_extent_bytes"; - if (is_const_one(op->predicate) && - bounds_and_alignment_info.contains(alloc_extent_name)) { - if (index_info.max_defined && index_info.max < 0) { - in_unreachable = true; - return Evaluate::make(unreachable()); - } - const ExprInfo &alloc_info = bounds_and_alignment_info.get(alloc_extent_name); - if (alloc_info.max_defined && index_info.min_defined) { - int index_min_bytes = index_info.min * op->value.type().bytes(); - if (index_min_bytes > alloc_info.max) { + if (is_const_one(op->predicate)) { + if (const auto *alloc_info = bounds_and_alignment_info.find(alloc_extent_name)) { + if (index_info.max_defined && index_info.max < 0) { in_unreachable = true; return Evaluate::make(unreachable()); } + if (alloc_info->max_defined && index_info.min_defined) { + int index_min_bytes = index_info.min * op->value.type().bytes(); + if (index_min_bytes > alloc_info->max) { + in_unreachable = true; + return Evaluate::make(unreachable()); + } + } } } diff --git a/src/SlidingWindow.cpp b/src/SlidingWindow.cpp index ab25ad32bc87..dfb50d714e37 100644 --- a/src/SlidingWindow.cpp +++ b/src/SlidingWindow.cpp @@ -69,10 +69,9 @@ class ExpandExpr : public IRMutator { const Scope &scope; Expr visit(const Variable *var) override { - if (scope.contains(var->name)) { - Expr expr = scope.get(var->name); - debug(4) << "Fully expanded " << var->name << " -> " << expr << "\n"; - return expr; + if (const Expr *expr = scope.find(var->name)) { + debug(4) << "Fully expanded " << var->name << " -> " << *expr << "\n"; + return *expr; } else { return var; } diff --git a/src/Solve.cpp b/src/Solve.cpp index b25719cff8c7..09245d90bf24 100644 --- a/src/Solve.cpp +++ b/src/Solve.cpp @@ -786,17 +786,15 @@ class SolveExpression : public IRMutator { if (op->name == var) { uses_var = true; return op; - } else if (scope.contains(op->name)) { - CacheEntry e = scope.get(op->name); - uses_var = uses_var || e.uses_var; - failed = failed || e.failed; - return e.expr; - } else if (external_scope.contains(op->name)) { - Expr e = external_scope.get(op->name); + } else if (const CacheEntry *e = scope.find(op->name)) { + uses_var = uses_var || e->uses_var; + failed = failed || e->failed; + return e->expr; + } else if (const Expr *e = external_scope.find(op->name)) { // Expressions in the external scope haven't been solved // yet. This will either pull its solution from the cache, // or solve it and then put it into the cache. - return mutate(e); + return mutate(*e); } else { return op; } @@ -948,13 +946,13 @@ class SolveForInterval : public IRVisitor { void visit(const Variable *op) override { internal_assert(op->type.is_bool()); - if (scope.contains(op->name)) { + if (const Expr *e = scope.find(op->name)) { pair key = {op->name, target}; auto it = solved_vars.find(key); if (it != solved_vars.end()) { result = it->second; } else { - scope.get(op->name).accept(this); + e->accept(this); solved_vars[key] = result; } } else { diff --git a/src/StageStridedLoads.cpp b/src/StageStridedLoads.cpp index feeab56a4122..723fc738ce51 100644 --- a/src/StageStridedLoads.cpp +++ b/src/StageStridedLoads.cpp @@ -103,8 +103,8 @@ class FindStridedLoads : public IRVisitor { if (stride >= 2 && stride < r->lanes && r->stride.type().is_scalar()) { const IRNode *s = scope; const Allocate *a = nullptr; - if (allocation_scope.contains(op->name)) { - a = allocation_scope.get(op->name); + if (const Allocate *const *a_ptr = allocation_scope.find(op->name)) { + a = *a_ptr; } found_loads[Key{op->name, base, stride, r->lanes, op->type, a, s}][offset].push_back(op); } @@ -161,8 +161,8 @@ class ReplaceStridedLoads : public IRMutator { protected: Expr visit(const Load *op) override { const Allocate *alloc = nullptr; - if (allocation_scope.contains(op->name)) { - alloc = allocation_scope.get(op->name); + if (const Allocate *const *a_ptr = allocation_scope.find(op->name)) { + alloc = *a_ptr; } auto it = replacements.find({alloc, op}); if (it != replacements.end()) { diff --git a/src/StmtToHTML.cpp b/src/StmtToHTML.cpp index 9c317ba35525..79cf6563551e 100644 --- a/src/StmtToHTML.cpp +++ b/src/StmtToHTML.cpp @@ -1134,8 +1134,8 @@ class HTMLCodePrinter : public IRVisitor { std::string variable(const std::string &x, const std::string &tooltip) { int id; - if (scope.contains(x)) { - id = scope.get(x); + if (const int *i = scope.find(x)) { + id = *i; } else { id = gen_unique_id(); scope.push(x, id); diff --git a/src/StorageFlattening.cpp b/src/StorageFlattening.cpp index d7e7c50002f6..13d7d6475120 100644 --- a/src/StorageFlattening.cpp +++ b/src/StorageFlattening.cpp @@ -31,10 +31,9 @@ class ExpandExpr : public IRMutator { const Scope &scope; Expr visit(const Variable *var) override { - if (scope.contains(var->name)) { - Expr expr = scope.get(var->name); + if (const Expr *e = scope.find(var->name)) { // Mutate the expression, so lets can get replaced recursively. - expr = mutate(expr); + Expr expr = mutate(*e); debug(4) << "Fully expanded " << var->name << " -> " << expr << "\n"; return expr; } else { diff --git a/src/UniquifyVariableNames.cpp b/src/UniquifyVariableNames.cpp index 26689ec34633..85a6ba521771 100644 --- a/src/UniquifyVariableNames.cpp +++ b/src/UniquifyVariableNames.cpp @@ -104,10 +104,9 @@ class UniquifyVariableNames : public IRMutator { } Expr visit(const Variable *op) override { - if (renaming.contains(op->name)) { - string new_name = renaming.get(op->name); - if (new_name != op->name) { - return Variable::make(op->type, new_name); + if (const string *new_name = renaming.find(op->name)) { + if (*new_name != op->name) { + return Variable::make(op->type, *new_name); } } return op; diff --git a/src/VectorizeLoops.cpp b/src/VectorizeLoops.cpp index 6d10d2e9d5f3..0745a34a9d39 100644 --- a/src/VectorizeLoops.cpp +++ b/src/VectorizeLoops.cpp @@ -297,8 +297,8 @@ bool is_interleaved_ramp(const Expr &e, const Scope &scope, InterleavedRam return true; } } else if (const Variable *var = e.as()) { - if (scope.contains(var->name)) { - return is_interleaved_ramp(scope.get(var->name), scope, result); + if (const Expr *e = scope.find(var->name)) { + return is_interleaved_ramp(*e, scope, result); } } return false;