Skip to content

Commit

Permalink
[Relay][Compilation] replace relay.build_module with C++ BuildModule (a…
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and wweic committed Jun 27, 2019
1 parent 8671bb3 commit 87edf6f
Show file tree
Hide file tree
Showing 13 changed files with 534 additions and 541 deletions.
2 changes: 1 addition & 1 deletion python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from . import module
from . import adt
from . import ir_pass
from .build_module import build, build_config, create_executor, optimize
from .build_module import build, build_config, create_executor
from . import prelude
from . import parser
from . import debug
Expand Down
21 changes: 21 additions & 0 deletions python/tvm/relay/_build_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable
"""The interface for building Relay functions exposed from C++."""
from tvm._ffi.function import _init_api

_init_api("relay.build_module", __name__)
18 changes: 6 additions & 12 deletions python/tvm/relay/backend/graph_runtime_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@
from __future__ import absolute_import

from tvm.ndarray import empty
from tvm._ffi.function import _init_api

from tvm.relay import build_module
from tvm import target as _target

_init_api("tvm.relay.build_module")
from tvm import expr as _expr

class GraphRuntimeCodegen(object):
"""The compiler from Relay to the TVM runtime system."""
Expand All @@ -57,17 +54,14 @@ def __init__(self, mod, target):
self._setup(mod, target)

def _setup(self, mod, target):
tgts = []
tgts = {}
if isinstance(target, dict):
for kv in target.items():
tgts.append(kv[0])
if isinstance(kv[1], (str, _target.Target)):
tgts.append(str(kv[1]))
else:
for dev, tgt in target.items():
if not isinstance(tgt, (str, _target.Target)):
raise Exception("Unknown target type")
tgts[dev] = _target.create(tgt)
elif isinstance(target, (str, _target.Target)):
tgts.append("0")
tgts.append(str(target))
tgts[_expr.IntImm("int32", 0)] = _target.create(target)
self._init(mod, tgts)

def codegen(self, func):
Expand Down
Loading

0 comments on commit 87edf6f

Please sign in to comment.