Skip to content

Commit

Permalink
[REFACTOR][TIR] Introduce PrimFuncPass. (apache#5139)
Browse files Browse the repository at this point in the history
* [REFACTOR][TIR] Introduce PrimFuncPass.

- Introduce PrimFuncPass
- Convert one pass to the unified Pass API.

* Address comments

* Fix comments
  • Loading branch information
tqchen authored and Trevor Morris committed Apr 16, 2020
1 parent da07aed commit 3a3a30d
Show file tree
Hide file tree
Showing 18 changed files with 570 additions and 6 deletions.
9 changes: 9 additions & 0 deletions docs/api/python/tir.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,12 @@ tvm.tir
:imported-members:
:exclude-members: PrimExpr, const
:autosummary:



tvm.tir.transform
-----------------
.. automodule:: tvm.tir.transform
:members:
:imported-members:
:autosummary:
2 changes: 1 addition & 1 deletion include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ class RelayExprNode : public BaseExprNode {
/*!
* \return The checked_type
*/
const Type& checked_type() const;
inline const Type& checked_type() const;
/*!
* \brief Check if the inferred(checked) type of the Expr
* is backed by a TTypeNode and return it.
Expand Down
4 changes: 4 additions & 0 deletions include/tvm/ir/type_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
virtual R VisitType_(const TypeCallNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const TypeDataNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const PrimTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitType_(const PointerTypeNode* op, Args... args) TYPE_FUNCTOR_DEFAULT;
virtual R VisitTypeDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
throw; // unreachable, written to stop compiler warning
Expand All @@ -115,6 +116,7 @@ class TypeFunctor<R(const Type& n, Args...)> {
TVM_TYPE_FUNCTOR_DISPATCH(TypeCallNode);
TVM_TYPE_FUNCTOR_DISPATCH(TypeDataNode);
TVM_TYPE_FUNCTOR_DISPATCH(PrimTypeNode);
TVM_TYPE_FUNCTOR_DISPATCH(PointerTypeNode);
return vtable;
}
};
Expand All @@ -138,6 +140,7 @@ class TVM_DLL TypeVisitor :
void VisitType_(const TypeCallNode* op) override;
void VisitType_(const TypeDataNode* op) override;
void VisitType_(const PrimTypeNode* op) override;
void VisitType_(const PointerTypeNode* op) override;
};

/*!
Expand All @@ -158,6 +161,7 @@ class TVM_DLL TypeMutator :
Type VisitType_(const TypeCallNode* op) override;
Type VisitType_(const TypeDataNode* op) override;
Type VisitType_(const PrimTypeNode* op) override;
Type VisitType_(const PointerTypeNode* op) override;

private:
Array<Type> MutateArray(Array<Type> arr);
Expand Down
72 changes: 72 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/tir/transform.h
* \brief TIR specific transformation passes.
*/
#ifndef TVM_TIR_TRANSFORM_H_
#define TVM_TIR_TRANSFORM_H_

#include <tvm/ir/transform.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/function.h>

#include <string>

namespace tvm {
namespace tir {
namespace transform {

using tvm::transform::Pass;
using tvm::transform::PassNode;
using tvm::transform::PassInfo;
using tvm::transform::PassInfoNode;
using tvm::transform::PassContext;
using tvm::transform::PassContextNode;
using tvm::transform::Sequential;

/*
* \brief Create a function pass that optimizes PrimFuncs.
*
* \param pass_func The packed function that contains the optimization.
* \param opt_level The optimization level of the function pass.
* \param name The name of the function pass.
* \param required The list of the passes that the function pass is dependent on.
*
* \return The created function pass.
*/
TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc<
PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func,
int opt_level,
const std::string& name,
const tvm::Array<tvm::PrimExpr>& required);

/*!
* \brief Create PrimFuncPass to combine context calls in the host function.
*
* \return The pass.
*/
Pass CombineContextCall();

} // namespace transform
} // namespace tir
} // namespace tvm

#endif // TVM_TIR_TRANSFORM_H_
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,4 @@

from . import ir_builder
from . import ir_pass
from . import transform
21 changes: 21 additions & 0 deletions python/tvm/tir/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Namespace of all TIR transformations"""
# pylint: disable=wildcard-import, invalid-name

from .function_pass import prim_func_pass, PrimFuncPass
from .transform import *
21 changes: 21 additions & 0 deletions python/tvm/tir/transform/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI APIs for tvm.tir.transform"""
import tvm._ffi


tvm._ffi._init_api("tir.transform", __name__)
149 changes: 149 additions & 0 deletions python/tvm/tir/transform/function_pass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""TIR specific function pass support."""
import inspect
import functools

import tvm._ffi
from tvm.ir.transform import Pass, PassInfo

from . import _ffi_api


@tvm._ffi.register_object("tir.PrimFuncPass")
class PrimFuncPass(Pass):
"""A pass that works on each :py:func:`tvm.tir.PrimFunc` in a module. A function
pass class should be created through py:func:`tvm.tir.transform.function_pass`.
"""


def _wrap_class_function_pass(pass_cls, pass_info):
"""Wrap a python class as function pass"""
class PyFunctionPass(PrimFuncPass):
"""Internal wrapper class to create a class instance."""
def __init__(self, *args, **kwargs):
# initialize handle in cass pass_cls creation failed.fg
self.handle = None
inst = pass_cls(*args, **kwargs)
# it is important not to capture self to
# avoid a cyclic dependency
def _pass_func(func, mod, ctx):
return inst.transform_function(func, mod, ctx)
self.__init_handle_by_constructor__(
_ffi_api.CreatePrimFuncPass, _pass_func, pass_info)
self._inst = inst

def __getattr__(self, name):
# fall back to instance attribute if there is not any
return self._inst.__getattribute__(name)

functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__)
PyFunctionPass.__name__ = pass_cls.__name__
PyFunctionPass.__doc__ = pass_cls.__doc__
PyFunctionPass.__module__ = pass_cls.__module__
return PyFunctionPass


def prim_func_pass(pass_func=None, opt_level=None, name=None, required=None):
"""Decorate a function pass.
This function returns a callback when pass_func
is provided. Otherwise, it returns the created function pass using the
given optimization function.
Parameters
----------
pass_func : Optional[Callable[(PrimFunc, IRModule, PassContext) -> PrimFunc]]
The transformation function or class.
opt_level : int
The optimization level of this module pass.
name : Optional[str]
The name of the function pass. The name could be empty. In this case, the
name of the optimization function will be used as the pass name.
required : Optional[List[str]]
The list of passes that the function pass is dependent on.
Returns
-------
create_function_pass : Union[Callable, FunctionPass]
A decorator will be returned if pass_func is not provided,
otherwise return the decorated result.
The returned decorator has two behaviors depending on the input:
A new FunctionPass will be returned when we decorate a pass function.
A new FunctionPass class will be returned when we decorate a class type.
Examples
--------
The following code block decorates a function pass class.
.. code-block:: python
@tvm.tir.transform.prim_func_pass(opt_level=1)
class TestReplaceFunc:
def __init__(self, new_func):
self.new_func = new_func
def transform_function(self, func, mod, ctx):
# just for demo purposes
# transform func to new_func
return self.new_func
The following code creates a function pass by decorating
a user defined transform function.
.. code-block:: python
@tvm.tir.transform.prim_func_pass(opt_level=2)
def transform(func, mod, ctx):
# my transformations here.
return func
function_pass = transform
assert isinstance(function_pass, transform.FunctionPass)
assert function_pass.info.opt_level == 2
# Given a module m, the optimization could be invoked as the follwoing:
updated_mod = function_pass(m)
# Now constant folding should have been applied to every function in
# the provided module m. And the updated module will be returned.
"""

if opt_level is None:
raise ValueError("Please provide opt_level for the funtion pass.")

required = required if required else []
if not isinstance(required, (list, tuple)):
raise TypeError("Required is expected to be the type of " +
"list/tuple.")

def create_function_pass(pass_arg):
"""Internal function that creates a function pass"""
fname = name if name else pass_arg.__name__
info = PassInfo(opt_level, fname, required)
if inspect.isclass(pass_arg):
return _wrap_class_function_pass(pass_arg, info)
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
raise TypeError("pass_func must be a callable for Module pass")
return _ffi_api.MakeFunctionPass(pass_arg, info)

if pass_func:
return create_function_pass(pass_func)
return create_function_pass
31 changes: 31 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Wrapping existing transformations."""
# pylint: disable=invalid-name

from . import _ffi_api


def CombineContextCall():
"""Combine context calls in the host function.
Returns
-------
fpass : tvm.ir.transform.Pass
The result pass
"""
return _ffi_api.CombineContextCall()
4 changes: 2 additions & 2 deletions src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ void IRModuleNode::Add(const GlobalVar& var,
GetRef<relay::Function>(ptr));
}

auto type = checked_func->checked_type();
Type type = checked_func->checked_type();
CHECK(type.as<relay::IncompleteTypeNode>() == nullptr);

if (functions.find(var) != functions.end()) {
CHECK(update)
<< "Already have definition for " << var->name_hint;
auto old_type = functions[var].as<relay::FunctionNode>()->checked_type();
auto old_type = functions[var]->checked_type();
CHECK(relay::AlphaEqual(type, old_type))
<< "Module#update changes type, not possible in this mode.";
}
Expand Down
Loading

0 comments on commit 3a3a30d

Please sign in to comment.