Skip to content

Commit

Permalink
save
Browse files Browse the repository at this point in the history
save

save

save

upstream

lint

remove bad changes

fix build

save

save

please the ci god
  • Loading branch information
MarisaKirisame committed May 16, 2019
1 parent c4439a8 commit 680d141
Show file tree
Hide file tree
Showing 4 changed files with 492 additions and 149 deletions.
71 changes: 38 additions & 33 deletions python/tvm/relay/ir_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ def well_formed(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
Expand Down Expand Up @@ -483,7 +483,7 @@ def free_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
Expand All @@ -505,7 +505,7 @@ def bound_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
Expand All @@ -521,7 +521,7 @@ def all_vars(expr):
Parameters
----------
expr: tvm.relay.Expr
expr : tvm.relay.Expr
The input expression
Returns
Expand All @@ -537,9 +537,10 @@ def free_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
Expand All @@ -556,9 +557,10 @@ def bound_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
Expand All @@ -575,9 +577,9 @@ def all_type_vars(expr, mod=None):
Parameters
----------
expr: Union[tvm.relay.Expr,tvm.relay.Type]
expr : Union[tvm.relay.Expr,tvm.relay.Type]
The input expression/type
mod: tvm.relay.Module, optional
mod : Optional[tvm.relay.Module]
The global module
Returns
Expand All @@ -594,12 +596,12 @@ def simplify_inference(expr):
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with some simplification
"""
Expand All @@ -612,12 +614,12 @@ def canonicalize_ops(expr):
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression without bias_add
"""
return _ir_pass.canonicalize_ops(expr)
Expand All @@ -628,12 +630,12 @@ def dead_code_elimination(expr):
Parameters
----------
e: tvm.relay.Expr
expr : tvm.relay.Expr
The input Expression
Returns
-------
result: tvm.relay.Expr
result : tvm.relay.Expr
An expression which is semantically equal to the input expression,
but with dead code removed.
"""
Expand All @@ -645,15 +647,15 @@ def alpha_equal(lhs, rhs):
Parameters
----------
lhs: tvm.relay.Expr
lhs : tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
result : bool
True iff lhs is alpha equal to rhs.
"""
return bool(_make._alpha_equal(lhs, rhs))
Expand All @@ -667,15 +669,15 @@ def graph_equal(lhs, rhs):
Parameters
----------
lhs: tvm.relay.Expr
lhs : tvm.relay.Expr
One of the input Expression.
rhs: tvm.relay.Expr
rhs : tvm.relay.Expr
One of the input Expression.
Returns
-------
result: bool
result : bool
True iff lhs is data-flow equivalent to rhs.
"""
return bool(_make._graph_equal(lhs, rhs))
Expand All @@ -686,12 +688,12 @@ def structural_hash(value):
Parameters
----------
expr: tvm.relay.Expr or tvm.relay.Type
expr : Union[tvm.relay.Expr, tvm.relay.Type]
The expression to hash.
Returns
-------
result: int
result : int
The hash value
"""
if isinstance(value, Expr):
Expand Down Expand Up @@ -852,12 +854,12 @@ def to_a_normal_form(expr, mod=None):
expr : tvm.relay.Expr
The input expression.
mod: Optional[tvm.relay.Module]
mod : Optional[tvm.relay.Module]
The global module.
Returns
-------
expr: tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.to_a_normal_form(expr, mod)
Expand All @@ -871,7 +873,7 @@ def to_graph_normal_form(expr):
The input expression
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression
"""
return _ir_pass.to_graph_normal_form(expr)
Expand Down Expand Up @@ -920,7 +922,7 @@ def get_total_mac_number(expr):
Returns
-------
ret : int64
result : int64
The number of MACs (multiply-accumulate) of a model
"""
return _ir_pass.GetTotalMacNumber(expr)
Expand All @@ -935,17 +937,17 @@ def eliminate_common_subexpr(expr, fskip=None):
expr : tvm.relay.Expr
The input expression.
fskip: function
fskip : function
The callback function that decides whether an expression should be skipped.
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.eliminate_common_subexpr(expr, fskip)

def partial_evaluate(expr):
def partial_evaluate(expr, mod=None):
"""
Evaluate the static fragment of the code.
Expand All @@ -954,9 +956,12 @@ def partial_evaluate(expr):
expr : tvm.relay.Expr
The input expression.
mod : Optional[tvm.relay.Module]
The global module
Returns
-------
expr : tvm.relay.Expr
result : tvm.relay.Expr
The output expression.
"""
return _ir_pass.partial_evaluate(expr)
return _ir_pass.partial_evaluate(expr, mod)
6 changes: 3 additions & 3 deletions src/relay/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,8 @@ TVM_REGISTER_API("relay._make.Call")

TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)
.set_dispatch<CallNode>([](const CallNode* node, tvm::IRPrinter* p) {
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
p->stream << "CallNode(" << node->op << ", " << node->args << ", "
<< node->attrs << ", " << node->type_args << ")";
});

Let LetNode::make(Var var, Expr value, Expr body) {
Expand Down Expand Up @@ -324,7 +324,7 @@ TVM_STATIC_IR_FUNCTOR_REGISTER(IRPrinter, vtable)

TVM_REGISTER_API("relay._expr.TempExprRealize")
.set_body_typed<Expr(TempExpr)>([](TempExpr temp) {
return temp->Realize();
return temp->Realize();
});

} // namespace relay
Expand Down
Loading

0 comments on commit 680d141

Please sign in to comment.