From 6cdfbab7f5976d2a348f7a4b3c3fded44954c931 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Wed, 25 Sep 2024 23:32:08 +0900 Subject: [PATCH 1/2] introduce ExportedProgramImporter --- python/tvm/relax/frontend/torch/__init__.py | 1 + .../torch/base_fx_graph_translator.py | 228 ++++++++ .../torch/exportedprogram_translator.py | 243 ++++++++ .../tvm/relax/frontend/torch/fx_translator.py | 209 +------ .../test_frontend_from_exportedprogram.py | 535 ++++++++++++++++++ 5 files changed, 1029 insertions(+), 187 deletions(-) create mode 100644 python/tvm/relax/frontend/torch/base_fx_graph_translator.py create mode 100644 python/tvm/relax/frontend/torch/exportedprogram_translator.py create mode 100644 tests/python/relax/test_frontend_from_exportedprogram.py diff --git a/python/tvm/relax/frontend/torch/__init__.py b/python/tvm/relax/frontend/torch/__init__.py index 55da5a456d6a..26d9bfe156df 100644 --- a/python/tvm/relax/frontend/torch/__init__.py +++ b/python/tvm/relax/frontend/torch/__init__.py @@ -17,5 +17,6 @@ """ PyTorch Frontends for constructing Relax programs, with the model importers """ +from .exportedprogram_translator import from_exportedprogram from .fx_translator import from_fx from .dynamo import relax_dynamo, dynamo_capture_subgraphs diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py new file mode 100644 index 000000000000..6a001b5a047c --- /dev/null +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -0,0 +1,228 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck +# pylint: disable=import-outside-toplevel +"""Base class for PyTorch FX Graph importer.""" +import abc +from typing import Callable, Dict, Optional, Tuple, Union + +from tvm import relax + + +class BaseFXGraphImporter(metaclass=abc.ABCMeta): + """Base class for FX Graph Importer.""" + + import torch # type: ignore + from torch import fx + + def __init__(self) -> None: + import torch # type: ignore + from torch import fx + + self.env: Dict[fx.Node, relax.Expr] = {} + self.params: Dict[torch.Tensor, relax.Expr] = {} + self.block_builder: relax.BlockBuilder = None + self.convert_map: Dict[ + Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var] + ] = self.create_convert_map() + + ########## Utilities ########## + + @staticmethod + def _convert_data_type(input_type: Union[str, torch.dtype], env: Optional[Dict] = None): + """converts the PyTorch scalar type input_type to a TVM dtype.""" + import torch # type: ignore + + if env is not None and input_type in env: + input_type = env[input_type] + + input_type = input_type.lower() if isinstance(input_type, str) else input_type + if input_type in ["float", "float32", "torch.float32", torch.float32]: + return "float32" + elif input_type in ["float16", "torch.float16", torch.float16]: + return "float16" + elif input_type in ["int64", "torch.int64", torch.int64]: + return "int64" + elif input_type in ["int32", "torch.int32", torch.int32]: + return "int32" + elif input_type in ["bool", "torch.bool", torch.bool]: + return "bool" + else: + raise NotImplementedError("input_type {} is not handled yet".format(input_type)) + + @staticmethod + def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var: + tensor = tensor.detach().cpu() + dtype = BaseFXGraphImporter._convert_data_type(str(tensor.data.dtype)) + return relax.const(tensor.data.numpy(), dtype) + + @staticmethod + def shape_of(tensor): + """Get the shape of a tensor.""" + import torch # type: ignore + + if isinstance(tensor, relax.Expr): + if not isinstance(tensor.struct_info, relax.TensorStructInfo): + raise TypeError("The input Expr of shape_of should be a Tensor") + return tensor.struct_info.shape + elif isinstance(tensor, torch.Tensor): + return tensor.shape + raise ValueError("Unsupported type: {}".format(type(tensor))) + + def retrieve_args(self, node: fx.Node): + return self._retrieve_args(node.args) + + def _retrieve_args(self, node): + from torch import fx + + if isinstance(node, fx.Node): + return self.env[node] + elif isinstance(node, tuple): + return tuple(self._retrieve_args(x) for x in node) + elif isinstance(node, list): + return [self._retrieve_args(x) for x in node] + elif isinstance(node, dict): + return {self._retrieve_args(k): self._retrieve_args(v) for k, v in node.items()} + else: + return node + + ########## Unary Ops ########## + + def _unary_op(self, op: Callable) -> Callable: + from torch import fx + + def convert(node: fx.Node) -> relax.Var: + return self.block_builder.emit(op(self.env[node.args[0]])) + + return convert + + ########## Neural Network ########## + + def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: + x = self.env[node.args[0]] + output_size = node.args[1] + return self.block_builder.emit( + relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") + ) + + def _conv2d_impl( + self, + x: relax.Expr, + weight: relax.Expr, + bias: Optional[relax.Expr], + strides: Optional[Tuple], + padding: Optional[Tuple], + dilation: Optional[Tuple], + groups: Optional[Tuple], + ): + conv2d = self.block_builder.emit( + relax.op.nn.conv2d( + x, + weight, + strides=strides, + padding=padding, + dilation=dilation, + groups=groups, + data_layout="NCHW", + kernel_layout="OIHW", + out_dtype="float32", + ) + ) + + if bias is None: + return conv2d + assert len(self.shape_of(bias)) == 1 + bias = relax.op.reshape(bias, (1, -1, 1, 1)) + return self.block_builder.emit(relax.op.add(conv2d, bias)) + + def _conv2d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + stride = args[3] if len(args) > 3 else 1 + padding = args[4] if len(args) > 4 else 0 + dilation = args[5] if len(args) > 5 else 1 + groups = args[6] if len(args) > 6 else 1 + return self._conv2d_impl( + x, + weight, + bias=bias, + strides=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + + def _linear(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + weight = args[1] + bias = args[2] if len(args) > 2 else None + return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) + + def _max_pool2d_impl( + self, + x: relax.Expr, + kernel_size: Union[int, Tuple[int, int]] = (1, 1), + stride: Optional[Union[int, Tuple[int, int]]] = None, + padding: Optional[int] = 0, + dilation: Optional[int] = 1, + ceil_mode: Optional[bool] = False, + ) -> relax.Var: + stride = kernel_size if stride is None else stride + return self.block_builder.emit( + relax.op.nn.max_pool2d( + x, + pool_size=kernel_size, + strides=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + layout="NCHW", + ) + ) + + def _max_pool2d(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + x = args[0] + kernel_size = args[1] + stride = args[2] if len(args) > 2 else None + padding = args[3] if len(args) > 3 else 0 + dilation = args[4] if len(args) > 4 else 1 + ceil_mode = args[5] if len(args) > 5 else False + + return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) + + ########## Manipulation ########## + + def _reshape(self, node: fx.Node) -> relax.Var: + import torch # type: ignore + + args = self.retrieve_args(node) + x = args[0] + dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] + return self.block_builder.emit(relax.op.reshape(x, dims)) + + ########## Others ########## + + @abc.abstractmethod + def create_convert_map( + self, + ) -> Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]]: + """Create convert map""" diff --git a/python/tvm/relax/frontend/torch/exportedprogram_translator.py b/python/tvm/relax/frontend/torch/exportedprogram_translator.py new file mode 100644 index 000000000000..45dd742258b2 --- /dev/null +++ b/python/tvm/relax/frontend/torch/exportedprogram_translator.py @@ -0,0 +1,243 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=invalid-name, inconsistent-return-statements, unidiomatic-typecheck +# pylint: disable=import-outside-toplevel +"""PyTorch ExportedProgram of Relax.""" +from collections import ChainMap, OrderedDict +from typing import Callable, Dict, List, Tuple + +import torch +import tvm +from tvm import relax + +from .base_fx_graph_translator import BaseFXGraphImporter + + +class ExportedProgramImporter(BaseFXGraphImporter): + """An importer from ExportedProgram to Relax.""" + + from torch import fx + + def create_input_vars( + self, exported_program: torch.export.ExportedProgram + ) -> Tuple[List[relax.Var], List[relax.Var]]: + """Create relax input vars.""" + parameters_buffers_constants = [] + user_inputs = [] + for spec in exported_program.graph_signature.input_specs: + name_hint = spec.arg.name + if spec.kind is torch.export.graph_signature.InputKind.CONSTANT_TENSOR: + shape = exported_program.tensor_constants[spec.target].shape + torch_dtype = exported_program.tensor_constants[spec.target].dtype + elif spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: + for node in exported_program.graph.find_nodes(op="placeholder", target=spec.target): + if node.name == name_hint: + shape = node.meta["tensor_meta"].shape + torch_dtype = node.meta["tensor_meta"].dtype + break + else: + # PARAMETER or BUFFER + shape = exported_program.state_dict[spec.target].shape + torch_dtype = exported_program.state_dict[spec.target].dtype + + dtype = self._convert_data_type(torch_dtype) + relax_var = relax.Var(name_hint, relax.TensorStructInfo(shape, dtype)) + if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: + user_inputs.append(relax_var) + else: + parameters_buffers_constants.append(relax_var) + + return parameters_buffers_constants, user_inputs + + def create_convert_map( + self, + ) -> Dict[str, Callable[[fx.Node], relax.Var]]: + return { + # unary + "dropout.default": lambda node: self.env[node.args[0]], + "relu.default": self._unary_op(relax.op.nn.relu), + # neural network + "adaptive_avg_pool2d.default": self._adaptive_avg_pool2d, + "conv2d.default": self._conv2d, + "linear.default": self._linear, + "max_pool2d.default": self._max_pool2d, + # tensor manipulation + "view.default": self._reshape, + } + + def from_exported_program( + self, + exported_program: torch.export.ExportedProgram, + keep_params_as_input: bool, + unwrap_unit_return_tuple: bool, + no_bind_return_tuple: bool, + ) -> tvm.IRModule: + """Convert a PyTorch ExportedProgram to a Relax program.""" + from torch import fx # type: ignore + + # Create input variables. + parameter_buffer_constant_vars, user_input_vars = self.create_input_vars(exported_program) + inputs_vars = parameter_buffer_constant_vars + user_input_vars + + # Initialize the block builder with a function and a dataflow block. + self.block_builder = relax.BlockBuilder() + func_name = "main" + func_attrs = {"num_input": len(user_input_vars)} if keep_params_as_input else None + + nodes: List[fx.Node] = exported_program.graph.nodes + with self.block_builder.function( + name=func_name, params=inputs_vars.copy(), attrs=func_attrs + ): + output = None + with self.block_builder.dataflow(): + # Translate the model. + for node in nodes: + if node.op == "placeholder": + if "grapharg" in node.meta and node.meta["grapharg"].fake_tensor is None: + # Ignore sym input + continue + + self.env[node] = inputs_vars.pop(0) + elif node.op == "output": + args = self.retrieve_args(node) + assert len(args) == 1 + assert isinstance(args[0], (tuple, relax.Tuple)) + + if unwrap_unit_return_tuple and len(args[0]) == 1: + output = self.block_builder.emit_output(args[0][0]) + elif no_bind_return_tuple: + output = [] + for ret in args[0]: + output.append(self.block_builder.emit_output(ret)) + else: + output = self.block_builder.emit_output(args[0]) + break + elif node.op == "get_attr": + self.env[node] = getattr(exported_program.graph_module, node.target) + elif node.op == "call_function": + func_name = node.target.__name__ + assert ( + func_name in self.convert_map + ), f"Unsupported function type {func_name}" + self.env[node] = self.convert_map[func_name](node) + else: + raise ValueError(f"Unsupported op {node.op}") + assert output is not None + self.block_builder.emit_func_output(output) + + to_bind_parameters = ChainMap( + OrderedDict(exported_program.named_buffers()), exported_program.constants + ) + if not keep_params_as_input: + to_bind_parameters = to_bind_parameters.new_child( + OrderedDict(exported_program.named_parameters()) + ) + + binding = {} + for tensor_name, tensor_value in to_bind_parameters.items(): + # find relax var name from graph signature + for spec in exported_program.graph_signature.input_specs: + if tensor_name == spec.target: + bind_name = spec.arg.name + break + binding[bind_name] = tvm.nd.from_dlpack(tensor_value.detach()) + + mod = self.block_builder.get() + mod = relax.transform.BindParams("main", binding)(mod) + + if keep_params_as_input: + parameters = dict(exported_program.named_parameters()) + params = [tvm.nd.from_dlpack(p.detach()) for p in parameters.values()] + mod["main"] = mod["main"].with_attr("params", params) + + return mod + + +def from_exportedprogram( + exported_program: torch.export.ExportedProgram, + *, + keep_params_as_input: bool = False, + unwrap_unit_return_tuple: bool = False, + no_bind_return_tuple: bool = False, +) -> tvm.IRModule: + """Convert a PyTorch ExportedProgram to a Relax program + + Parameters + ---------- + exported_program : torch.export.ExportedProgram + The PyTorch ExportedProgram to convert. + + keep_params_as_input : bool + Whether to keep model parameters as input variables. + + unwrap_unit_return_tuple : bool + A boolean flag indicating if to the return value when it is an unit tuple. + When the return value is not a unit tuple, no unwrap will take place. + + no_bind_return_tuple : bool + A boolean flag indicating whether to bind the return tuple as a relax var. + If the flag is true and the return value is a tuple, it will not bind it to a var. + + Returns + ------- + output : tvm.IRModule + The import result IRModule, with the function "main" containing the + translated logic. + + Examples + -------- + Users can use the torch.export.export() to extract a torch.export.ExportedProgram + from a PyTorch model. The following codes show how to convert a PyTorch model to + a Relax program. + + .. code-block:: python + + # Import the importer. + import tvm + from tvm.relax.frontend.torch import from_exportedprogram + import torch + from torch.export import export + + # Define the module + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(in_features=10, out_features=7, bias=True) + + def forward(self, input): + return self.linear(input) + + # Instantiate the model and create the input info dict. + torch_model = MyModule() + + # Use torch.export.export() to convert the PyTorch model into ExportedProgram. + example_args = (torch.rand(128, 10, dtype=torch.float32),) + exported_program = export(torch_model, args=example_args) + + # Use the importer to import the ExportedProgram to Relax. + mod: tvm.IRModule = from_exportedprogram(exported_program) + """ + # decompose into Core ATen operators + exported_program.run_decompositions() + + return ExportedProgramImporter().from_exported_program( + exported_program, + keep_params_as_input, + unwrap_unit_return_tuple, + no_bind_return_tuple, + ) diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 27da69dbb182..ec53cf23edc5 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -24,8 +24,10 @@ import tvm from tvm import relax +from .base_fx_graph_translator import BaseFXGraphImporter -class TorchFXImporter: + +class TorchFXImporter(BaseFXGraphImporter): """An importer from PyTorch FX to Relax.""" import torch # type: ignore @@ -33,15 +35,12 @@ class TorchFXImporter: def __init__(self) -> None: import torch # type: ignore - from torch import fx - self.env: Dict[fx.Node, relax.Expr] = {} - self.params: Dict[torch.Tensor, relax.Expr] = {} + super().__init__() self.named_modules: Dict[str, torch.Module] = None - self.block_builder: relax.BlockBuilder = None - self.create_convert_map() ########## Utilities ########## + def _fetch_attr(self, model, target: str): import torch # type: ignore @@ -58,77 +57,11 @@ def _fetch_attr(self, model, target: str): # If so, return the parameter instead. if attr_itr in self.params: return self.params[attr_itr] - return TorchFXImporter._convert_torch_tensor_to_relax(attr_itr) + return self._convert_torch_tensor_to_relax(attr_itr) return attr_itr - @staticmethod - def _convert_data_type(input_type: Union[str, torch.dtype], env: Optional[Dict] = None): - """converts the PyTorch scalar type input_type to a TVM dtype.""" - import torch # type: ignore - - if env is not None and input_type in env: - input_type = env[input_type] - - input_type = input_type.lower() if isinstance(input_type, str) else input_type - if input_type in ["float", "float32", "torch.float32", torch.float32]: - return "float32" - elif input_type in ["float16", "torch.float16", torch.float16]: - return "float16" - elif input_type in ["int64", "torch.int64", torch.int64]: - return "int64" - elif input_type in ["int32", "torch.int32", torch.int32]: - return "int32" - elif input_type in ["bool", "torch.bool", torch.bool]: - return "bool" - else: - raise NotImplementedError("input_type {} is not handled yet".format(input_type)) - - @staticmethod - def _convert_torch_tensor_to_relax(tensor: torch.Tensor) -> relax.Var: - tensor = tensor.detach().cpu() - dtype = TorchFXImporter._convert_data_type(str(tensor.data.dtype)) - return relax.const(tensor.data.numpy(), dtype) - - @staticmethod - def shape_of(tensor): - """Get the shape of a tensor.""" - import torch # type: ignore - - if isinstance(tensor, relax.Expr): - if not isinstance(tensor.struct_info, relax.TensorStructInfo): - raise TypeError("The input Expr of shape_of should be a Tensor") - return tensor.struct_info.shape - elif isinstance(tensor, torch.Tensor): - return tensor.shape - raise ValueError("Unsupported type: {}".format(type(tensor))) - - def retrieve_args(self, node): - return self._retrieve_args(node.args) - - def _retrieve_args(self, node): - from torch import fx - - if isinstance(node, fx.Node): - return self.env[node] - elif isinstance(node, tuple): - return tuple(self._retrieve_args(x) for x in node) - elif isinstance(node, list): - return [self._retrieve_args(x) for x in node] - elif isinstance(node, dict): - return {self._retrieve_args(k): self._retrieve_args(v) for k, v in node.items()} - else: - return node - ########## Unary Ops ########## - def _unary_op(self, op: Callable) -> Callable: - from torch import fx - - def convert(node: fx.Node) -> relax.Var: - return self.block_builder.emit(op(self.env[node.args[0]])) - - return convert - def _clamp(self, node: fx.Node) -> relax.Expr: args = self.retrieve_args(node) a_min = args[1] if len(args) > 1 else node.kwargs["min"] @@ -272,13 +205,6 @@ def call_binary_op(op, lhs, rhs): ########## Neural Network ########## - def _adaptive_avg_pool2d(self, node: fx.Node) -> relax.Var: - x = self.env[node.args[0]] - output_size = node.args[1] - return self.block_builder.emit( - relax.op.nn.adaptive_avg_pool2d(x, output_size, layout="NCHW") - ) - def _adaptive_avg_pool2d_module(self, node: fx.Node) -> relax.Var: module = self.named_modules[node.target] @@ -590,55 +516,6 @@ def _conv1d_module(self, node: fx.Node) -> relax.Var: groups=module.groups, ) - def _conv2d_impl( - self, - x: relax.Expr, - weight: relax.Expr, - bias: Optional[relax.Expr], - strides: Optional[Tuple], - padding: Optional[Tuple], - dilation: Optional[Tuple], - groups: Optional[Tuple], - ): - conv2d = self.block_builder.emit( - relax.op.nn.conv2d( - x, - weight, - strides=strides, - padding=padding, - dilation=dilation, - groups=groups, - data_layout="NCHW", - kernel_layout="OIHW", - out_dtype="float32", - ) - ) - - if bias is None: - return conv2d - assert len(self.shape_of(bias)) == 1 - bias = relax.op.reshape(bias, (1, -1, 1, 1)) - return self.block_builder.emit(relax.op.add(conv2d, bias)) - - def _conv2d(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - stride = args[3] if len(args) > 3 else 1 - padding = args[4] if len(args) > 4 else 0 - dilation = args[5] if len(args) > 5 else 1 - groups = args[6] if len(args) > 6 else 1 - return self._conv2d_impl( - x, - weight, - bias=bias, - strides=stride, - padding=padding, - dilation=dilation, - groups=groups, - ) - def _conv2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -940,13 +817,6 @@ def _layer_norm_module(self, node: fx.Node) -> relax.Var: eps = module.eps return self._layer_norm_impl(x, gamma, beta, eps, normalized_shape) - def _linear(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - weight = args[1] - bias = args[2] if len(args) > 2 else None - return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _linear_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -954,39 +824,6 @@ def _linear_module(self, node: fx.Node) -> relax.Var: bias = self.params.get(module.bias, None) return self.block_builder.emit(relax.op.linear(x, weight, bias, "float32")) - def _max_pool2d_impl( - self, - x: relax.Expr, - kernel_size: Union[int, Tuple[int, int]] = (1, 1), - stride: Optional[Union[int, Tuple[int, int]]] = None, - padding: Optional[int] = 0, - dilation: Optional[int] = 1, - ceil_mode: Optional[bool] = False, - ) -> relax.Var: - stride = kernel_size if stride is None else stride - return self.block_builder.emit( - relax.op.nn.max_pool2d( - x, - pool_size=kernel_size, - strides=stride, - padding=padding, - dilation=dilation, - ceil_mode=ceil_mode, - layout="NCHW", - ) - ) - - def _max_pool2d(self, node: fx.Node) -> relax.Var: - args = self.retrieve_args(node) - x = args[0] - kernel_size = args[1] - stride = args[2] if len(args) > 2 else None - padding = args[3] if len(args) > 3 else 0 - dilation = args[4] if len(args) > 4 else 1 - ceil_mode = args[5] if len(args) > 5 else False - - return self._max_pool2d_impl(x, kernel_size, stride, padding, dilation, ceil_mode) - def _max_pool2d_module(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] module = self.named_modules[node.target] @@ -1138,14 +975,6 @@ def _repeat(self, node: fx.Node) -> relax.Var: dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] return self.block_builder.emit(relax.op.tile(x, dims)) - def _reshape(self, node: fx.Node) -> relax.Var: - import torch # type: ignore - - args = self.retrieve_args(node) - x = args[0] - dims = args[1] if isinstance(args[1], (torch.Size, tuple, list)) else args[1:] - return self.block_builder.emit(relax.op.reshape(x, dims)) - def _size(self, node: fx.Node) -> relax.Expr: x = self.env[node.args[0]] shape = self.shape_of(x) @@ -1448,12 +1277,23 @@ def _sym_size_int(self, node: fx.Node) -> relax.Expr: idx = node.args[1] return self.block_builder.emit(relax.const(shape[idx].value, "int32")) - def create_convert_map(self): + def create_input_vars(self, input_info: List[Tuple[Tuple[int], str]]) -> List[relax.Var]: + inputs = list() + for idx, (shape, dtype) in enumerate(input_info): + inputs.append( + relax.Var( + f"inp_{idx}", relax.TensorStructInfo(shape, self._convert_data_type(dtype)) + ) + ) + return inputs + + def create_convert_map( + self, + ) -> Dict[Union[torch.nn.Module, str], Callable[[fx.Node], relax.Var]]: import operator from torch import nn - from torch import fx - self.convert_map: Dict[Union[nn.Module, str], Callable[[fx.Node], relax.Var]] = { + return { ## call_module # unary nn.Dropout: lambda node: self.env[node.args[0]], @@ -1638,14 +1478,9 @@ def from_fx( self.named_modules = dict(model.named_modules()) graph: fx.Graph = model.graph + # Create input variables. - inputs = list() - for idx, (shape, dtype) in enumerate(input_info): - inputs.append( - relax.Var( - f"inp_{idx}", relax.TensorStructInfo(shape, self._convert_data_type(dtype)) - ) - ) + inputs = self.create_input_vars(input_info) # Initialize the block builder with a function and a dataflow block. func_name = "main" diff --git a/tests/python/relax/test_frontend_from_exportedprogram.py b/tests/python/relax/test_frontend_from_exportedprogram.py new file mode 100644 index 000000000000..31feb25fa456 --- /dev/null +++ b/tests/python/relax/test_frontend_from_exportedprogram.py @@ -0,0 +1,535 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import torch +from torch.nn import Module +from torch.export import export + +import tvm +from tvm import relax +import tvm.testing +from tvm.script import ir as I +from tvm.script import relax as R +from tvm.script import tir as T +from tvm.relax.frontend.torch import from_exportedprogram + + +def verify_model(torch_model, example_args, binding, expected): + exported_program = export(torch_model, args=example_args) + mod = from_exportedprogram(exported_program) + + binding = {k: tvm.nd.array(v) for k, v in binding.items()} + expected = relax.transform.BindParams("main", binding)(expected) + tvm.ir.assert_structural_equal(mod, expected) + + +def test_unary(): + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + # dropout + class Dropout1(Module): + def __init__(self): + super().__init__() + self.dropout = torch.nn.Dropout(0.5) + + def forward(self, input): + return self.dropout(input) + + class Dropout2(Module): + def forward(self, input): + return torch.dropout(input, 0.5, train=True) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (input_1,) + R.output(gv) + return gv + + verify_model(Dropout1(), example_args, {}, expected1) + verify_model(Dropout2(), example_args, {}, expected1) + + # relu + class ReLU0(Module): + def __init__(self): + super().__init__() + self.relu = torch.nn.ReLU() + + def forward(self, input): + return self.relu(input) + + class ReLU1(Module): + def forward(self, input): + return torch.nn.functional.relu(input) + + @tvm.script.ir_module + class expected: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.relu(input_1) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + verify_model(ReLU0(), example_args, {}, expected) + verify_model(ReLU1(), example_args, {}, expected) + + +def test_adaptive_avgpool2d(): + class AdaptiveAvgPool2d0(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.AdaptiveAvgPool2d([10, 10]) + + def forward(self, input): + return self.pool(input) + + class AdaptiveAvgPool2d1(Module): + def forward(self, input): + return torch.nn.functional.adaptive_avg_pool2d(input, [10, 10]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.adaptive_avg_pool2d( + input_1, output_size=[10, 10], layout="NCHW", out_layout="NCHW" + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(AdaptiveAvgPool2d0(), example_args, {}, expected1) + verify_model(AdaptiveAvgPool2d1(), example_args, {}, expected1) + + +def test_conv2d(): + class Conv2D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + class Conv2D1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[6, 3, 7, 7]) + self.bias = torch.randn(size=[6]) + + def forward(self, input): + return torch.nn.functional.conv2d(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + w2: R.Tensor((6,), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1, 1)) = R.reshape(w2, [1, 6, 1, 1]) + lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + class Conv2D2(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=False) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + w1: R.Tensor((6, 3, 7, 7), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + w1, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = Conv2D1() + binding = {"w1": model.conv.weight.detach().numpy(), "w2": model.conv.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv2D1Func() + binding = {"w1": model.weight.numpy(), "w2": model.bias.numpy()} + verify_model(model, example_args, binding, expected1) + + model = Conv2D2() + binding = {"w1": model.conv.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) + + +def test_linear(): + class Dense1(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=True) + + def forward(self, input): + return self.linear(input) + + class Dense1Func(Module): + def __init__(self): + super().__init__() + self.weight = torch.randn(size=[7, 10]) + self.bias = torch.randn(size=[7]) + + def forward(self, input): + return torch.nn.functional.linear(input, self.weight, self.bias) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + w1: R.Tensor((7, 10), dtype="float32"), + w2: R.Tensor((7,), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=None) + lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( + input_1, lv, out_dtype="float32" + ) + lv2: R.Tensor((1, 3, 10, 7), dtype="float32") = R.add(lv1, w2) + gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv2,) + R.output(gv) + return gv + + class Dense2(Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(10, 7, bias=False) + + def forward(self, input): + return self.linear(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + w1: R.Tensor((7, 10), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((10, 7), dtype="float32") = R.permute_dims(w1, axes=None) + lv1: R.Tensor((1, 3, 10, 7), dtype="float32") = R.matmul( + input_1, lv, out_dtype="float32" + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 7), dtype="float32")) = (lv1,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + + model = Dense1() + binding = {"w1": model.linear.weight.detach().numpy(), "w2": model.linear.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Dense1Func() + binding = {"w1": model.weight.detach().numpy(), "w2": model.bias.detach().numpy()} + verify_model(model, example_args, binding, expected1) + + model = Dense2() + binding = {"w1": model.linear.weight.detach().numpy()} + verify_model(model, example_args, binding, expected2) + + +def test_maxpool2d(): + class MaxPool2d(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[1, 1]) + + def forward(self, input): + return self.pool(input) + + class MaxPool2d_functional(Module): + def __init__(self): + super().__init__() + + def forward(self, input): + return torch.nn.functional.max_pool2d(input, kernel_size=[1, 1]) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[1, 1], + strides=[1, 1], + dilation=[1, 1], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class MaxPool2d2(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[2, 2], dilation=[2, 3]) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected2: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 4, 4), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[2, 2], + strides=[2, 2], + dilation=[2, 3], + padding=[0, 0, 0, 0], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tuple(R.Tensor((1, 3, 4, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + class MaxPool2d3(Module): + def __init__(self): + super().__init__() + self.pool = torch.nn.MaxPool2d(kernel_size=[4, 4], padding=2, stride=2) + + def forward(self, input): + return self.pool(input) + + @tvm.script.ir_module + class expected3: + @R.function + def main( + input_1: R.Tensor((1, 3, 10, 10), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((1, 3, 6, 6), dtype="float32") = R.nn.max_pool2d( + input_1, + pool_size=[4, 4], + strides=[2, 2], + dilation=[1, 1], + padding=[2, 2, 2, 2], + layout="NCHW", + out_layout="NCHW", + ) + gv: R.Tuple(R.Tensor((1, 3, 6, 6), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + verify_model(MaxPool2d(), example_args, {}, expected1) + verify_model(MaxPool2d_functional(), example_args, {}, expected1) + verify_model(MaxPool2d2(), example_args, {}, expected2) + verify_model(MaxPool2d3(), example_args, {}, expected3) + + +def test_view(): + class View(Module): + def forward(self, x): + return x.view(2, 12) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + x: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((2, 12), dtype="float32")): + # block 0 + with R.dataflow(): + lv: R.Tensor((2, 12), dtype="float32") = R.reshape(x, (2, 12)) + gv: R.Tuple(R.Tensor((2, 12), dtype="float32")) = (lv,) + R.output(gv) + return gv + + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) + verify_model(View(), example_args, {}, expected1) + + +def test_keep_params(): + class Conv2D1(Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 6, 7, bias=True) + + def forward(self, input): + return self.conv(input) + + @tvm.script.ir_module + class expected1: + @R.function + def main( + conv_weight: R.Tensor((6, 3, 7, 7), dtype="float32"), + conv_bias: R.Tensor((6,), dtype="float32"), + input_1: R.Tensor((1, 3, 10, 10), dtype="float32"), + ) -> R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")): + R.func_attr({"num_input": 1}) + # block 0 + with R.dataflow(): + lv1: R.Tensor((1, 6, 4, 4), dtype="float32") = R.nn.conv2d( + input_1, + conv_weight, + strides=[1, 1], + padding=[0, 0, 0, 0], + dilation=[1, 1], + data_layout="NCHW", + kernel_layout="OIHW", + out_layout="NCHW", + out_dtype="float32", + ) + lv2: R.Tensor((1, 6, 1, 1), dtype="float32") = R.reshape(conv_bias, [1, 6, 1, 1]) + lv3: R.Tensor((1, 6, 4, 4), dtype="float32") = R.add(lv1, lv2) + gv: R.Tuple(R.Tensor((1, 6, 4, 4), dtype="float32")) = (lv3,) + R.output(gv) + return gv + + from tvm.relax.frontend import detach_params + + example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) + model = Conv2D1() + exported_program = torch.export.export(model, example_args) + mod = from_exportedprogram(exported_program, keep_params_as_input=True) + mod, params = detach_params(mod) + tvm.ir.assert_structural_equal(mod, expected1) + func = mod["main"] + params = params["main"] + + assert len(params) == len(func.params) - 1 + for param_var, param_ndarray in zip(func.params[:-1], params): + assert tuple(x.value for x in param_var.struct_info.shape.values) == param_ndarray.shape + assert param_var.struct_info.dtype == param_ndarray.dtype + + tvm.testing.assert_allclose(params[0].numpy(), model.conv.weight.detach().detach().numpy()) + tvm.testing.assert_allclose(params[1].numpy(), model.conv.bias.detach().detach().numpy()) + + +def test_unwrap_unit_return_tuple(): + class Identity(Module): + def __init__(self): + super().__init__() + + def forward(self, x): + return (x,) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32") + ) -> R.Tensor((256, 256), dtype="float32"): + with R.dataflow(): + gv: R.Tensor((256, 256), dtype="float32") = inp_0 + R.output(gv) + return gv + + example_args = (torch.randn(256, 256, dtype=torch.float32),) + exported_program = export(Identity(), args=example_args) + mod = from_exportedprogram(exported_program, unwrap_unit_return_tuple=True) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_no_bind_return_tuple(): + class Identity(Module): + def __init__(self): + super().__init__() + + def forward(self, x, y): + return (x, y) + + @tvm.script.ir_module + class Expected: + @R.function + def main( + inp_0: R.Tensor((256, 256), dtype="float32"), + inp_1: R.Tensor((256, 256), dtype="float32"), + ) -> R.Tuple(R.Tensor((256, 256), dtype="float32"), R.Tensor((256, 256), dtype="float32")): + with R.dataflow(): + gv: R.Tensor((256, 256), dtype="float32") = inp_0 + gv1: R.Tensor((256, 256), dtype="float32") = inp_1 + R.output(gv, gv1) + return (gv, gv1) + + example_args = ( + torch.randn(256, 256, dtype=torch.float32), + torch.randn(256, 256, dtype=torch.float32), + ) + exported_program = export(Identity(), args=example_args) + mod = from_exportedprogram(exported_program, no_bind_return_tuple=True) + tvm.ir.assert_structural_equal(mod, Expected) From 893e5b37243d0fd1184025c6aff88c471cb25381 Mon Sep 17 00:00:00 2001 From: Masahiro Hiramori Date: Thu, 26 Sep 2024 22:37:43 +0900 Subject: [PATCH 2/2] address review comments --- python/tvm/relax/frontend/torch/__init__.py | 2 +- ...am_translator.py => exported_program_translator.py} | 6 +++--- ...ogram.py => test_frontend_from_exported_program.py} | 10 +++++----- 3 files changed, 9 insertions(+), 9 deletions(-) rename python/tvm/relax/frontend/torch/{exportedprogram_translator.py => exported_program_translator.py} (98%) rename tests/python/relax/{test_frontend_from_exportedprogram.py => test_frontend_from_exported_program.py} (98%) diff --git a/python/tvm/relax/frontend/torch/__init__.py b/python/tvm/relax/frontend/torch/__init__.py index 26d9bfe156df..36eac975dfc7 100644 --- a/python/tvm/relax/frontend/torch/__init__.py +++ b/python/tvm/relax/frontend/torch/__init__.py @@ -17,6 +17,6 @@ """ PyTorch Frontends for constructing Relax programs, with the model importers """ -from .exportedprogram_translator import from_exportedprogram +from .exported_program_translator import from_exported_program from .fx_translator import from_fx from .dynamo import relax_dynamo, dynamo_capture_subgraphs diff --git a/python/tvm/relax/frontend/torch/exportedprogram_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py similarity index 98% rename from python/tvm/relax/frontend/torch/exportedprogram_translator.py rename to python/tvm/relax/frontend/torch/exported_program_translator.py index 45dd742258b2..9af422d1c3ca 100644 --- a/python/tvm/relax/frontend/torch/exportedprogram_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -168,7 +168,7 @@ def from_exported_program( return mod -def from_exportedprogram( +def from_exported_program( exported_program: torch.export.ExportedProgram, *, keep_params_as_input: bool = False, @@ -209,7 +209,7 @@ def from_exportedprogram( # Import the importer. import tvm - from tvm.relax.frontend.torch import from_exportedprogram + from tvm.relax.frontend.torch import from_exported_program import torch from torch.export import export @@ -230,7 +230,7 @@ def forward(self, input): exported_program = export(torch_model, args=example_args) # Use the importer to import the ExportedProgram to Relax. - mod: tvm.IRModule = from_exportedprogram(exported_program) + mod: tvm.IRModule = from_exported_program(exported_program) """ # decompose into Core ATen operators exported_program.run_decompositions() diff --git a/tests/python/relax/test_frontend_from_exportedprogram.py b/tests/python/relax/test_frontend_from_exported_program.py similarity index 98% rename from tests/python/relax/test_frontend_from_exportedprogram.py rename to tests/python/relax/test_frontend_from_exported_program.py index 31feb25fa456..112390fe6094 100644 --- a/tests/python/relax/test_frontend_from_exportedprogram.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -24,12 +24,12 @@ from tvm.script import ir as I from tvm.script import relax as R from tvm.script import tir as T -from tvm.relax.frontend.torch import from_exportedprogram +from tvm.relax.frontend.torch import from_exported_program def verify_model(torch_model, example_args, binding, expected): exported_program = export(torch_model, args=example_args) - mod = from_exportedprogram(exported_program) + mod = from_exported_program(exported_program) binding = {k: tvm.nd.array(v) for k, v in binding.items()} expected = relax.transform.BindParams("main", binding)(expected) @@ -465,7 +465,7 @@ def main( example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),) model = Conv2D1() exported_program = torch.export.export(model, example_args) - mod = from_exportedprogram(exported_program, keep_params_as_input=True) + mod = from_exported_program(exported_program, keep_params_as_input=True) mod, params = detach_params(mod) tvm.ir.assert_structural_equal(mod, expected1) func = mod["main"] @@ -501,7 +501,7 @@ def main( example_args = (torch.randn(256, 256, dtype=torch.float32),) exported_program = export(Identity(), args=example_args) - mod = from_exportedprogram(exported_program, unwrap_unit_return_tuple=True) + mod = from_exported_program(exported_program, unwrap_unit_return_tuple=True) tvm.ir.assert_structural_equal(mod, Expected) @@ -531,5 +531,5 @@ def main( torch.randn(256, 256, dtype=torch.float32), ) exported_program = export(Identity(), args=example_args) - mod = from_exportedprogram(exported_program, no_bind_return_tuple=True) + mod = from_exported_program(exported_program, no_bind_return_tuple=True) tvm.ir.assert_structural_equal(mod, Expected)