-
Notifications
You must be signed in to change notification settings - Fork 629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Common Subgraph Elimination #5752
Changes from all commits
bb78a3c
05050b0
1168e6d
ade0416
cb256d3
2c5b91a
4aaa045
327cfe8
cd937b2
a332acf
03bde04
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#include "dali/pipeline/graph/cse.h" | ||
#include <functional> | ||
#include <map> | ||
#include <string> | ||
#include <utility> | ||
#include "dali/pipeline/dali.pb.h" | ||
|
||
namespace dali { | ||
namespace graph { | ||
|
||
namespace { | ||
|
||
/** Computes the CSE key by serializing the relevant subset of an OpSpec to protobuf */ | ||
std::string OpSpecCSEKey(const OpSpec &spec) { | ||
dali_proto::OpDef op; | ||
op.set_name(spec.SchemaName()); | ||
|
||
for (int i = 0; i < spec.NumInput(); ++i) { | ||
dali_proto::InputOutput *in = op.add_input(); | ||
in->set_name(spec.InputName(i)); | ||
in->set_device(spec.InputDevice(i)); | ||
if (spec.IsArgumentInput(i)) { | ||
in->set_arg_name(spec.ArgumentInputName(i)); | ||
} | ||
in->set_is_argument_input(spec.IsArgumentInput(i)); | ||
} | ||
|
||
for (int i = 0; i < spec.NumOutput(); ++i) { | ||
dali_proto::InputOutput *out = op.add_output(); | ||
// Use a placeholder instead of the real name | ||
out->set_name(std::to_string(i)); | ||
out->set_device(spec.OutputDevice(i)); | ||
} | ||
|
||
auto &schema = spec.GetSchemaOrDefault(); | ||
std::map<std::string_view, Argument *, std::less<>> sorted_args; | ||
for (auto &a : spec.Arguments()) { | ||
// Some arguments should be skipped when comparing operators | ||
auto arg_name = a->get_name(); | ||
if (schema.HasArgument(arg_name)) | ||
if (schema.GetArgument(arg_name).ignore_cmp) | ||
continue; | ||
|
||
sorted_args.emplace(arg_name, a.get()); | ||
} | ||
|
||
for (auto [name, a] : sorted_args) { | ||
dali_proto::Argument *arg = op.add_args(); | ||
DaliProtoPriv arg_wrap(arg); | ||
a->SerializeToProtobuf(&arg_wrap); | ||
} | ||
|
||
return op.SerializeAsString(); | ||
} | ||
|
||
/** The context for Common Subgraph Elimination */ | ||
class CSE { | ||
public: | ||
void Run(OpGraph &graph) { | ||
for (auto &node : graph.OpNodes()) | ||
Run(&node); | ||
for (auto output_name : graph.Outputs()) { | ||
auto it = renamed_full_.find(output_name); | ||
|
||
if (it != renamed_full_.end()) | ||
builder_.AddOutput(it->second); | ||
else | ||
builder_.AddOutput(std::string(output_name)); | ||
} | ||
graph = {}; | ||
graph = std::move(builder_).GetGraph(true); | ||
} | ||
|
||
bool IsFoldable(const OpSpec &spec) { | ||
return !spec.GetArgument<bool>("preserve") && | ||
!spec.GetArgument<bool>("preserve_name") && | ||
!spec.GetSchemaOrDefault().IsNoPrune(); | ||
} | ||
|
||
void Run(OpNode *node) { | ||
OpSpec new_spec = node->spec; | ||
for (int i = 0; i < new_spec.NumInput(); i++) { | ||
auto it = renamed_.find(new_spec.InputName(i)); | ||
if (it != renamed_.end()) | ||
new_spec.RenameInput(i, it->second); | ||
} | ||
std::string key = OpSpecCSEKey(new_spec); | ||
OpNode *&norm = normalized_nodes_[key]; | ||
bool foldable = IsFoldable(new_spec); | ||
|
||
if (!norm || !foldable) | ||
norm = node; | ||
|
||
if (norm != node) { | ||
for (int o = 0; o < node->spec.NumOutput(); o++) { | ||
renamed_.emplace(node->spec.OutputName(o), norm->spec.OutputName(o)); | ||
renamed_full_.emplace(node->spec.Output(o), norm->spec.Output(o)); | ||
} | ||
} else { | ||
builder_.Add(norm->instance_name, new_spec); | ||
} | ||
} | ||
|
||
std::map<std::string, OpNode *> normalized_nodes_; | ||
std::map<std::string, std::string, std::less<>> renamed_; | ||
std::map<std::string, std::string, std::less<>> renamed_full_; | ||
OpGraph::Builder builder_; | ||
}; | ||
|
||
} // namespace | ||
|
||
void EliminateCommonSubgraphs(OpGraph &graph) { | ||
CSE cse; | ||
cse.Run(graph); | ||
} | ||
|
||
} // namespace graph | ||
} // namespace dali |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
#ifndef DALI_PIPELINE_GRAPH_CSE_H_ | ||
#define DALI_PIPELINE_GRAPH_CSE_H_ | ||
|
||
#include "dali/pipeline/graph/op_graph2.h" | ||
|
||
namespace dali { | ||
namespace graph { | ||
|
||
/** Eliminate Common Subgraphs | ||
* | ||
* Runs a common subexpression (subgraph) analysis on the graph. | ||
* The graph is completely rewritten in the process. | ||
* | ||
* The algorithm works by traversing the original graph in topological order. | ||
* Each OpSpec is first updated by renaming the inputs to match the previously merged nodes. | ||
* If the updated OpSpec was already seen, then it is replaced and the output names are added | ||
* to the renaming map. | ||
* | ||
* To identify matching operators, a key is computed which consists of the OpSpec's schema name, | ||
* arguments, inputs and output devices (but NOT output names!). | ||
* Some arguments are ignored - notably, the ones identifying the source location in Python | ||
* (that would make any kind of CSE pointless). | ||
* | ||
* If the key matches one previously seen, the operators are assumed equal and can be merged, | ||
* with several exceptions. | ||
* | ||
* The operators which are not merged: | ||
* - ExternalSource | ||
* - operators with explicitly given name | ||
* - operators with "preserve" argument set | ||
* - operators with NoPrune schema | ||
* | ||
* Example: | ||
* | ||
* ``` | ||
* op1(args1) --- out1_0_A --- op2(args2) -- out2_0 --> pipeline_output_0 | ||
* __op1_0 \ / __op2_0 | ||
* --- out1_0_B -- | ||
* | ||
* op1(args1) --- out1_1_A --- op2(args2) -- out2_1 --> pipeline_output_1 | ||
* __op1_1 \ / __op2_1 | ||
* --- out1_1_B -- | ||
* ``` | ||
* | ||
* In the example above, the two instances of op1 are identical so they're collapsed into one, | ||
* the __op1_0. The renaming map is: | ||
* out1_1_A : out1_0_A | ||
* out1_1_B : out1_0_B | ||
* | ||
* After renaming the inputs to __op2_1, we get: | ||
* | ||
* ``` | ||
* op1(args1) --+-- out1_0_A ----- op2(args2) -- out2_0 --> pipeline_output_0 | ||
* __op1_0 | \ / __op2_0 | ||
* +----(-- out1_0_B - | ||
* | \ | ||
* | --------- op2(args2) ------ out2_1 --> pipeline_output_1 | ||
* \ / __op2_1 | ||
* --------------- | ||
* ``` | ||
* At this point, __op2_1 is identical to __op2_0 and can be removed. The final graph: | ||
* | ||
* ``` | ||
* op1(args1) --- out1_0_A --- op2(args2) -- out2_0 -+---> pipeline_output_0 | ||
* __op1_0 \ / __op2_0 \ | ||
* --- out1_0_B -- --> pipeline_output_1 | ||
* ``` | ||
* | ||
*/ | ||
void EliminateCommonSubgraphs(OpGraph &graph); | ||
|
||
} // namespace graph | ||
} // namespace dali | ||
|
||
|
||
#endif // DALI_PIPELINE_GRAPH_CSE_H_ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -41,7 +41,8 @@ inline std::shared_ptr<Argument> DeserializeProtobufVectorImpl(const DaliProtoPr | |
auto args = arg.extra_args(); | ||
std::vector<T> ret_val; | ||
for (auto& a : args) { | ||
const T& elem = DeserializeProtobuf(a)->Get<T>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After changing |
||
auto des = DeserializeProtobuf(a); | ||
const T& elem = des->Get<T>(); | ||
ret_val.push_back(elem); | ||
} | ||
return Argument::Store(arg.name(), ret_val); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -62,9 +62,30 @@ const OpSchema &OpSchema::Default() { | |
return default_schema; | ||
} | ||
|
||
namespace { | ||
constexpr const char *default_module = "nvidia.dali.ops"; | ||
} // namespace | ||
|
||
OpSchema::OpSchema(std::string_view name) : name_(name) { | ||
// Process the module path and operator name | ||
InitNames(); | ||
|
||
std::string module = default_module; | ||
for (const auto &submodule : ModulePath()) { | ||
module += "." + submodule; | ||
} | ||
|
||
AddOptionalArg("_module", | ||
"String identifying the module in which the operator is defined. " | ||
"Most of the time it is `__module__` of the API function/class.", | ||
module); | ||
arguments_["_module"].ignore_cmp = true; | ||
|
||
AddOptionalArg("_display_name", | ||
"Operator name as presented in the API it was instantiated in (without the module " | ||
"path), for example: cast_like or CastLike.", | ||
OperatorName()); | ||
arguments_["_display_name"].ignore_cmp = true; | ||
} | ||
|
||
OpSchema::OpSchema(DefaultSchemaTag) : name_(""), default_(true) { | ||
|
@@ -76,6 +97,10 @@ OpSchema::OpSchema(DefaultSchemaTag) : name_(""), default_(true) { | |
AddInternalArg("default_cuda_stream_priority", "Default cuda stream priority", 0); // deprecated | ||
AddInternalArg("checkpointing", "Setting to `true` enables checkpointing", false); | ||
|
||
AddInternalArg("preserve_name", R"(When true, the operator cannot be renamed. | ||
This disables merging this operator with another one with a different name.)", | ||
false); | ||
Comment on lines
+100
to
+102
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't this be the case always when an operator was explicitly named by the user? I understand that having two equal ops (Readers?) but with different names is not a clever way to use DALI, but it is possible for the user to do that. Let's say that for whatever reason I have two readers doing the same work, but named differently ("Reader1", "Reader2"). I wrapped the pipeline into something (an iterator?) that relies on a presence of "Reader2" in the pipeline. This is not a good code, but it works. And after CSE it doesn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's exactly how it works. If you specify a name manually, then There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see:
Thank you! |
||
|
||
AddOptionalArg<int>("seed", R"code(Random seed. | ||
If not provided, it will be populated based on the global seed of the pipeline.)code", | ||
nullptr); | ||
|
@@ -89,6 +114,7 @@ to accommodate a batch of samples of this size.)code", | |
AddOptionalArg("preserve", R"code(Prevents the operator from being removed from the | ||
graph even if its outputs are not used.)code", | ||
false); | ||
arguments_["preserve"].ignore_cmp = true; | ||
|
||
// For simplicity we pass StackSummary as 4 separate arguments so we don't need to extend DALI | ||
// with support for special FrameSummary type. | ||
|
@@ -103,37 +129,39 @@ messages, pointing to the origin of the error in pipeline definition. | |
The list of FrameSummaries is split into four parameters: each is the list containing corresponding | ||
parameters of FrameSummary. This parameter represents the `filename` member.)code", | ||
std::vector<std::string>{}); | ||
arguments_["_origin_stack_filename"].ignore_cmp = true; | ||
|
||
AddOptionalArg("_origin_stack_lineno", R"code(StackSummary - lineno member of FrameSummary, see | ||
_origin_stack_filename for more information.)code", | ||
std::vector<int>{}); | ||
arguments_["_origin_stack_lineno"].ignore_cmp = true; | ||
|
||
AddOptionalArg("_origin_stack_name", R"code(StackSummary - name member of FrameSummary, see | ||
_origin_stack_filename for more information.)code", | ||
std::vector<std::string>{}); | ||
arguments_["_origin_stack_name"].ignore_cmp = true; | ||
|
||
AddOptionalArg("_origin_stack_line", R"code(StackSummary - line member of FrameSummary, see | ||
_origin_stack_filename for more information.)code", | ||
std::vector<std::string>{}); | ||
arguments_["_origin_stack_line"].ignore_cmp = true; | ||
|
||
AddOptionalArg("_pipeline_internal", R"code(Boolean specifying if this operator was defined within | ||
a pipeline scope. False if it was defined without pipeline being set as current.)code", | ||
true); | ||
|
||
std::string default_module = "nvidia.dali.ops"; | ||
for (const auto &submodule : ModulePath()) { | ||
default_module += "." + submodule; | ||
} | ||
arguments_["_pipeline_internal"].ignore_cmp = true; | ||
|
||
AddOptionalArg("_module", | ||
"String identifying the module in which the operator is defined. " | ||
"Most of the time it is `__module__` of the API function/class.", | ||
default_module); | ||
arguments_["_module"].ignore_cmp = true; | ||
|
||
AddOptionalArg("_display_name", | ||
"Operator name as presented in the API it was instantiated in (without the module " | ||
"path), for example: cast_like or CastLike.", | ||
OperatorName()); | ||
"<empty>"); | ||
arguments_["_display_name"].ignore_cmp = true; | ||
|
||
DeprecateArg("seed", true, | ||
"The argument \"seed\" should not be used with operators that don't use " | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OpNodes returns the nodes in topological order?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the graph is sorted by the Builder.