Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend AttrPattern to support CallNode and FunctionNode attributes #5637

Merged
merged 3 commits into from
May 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/langref/relay_pattern.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ There are quite a few properties that are worth matching of operators below we e
The next example is a dense operation with any operator that is marked element-wise::

def test_no_match_attr():
op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE)
op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE})
op_pat = op(wildcard(), wildcard())
x = relay.var('x')
y = relay.var('y')
Expand Down Expand Up @@ -97,7 +97,7 @@ The high level design is to introduce a language of patterns for now we propose
| *
| pattern(pattern1, ... patternN)
| has_type(pattern, type)
| has_attr(pattern, attr, attr_value)
| has_attr(pattern, attrs)
| is_input(name)
| pattern1 `|` pattern2
| dominates(parent_pattern, path_pattern, child_pattern)
Expand Down
21 changes: 9 additions & 12 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,20 @@ def __mul__(self, other):
def __truediv__(self, other):
return is_op("divide")(self, other)

def has_attr(self, attr_name: str, attr_value):
def has_attr(self, attrs):
"""
Add an attribute constraint to this pattern

Parameters
----------
attr_name: str
The name of the attribute to match
attr_value: Any
The value of the attribute to match
attrs: Dict[str, Object]

Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting AttrPattern
"""
attrs = make_node("DictAttrs", **{attr_name: attr_value})
attrs = make_node("DictAttrs", **attrs)
return AttrPattern(self, attrs)

def has_type(self, ttype):
Expand Down Expand Up @@ -235,26 +232,26 @@ def has_type(ttype, pattern: DFPattern = None) -> DFPattern:
return TypePattern(pattern, ttype)


def has_attr(attr_name: DFPattern, attr_value, pattern=None) -> DFPattern:
def has_attr(attrs, pattern=None) -> DFPattern:
"""
Syntatic sugar for creating an AttrPattern

Parameters
----------
pattern: tvm.relay.dataflow_pattern.DFPattern
The input pattern.

attrs: tvm.Attrs
attrs: Dict[str, Object]
The attributes to match

pattern: Optional[tvm.relay.dataflow_pattern.DFPattern]
The input pattern.

Returns
-------
result: tvm.relay.dataflow_pattern.DFPattern
The resulting AttrPattern
"""
if pattern is None:
pattern = wildcard()
return pattern.has_attr(attr_name, attr_value)
return pattern.has_attr(attrs)


def dominates(parent: DFPattern, path: DFPattern, child: DFPattern) -> DFPattern:
Expand Down
76 changes: 55 additions & 21 deletions src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,39 +101,73 @@ bool DFPatternMatcher::VisitDFPattern_(const AltPatternNode* op, const Expr& exp
return VisitDFPattern(op->left, expr) || VisitDFPattern(op->right, expr);
}

bool MatchRetValue(const ObjectRef& lhs, const TVMRetValue& rhs) {
switch (rhs.type_code()) {
case kDLInt:
if (auto* val = lhs.as<IntImmNode>()) {
return val->value == rhs.operator int64_t();
}
break;
case kDLFloat:
if (auto* val = lhs.as<FloatImmNode>()) {
return val->value == rhs.operator double();
}
break;
case kTVMStr:
std::cout << lhs << std::endl;
if (auto* val = lhs.as<tir::StringImmNode>()) {
return val->value == rhs.operator std::string();
} else if (auto* val = lhs.as<StringObj>()) {
return val->data == rhs.operator std::string();
}
break;
default:
CHECK(false) << "Unsupported type code in Pattern Node " << rhs.type_code();
}
return false;
}

bool DFPatternMatcher::VisitDFPattern_(const AttrPatternNode* attr_pattern, const Expr& expr) {
bool matches = false;
auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
if (const auto* op_node = expr.as<OpNode>()) {
Op op = GetRef<Op>(op_node);
auto attributes = attr_pattern->attrs.as<DictAttrsNode>()->dict;
for (auto kv : attributes) {
auto attr_name = kv.first;
auto attr_value = kv.second;
auto op_map = Op::GetAttrMap<TVMRetValue>(attr_name);
if (op_map.count(op)) {
switch (op_map[op].type_code()) {
case kDLInt:
if (auto* val = kv.second.as<IntImmNode>()) {
matches = val->value == op_map[op].operator int64_t();
}
break;
case kDLFloat:
if (auto* val = kv.second.as<FloatImmNode>()) {
matches = val->value == op_map[op].operator double();
}
break;
case kTVMStr:
if (auto* val = kv.second.as<tir::StringImmNode>()) {
matches = val->value == op_map[op].operator std::string();
}
break;
default:
CHECK(false) << "Unsupported type in Type Pattern Node";
}
matches = MatchRetValue(attr_value, op_map[op]);
}
}
} else if (auto* op = expr.as<CallNode>()) {
matches = true;
// TODO(mbrookhart): When OpNode Attrs move from TVMRetValue to the Object system, remove this
// and replace the whole thing with a Visitor-based approach
ReflectionVTable* reflection = ReflectionVTable::Global();
auto attrs_node = const_cast<Object*>(op->attrs.get());
auto attr_names = reflection->ListAttrNames(attrs_node);
for (auto kv : attributes) {
if (matches &&
std::find(attr_names.begin(), attr_names.end(), kv.first) != attr_names.end()) {
matches &= MatchRetValue(kv.second, reflection->GetAttr(attrs_node, kv.first));
} else {
matches = false;
break;
}
}
} else if (auto* op = expr.as<FunctionNode>()) {
matches = true;
for (auto kv : attributes) {
if (matches && op->attrs->dict.count(kv.first)) {
matches &= StructuralEqual()(kv.second, op->attrs->dict[kv.first]);
} else {
matches = false;
mbrookhart marked this conversation as resolved.
Show resolved Hide resolved
break;
}
}
}
return matches;
return matches && VisitDFPattern(attr_pattern->pattern, expr);
}

Array<DFPattern> reverse(const Array<DFPattern>& args) {
Expand Down
60 changes: 49 additions & 11 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_TypePattern():
assert ty_pat.type == ttype

def test_AttrPattern():
op = is_op('add').has_attr("TOpPattern", K_ELEMWISE)
op = is_op('add').has_attr({"TOpPattern": K_ELEMWISE})
assert isinstance(op, AttrPattern)
assert op.attrs["TOpPattern"] == K_ELEMWISE

Expand Down Expand Up @@ -225,19 +225,57 @@ def test_no_match_type():
ty_pat = has_type(relay.TensorType((10, 10), "float32"))
assert not ty_pat.match(x)

def test_match_attr():
op = is_op('add').has_attr("TOpPattern", K_BROADCAST)
def test_match_op_attr():
op = is_op('add').has_attr({"TOpPattern": K_BROADCAST})
op_pat = op(wildcard(), wildcard())
x = relay.var('x')
y = relay.var('y')
assert op_pat.match(x + y)

def test_no_match_attr():
op = is_op('nn.dense').has_attr("TOpPattern", K_ELEMWISE)
def test_no_match_op_attr():
op = is_op('nn.dense').has_attr({"TOpPattern": K_ELEMWISE})
op_pat = op(wildcard(), wildcard())
x = relay.var('x')
y = relay.var('y')
assert not op_pat.match(relay.op.nn.dense(x, y))
op = is_op('add').has_attr({"TOpPattern": K_BROADCAST})
op_pat = op(wildcard(), wildcard())
x = relay.var('x')
y = relay.var('y')
assert not op_pat.match(x - y)

def test_match_func_attr():
pattern = wildcard().has_attr({"Composite": "add"})
x = relay.var('x')
y = relay.var('y')
f = relay.Function([x, y], x + y).with_attr("Composite", "add")
assert pattern.match(f)

def test_no_match_func_attr():
pattern = wildcard().has_attr({"Composite": "add"})
x = relay.var('x')
y = relay.var('y')

f = relay.Function([x, y], x + y).with_attr("RandomTest", "add")
assert not pattern.match(f)
f = relay.Function([x, y], x + y).with_attr("Composite", "conv_bias")
assert not pattern.match(f)

def test_match_call_attr():
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NCHW"})
x = relay.var('x')
y = relay.var('y')
assert is_conv2d.match(relay.op.nn.conv2d(x, y))

def test_no_match_call_attr():
x = relay.var('x')
y = relay.var('y')

is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"data_layout": "NHWC"})
assert not is_conv2d.match(relay.op.nn.conv2d(x, y))

is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard()).has_attr({"RandomAttr": "NCHW"})
assert not is_conv2d.match(relay.op.nn.conv2d(x, y))

def test_match_diamond():
# Pattern
Expand Down Expand Up @@ -301,7 +339,7 @@ def test_match_fake_diamond():
def test_match_dominator():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

Expand Down Expand Up @@ -344,7 +382,7 @@ def test_match_dominator():

# Fuzzy path/nested Diamond
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

Expand All @@ -361,7 +399,7 @@ def test_match_dominator():

def test_not_match_dominator():
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

Expand Down Expand Up @@ -578,7 +616,7 @@ def __init__(self):
self.weight = wildcard()

is_conv2d = is_op('nn.conv2d')(self.inp, self.weight)
is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
reduction = is_op('add')(wildcard(), wildcard())
self.pattern = dominates(is_conv2d, is_unary_elemwise, reduction)

Expand Down Expand Up @@ -705,7 +743,7 @@ def test_algebraic_simplify():
def test_partition_dominator():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

Expand All @@ -730,7 +768,7 @@ def generate_diamond(inp, weight):
def test_quadruple_partition_dominator():
# Pattern
is_conv2d = is_op('nn.conv2d')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr("TOpPattern", K_ELEMWISE))(wildcard()) | is_op('add')(wildcard(), wildcard())
is_unary_elemwise = (wildcard().has_attr({"TOpPattern": K_ELEMWISE}))(wildcard()) | is_op('add')(wildcard(), wildcard())
reduction = is_op('add')(wildcard(), wildcard())
diamond = dominates(is_conv2d, is_unary_elemwise, reduction)

Expand Down