Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make realization order invariant to unique_name suffixes #8124

Merged
merged 4 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/Deserialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,12 @@ void Deserializer::deserialize_function(const Serialize::Func *function, Functio
deserialize_vector<flatbuffers::String, std::string>(function->trace_tags(),
&Deserializer::deserialize_string);
const bool frozen = function->frozen();
const uint64_t definition_order = function->definition_order();
hl_function.update_with_deserialization(name, origin_name, output_types, required_types,
required_dim, args, func_schedule, init_def, updates,
debug_file, output_buffers, extern_arguments, extern_function_name,
name_mangling, extern_function_device_api, extern_proxy_expr,
trace_loads, trace_stores, trace_realizations, trace_tags, frozen);
trace_loads, trace_stores, trace_realizations, trace_tags, frozen, definition_order);
}

Stmt Deserializer::deserialize_stmt(Serialize::Stmt type_code, const void *stmt) {
Expand Down
19 changes: 18 additions & 1 deletion src/Function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ class WeakenFunctionPtrs : public IRMutator {
}
};

uint64_t definition_order_counter() {
static std::atomic<uint64_t> counter{};
return counter++;
}

} // namespace

struct FunctionContents {
Expand Down Expand Up @@ -112,6 +117,8 @@ struct FunctionContents {

bool frozen = false;

uint64_t definition_order = 0;

void accept(IRVisitor *visitor) const {
func_schedule.accept(visitor);

Expand Down Expand Up @@ -352,7 +359,8 @@ void Function::update_with_deserialization(const std::string &name,
bool trace_stores,
bool trace_realizations,
const std::vector<std::string> &trace_tags,
bool frozen) {
bool frozen,
uint64_t definition_order) {
contents->name = name;
contents->origin_name = origin_name;
contents->output_types = output_types;
Expand All @@ -374,6 +382,7 @@ void Function::update_with_deserialization(const std::string &name,
contents->trace_realizations = trace_realizations;
contents->trace_tags = trace_tags;
contents->frozen = frozen;
contents->definition_order = definition_order;
}

namespace {
Expand Down Expand Up @@ -512,6 +521,7 @@ void Function::deep_copy(const FunctionPtr &copy, DeepCopyMap &copied_map) const
copy->frozen = contents->frozen;
copy->output_buffers = contents->output_buffers;
copy->func_schedule = contents->func_schedule.deep_copy(copied_map);
copy->definition_order = contents->definition_order;

// Copy the pure definition
if (contents->init_def.defined()) {
Expand Down Expand Up @@ -616,6 +626,8 @@ void Function::define(const vector<string> &args, vector<Expr> values) {
check_dims((int)args.size());
contents->args = args;

contents->definition_order = definition_order_counter();

std::vector<Expr> init_def_args;
init_def_args.resize(args.size());
for (size_t i = 0; i < args.size(); i++) {
Expand Down Expand Up @@ -902,6 +914,7 @@ void Function::define_extern(const std::string &function_name,
contents->output_types = types;
contents->extern_mangling = mangling;
contents->extern_function_device_api = device_api;
contents->definition_order = definition_order_counter();

std::vector<Expr> values;
contents->output_buffers.clear();
Expand Down Expand Up @@ -1326,5 +1339,9 @@ pair<vector<Function>, map<string, Function>> deep_copy(
return {copy_outputs, copy_env};
}

uint64_t Function::definition_order() const {
return contents->definition_order;
}

} // namespace Internal
} // namespace Halide
7 changes: 6 additions & 1 deletion src/Function.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,8 @@ class Function {
bool trace_stores,
bool trace_realizations,
const std::vector<std::string> &trace_tags,
bool frozen);
bool frozen,
uint64_t definition_order);

/** Get a handle on the halide function contents that this Function
* represents. */
Expand Down Expand Up @@ -347,6 +348,10 @@ class Function {
/** Define the output buffers. If the Function has types specified, this can be called at
* any time. If not, it can only be called for a Function with a pure definition. */
void create_output_buffers(const std::vector<Type> &types, int dims) const;

/** A unique counter which increments each time a Function is given its pure
* definition. Used for ordering Funcs in a name-agnostic way. */
uint64_t definition_order() const;
};

/** Deep copy an entire Function DAG. */
Expand Down
57 changes: 54 additions & 3 deletions src/RealizationOrder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@ find_fused_groups(const map<string, Function> &env,
map<string, vector<string>> fused_groups;
map<string, string> group_name;

int counter = 0;
for (const auto &iter : env) {
const string &fn = iter.first;
if (visited.find(fn) == visited.end()) {
vector<string> group;
find_fused_groups_dfs(fn, fuse_adjacency_list, visited, group);

// Create a unique name for the fused group.
string rename = unique_name("_fg");
string rename = "_fg" + std::to_string(counter++);
fused_groups.emplace(rename, group);
for (const auto &m : group) {
group_name.emplace(m, rename);
Expand All @@ -69,7 +70,7 @@ void realization_order_dfs(const string &current,
internal_assert(iter != graph.end());

for (const string &fn : iter->second) {
internal_assert(fn != current);
internal_assert(fn != current) << fn;
if (visited.find(fn) == visited.end()) {
realization_order_dfs(fn, graph, visited, result_set, order);
} else {
Expand Down Expand Up @@ -235,6 +236,50 @@ void check_fused_stages_are_scheduled_in_order(const Function &f) {
}
}

// Reorder Funcs in a vector to have an order that's resistant to unique_name
// calls, so that multitarget builds don't get arbitrary changes to topological
// ordering, and so that machine-generated schedules (which depend on the
// topological order) and less likely to be invalidated by things that have
// happened in the same process earlier.
//
// To do this, we break each name into a prefix, the definition order counter of
// the Func, and then finally the full original name. The prefix is what you get
// after stripping off anything after a $ (to handle suffixes introduced by
// multi-character unique_name calls), and then stripping off any digits (to
// handle suffixes introduced by single-character unique_name calls).
//
// This is gross. The reason we don't just break ties by definition order alone
// is two-fold. First, it's more likely to be consistent with the realization
// order before this sorting was done. Second, consider a multi-target
// compilation scenario in which the same pipeline is compiled for many targets
// in the same process. Now say there is a Func that is shared by all targets,
// but only defined the first time it is used, halfway through the definition of
// the first Pipeline (e.g. because it's a static local in some helper
// function). Its definition order in that first target is midway through the
// pipeline, but its definition order for every every subsequent target is
// before the start of the pipeline, so if we sort by definition order alone it
// won't show up in a consistent place. If we use a name prefix as the primary
// key, then as long as it has a unique name, it will still show up in a
// consistent place.
void sort_funcs_by_name_and_counter(vector<string> *funcs, const map<string, Function> &env) {
vector<std::tuple<string, uint64_t, string>> items;
items.reserve(funcs->size());
for (size_t i = 0; i < funcs->size(); i++) {
const string &full_name = (*funcs)[i];
string prefix = split_string(full_name, "$")[0];
while (!prefix.empty() && std::isdigit(prefix.back())) {
prefix.pop_back();
}
auto it = env.find(full_name);
uint64_t counter = (it != env.end()) ? it->second.definition_order() : 0;
items.emplace_back(prefix, counter, full_name);
}
std::sort(items.begin(), items.end());
for (size_t i = 0; i < items.size(); i++) {
(*funcs)[i] = std::move(std::get<2>(items[i]));
}
}

} // anonymous namespace

pair<vector<string>, vector<vector<string>>> realization_order(
Expand Down Expand Up @@ -318,6 +363,9 @@ pair<vector<string>, vector<vector<string>>> realization_order(
}
}
}
for (auto &p : graph) {
sort_funcs_by_name_and_counter(&p.second, env);
}

// Compute the realization order of the fused groups (i.e. the dummy nodes)
// and also the realization order of the functions within a fused group.
Expand Down Expand Up @@ -376,7 +424,10 @@ vector<string> topological_order(const vector<Function> &outputs,
s.push_back(callee.first);
}
}
graph.emplace(caller.first, s);
graph.emplace(caller.first, std::move(s));
}
for (auto &p : graph) {
sort_funcs_by_name_and_counter(&p.second, env);
}

vector<string> order;
Expand Down
5 changes: 4 additions & 1 deletion src/Serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1030,6 +1030,7 @@ Offset<Serialize::Func> Serializer::serialize_function(FlatBufferBuilder &builde
trace_tags_serialized.push_back(serialize_string(builder, tag));
}
const bool frozen = function.frozen();
const uint64_t definition_order = function.definition_order();
auto func = Serialize::CreateFunc(builder,
name_serialized,
origin_name_serialized,
Expand All @@ -1050,7 +1051,9 @@ Offset<Serialize::Func> Serializer::serialize_function(FlatBufferBuilder &builde
trace_loads,
trace_stores,
trace_realizations,
builder.CreateVector(trace_tags_serialized), frozen);
builder.CreateVector(trace_tags_serialized),
frozen,
definition_order);
return func;
}

Expand Down
3 changes: 2 additions & 1 deletion src/halide_ir.fbs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ enum SerializationVersionMinor: int {
Value = 0
}
enum SerializationVersionPatch: int {
Value = 0
Value = 1
}

// from src/IR.cpp
Expand Down Expand Up @@ -714,6 +714,7 @@ table Func {
trace_realizations: bool = false;
trace_tags: [string];
frozen: bool = false;
definition_order: uint64;
}

table Pipeline {
Expand Down
1 change: 1 addition & 0 deletions test/correctness/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,7 @@ tests(GROUPS correctness
split_fuse_rvar.cpp
split_reuse_inner_name_bug.cpp
split_store_compute.cpp
stable_realization_order.cpp
stack_allocations.cpp
stage_strided_loads.cpp
stencil_chain_in_update_definitions.cpp
Expand Down
41 changes: 41 additions & 0 deletions test/correctness/stable_realization_order.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
#include "Halide.h"

using namespace Halide;
using namespace Halide::Internal;

int main(int argc, char **argv) {
// Verify that the realization order is invariant to anything to do with
// unique_name counters.

std::vector<std::string> expected;

for (int i = 0; i < 10; i++) {
std::map<std::string, Function> env;
Var x, y;
Expr s = 0;
std::vector<Func> funcs(8);
for (size_t i = 0; i < funcs.size() - 1; i++) {
funcs[i](x, y) = x + y;
s += funcs[i](x, y);
env[funcs[i].name()] = funcs[i].function();
}
funcs.back()(x, y) = s;
env[funcs.back().name()] = funcs.back().function();

auto r = realization_order({funcs.back().function()}, env).first;
// Ties in the realization order are supposed to be broken by any
// alphabetical prefix of the Func name followed by time of
// definition. All the Funcs in this test have the same name, so it
// should just depend on time of definition.
assert(r.size() == funcs.size());
for (size_t i = 0; i < funcs.size(); i++) {
if (funcs[i].name() != r[i]) {
debug(0) << "Unexpected realization order: "
<< funcs[i].name() << " != " << r[i] << "\n";
}
}
}

printf("Success!\n");
return 0;
}
Loading