From d33ffa20f233224adcf80aa147cadf7f594dda51 Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Tue, 5 Mar 2024 09:50:07 -0800 Subject: [PATCH] Make realization order invariant to unique_name suffixes (#8124) * Make realization order invariant to unique_name suffixes * Add test * definition_order -> uint64 everywhere * Use visitation order instead of definition order --------- Co-authored-by: Steven Johnson --- src/FindCalls.cpp | 83 ++++++++++++------- src/FindCalls.h | 5 ++ src/RealizationOrder.cpp | 71 +++++++++++++++- test/correctness/CMakeLists.txt | 1 + test/correctness/stable_realization_order.cpp | 41 +++++++++ 5 files changed, 167 insertions(+), 34 deletions(-) create mode 100644 test/correctness/stable_realization_order.cpp diff --git a/src/FindCalls.cpp b/src/FindCalls.cpp index 77c5ae7645cd..9345c89dcac5 100644 --- a/src/FindCalls.cpp +++ b/src/FindCalls.cpp @@ -8,24 +8,22 @@ namespace Halide { namespace Internal { -using std::map; -using std::string; -using std::vector; - namespace { + /* Find all the internal halide calls in an expr */ class FindCalls : public IRVisitor { public: - map calls; + std::map calls; + std::vector order; using IRVisitor::visit; void include_function(const Function &f) { - map::iterator iter = calls.find(f.name()); - if (iter == calls.end()) { - calls[f.name()] = f; + auto [it, inserted] = calls.emplace(f.name(), f); + if (inserted) { + order.push_back(f); } else { - user_assert(iter->second.same_as(f)) + user_assert(it->second.same_as(f)) << "Can't compile a pipeline using multiple functions with same name: " << f.name() << "\n"; } @@ -41,64 +39,87 @@ class FindCalls : public IRVisitor { } }; -void populate_environment_helper(const Function &f, map &env, - bool recursive = true, bool include_wrappers = false) { - map::const_iterator iter = env.find(f.name()); - if (iter != env.end()) { +void populate_environment_helper(const Function &f, + std::map *env, + std::vector *order, + bool recursive = true, + bool include_wrappers = false) { + std::map::const_iterator iter = env->find(f.name()); + if (iter != env->end()) { user_assert(iter->second.same_as(f)) << "Can't compile a pipeline using multiple functions with same name: " << f.name() << "\n"; return; } + auto insert_func = [](const Function &f, + std::map *env, + std::vector *order) { + auto [it, inserted] = env->emplace(f.name(), f); + if (inserted) { + order->push_back(f); + } + }; + FindCalls calls; f.accept(&calls); if (f.has_extern_definition()) { for (const ExternFuncArgument &arg : f.extern_arguments()) { if (arg.is_func()) { - Function g(arg.func); - calls.calls[g.name()] = g; + insert_func(Function{arg.func}, &calls.calls, &calls.order); } } } if (include_wrappers) { for (const auto &it : f.schedule().wrappers()) { - Function g(it.second); - calls.calls[g.name()] = g; + insert_func(Function{it.second}, &calls.calls, &calls.order); } } if (!recursive) { - env.insert(calls.calls.begin(), calls.calls.end()); + for (const Function &g : calls.order) { + insert_func(g, env, order); + } } else { - env[f.name()] = f; - - for (const auto &i : calls.calls) { - populate_environment_helper(i.second, env, recursive, include_wrappers); + insert_func(f, env, order); + for (const Function &g : calls.order) { + populate_environment_helper(g, env, order, recursive, include_wrappers); } } } } // namespace -map build_environment(const vector &funcs) { - map env; +std::map build_environment(const std::vector &funcs) { + std::map env; + std::vector order; for (const Function &f : funcs) { - populate_environment_helper(f, env, true, true); + populate_environment_helper(f, &env, &order, true, true); } return env; } -map find_transitive_calls(const Function &f) { - map res; - populate_environment_helper(f, res, true, false); +std::vector called_funcs_in_order_found(const std::vector &funcs) { + std::map env; + std::vector order; + for (const Function &f : funcs) { + populate_environment_helper(f, &env, &order, true, true); + } + return order; +} + +std::map find_transitive_calls(const Function &f) { + std::map res; + std::vector order; + populate_environment_helper(f, &res, &order, true, false); return res; } -map find_direct_calls(const Function &f) { - map res; - populate_environment_helper(f, res, false, false); +std::map find_direct_calls(const Function &f) { + std::map res; + std::vector order; + populate_environment_helper(f, &res, &order, false, false); return res; } diff --git a/src/FindCalls.h b/src/FindCalls.h index f55140ae9162..40787d922a4f 100644 --- a/src/FindCalls.h +++ b/src/FindCalls.h @@ -36,6 +36,11 @@ std::map find_transitive_calls(const Function &f); * a map of them. */ std::map build_environment(const std::vector &funcs); +/** Returns the same Functions as build_environment, but returns a vector of + * Functions instead, where the order is the order in which the Functions were + * first encountered. This is stable to changes in the names of the Functions. */ +std::vector called_funcs_in_order_found(const std::vector &funcs); + } // namespace Internal } // namespace Halide diff --git a/src/RealizationOrder.cpp b/src/RealizationOrder.cpp index 8541c17ea862..af12ba80c228 100644 --- a/src/RealizationOrder.cpp +++ b/src/RealizationOrder.cpp @@ -41,6 +41,7 @@ find_fused_groups(const map &env, map> fused_groups; map group_name; + int counter = 0; for (const auto &iter : env) { const string &fn = iter.first; if (visited.find(fn) == visited.end()) { @@ -48,7 +49,7 @@ find_fused_groups(const map &env, find_fused_groups_dfs(fn, fuse_adjacency_list, visited, group); // Create a unique name for the fused group. - string rename = unique_name("_fg"); + string rename = "_fg" + std::to_string(counter++); fused_groups.emplace(rename, group); for (const auto &m : group) { group_name.emplace(m, rename); @@ -69,7 +70,7 @@ void realization_order_dfs(const string ¤t, internal_assert(iter != graph.end()); for (const string &fn : iter->second) { - internal_assert(fn != current); + internal_assert(fn != current) << fn; if (visited.find(fn) == visited.end()) { realization_order_dfs(fn, graph, visited, result_set, order); } else { @@ -235,8 +236,63 @@ void check_fused_stages_are_scheduled_in_order(const Function &f) { } } +// Reorder Funcs in a vector to have an order that's resistant to unique_name +// calls, so that multitarget builds don't get arbitrary changes to topological +// ordering, and so that machine-generated schedules (which depend on the +// topological order) and less likely to be invalidated by things that have +// happened in the same process earlier. +// +// To do this, we break each name into a prefix, the visitation order counter of +// the Func, and then finally the full original name. The prefix is what you get +// after stripping off anything after a $ (to handle suffixes introduced by +// multi-character unique_name calls), and then stripping off any digits (to +// handle suffixes introduced by single-character unique_name calls). The +// visitation order is when the Func is first encountered in an IRVisitor +// traversal of the entire Pipeline. +// +// This is gross. The reason we don't just break ties by visitation order alone +// is because that way it's likely to be consistent with the realization +// order before this sorting was done. +void sort_funcs_by_name_and_counter(vector *funcs, + const map &env, + const map &visitation_order) { + vector> items; + items.reserve(funcs->size()); + for (size_t i = 0; i < funcs->size(); i++) { + const string &full_name = (*funcs)[i]; + string prefix = split_string(full_name, "$")[0]; + while (!prefix.empty() && std::isdigit(prefix.back())) { + prefix.pop_back(); + } + auto env_it = env.find(full_name); + uint64_t counter = 0; + if (env_it != env.end()) { + auto v_it = visitation_order.find(full_name); + internal_assert(v_it != visitation_order.end()) + << "Func " << full_name + << " is somehow in the visitation order but not the environment."; + counter = v_it->second; + } + + items.emplace_back(prefix, counter, full_name); + } + std::sort(items.begin(), items.end()); + for (size_t i = 0; i < items.size(); i++) { + (*funcs)[i] = std::move(std::get<2>(items[i])); + } +} + } // anonymous namespace +map compute_visitation_order(const vector &outputs) { + vector funcs = called_funcs_in_order_found(outputs); + map result; + for (uint64_t i = 0; i < funcs.size(); i++) { + result[funcs[i].name()] = i; + } + return result; +} + pair, vector>> realization_order( const vector &outputs, map &env) { @@ -318,6 +374,10 @@ pair, vector>> realization_order( } } } + auto visitation_order = compute_visitation_order(outputs); + for (auto &p : graph) { + sort_funcs_by_name_and_counter(&p.second, env, visitation_order); + } // Compute the realization order of the fused groups (i.e. the dummy nodes) // and also the realization order of the functions within a fused group. @@ -376,7 +436,12 @@ vector topological_order(const vector &outputs, s.push_back(callee.first); } } - graph.emplace(caller.first, s); + graph.emplace(caller.first, std::move(s)); + } + + auto visitation_order = compute_visitation_order(outputs); + for (auto &p : graph) { + sort_funcs_by_name_and_counter(&p.second, env, visitation_order); } vector order; diff --git a/test/correctness/CMakeLists.txt b/test/correctness/CMakeLists.txt index f77393a21114..9b934b768cdd 100644 --- a/test/correctness/CMakeLists.txt +++ b/test/correctness/CMakeLists.txt @@ -295,6 +295,7 @@ tests(GROUPS correctness split_fuse_rvar.cpp split_reuse_inner_name_bug.cpp split_store_compute.cpp + stable_realization_order.cpp stack_allocations.cpp stage_strided_loads.cpp stencil_chain_in_update_definitions.cpp diff --git a/test/correctness/stable_realization_order.cpp b/test/correctness/stable_realization_order.cpp new file mode 100644 index 000000000000..f62423559327 --- /dev/null +++ b/test/correctness/stable_realization_order.cpp @@ -0,0 +1,41 @@ +#include "Halide.h" + +using namespace Halide; +using namespace Halide::Internal; + +int main(int argc, char **argv) { + // Verify that the realization order is invariant to anything to do with + // unique_name counters. + + std::vector expected; + + for (int i = 0; i < 10; i++) { + std::map env; + Var x, y; + Expr s = 0; + std::vector funcs(8); + for (size_t i = 0; i < funcs.size() - 1; i++) { + funcs[i](x, y) = x + y; + s += funcs[i](x, y); + env[funcs[i].name()] = funcs[i].function(); + } + funcs.back()(x, y) = s; + env[funcs.back().name()] = funcs.back().function(); + + auto r = realization_order({funcs.back().function()}, env).first; + // Ties in the realization order are supposed to be broken by any + // alphabetical prefix of the Func name followed by time of + // definition. All the Funcs in this test have the same name, so it + // should just depend on time of definition. + assert(r.size() == funcs.size()); + for (size_t i = 0; i < funcs.size(); i++) { + if (funcs[i].name() != r[i]) { + debug(0) << "Unexpected realization order: " + << funcs[i].name() << " != " << r[i] << "\n"; + } + } + } + + printf("Success!\n"); + return 0; +}