Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
asdf
Browse files Browse the repository at this point in the history
Bing Xu committed May 7, 2019
1 parent 6ea243a commit 3d5798e
Showing 7 changed files with 468 additions and 266 deletions.
19 changes: 14 additions & 5 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
@@ -344,23 +344,32 @@ TVM_DLL Array<LoweredFunc> lower(Schedule sch,
const std::string& name,
const std::unordered_map<Tensor, Buffer>& binds,
const BuildConfig& config);
/*!
* \brief Split host/device function and running necessary pass before build
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \return The Array<Array<LoweredFunc>> with 2 elements. First is host function Array,
second is device function array
*/
TVM_DLL Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config);

/*!
* \brief Build a device and host module for a specific target from an array of lowered functions.
* \param funcs The functions to be built.
* \param target The target device to build for.
* \param target_host The target for building host code. To use the default, pass Target()
* \param config The build configuration.
* \param (optional) returned host functions
* \param (optional) returned dev mods
* \return The built module.
*/
TVM_DLL runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config,
Array<LoweredFunc>* fhost_ret = nullptr,
std::vector<runtime::Module>* devmod_ret = nullptr);
const BuildConfig& config);

class GenericFuncNode;

33 changes: 18 additions & 15 deletions src/codegen/build_module.cc
Original file line number Diff line number Diff line change
@@ -422,12 +422,10 @@ Array<LoweredFunc> lower(Schedule sch,
return Array<LoweredFunc>({ ir::MakeAPI(stmt, name, out_arg_list, 0, config->restricted_func) });
}

runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config,
Array<LoweredFunc>* fhost_ret,
std::vector<runtime::Module>* devmod_ret) {
Array<Array<LoweredFunc> > split_dev_host_funcs(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
std::unordered_set<std::string> all_names;
for (const auto &x : funcs) {
CHECK(all_names.count(x->name) == 0) << "Duplicate function name " << x->name;
@@ -466,12 +464,6 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
}
}

if (fhost_ret != nullptr) {
for (auto f : fhost) {
fhost_ret->push_back(f);
}
}

auto keys = target->keys();
bool target_is_gpu =
std::find(keys.begin(), keys.end(), "gpu") != keys.end();
@@ -500,14 +492,25 @@ runtime::Module build(const Array<LoweredFunc>& funcs,
func = ir::CombineContextCall(func);
fhost.Set(i, func);
}
Array<Array<LoweredFunc> > ret;
ret.push_back(fhost);
ret.push_back(fdevice);
return ret;
}

runtime::Module build(const Array<LoweredFunc>& funcs,
const Target& target,
const Target& target_host,
const BuildConfig& config) {
auto target_host_val = target_host.defined() ? target_host : DefaultTargetHost(target);
auto host_dev_funcs = split_dev_host_funcs(funcs, target, target_host, config);
auto& fhost = host_dev_funcs[0];
auto& fdevice = host_dev_funcs[1];

auto mhost = codegen::Build(fhost, target_host_val->str());

if (fdevice.size() > 0) {
auto mdev = codegen::Build(fdevice, target->str());
if (devmod_ret != nullptr) {
devmod_ret->push_back(mdev);
}
mhost.Import(mdev);
}

448 changes: 249 additions & 199 deletions src/relay/backend/build_module.cc

Large diffs are not rendered by default.

4 changes: 3 additions & 1 deletion src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
@@ -369,7 +369,9 @@ class CompileEngineImpl : public CompileEngineNode {
cache_node->funcs = (*f)(
spair.first, all_args, cache_node->func_name, key->source_func);
} else {
LOG(FATAL) << "relay.backend.lower is not registred";
tvm::BuildConfig bcfg = tvm::build_config();
std::unordered_map<Tensor, Buffer> binds;
cache_node->funcs = tvm::lower(spair.first, all_args, cache_node->func_name, binds, bcfg);
}
value->cached_func = CachedFunc(cache_node);
return value;
7 changes: 6 additions & 1 deletion src/relay/backend/graph_runtime_codegen.cc
Original file line number Diff line number Diff line change
@@ -416,7 +416,12 @@ class GraphRuntimeCodegen
} else {
// heterogeneous execution.
const auto call_dev_key = std::to_string(call_dev_type);
const auto call_dev_name = runtime::DeviceName(call_dev_type);
std::string call_dev_name;
if (call_dev_type == 0) {
call_dev_name = "llvm";
} else {
call_dev_name = runtime::DeviceName(call_dev_type);
}
if (targets_.count(call_dev_name) == 0 && targets_.count(call_dev_key) == 0) {
LOG(FATAL) << "No target is provided for device "
<< call_dev_name;
104 changes: 104 additions & 0 deletions tests/cpp/relay_build_module_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

#include <gtest/gtest.h>
#include <tvm/tvm.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/type.h>
#include <tvm/relay/pass.h>
#include <topi/generic/injective.h>
#include <tvm/runtime/packed_func.h>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>


TVM_REGISTER_GLOBAL("test.sch")
.set_body([](tvm::TVMArgs args, tvm::TVMRetValue *rv) {
*rv = topi::generic::schedule_injective(args[0], args[1]);
});

TEST(Relay, BuildModule) {
using namespace tvm;
auto tensor_type = relay::TensorTypeNode::make({2, 3}, ::tvm::Float(32));
auto a = relay::VarNode::make("a", tensor_type);
auto b = relay::VarNode::make("b", tensor_type);
auto add_op = relay::Op::Get("add");
auto x = relay::CallNode::make(add_op, {a, b}, tvm::Attrs(), {});
auto c = relay::VarNode::make("c", tensor_type);
auto y = relay::CallNode::make(add_op, {x, c}, tvm::Attrs(), {});
auto func = relay::FunctionNode::make(relay::FreeVars(y), y, relay::Type(), {});
auto A = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto B = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});
auto C = tvm::runtime::NDArray::Empty({2, 3}, {kDLFloat, 32, 1}, {kDLCPU, 0});

auto pA = (float*)A.ToDLPack()->dl_tensor.data;
auto pB = (float*)B.ToDLPack()->dl_tensor.data;
auto pC = (float*)C.ToDLPack()->dl_tensor.data;

for (int i = 0; i < 6; ++i) {
pA[i] = i;
pB[i] = i + 1;
pC[i] = i + 2;
}
// get schedule
auto reg = tvm::runtime::Registry::Get("relay.op._Register");
auto s_i = tvm::runtime::Registry::Get("test.sch");
if (!reg) {
LOG(FATAL) << "no _Register";
}
if (!s_i) {
LOG(FATAL) << "no _Register";
}
(*reg)("add", "FTVMSchedule", *s_i, 10);
// build
auto pfb = tvm::runtime::Registry::Get("relay.build_module._BuildModule");
tvm::runtime::Module build_mod = (*pfb)();
auto build_f = build_mod.GetFunction("build", false);
auto json_f = build_mod.GetFunction("get_graph_json", false);
auto mod_f = build_mod.GetFunction("get_module", false);
Array<HalideIR::Expr> target_pair;
target_pair.push_back(ir::StringImm::make("cpu"));
target_pair.push_back(ir::StringImm::make("llvm"));
build_f(func, target_pair, "llvm");
std::string json = json_f();
tvm::runtime::Module mod = mod_f();
// run
auto ctx = A->ctx;
auto pfr = tvm::runtime::Registry::Get("tvm.graph_runtime.create");
tvm::runtime::Module run_mod = (*pfr)(json, mod, (int)ctx.device_type, (int)ctx.device_id);
auto set_input_f = run_mod.GetFunction("set_input", false);
auto run_f = run_mod.GetFunction("run", false);
auto get_output_f = run_mod.GetFunction("get_output", false);
set_input_f("a", A);
set_input_f("b", B);
set_input_f("c", C);
run_f();
tvm::runtime::NDArray Y = get_output_f(0);
auto pY = (float*)Y.ToDLPack()->dl_tensor.data;
for (int i = 0; i < 6; ++i) {
CHECK_LT(fabs(pY[i] - (i + (i + 1) + (i + 2))), 1e-4);
}
}

int main(int argc, char ** argv) {
testing::InitGoogleTest(&argc, argv);
testing::FLAGS_gtest_death_test_style = "threadsafe";
return RUN_ALL_TESTS();
}
119 changes: 74 additions & 45 deletions tests/python/relay/test_cpp_build_module.py
Original file line number Diff line number Diff line change
@@ -23,55 +23,84 @@
_init_api("tvm.relay.build_module")

class BuildModule(object):
def __init__(self):
self.mod = relay.build_module._BuildModule()
self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
def __init__(self):
self.mod = relay.build_module._BuildModule()
self._get_graph_json = self.mod["get_graph_json"]
self._get_module = self.mod["get_module"]
self._build = self.mod["build"]
self._set_opt_level = self.mod["set_opt_level"]
self._set_params_func = self.mod["set_params"]
self._get_params_func = self.mod["get_params"]

def build(self, func, target, target_host):
tgts = []
for kv in target.items():
tgts.append(kv[0])
tgts.append(kv[1])
self._build(func, tgts, target_host)

def build(self, func, target, target_host, params):
tgts = []
for kv in target.items():
tgts.append(kv[0])
tgts.append(kv[1])
self._set_params(params)
self._build(func, tgts, target_host)

def get_json(self):
return self._get_graph_json()
def get_json(self):
return self._get_graph_json()

def get_module(self):
return self._get_module()

def set_opt_level(self, level):
self._set_opt_level(level)

def _set_params(self, params):
inputs = {}
for name, param in params.items():
inputs[name] = relay.Constant(param)
self._set_params_func(inputs)

def get_params(self):
params = self._get_params_func()
ret = {}
for key, value in params.items():
ret[key] = value.data
return ret

def get_module(self):
return self._get_module()

def test_build():
m_bld = BuildModule()
# func
a = relay.var("a", dtype="float32", shape=(16, 8))
b = relay.var("b", dtype="float32", shape=(8, 8))
c = relay.var("c", dtype="float32", shape=(16, 8))
x = relay.nn.dense(a, b)
y = relay.nn.relu(x)
z = y + c
func = relay.Function([a, b, c], z)
# build
targets = {
"cpu": "llvm -mcpu=sse3"
}
m_bld.build(func, targets, "llvm -mcpu=sse3")
g_json = m_bld.get_json()
mmod = m_bld.get_module()

m_bld = BuildModule()
tgt_name = "llvm"
tgt = "llvm"
ctx = tvm.cpu()
# func
a = relay.var("a", dtype="float32", shape=(16, 8))
b = relay.var("b", dtype="float32", shape=(8, 8))
c = relay.var("c", dtype="float32", shape=(16, 8))
x = relay.nn.dense(a, b)
y = relay.nn.relu(x)
z = y + c
func = relay.Function([a, b, c], z)
A = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx)
B = tvm.nd.array(np.random.uniform(-1, 1, (8, 8)).astype("float32"), ctx=ctx)
C = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"), ctx=ctx)
params = {
"b" : B,
"c" : C
}
# build
targets = {
tgt: tgt
}
m_bld.set_opt_level(3)
m_bld.build(func, targets, "llvm -mcpu=sse3", params=params)
g_json = m_bld.get_json()
mmod = m_bld.get_module()
params = m_bld.get_params()

# test
A = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"))
B = tvm.nd.array(np.random.uniform(-1, 1, (8, 8)).astype("float32"))
C = tvm.nd.array(np.random.uniform(-1, 1, (16, 8)).astype("float32"))

rt = tvm.contrib.graph_runtime.create(g_json, mmod, tvm.cpu())
rt.set_input("a", A)
rt.set_input("b", B)
rt.set_input("c", C)
rt.run()
out = rt.get_output(0)
# test
rt = tvm.contrib.graph_runtime.create(g_json, mmod, ctx)
rt.set_input("a", A)
rt.load_params(relay.save_param_dict(params))
rt.run()
out = rt.get_output(0)

np.testing.assert_allclose(out.asnumpy(),
np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5)
np.testing.assert_allclose(out.asnumpy(),
np.maximum(np.dot(A.asnumpy(), B.asnumpy().T), 0) + C.asnumpy(), atol=1e-5, rtol=1e-5)

0 comments on commit 3d5798e

Please sign in to comment.