Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[relay][heterogeneous] annotate using visitor #3261

Merged
merged 2 commits into from
Jun 1, 2019
Merged

Conversation

zhiics
Copy link
Member

@zhiics zhiics commented May 30, 2019

Recently, there are some discussions about how to leverage an additional pass to annotate a Relay program for heterogeneous compilation. This PR adds a unit test to show it and it slightly changes the manual way for annotation. Now users need to connect nodes but don't need to pass the annotation exprs to the backend.

This following example has been executed locally:

import tvm
from tvm import relay
import tvm.relay.testing
from tvm.relay.expr_functor import ExprMutator

class ScheduleConv2d(ExprMutator):
    def __init__(self, device):
        self.device = device
        super().__init__()

    def visit_call(self, expr):
        visit = super().visit_call(expr)
        if expr.op == tvm.relay.op.get("nn.conv2d"):
            return relay.annotation.on_device(visit, self.device)
        else:
            return visit

def schedule_conv2d_on_gpu(expr):
    sched = ScheduleConv2d(tvm.gpu(0))
    return sched.visit(expr)

resnet, params = relay.testing.resnet.get_workload()
resnet = schedule_conv2d_on_gpu(resnet)
resnet = relay.ir_pass.infer_type(resnet)

target = {"gpu": "cuda", "cpu": "llvm"}
with relay.build_config(opt_level=3, fallback_device=tvm.cpu()):
    json, mod, params = relay.build(resnet, target=target)
    print(json)

cc' @jroesch this should be what you want.
@imorinaga @anijain2305 @jwfromm

@@ -176,7 +176,11 @@ class RewriteAnnotation : public ExprMutator {
}

Expr VisitExpr_(const CallNode* call_node) final {
if (IsOnDeviceNode(call_node) || IsDeviceCopyNode(call_node)) {
if (IsOnDeviceNode(call_node)) {
Copy link
Member

@jroesch jroesch May 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the only meaningful change right? this will just search deeper in the tree to rewrite until we hit another device annotation? just want to make sure I 100% understand this time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, for example, on_device connects a and b expression, a -> on_device -> b, we have inserted device_copy when we visit b. Anytime when we visit on_device, if means we either have inserted device_copy op, or it is not necessary, therefore, we can safely delete on_device by returning 'a.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay cool, LGTM

@jroesch jroesch merged commit 887255a into apache:master Jun 1, 2019
@zhiics zhiics deleted the hetero branch June 4, 2019 18:07
wweic pushed a commit to wweic/tvm that referenced this pull request Jun 26, 2019
* annotate using visitor

* retrigger CI
wweic pushed a commit to neo-ai/tvm that referenced this pull request Jun 27, 2019
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants