Skip to content

Commit

Permalink
[Parser][Printer] Add class -> IRModule parsing, and extern func supp…
Browse files Browse the repository at this point in the history
…ort for call_dps (#15)

* update parser and printer for match_shape

* support parsing class to IRModule, and extern func in call_dps
  • Loading branch information
altanh authored and junrushao committed Feb 5, 2023
1 parent a8713cb commit 926c696
Show file tree
Hide file tree
Showing 5 changed files with 266 additions and 177 deletions.
3 changes: 3 additions & 0 deletions include/tvm/ir/op.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,9 @@ class OpNode : public RelayExprNode {
// Internal function to compute if it is primitive op
bool IsPrimitiveOp_() const {
const auto& fn_ty = this->op_type;
if (!fn_ty.get()) {
return false;
}
ICHECK(fn_ty.get() != nullptr) << "op_type of " << this->name << " is not registered";
if (fn_ty->type_constraints.size() != 1) return false;
const TypeRelationNode* rel = fn_ty->type_constraints[0].as<TypeRelationNode>();
Expand Down
296 changes: 166 additions & 130 deletions python/tvm/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,6 @@
import tvm.relax as rx


def pretty_print(node):
"""Prints the given Relax IR node in the Relax text format.
Parameters
----------
node : Union[rx.Type, rx.Expr, rx.Binding, rx.BindingBlock]
The Relax IR node to print.
"""
print(tvm.script._ffi_api.AsRelaxScript(node))


def astext(node) -> str:
"""Returns the Relax text format representation of the given Relax IR node.
Parameters
----------
node : Union[rx.Type, rx.Expr, rx.Binding, rx.BindingBlock]
The Relax IR node to print.
Returns
-------
str
The text format representation of the given Relax IR node.
"""
return tvm.script._ffi_api.AsRelaxScript(node)


def _is_registered(op_name: str, op_set=None) -> bool:
"""Returns whether or not the given operator is registered.
Expand Down Expand Up @@ -130,10 +103,9 @@ class ArithmeticOp(Enum):


class RelaxTransformer(Transformer):
def __init__(self, definition_scope):
def __init__(self):
super().__init__()
self.definition_scope = definition_scope
self.module = {}
self.module = tvm.IRModule()
self._scopes = [{}] # str -> Var
self._registered_ops = set(tvm.ir._ffi_api.ListOpNames()) # cached

Expand Down Expand Up @@ -415,7 +387,7 @@ def parse_primexpr(self, expr: ast.Expr, bind_free_vars: bool) -> tir.PrimExpr:
self.report_error(f"unsupported dimension expression: {expr}", expr.span)

def transform_module(self, mod: ast.Module) -> IRModule:
"""Transforms the given synr Module to a Relax IRModule.
"""Transforms the given synr Module to a Relax IRModule or Function.
Parameters
----------
Expand All @@ -424,13 +396,28 @@ def transform_module(self, mod: ast.Module) -> IRModule:
Returns
-------
IRModule
The parsed Relax IRModule
Union[IRModule, Function]
The parsed Relax IRModule or Function
"""
for func_name in mod.funcs:
func = mod.funcs[func_name]
self.module[func_name] = self.transform_function(func, is_global=True)
return self.module
if len(mod.funcs) != 1:
self.report_error(
"the input must be either a single function or a single class", mod.span
)

(root_func,) = mod.funcs.values()

if isinstance(root_func, ast.Function):
return self.transform_function(root_func, is_global=True)
elif isinstance(root_func, ast.Class):
# add global vars to the root scope for resolving global function calls
for func_name in root_func.funcs:
self.scope[func_name] = relay.GlobalVar(func_name)
for func_name, func in root_func.funcs.items():
global_var = self.scope[func_name]
self.module[global_var] = self.transform_function(func, is_global=True)
return self.module
else:
self.report_error(f"unsupported input class: {root_func}", root_func.span)

def _parse_attrs_to_str(self, expr: ast.Attr) -> str:
strs = []
Expand Down Expand Up @@ -804,6 +791,104 @@ def parse_dataflow(self, block: ast.Block) -> rx.DataflowBlock:

return rx.DataflowBlock(bindings, self.to_tvm_span(block.span))

def parse_attr(self, expr: ast.Attr) -> rx.Expr:
"""Parses the given synr Attr node to a Relax expression.
Parameters
----------
expr : ast.Attr
The synr Attr node to be parsed.
Returns
-------
rx.Expr
The parsed expression.
"""
if expr.field.name == "shape":
obj = self.transform_expr(expr.object)
attrs = tvm.ir.attrs.make_node("relay.attrs.ShapeOfAttrs", dtype="int32")
return relay.Call(
relay.op.get("shape_of"), [obj], attrs=attrs, span=self.to_tvm_span(expr.span)
)
else:
# assume it's a hierarchical op identifier (e.g. nn.softmax, relax.call_dps)
op_name = self._parse_attrs_to_str(expr)
# NOTE: at least for now, all special operators are namespaced
try:
return SpecialOp(op_name)
except ValueError:
# TODO(@altanh): maybe diagnostics here in case this fails?
return relay.op.get(op_name)

def parse_call(self, expr: ast.Call) -> Union[tir.PrimExpr, rx.Expr]:
"""Parses the given synr Call node to a Relax expression or PrimExpr.
Parameters
----------
expr : ast.Call
The synr Call node to be parsed.
Returns
-------
Union[tir.PrimExpr, rx.Expr]
The parsed expression. It will be a PrimExpr if expr is an arithmetic operation on
PrimExprs.
"""
op = self.transform_expr(expr.func_name)

if op == SpecialOp.CALL_PACKED:
if len(expr.params) != 2:
self.report_error(
op.value + " takes an extern function name and a tuple of arguments",
expr.span,
)
extern_func = expr.params[0]
if not (isinstance(extern_func, ast.Constant) and isinstance(extern_func.value, str)):
self.report_error(
"the first argument of " + op.value + " must be the extern function name",
extern_func.span,
)
op = rx.ExternFunc(extern_func.value, self.to_tvm_span(extern_func.span))
args = [self.transform_expr(expr.params[1])]

elif isinstance(op, ArithmeticOp):
args = [self.transform_expr(arg) for arg in expr.params]
if all([isinstance(arg, tir.PrimExpr) for arg in args]):
return PRIMEXPR_ARITHMETIC_OP_MAP[op](*args, span=self.to_tvm_span(expr.span))
# otherwise it's just a normal Relax operator call
op = RELAX_ARITHMETIC_OP_MAP[op]

elif isinstance(op, tvm.ir.Op):
args = [self.transform_expr(arg) for arg in expr.params]
# check call arity eagerly
if op.num_inputs != -1 and len(args) != op.num_inputs:
self.report_error(
f"{op.name} expects {op.num_input} arguments but got {len(args)}", expr.span
)
if op.name == "relax.call_dps" and isinstance(args[1], str):
# extern function call case: rewrite identifier to an ExternFunc
args[1] = rx.ExternFunc(args[1], self.to_tvm_span(expr.params[1].span))

elif isinstance(op, relay.Expr):
args = [self.transform_expr(arg) for arg in expr.params]

else:
self.report_error(f"unsupported function in call: {op}", expr.func_name.span)

# parse call attributes if applicable
if isinstance(op, rx.ExternFunc) or (isinstance(op, tvm.ir.Op) and op.attrs_type_key != ""):
attrs_type_key = "DictAttrs" if isinstance(op, rx.ExternFunc) else op.attrs_type_key
kwargs = {}
for key, val in expr.keyword_params.items():
assert isinstance(key, ast.Constant) and isinstance(key.value, str)
# TODO(@altanh): might need separate attribute parsing eventually
kwargs[key.value] = self.transform_expr(val)
attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs)
else:
attrs = None

return relay.Call(op, args, attrs=attrs, span=self.to_tvm_span(expr.span))

# Exprs:
# - ArrayLiteral: unsupported for now?
# - Attr: use for .shape, and intrinsic/special operator namespace
Expand All @@ -827,65 +912,10 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr:
The corresponding Relax expression
"""
if isinstance(expr, ast.Attr):
if expr.field.name == "shape":
obj = self.transform_expr(expr.object)
attrs = tvm.ir.attrs.make_node("relay.attrs.ShapeOfAttrs", dtype="int32")
return relay.Call(
relay.op.get("shape_of"), [obj], attrs=attrs, span=self.to_tvm_span(expr.span)
)
else:
# assume it's a hierarchical op identifier (e.g. nn.softmax, relax.call_dps)
op_name = self._parse_attrs_to_str(expr)
# NOTE: at least for now, all special operators are namespaced
try:
return SpecialOp(op_name)
except ValueError:
# TODO(@altanh): maybe diagnostics here in case this fails?
return relay.op.get(op_name)

if isinstance(expr, ast.Call):
# TODO(@altanh): support parsing kwargs as attributes?
op = self.transform_expr(expr.func_name)
if op == SpecialOp.CALL_PACKED:
if len(expr.params) != 2:
self.report_error(
op.value + " takes an extern function name and a tuple of arguments",
expr.span,
)
extern_func = expr.params[0]
if not (
isinstance(extern_func, ast.Constant) and isinstance(extern_func.value, str)
):
self.report_error(
"the first argument of " + op.value + " must be the extern function name",
extern_func.span,
)
op = rx.ExternFunc(extern_func.value, self.to_tvm_span(extern_func.span))
args = [self.transform_expr(expr.params[1])]
elif isinstance(op, ArithmeticOp):
args = [self.transform_expr(arg) for arg in expr.params]
if all([isinstance(arg, tir.PrimExpr) for arg in args]):
return PRIMEXPR_ARITHMETIC_OP_MAP[op](*args, span=self.to_tvm_span(expr.span))
# otherwise it's just a normal Relax operator call
op = RELAX_ARITHMETIC_OP_MAP[op]
elif isinstance(op, (tvm.ir.Op, relay.Expr)):
args = [self.transform_expr(arg) for arg in expr.params]
else:
self.report_error(f"unsupported function in call: {op}", expr.func_name.span)
return self.parse_attr(expr)

if isinstance(op, rx.ExternFunc) or (
isinstance(op, tvm.ir.Op) and op.attrs_type_key != ""
):
attrs_type_key = "DictAttrs" if isinstance(op, rx.ExternFunc) else op.attrs_type_key
kwargs = {}
for key, val in expr.keyword_params.items():
assert isinstance(key, ast.Constant) and isinstance(key.value, str)
kwargs[key.value] = self.transform_expr(val)
attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs)
else:
attrs = None
# TODO(@altanh): should we check for correct arity here eagerly, or defer to a pass?
return relay.Call(op, args, attrs=attrs, span=self.to_tvm_span(expr.span))
elif isinstance(expr, ast.Call):
return self.parse_call(expr)

elif isinstance(expr, ast.Tuple):
fields = [self.transform_expr(field) for field in expr.values]
Expand Down Expand Up @@ -1015,61 +1045,67 @@ def transform_block(self, block: ast.Block) -> rx.SeqExpr:
# self.tvm_diag_ctx.render()


# TODO(@altanh, @jroesch): revisit this?
class RelaxDecoratedFn:
def __init__(self, fn_name, relax_module, diag_ctx):
self.fn_name = fn_name
self.module = relax_module
self.diag_ctx = diag_ctx

def __call__(self, *args):
pretty_print(self.module[self.fn_name])
# compiler = Compiler(self.diag_ctx, self.module, self.fn_name)
# compiled_f = compiler.compile(execute=True)
# # Actually compute needed buffer sizes.
# out = tvm.nd.array(np.random.rand(10).astype('float32'))
# compiled_f(*(list(args) + [out]))
# return out


def script(f) -> RelaxDecoratedFn:
"""Parses the decorated Relax function (in Relax IR) to a Relax AST.
def script(f) -> Union[rx.Function, tvm.IRModule]:
"""Parses the decorated Relax function or module (in Relax IR) to a Relax AST.
Parameters
----------
f : function
The function to be parsed, written in the Relax IR
f : Union[function, class]
The function or class to be parsed, written in the Relax IR.
Returns
-------
RelaxDecoratedFn
The parsed Relax function
Union[rx.Function, IRModule]
The parsed Relax function or IRModule.
"""
# ir_module = tvm.IRModule({})
# diag_ctx = diagnostics.DiagnosticContext(ir_module, diagnostics.get_renderer())
diag_ctx = tvm.script.diagnostics.TVMDiagnosticCtx()
ast = synr.to_ast(f, diag_ctx)
definition_scope = inspect.getmodule(f)
module = RelaxTransformer(definition_scope).do_transform(ast, diag_ctx)
return RelaxDecoratedFn(f.__name__, module, diag_ctx)
return RelaxTransformer().do_transform(ast, diag_ctx)


def fromtext(source: str, source_name: str = "from_string"):
"""Parses the given input string (in the Relax text format) to a Relax AST.
def fromtext(source: str, source_name: str = "from_string") -> Union[rx.Function, tvm.IRModule]:
"""Parses the given input string (in the Relax text format) to a Relax function or IRModule.
Parameters
----------
source : str
The input source string.
The input source string. It should be either a decorated Python class or function.
source_name : str, optional
A descriptive name for error reporting, by default "from_string".
Returns
-------
Relax AST
The parsed Relax AST.
Union[rx.Function, IRModule]
The parsed Relax function or IRModule.
"""
# TODO(@altanh): actually use source_name somewhere?
diag_ctx = tvm.script.diagnostics.TVMDiagnosticCtx()
ast = synr.to_ast(source, diag_ctx)
module = RelaxTransformer(None).do_transform(ast, diag_ctx)
return module
return RelaxTransformer().do_transform(ast, diag_ctx)


def pretty_print(node):
"""Prints the given Relax IR node in the Relax text format.
Parameters
----------
node : Union[rx.Type, rx.Expr, rx.Binding, rx.BindingBlock]
The Relax IR node to print.
"""
print(tvm.script._ffi_api.AsRelaxScript(node))


def astext(node) -> str:
"""Returns the Relax text format representation of the given Relax IR node.
Parameters
----------
node : Union[rx.Type, rx.Expr, rx.Binding, rx.BindingBlock]
The Relax IR node to print.
Returns
-------
str
The text format representation of the given Relax IR node.
"""
return tvm.script._ffi_api.AsRelaxScript(node)
Loading

0 comments on commit 926c696

Please sign in to comment.