Skip to content

Commit

Permalink
Parse relax TupleGetItem (apache#77)
Browse files Browse the repository at this point in the history
* Parse relax TupleGetItem

* Incorporate comments

* fix comment
  • Loading branch information
yongwww authored and junrushao committed Oct 14, 2022
1 parent 0a552da commit ea8cc43
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 8 deletions.
119 changes: 111 additions & 8 deletions python/tvm/script/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@
from __future__ import annotations

import inspect
from typing import Union, Dict, List, Tuple, Optional, Callable, Any
from enum import Enum
from typing import Union, Dict, List, Tuple, Optional, Callable, Any

import tvm
from tvm import relay, relax, tir
import tvm.script
from tvm.ir.module import IRModule
from tvm.ir import diagnostics
from tvm.ir.module import IRModule
from tvm.script.tir.node import BufferSlice
import tvm.script.tir as tir_namespace
import tvm.script.relax as relax_namespace

import synr
from synr import ast, Transformer

from tvm import relay, relax, tir

import tvm.script.relax as relax_namespace
import tvm.script.tir as tir_namespace

from ..parser import TVMScriptParser as _TIRScriptParser
from ..utils import tvm_span_from_synr, call_with_error_reporting


def _is_registered(op_name: str, op_set=None) -> bool:
Expand Down Expand Up @@ -68,6 +68,8 @@ class SpecialOp(Enum):
CALL_PACKED = "relax.call_packed"
DATAFLOW = "relax.dataflow"
DATAFLOW_OUTPUT = "relax.output"
TUPLE = "relax.Tuple"
TUPLE_GET_ITEM = "relax.TupleGetItem"


class ArithmeticOp(Enum):
Expand Down Expand Up @@ -520,7 +522,7 @@ def is_match_shape(self, stmt: ast.Stmt) -> bool:
if isinstance(stmt, ast.UnassignedCall):
call_op = self.transform_expr(stmt.call.func_name)
elif isinstance(stmt, ast.Assign) and isinstance(stmt.rhs, ast.Call):
call_op = self.transform_expr(stmt.rhs.func_name)
call_op = self.transform_expr(stmt.rhs)
return call_op == SpecialOp.MATCH_SHAPE

def parse_binding(self, stmt: ast.Stmt, is_dataflow: bool = False) -> relax.Binding:
Expand Down Expand Up @@ -865,6 +867,8 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, relax.Expr]:
The parsed expression. It will be a PrimExpr if expr is an arithmetic operation on
PrimExprs.
"""
if isinstance(expr.func_name, ast.Op) and expr.func_name.name == ast.BuiltinOp.Subscript:
return self.transform_Subscript(expr)
op = self.transform_expr(expr.func_name)

if op == SpecialOp.CALL_PACKED:
Expand All @@ -877,6 +881,16 @@ def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, relax.Expr]:
op = relax.ExternFunc(extern_func.value, self.to_tvm_span(extern_func.span))
args = [self.transform_expr(arg) for arg in expr.params[1:]]

elif op == SpecialOp.TUPLE:
args = [self.transform_expr(arg) for arg in expr.params[0].values]
return relax.Tuple(args)

elif op == SpecialOp.TUPLE_GET_ITEM:
assert len(expr.params) == 2, "TupleGetItem expects to get two parameters."
args = [self.transform_expr(arg) for arg in expr.params]
# index of TupleGetItem only accepts int type intead of tir.expr.IntImm
return relax.TupleGetItem(args[0], args[1].value)

elif isinstance(op, ArithmeticOp):
args = [self.transform_expr(arg) for arg in expr.params]
if all([isinstance(arg, tir.PrimExpr) for arg in args]):
Expand Down Expand Up @@ -953,10 +967,13 @@ def transform_expr(self, expr: ast.Expr) -> relax.Expr:
relax.Expr
The corresponding Relax expression
"""

if isinstance(expr, ast.Attr):
return self.parse_attr(expr)

elif isinstance(expr, ast.Call):
if hasattr(expr.func_name, "field") and expr.func_name.field.name == "match_shape":
return self.transform_expr(expr.func_name)
return self.parse_call(expr)

elif isinstance(expr, ast.Tuple):
Expand Down Expand Up @@ -1009,6 +1026,92 @@ def transform_expr(self, expr: ast.Expr) -> relax.Expr:
else:
self.report_error(f"unsupported expression: {expr}", expr.span)

def transform_Subscript(self, expr):
"""Array access visitor."""

symbol = self.transform(expr.params[0])
if symbol is None:
self.report_error(
f"Variable {expr.params[0].id.name} is not defined.", expr.params[0].span
)
indexes = [self.transform(x) for x in expr.params[1].values]
if isinstance(symbol, relax.expr.Var):
if len(indexes) > 1:
self.report_error(
"Only a single index can be provided when indexing into a `var`.",
expr.params[1].span,
)
index = indexes[0].value
if not isinstance(index, (tvm.tir.PrimExpr, int)):
self.report_error(
"Var load index should be an int or PrimExpr, but it is a" + type(index),
expr.span,
)
return call_with_error_reporting(
self.report_error,
expr.span,
relax.TupleGetItem,
symbol,
index,
)
elif isinstance(symbol, tvm.tir.expr.Var):
if symbol.dtype == "handle":
self.report_error(
"Cannot read directly from a handle, use `T.match_buffer` "
"to create a buffer to read from.",
expr.params[0].span,
)
if len(indexes) > 1:
self.report_error(
"Only a single index can be provided when indexing into a `var`.",
expr.params[1].span,
)
index = indexes[0]
if not isinstance(index, (tvm.tir.PrimExpr, int)):
self.report_error(
"Var load index should be an int or PrimExpr, but it is a" + type(index),
expr.span,
)

return call_with_error_reporting(
self.report_error,
expr.span,
tvm.tir.Load,
"float32",
symbol,
index,
True,
span=tvm_span_from_synr(expr.span),
)
elif isinstance(symbol, tvm.tir.Buffer):
return BufferSlice(
symbol, indexes, self.report_error, span=tvm_span_from_synr(expr.span)
)
elif isinstance(symbol, tvm.container.Array):
if len(indexes) > 1:
self.report_error(
"Array access should be one-dimension access, but the indices are "
+ str(indexes),
expr.span,
)
index = indexes[0]
if not isinstance(index, (int, tvm.tir.expr.IntImm)):
self.report_error(
"Array access index expected int or IntImm, but got " + type(index),
expr.span,
)
if int(index) >= len(symbol):
self.report_error(
f"Array access out of bound, size: {len(symbol)}, got index {index}.",
expr.span,
)
return symbol[int(index)]
else:
self.report_error(
f"Cannot subscript from a {type(symbol).__name__}.",
expr.params[0].span,
)

def transform_block(self, block: ast.Block) -> relax.SeqExpr:
"""Transforms the given synr block to a Relax SeqExpr (sequence of Blocks with a final
expression).
Expand Down
29 changes: 29 additions & 0 deletions tests/python/relax/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,35 @@ def f(x: Tensor[_, _], y: Tensor[(32,), "float32"]):
check_shape(tup.fields[1], (32,))


def test_tuplegetitem():
@R.function
def f(x: Tensor[_, _], y: Tensor[_, _]):
t1 = relax.Tuple((x, y))
t2 = (x, y)
a = t1[0]
b = relax.TupleGetItem(t2, 1)
c = add(a, b)
return c

x, y = f.params
bind_0 = f.body.blocks[0].bindings[0]
bind_1 = f.body.blocks[0].bindings[1]
bind_2 = f.body.blocks[0].bindings[2]
bind_3 = f.body.blocks[0].bindings[3]
bind_4 = f.body.blocks[0].bindings[4]
assert_structural_equal(bind_0.value.fields, [x, y])
assert_structural_equal(bind_1.value.fields, [x, y])
assert isinstance(bind_0.value, relax.expr.Tuple)
assert isinstance(bind_1.value, relax.expr.Tuple)
assert isinstance(bind_2.value, relax.TupleGetItem)
assert isinstance(bind_3.value, relax.TupleGetItem)
assert bind_2.value.index == 0
assert bind_3.value.index == 1
assert bind_2.var.name_hint == "a"
assert bind_3.var.name_hint == "b"
check_call(bind_4.value, "add", [bind_2.var, bind_3.var])


def test_local_func():
@R.function
def f(x: Tensor[_, _]):
Expand Down
14 changes: 14 additions & 0 deletions tests/python/relax/test_printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,20 @@ def foo(x: Tensor[_, _], y: Tensor[(32,), "float32"]):
check_roundtrip(foo)


def test_tuplegetitem():
@R.function
def foo(x: Tensor[_, _]):
y = add(x, x)
z = multiply(y, x)
t = relax.Tuple((y, z))
a = relax.TupleGetItem(t, 0)
b = relax.TupleGetItem(t, 1)
c = divide(a, b)
return c

check_roundtrip(foo)


def test_local_func():
@R.function
def foo(x: Tensor[_, _]):
Expand Down
22 changes: 22 additions & 0 deletions tests/python/relax/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,27 @@ def test_vm_tuple():
np.testing.assert_allclose(res3.numpy(), inp.numpy())


def test_vm_tuplegetitem():
@tvm.script.ir_module
class TestVMTupleGetItem:
@R.function
def tuple_get_item(x: Tensor[(_, _), "float32"], y: Tensor[(_, _), "float32"]):
t = relax.Tuple((x, y))
a = relax.TupleGetItem(t, 0)
b = relax.TupleGetItem(t, 1)
c = relax.call_packed("test.vm.add", a, b)
return c

mod = TestVMTupleGetItem
target = tvm.target.Target("llvm", host="llvm")
ex, lib = relax.vm.build(mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
x_inp = tvm.nd.array(np.random.rand(2, 3))
y_inp = tvm.nd.array(np.random.rand(2, 3))
res = vm["tuple_get_item"](x_inp, y_inp)
np.testing.assert_allclose(res.numpy(), x_inp.numpy() + y_inp.numpy())


if __name__ == "__main__":
test_vm_execute()
test_vm_multiple_func()
Expand All @@ -699,3 +720,4 @@ def test_vm_tuple():
test_vm_relax_symbolic_shape()
test_vm_relax_dyn_tir_shape()
test_vm_tuple()
test_vm_tuplegetitem()

0 comments on commit ea8cc43

Please sign in to comment.