From a01ce2cb64191a5b6f21d565ad263b4a64dda997 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Thu, 27 Feb 2020 12:21:48 +0530 Subject: [PATCH 1/7] Conditions updated to cover better user scenarios --- src/relay/ir/alpha_equal.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 78688d7dc730..c622599dd89c 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -50,14 +50,14 @@ class AlphaEqualHandler: * \return The comparison result. */ bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) { - if (lhs.same_as(rhs)) return true; if (!lhs.defined() || !rhs.defined()) return false; - if (lhs->IsInstance()) { - if (!rhs->IsInstance()) return false; + if (lhs.same_as(rhs)) return true; + if (lhs->IsInstance() || rhs->IsInstance()) { + if (!rhs->IsInstance() || !lhs->IsInstance()) return false; return TypeEqual(Downcast(lhs), Downcast(rhs)); } - if (lhs->IsInstance()) { - if (!rhs->IsInstance()) return false; + if (lhs->IsInstance() || rhs->IsInstance()) { + if (!rhs->IsInstance() || !lhs->IsInstance()) return false; return ExprEqual(Downcast(lhs), Downcast(rhs)); } if (const auto lhsm = lhs.as()) { From f93a688a6b5df1e47ebf6827305b5101249139b5 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Sat, 29 Feb 2020 21:16:12 +0530 Subject: [PATCH 2/7] [1] New test case added --- tests/python/relay/test_pass_alpha_equal.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 7e34f48ec7e1..b267c51e7fcb 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -219,6 +219,14 @@ def test_constant_alpha_equal(): assert not alpha_equal(x, y) assert alpha_equal(x, relay.const(1)) +def test_type_node_alpha_equal(): + v1 = relay.TypeVar('v1', 6) + v2 = relay.TypeVar('v2', 6) + assert not alpha_equal(v1, v2) + + v1 = relay.TypeVar('v1', 0) + v2 = relay.TypeVar('v2', 6) + assert not alpha_equal(v1, v2) def test_var_alpha_equal(): v1 = relay.Var("v1") @@ -676,6 +684,7 @@ def test_fn_attribute(): test_tensor_type_alpha_equal() test_incomplete_type_alpha_equal() test_constant_alpha_equal() + test_type_node_alpha_equal() test_func_type_alpha_equal() test_tuple_type_alpha_equal() test_type_relation_alpha_equal() From af9c68cc7c039102afc24b2051767ca8c314f48f Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Sat, 29 Feb 2020 21:17:50 +0530 Subject: [PATCH 3/7] [2] New test case added --- tests/cpp/relay_pass_alpha_equal.cc | 87 +++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) create mode 100644 tests/cpp/relay_pass_alpha_equal.cc diff --git a/tests/cpp/relay_pass_alpha_equal.cc b/tests/cpp/relay_pass_alpha_equal.cc new file mode 100644 index 000000000000..b9586424f509 --- /dev/null +++ b/tests/cpp/relay_pass_alpha_equal.cc @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +TEST(Relay, AlphaTestEmptyTypeNodes) { + using namespace tvm; + auto x = TypeVar("x", kTypeData); + auto y = TypeVar(); + EXPECT_FALSE(relay::AlphaEqual(x, y)); + + runtime::PackedFunc *packed_func = new tvm::runtime::PackedFunc(); + TVMRetValue rv; + (void)TVMFuncGetGlobal("relay._make._alpha_equal", (void**)&packed_func); + std::vector values(2); + std::vector codes(2); + runtime::TVMArgsSetter setter(values.data(), codes.data()); + setter(0, x); + setter(1, y); + packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv); + EXPECT_FALSE(bool(rv)); +} + +TEST(Relay, AlphaTestSameTypeNodes) { + using namespace tvm; + auto x = TypeVar("x", kTypeData); + EXPECT_TRUE(relay::AlphaEqual(x, x)); + + runtime::PackedFunc *packed_func = new tvm::runtime::PackedFunc(); + TVMRetValue rv; + (void)TVMFuncGetGlobal("relay._make._alpha_equal", (void**)&packed_func); + std::vector values(2); + std::vector codes(2); + runtime::TVMArgsSetter setter(values.data(), codes.data()); + setter(0, x); + setter(1, x); + packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv); + EXPECT_TRUE(bool(rv)); +} + +TEST(Relay, AlphaTestIncompatibleTypeNodes) { + using namespace tvm; + auto x = TypeVar("x", kTypeData); + auto y = relay::VarNode::make("y", relay::Type()); + runtime::PackedFunc *packed_func = new tvm::runtime::PackedFunc(); + TVMRetValue rv; + (void)TVMFuncGetGlobal("relay._make._alpha_equal", (void**)&packed_func); + std::vector values(2); + std::vector codes(2); + runtime::TVMArgsSetter setter(values.data(), codes.data()); + setter(0, x); + setter(1, y); + packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv); + EXPECT_FALSE(bool(rv)); + + setter(0, y); + setter(1, x); + packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv); + EXPECT_FALSE(bool(rv)); +} + +int main(int argc, char ** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} From 742bec52fa40d131e529178261814229bea16855 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Sat, 29 Feb 2020 22:31:42 +0530 Subject: [PATCH 4/7] [3] Proper variable name used --- tests/cpp/relay_pass_alpha_equal.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/relay_pass_alpha_equal.cc b/tests/cpp/relay_pass_alpha_equal.cc index b9586424f509..460287a82fac 100644 --- a/tests/cpp/relay_pass_alpha_equal.cc +++ b/tests/cpp/relay_pass_alpha_equal.cc @@ -32,7 +32,7 @@ TEST(Relay, AlphaTestEmptyTypeNodes) { runtime::PackedFunc *packed_func = new tvm::runtime::PackedFunc(); TVMRetValue rv; - (void)TVMFuncGetGlobal("relay._make._alpha_equal", (void**)&packed_func); + TVMFuncGetGlobal("relay._make._alpha_equal", reinterpret_cast(&packed_func)); std::vector values(2); std::vector codes(2); runtime::TVMArgsSetter setter(values.data(), codes.data()); @@ -49,7 +49,7 @@ TEST(Relay, AlphaTestSameTypeNodes) { runtime::PackedFunc *packed_func = new tvm::runtime::PackedFunc(); TVMRetValue rv; - (void)TVMFuncGetGlobal("relay._make._alpha_equal", (void**)&packed_func); + TVMFuncGetGlobal("relay._make._alpha_equal", reinterpret_cast(&packed_func)); std::vector values(2); std::vector codes(2); runtime::TVMArgsSetter setter(values.data(), codes.data()); @@ -65,7 +65,7 @@ TEST(Relay, AlphaTestIncompatibleTypeNodes) { auto y = relay::VarNode::make("y", relay::Type()); runtime::PackedFunc *packed_func = new tvm::runtime::PackedFunc(); TVMRetValue rv; - (void)TVMFuncGetGlobal("relay._make._alpha_equal", (void**)&packed_func); + TVMFuncGetGlobal("relay._make._alpha_equal", reinterpret_cast(&packed_func)); std::vector values(2); std::vector codes(2); runtime::TVMArgsSetter setter(values.data(), codes.data()); From d0c53b06baeae43df0334829bf483118989e7b5f Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Sun, 1 Mar 2020 15:22:29 +0530 Subject: [PATCH 5/7] [4] Review Comments handled --- tests/cpp/relay_pass_alpha_equal.cc | 84 ++++++++++++++++------------- 1 file changed, 46 insertions(+), 38 deletions(-) diff --git a/tests/cpp/relay_pass_alpha_equal.cc b/tests/cpp/relay_pass_alpha_equal.cc index 460287a82fac..a32f4f5bcd96 100644 --- a/tests/cpp/relay_pass_alpha_equal.cc +++ b/tests/cpp/relay_pass_alpha_equal.cc @@ -24,60 +24,68 @@ #include #include +using namespace tvm; + +class TestAlphaEquals { + runtime::PackedFunc *_packed_func; + public: + TestAlphaEquals(const char* func_name) { + _packed_func = new runtime::PackedFunc(); + TVMFuncGetGlobal(func_name, reinterpret_cast(&_packed_func)); + } + + void UpdatePackedFunc(const char* func_name) { + TVMFuncGetGlobal(func_name, reinterpret_cast(&_packed_func)); + } + + bool operator()(ObjectRef input_1, ObjectRef input_2) { + TVMRetValue rv; + std::vector values(2); + std::vector codes(2); + runtime::TVMArgsSetter setter(values.data(), codes.data()); + setter(0, input_1); + setter(1, input_2); + _packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv); + return bool(rv); + }; + +}; + TEST(Relay, AlphaTestEmptyTypeNodes) { - using namespace tvm; auto x = TypeVar("x", kTypeData); auto y = TypeVar(); EXPECT_FALSE(relay::AlphaEqual(x, y)); - runtime::PackedFunc *packed_func = new tvm::runtime::PackedFunc(); - TVMRetValue rv; - TVMFuncGetGlobal("relay._make._alpha_equal", reinterpret_cast(&packed_func)); - std::vector values(2); - std::vector codes(2); - runtime::TVMArgsSetter setter(values.data(), codes.data()); - setter(0, x); - setter(1, y); - packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv); - EXPECT_FALSE(bool(rv)); + TestAlphaEquals test_equals("relay._make._alpha_equal"); + EXPECT_FALSE(test_equals(x, y)); } TEST(Relay, AlphaTestSameTypeNodes) { - using namespace tvm; auto x = TypeVar("x", kTypeData); EXPECT_TRUE(relay::AlphaEqual(x, x)); - runtime::PackedFunc *packed_func = new tvm::runtime::PackedFunc(); - TVMRetValue rv; - TVMFuncGetGlobal("relay._make._alpha_equal", reinterpret_cast(&packed_func)); - std::vector values(2); - std::vector codes(2); - runtime::TVMArgsSetter setter(values.data(), codes.data()); - setter(0, x); - setter(1, x); - packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv); - EXPECT_TRUE(bool(rv)); + TestAlphaEquals test_equals("relay._make._alpha_equal"); + EXPECT_TRUE(test_equals(x, x)); } TEST(Relay, AlphaTestIncompatibleTypeNodes) { - using namespace tvm; auto x = TypeVar("x", kTypeData); auto y = relay::VarNode::make("y", relay::Type()); - runtime::PackedFunc *packed_func = new tvm::runtime::PackedFunc(); - TVMRetValue rv; - TVMFuncGetGlobal("relay._make._alpha_equal", reinterpret_cast(&packed_func)); - std::vector values(2); - std::vector codes(2); - runtime::TVMArgsSetter setter(values.data(), codes.data()); - setter(0, x); - setter(1, y); - packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv); - EXPECT_FALSE(bool(rv)); - - setter(0, y); - setter(1, x); - packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv); - EXPECT_FALSE(bool(rv)); + + TestAlphaEquals test_equals("relay._make._alpha_equal"); + EXPECT_FALSE(test_equals(x, y)); + EXPECT_TRUE(test_equals(x, y) == test_equals(y, x)); + +} + +TEST(Relay, AlphaTestIncompatibleExprNodes) { + auto x = relay::VarNode::make("x", relay::Type()); + auto y = ObjectRef(make_object()); + + TestAlphaEquals test_equals("relay._make._alpha_equal"); + EXPECT_FALSE(test_equals(x, y)); + EXPECT_TRUE(test_equals(x, y) == test_equals(y, x)); + } int main(int argc, char ** argv) { From e5522cf8fb1d915315b815fc8b800cb64f3907a4 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Mon, 2 Mar 2020 11:14:47 +0530 Subject: [PATCH 6/7] [5] Review comments handled --- tests/cpp/relay_pass_alpha_equal.cc | 28 --------------------- tests/python/relay/test_pass_alpha_equal.py | 23 +++++++++++++++++ 2 files changed, 23 insertions(+), 28 deletions(-) diff --git a/tests/cpp/relay_pass_alpha_equal.cc b/tests/cpp/relay_pass_alpha_equal.cc index a32f4f5bcd96..0207fca00cf7 100644 --- a/tests/cpp/relay_pass_alpha_equal.cc +++ b/tests/cpp/relay_pass_alpha_equal.cc @@ -60,34 +60,6 @@ TEST(Relay, AlphaTestEmptyTypeNodes) { EXPECT_FALSE(test_equals(x, y)); } -TEST(Relay, AlphaTestSameTypeNodes) { - auto x = TypeVar("x", kTypeData); - EXPECT_TRUE(relay::AlphaEqual(x, x)); - - TestAlphaEquals test_equals("relay._make._alpha_equal"); - EXPECT_TRUE(test_equals(x, x)); -} - -TEST(Relay, AlphaTestIncompatibleTypeNodes) { - auto x = TypeVar("x", kTypeData); - auto y = relay::VarNode::make("y", relay::Type()); - - TestAlphaEquals test_equals("relay._make._alpha_equal"); - EXPECT_FALSE(test_equals(x, y)); - EXPECT_TRUE(test_equals(x, y) == test_equals(y, x)); - -} - -TEST(Relay, AlphaTestIncompatibleExprNodes) { - auto x = relay::VarNode::make("x", relay::Type()); - auto y = ObjectRef(make_object()); - - TestAlphaEquals test_equals("relay._make._alpha_equal"); - EXPECT_FALSE(test_equals(x, y)); - EXPECT_TRUE(test_equals(x, y) == test_equals(y, x)); - -} - int main(int argc, char ** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index b267c51e7fcb..163abbd837a1 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -28,6 +28,15 @@ def alpha_equal(x, y): """ return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y) +def alpha_equal_commutative(x, y): + """ + Check for commutative property of equality + """ + xy = analysis.alpha_equal(x, y) + yx = analysis.alpha_equal(y, x) + assert xy == yx + return analysis.alpha_equal(x, y) + def test_tensor_type_alpha_equal(): t1 = relay.TensorType((3, 4), "float32") t2 = relay.TensorType((3, 4), "float32") @@ -228,6 +237,18 @@ def test_type_node_alpha_equal(): v2 = relay.TypeVar('v2', 6) assert not alpha_equal(v1, v2) + assert alpha_equal_commutative(v1, v1) + +def test_type_node_incompatible_alpha_equal(): + v1 = relay.TypeVar('v1', 6) + v2 = relay.Var("v2") + assert not alpha_equal_commutative(v1, v2) + +def test_expr_node_incompatible_alpha_equal(): + v1 = relay.Var("v1") + v2 = relay.PatternVar(relay.Var("v2")) + assert not alpha_equal_commutative(v1, v2) + def test_var_alpha_equal(): v1 = relay.Var("v1") v2 = relay.Var("v2") @@ -685,6 +706,8 @@ def test_fn_attribute(): test_incomplete_type_alpha_equal() test_constant_alpha_equal() test_type_node_alpha_equal() + test_type_node_incompatible_alpha_equal() + test_expr_node_incompatible_alpha_equal() test_func_type_alpha_equal() test_tuple_type_alpha_equal() test_type_relation_alpha_equal() From 4b7909a435ee1b603be6c030624a994894578bb5 Mon Sep 17 00:00:00 2001 From: ANSHUMAN TRIPATHY Date: Mon, 2 Mar 2020 12:53:04 +0530 Subject: [PATCH 7/7] [6] Review comments handled --- tests/python/relay/test_pass_alpha_equal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 163abbd837a1..ec026be69e63 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -35,7 +35,7 @@ def alpha_equal_commutative(x, y): xy = analysis.alpha_equal(x, y) yx = analysis.alpha_equal(y, x) assert xy == yx - return analysis.alpha_equal(x, y) + return xy def test_tensor_type_alpha_equal(): t1 = relay.TensorType((3, 4), "float32")