From 0dda395eda42e7c286e06c503f6215e30bd1a2f8 Mon Sep 17 00:00:00 2001 From: PaulPalomeroBernardo Date: Fri, 1 Apr 2022 11:30:54 +0200 Subject: [PATCH] Update 00xx_UMA_Unified_Modular_Accelerator_Interface.md * Add descriptions for all API functions * Clarify backend registration and add target hook explanation * Remove schedules from API and corresponding descriptions --- ...A_Unified_Modular_Accelerator_Interface.md | 196 ++++++++++++++---- 1 file changed, 156 insertions(+), 40 deletions(-) diff --git a/rfcs/00xx_UMA_Unified_Modular_Accelerator_Interface.md b/rfcs/00xx_UMA_Unified_Modular_Accelerator_Interface.md index 6e4e0e5b..bc531a89 100644 --- a/rfcs/00xx_UMA_Unified_Modular_Accelerator_Interface.md +++ b/rfcs/00xx_UMA_Unified_Modular_Accelerator_Interface.md @@ -81,7 +81,6 @@ UMA Partitioner: * Order: pre-partitioning passes, Graph partitioning, post-partitioning passes * API level: * UMA Partitioner creates a wrapper API to TVM core-compiler APIs - * *UMAPartitioner* baseclass (Python only) has to be inherited by accelerator-specific Partitioners (e.g. Accelerator A Partitioner, etc) The figure below described the *UMA Pipeline*. The blocks are described below: @@ -93,15 +92,13 @@ UMA Pipelining: * Input: Partitioned composite functions * Custom primitives can be registered * Lowering from Relay to S-TIR, using TOPI or custom primitives - * Interface for registering accelerator-specific schedules and passes - * Execution of UMA schedules and passes on S-TIR + * Interface for registering accelerator-specific passes + * Execution of UMA passes on S-TIR and NS-TIR * Output: NS-TIR(including tir.extern calls) - * UMALower baseclass (Python only) has to be inherited by accelerator-specific Lower classes (e.g. Accelerator A Lower, etc) * UMACodegen * Input: NS-TIR(including tir.extern calls) * Defaults to standard TVM codegen * Intend is to provide a Python interface to insert/emit target code - * UMACodegen baseclass has to be inherited by accelerator-specific Codegen classes (e.g. Accelerator A Codegen, etc) * Output: Target .c files The intention is to use TensorIR with MetaScheduler for optimization and Relax (a possible succesor of Relay [video link](https://www.youtube.com/watch?v=xVbkjJDMexo)) in later versions. @@ -113,16 +110,13 @@ NS-TIR: Non-Schedulable TIR ### Adding a New Custom Accelerator -A custom accelerator is added by inheriting the `UMABackend`. New elements (e.g., passes, schedules) are added using a registration machanism. +A custom accelerator is added by inheriting the `UMABackend`. New elements (e.g., passes) are added using a registration machanism. Below example shows a backend that makes use of all available registration functions. ```python """UMA backend for the UltraTrail accelerator""" class UltraTrailBackend(UMABackend): def __init__(self): super(UltraTrailBackend, self).__init__() - - # Configuration parameters - self._register_config({"partitioning.enable_MergeCompilerRegions": False}) # Example config parameter # Relay to Relay function registration self._register_pattern("conv1d_relu", conv1d_relu_pattern()) @@ -131,15 +125,13 @@ class UltraTrailBackend(UMABackend): self._register_relay_pass(2, BufferScopeAnnotator()) # Relay to TIR function registration - self._register_operator_strategy("nn.conv1d", custom_conv1d_strategy, plevel=9) - - self._register_tir_schedule(insert_extern_calls) + self._register_operator_strategy("nn.conv1d", custom_conv1d_strategy) self._register_tir_pass(0, CodegenGenerateConfig()) - self._register_tir_pass(0, CodegenGenerateConstants()) + self._register_tir_pass(0, CodegenGenerateExternCalls()) # TIR to runtime function registration - self._register_codegen(format="c", includes=None, constants=None, replace_call_extern=None) + self._register_codegen(format="c", includes=gen_includes, replace_call_extern=None) @property def target_name(self): @@ -148,17 +140,9 @@ class UltraTrailBackend(UMABackend): ## Reference-level explanation - - ### File and class structure and Snippets as example for integration -UMA provides a mostly python-based API. On the C++ side, new targets are registered using target hooks (RFC #0010). A generic `codegen.cc` handles the calls to the python side. -``` -. -├── codegen.cc -└── targets.cc -``` -The python API is structured as shown below. The base class `UMABackend` functions as the core API. It uses the API points described in the [flow description](#flow-description) to register elements for the different stages. +UMA provides a fully python-based API. The API is structured as shown below. The base class `UMABackend` functions as the core API. It uses the API points described in the [flow description](#flow-description) to register elements for the different stages. ``` . ├── backend.py @@ -172,7 +156,6 @@ The python API is structured as shown below. The base class `UMABackend` functio │ ├── codegen.py │ ├── passes.py │ ├── patterns.py -│ ├── schedules.py │ └── strategies.py └── accelerator_B └── ... @@ -180,14 +163,15 @@ The python API is structured as shown below. The base class `UMABackend` functio ### UMABackend class -Once a `UMABackend` is registered (as shown in Guide-level explanation), it hooks into the usual `relay.build` process to create the code for the target accelerator. +A new custom backend is created by implementing the `UMABackend` as shown in the [Guide-level explanation](#adding-a-new-custom-accelerator). To use the backend it simply needs to be registered using `backend.register()`. Once a `UMABackend` is registered, it hooks into the usual `relay.build` process to create the code for the target accelerator. ``` # Load model mod, params = relay.frontend.from_pytorch(scripted_model, [("input_data", input_shape)]) # Register a UMA backend -UltraTrailBackend().register() -mod = UltraTrailBackend().partition(mod) +ut_backend = UltraTrailBackend() +ut_backend.register() +mod = ut_backend.partition(mod) # Relay build (AOT C target) TARGET = tvm.target.Target("c") @@ -200,29 +184,161 @@ with tvm.transform.PassContext( module = relay.build(mod, target=TARGET, runtime=RUNTIME, executor=EXECUTOR, params=params) ``` -#### UMABackend References +### UMA Target Hooks + +UMA uses target hooks (RFC #0010) to perform the lowering of the partitioned, accelerator specific functions from relay to TIR and TIR to runtime. To register the hooks from the python side, a global function `RegisterTarget` is registered on the C++ side and used during the backend registration. + +```cpp +TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget") + .set_body_typed([](String target_name){ + ::tvm::TargetKindRegEntry::RegisterOrGet(target_name) + .set_name() + .set_device_type(kDLCPU) + .add_attr_option>("keys") + .add_attr_option("tag") + .add_attr_option("device") + .add_attr_option("model") + .add_attr_option>("libs") + .add_attr_option("host") + .add_attr_option("from_device") + .set_attr("RelayToTIR", relay::contrib::uma::RelayToTIR(target_name)) + .set_attr("TIRToRuntime", relay::contrib::uma::TIRToRuntime); + }); +``` + +### UMABackend References + +UMA should function as an easy to use API, that also helps new users gain orientation in the codebase. To this end, all API functions are typed and documented with examples of their use and required input. + +#### `_register_relay_pass` +```python +_register_relay_pass(self, phase: int, relay_pass: tvm.transform.Pass) -> None +``` + +|Parameter|Description| +|---------|-----------| +|phase|The phase at which the pass is registered.| +|relay_pass|The relay pass to be registered.| + +Example usage: +```python +self._register_relay_pass(0, MyPassA) + +# Where a relay pass can look like this: +@tvm.ir.transform.module_pass(opt_level=0) +class MyPassA: + def transform_module(self, mod, ctx): + # My pass functionality... + return mod +``` -#### _register_config +#### `_register_pattern` ```python -UMABackend._register_config(parameters: dict) +_register_pattern(self, name: str, pattern: tvm.relay.dataflow_pattern.DFPattern,) -> None ``` -The following ```parameters``` are allow to be passed via ```_register_config```. -Table of supported parameters: +|Parameter|Description| +|---------|-----------| +|name|The name of the pattern.| +|pattern|The dataflow pattern.| + +Example usage: +```python +self._register_pattern("conv1d", conv1d_pattern) + +# Where a dataflow pattern can look like this: +conv1d_pattern = is_op("nn.conv1d")(wildcard(), wildcard()) +optional_bias = lambda x: is_op("nn.bias_add")(x, wildcard()) +optional_relu = lambda x: is_op("nn.relu")(x) +conv1d_pattern = conv1d_pattern.optional(optional_bias).optional(optional_relu) +``` -|Parameter name|Type|Default|Description| -|--------------|-----------|-----------|-----------| -|partitioning.enable_MergeCompilerRegions|bool |True |MergeCompilerRegions pass is used for partitioning +#### `_register_operator_strategy` +```python +_register_operator_strategy(self, op: str, strategy: Callable[[tvm.ir.Attrs, tvm.ir.Array, tvm.ir.TensorType, tvm.target.Target], tvm.relay.op.op.OpStrategy], plevel: Optional[int] = 11) -> None +``` -This list is not complete and it is the intend of **THIS RFC** to find a small set of necessary parameters. +|Parameter|Description| +|---------|-----------| +|op|The name of the operator for which this strategy will be registered.| +|strategy|The strategy function.| +|plevel|The priority level of the strategy. Higher plevel equals higher priorization. The TVM default for topi strategies is 10 so by default new UMA strategies are always used.| Example usage: ```python -UMABackend._register_config({"partitioning.enable_MergeCompilerRegions": False}) +self._register_operator_strategy("nn.conv1d", custom_conv1d_strategy) + +# Where a strategy function can look like this: +@relay.op.strategy.override_native_generic_func("custom_conv1d_strategy") +def custom_conv1d_strategy(attrs, inputs, out_type, target): + strategy = _op.OpStrategy() + strategy.add_implementation( + wrap_compute_conv1d(custom_conv1d_compute), + wrap_topi_schedule(custom_conv1d_schedule), + name="custom_conv1d.generic", + return strategy ``` -#### _register_pattern +#### `_register_tir_pass` ```python -UMABackend.self._register_pattern() +_register_tir_pass(self, phase: int, tir_pass: tvm.tir.transform.PrimFuncPass) -> None ``` -TODO \ No newline at end of file + +|Parameter|Description| +|---------|-----------| +|phase|The phase at which the pass is registered.| +|tir_pass|The relay pass to be registered.| + +Example usage: +```python +self._register_tir_pass(0, MyPassA) + +# Where a TIR pass can look like this: +@tvm.tir.transform.prim_func_pass(opt_level=0) +class MyPassA: + def transform_function(self, func, mod, ctx): + # My pass functionality... + return func +``` + +#### `_register_codegen` +```python +_register_codegen(self, fmt: str = "c", **kwargs) -> None +``` + +|Parameter|Description| +|---------|-----------| +|fmt|The codegen format. For now, only C-codegen is supported by UMA.| +|**kwargs|Keyword arguments for the chosen codegen.| + +Example usage: +```python +self._register_codegen(fmt="c", includes=gen_includes, replace_call_extern=gen_replace_call_extern) + +# The C-codegen provides two hooks which allows the user to insert code through the python API. +# - `includes` hooks into the include stream and allows insertion of custom includes. +# - `replace_call_extern` hooks into the expression visitor and allows the user to insert custom code for a given extern call. +# +# The code generation functions can look like this: + +def gen_includes() -> str: + includes = "#include \n" + return includes + +def gen_replace_call_extern(args: tvm.ir.container.Array) -> str: + return "my_custom_api_function({}, {}, {})".format(*args) +``` + +#### Configuration + +The `UMABackend` can be further configured through parameters during initialization. However, the amount of configurability should be kept to a minimum, since UMA tries to streamline and simplify the compilation pipeline. +```python +class UltraTrailBackend(UMABackend): + def __init__(self): + super(UltraTrailBackend, self).__init__(merge_compiler_regions=False) +``` + +Below is a list of currently planned configuration parameters for the initial version of UMA proposed in this RFC. +|Parameter|Type|Description| +|---------|----|-----------| +|merge_compiler_regions|bool (default: True)|Enables/disables the `MergeCompilerRegions` pass during partitioning.|