From fe74b37ab578e6d3c540b0f6ac187a220ccc028a Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 4 Mar 2020 18:35:38 -0600 Subject: [PATCH] Conditions updated to cover better user scenarios (#4951) * Conditions updated to cover better user scenarios * [1] New test case added * [2] New test case added * [3] Proper variable name used * [4] Review Comments handled * [5] Review comments handled * [6] Review comments handled --- src/relay/ir/alpha_equal.cc | 10 +-- tests/cpp/relay_pass_alpha_equal.cc | 67 +++++++++++++++++++++ tests/python/relay/test_pass_alpha_equal.py | 32 ++++++++++ 3 files changed, 104 insertions(+), 5 deletions(-) create mode 100644 tests/cpp/relay_pass_alpha_equal.cc diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index 78688d7dc730..c622599dd89c 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -50,14 +50,14 @@ class AlphaEqualHandler: * \return The comparison result. */ bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) { - if (lhs.same_as(rhs)) return true; if (!lhs.defined() || !rhs.defined()) return false; - if (lhs->IsInstance()) { - if (!rhs->IsInstance()) return false; + if (lhs.same_as(rhs)) return true; + if (lhs->IsInstance() || rhs->IsInstance()) { + if (!rhs->IsInstance() || !lhs->IsInstance()) return false; return TypeEqual(Downcast(lhs), Downcast(rhs)); } - if (lhs->IsInstance()) { - if (!rhs->IsInstance()) return false; + if (lhs->IsInstance() || rhs->IsInstance()) { + if (!rhs->IsInstance() || !lhs->IsInstance()) return false; return ExprEqual(Downcast(lhs), Downcast(rhs)); } if (const auto lhsm = lhs.as()) { diff --git a/tests/cpp/relay_pass_alpha_equal.cc b/tests/cpp/relay_pass_alpha_equal.cc new file mode 100644 index 000000000000..0207fca00cf7 --- /dev/null +++ b/tests/cpp/relay_pass_alpha_equal.cc @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include +#include + +using namespace tvm; + +class TestAlphaEquals { + runtime::PackedFunc *_packed_func; + public: + TestAlphaEquals(const char* func_name) { + _packed_func = new runtime::PackedFunc(); + TVMFuncGetGlobal(func_name, reinterpret_cast(&_packed_func)); + } + + void UpdatePackedFunc(const char* func_name) { + TVMFuncGetGlobal(func_name, reinterpret_cast(&_packed_func)); + } + + bool operator()(ObjectRef input_1, ObjectRef input_2) { + TVMRetValue rv; + std::vector values(2); + std::vector codes(2); + runtime::TVMArgsSetter setter(values.data(), codes.data()); + setter(0, input_1); + setter(1, input_2); + _packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv); + return bool(rv); + }; + +}; + +TEST(Relay, AlphaTestEmptyTypeNodes) { + auto x = TypeVar("x", kTypeData); + auto y = TypeVar(); + EXPECT_FALSE(relay::AlphaEqual(x, y)); + + TestAlphaEquals test_equals("relay._make._alpha_equal"); + EXPECT_FALSE(test_equals(x, y)); +} + +int main(int argc, char ** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index 7e34f48ec7e1..ec026be69e63 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -28,6 +28,15 @@ def alpha_equal(x, y): """ return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y) +def alpha_equal_commutative(x, y): + """ + Check for commutative property of equality + """ + xy = analysis.alpha_equal(x, y) + yx = analysis.alpha_equal(y, x) + assert xy == yx + return xy + def test_tensor_type_alpha_equal(): t1 = relay.TensorType((3, 4), "float32") t2 = relay.TensorType((3, 4), "float32") @@ -219,6 +228,26 @@ def test_constant_alpha_equal(): assert not alpha_equal(x, y) assert alpha_equal(x, relay.const(1)) +def test_type_node_alpha_equal(): + v1 = relay.TypeVar('v1', 6) + v2 = relay.TypeVar('v2', 6) + assert not alpha_equal(v1, v2) + + v1 = relay.TypeVar('v1', 0) + v2 = relay.TypeVar('v2', 6) + assert not alpha_equal(v1, v2) + + assert alpha_equal_commutative(v1, v1) + +def test_type_node_incompatible_alpha_equal(): + v1 = relay.TypeVar('v1', 6) + v2 = relay.Var("v2") + assert not alpha_equal_commutative(v1, v2) + +def test_expr_node_incompatible_alpha_equal(): + v1 = relay.Var("v1") + v2 = relay.PatternVar(relay.Var("v2")) + assert not alpha_equal_commutative(v1, v2) def test_var_alpha_equal(): v1 = relay.Var("v1") @@ -676,6 +705,9 @@ def test_fn_attribute(): test_tensor_type_alpha_equal() test_incomplete_type_alpha_equal() test_constant_alpha_equal() + test_type_node_alpha_equal() + test_type_node_incompatible_alpha_equal() + test_expr_node_incompatible_alpha_equal() test_func_type_alpha_equal() test_tuple_type_alpha_equal() test_type_relation_alpha_equal()