Skip to content

Commit

Permalink
[Unity][nnModule] Dynamic shape support in nn Module (#16284)
Browse files Browse the repository at this point in the history
* [Unity][nnModule] Dynamic shape support in nn Module
  • Loading branch information
CharlieFRuan authored Jan 12, 2024
1 parent 138cb65 commit 07d8e02
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 10 deletions.
15 changes: 11 additions & 4 deletions python/tvm/relax/frontend/nn/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,15 @@ def from_scalar(data: Union[int, float], dtype: str) -> "Tensor":

@staticmethod
def placeholder(
shape: Sequence[Union[int, tir.PrimExpr]],
shape: Sequence[Union[int, str, tir.PrimExpr]],
dtype: str,
name: str = "tensor",
) -> "Tensor":
"""Create a placeholder tensor with given shape and dtype. A placeholder tensor should
never be created directly by users in usual cases, and the only exception is to indicate
the shape/dtype of return values of an external function.
If shape is a string `name`, we create a symbolic shape `tvm.tir.Var(name, "int64")`.
"""
new_shape = []
for expr in shape:
Expand All @@ -143,6 +145,10 @@ def placeholder(
assert expr >= 0
new_shape.append(expr)
continue
if isinstance(expr, str):
expr = tir.Var(expr, "int64")
new_shape.append(expr)
continue
if not isinstance(expr, tir.PrimExpr):
raise TypeError(f"Invalid shape: {shape}")
assert expr.dtype == "int64"
Expand Down Expand Up @@ -214,16 +220,17 @@ class Parameter(Tensor):

def __init__(
self,
shape: Sequence[Union[int, tir.PrimExpr]],
shape: Sequence[Union[int, str, tir.PrimExpr]],
dtype: Optional[str] = None,
) -> None:
"""Create a parameter with given shape and dtype. The parameter is not bound to any
concrete values.
Parameters
----------
shape : Sequence[Union[int, tir.PrimExpr]]
The shape of the parameter
shape : Sequence[Union[int, str, tir.PrimExpr]]
The shape of the parameter. If it is a string `name`, we create a symbolic shape
`tvm.tir.Var(name, "int64")`.
dtype : Optional[str]
The data type of the parameter. If not specified, the default dtype will be used.
"""
Expand Down
21 changes: 18 additions & 3 deletions python/tvm/relax/frontend/nn/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,7 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
return result

# pylint: enable=protected-access

params = _params()
params = None
effects = _effects()
ext_mods = self.extern_mods
with self:
Expand All @@ -122,6 +121,7 @@ def _effects() -> typing.List[typing.Tuple[str, core.Effect]]:
outputs = _emit_effect_init(self.builder, effects)
self.builder.emit_func_output(outputs, params=[])
for method_name, method_spec in zip(spec.method_names, spec.method_specs):
params = _params() # Re-initialize so symbolic shapes not shared across methods
len_args = len(method_spec.arg_specs)
len_effects = {
"packed": 1,
Expand Down Expand Up @@ -159,6 +159,9 @@ def _emit_method( # pylint: disable=too-many-locals,too-many-branches,too-many-
effects: typing.Optional[typing.List[typing.Tuple[str, core.Effect]]],
):
# pylint: disable=protected-access
# symbolic shape's name mapping to its tir.Var for reuse
str2var_params: typing.Dict[str, tir.Var] = {}

def _unwrap_ret(expr: typing.Any) -> typing.Any:
if isinstance(expr, (core.Tensor, core.Object)):
return expr._expr
Expand All @@ -184,8 +187,20 @@ def _convert_input(arg):

def _params(mode: str) -> typing.List[rx.Var]:
inputs: typing.List[rx.Var] = []

def _get_var(shape_var: tir.Var) -> tir.Var:
name = shape_var.name
if name in str2var_params:
return str2var_params[name]
var = tir.Var(name, "int64")
str2var_params[name] = var
return var

for name, param in params:
var = core.Tensor.placeholder(param.shape, param.dtype, name)._expr
# Make sure the a symbolic shape is not re-registered (same as _method_spec_to_inputs)
# e.g. we do not see `vocab_size` for `lm_head` and `vocab_size_1` for `embed_tokens`
new_shape = [_get_var(x) if isinstance(x, tir.Var) else x for x in param.shape]
var = core.Tensor.placeholder(new_shape, param.dtype, name)._expr
inputs.append(var)
param._expr = var
if mode == "none":
Expand Down
11 changes: 8 additions & 3 deletions python/tvm/relax/frontend/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ class Linear(Module):

def __init__(
self,
in_features: int,
out_features: int,
in_features: Union[int, str, tir.PrimExpr],
out_features: Union[int, str, tir.PrimExpr],
bias: bool = True,
dtype: Optional[str] = None,
out_dtype: Optional[str] = None,
Expand Down Expand Up @@ -617,7 +617,12 @@ class Embedding(Module):
Module for embedding layer.
"""

def __init__(self, num: int, dim: int, dtype: Optional[str] = None):
def __init__(
self,
num: Union[int, str, tir.PrimExpr],
dim: Union[int, str, tir.PrimExpr],
dtype: Optional[str] = None,
):
self.num = num
self.dim = dim
self.weight = Parameter((num, dim), dtype=dtype)
Expand Down

0 comments on commit 07d8e02

Please sign in to comment.