Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Common Subgraph Elimination #5752

Merged
merged 11 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 132 additions & 0 deletions dali/pipeline/graph/cse.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dali/pipeline/graph/cse.h"
#include <functional>
#include <map>
#include <string>
#include <utility>
#include "dali/pipeline/dali.pb.h"

namespace dali {
namespace graph {

namespace {

/** Computes the CSE key by serializing the relevant subset of an OpSpec to protobuf */
std::string OpSpecCSEKey(const OpSpec &spec) {
dali_proto::OpDef op;
op.set_name(spec.SchemaName());

for (int i = 0; i < spec.NumInput(); ++i) {
dali_proto::InputOutput *in = op.add_input();
in->set_name(spec.InputName(i));
in->set_device(spec.InputDevice(i));
if (spec.IsArgumentInput(i)) {
in->set_arg_name(spec.ArgumentInputName(i));
}
in->set_is_argument_input(spec.IsArgumentInput(i));
}

for (int i = 0; i < spec.NumOutput(); ++i) {
dali_proto::InputOutput *out = op.add_output();
// Use a placeholder instead of the real name
out->set_name(std::to_string(i));
out->set_device(spec.OutputDevice(i));
}

auto &schema = spec.GetSchemaOrDefault();
std::map<std::string_view, Argument *, std::less<>> sorted_args;
for (auto &a : spec.Arguments()) {
// Some arguments should be skipped when comparing operators
auto arg_name = a->get_name();
if (schema.HasArgument(arg_name))
if (schema.GetArgument(arg_name).ignore_cmp)
continue;

sorted_args.emplace(arg_name, a.get());
}

for (auto [name, a] : sorted_args) {
dali_proto::Argument *arg = op.add_args();
DaliProtoPriv arg_wrap(arg);
a->SerializeToProtobuf(&arg_wrap);
}

return op.SerializeAsString();
}

/** The context for Common Subgraph Elimination */
class CSE {
public:
void Run(OpGraph &graph) {
for (auto &node : graph.OpNodes())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OpNodes returns the nodes in topological order?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the graph is sorted by the Builder.

Run(&node);
for (auto output_name : graph.Outputs()) {
auto it = renamed_full_.find(output_name);

if (it != renamed_full_.end())
builder_.AddOutput(it->second);
else
builder_.AddOutput(std::string(output_name));
}
graph = {};
graph = std::move(builder_).GetGraph(true);
}

bool IsFoldable(const OpSpec &spec) {
return !spec.GetArgument<bool>("preserve") &&
!spec.GetArgument<bool>("preserve_name") &&
!spec.GetSchemaOrDefault().IsNoPrune();
}

void Run(OpNode *node) {
OpSpec new_spec = node->spec;
for (int i = 0; i < new_spec.NumInput(); i++) {
auto it = renamed_.find(new_spec.InputName(i));
if (it != renamed_.end())
new_spec.RenameInput(i, it->second);
}
std::string key = OpSpecCSEKey(new_spec);
OpNode *&norm = normalized_nodes_[key];
bool foldable = IsFoldable(new_spec);

if (!norm || !foldable)
norm = node;

if (norm != node) {
for (int o = 0; o < node->spec.NumOutput(); o++) {
renamed_.emplace(node->spec.OutputName(o), norm->spec.OutputName(o));
renamed_full_.emplace(node->spec.Output(o), norm->spec.Output(o));
}
} else {
builder_.Add(norm->instance_name, new_spec);
}
}

std::map<std::string, OpNode *> normalized_nodes_;
std::map<std::string, std::string, std::less<>> renamed_;
std::map<std::string, std::string, std::less<>> renamed_full_;
OpGraph::Builder builder_;
};

} // namespace

void EliminateCommonSubgraphs(OpGraph &graph) {
CSE cse;
cse.Run(graph);
}

} // namespace graph
} // namespace dali
90 changes: 90 additions & 0 deletions dali/pipeline/graph/cse.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef DALI_PIPELINE_GRAPH_CSE_H_
#define DALI_PIPELINE_GRAPH_CSE_H_

#include "dali/pipeline/graph/op_graph2.h"

namespace dali {
namespace graph {

/** Eliminate Common Subgraphs
*
* Runs a common subexpression (subgraph) analysis on the graph.
* The graph is completely rewritten in the process.
*
* The algorithm works by traversing the original graph in topological order.
* Each OpSpec is first updated by renaming the inputs to match the previously merged nodes.
* If the updated OpSpec was already seen, then it is replaced and the output names are added
* to the renaming map.
*
* To identify matching operators, a key is computed which consists of the OpSpec's schema name,
* arguments, inputs and output devices (but NOT output names!).
* Some arguments are ignored - notably, the ones identifying the source location in Python
* (that would make any kind of CSE pointless).
*
* If the key matches one previously seen, the operators are assumed equal and can be merged,
* with several exceptions.
*
* The operators which are not merged:
* - ExternalSource
* - operators with explicitly given name
* - operators with "preserve" argument set
* - operators with NoPrune schema
*
* Example:
*
* ```
* op1(args1) --- out1_0_A --- op2(args2) -- out2_0 --> pipeline_output_0
* __op1_0 \ / __op2_0
* --- out1_0_B --
*
* op1(args1) --- out1_1_A --- op2(args2) -- out2_1 --> pipeline_output_1
* __op1_1 \ / __op2_1
* --- out1_1_B --
* ```
*
* In the example above, the two instances of op1 are identical so they're collapsed into one,
* the __op1_0. The renaming map is:
* out1_1_A : out1_0_A
* out1_1_B : out1_0_B
*
* After renaming the inputs to __op2_1, we get:
*
* ```
* op1(args1) --+-- out1_0_A ----- op2(args2) -- out2_0 --> pipeline_output_0
* __op1_0 | \ / __op2_0
* +----(-- out1_0_B -
* | \
* | --------- op2(args2) ------ out2_1 --> pipeline_output_1
* \ / __op2_1
* ---------------
* ```
* At this point, __op2_1 is identical to __op2_0 and can be removed. The final graph:
*
* ```
* op1(args1) --- out1_0_A --- op2(args2) -- out2_0 -+---> pipeline_output_0
* __op1_0 \ / __op2_0 \
* --- out1_0_B -- --> pipeline_output_1
* ```
*
*/
void EliminateCommonSubgraphs(OpGraph &graph);

} // namespace graph
} // namespace dali


#endif // DALI_PIPELINE_GRAPH_CSE_H_
3 changes: 2 additions & 1 deletion dali/pipeline/operator/argument.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ inline std::shared_ptr<Argument> DeserializeProtobufVectorImpl(const DaliProtoPr
auto args = arg.extra_args();
std::vector<T> ret_val;
for (auto& a : args) {
const T& elem = DeserializeProtobuf(a)->Get<T>();
Copy link
Contributor Author

@mzient mzient Dec 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After changing Argument::Get to returning a reference, this became a bug: returning a reference to a field of the result of DeserializeProtobuf - since we get a temporary shared_ptr, it would be destroyed (along with the referenced field!) before the field reference could be used.

auto des = DeserializeProtobuf(a);
const T& elem = des->Get<T>();
ret_val.push_back(elem);
}
return Argument::Store(arg.name(), ret_val);
Expand Down
20 changes: 10 additions & 10 deletions dali/pipeline/operator/argument.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2017-2018, 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// Copyright (c) 2017-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -81,7 +81,7 @@ class ValueInst : public Value {
return to_string(val_);
}

const T &Get() const {
const T &Get() const & {
return val_;
}

Expand Down Expand Up @@ -134,10 +134,10 @@ class Argument {
virtual void SerializeToProtobuf(DaliProtoPriv* arg) = 0;

template <typename T>
T Get();
const T &Get() const &;

template <typename T>
bool IsType();
bool IsType() const;

template <typename T>
static std::shared_ptr<Argument> Store(const std::string& s, const T& val);
Expand All @@ -159,7 +159,7 @@ class ArgumentInst : public Argument {
public:
explicit ArgumentInst(const std::string& s, const T& v) : Argument(s), val(v) {}

T Get() {
const T &Get() const & {
return val.Get();
}

Expand Down Expand Up @@ -188,7 +188,7 @@ class ArgumentInst<std::vector<T>> : public Argument {
public:
explicit ArgumentInst(const std::string& s, const std::vector<T>& v) : Argument(s), val(v) {}

std::vector<T> Get() {
const std::vector<T> &Get() const & {
return val.Get();
}

Expand Down Expand Up @@ -222,13 +222,13 @@ class ArgumentInst<std::vector<T>> : public Argument {
DLL_PUBLIC std::shared_ptr<Argument> DeserializeProtobuf(const DaliProtoPriv &arg);

template <typename T>
bool Argument::IsType() {
return dynamic_cast<ArgumentInst<T>*>(this) != nullptr;
bool Argument::IsType() const {
return dynamic_cast<const ArgumentInst<T>*>(this) != nullptr;
}

template <typename T>
T Argument::Get() {
ArgumentInst<T>* self = dynamic_cast<ArgumentInst<T>*>(this);
const T &Argument::Get() const & {
auto *self = dynamic_cast<const ArgumentInst<T>*>(this);
if (self == nullptr) {
DALI_FAIL(make_string("Invalid type of argument \"", get_name(), "\". Expected ",
typeid(T).name()));
Expand Down
40 changes: 34 additions & 6 deletions dali/pipeline/operator/op_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,30 @@ const OpSchema &OpSchema::Default() {
return default_schema;
}

namespace {
constexpr const char *default_module = "nvidia.dali.ops";
} // namespace

OpSchema::OpSchema(std::string_view name) : name_(name) {
// Process the module path and operator name
InitNames();

std::string module = default_module;
for (const auto &submodule : ModulePath()) {
module += "." + submodule;
}

AddOptionalArg("_module",
"String identifying the module in which the operator is defined. "
"Most of the time it is `__module__` of the API function/class.",
module);
arguments_["_module"].ignore_cmp = true;

AddOptionalArg("_display_name",
"Operator name as presented in the API it was instantiated in (without the module "
"path), for example: cast_like or CastLike.",
OperatorName());
arguments_["_display_name"].ignore_cmp = true;
}

OpSchema::OpSchema(DefaultSchemaTag) : name_(""), default_(true) {
Expand All @@ -76,6 +97,10 @@ OpSchema::OpSchema(DefaultSchemaTag) : name_(""), default_(true) {
AddInternalArg("default_cuda_stream_priority", "Default cuda stream priority", 0); // deprecated
AddInternalArg("checkpointing", "Setting to `true` enables checkpointing", false);

AddInternalArg("preserve_name", R"(When true, the operator cannot be renamed.
This disables merging this operator with another one with a different name.)",
false);
Comment on lines +100 to +102
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be the case always when an operator was explicitly named by the user?

I understand that having two equal ops (Readers?) but with different names is not a clever way to use DALI, but it is possible for the user to do that. Let's say that for whatever reason I have two readers doing the same work, but named differently ("Reader1", "Reader2"). I wrapped the pipeline into something (an iterator?) that relies on a presence of "Reader2" in the pipeline. This is not a good code, but it works. And after CSE it doesn't

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's exactly how it works. If you specify a name manually, then preserve_name is set to True.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see:

self._spec.AddArg("preserve_name", not self._autoname)

Thank you!


AddOptionalArg<int>("seed", R"code(Random seed.
If not provided, it will be populated based on the global seed of the pipeline.)code",
nullptr);
Expand All @@ -89,6 +114,7 @@ to accommodate a batch of samples of this size.)code",
AddOptionalArg("preserve", R"code(Prevents the operator from being removed from the
graph even if its outputs are not used.)code",
false);
arguments_["preserve"].ignore_cmp = true;

// For simplicity we pass StackSummary as 4 separate arguments so we don't need to extend DALI
// with support for special FrameSummary type.
Expand All @@ -103,37 +129,39 @@ messages, pointing to the origin of the error in pipeline definition.
The list of FrameSummaries is split into four parameters: each is the list containing corresponding
parameters of FrameSummary. This parameter represents the `filename` member.)code",
std::vector<std::string>{});
arguments_["_origin_stack_filename"].ignore_cmp = true;

AddOptionalArg("_origin_stack_lineno", R"code(StackSummary - lineno member of FrameSummary, see
_origin_stack_filename for more information.)code",
std::vector<int>{});
arguments_["_origin_stack_lineno"].ignore_cmp = true;

AddOptionalArg("_origin_stack_name", R"code(StackSummary - name member of FrameSummary, see
_origin_stack_filename for more information.)code",
std::vector<std::string>{});
arguments_["_origin_stack_name"].ignore_cmp = true;

AddOptionalArg("_origin_stack_line", R"code(StackSummary - line member of FrameSummary, see
_origin_stack_filename for more information.)code",
std::vector<std::string>{});
arguments_["_origin_stack_line"].ignore_cmp = true;

AddOptionalArg("_pipeline_internal", R"code(Boolean specifying if this operator was defined within
a pipeline scope. False if it was defined without pipeline being set as current.)code",
true);

std::string default_module = "nvidia.dali.ops";
for (const auto &submodule : ModulePath()) {
default_module += "." + submodule;
}
arguments_["_pipeline_internal"].ignore_cmp = true;

AddOptionalArg("_module",
"String identifying the module in which the operator is defined. "
"Most of the time it is `__module__` of the API function/class.",
default_module);
arguments_["_module"].ignore_cmp = true;

AddOptionalArg("_display_name",
"Operator name as presented in the API it was instantiated in (without the module "
"path), for example: cast_like or CastLike.",
OperatorName());
"<empty>");
arguments_["_display_name"].ignore_cmp = true;

DeprecateArg("seed", true,
"The argument \"seed\" should not be used with operators that don't use "
Expand Down
Loading
Loading