diff --git a/docs/langref/relay_pattern.rst b/docs/langref/relay_pattern.rst index 7f81b9b48299..f56d49681cd0 100644 --- a/docs/langref/relay_pattern.rst +++ b/docs/langref/relay_pattern.rst @@ -41,7 +41,7 @@ There are quite a few properties that are worth matching of operators below we e The next example is a dense operation with any operator that is marked element-wise:: def test_no_match_attr(): - op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE) + op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE}) op_pat = op(wildcard(), wildcard()) x = relay.var('x') y = relay.var('y') @@ -97,7 +97,7 @@ The high level design is to introduce a language of patterns for now we propose | * | pattern(pattern1, ... patternN) | has_type(pattern, type) - | has_attr(pattern, attr, attr_value) + | has_attr(pattern, attrs) | is_input(name) | pattern1 `|` pattern2 | dominates(parent_pattern, path_pattern, child_pattern) diff --git a/python/tvm/relay/dataflow_pattern/__init__.py b/python/tvm/relay/dataflow_pattern/__init__.py index ca324bc444ec..bb1e5315bb96 100644 --- a/python/tvm/relay/dataflow_pattern/__init__.py +++ b/python/tvm/relay/dataflow_pattern/__init__.py @@ -61,23 +61,20 @@ def __mul__(self, other): def __truediv__(self, other): return is_op("divide")(self, other) - def has_attr(self, attr_name: str, attr_value): + def has_attr(self, attrs): """ Add an attribute constraint to this pattern Parameters ---------- - attr_name: str - The name of the attribute to match - attr_value: Any - The value of the attribute to match + attrs: Dict[str, Object] Returns ------- result: tvm.relay.dataflow_pattern.DFPattern The resulting AttrPattern """ - attrs = make_node("DictAttrs", **{attr_name: attr_value}) + attrs = make_node("DictAttrs", **attrs) return AttrPattern(self, attrs) def has_type(self, ttype): @@ -235,18 +232,18 @@ def has_type(ttype, pattern: DFPattern = None) -> DFPattern: return TypePattern(pattern, ttype) -def has_attr(attr_name: DFPattern, attr_value, pattern=None) -> DFPattern: +def has_attr(attrs, pattern=None) -> DFPattern: """ Syntatic sugar for creating an AttrPattern Parameters ---------- - pattern: tvm.relay.dataflow_pattern.DFPattern - The input pattern. - - attrs: tvm.Attrs + attrs: Dict[str, Object] The attributes to match + pattern: Optional[tvm.relay.dataflow_pattern.DFPattern] + The input pattern. + Returns ------- result: tvm.relay.dataflow_pattern.DFPattern @@ -254,7 +251,7 @@ def has_attr(attr_name: DFPattern, attr_value, pattern=None) -> DFPattern: """ if pattern is None: pattern = wildcard() - return pattern.has_attr(attr_name, attr_value) + return pattern.has_attr(attrs) def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern: diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 7c70f324ebf3..eedd94bd8b17 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -101,39 +101,73 @@ bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& exp return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr); } +bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) { + switch (rhs.type_code()) { + case kDLInt: + if (auto* val = lhs.as()) { + return val->value == rhs.operator int64_t(); + } + break; + case kDLFloat: + if (auto* val = lhs.as()) { + return val->value == rhs.operator double(); + } + break; + case kTVMStr: + std::cout << lhs << std::endl; + if (auto* val = lhs.as()) { + return val->value == rhs.operator std::string(); + } else if (auto* val = lhs.as()) { + return val->data == rhs.operator std::string(); + } + break; + default: + CHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code(); + } + return false; +} + bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) { bool matches = false; + auto attributes = attr_pattern->attrs.as()->dict; if (const auto* op_node = expr.as()) { Op op = GetRef(op_node); - auto attributes = attr_pattern->attrs.as()->dict; for (auto kv : attributes) { auto attr_name = kv.first; auto attr_value = kv.second; auto op_map = Op::GetAttrMap(attr_name); if (op_map.count(op)) { - switch (op_map[op].type_code()) { - case kDLInt: - if (auto* val = kv.second.as()) { - matches = val->value == op_map[op].operator int64_t(); - } - break; - case kDLFloat: - if (auto* val = kv.second.as()) { - matches = val->value == op_map[op].operator double(); - } - break; - case kTVMStr: - if (auto* val = kv.second.as()) { - matches = val->value == op_map[op].operator std::string(); - } - break; - default: - CHECK(false) << "Unsupported type in Type Pattern Node"; - } + matches = MatchRetValue(attr_value, op_map[op]); + } + } + } else if (auto* op = expr.as()) { + matches = true; + // TODO(mbrookhart): When OpNode Attrs move from TVMRetValue to the Object system, remove this + // and replace the whole thing with a Visitor-based approach + ReflectionVTable* reflection = ReflectionVTable::Global(); + auto attrs_node = const_cast(op->attrs.get()); + auto attr_names = reflection->ListAttrNames(attrs_node); + for (auto kv : attributes) { + if (matches && + std::find(attr_names.begin(), attr_names.end(), kv.first) != attr_names.end()) { + matches &= MatchRetValue(kv.second, reflection->GetAttr(attrs_node, kv.first)); + } else { + matches = false; + break; + } + } + } else if (auto* op = expr.as()) { + matches = true; + for (auto kv : attributes) { + if (matches && op->attrs->dict.count(kv.first)) { + matches &= StructuralEqual()(kv.second, op->attrs->dict[kv.first]); + } else { + matches = false; + break; } } } - return matches; + return matches && VisitDFPattern(attr_pattern->pattern, expr); } Array reverse(const Array& args) { diff --git a/tests/python/relay/test_dataflow_pattern.py b/tests/python/relay/test_dataflow_pattern.py index 41b3d6d997e9..a92ef71176b0 100644 --- a/tests/python/relay/test_dataflow_pattern.py +++ b/tests/python/relay/test_dataflow_pattern.py @@ -77,7 +77,7 @@ def test_TypePattern(): assert ty_pat.type == ttype def test_AttrPattern(): - op = is_op('add').has_attr("TOpPattern", K_ELEMWISE) + op = is_op('add').has_attr({"TOpPattern": K_ELEMWISE}) assert isinstance(op, AttrPattern) assert op.attrs["TOpPattern"] == K_ELEMWISE @@ -225,19 +225,57 @@ def test_no_match_type(): ty_pat = has_type(relay.TensorType((10, 10), "float32")) assert not ty_pat.match(x) -def test_match_attr(): - op = is_op('add').has_attr("TOpPattern", K_BROADCAST) +def test_match_op_attr(): + op = is_op('add').has_attr({"TOpPattern": K_BROADCAST}) op_pat = op(wildcard(), wildcard()) x = relay.var('x') y = relay.var('y') assert op_pat.match(x + y) -def test_no_match_attr(): - op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE) +def test_no_match_op_attr(): + op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE}) op_pat = op(wildcard(), wildcard()) x = relay.var('x') y = relay.var('y') assert not op_pat.match(relay.op.nn.dense(x, y)) + op = is_op('add').has_attr({"TOpPattern": K_BROADCAST}) + op_pat = op(wildcard(), wildcard()) + x = relay.var('x') + y = relay.var('y') + assert not op_pat.match(x - y) + +def test_match_func_attr(): + pattern = wildcard().has_attr({"Composite": "add"}) + x = relay.var('x') + y = relay.var('y') + f = relay.Function([x, y], x + y).with_attr("Composite", "add") + assert pattern.match(f) + +def test_no_match_func_attr(): + pattern = wildcard().has_attr({"Composite": "add"}) + x = relay.var('x') + y = relay.var('y') + + f = relay.Function([x, y], x + y).with_attr("RandomTest", "add") + assert not pattern.match(f) + f = relay.Function([x, y], x + y).with_attr("Composite", "conv_bias") + assert not pattern.match(f) + +def test_match_call_attr(): + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"}) + x = relay.var('x') + y = relay.var('y') + assert is_conv2d.match(relay.op.nn.conv2d(x, y)) + +def test_no_match_call_attr(): + x = relay.var('x') + y = relay.var('y') + + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NHWC"}) + assert not is_conv2d.match(relay.op.nn.conv2d(x, y)) + + is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"}) + assert not is_conv2d.match(relay.op.nn.conv2d(x, y)) def test_match_diamond(): # Pattern @@ -301,7 +339,7 @@ def test_match_fake_diamond(): def test_match_dominator(): # Pattern is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) - is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) reduction = is_op('add')(wildcard(), wildcard()) diamond = dominates(is_conv2d, is_unary_elemwise, reduction) @@ -344,7 +382,7 @@ def test_match_dominator(): # Fuzzy path/nested Diamond is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) - is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard()) reduction = is_op('add')(wildcard(), wildcard()) diamond = dominates(is_conv2d, is_unary_elemwise, reduction) @@ -361,7 +399,7 @@ def test_match_dominator(): def test_not_match_dominator(): is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) - is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) reduction = is_op('add')(wildcard(), wildcard()) diamond = dominates(is_conv2d, is_unary_elemwise, reduction) @@ -578,7 +616,7 @@ def __init__(self): self.weight = wildcard() is_conv2d = is_op('nn.conv2d')(self.inp, self.weight) - is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard()) reduction = is_op('add')(wildcard(), wildcard()) self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction) @@ -705,7 +743,7 @@ def test_algebraic_simplify(): def test_partition_dominator(): # Pattern is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) - is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) reduction = is_op('add')(wildcard(), wildcard()) diamond = dominates(is_conv2d, is_unary_elemwise, reduction) @@ -730,7 +768,7 @@ def generate_diamond(inp, weight): def test_quadruple_partition_dominator(): # Pattern is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()) - is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard()) + is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard()) reduction = is_op('add')(wildcard(), wildcard()) diamond = dominates(is_conv2d, is_unary_elemwise, reduction)