Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][IR] alpha_equal to structural_equal #5161

Merged
merged 3 commits into from
Mar 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion include/tvm/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,9 @@ class IncompleteTypeNode : public TypeNode {
}

bool SEqualReduce(const IncompleteTypeNode* other, SEqualReducer equal) const {
return equal(kind, other->kind);
return
equal(kind, other->kind) &&
equal.FreeVarEqualImpl(this, other);
}

void SHashReduce(SHashReducer hash_reduce) const {
Expand Down
55 changes: 0 additions & 55 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,61 +64,6 @@ TVM_DLL Kind KindCheck(const Type& t, const IRModule& mod);
*/
TVM_DLL bool ConstantCheck(const Expr& e);

/*!
* \brief Compare two expressions for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
*
* For example: `let x = 1 in x` is equal to `let y = 1 in y`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param e1 The left hand expression.
* \param e2 The right hand expression.
*
* \return true if equal, otherwise false
*/
TVM_DLL bool AlphaEqual(const Expr& e1, const Expr& e2);

/*!
* \brief Compare two types for structural equivalence.
*
* This comparison operator respects scoping and compares
* expressions without regard to variable choice.
*
* For example: `forall s, Tensor[f32, s]` is equal to
* `forall w, Tensor[f32, w]`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param t1 The left hand type.
* \param t2 The right hand type.
*
* \return true if equal, otherwise false
*/
TVM_DLL bool AlphaEqual(const Type& t1, const Type& t2);

/*!
* \brief Compare two patterns for structural equivalence.
*
* This comparison operator respects scoping and compares
* patterns without regard to variable choice.
*
* For example: `A(x, _, y)` is equal to `A(z, _, a)`.
*
* See https://en.wikipedia.org/wiki/Lambda_calculus#Alpha_equivalence
* for more details.
*
* \param t1 The left hand pattern.
* \param t2 The right hand pattern.
*
* \return true if equal, otherwise false
*/
TVM_DLL bool AlphaEqual(const Pattern& t1, const Pattern& t2);

/*!
* \brief Check that each Var is only bound once.
*
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""Unified type system in the project."""
from enum import IntEnum
import tvm
import tvm._ffi

from .base import Node
Expand All @@ -26,7 +27,7 @@ class Type(Node):
"""The base class of all types."""
def __eq__(self, other):
"""Compare two types for structural equivalence."""
return bool(_ffi_api.type_alpha_equal(self, other))
return bool(tvm.ir.structural_equal(self, other))

def __ne__(self, other):
return not self.__eq__(other)
Expand Down
1 change: 0 additions & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@

from . import transform
from . import analysis
from .analysis import alpha_equal
from .build_module import build, create_executor, optimize
from .transform import build_config
from . import debug
Expand Down
72 changes: 0 additions & 72 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,78 +220,6 @@ def all_type_vars(expr, mod=None):
return _ffi_api.all_type_vars(expr, use_mod)


def alpha_equal(lhs, rhs):
"""Compare two Relay expr for structural equivalence (alpha equivalence).

Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.

rhs : tvm.relay.Expr
One of the input Expression.

Returns
-------
result : bool
True iff lhs is alpha equal to rhs.
"""
return bool(_ffi_api._alpha_equal(lhs, rhs))


def assert_alpha_equal(lhs, rhs):
"""Assert that two Relay expr is structurally equivalent. (alpha equivalence).

Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.

rhs : tvm.relay.Expr
One of the input Expression.
"""
_ffi_api._assert_alpha_equal(lhs, rhs)


def graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.

Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.

rhs : tvm.relay.Expr
One of the input Expression.

Returns
-------
result : bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_ffi_api._graph_equal(lhs, rhs))


def assert_graph_equal(lhs, rhs):
"""Compare two Relay expr for data-flow equivalence.
The difference between this and alpha-equality is that
variables are not expected to match between lhs and rhs;
they are treated as sources and are mapped between each other.

Parameters
----------
lhs : tvm.relay.Expr
One of the input Expression.

rhs : tvm.relay.Expr
One of the input Expression.
"""
_ffi_api._assert_graph_equal(lhs, rhs)


def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
Expand Down
19 changes: 9 additions & 10 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
*/
#include <tvm/runtime/registry.h>
#include <tvm/ir/module.h>
#include <tvm/node/structural_equal.h>
// NOTE: reverse dependency on relay.
// These dependencies do not happen at the interface-level,
// and are only used in minimum cases where they are clearly marked.
Expand Down Expand Up @@ -194,12 +195,11 @@ relay::Function RunTypeCheck(const IRModule& mod,
<< AsText(func, false)
<< std::endl;
}
func =
relay::Function(concat(func->params, fv),
func->body,
func->ret_type,
concat(func->type_params, ftv),
func->attrs);
func = relay::Function(concat(func->params, fv),
func->body,
func->ret_type,
concat(func->type_params, ftv),
func->attrs);
// Type check the item before we add it to the module.
relay::Function checked_func = InferType(func, mod, var);
return checked_func;
Expand All @@ -222,7 +222,7 @@ void IRModuleNode::Add(const GlobalVar& var,
CHECK(update)
<< "Already have definition for " << var->name_hint;
auto old_type = functions[var]->checked_type();
CHECK(relay::AlphaEqual(type, old_type))
CHECK(tvm::StructuralEqual()(type, old_type))
<< "Module#update changes type, not possible in this mode.";
}
var->checked_type_ = type;
Expand Down Expand Up @@ -353,9 +353,8 @@ IRModule IRModule::FromExpr(
if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
} else {
func = relay::Function(
relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {});
func = relay::Function(relay::FreeVars(expr), expr, Type(),
relay::FreeTypeVars(expr, mod), {});
}
auto main_gv = GlobalVar("main");
mod->Add(main_gv, func);
Expand Down
Loading