-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[REFACTOR][TIR] Introduce PrimFuncPass. #5139
Merged
Merged
Changes from 1 commit
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* \file tvm/tir/transform.h | ||
* \brief TIR specific transformation passes. | ||
*/ | ||
#ifndef TVM_TIR_TRANSFORM_H_ | ||
#define TVM_TIR_TRANSFORM_H_ | ||
|
||
#include <tvm/ir/transform.h> | ||
#include <tvm/tir/expr.h> | ||
#include <tvm/tir/function.h> | ||
|
||
#include <string> | ||
|
||
namespace tvm { | ||
namespace tir { | ||
namespace transform { | ||
|
||
using tvm::transform::Pass; | ||
using tvm::transform::PassNode; | ||
using tvm::transform::PassInfo; | ||
using tvm::transform::PassInfoNode; | ||
using tvm::transform::PassContext; | ||
using tvm::transform::PassContextNode; | ||
using tvm::transform::Sequential; | ||
|
||
/* | ||
* \brief Create a function pass that optimizes PrimFuncs. | ||
* | ||
* \param pass_func The packed function that contains the optimization. | ||
* \param opt_level The optimization level of the function pass. | ||
* \param name The name of the function pass. | ||
* \param required The list of the passes that the function pass is dependent on. | ||
* | ||
* \return The created function pass. | ||
*/ | ||
TVM_DLL Pass CreatePrimFuncPass(const runtime::TypedPackedFunc< | ||
PrimFunc(PrimFunc, IRModule, PassContext)>& pass_func, | ||
int opt_level, | ||
const std::string& name, | ||
const tvm::Array<tvm::PrimExpr>& required); | ||
|
||
/*! | ||
* \brief Create PrimFuncPass to combine context calls in the host function. | ||
* | ||
* \return The pass. | ||
*/ | ||
Pass CombineContextCall(); | ||
|
||
} // namespace transform | ||
} // namespace tir | ||
} // namespace tvm | ||
|
||
#endif // TVM_TIR_TRANSFORM_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,3 +45,4 @@ | |
|
||
from . import ir_builder | ||
from . import ir_pass | ||
from . import transform |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Namespace of all TIR transformations""" | ||
# pylint: disable=wildcard-import, invalid-name | ||
|
||
from .function_pass import prim_func_pass, PrimFuncPass | ||
from .transform import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""FFI APIs for tvm.tir.transform""" | ||
import tvm._ffi | ||
|
||
|
||
tvm._ffi._init_api("tir.transform", __name__) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""TIR specific function pass support.""" | ||
import inspect | ||
import functools | ||
|
||
import tvm._ffi | ||
from tvm.ir.transform import Pass, PassInfo | ||
|
||
from . import _ffi_api | ||
|
||
|
||
@tvm._ffi.register_object("tir.PrimFuncPass") | ||
class PrimFuncPass(Pass): | ||
"""A pass that works on each tvm.relay.Function in a module. A function | ||
pass class should be created through py:func:`tvm.tir.transform.function_pass`. | ||
""" | ||
|
||
|
||
def _wrap_class_function_pass(pass_cls, pass_info): | ||
"""Wrap a python class as function pass""" | ||
class PyFunctionPass(PrimFuncPass): | ||
"""Internal wrapper class to create a class instance.""" | ||
def __init__(self, *args, **kwargs): | ||
# initialize handle in cass pass_cls creation failed.fg | ||
self.handle = None | ||
inst = pass_cls(*args, **kwargs) | ||
# it is important not to capture self to | ||
# avoid a cyclic dependency | ||
def _pass_func(func, mod, ctx): | ||
return inst.transform_function(func, mod, ctx) | ||
self.__init_handle_by_constructor__( | ||
_ffi_api.CreatePrimFuncPass, _pass_func, pass_info) | ||
self._inst = inst | ||
|
||
def __getattr__(self, name): | ||
# fall back to instance attribute if there is not any | ||
return self._inst.__getattribute__(name) | ||
|
||
functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__) | ||
PyFunctionPass.__name__ = pass_cls.__name__ | ||
PyFunctionPass.__doc__ = pass_cls.__doc__ | ||
PyFunctionPass.__module__ = pass_cls.__module__ | ||
return PyFunctionPass | ||
|
||
|
||
def prim_func_pass(pass_func=None, opt_level=None, name=None, required=None): | ||
"""Decorate a function pass. | ||
|
||
This function returns a callback when pass_func | ||
is provided. Otherwise, it returns the created function pass using the | ||
given optimization function. | ||
|
||
Parameters | ||
---------- | ||
pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function->PrimFunc Module->IRModule |
||
The transformation function or class. | ||
|
||
opt_level : int | ||
The optimization level of this module pass. | ||
|
||
name : Optional[str] | ||
The name of the function pass. The name could be empty. In this case, the | ||
name of the optimization function will be used as the pass name. | ||
|
||
required : Optional[List[str]] | ||
The list of passes that the module pass is dependent on. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. function pass |
||
|
||
Returns | ||
------- | ||
create_function_pass : Union[Callable, FunctionPass] | ||
|
||
A decorator will be returned if pass_func is not provided, | ||
otherwise return the decorated result. | ||
The returned decorator has two behaviors depending on the input: | ||
A new FunctionPass will be returned when we decorate a pass function. | ||
A new FunctionPass class will be returned when we decorate a class type. | ||
|
||
Examples | ||
-------- | ||
The following code block decorates a function pass class. | ||
|
||
.. code-block:: python | ||
|
||
@tvm.tir.transform.prim_func_pass(opt_level=1) | ||
class TestReplaceFunc: | ||
def __init__(self, new_func): | ||
self.new_func = new_func | ||
|
||
def transform_function(self, func, mod, ctx): | ||
# just for demo purposes | ||
# transform func to new_func | ||
return self.new_func | ||
|
||
The following code creates a function pass by decorating | ||
a user defined transform function. | ||
|
||
.. code-block:: python | ||
|
||
@tvm.tir.transform.prim_func_pass(opt_level=2) | ||
def transform(func, mod, ctx): | ||
# my transformations here. | ||
return func | ||
|
||
function_pass = transform | ||
assert isinstance(function_pass, transform.FunctionPass) | ||
assert function_pass.info.opt_level == 2 | ||
|
||
# Given a module m, the optimization could be invoked as the follwoing: | ||
updated_mod = function_pass(m) | ||
# Now constant folding should have been applied to every function in | ||
# the provided module m. And the updated module will be returned. | ||
""" | ||
|
||
if opt_level is None: | ||
raise ValueError("Please provide opt_level for the funtion pass.") | ||
|
||
required = required if required else [] | ||
if not isinstance(required, (list, tuple)): | ||
raise TypeError("Required is expected to be the type of " + | ||
"list/tuple.") | ||
|
||
def create_function_pass(pass_arg): | ||
"""Internal function that creates a function pass""" | ||
fname = name if name else pass_arg.__name__ | ||
info = PassInfo(opt_level, fname, required) | ||
if inspect.isclass(pass_arg): | ||
return _wrap_class_function_pass(pass_arg, info) | ||
if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)): | ||
raise TypeError("pass_func must be a callable for Module pass") | ||
return _ffi_api.MakeFunctionPass(pass_arg, info) | ||
|
||
if pass_func: | ||
return create_function_pass(pass_func) | ||
return create_function_pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# Licensed to the Apache Software Foundation (ASF) under one | ||
# or more contributor license agreements. See the NOTICE file | ||
# distributed with this work for additional information | ||
# regarding copyright ownership. The ASF licenses this file | ||
# to you under the Apache License, Version 2.0 (the | ||
# "License"); you may not use this file except in compliance | ||
# with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an | ||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
# KIND, either express or implied. See the License for the | ||
# specific language governing permissions and limitations | ||
# under the License. | ||
"""Wrapping existing transformations.""" | ||
# pylint: disable=invalid-name | ||
|
||
from . import _ffi_api | ||
|
||
|
||
def CombineContextCall(): | ||
"""Combine context calls in the host function. | ||
|
||
Returns | ||
------- | ||
fpass : tvm.ir.transform.Pass | ||
The result pass | ||
""" | ||
return _ffi_api.CombineContextCall() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tvm.tir.PrimFunc