Skip to content

Commit

Permalink
[TIR] Enhance Substitute, python bindings for Substitute/PostOrderVis…
Browse files Browse the repository at this point in the history
…it/IRTransform. (apache#5400)

Substitute now takes a std::function to customize more replacing behaviors.

Co-authored-by: Siyuan Feng <[email protected]>

Co-authored-by: Siyuan Feng <[email protected]>
  • Loading branch information
2 people authored and Trevor Morris committed Jun 18, 2020
1 parent cc3852b commit 55a173a
Show file tree
Hide file tree
Showing 34 changed files with 419 additions and 317 deletions.
7 changes: 7 additions & 0 deletions docs/api/python/tir.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,10 @@ tvm.tir.analysis
:members:
:imported-members:
:autosummary:


tvm.tir.stmt_functor
--------------------
.. automodule:: tvm.tir.stmt_functor
:members:
:autosummary:
7 changes: 7 additions & 0 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,10 @@ struct PackedFuncValueConverter<::tvm::runtime::String> {
}
};

/*! \brief Helper to represent nullptr for optional. */
struct NullOptType {
};

/*!
* \brief Optional container that to represent to a Nullable variant of T.
* \tparam T The original ObjectRef.
Expand Down Expand Up @@ -642,6 +646,8 @@ class Optional : public ObjectRef {
* \param ptr
*/
explicit Optional(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! \brief Nullopt handling */
Optional(NullOptType) {} // NOLINT(*)
// nullptr handling.
// disallow implicit conversion as 0 can be implicitly converted to nullptr_t
explicit Optional(std::nullptr_t) {}
Expand Down Expand Up @@ -751,6 +757,7 @@ struct PackedFuncValueConverter<Optional<T>> {
// expose the functions to the root namespace.
using runtime::String;
using runtime::Optional;
constexpr runtime::NullOptType NullOpt{};
} // namespace tvm

namespace std {
Expand Down
34 changes: 0 additions & 34 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,40 +81,6 @@ bool ExprUseVar(const PrimExpr& e, const std::unordered_set<const VarNode*>& vse
*/
TVM_DLL Stmt ConvertSSA(Stmt stmt);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt,
const std::unordered_map<const VarNode*, PrimExpr>& value_map);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param expr The source expression to be substituted
* \param value_map The map of new values.
* \return The converted expression.
*/
PrimExpr Substitute(PrimExpr expr,
const std::unordered_map<const VarNode*, PrimExpr>& value_map);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param expr The source expression to be substituted
* \param value_map The map of new values.
* \return The converted expression.
*/
PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map);

/*!
* \brief Verify if there is any argument bound to compact buffer.
*
Expand Down
72 changes: 64 additions & 8 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@
/*!
* \file tvm/tir/stmt_functor.h
*
* \brief Functors for tir stmts.
* \brief Functors for tir stmts
* utility functions to call common functors.
*/
#ifndef TVM_TIR_STMT_FUNCTOR_H_
#define TVM_TIR_STMT_FUNCTOR_H_

#include <tvm/node/functor.h>
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/expr_functor.h>

#include <utility>
#include <unordered_map>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -318,33 +321,86 @@ class StmtExprMutator :
};

/*!
* \brief recursively visit the ir in post DFS order node, and transform it
* \brief recursively visit the ir nodes in post DFS order, and transform it
*
* \param node The ir to be transformed.
* \param stmt The ir to be transformed.
* \param preorder The function called in before recursive mutation
* If preorder returns None, then the transform will proceed to recursive call.
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of runtime::String.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* If it is null, all IRNode will call preorder/postorder
* If it is not null, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
TVM_DLL Stmt IRTransform(Stmt node,
TVM_DLL Stmt IRTransform(Stmt stmt,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<runtime::String>& only_enable = {});
Optional<Array<String>> only_enable = NullOpt);

/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* \brief Recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);

/*!
* \brief Substitute the var specified by vmap.
* \param stmt The source statement to be substituted
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
* \return The converted form.
*/
TVM_DLL Stmt Substitute(Stmt stmt,
std::function<Optional<PrimExpr>(const Var& var)> vmap);

/*!
* \brief Substitute the var specified by vmap.
* \param expr The source statement to be substituted
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
* \return The result.
*/
TVM_DLL PrimExpr Substitute(PrimExpr expr,
std::function<Optional<PrimExpr>(const Var& var)> vmap);

/*!
* \brief Sugar for substitute via a given map.
* \param input The input to be updated.
* \param value_map The map of new values.
* \return The result.
* \tparam T the input type, can be PrimExpr or Stmt.
*/
template<typename T>
inline T Substitute(T input, const Map<Var, PrimExpr>& value_map) {
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
auto it = value_map.find(var);
if (it != value_map.end()) return (*it).second;
return Optional<PrimExpr>(nullptr);
};
return Substitute(std::move(input), vmap);
}

/*!
* \brief Sugar for substitute via a given map.
* \param input The input to be updated.
* \param value_map The map of new values.
* \return The result.
* \tparam T the input type, can be PrimExpr or Stmt.
*/
template<typename T>
inline T Substitute(T input,
const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
auto it = value_map.find(var.get());
if (it != value_map.end()) return (*it).second;
return Optional<PrimExpr>(nullptr);
};
return Substitute(std::move(input), vmap);
}

} // namespace tir
} // namespace tvm

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/te/hybrid/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _pruned_source(func):
def replace_io(body, rmap):
"""Replacing tensors usage according to the dict given"""
# pylint: disable=import-outside-toplevel
from tvm.tir import ir_pass
from tvm.tir import stmt_functor

def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
Expand All @@ -84,7 +84,7 @@ def replace(op):
_expr.Call.Halide, buf.op, buf.value_index)
return None

return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call'])
return stmt_functor.ir_transform(body, None, replace, ['Provide', 'Call'])


def _is_tvm_arg_types(args):
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@
from . import ir_pass
from . import transform
from . import analysis
from . import stmt_functor
77 changes: 77 additions & 0 deletions python/tvm/tir/stmt_functor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Statement functor utilities for IR transformations"""
from . import _ffi_api


def ir_transform(stmt, preorder, postorder, only_enable=None):
"""Recursively visit and transform ir nodes in post DFS order.
Parameters
----------
stmt : Stmt
The input to be transformed.
preorder: function
The function called in before recursive mutation
If preorder returns None, then the transform will proceed to recursive call.
If preorder returns a not None Stmt/Expr, the transformer will simply return it and
won't do further recursion.
postorder : function
The function called after recursive mutation.
only_enable : Optional[List[str]]
List of types that we only enable.
Returns
-------
result : Stmt
The result.
"""
return _ffi_api.IRTransform(stmt, preorder, postorder, only_enable)


def post_order_visit(stmt, fvisit):
"""Recursively visit the ir in post DFS order node, apply fvisit
Each node is guaranteed to be visited only once.
Parameters
----------
fvisit: function
The visitor function.
"""
return _ffi_api.PostOrderVisit(stmt, fvisit)


def substitute(node, vmap):
""" Substitute the var specified by vmap.
Parameters
----------
node: ObjectRef
The input.
vmap : Dict[Var, PrimExpr]
The variable mapping.
Returns
-------
result : Stmt
The result.
"""
return _ffi_api.Substitute(node, vmap)
31 changes: 16 additions & 15 deletions src/arith/solve_linear_equation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_solver.h>
#include <tvm/arith/util.h>
#include <tvm/tir/op.h>
#include <tvm/arith/pattern.h>
#include <tvm/tir/ir_pass.h>

#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/runtime/data_type.h>

namespace tvm {
Expand Down Expand Up @@ -130,10 +131,10 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
(*S)[i][j] = new_i_j;
}
// We have to do the same with rhs
PrimExpr ea = te::make_const((*y)[index].dtype(), a);
PrimExpr eb = te::make_const((*y)[i].dtype(), b);
PrimExpr e_m_g = te::make_const((*y)[i].dtype(), m_g);
PrimExpr e_n_g = te::make_const((*y)[index].dtype(), n_g);
PrimExpr ea = tir::make_const((*y)[index].dtype(), a);
PrimExpr eb = tir::make_const((*y)[i].dtype(), b);
PrimExpr e_m_g = tir::make_const((*y)[i].dtype(), m_g);
PrimExpr e_n_g = tir::make_const((*y)[index].dtype(), n_g);
PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i];
PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i];
(*y)[index] = new_index_rhs;
Expand Down Expand Up @@ -190,10 +191,10 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
(*V)[i][j] = new_i_j;
}
// And apply reverse transformations to new_to_old.
PrimExpr ea = te::make_const((*x)[j].dtype(), a);
PrimExpr eb = te::make_const((*x)[index].dtype(), b);
PrimExpr e_m_g = te::make_const((*x)[index].dtype(), m_g);
PrimExpr e_n_g = te::make_const((*x)[j].dtype(), n_g);
PrimExpr ea = tir::make_const((*x)[j].dtype(), a);
PrimExpr eb = tir::make_const((*x)[index].dtype(), b);
PrimExpr e_m_g = tir::make_const((*x)[index].dtype(), m_g);
PrimExpr e_n_g = tir::make_const((*x)[j].dtype(), n_g);
PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j];
PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j];
(*x)[index] = new_index;
Expand Down Expand Up @@ -369,7 +370,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol
IntConstraints(
/*variables=*/{},
/*ranges=*/{},
/*relations=*/{te::make_zero(DataType::Bool())}),
/*relations=*/{tir::make_zero(DataType::Bool())}),
{}, {});
} else if (!tir::is_const_int(new_relation, 1)) {
new_relations.push_back(new_relation);
Expand Down Expand Up @@ -403,13 +404,13 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol
// The j-th variable is just a single value, don't create a tvm variable
// S^{-1}_{nxm} Uy_{mxn}
if (S[j][j] >= 0) {
PrimExpr a = te::make_const(Uy[j].dtype(), S[j][j]);
PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]);
solution_for_V_inv_x.push_back(
analyzer_problem.Simplify(floordiv(Uy[j], a)));
} else {
// This is required because some simplifiers
// have problems with dividing by negative numbers
PrimExpr a = te::make_const(Uy[j].dtype(), -S[j][j]);
PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]);
solution_for_V_inv_x.push_back(
analyzer_problem.Simplify(floordiv(-Uy[j], a)));
}
Expand All @@ -418,9 +419,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol

// V V^{-1} x = x
for (size_t i = 0; i < num_vars; ++i) {
PrimExpr e = te::make_zero(system_to_solve->variables[i].dtype());
PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype());
for (size_t j = 0; j < num_vars; ++j) {
e = e + te::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j];
e = e + tir::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j];
}
e = analyzer_problem.Simplify(e);
old_to_new_map.Set(system_to_solve->variables[i], e);
Expand Down
2 changes: 1 addition & 1 deletion src/te/autodiff/ad_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
* \brief Utility for tensor-level auto-differentiation.
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
#include <string>
#include "ad_util.h"

Expand Down
1 change: 0 additions & 1 deletion src/te/operation/hybrid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
#include <unordered_set>
Expand Down
Loading

0 comments on commit 55a173a

Please sign in to comment.