Skip to content

Commit

Permalink
Add redirecting operation to dataflow pattern graph
Browse files Browse the repository at this point in the history
  • Loading branch information
kfeng123 committed Jul 31, 2023
1 parent 0556653 commit 9875e8e
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 3 deletions.
6 changes: 6 additions & 0 deletions include/tvm/relay/dataflow_pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ class WildcardPatternNode : public DFPatternNode {
public:
void VisitAttrs(tvm::AttrVisitor* v) {}

/*! \brief If the wildcard is redirected, then pattern is not nullptr, and the wildcard
* redirects to the pattern. */
Optional<DFPattern> pattern{nullptr};

static constexpr const char* _type_key = "relay.dataflow_pattern.WildcardPattern";
TVM_DECLARE_FINAL_OBJECT_INFO(WildcardPatternNode, DFPatternNode);
};
Expand All @@ -372,6 +376,8 @@ class WildcardPatternNode : public DFPatternNode {
class WildcardPattern : public DFPattern {
public:
TVM_DEFINE_OBJECT_REF_METHODS(WildcardPattern, DFPattern, WildcardPatternNode);

void redirect_to(DFPattern pat) const;
};

class TypePattern;
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/dataflow_pattern/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,19 @@ class WildcardPattern(DFPattern):
def __init__(self):
self.__init_handle_by_constructor__(ffi.WildcardPattern)

def redirect_to(
self,
pat: "DFPattern",
):
"""Redirect the WildcardPattern to another pattern
Parameters
----------
pat: relay.dataflow_pattern.DFPattern
The pattern that wildcard is redirected to.
"""
ffi.WildcardPattern_redirect_to(self, pat);


@register_df_node
class TypePattern(DFPattern):
Expand Down
6 changes: 5 additions & 1 deletion src/relay/ir/dataflow_matcher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,11 @@ bool DFPatternMatcher::VisitDFPattern_(const ConstantPatternNode* op, const Expr
}

bool DFPatternMatcher::VisitDFPattern_(const WildcardPatternNode* op, const Expr& expr) {
return true;
if (op->pattern) {
return VisitDFPattern(op->pattern.value(), expr);
} else {
return true;
}
}

bool MatchPattern(DFPattern pattern, Expr expr) {
Expand Down
10 changes: 10 additions & 0 deletions src/relay/ir/dataflow_pattern.cc
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,18 @@ TVM_STATIC_IR_FUNCTOR(DFPatternPrinter, vtable)
<< ")";
});

void WildcardPattern::redirect_to(DFPattern pat) const {
WildcardPatternNode* ptr = static_cast<WildcardPatternNode*>(get_mutable());
ptr->pattern = pat;
}

TVM_REGISTER_NODE_TYPE(WildcardPatternNode);

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern_redirect_to")
.set_body_typed([](WildcardPattern wildcard, DFPattern pat) {
return wildcard.redirect_to(pat);
});

TVM_REGISTER_GLOBAL("relay.dataflow_pattern.WildcardPattern").set_body_typed([]() {
auto w = WildcardPattern(make_object<WildcardPatternNode>());
return w;
Expand Down
6 changes: 5 additions & 1 deletion src/relay/ir/dataflow_pattern_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,11 @@ void DFPatternVisitor::VisitDFPattern_(const VarPatternNode* op) {}

void DFPatternVisitor::VisitDFPattern_(const ConstantPatternNode* op) {}

void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {}
void DFPatternVisitor::VisitDFPattern_(const WildcardPatternNode* op) {
if (op->pattern) {
VisitDFPattern(op->pattern.value());
}
}

} // namespace relay
} // namespace tvm
7 changes: 6 additions & 1 deletion src/relay/ir/indexed_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,12 @@ std::unique_ptr<IndexedGraph<DFPattern>> CreateIndexedGraph(const DFPattern& pat

void VisitDFPattern_(const VarPatternNode* op) override {}

void VisitDFPattern_(const WildcardPatternNode* op) override {}
void VisitDFPattern_(const WildcardPatternNode* op) override {
if (op->pattern) {
auto node = graph_->item_to_node(GetRef<WildcardPattern>(op));
AddOutput(op->pattern.value(), node);
}
}

std::unique_ptr<IndexedGraph<DFPattern>> graph_;
};
Expand Down
56 changes: 56 additions & 0 deletions tests/python/relay/test_dataflow_pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1964,5 +1964,61 @@ def test_partition_parallel_branch_with_same_input():
assert tvm.ir.structural_equal(partitioned, reference)


def test_rewrite_with_pattern_recursion():
data = relay.var("data", relay.TensorType((2, 8), "float32"))
dense_weight = relay.const(np.zeros((4, 8)))
feat = relay.nn.dense(data, dense_weight)
feat = relay.cast(feat, "float32")
feat = relay.cast(feat, "float32")
feat = relay.cast(feat, "float32")
feat = relay.cast(feat, "float32")
feat = relay.cast(feat, "float32")
oup = relay.cast(feat, "float32")

expected = relay.nn.relu(oup)

class TheRewrite(DFPatternCallback):
def __init__(self, pattern):
super(TheRewrite, self).__init__(rewrite_once=True)
self.pattern = pattern

def callback(self, pre, post, node_map):
return relay.nn.relu(post)

def test_reset_call_args():
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
wildcard_redirect = wildcard()
the_pattern = is_op("cast")(wildcard_redirect)
the_pattern2 = the_pattern | dense_pattern
wildcard_redirect.redirect_to(the_pattern2)

actual = rewrite(TheRewrite(the_pattern), oup)
tvm.ir.assert_structural_equal(actual, expected)

def test_reset_alt_left():
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
wildcard_redirect = wildcard()
or_pattern = wildcard_redirect | dense_pattern
the_pattern = is_op("cast")(or_pattern)
wildcard_redirect.redirect_to(the_pattern)

actual = rewrite(TheRewrite(the_pattern), oup)
tvm.ir.assert_structural_equal(actual, expected)

def test_reset_alt_right():
dense_pattern = is_op("nn.dense")(wildcard(), wildcard())
wildcard_redirect = wildcard()
or_pattern = dense_pattern | wildcard_redirect
the_pattern = is_op("cast")(or_pattern)
wildcard_redirect.redirect_to(the_pattern)

actual = rewrite(TheRewrite(the_pattern), oup)
tvm.ir.assert_structural_equal(actual, expected)

test_reset_call_args()
test_reset_alt_left()
test_reset_alt_right()


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 9875e8e

Please sign in to comment.