Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Enhance Substitute, python bindings for Substitute/PostOrderVisit #5400

Merged
merged 1 commit into from
Apr 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/api/python/tir.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,10 @@ tvm.tir.analysis
:members:
:imported-members:
:autosummary:


tvm.tir.stmt_functor
--------------------
.. automodule:: tvm.tir.stmt_functor
:members:
:autosummary:
7 changes: 7 additions & 0 deletions include/tvm/runtime/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,10 @@ struct PackedFuncValueConverter<::tvm::runtime::String> {
}
};

/*! \brief Helper to represent nullptr for optional. */
struct NullOptType {
};

/*!
* \brief Optional container that to represent to a Nullable variant of T.
* \tparam T The original ObjectRef.
Expand Down Expand Up @@ -642,6 +646,8 @@ class Optional : public ObjectRef {
* \param ptr
*/
explicit Optional(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! \brief Nullopt handling */
Optional(NullOptType) {} // NOLINT(*)
// nullptr handling.
// disallow implicit conversion as 0 can be implicitly converted to nullptr_t
explicit Optional(std::nullptr_t) {}
Expand Down Expand Up @@ -751,6 +757,7 @@ struct PackedFuncValueConverter<Optional<T>> {
// expose the functions to the root namespace.
using runtime::String;
using runtime::Optional;
constexpr runtime::NullOptType NullOpt{};
} // namespace tvm

namespace std {
Expand Down
34 changes: 0 additions & 34 deletions include/tvm/tir/ir_pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,40 +81,6 @@ bool ExprUseVar(const PrimExpr& e, const std::unordered_set<const VarNode*>& vse
*/
TVM_DLL Stmt ConvertSSA(Stmt stmt);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt,
const std::unordered_map<const VarNode*, PrimExpr>& value_map);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param expr The source expression to be substituted
* \param value_map The map of new values.
* \return The converted expression.
*/
PrimExpr Substitute(PrimExpr expr,
const std::unordered_map<const VarNode*, PrimExpr>& value_map);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param stmt The source statement to be substituted
* \param value_map The map of new values.
* \return The converted form.
*/
Stmt Substitute(Stmt stmt, const Map<Var, PrimExpr>& value_map);

/*!
* \brief Substitute the var specified in key->var to be value.
* \param expr The source expression to be substituted
* \param value_map The map of new values.
* \return The converted expression.
*/
PrimExpr Substitute(PrimExpr expr, const Map<Var, PrimExpr>& value_map);

/*!
* \brief Verify if there is any argument bound to compact buffer.
*
Expand Down
72 changes: 64 additions & 8 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@
/*!
* \file tvm/tir/stmt_functor.h
*
* \brief Functors for tir stmts.
* \brief Functors for tir stmts
* utility functions to call common functors.
*/
#ifndef TVM_TIR_STMT_FUNCTOR_H_
#define TVM_TIR_STMT_FUNCTOR_H_

#include <tvm/node/functor.h>
#include <tvm/node/container.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/expr_functor.h>

#include <utility>
#include <unordered_map>

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -318,33 +321,86 @@ class StmtExprMutator :
};

/*!
* \brief recursively visit the ir in post DFS order node, and transform it
* \brief recursively visit the ir nodes in post DFS order, and transform it
*
* \param node The ir to be transformed.
* \param stmt The ir to be transformed.
* \param preorder The function called in before recursive mutation
* If preorder returns None, then the transform will proceed to recursive call.
* If preorder returns a not None Stmt/Expr, the transformer will simply return it and
* won't do further recursion.
* \param postorder The function called after recursive mutation.
* The recursive mutation result is passed to postorder for further mutation.
* \param only_enable List of runtime::String.
* If it is empty, all IRNode will call preorder/postorder
* If it is not empty, preorder/postorder will only be called
* If it is null, all IRNode will call preorder/postorder
* If it is not null, preorder/postorder will only be called
* when the IRNode's type key is in the list.
*/
TVM_DLL Stmt IRTransform(Stmt node,
TVM_DLL Stmt IRTransform(Stmt stmt,
const runtime::PackedFunc& preorder,
const runtime::PackedFunc& postorder,
const Array<runtime::String>& only_enable = {});
Optional<Array<String>> only_enable = NullOpt);

/*!
* \brief recursively visit the ir in post DFS order node, apply fvisit
* \brief Recursively visit the ir in post DFS order node, apply fvisit
* Each node is guaranteed to be visited only once.
* \param node The ir to be visited.
* \param fvisit The visitor function to be applied.
*/
TVM_DLL void PostOrderVisit(const ObjectRef& node, std::function<void(const ObjectRef&)> fvisit);

/*!
* \brief Substitute the var specified by vmap.
* \param stmt The source statement to be substituted
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
* \return The converted form.
*/
TVM_DLL Stmt Substitute(Stmt stmt,
std::function<Optional<PrimExpr>(const Var& var)> vmap);

/*!
* \brief Substitute the var specified by vmap.
* \param expr The source statement to be substituted
* \param vmap returns a new value if re-mapping is needed, otherwise returns nullptr.
* \return The result.
*/
TVM_DLL PrimExpr Substitute(PrimExpr expr,
std::function<Optional<PrimExpr>(const Var& var)> vmap);

/*!
* \brief Sugar for substitute via a given map.
* \param input The input to be updated.
* \param value_map The map of new values.
* \return The result.
* \tparam T the input type, can be PrimExpr or Stmt.
*/
template<typename T>
inline T Substitute(T input, const Map<Var, PrimExpr>& value_map) {
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
auto it = value_map.find(var);
if (it != value_map.end()) return (*it).second;
return Optional<PrimExpr>(nullptr);
};
return Substitute(std::move(input), vmap);
}

/*!
* \brief Sugar for substitute via a given map.
* \param input The input to be updated.
* \param value_map The map of new values.
* \return The result.
* \tparam T the input type, can be PrimExpr or Stmt.
*/
template<typename T>
inline T Substitute(T input,
const std::unordered_map<const VarNode*, PrimExpr>& value_map) {
auto vmap = [&](const Var& var) -> Optional<PrimExpr> {
auto it = value_map.find(var.get());
if (it != value_map.end()) return (*it).second;
return Optional<PrimExpr>(nullptr);
};
return Substitute(std::move(input), vmap);
}

} // namespace tir
} // namespace tvm

Expand Down
4 changes: 2 additions & 2 deletions python/tvm/te/hybrid/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def _pruned_source(func):
def replace_io(body, rmap):
"""Replacing tensors usage according to the dict given"""
# pylint: disable=import-outside-toplevel
from tvm.tir import ir_pass
from tvm.tir import stmt_functor

def replace(op):
if isinstance(op, _stmt.Provide) and op.func in rmap.keys():
Expand All @@ -84,7 +84,7 @@ def replace(op):
_expr.Call.Halide, buf.op, buf.value_index)
return None

return ir_pass.IRTransform(body, None, replace, ['Provide', 'Call'])
return stmt_functor.ir_transform(body, None, replace, ['Provide', 'Call'])


def _is_tvm_arg_types(args):
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@
from . import ir_pass
from . import transform
from . import analysis
from . import stmt_functor
77 changes: 77 additions & 0 deletions python/tvm/tir/stmt_functor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Statement functor utilities for IR transformations"""
from . import _ffi_api


def ir_transform(stmt, preorder, postorder, only_enable=None):
"""Recursively visit and transform ir nodes in post DFS order.
Parameters
----------
stmt : Stmt
The input to be transformed.
preorder: function
The function called in before recursive mutation
If preorder returns None, then the transform will proceed to recursive call.
If preorder returns a not None Stmt/Expr, the transformer will simply return it and
won't do further recursion.
postorder : function
The function called after recursive mutation.
only_enable : Optional[List[str]]
List of types that we only enable.
Returns
-------
result : Stmt
The result.
"""
return _ffi_api.IRTransform(stmt, preorder, postorder, only_enable)


def post_order_visit(stmt, fvisit):
"""Recursively visit the ir in post DFS order node, apply fvisit
Each node is guaranteed to be visited only once.
Parameters
----------
fvisit: function
The visitor function.
"""
return _ffi_api.PostOrderVisit(stmt, fvisit)


def substitute(node, vmap):
""" Substitute the var specified by vmap.
Parameters
----------
node: ObjectRef
The input.
vmap : Dict[Var, PrimExpr]
The variable mapping.
Returns
-------
result : Stmt
The result.
"""
return _ffi_api.Substitute(node, vmap)
31 changes: 16 additions & 15 deletions src/arith/solve_linear_equation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@
#include <tvm/arith/analyzer.h>
#include <tvm/arith/int_solver.h>
#include <tvm/arith/util.h>
#include <tvm/tir/op.h>
#include <tvm/arith/pattern.h>
#include <tvm/tir/ir_pass.h>

#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/runtime/data_type.h>

namespace tvm {
Expand Down Expand Up @@ -130,10 +131,10 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
(*S)[i][j] = new_i_j;
}
// We have to do the same with rhs
PrimExpr ea = te::make_const((*y)[index].dtype(), a);
PrimExpr eb = te::make_const((*y)[i].dtype(), b);
PrimExpr e_m_g = te::make_const((*y)[i].dtype(), m_g);
PrimExpr e_n_g = te::make_const((*y)[index].dtype(), n_g);
PrimExpr ea = tir::make_const((*y)[index].dtype(), a);
PrimExpr eb = tir::make_const((*y)[i].dtype(), b);
PrimExpr e_m_g = tir::make_const((*y)[i].dtype(), m_g);
PrimExpr e_n_g = tir::make_const((*y)[index].dtype(), n_g);
PrimExpr new_index_rhs = ea*(*y)[index] + eb*(*y)[i];
PrimExpr new_i_rhs = e_n_g*(*y)[index] - e_m_g*(*y)[i];
(*y)[index] = new_index_rhs;
Expand Down Expand Up @@ -190,10 +191,10 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t> >* S,
(*V)[i][j] = new_i_j;
}
// And apply reverse transformations to new_to_old.
PrimExpr ea = te::make_const((*x)[j].dtype(), a);
PrimExpr eb = te::make_const((*x)[index].dtype(), b);
PrimExpr e_m_g = te::make_const((*x)[index].dtype(), m_g);
PrimExpr e_n_g = te::make_const((*x)[j].dtype(), n_g);
PrimExpr ea = tir::make_const((*x)[j].dtype(), a);
PrimExpr eb = tir::make_const((*x)[index].dtype(), b);
PrimExpr e_m_g = tir::make_const((*x)[index].dtype(), m_g);
PrimExpr e_n_g = tir::make_const((*x)[j].dtype(), n_g);
PrimExpr new_index = e_m_g*(*x)[index] + e_n_g*(*x)[j];
PrimExpr new_j = eb*(*x)[index] - ea*(*x)[j];
(*x)[index] = new_index;
Expand Down Expand Up @@ -369,7 +370,7 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol
IntConstraints(
/*variables=*/{},
/*ranges=*/{},
/*relations=*/{te::make_zero(DataType::Bool())}),
/*relations=*/{tir::make_zero(DataType::Bool())}),
{}, {});
} else if (!tir::is_const_int(new_relation, 1)) {
new_relations.push_back(new_relation);
Expand Down Expand Up @@ -403,13 +404,13 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol
// The j-th variable is just a single value, don't create a tvm variable
// S^{-1}_{nxm} Uy_{mxn}
if (S[j][j] >= 0) {
PrimExpr a = te::make_const(Uy[j].dtype(), S[j][j]);
PrimExpr a = tir::make_const(Uy[j].dtype(), S[j][j]);
solution_for_V_inv_x.push_back(
analyzer_problem.Simplify(floordiv(Uy[j], a)));
} else {
// This is required because some simplifiers
// have problems with dividing by negative numbers
PrimExpr a = te::make_const(Uy[j].dtype(), -S[j][j]);
PrimExpr a = tir::make_const(Uy[j].dtype(), -S[j][j]);
solution_for_V_inv_x.push_back(
analyzer_problem.Simplify(floordiv(-Uy[j], a)));
}
Expand All @@ -418,9 +419,9 @@ IntConstraintsTransform SolveLinearEquations(const IntConstraints &system_to_sol

// V V^{-1} x = x
for (size_t i = 0; i < num_vars; ++i) {
PrimExpr e = te::make_zero(system_to_solve->variables[i].dtype());
PrimExpr e = tir::make_zero(system_to_solve->variables[i].dtype());
for (size_t j = 0; j < num_vars; ++j) {
e = e + te::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j];
e = e + tir::make_const(e.dtype(), V[i][j])*solution_for_V_inv_x[j];
}
e = analyzer_problem.Simplify(e);
old_to_new_map.Set(system_to_solve->variables[i], e);
Expand Down
2 changes: 1 addition & 1 deletion src/te/autodiff/ad_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
* \brief Utility for tensor-level auto-differentiation.
*/
#include <tvm/tir/expr.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/stmt_functor.h>
#include <string>
#include "ad_util.h"

Expand Down
1 change: 0 additions & 1 deletion src/te/operation/hybrid_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <tvm/arith/analyzer.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/ir_pass.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/op.h>
#include <unordered_set>
Expand Down
Loading