Skip to content

Commit

Permalink
Assorted fixes:
Browse files Browse the repository at this point in the history
- restore _module and _display name in the default schema
- add preserve(_name) in Pipeline tests where CSE is unwanted
- fix and extend Python tests
- extend documentation

Signed-off-by: Michał Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Dec 16, 2024
1 parent 2c5b91a commit 4aaa045
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 18 deletions.
24 changes: 16 additions & 8 deletions dali/pipeline/graph/cse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
namespace dali {
namespace graph {

std::string OpSpecKey(const OpSpec &spec) {
namespace {

/** Computes the CSE key by serializing the relevant subset of an OpSpec to protobuf */
std::string OpSpecCSEKey(const OpSpec &spec) {
dali_proto::OpDef op;
op.set_name(spec.SchemaName());

Expand All @@ -38,21 +41,21 @@ std::string OpSpecKey(const OpSpec &spec) {

for (int i = 0; i < spec.NumOutput(); ++i) {
dali_proto::InputOutput *out = op.add_output();
// clear output name!
// Use a placeholder instead of the real name
out->set_name(std::to_string(i));
out->set_device(spec.OutputDevice(i));
out->set_is_argument_input(false);
}

auto &schema = spec.GetSchemaOrDefault();
std::map<std::string_view, Argument *, std::less<>> sorted_args;
for (auto &a : spec.Arguments()) {
// Some arguments should be skipped when comparing operators
if (schema.HasArgument(a->get_name()))
if (schema.GetArgument(a->get_name()).ignore_cmp)
auto arg_name = a->get_name();
if (schema.HasArgument(arg_name))
if (schema.GetArgument(arg_name).ignore_cmp)
continue;

sorted_args.emplace(a->get_name(), a.get());
sorted_args.emplace(arg_name, a.get());
}

for (auto [name, a] : sorted_args) {
Expand All @@ -64,6 +67,7 @@ std::string OpSpecKey(const OpSpec &spec) {
return op.SerializeAsString();
}

/** The context for Common Subgraph Elimination */
class CSE {
public:
void Run(OpGraph &graph) {
Expand All @@ -82,7 +86,9 @@ class CSE {
}

bool IsFoldable(const OpSpec &spec) {
return !spec.GetArgument<bool>("preserve") && !spec.GetArgument<bool>("preserve_name");
return !spec.GetArgument<bool>("preserve") &&
!spec.GetArgument<bool>("preserve_name") &&
!spec.GetSchemaOrDefault().IsNoPrune();
}

void Run(OpNode *node) {
Expand All @@ -92,7 +98,7 @@ class CSE {
if (it != renamed_.end())
new_spec.RenameInput(i, it->second);
}
std::string key = OpSpecKey(new_spec);
std::string key = OpSpecCSEKey(new_spec);
OpNode *&norm = normalized_nodes_[key];
bool foldable = IsFoldable(new_spec);

Expand All @@ -115,6 +121,8 @@ class CSE {
OpGraph::Builder builder_;
};

} // namespace

void EliminateCommonSubgraphs(OpGraph &graph) {
CSE cse;
cse.Run(graph);
Expand Down
33 changes: 33 additions & 0 deletions dali/pipeline/graph/cse.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,39 @@
namespace dali {
namespace graph {

/** Eliminate Common Subgraphs
*
* Runs a common subexpression (subgraph) analysis on the graph.
* The graph is completely rewritten in the process.
*
* The algorithm works by traversing the original graph in topological order.
* Each OpSpec is first updated by renaming the inputs to match the previously merged nodes.
* If the update OpSpec was already seen, then the OpSpec is replaced and the output names
* are added to the renaming map.
*
* op1(args1) --- out1_A -- op2(args2)
* \ /
* --- out1_B -
*
* op2(args1) --- out1_A -- op2(args2)
* \ /
* --- out1_B -
*
* To identify matchin operators, a key is computed which consists of the OpSpec's schema name,
* arguments, inputs and output devices (but NOT output names!).
*
* If the key matches one previously seen, the operators are assumed equal.
*
* Some arguments are ignored - notably, the ones identifying the source location in Python
* (that would make any kind of CSE pointless).
*
* There are some operators which are never merged:
* - ExternalSource
* - operators with explicitly given name
* - operators with "preserve" argument set
* - operators with NoPrune schema
*
*/
void EliminateCommonSubgraphs(OpGraph &graph);

} // namespace graph
Expand Down
22 changes: 19 additions & 3 deletions dali/pipeline/operator/op_schema.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,23 @@ const OpSchema &OpSchema::Default() {
return default_schema;
}

namespace {
constexpr const char *default_module = "nvidia.dali.ops";
} // namespace

OpSchema::OpSchema(std::string_view name) : name_(name) {
// Process the module path and operator name
InitNames();

std::string default_module = "nvidia.dali.ops";
std::string module = default_module;
for (const auto &submodule : ModulePath()) {
default_module += "." + submodule;
module += "." + submodule;
}

AddOptionalArg("_module",
"String identifying the module in which the operator is defined. "
"Most of the time it is `__module__` of the API function/class.",
default_module);
module);
arguments_["_module"].ignore_cmp = true;

AddOptionalArg("_display_name",
Expand Down Expand Up @@ -151,6 +155,18 @@ a pipeline scope. False if it was defined without pipeline being set as current.
"The argument \"seed\" should not be used with operators that don't use "
"random numbers.");
arguments_["seed"].hidden = true;

AddOptionalArg("_module",
"String identifying the module in which the operator is defined. "
"Most of the time it is `__module__` of the API function/class.",
default_module);
arguments_["_module"].ignore_cmp = true;

AddOptionalArg("_display_name",
"Operator name as presented in the API it was instantiated in (without the module "
"path), for example: cast_like or CastLike.",
"<empty>");
arguments_["_display_name"].ignore_cmp = true;
}


Expand Down
4 changes: 4 additions & 0 deletions dali/pipeline/pipeline_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ TEST_F(PipelineTestOnce, TestPresize) {
.AddArg("device", "cpu")
.AddArg("bytes_per_sample_hint", presize_val_CPU)
.AddInput("raw_jpegs", "cpu")
.AddArg("preserve", true)
.AddOutput("out_2", "cpu"));

pipe.AddOperator(
Expand Down Expand Up @@ -867,18 +868,21 @@ TEST(PipelineTest, AutoName) {
int id = pipe.AddOperator(
OpSpec("Copy")
.AddArg("device", "gpu")
.AddArg("preserve_name", true) // suppress CSE
.AddInput("data", "gpu")
.AddOutput("copied1", "gpu"), 1);

EXPECT_NO_THROW(pipe.AddOperator(
OpSpec("Copy")
.AddArg("device", "gpu")
.AddArg("preserve_name", true)
.AddInput("data", "gpu")
.AddOutput("copied2", "gpu"), id));

EXPECT_NO_THROW(pipe.AddOperator(
OpSpec("Copy")
.AddArg("device", "gpu")
.AddArg("preserve_name", true)
.AddInput("data", "gpu")
.AddOutput("copied3", "gpu"), id));

Expand Down
22 changes: 15 additions & 7 deletions dali/test/python/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2340,22 +2340,24 @@ def my_pipe():
a = fn.random.uniform(range=[0, 1], shape=(1,), seed=123)
b = fn.random.uniform(range=[0, 1], shape=(1,), seed=123)
c = fn.random.uniform(range=[0, 1], shape=(1,), seed=123)
i = fn.random.uniform(range=[0, 1], shape=(1,), seed=1234)
i = fn.random.uniform(range=[0, 1], shape=(1,), seed=1234) # different seed - must not CSE
j = fn.random.uniform(range=[0, 1], shape=(1,), seed=123, name="do_not_merge")

d = a[0]
e = a[0] # repeated a[0] should be ignored
f = c[0] # c -> a -> a[0]
f = c[0] # c -> a, so it follows that c[0] -> a[0]

g = a[0] + b[0] - c[0]
h = c[0] + a[0] - b[0]
return a, b, c, d, e, f, g, h, i
g = a[0] + b[0] - c[0] # a[0] + a[0] - a[0]
h = c[0] + a[0] - b[0] # likewise
return a, b, c, d, e, f, g, h, i, j

pipe = my_pipe()
pipe.build()
a, b, c, d, e, f, g, h, i = pipe.run()
a, b, c, d, e, f, g, h, i, j = pipe.run()
assert a.data_ptr() == b.data_ptr()
assert a.data_ptr() == c.data_ptr()
assert a.data_ptr() != i.data_ptr()
assert j.data_ptr() != a.data_ptr() # j has a manually specified name and should not be merged

assert d.data_ptr() == e.data_ptr()
assert d.data_ptr() == f.data_ptr()
Expand All @@ -2364,7 +2366,13 @@ def my_pipe():


def test_cse_cond():
@pipeline_def(batch_size=8, num_threads=4, device_id=0, enable_conditionals=True)
@pipeline_def(
batch_size=8,
num_threads=4,
device_id=0,
enable_conditionals=True,
exec_dynamic=True, # required for opportunistic MakeContiguous
)
def my_pipe():
a = fn.random.uniform(range=[0, 1], shape=(1,), seed=123)
b = fn.random.uniform(range=[0, 1], shape=(1,), seed=123)
Expand Down

0 comments on commit 4aaa045

Please sign in to comment.