From 2c5b91a6dc710a8589b1350c9d3f4bec423101e9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Zientkiewicz?= Date: Mon, 16 Dec 2024 14:33:52 +0100 Subject: [PATCH] Use protobuf for operator comparison. Fix bugs. Remove compare functions. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: MichaƂ Zientkiewicz --- dali/core/compare_test.cc | 98 -------------------- dali/operators/reader/parser/tf_feature.h | 39 -------- dali/pipeline/graph/cse.cc | 7 +- dali/pipeline/operator/argument.cc | 3 +- dali/pipeline/operator/argument.h | 22 ----- dali/pipeline/pipeline.cc | 3 + dali/test/python/test_pipeline.py | 61 +++++++++++-- include/dali/core/compare.h | 105 ---------------------- 8 files changed, 65 insertions(+), 273 deletions(-) delete mode 100644 dali/core/compare_test.cc delete mode 100644 include/dali/core/compare.h diff --git a/dali/core/compare_test.cc b/dali/core/compare_test.cc deleted file mode 100644 index 7cbb90a0d07..00000000000 --- a/dali/core/compare_test.cc +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "dali/core/compare.h" -#include -#include -#include -#include -#include "dali/core/int_literals.h" - -using namespace std::literals; // NOLINT - -namespace dali { - -TEST(CompareTest, Primitive) { - EXPECT_EQ(compare(1, 1_u8), 0); - EXPECT_LT(compare(-2, 1_u8), 0); - EXPECT_GT(compare(2, 1_i64), 0); - EXPECT_GT(compare(0x110000007f000000_i64, 0x100000007fffffff_i64), 0); - EXPECT_LT(compare(0x100000007f000000_i64, 0x100000007fffffff_i64), 0); - EXPECT_GT(compare(float16(2), 1.0f), 0); - EXPECT_EQ(compare(2.0f, float16(2)), 0); -} - -TEST(CompareTest, Enum) { - enum E { - A = 1, - B = 2, - C = 3 - }; - EXPECT_EQ(compare(A, 1.0f), 0); - EXPECT_LT(compare(A, B), 0); - EXPECT_EQ(compare(C, A + B), 0); -} - -TEST(CompareTest, String) { - EXPECT_LT(compare("abc", "abcd"), 0); - EXPECT_GT(compare(std::string("abcde"), std::string("abcd")), 0); - EXPECT_LT(compare(std::string_view("abcd"), std::string_view("abdd")), 0); -} - -TEST(CompareTest, CArray) { - int shorter[] = { 1, 2, 3 }; - int longer[] = { 1, 2, 3, 4 }; - EXPECT_LT(compare(shorter, longer), 0); - EXPECT_GT(compare(longer, shorter), 0); - EXPECT_EQ(compare(shorter, shorter), 0); - - int ints[] = { 1, 2, 3 }; - int16_t shorts[] = { 1, 2, 3 }; - EXPECT_EQ(compare(ints, shorts), 0); - - int16_t different[] = { 1, 3, 3 }; - EXPECT_LT(compare(ints, different), 0); -} - -TEST(CompareTest, MixedCollections) { - std::vector shorter = { 1, 2, 3 }; - int longer[] = { 1, 2, 3, 4 }; - EXPECT_LT(compare(shorter, longer), 0); - EXPECT_GT(compare(longer, shorter), 0); - EXPECT_EQ(compare(shorter, shorter), 0); - - int ints[] = { 1, 2, 3 }; - std::array shorts = {{ 1, 2, 3 }}; - EXPECT_EQ(compare(ints, shorts), 0); - - std::list different = { 1, 3, 3 }; - EXPECT_LT(compare(ints, different), 0); -} - -TEST(CompareTest, Tuple) { - EXPECT_EQ(compare(std::make_tuple(1, 2, 3), std::make_tuple(1_u8, 2.0f, 3.0)), 0); - EXPECT_GT(compare(std::make_tuple(1, 2, 3), std::make_tuple(1_u8, 2.0f)), 0); - EXPECT_LT(compare(std::make_tuple(1, 2), std::make_tuple(1_u8, 2.0f, 3.0)), 0); - - EXPECT_LT(compare(std::make_tuple(1, "Former"s, 3), std::make_tuple(1_u8, "Jesse"s, 3.0)), 0); - EXPECT_GT(compare(std::make_tuple(1.000001f, "a"sv, 3), std::make_tuple(1_u8, "b"sv, 3.0)), 0); -} - -TEST(CompareTest, Pair) { - EXPECT_EQ(compare(std::make_pair(1.0f, 42), std::make_pair(1, 42.0f)), 0); - EXPECT_GT(compare(std::make_pair(1.1f, 42), std::make_pair(1, 42.0f)), 0); - EXPECT_LT(compare(std::make_pair(1.0f, 42), std::make_pair(1, 42.1f)), 0); -} - -} // namespace dali diff --git a/dali/operators/reader/parser/tf_feature.h b/dali/operators/reader/parser/tf_feature.h index d0185d469b8..2915dbe81b4 100644 --- a/dali/operators/reader/parser/tf_feature.h +++ b/dali/operators/reader/parser/tf_feature.h @@ -22,7 +22,6 @@ #include #include "dali/core/common.h" -#include "dali/core/compare.h" #include "dali/pipeline/proto/dali_proto_utils.h" namespace dali { @@ -206,40 +205,6 @@ class Feature { } } - bool operator<(const Feature &rhs) const { - return Compare(rhs) < 0; - } - - int Compare(const Feature &rhs) const { - if (int cmp = has_shape_ - rhs.has_shape_) - return cmp; - if (int cmp = has_partial_shape_ - rhs.has_partial_shape_) - return cmp; - if (int cmp = type_ - rhs.type_) - return cmp; - switch (type_) { - case int64: - if (int64_t cmp = val_.int64 - rhs.val_.int64) - return cmp >> 32; - break; - case float32: - if (float cmp = val_.float32 - rhs.val_.float32) - return 1 - 2 * std::signbit(cmp); - break; - case string: - if (int cmp = val_.str.compare(rhs.val_.str)) - return cmp; - break; - } - if (has_shape_) - if (int cmp = compare(shape_, rhs.shape_)) - return cmp; - if (has_partial_shape_) - if (int cmp = compare(partial_shape_, rhs.partial_shape_)) - return cmp; - return 0; - } - private: bool has_shape_; std::vector shape_; @@ -249,10 +214,6 @@ class Feature { std::vector partial_shape_; }; -inline int compare(const TFUtil::Feature &a, const TFUtil::Feature &b) { - return a.Compare(b); -} - } // namespace TFUtil } // namespace dali diff --git a/dali/pipeline/graph/cse.cc b/dali/pipeline/graph/cse.cc index c104db881d2..d3d04eaeb55 100644 --- a/dali/pipeline/graph/cse.cc +++ b/dali/pipeline/graph/cse.cc @@ -13,10 +13,11 @@ // limitations under the License. #include "dali/pipeline/graph/cse.h" -#include "dali/pipeline/dali.pb.h" #include #include #include +#include +#include "dali/pipeline/dali.pb.h" namespace dali { namespace graph { @@ -93,7 +94,9 @@ class CSE { } std::string key = OpSpecKey(new_spec); OpNode *&norm = normalized_nodes_[key]; - if (!norm || !IsFoldable(new_spec)) + bool foldable = IsFoldable(new_spec); + + if (!norm || !foldable) norm = node; if (norm != node) { diff --git a/dali/pipeline/operator/argument.cc b/dali/pipeline/operator/argument.cc index b8c3b2f1fb0..ebb1b7ad646 100644 --- a/dali/pipeline/operator/argument.cc +++ b/dali/pipeline/operator/argument.cc @@ -41,7 +41,8 @@ inline std::shared_ptr DeserializeProtobufVectorImpl(const DaliProtoPr auto args = arg.extra_args(); std::vector ret_val; for (auto& a : args) { - const T& elem = DeserializeProtobuf(a)->Get(); + auto des = DeserializeProtobuf(a); + const T& elem = des->Get(); ret_val.push_back(elem); } return Argument::Store(arg.name(), ret_val); diff --git a/dali/pipeline/operator/argument.h b/dali/pipeline/operator/argument.h index 2c2cb2c9ab9..ccf9a1538d1 100644 --- a/dali/pipeline/operator/argument.h +++ b/dali/pipeline/operator/argument.h @@ -22,7 +22,6 @@ #include #include "dali/core/common.h" -#include "dali/core/compare.h" #include "dali/core/error_handling.h" #include "dali/pipeline/data/types.h" #include "dali/pipeline/proto/dali_proto_utils.h" @@ -145,8 +144,6 @@ class Argument { virtual ~Argument() = default; - virtual int Compare(const Argument &other) const = 0; - protected: Argument() : has_name_(false) {} @@ -157,10 +154,6 @@ class Argument { bool has_name_; }; -inline int compare(Argument &a, Argument &b) { - return a.Compare(b); -} - template class ArgumentInst : public Argument { public: @@ -186,13 +179,6 @@ class ArgumentInst : public Argument { dali::SerializeToProtobuf(val.Get(), arg); } - int Compare(const Argument &other) const override { - if (auto *pother = dynamic_cast *>(&other)) - return compare(Get(), pother->Get()); - else - return GetTypeId() - other.GetTypeId(); - } - private: ValueInst val; }; @@ -229,14 +215,6 @@ class ArgumentInst> : public Argument { } } - int Compare(const Argument &other) const override { - if (auto *pother = dynamic_cast> *>(&other)) { - return compare(Get(), pother->Get()); - } else { - return GetTypeId() - other.GetTypeId(); - } - } - private: ValueInst> val; }; diff --git a/dali/pipeline/pipeline.cc b/dali/pipeline/pipeline.cc index 59d0cb3d883..a902a6104ec 100644 --- a/dali/pipeline/pipeline.cc +++ b/dali/pipeline/pipeline.cc @@ -278,6 +278,9 @@ int Pipeline::AddOperatorImpl(const OpSpec &const_spec, const std::string &inst_ if (spec.GetSchema().IsNoPrune()) spec.SetArg("preserve", true); + if (spec.SchemaName() == "ExternalSource") + spec.SetArg("preserve_name", true); // ExternalSource must not be collapsed in CSE + // Take a copy of the passed OpSpec for serialization purposes before any modification this->op_specs_for_serialization_.push_back({inst_name, spec, logical_id}); diff --git a/dali/test/python/test_pipeline.py b/dali/test/python/test_pipeline.py index a64f481c15d..e3a4ebb890e 100644 --- a/dali/test/python/test_pipeline.py +++ b/dali/test/python/test_pipeline.py @@ -547,7 +547,7 @@ def define_graph(self): def test_seed_serialize(): - batch_size = 64 + batch_size = 32 class HybridPipe(Pipeline): def __init__(self, batch_size, num_threads, device_id): @@ -574,18 +574,17 @@ def define_graph(self): ) return (output, self.labels) - n = 30 orig_pipe = HybridPipe(batch_size=batch_size, num_threads=2, device_id=0) s = orig_pipe.serialize() - for i in range(50): + for i in range(10): pipe = Pipeline() pipe.deserialize_and_build(s) pipe_out = pipe.run() pipe_out_cpu = pipe_out[0].as_cpu() - img_chw_test = pipe_out_cpu.at(n) if i == 0: - img_chw = img_chw_test - assert np.sum(np.abs(img_chw - img_chw_test)) == 0 + ref = pipe_out_cpu + else: + check_batch(pipe_out_cpu, ref) def test_make_contiguous_serialize(): @@ -2333,3 +2332,53 @@ def def_ref(): (ref,) = ref_pipe.run() check_batch(cpu, ref, bs, 0, 0, "HWC") check_batch(gpu, ref, bs, 0, 0, "HWC") + + +def test_cse(): + @pipeline_def(batch_size=8, num_threads=4, device_id=0) + def my_pipe(): + a = fn.random.uniform(range=[0, 1], shape=(1,), seed=123) + b = fn.random.uniform(range=[0, 1], shape=(1,), seed=123) + c = fn.random.uniform(range=[0, 1], shape=(1,), seed=123) + i = fn.random.uniform(range=[0, 1], shape=(1,), seed=1234) + + d = a[0] + e = a[0] # repeated a[0] should be ignored + f = c[0] # c -> a -> a[0] + + g = a[0] + b[0] - c[0] + h = c[0] + a[0] - b[0] + return a, b, c, d, e, f, g, h, i + + pipe = my_pipe() + pipe.build() + a, b, c, d, e, f, g, h, i = pipe.run() + assert a.data_ptr() == b.data_ptr() + assert a.data_ptr() == c.data_ptr() + assert a.data_ptr() != i.data_ptr() + + assert d.data_ptr() == e.data_ptr() + assert d.data_ptr() == f.data_ptr() + + assert g.data_ptr() == h.data_ptr() + + +def test_cse_cond(): + @pipeline_def(batch_size=8, num_threads=4, device_id=0, enable_conditionals=True) + def my_pipe(): + a = fn.random.uniform(range=[0, 1], shape=(1,), seed=123) + b = fn.random.uniform(range=[0, 1], shape=(1,), seed=123) + + if a[0] > 0: + d = a + else: + d = b # this is the same as `a` + + return a, b, d + + pipe = my_pipe() + pipe.build() + a, b, d = pipe.run() + assert a.data_ptr() == b.data_ptr() + # `d` is opportunistically reassembled and gets the same first sample pointer as `a` + assert d.data_ptr() == a.data_ptr() diff --git a/include/dali/core/compare.h b/include/dali/core/compare.h deleted file mode 100644 index 367ae04cf7a..00000000000 --- a/include/dali/core/compare.h +++ /dev/null @@ -1,105 +0,0 @@ -// Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License.#ifndef DALI_MAKE_STRING_H - -#ifndef DALI_CORE_COMPARE_H_ -#define DALI_CORE_COMPARE_H_ - -#include -#include -#include -#include -#include -#include -#include "dali/core/float16.h" - -namespace dali { - -template -constexpr std::enable_if_t<(is_arithmetic_or_half::value || std::is_enum_v) && - (is_arithmetic_or_half::value || std::is_enum_v), int> -compare(const A &a, const B &b) { - return a < b ? -1 : b < a ? 1 : 0; -} - -constexpr int compare(const void *a, const void *b) { - return a < b ? -1 : a > b ? 1 : 0; -} - -inline int compare(const std::string &a, const std::string &b) { - return a.compare(b); -} - -inline int compare(std::string_view a, std::string_view b) { - return a.compare(b); -} - -/** Lexicographical 3-way comparison. - * - * Compares tuple elements and returns the sign of the first non-equal comparison. - * If the tuples have different lengths and the common part compares equal, the shorter tuple - * is ordered before the longer one. - */ -template -inline int compare(const std::tuple &a, const std::tuple &b) { - if constexpr (idx < sizeof...(Ts) && idx < sizeof...(Us)) { - if (int cmp = compare(std::get(a), std::get(b))) - return cmp; - return compare(a, b); - } else { - return compare(sizeof...(Ts), sizeof...(Us)); - } -} - -template -inline int compare(const std::pair &ab, const std::pair &cd) { - if (int cmp = compare(ab.first, cd.first)) - return cmp; - return compare(ab.second, cd.second); -} - -template -int compare_range(A &&a, B &&b) { - auto i = std::begin(a); - auto j = std::begin(b); - auto ae = std::end(a); - auto be = std::end(b); - for (; i != ae && j != be; ++i, ++j) { - if (int cmp = compare(*i, *j)) - return cmp; - } - if (i != ae) - return 1; - if (j != be) - return -1; - return 0; -} - -/** Lexicographical 3-way comparison. - * - * Compares range elements and returns the sign of the first non-equal comparison. - * If the ranges have different lengths and the common part compares equal, the shorter range - * is ordered before the longer one. - */ -template ())), - typename = decltype(std::end(std::declval())), - typename = decltype(std::begin(std::declval())), - typename = decltype(std::end(std::declval()))> -int compare(const A &a, const B &b) { - return compare_range(a, b); -} - -} // namespace dali - -#endif // DALI_CORE_COMPARE_H_