Skip to content

Commit

Permalink
Conditions updated to cover better user scenarios (apache#4951)
Browse files Browse the repository at this point in the history
* Conditions updated to cover better user scenarios

* [1] New test case added

* [2] New test case added

* [3] Proper variable name used

* [4] Review Comments handled

* [5] Review comments handled

* [6] Review comments handled
  • Loading branch information
tqchen authored Mar 5, 2020
1 parent 7a06bbe commit fe74b37
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 5 deletions.
10 changes: 5 additions & 5 deletions src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ class AlphaEqualHandler:
* \return The comparison result.
*/
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
if (lhs.same_as(rhs)) return true;
if (!lhs.defined() || !rhs.defined()) return false;
if (lhs->IsInstance<TypeNode>()) {
if (!rhs->IsInstance<TypeNode>()) return false;
if (lhs.same_as(rhs)) return true;
if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return false;
return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
}
if (lhs->IsInstance<ExprNode>()) {
if (!rhs->IsInstance<ExprNode>()) return false;
if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return false;
return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
}
if (const auto lhsm = lhs.as<IRModuleNode>()) {
Expand Down
67 changes: 67 additions & 0 deletions tests/cpp/relay_pass_alpha_equal.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <gtest/gtest.h>
#include <tvm/te/operation.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/transform.h>

using namespace tvm;

class TestAlphaEquals {
runtime::PackedFunc *_packed_func;
public:
TestAlphaEquals(const char* func_name) {
_packed_func = new runtime::PackedFunc();
TVMFuncGetGlobal(func_name, reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
}

void UpdatePackedFunc(const char* func_name) {
TVMFuncGetGlobal(func_name, reinterpret_cast<TVMFunctionHandle*>(&_packed_func));
}

bool operator()(ObjectRef input_1, ObjectRef input_2) {
TVMRetValue rv;
std::vector<TVMValue> values(2);
std::vector<int> codes(2);
runtime::TVMArgsSetter setter(values.data(), codes.data());
setter(0, input_1);
setter(1, input_2);
_packed_func->CallPacked(TVMArgs(values.data(), codes.data(), 2), &rv);
return bool(rv);
};

};

TEST(Relay, AlphaTestEmptyTypeNodes) {
auto x = TypeVar("x", kTypeData);
auto y = TypeVar();
EXPECT_FALSE(relay::AlphaEqual(x, y));

TestAlphaEquals test_equals("relay._make._alpha_equal");
EXPECT_FALSE(test_equals(x, y));
}

int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
32 changes: 32 additions & 0 deletions tests/python/relay/test_pass_alpha_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ def alpha_equal(x, y):
"""
return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)

def alpha_equal_commutative(x, y):
"""
Check for commutative property of equality
"""
xy = analysis.alpha_equal(x, y)
yx = analysis.alpha_equal(y, x)
assert xy == yx
return xy

def test_tensor_type_alpha_equal():
t1 = relay.TensorType((3, 4), "float32")
t2 = relay.TensorType((3, 4), "float32")
Expand Down Expand Up @@ -219,6 +228,26 @@ def test_constant_alpha_equal():
assert not alpha_equal(x, y)
assert alpha_equal(x, relay.const(1))

def test_type_node_alpha_equal():
v1 = relay.TypeVar('v1', 6)
v2 = relay.TypeVar('v2', 6)
assert not alpha_equal(v1, v2)

v1 = relay.TypeVar('v1', 0)
v2 = relay.TypeVar('v2', 6)
assert not alpha_equal(v1, v2)

assert alpha_equal_commutative(v1, v1)

def test_type_node_incompatible_alpha_equal():
v1 = relay.TypeVar('v1', 6)
v2 = relay.Var("v2")
assert not alpha_equal_commutative(v1, v2)

def test_expr_node_incompatible_alpha_equal():
v1 = relay.Var("v1")
v2 = relay.PatternVar(relay.Var("v2"))
assert not alpha_equal_commutative(v1, v2)

def test_var_alpha_equal():
v1 = relay.Var("v1")
Expand Down Expand Up @@ -676,6 +705,9 @@ def test_fn_attribute():
test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal()
test_constant_alpha_equal()
test_type_node_alpha_equal()
test_type_node_incompatible_alpha_equal()
test_expr_node_incompatible_alpha_equal()
test_func_type_alpha_equal()
test_tuple_type_alpha_equal()
test_type_relation_alpha_equal()
Expand Down

0 comments on commit fe74b37

Please sign in to comment.