diff --git a/include/tvm/relay/dataflow_pattern.h b/include/tvm/relay/dataflow_pattern.h index 80a5d6f52617..11ac7e39f4a3 100644 --- a/include/tvm/relay/dataflow_pattern.h +++ b/include/tvm/relay/dataflow_pattern.h @@ -309,6 +309,64 @@ class TypePattern : public DFPattern { TVM_DEFINE_OBJECT_REF_METHODS(TypePattern, DFPattern, TypePatternNode); }; +class ShapePattern; +/*! + * \brief Pattern for Shapes. + */ +class ShapePatternNode : public DFPatternNode { + public: + /*! \brief The pattern. */ + DFPattern pattern; + /*! \brief The type to match */ + Array shape; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("shape", &shape); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.ShapePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(ShapePatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches a type in another pattern + */ +class ShapePattern : public DFPattern { + public: + TVM_DLL ShapePattern(DFPattern pattern, Array type); + TVM_DEFINE_OBJECT_REF_METHODS(ShapePattern, DFPattern, ShapePatternNode); +}; + +class DataTypePattern; +/*! + * \brief Pattern for Types. + */ +class DataTypePatternNode : public DFPatternNode { + public: + /*! \brief The pattern. */ + DFPattern pattern; + /*! \brief The type to match */ + DataType dtype; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pattern", &pattern); + v->Visit("dtype", &dtype); + } + + static constexpr const char* _type_key = "relay.dataflow_pattern.DataTypePattern"; + TVM_DECLARE_FINAL_OBJECT_INFO(DataTypePatternNode, DFPatternNode); +}; + +/*! + * \brief A pattern which matches a type in another pattern + */ +class DataTypePattern : public DFPattern { + public: + TVM_DLL DataTypePattern(DFPattern pattern, DataType dtype); + TVM_DEFINE_OBJECT_REF_METHODS(DataTypePattern, DFPattern, DataTypePatternNode); +}; + class AttrPattern; /*! * \brief Pattern for Attributes. diff --git a/include/tvm/relay/dataflow_pattern_functor.h b/include/tvm/relay/dataflow_pattern_functor.h index a1140ae4f54e..98c81c929409 100644 --- a/include/tvm/relay/dataflow_pattern_functor.h +++ b/include/tvm/relay/dataflow_pattern_functor.h @@ -84,8 +84,10 @@ class DFPatternFunctor { virtual R VisitDFPattern_(const AltPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const AttrPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const CallPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const DataTypePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const DominatorPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const ExprPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; + virtual R VisitDFPattern_(const ShapePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TupleGetItemPatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; virtual R VisitDFPattern_(const TuplePatternNode* op, Args... args) DFPATTERN_FUNCTOR_DEFAULT; @@ -106,13 +108,15 @@ class DFPatternFunctor { RELAY_DFPATTERN_FUNCTOR_DISPATCH(AltPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(AttrPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(CallPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(DataTypePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(DominatorPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(ExprPatternNode); + RELAY_DFPATTERN_FUNCTOR_DISPATCH(ShapePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TupleGetItemPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TuplePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(TypePatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(VarPatternNode); - RELAY_DFPATTERN_FUNCTOR_DISPATCH(ConstantPatternNode); RELAY_DFPATTERN_FUNCTOR_DISPATCH(WildcardPatternNode); return vtable; } @@ -130,13 +134,15 @@ class DFPatternVisitor : public DFPatternFunctor { void VisitDFPattern_(const AltPatternNode* op) override; void VisitDFPattern_(const AttrPatternNode* op) override; void VisitDFPattern_(const CallPatternNode* op) override; + void VisitDFPattern_(const ConstantPatternNode* op) override; + void VisitDFPattern_(const DataTypePatternNode* op) override; void VisitDFPattern_(const DominatorPatternNode* op) override; void VisitDFPattern_(const ExprPatternNode* op) override; + void VisitDFPattern_(const ShapePatternNode* op) override; void VisitDFPattern_(const TupleGetItemPatternNode* op) override; void VisitDFPattern_(const TuplePatternNode* op) override; void VisitDFPattern_(const TypePatternNode* op) override; void VisitDFPattern_(const VarPatternNode* op) override; - void VisitDFPattern_(const ConstantPatternNode* op) override; void VisitDFPattern_(const WildcardPatternNode* op) override; protected: diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index e6a1a5e658e9..915842c8e5fa 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -97,6 +97,38 @@ def has_type(self, ttype: tvm.ir.type.Type): """ return has_type(ttype, self) + def has_dtype(self, dtype: str): + """ + Add a type constraint to this pattern + + Parameters + ---------- + dtype: str + The dtype to match + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting DataTypePattern + """ + return has_dtype(dtype, self) + + def has_shape(self, shape: List[tvm.ir.PrimExpr]): + """ + Add a type constraint to this pattern + + Parameters + ---------- + shape: List[tvm.ir.PrimExpr] + The shape to match + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting ShapePattern + """ + return has_shape(shape, self) + def match(self, expr: Expr) -> bool: """ Match this pattern to an expression @@ -293,18 +325,18 @@ def wildcard() -> "DFPattern": return WildcardPattern() -def has_type(ttype, pattern: "DFPattern" = None) -> "DFPattern": +def has_type(ttype: tvm.ir.type.Type, pattern: "DFPattern" = None) -> "DFPattern": """ Syntatic sugar for creating a TypePattern Parameters ---------- - pattern: tvm.relay.dataflow_pattern.DFPattern - The pattern that needs type annotation - ttype: tvm.ir.type.Type The type to match + pattern: tvm.relay.dataflow_pattern.DFPattern + The pattern that needs type annotation + Returns ------- result: tvm.relay.dataflow_pattern.DFPattern @@ -315,6 +347,50 @@ def has_type(ttype, pattern: "DFPattern" = None) -> "DFPattern": return TypePattern(pattern, ttype) +def has_dtype(dtype: str, pattern: "DFPattern" = None) -> "DFPattern": + """ + Syntatic sugar for creating a DataTypePattern + + Parameters + ---------- + dtype: str + The dtype to match + + pattern: tvm.relay.dataflow_pattern.DFPattern + The pattern that needs type annotation + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting DataTypePattern + """ + if pattern is None: + pattern = wildcard() + return DataTypePattern(pattern, dtype) + + +def has_shape(shape: List[tvm.ir.PrimExpr], pattern: "DFPattern" = None) -> "DFPattern": + """ + Syntatic sugar for creating a ShapePattern + + Parameters + ---------- + shape: List[tvm.ir.PrimExpr] + The shape to match + + pattern: tvm.relay.dataflow_pattern.DFPattern + The pattern that needs type annotation + + Returns + ------- + result: tvm.relay.dataflow_pattern.DFPattern + The resulting ShapePattern + """ + if pattern is None: + pattern = wildcard() + return ShapePattern(pattern, shape) + + def has_attr(attrs, pattern=None) -> "DFPattern": """ Syntatic sugar for creating an AttrPattern @@ -514,7 +590,7 @@ def __init__(self): @register_df_node class TypePattern(DFPattern): - """Get index-th item from a TuplePattern. + """A pattern that matches another pattern with a certain type annotation. Parameters ---------- @@ -529,6 +605,40 @@ def __init__(self, pattern: "DFPattern", ttype: tvm.ir.type.Type): self.__init_handle_by_constructor__(ffi.TypePattern, pattern, ttype) +@register_df_node +class DataTypePattern(DFPattern): + """A pattern that matches another pattern with certain data type + + Parameters + ---------- + pattern: tvm.relay.dataflow_pattern.DFPattern + The input pattern that needs type annotation. + + dtype: str + The dtype to match. + """ + + def __init__(self, pattern: "DFPattern", dtype: str): + self.__init_handle_by_constructor__(ffi.DataTypePattern, pattern, dtype) + + +@register_df_node +class ShapePattern(DFPattern): + """A pattern that matches another pattern with a certain tensor shape + + Parameters + ---------- + pattern: tvm.relay.dataflow_pattern.DFPattern + The input pattern that needs type annotation. + + shape: List[tvm.ir.PrimExpr] + The shape to match. + """ + + def __init__(self, pattern: "DFPattern", shape: List[tvm.ir.PrimExpr]): + self.__init_handle_by_constructor__(ffi.ShapePattern, pattern, shape) + + @register_df_node class AttrPattern(DFPattern): """Get match an expression with a certain attributes. diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index e9543e354bd1..d01a1e707d03 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -50,13 +50,15 @@ class DFPatternMatcher : public DFPatternFunctortype, expr_type)) && VisitDFPattern(op->pattern, expr); } +bool DFPatternMatcher::VisitDFPattern_(const ShapePatternNode* op, const Expr& expr) { + auto expr_type = InferType(expr).as()->checked_type(); + if (const TensorTypeNode* tensor_type = expr_type.as()) { + return (StructuralEqual()(op->shape, tensor_type->shape)) && VisitDFPattern(op->pattern, expr); + } + return false; +} + +bool DFPatternMatcher::VisitDFPattern_(const DataTypePatternNode* op, const Expr& expr) { + auto expr_type = InferType(expr).as()->checked_type(); + if (const TensorTypeNode* tensor_type = expr_type.as()) { + return (StructuralEqual()(op->dtype, tensor_type->dtype)) && VisitDFPattern(op->pattern, expr); + } + return false; +} + bool DFPatternMatcher::VisitDFPattern_(const VarPatternNode* op, const Expr& expr) { bool matches = false; if (const auto* var_node = expr.as()) { diff --git a/src/relay/ir/dataflow_pattern.cc b/src/relay/ir/dataflow_pattern.cc index 280913164fd5..4664e5fc8168 100644 --- a/src/relay/ir/dataflow_pattern.cc +++ b/src/relay/ir/dataflow_pattern.cc @@ -187,6 +187,46 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "TypePattern(" << node->pattern << " has type " << node->type << ")"; }); +ShapePattern::ShapePattern(DFPattern pattern, Array shape) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->shape = std::move(shape); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ShapePatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.ShapePattern") + .set_body_typed([](DFPattern pattern, Array shape) { + return ShapePattern(pattern, shape); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "ShapePattern(" << node->pattern << " has shape " << node->shape << ")"; + }); + +DataTypePattern::DataTypePattern(DFPattern pattern, DataType dtype) { + ObjectPtr n = make_object(); + n->pattern = std::move(pattern); + n->dtype = std::move(dtype); + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(DataTypePatternNode); + +TVM_REGISTER_GLOBAL("relay.dataflow_pattern.DataTypePattern") + .set_body_typed([](DFPattern pattern, DataType dtype) { + return DataTypePattern(pattern, dtype); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "TypePattern(" << node->pattern << " has dtype " << node->dtype << ")"; + }); + AttrPattern::AttrPattern(DFPattern pattern, Attrs attrs) { ObjectPtr n = make_object(); n->pattern = std::move(pattern); diff --git a/src/relay/ir/dataflow_pattern_functor.cc b/src/relay/ir/dataflow_pattern_functor.cc index ee44bcb43c8b..7e9f828c8aa8 100644 --- a/src/relay/ir/dataflow_pattern_functor.cc +++ b/src/relay/ir/dataflow_pattern_functor.cc @@ -49,6 +49,11 @@ void DFPatternVisitor::VisitDFPattern_(const CallPatternNode* op) { VisitDFPattern(arg); } } + +void DFPatternVisitor::VisitDFPattern_(const DataTypePatternNode* op) { + VisitDFPattern(op->pattern); +} + void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) { VisitDFPattern(op->parent); VisitDFPattern(op->path); @@ -57,6 +62,8 @@ void DFPatternVisitor::VisitDFPattern_(const DominatorPatternNode* op) { void DFPatternVisitor::VisitDFPattern_(const ExprPatternNode* op) {} +void DFPatternVisitor::VisitDFPattern_(const ShapePatternNode* op) { VisitDFPattern(op->pattern); } + void DFPatternVisitor::VisitDFPattern_(const TupleGetItemPatternNode* op) { VisitDFPattern(op->tuple); } diff --git a/src/relay/ir/indexed_graph.cc b/src/relay/ir/indexed_graph.cc index 0d4b90da0293..456bf02a0611 100644 --- a/src/relay/ir/indexed_graph.cc +++ b/src/relay/ir/indexed_graph.cc @@ -246,6 +246,13 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { VisitDFPattern(arg, graph_.node_map_[GetRef(op)]); } } + + void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {} + + void VisitDFPattern_(const DataTypePatternNode* op, NodePtr parent) override { + VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + } + void VisitDFPattern_(const DominatorPatternNode* op, NodePtr parent) override { VisitDFPattern(op->parent, graph_.node_map_[GetRef(op)]); VisitDFPattern(op->path, graph_.node_map_[GetRef(op)]); @@ -254,6 +261,10 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern_(const ExprPatternNode* op, NodePtr parent) override {} + void VisitDFPattern_(const ShapePatternNode* op, NodePtr parent) override { + VisitDFPattern(op->pattern, graph_.node_map_[GetRef(op)]); + } + void VisitDFPattern_(const TupleGetItemPatternNode* op, NodePtr parent) override { VisitDFPattern(op->tuple, graph_.node_map_[GetRef(op)]); } @@ -270,8 +281,6 @@ IndexedGraph CreateIndexedGraph(const DFPattern& pattern) { void VisitDFPattern_(const VarPatternNode* op, NodePtr parent) override {} - void VisitDFPattern_(const ConstantPatternNode* op, NodePtr parent) override {} - void VisitDFPattern_(const WildcardPatternNode* op, NodePtr parent) override {} }; return Annotator(Creator().CreateGraph(pattern)).Annotate(); diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 8d67db5c0c21..9727e53bab0a 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -94,6 +94,20 @@ def test_TypePattern(): assert ty_pat.type == ttype +def test_DataTypePattern(): + dtype = "float16" + pattern = has_dtype(dtype) + assert isinstance(pattern, DataTypePattern) + assert pattern.dtype == dtype + + +def test_ShapePattern(): + shape = [10, 10] + pattern = has_shape(shape) + assert isinstance(pattern, ShapePattern) + assert tvm.ir.structural_equal(pattern.shape, shape) + + def test_AttrPattern(): op = is_op('add').has_attr({"TOpPattern": K_ELEMWISE}) assert isinstance(op, AttrPattern) @@ -265,6 +279,30 @@ def test_no_match_type(): assert not ty_pat.match(x) +def test_match_dtype(): + x = relay.var('x', shape=(10, 10), dtype="float32") + ty_pat = has_dtype("float32") + assert ty_pat.match(x) + + +def test_no_match_dtype(): + x = relay.var('x', shape=(10, 10), dtype="int32") + ty_pat = has_dtype("float32") + assert not ty_pat.match(x) + + +def test_match_shape(): + x = relay.var('x', shape=(10, 10), dtype="float32") + ty_pat = has_shape((10, 10)) + assert ty_pat.match(x) + + +def test_no_match_shape(): + x = relay.var('x', shape=(10, 10), dtype="int32") + ty_pat = has_shape((10, 5)) + assert not ty_pat.match(x) + + def test_match_op_attr(): op = is_op('add').has_attr({"TOpPattern": K_BROADCAST}) op_pat = op(wildcard(), wildcard()) @@ -500,6 +538,54 @@ def test_not_match_dominator(): assert not diamond.match(out) +def test_match_typed_dominator(): + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float32") + reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 3, 10, 10]) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Classic Diamond + inp = relay.var('input',relay.TensorType((1, 3, 12, 12), "float32")) + weight = relay.var('weight', relay.TensorType((3, 3, 3, 3), "float32")) + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Check + assert diamond.match(out) + +def test_no_match_typed_dominator(): + # Classic Diamond + inp = relay.var('input',relay.TensorType((1, 3, 12, 12), "float32")) + weight = relay.var('weight', relay.TensorType((3, 3, 3, 3), "float32")) + conv2d = relay.op.nn.conv2d(inp, weight) + relu = relay.op.nn.relu(conv2d) + relu = relay.op.nn.relu(relu) + leaky_relu = relay.op.nn.leaky_relu(conv2d, alpha=0) + out = relu + leaky_relu + + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float32") + reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 1, 10, 10]) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Check + assert not diamond.match(out) + + # Pattern + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()).has_dtype("float16") + reduction = is_op('add')(wildcard(), wildcard()).has_shape([1, 3, 10, 10]) + diamond = dominates(is_conv2d, is_unary_elemwise, reduction) + + # Check + assert not diamond.match(out) + + def test_rewrite(): x = relay.var('x') y = relay.var('y') @@ -1222,6 +1308,8 @@ def test_partition_constant_embedding(): test_TupleGetItemPattern() test_AltPattern() test_TypePattern() + test_DataTypePattern() + test_ShapePattern() test_AttrPattern() test_match_op() test_no_match_op() @@ -1237,6 +1325,10 @@ def test_partition_constant_embedding(): test_no_match_tuple() test_match_type() test_no_match_type() + test_match_dtype() + test_no_match_dtype() + test_match_shape() + test_no_match_shape() test_match_op_attr() test_no_match_op_attr() test_match_func_attr()