From 535095a07156bb54b4805b059a301e0921ff5000 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Wed, 11 Mar 2020 08:56:03 -0700 Subject: [PATCH] Revert "Conditions updated to cover better user scenarios (#4951)" (#5032) This reverts commit fe74b37ab578e6d3c540b0f6ac187a220ccc028a. --- src/relay/ir/alpha_equal.cc | 10 +-- tests/cpp/relay_pass_alpha_equal.cc | 67 --------------------- tests/python/relay/test_pass_alpha_equal.py | 32 ---------- 3 files changed, 5 insertions(+), 104 deletions(-) delete mode 100644 tests/cpp/relay_pass_alpha_equal.cc diff --git a/src/relay/ir/alpha_equal.cc b/src/relay/ir/alpha_equal.cc index c622599dd89c..78688d7dc730 100644 --- a/src/relay/ir/alpha_equal.cc +++ b/src/relay/ir/alpha_equal.cc @@ -50,14 +50,14 @@ class AlphaEqualHandler: * \return The comparison result. */ bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) { - if (!lhs.defined() || !rhs.defined()) return false; if (lhs.same_as(rhs)) return true; - if (lhs->IsInstance() || rhs->IsInstance()) { - if (!rhs->IsInstance() || !lhs->IsInstance()) return false; + if (!lhs.defined() || !rhs.defined()) return false; + if (lhs->IsInstance()) { + if (!rhs->IsInstance()) return false; return TypeEqual(Downcast(lhs), Downcast(rhs)); } - if (lhs->IsInstance() || rhs->IsInstance()) { - if (!rhs->IsInstance() || !lhs->IsInstance()) return false; + if (lhs->IsInstance()) { + if (!rhs->IsInstance()) return false; return ExprEqual(Downcast(lhs), Downcast(rhs)); } if (const auto lhsm = lhs.as()) { diff --git a/tests/cpp/relay_pass_alpha_equal.cc b/tests/cpp/relay_pass_alpha_equal.cc deleted file mode 100644 index 0207fca00cf7..000000000000 --- a/tests/cpp/relay_pass_alpha_equal.cc +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include -#include -#include -#include -#include -#include - -using namespace tvm; - -class TestAlphaEquals { - runtime::PackedFunc *_packed_func; - public: - TestAlphaEquals(const char* func_name) { - _packed_func = new runtime::PackedFunc(); - TVMFuncGetGlobal(func_name, reinterpret_cast(&_packed_func)); - } - - void UpdatePackedFunc(const char* func_name) { - TVMFuncGetGlobal(func_name, reinterpret_cast(&_packed_func)); - } - - bool operator()(ObjectRef input_1, ObjectRef input_2) { - TVMRetValue rv; - std::vector values(2); - std::vector codes(2); - runtime::TVMArgsSetter setter(values.data(), codes.data()); - setter(0, input_1); - setter(1, input_2); - _packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv); - return bool(rv); - }; - -}; - -TEST(Relay, AlphaTestEmptyTypeNodes) { - auto x = TypeVar("x", kTypeData); - auto y = TypeVar(); - EXPECT_FALSE(relay::AlphaEqual(x, y)); - - TestAlphaEquals test_equals("relay._make._alpha_equal"); - EXPECT_FALSE(test_equals(x, y)); -} - -int main(int argc, char ** argv) { - testing::InitGoogleTest(&argc, argv); - testing::FLAGS_gtest_death_test_style = "threadsafe"; - return RUN_ALL_TESTS(); -} diff --git a/tests/python/relay/test_pass_alpha_equal.py b/tests/python/relay/test_pass_alpha_equal.py index ec026be69e63..7e34f48ec7e1 100644 --- a/tests/python/relay/test_pass_alpha_equal.py +++ b/tests/python/relay/test_pass_alpha_equal.py @@ -28,15 +28,6 @@ def alpha_equal(x, y): """ return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y) -def alpha_equal_commutative(x, y): - """ - Check for commutative property of equality - """ - xy = analysis.alpha_equal(x, y) - yx = analysis.alpha_equal(y, x) - assert xy == yx - return xy - def test_tensor_type_alpha_equal(): t1 = relay.TensorType((3, 4), "float32") t2 = relay.TensorType((3, 4), "float32") @@ -228,26 +219,6 @@ def test_constant_alpha_equal(): assert not alpha_equal(x, y) assert alpha_equal(x, relay.const(1)) -def test_type_node_alpha_equal(): - v1 = relay.TypeVar('v1', 6) - v2 = relay.TypeVar('v2', 6) - assert not alpha_equal(v1, v2) - - v1 = relay.TypeVar('v1', 0) - v2 = relay.TypeVar('v2', 6) - assert not alpha_equal(v1, v2) - - assert alpha_equal_commutative(v1, v1) - -def test_type_node_incompatible_alpha_equal(): - v1 = relay.TypeVar('v1', 6) - v2 = relay.Var("v2") - assert not alpha_equal_commutative(v1, v2) - -def test_expr_node_incompatible_alpha_equal(): - v1 = relay.Var("v1") - v2 = relay.PatternVar(relay.Var("v2")) - assert not alpha_equal_commutative(v1, v2) def test_var_alpha_equal(): v1 = relay.Var("v1") @@ -705,9 +676,6 @@ def test_fn_attribute(): test_tensor_type_alpha_equal() test_incomplete_type_alpha_equal() test_constant_alpha_equal() - test_type_node_alpha_equal() - test_type_node_incompatible_alpha_equal() - test_expr_node_incompatible_alpha_equal() test_func_type_alpha_equal() test_tuple_type_alpha_equal() test_type_relation_alpha_equal()