Skip to content

Commit

Permalink
Revert "Conditions updated to cover better user scenarios (#4951)"
Browse files Browse the repository at this point in the history
This reverts commit fe74b37.
  • Loading branch information
tqchen authored Mar 10, 2020
1 parent 6026af5 commit 07bac24
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 104 deletions.
10 changes: 5 additions & 5 deletions src/relay/ir/alpha_equal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,14 @@ class AlphaEqualHandler:
* \return The comparison result.
*/
bool Equal(const ObjectRef& lhs, const ObjectRef& rhs) {
if (!lhs.defined() || !rhs.defined()) return false;
if (lhs.same_as(rhs)) return true;
if (lhs->IsInstance<TypeNode>() || rhs->IsInstance<TypeNode>()) {
if (!rhs->IsInstance<TypeNode>() || !lhs->IsInstance<TypeNode>()) return false;
if (!lhs.defined() || !rhs.defined()) return false;
if (lhs->IsInstance<TypeNode>()) {
if (!rhs->IsInstance<TypeNode>()) return false;
return TypeEqual(Downcast<Type>(lhs), Downcast<Type>(rhs));
}
if (lhs->IsInstance<ExprNode>() || rhs->IsInstance<ExprNode>()) {
if (!rhs->IsInstance<ExprNode>() || !lhs->IsInstance<ExprNode>()) return false;
if (lhs->IsInstance<ExprNode>()) {
if (!rhs->IsInstance<ExprNode>()) return false;
return ExprEqual(Downcast<Expr>(lhs), Downcast<Expr>(rhs));
}
if (const auto lhsm = lhs.as<IRModuleNode>()) {
Expand Down
67 changes: 0 additions & 67 deletions tests/cpp/relay_pass_alpha_equal.cc

This file was deleted.

32 changes: 0 additions & 32 deletions tests/python/relay/test_pass_alpha_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,6 @@ def alpha_equal(x, y):
"""
return analysis.alpha_equal(x, y) and analysis.structural_hash(x) == analysis.structural_hash(y)

def alpha_equal_commutative(x, y):
"""
Check for commutative property of equality
"""
xy = analysis.alpha_equal(x, y)
yx = analysis.alpha_equal(y, x)
assert xy == yx
return xy

def test_tensor_type_alpha_equal():
t1 = relay.TensorType((3, 4), "float32")
t2 = relay.TensorType((3, 4), "float32")
Expand Down Expand Up @@ -228,26 +219,6 @@ def test_constant_alpha_equal():
assert not alpha_equal(x, y)
assert alpha_equal(x, relay.const(1))

def test_type_node_alpha_equal():
v1 = relay.TypeVar('v1', 6)
v2 = relay.TypeVar('v2', 6)
assert not alpha_equal(v1, v2)

v1 = relay.TypeVar('v1', 0)
v2 = relay.TypeVar('v2', 6)
assert not alpha_equal(v1, v2)

assert alpha_equal_commutative(v1, v1)

def test_type_node_incompatible_alpha_equal():
v1 = relay.TypeVar('v1', 6)
v2 = relay.Var("v2")
assert not alpha_equal_commutative(v1, v2)

def test_expr_node_incompatible_alpha_equal():
v1 = relay.Var("v1")
v2 = relay.PatternVar(relay.Var("v2"))
assert not alpha_equal_commutative(v1, v2)

def test_var_alpha_equal():
v1 = relay.Var("v1")
Expand Down Expand Up @@ -705,9 +676,6 @@ def test_fn_attribute():
test_tensor_type_alpha_equal()
test_incomplete_type_alpha_equal()
test_constant_alpha_equal()
test_type_node_alpha_equal()
test_type_node_incompatible_alpha_equal()
test_expr_node_incompatible_alpha_equal()
test_func_type_alpha_equal()
test_tuple_type_alpha_equal()
test_type_relation_alpha_equal()
Expand Down

0 comments on commit 07bac24

Please sign in to comment.