Skip to content

Commit

Permalink
Almost working. TODO: external source, protobuf WTF.
Browse files Browse the repository at this point in the history
Signed-off-by: Michal Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Dec 16, 2024
1 parent ade0416 commit cb256d3
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 15 deletions.
63 changes: 49 additions & 14 deletions dali/pipeline/graph/cse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
namespace dali {
namespace graph {

std::string OpSpecToString(const OpSpec &spec) {
std::string OpSpecKey(const OpSpec &spec) {
dali_proto::OpDef op;
op.set_name(spec.SchemaName());

Expand All @@ -44,42 +44,77 @@ std::string OpSpecToString(const OpSpec &spec) {
}

auto &schema = spec.GetSchemaOrDefault();
for (auto& a : spec.Arguments()) {
// filter out args that need to be dealt with on
// loading a serialized pipeline
auto name = a->get_name();

std::map<std::string_view, Argument *, std::less<>> sorted_args;
for (auto &a : spec.Arguments()) {
// Some arguments should be skipped when comparing operators
if (schema.GetArgument(name).ignore_cmp)
continue;
if (schema.HasArgument(a->get_name()))
if (schema.GetArgument(a->get_name()).ignore_cmp)
continue;

sorted_args.emplace(a->get_name(), a.get());
}

for (auto [name, a] : sorted_args) {
dali_proto::Argument *arg = op.add_args();
DaliProtoPriv arg_wrap(arg);

a->SerializeToProtobuf(&arg_wrap);
}

return op.SerializeAsString();
}

class CSE {
public:
explicit CSE(OpGraph &graph) : graph_(graph) {}
void Run(OpGraph &graph) {
for (auto &node : graph.OpNodes())
Run(&node);
for (auto output_name : graph.Outputs()) {
auto it = renamed_full_.find(output_name);

void Run() {
if (it != renamed_full_.end())
builder_.AddOutput(it->second);
else
builder_.AddOutput(std::string(output_name));
}
graph = {};
graph = std::move(builder_).GetGraph(true);
}

bool IsFoldable(const OpSpec &spec) {
return !spec.GetArgument<bool>("preserve") && !spec.GetArgument<bool>("preserve_name");
}

void Run(OpNode *node) {
for (int i = 0, ninp = node->inputs; i < ninp; ++i) {
OpSpec new_spec = node->spec;
for (int i = 0; i < new_spec.NumInput(); i++) {
auto it = renamed_.find(new_spec.InputName(i));
if (it != renamed_.end())
new_spec.RenameInput(i, it->second);
}
std::string key = OpSpecKey(new_spec);
OpNode *&norm = normalized_nodes_[key];
if (!norm || !IsFoldable(new_spec))
norm = node;

if (norm != node) {
for (int o = 0; o < node->spec.NumOutput(); o++) {
renamed_.emplace(node->spec.OutputName(o), norm->spec.OutputName(o));
renamed_full_.emplace(node->spec.Output(o), norm->spec.Output(o));
}
} else {
builder_.Add(norm->instance_name, new_spec);
}
}

std::map<std::string, OpNode *> normalized_nodes_;
std::map<OpNode *, OpNode *> renamed_;
OpGraph &graph_;
std::map<std::string, std::string, std::less<>> renamed_;
std::map<std::string, std::string, std::less<>> renamed_full_;
OpGraph::Builder builder_;
};

void EliminateCommonSubgraphs(OpGraph &graph) {
CSE cse;
cse.Run(graph);
}

} // namespace graph
Expand Down
3 changes: 2 additions & 1 deletion dali/pipeline/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include "dali/pipeline/operator/error_reporting.h"
#include "dali/pipeline/operator/name_utils.h"
#include "dali/pipeline/graph/graph2dot.h"
#include "dali/pipeline/graph/cse.h"

namespace dali {

Expand Down Expand Up @@ -525,7 +526,7 @@ void Pipeline::Build(std::vector<PipelineOutputDesc> output_descs) {
}

// Graph optimization goes here

graph::EliminateCommonSubgraphs(graph_);

// Load the final graph into the executor
executor_->Build(graph_);
Expand Down

0 comments on commit cb256d3

Please sign in to comment.