From f8519c1f0ca53f204b82afdd19d9ad0c79564346 Mon Sep 17 00:00:00 2001 From: Altan Haan Date: Mon, 27 Sep 2021 14:03:36 -0700 Subject: [PATCH] [Parser][Printer] More parser/printer improvements (#12) * Relax pretty printer initial prototype * call into TVMScriptPrinter for PrimFuncs * most round-trip tests pass * address comments * implement relax.output syntax for dataflow block outputs * remove leftover comments * fix Var constructor on ShapeExpr annotation * add printing and parsing for simple PrimExpr and Call Attrs --- python/tvm/relax/parser.py | 92 +++++++++++-- src/printer/relax_script_printer.cc | 205 ++++++++++++++++++++++------ tests/python/relax/test_parser.py | 21 ++- tests/python/relax/test_printer.py | 14 +- 4 files changed, 272 insertions(+), 60 deletions(-) diff --git a/python/tvm/relax/parser.py b/python/tvm/relax/parser.py index 026763cc2b5b..cf4c5c261ad5 100644 --- a/python/tvm/relax/parser.py +++ b/python/tvm/relax/parser.py @@ -94,7 +94,7 @@ def _tir_from_synr( # NOTE: call_dps is an actual registered operator class SpecialOp(Enum): - """Relax operator calls that have special semantics handled by the parser.""" + """Relax operators that have special semantics handled by the parser.""" MATCH_SHAPE = "relax.match_shape" CALL_PACKED = "relax.call_packed" @@ -102,6 +102,33 @@ class SpecialOp(Enum): DATAFLOW_OUTPUT = "relax.output" +class ArithmeticOp(Enum): + """Arithmetic operators that can desugar to either Relax or TIR PrimExpr operators.""" + + ADD = ast.BuiltinOp.Add + SUB = ast.BuiltinOp.Sub + MUL = ast.BuiltinOp.Mul + DIV = ast.BuiltinOp.Div + FLOOR_DIV = ast.BuiltinOp.FloorDiv + + +RELAX_ARITHMETIC_OP_MAP = { + ArithmeticOp.ADD: relay.op.get("add"), + ArithmeticOp.SUB: relay.op.get("subtract"), + ArithmeticOp.MUL: relay.op.get("multiply"), + ArithmeticOp.DIV: relay.op.get("divide"), + ArithmeticOp.FLOOR_DIV: relay.op.get("floor_divide"), +} + +PRIMEXPR_ARITHMETIC_OP_MAP = { + ArithmeticOp.ADD: tir.Add, + ArithmeticOp.SUB: tir.Sub, + ArithmeticOp.MUL: tir.Mul, + ArithmeticOp.DIV: tir.Div, + ArithmeticOp.FLOOR_DIV: tir.FloorDiv, +} + + class RelaxTransformer(Transformer): def __init__(self, definition_scope): super().__init__() @@ -367,16 +394,25 @@ def parse_primexpr(self, expr: ast.Expr, bind_free_vars: bool) -> tir.PrimExpr: "cannot introduce new dimension variables in this expression", expr.span, ) + elif isinstance(expr, ast.Constant): if not isinstance(expr.value, int): self.report_error("only integer constants are supported", expr.span) return tir.const(expr.value, "int32", self.to_tvm_span(expr.span)) + + elif isinstance(expr, ast.Call): + if not isinstance(expr.func_name, ast.Op): + self.report_error( + "only built-in operators can be used in dimension expressions", + expr.func_name.span, + ) + op = PRIMEXPR_ARITHMETIC_OP_MAP[self.transform_expr(expr.func_name)] + # TODO(@altanh): it might not make sense to bind free variables + args = [self.parse_primexpr(arg, bind_free_vars) for arg in expr.params] + return op(*args, span=self.to_tvm_span(expr.span)) + else: - # TODO(@altanh): parse (simple) PrimExprs - self.report_error( - "only dimension variable expressions are currently supported", - expr.span, - ) + self.report_error(f"unsupported dimension expression: {expr}", expr.span) def transform_module(self, mod: ast.Module) -> IRModule: """Transforms the given synr Module to a Relax IRModule. @@ -750,7 +786,10 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: if isinstance(expr, ast.Attr): if expr.field.name == "shape": obj = self.transform_expr(expr.object) - return relay.Call(relay.op.get("shape_of"), [obj], span=self.to_tvm_span(expr.span)) + attrs = tvm.ir.attrs.make_node("relay.attrs.ShapeOfAttrs", dtype="int32") + return relay.Call( + relay.op.get("shape_of"), [obj], attrs=attrs, span=self.to_tvm_span(expr.span) + ) else: # assume it's a hierarchical op identifier (e.g. nn.softmax, relax.call_dps) op_name = self._parse_attrs_to_str(expr) @@ -780,12 +819,30 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: ) op = rx.ExternFunc(extern_func.value, self.to_tvm_span(extern_func.span)) args = [self.transform_expr(expr.params[1])] + elif isinstance(op, ArithmeticOp): + args = [self.transform_expr(arg) for arg in expr.params] + if all([isinstance(arg, tir.PrimExpr) for arg in args]): + return PRIMEXPR_ARITHMETIC_OP_MAP[op](*args, span=self.to_tvm_span(expr.span)) + # otherwise it's just a normal Relax operator call + op = RELAX_ARITHMETIC_OP_MAP[op] elif isinstance(op, (tvm.ir.Op, relay.Expr)): args = [self.transform_expr(arg) for arg in expr.params] else: self.report_error(f"unsupported function in call: {op}", expr.func_name.span) + + if isinstance(op, rx.ExternFunc) or ( + isinstance(op, tvm.ir.Op) and op.attrs_type_key != "" + ): + attrs_type_key = "DictAttrs" if isinstance(op, rx.ExternFunc) else op.attrs_type_key + kwargs = {} + for key, val in expr.keyword_params.items(): + assert isinstance(key, ast.Constant) and isinstance(key.value, str) + kwargs[key.value] = self.transform_expr(val) + attrs = tvm.ir.attrs.make_node(attrs_type_key, **kwargs) + else: + attrs = None # TODO(@altanh): should we check for correct arity here eagerly, or defer to a pass? - return relay.Call(op, args, span=self.to_tvm_span(expr.span)) + return relay.Call(op, args, attrs=attrs, span=self.to_tvm_span(expr.span)) elif isinstance(expr, ast.Tuple): fields = [self.transform_expr(field) for field in expr.values] @@ -812,14 +869,27 @@ def transform_expr(self, expr: ast.Expr) -> rx.Expr: return tir.IntImm("int32", expr.value, self.to_tvm_span(expr.span)) elif isinstance(expr.value, float): return tir.FloatImm("float32", expr.value, self.to_tvm_span(expr.span)) + elif isinstance(expr.value, str): + # FIXME(@altanh): using StringImm seems to cause problems, but this loses span + return expr.value + elif expr.value is None: + return None else: self.report_error( - "unsupported constant expression (we currently only support int and float)", + f"unsupported constant expression: {expr}", expr.span, ) + elif isinstance(expr, ast.Op): + # TODO(@altanh): might need to generalize from ArithmeticOp if we decide to support + # array slicing syntax + try: + return ArithmeticOp(expr.name) + except ValueError: + self.report_error(f"unsupported built-in operator: {expr.name}", expr.span) + else: - self.report_error("unsupported expression", expr.span) + self.report_error(f"unsupported expression: {expr}", expr.span) def transform_block(self, block: ast.Block) -> rx.SeqExpr: """Transforms the given synr block to a Relax SeqExpr (sequence of Blocks with a final @@ -842,7 +912,7 @@ def transform_block(self, block: ast.Block) -> rx.SeqExpr: parsed_stmt = self.transform_stmt(stmt) if isinstance(parsed_stmt, rx.DataflowBlock): if current_block: - # FIXME: span + # FIXME(@altanh): need to manually construct span start & end blocks.append(rx.BindingBlock(current_block, self.to_tvm_span(stmt.span))) current_block = [] blocks.append(parsed_stmt) diff --git a/src/printer/relax_script_printer.cc b/src/printer/relax_script_printer.cc index ed825781c645..e50588ad1f18 100644 --- a/src/printer/relax_script_printer.cc +++ b/src/printer/relax_script_printer.cc @@ -35,7 +35,9 @@ namespace tvm { namespace relax { class RelaxScriptPrinter : public relax::IRFunctor, - public TypeFunctor { + public tir::ExprFunctor, + public TypeFunctor, + public AttrFunctor { public: TVM_DLL Doc Print(const ObjectRef& node); @@ -65,7 +67,15 @@ class RelaxScriptPrinter : public relax::IRFunctor, Doc VisitNode_(const relax::FunctionNode* op) override; Doc VisitNode_(const relax::ExternFuncNode* op) override; - Doc PrintDimVar(const tir::Var& var); + // PrimExpr nodes allowed in Relax + Doc VisitExpr_(const tir::VarNode* op) override; + Doc VisitExpr_(const tir::IntImmNode* op) override; + Doc VisitExpr_(const tir::AddNode* op) override; + Doc VisitExpr_(const tir::SubNode* op) override; + Doc VisitExpr_(const tir::MulNode* op) override; + Doc VisitExpr_(const tir::DivNode* op) override; + Doc VisitExpr_(const tir::FloorDivNode* op) override; + Doc PrintIfStmt(const relax::Var& var, const relay::If& ite); Doc PrintFunctionDef(const Doc& name, const relax::Function& func); @@ -75,12 +85,59 @@ class RelaxScriptPrinter : public relax::IRFunctor, Doc VisitType_(const relax::DynTensorTypeNode* node) override; Doc VisitType_(const relay::TupleTypeNode* node) override; + Doc PrintAttr(const ObjectRef& attr); + std::vector PrintAttrs(const Attrs& attrs); + Doc VisitAttrDefault_(const Object* op) override; + Doc VisitAttr_(const ArrayNode* op) override; + Doc VisitAttr_(const tir::IntImmNode* op) override; + Doc VisitAttr_(const tir::FloatImmNode* op) override; + Doc GetUniqueName(std::string prefix, std::string fallback); + + /*! + * \brief Attribute printer which prints the attributes as kwargs in a call. + */ + class AttrPrinter : public AttrVisitor { + public: + AttrPrinter(std::vector* docs, RelaxScriptPrinter* parent) : docs(docs), parent_(parent) {} + + template + void PrintKV(const char* key, const T& value) { + Doc doc; + doc << key << "=" << value; + docs->push_back(doc); + } + + void Visit(const char* key, double* value) final { PrintKV(key, *value); } + void Visit(const char* key, int64_t* value) final { PrintKV(key, *value); } + void Visit(const char* key, uint64_t* value) final { PrintKV(key, *value); } + void Visit(const char* key, int* value) final { PrintKV(key, *value); } + void Visit(const char* key, bool* value) final { PrintKV(key, Doc::PyBoolLiteral(*value)); } + void Visit(const char* key, std::string* value) final { PrintKV(key, Doc::StrLiteral(*value)); } + void Visit(const char* key, void** value) final { + LOG(FATAL) << "do not allow void as argument"; + } + void Visit(const char* key, DataType* value) final { + PrintKV(key, Doc::StrLiteral(runtime::DLDataType2String(*value))); + } + void Visit(const char* key, runtime::NDArray* value) final { + LOG(FATAL) << "do not allow NDarray as argument"; + } + void Visit(const char* key, runtime::ObjectRef* obj) final { + PrintKV(key, parent_->PrintAttr(*obj)); + } + + private: + std::vector* docs; + RelaxScriptPrinter* parent_; + }; }; Doc RelaxScriptPrinter::Print(const ObjectRef& node) { if (node->IsInstance()) { return VisitType(Downcast(node)); + } else if (node->IsInstance()) { + return VisitExpr(Downcast(node)); } else { return VisitNode(node); } @@ -96,8 +153,8 @@ Doc RelaxScriptPrinter::VisitNode_(const relay::TupleNode* op) { Doc doc; std::vector fields; - for (size_t i = 0; i < num_fields; ++i) { - fields.push_back(Print(op->fields[i])); + for (const Expr& field : op->fields) { + fields.push_back(Print(field)); } doc << "(" << Doc::Concat(fields, Doc::Text(", ")); if (num_fields == 1) { @@ -113,27 +170,27 @@ Doc RelaxScriptPrinter::VisitNode_(const relay::GlobalVarNode* op) { } Doc RelaxScriptPrinter::VisitNode_(const relay::CallNode* op) { + // TODO(@altanh): how to support when func cannot be printed as Python expr? + // e.g. Function or If Doc doc; - if (const relax::ExternFuncNode* ext = op->op.as()) { + if (op->op.as()) { ICHECK_EQ(op->args.size(), 1) << "extern calls should only have one argument"; - doc << "relax.call_packed(" << Print(op->op) << ", " << Print(op->args[0]) << ")"; - return doc; + doc << "relax.call_packed(" << Print(op->op) << ", " << Print(op->args[0]); + } else { + std::vector args; + for (const Expr& arg : op->args) { + args.push_back(Print(arg)); + } + doc << Print(op->op) << "(" << Doc::Concat(args, Doc::Text(", ")); } - // TODO(@altanh): how to support when func cannot be printed as Python expr? - // e.g. Function or If - doc << Print(op->op); - if (op->args.empty()) { - doc << "()"; - return doc; + std::vector attrs = PrintAttrs(op->attrs); + if (!attrs.empty()) { + doc << ", " << Doc::Concat(attrs); } - std::vector args; - for (size_t i = 0; i < op->args.size(); ++i) { - args.push_back(Print(op->args[i])); - } - doc << "(" << Doc::Concat(args, Doc::Text(", ")) << ")"; + doc << ")"; return doc; } @@ -168,15 +225,8 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::ShapeExprNode* op) { Doc doc; std::vector fields; - for (size_t i = 0; i < op->values.size(); ++i) { - auto val = op->values[i]; - if (const tir::VarNode* var = val.as()) { - fields.push_back(PrintDimVar(GetRef(var))); - } else if (const tir::IntImmNode* num = val.as()) { - fields.push_back(Doc::Text(std::to_string(num->value))); - } else { - LOG(FATAL) << "cannot print PrimExpr: " << val->GetTypeKey(); - } + for (const PrimExpr& field : op->values) { + fields.push_back(Print(field)); } doc << "(" << Doc::Concat(fields, Doc::Text(", ")); if (fields.size() == 1) { @@ -227,8 +277,8 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::VarBindingNode* op) { Doc RelaxScriptPrinter::VisitNode_(const relax::BindingBlockNode* op) { Doc doc; - for (size_t i = 0; i < op->bindings.size(); ++i) { - doc << Print(op->bindings[i]) << Doc::NewLine(); + for (const relax::Binding& binding : op->bindings) { + doc << Print(binding) << Doc::NewLine(); } return doc; } @@ -237,11 +287,11 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::DataflowBlockNode* op) { Doc block; Doc body; std::vector return_vars; - for (size_t i = 0; i < op->bindings.size(); ++i) { - body << Print(op->bindings[i]) << Doc::NewLine(); - if (const relax::VarBindingNode* binding = op->bindings[i].as()) { - if (!binding->var.as()) { - return_vars.push_back(Print(binding->var)); + for (const relax::Binding& binding : op->bindings) { + body << Print(binding) << Doc::NewLine(); + if (const relax::VarBindingNode* var_binding = binding.as()) { + if (!var_binding->var.as()) { + return_vars.push_back(Print(var_binding->var)); } } } @@ -254,8 +304,8 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::DataflowBlockNode* op) { Doc RelaxScriptPrinter::VisitNode_(const relax::SeqExprNode* op) { Doc doc; - for (size_t i = 0; i < op->blocks.size(); ++i) { - doc << Print(op->blocks[i]); + for (const relax::BindingBlock& block : op->blocks) { + doc << Print(block); } // NOTE: the body expression is printed in the parent, since SeqExprs are used for both Function // bodies and If expr bodies (which don't have a "return" statement but instead a binding) @@ -271,6 +321,32 @@ Doc RelaxScriptPrinter::VisitNode_(const relax::ExternFuncNode* op) { return Doc::StrLiteral(op->global_symbol); } +Doc RelaxScriptPrinter::VisitExpr_(const tir::VarNode* op) { + tir::Var var = GetRef(op); + if (!dim_var_map_.count(var)) { + dim_var_map_[var] = GetUniqueName(var->name_hint, "dim"); + } + return dim_var_map_[var]; +} + +Doc RelaxScriptPrinter::VisitExpr_(const tir::IntImmNode* op) { + return Doc::Text(std::to_string(op->value)); +} + +#define TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(OpName, OpString) \ + Doc RelaxScriptPrinter::VisitExpr_(const OpName* op) { \ + Doc doc; \ + doc << "(" << Print(op->a) << OpString; \ + doc << Print(op->b) << ")"; \ + return doc; \ + } + +TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(tir::AddNode, " + ") +TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(tir::SubNode, " - ") +TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(tir::MulNode, " * ") +TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(tir::DivNode, " / ") +TVM_DEFINE_RELAX_PRINTER_PRIMEXPR_BINOP(tir::FloorDivNode, " // "); + Doc RelaxScriptPrinter::VisitType_(const relax::ShapeTypeNode* node) { return Doc::Text("Shape"); } Doc RelaxScriptPrinter::VisitType_(const relax::DynTensorTypeNode* node) { @@ -286,20 +362,61 @@ Doc RelaxScriptPrinter::VisitType_(const relay::TupleTypeNode* node) { Doc doc; std::vector fields; - for (size_t i = 0; i < node->fields.size(); ++i) { - fields.push_back(Print(node->fields[i])); + for (Type ty : node->fields) { + fields.push_back(Print(ty)); } - doc << "Tuple[" << Doc::Concat(fields, Doc::Text(", ")) << "]"; + doc << "Tuple[" << Doc::Concat(fields) << "]"; return doc; } -Doc RelaxScriptPrinter::PrintDimVar(const tir::Var& var) { - if (!dim_var_map_.count(var)) { - dim_var_map_[var] = GetUniqueName(var->name_hint, "dim"); +Doc RelaxScriptPrinter::PrintAttr(const ObjectRef& attr) { + if (attr.defined()) { + if (const StringObj* str = attr.as()) { + return Doc::StrLiteral(GetRef(str)); + } else { + return VisitAttr(attr); + } + } else { + return Doc::Text("None"); } +} - return dim_var_map_[var]; +std::vector RelaxScriptPrinter::PrintAttrs(const Attrs& attrs) { + std::vector kwargs; + if (!attrs.defined()) { + return kwargs; + } else if (const DictAttrsNode* dict_attrs = attrs.as()) { + for (const auto& k : dict_attrs->dict) { + kwargs.push_back(Doc::Text(k.first) << "=" << Print(k.second)); + } + } else { + AttrPrinter attr_printer(&kwargs, this); + const_cast(attrs.operator->())->VisitNonDefaultAttrs(&attr_printer); + } + return kwargs; +} + +Doc RelaxScriptPrinter::VisitAttrDefault_(const Object* op) { + return PrintAttr(GetRef(op)); +} + +Doc RelaxScriptPrinter::VisitAttr_(const ArrayNode* op) { + Doc doc; + std::vector arr_vals; + for (ObjectRef val : *op) { + arr_vals.push_back(PrintAttr(val)); + } + doc << "[" << Doc::Concat(arr_vals) << "]"; + return doc; +} + +Doc RelaxScriptPrinter::VisitAttr_(const tir::IntImmNode* op) { + return Doc::Text(std::to_string(op->value)); +} + +Doc RelaxScriptPrinter::VisitAttr_(const tir::FloatImmNode* op) { + return Doc::Text(std::to_string(op->value)); } Doc RelaxScriptPrinter::PrintIfStmt(const relax::Var& var, const relay::If& ite) { @@ -357,8 +474,6 @@ Doc RelaxScriptPrinter::PrintFunctionDef(const Doc& name, const relax::Function& Doc RelaxScriptPrinter::PrintTensorAnnotation(const relax::DynTensorType& ty, const Optional& shape) { Doc doc; - // doc << "Tensor[" - // << (shape.defined() ? Print(Downcast(shape.value())) : Doc::Text("_")) << ", "; doc << "Tensor["; if (shape.defined()) { doc << Print(Downcast(shape.value())); diff --git a/tests/python/relax/test_parser.py b/tests/python/relax/test_parser.py index 86296bd3f561..0575129ea12a 100644 --- a/tests/python/relax/test_parser.py +++ b/tests/python/relax/test_parser.py @@ -53,7 +53,7 @@ def check_call(call, op, args): def test_annotations(): @rx.script def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: - z: Tensor[(32, k), "float32"] = nn.matmul(x, y) + z: Tensor[(32, k), "float32"] = nn.matmul(x, y, units=None) w: Tensor[_, _] = multiply(z, z) q: Tensor[(_, _), _] = add(w, w) t = subtract(w, z) @@ -411,7 +411,7 @@ def test_call_packed(): @rx.script def foo(x: Tensor[(3, 4), "float32"]): # test that we can intro dim vars - z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", (x, x)) + z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", (x, x), mp=False) return z f = rx_func(foo) @@ -421,4 +421,21 @@ def foo(x: Tensor[(3, 4), "float32"]): assert isinstance(z_bind.value.op, rx.ExternFunc) assert z_bind.value.op.global_symbol == "contrib.my_matmul" + assert "mp" in z_bind.value.attrs and z_bind.value.attrs["mp"] == False assert structural_equal(z_bind.value.args, [rx.Tuple([x, x])]) + + +def test_primexpr_arithmetic(): + @rx.script + def foo(x: Tensor[(n, m), "float32"]): + z: Tensor[(n * m,), "float32"] = relax.call_packed("my_flatten", (x,)) + sh: Shape = (n + m, n // m) + return z + + f = rx_func(foo) + x = f.params[0] + n, m = x.shape_ + z_bind, sh_bind = f.body.blocks[0].bindings + + assert structural_equal(z_bind.var.shape_.values, [tir.Mul(n, m)]) + assert structural_equal(sh_bind.value.values, [tir.Add(n, m), tir.FloorDiv(n, m)]) diff --git a/tests/python/relax/test_printer.py b/tests/python/relax/test_printer.py index 4df8ac3d0d02..4623ad01b52e 100644 --- a/tests/python/relax/test_printer.py +++ b/tests/python/relax/test_printer.py @@ -19,7 +19,7 @@ def check_roundtrip(fn): def test_annotations(): @rx.script def foo(x: Tensor[(32, m), "float32"], y: Tensor[(m, k), "float32"]) -> Tensor: - z: Tensor[(32, k), "float32"] = nn.matmul(x, y) + z: Tensor[(32, k), "float32"] = nn.matmul(x, y, units=None) w: Tensor[_, _] = multiply(z, z) t = subtract(w, z) sh: Shape = t.shape @@ -129,7 +129,17 @@ def test_call_packed(): @rx.script def foo(x: Tensor[(3, 4), "float32"]): # test that we can intro dim vars - z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", (x, x)) + z: Tensor[(n, m), "float32"] = relax.call_packed("contrib.my_matmul", (x, x), mp=False) + return z + + check_roundtrip(foo) + + +def test_primexpr_arithmetic(): + @rx.script + def foo(x: Tensor[(n, m), "float32"]): + z: Tensor[(n * m,), "float32"] = relax.call_packed("my_flatten", (x,)) + sh: Shape = (n + m, n // m) return z check_roundtrip(foo)