Skip to content

Commit

Permalink
Use protobuf for operator comparison. Fix bugs. Remove compare functi…
Browse files Browse the repository at this point in the history
…ons.

Signed-off-by: Michał Zientkiewicz <[email protected]>
  • Loading branch information
mzient committed Dec 16, 2024
1 parent cb256d3 commit 2c5b91a
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 273 deletions.
98 changes: 0 additions & 98 deletions dali/core/compare_test.cc

This file was deleted.

39 changes: 0 additions & 39 deletions dali/operators/reader/parser/tf_feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <vector>

#include "dali/core/common.h"
#include "dali/core/compare.h"
#include "dali/pipeline/proto/dali_proto_utils.h"

namespace dali {
Expand Down Expand Up @@ -206,40 +205,6 @@ class Feature {
}
}

bool operator<(const Feature &rhs) const {
return Compare(rhs) < 0;
}

int Compare(const Feature &rhs) const {
if (int cmp = has_shape_ - rhs.has_shape_)
return cmp;
if (int cmp = has_partial_shape_ - rhs.has_partial_shape_)
return cmp;
if (int cmp = type_ - rhs.type_)
return cmp;
switch (type_) {
case int64:
if (int64_t cmp = val_.int64 - rhs.val_.int64)
return cmp >> 32;
break;
case float32:
if (float cmp = val_.float32 - rhs.val_.float32)
return 1 - 2 * std::signbit(cmp);
break;
case string:
if (int cmp = val_.str.compare(rhs.val_.str))
return cmp;
break;
}
if (has_shape_)
if (int cmp = compare(shape_, rhs.shape_))
return cmp;
if (has_partial_shape_)
if (int cmp = compare(partial_shape_, rhs.partial_shape_))
return cmp;
return 0;
}

private:
bool has_shape_;
std::vector<Index> shape_;
Expand All @@ -249,10 +214,6 @@ class Feature {
std::vector<Index> partial_shape_;
};

inline int compare(const TFUtil::Feature &a, const TFUtil::Feature &b) {
return a.Compare(b);
}

} // namespace TFUtil

} // namespace dali
Expand Down
7 changes: 5 additions & 2 deletions dali/pipeline/graph/cse.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
// limitations under the License.

#include "dali/pipeline/graph/cse.h"
#include "dali/pipeline/dali.pb.h"
#include <functional>
#include <map>
#include <string>
#include <utility>
#include "dali/pipeline/dali.pb.h"

namespace dali {
namespace graph {
Expand Down Expand Up @@ -93,7 +94,9 @@ class CSE {
}
std::string key = OpSpecKey(new_spec);
OpNode *&norm = normalized_nodes_[key];
if (!norm || !IsFoldable(new_spec))
bool foldable = IsFoldable(new_spec);

if (!norm || !foldable)
norm = node;

if (norm != node) {
Expand Down
3 changes: 2 additions & 1 deletion dali/pipeline/operator/argument.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ inline std::shared_ptr<Argument> DeserializeProtobufVectorImpl(const DaliProtoPr
auto args = arg.extra_args();
std::vector<T> ret_val;
for (auto& a : args) {
const T& elem = DeserializeProtobuf(a)->Get<T>();
auto des = DeserializeProtobuf(a);
const T& elem = des->Get<T>();
ret_val.push_back(elem);
}
return Argument::Store(arg.name(), ret_val);
Expand Down
22 changes: 0 additions & 22 deletions dali/pipeline/operator/argument.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
#include <vector>

#include "dali/core/common.h"
#include "dali/core/compare.h"
#include "dali/core/error_handling.h"
#include "dali/pipeline/data/types.h"
#include "dali/pipeline/proto/dali_proto_utils.h"
Expand Down Expand Up @@ -145,8 +144,6 @@ class Argument {

virtual ~Argument() = default;

virtual int Compare(const Argument &other) const = 0;

protected:
Argument() : has_name_(false) {}

Expand All @@ -157,10 +154,6 @@ class Argument {
bool has_name_;
};

inline int compare(Argument &a, Argument &b) {
return a.Compare(b);
}

template <typename T>
class ArgumentInst : public Argument {
public:
Expand All @@ -186,13 +179,6 @@ class ArgumentInst : public Argument {
dali::SerializeToProtobuf(val.Get(), arg);
}

int Compare(const Argument &other) const override {
if (auto *pother = dynamic_cast<const ArgumentInst<T> *>(&other))
return compare(Get(), pother->Get());
else
return GetTypeId() - other.GetTypeId();
}

private:
ValueInst<T> val;
};
Expand Down Expand Up @@ -229,14 +215,6 @@ class ArgumentInst<std::vector<T>> : public Argument {
}
}

int Compare(const Argument &other) const override {
if (auto *pother = dynamic_cast<const ArgumentInst<std::vector<T>> *>(&other)) {
return compare(Get(), pother->Get());
} else {
return GetTypeId() - other.GetTypeId();
}
}

private:
ValueInst<std::vector<T>> val;
};
Expand Down
3 changes: 3 additions & 0 deletions dali/pipeline/pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,9 @@ int Pipeline::AddOperatorImpl(const OpSpec &const_spec, const std::string &inst_
if (spec.GetSchema().IsNoPrune())
spec.SetArg("preserve", true);

if (spec.SchemaName() == "ExternalSource")
spec.SetArg("preserve_name", true); // ExternalSource must not be collapsed in CSE

// Take a copy of the passed OpSpec for serialization purposes before any modification
this->op_specs_for_serialization_.push_back({inst_name, spec, logical_id});

Expand Down
61 changes: 55 additions & 6 deletions dali/test/python/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def define_graph(self):


def test_seed_serialize():
batch_size = 64
batch_size = 32

class HybridPipe(Pipeline):
def __init__(self, batch_size, num_threads, device_id):
Expand All @@ -574,18 +574,17 @@ def define_graph(self):
)
return (output, self.labels)

n = 30
orig_pipe = HybridPipe(batch_size=batch_size, num_threads=2, device_id=0)
s = orig_pipe.serialize()
for i in range(50):
for i in range(10):
pipe = Pipeline()
pipe.deserialize_and_build(s)
pipe_out = pipe.run()
pipe_out_cpu = pipe_out[0].as_cpu()
img_chw_test = pipe_out_cpu.at(n)
if i == 0:
img_chw = img_chw_test
assert np.sum(np.abs(img_chw - img_chw_test)) == 0
ref = pipe_out_cpu
else:
check_batch(pipe_out_cpu, ref)


def test_make_contiguous_serialize():
Expand Down Expand Up @@ -2333,3 +2332,53 @@ def def_ref():
(ref,) = ref_pipe.run()
check_batch(cpu, ref, bs, 0, 0, "HWC")
check_batch(gpu, ref, bs, 0, 0, "HWC")


def test_cse():
@pipeline_def(batch_size=8, num_threads=4, device_id=0)
def my_pipe():
a = fn.random.uniform(range=[0, 1], shape=(1,), seed=123)
b = fn.random.uniform(range=[0, 1], shape=(1,), seed=123)
c = fn.random.uniform(range=[0, 1], shape=(1,), seed=123)
i = fn.random.uniform(range=[0, 1], shape=(1,), seed=1234)

d = a[0]
e = a[0] # repeated a[0] should be ignored
f = c[0] # c -> a -> a[0]

g = a[0] + b[0] - c[0]
h = c[0] + a[0] - b[0]
return a, b, c, d, e, f, g, h, i

pipe = my_pipe()
pipe.build()
a, b, c, d, e, f, g, h, i = pipe.run()
assert a.data_ptr() == b.data_ptr()
assert a.data_ptr() == c.data_ptr()
assert a.data_ptr() != i.data_ptr()

assert d.data_ptr() == e.data_ptr()
assert d.data_ptr() == f.data_ptr()

assert g.data_ptr() == h.data_ptr()


def test_cse_cond():
@pipeline_def(batch_size=8, num_threads=4, device_id=0, enable_conditionals=True)
def my_pipe():
a = fn.random.uniform(range=[0, 1], shape=(1,), seed=123)
b = fn.random.uniform(range=[0, 1], shape=(1,), seed=123)

if a[0] > 0:
d = a
else:
d = b # this is the same as `a`

return a, b, d

pipe = my_pipe()
pipe.build()
a, b, d = pipe.run()
assert a.data_ptr() == b.data_ptr()
# `d` is opportunistically reassembled and gets the same first sample pointer as `a`
assert d.data_ptr() == a.data_ptr()
Loading

0 comments on commit 2c5b91a

Please sign in to comment.