Skip to content

Commit

Permalink
[Relay] External codegen (#4482)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics authored and tqchen committed Dec 18, 2019
1 parent 603280b commit c44b7bf
Show file tree
Hide file tree
Showing 24 changed files with 1,599 additions and 51 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,8 @@ include(cmake/modules/LLVM.cmake)
include(cmake/modules/Micro.cmake)
include(cmake/modules/ANTLR.cmake)
include(cmake/modules/contrib/BLAS.cmake)
include(cmake/modules/contrib/CODEGENC.cmake)
include(cmake/modules/contrib/DNNL.cmake)
include(cmake/modules/contrib/Random.cmake)
include(cmake/modules/contrib/MicroStandaloneRuntime.cmake)
include(cmake/modules/contrib/Sort.cmake)
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,9 @@ set(USE_ROCBLAS OFF)
# Whether use contrib sort
set(USE_SORT ON)

# Whether use MKL-DNN (DNNL) codegen
set(USE_DNNL_CODEGEN OFF)

# Build ANTLR parser for Relay text format
# Possible values:
# - ON: enable ANTLR by searching default locations (cmake find_program for antlr4 and /usr/local for jar)
Expand Down
20 changes: 20 additions & 0 deletions cmake/modules/contrib/CODEGENC.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

file(GLOB CSOURCE_RELAY_CONTRIB_SRC src/relay/backend/contrib/codegen_c/codegen.cc)
list(APPEND COMPILER_SRCS ${CSOURCE_RELAY_CONTRIB_SRC})

28 changes: 28 additions & 0 deletions cmake/modules/contrib/DNNL.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

if(USE_DNNL_CODEGEN STREQUAL "ON")
file(GLOB DNNL_RELAY_CONTRIB_SRC src/relay/backend/contrib/dnnl/codegen.cc)
list(APPEND COMPILER_SRCS ${DNNL_RELAY_CONTRIB_SRC})

find_library(EXTERN_LIBRARY_DNNL dnnl)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${EXTERN_LIBRARY_DNNL})
file(GLOB DNNL_CONTRIB_SRC src/runtime/contrib/dnnl/*)
list(APPEND RUNTIME_SRCS ${DNNL_CONTRIB_SRC})
message(STATUS "Build with DNNL codegen: " ${EXTERN_LIBRARY_DNNL})
endif()

3 changes: 3 additions & 0 deletions include/tvm/build_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ TVM_DLL Target intel_graphics(const std::vector<std::string>& options =
TVM_DLL Target stackvm(const std::vector<std::string>& options =
std::vector<std::string>());

/*! \return A target for external device */
TVM_DLL Target ext_dev(const std::vector<std::string>& options =
std::vector<std::string>());
} // namespace target

/*!
Expand Down
28 changes: 28 additions & 0 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,15 @@ class FunctionNode : public ExprNode {
*/
bool IsPrimitive() const;

/*!
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool UseDefaultCompiler() const;

TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
Expand Down Expand Up @@ -588,6 +597,25 @@ std::string AsText(const NodeRef& node,
bool show_meta_data = true,
runtime::TypedPackedFunc<std::string(Expr)> annotate = nullptr);

/*! \brief namespace of the attributes that are attached to a function. */
namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
* \brief Indicate the compiler that should be used for builing this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
*/
constexpr const char* kCompiler = "Compiler";
/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
} // namespace attr

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
11 changes: 10 additions & 1 deletion python/tvm/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,16 @@ def export_library(self,
self.save(path_obj)
files = [path_obj]
is_system_lib = self.type_key == "llvm" and self.get_function("__tvm_is_system_module")()
has_imported_c_file = False
if self.imported_modules:
for i, m in enumerate(self.imported_modules):
if m.type_key == "c":
has_imported_c_file = True
c_file_name = "tmp_" + str(i) + ".cc"
path_cc = temp.relpath(c_file_name)
with open(path_cc, "w") as f:
f.write(m.get_source())
files.append(path_cc)
path_cc = temp.relpath("devc.cc")
with open(path_cc, "w") as f:
f.write(_PackImportsToC(self, is_system_lib))
Expand All @@ -143,7 +152,7 @@ def export_library(self,
fcompile = _tar.tar
else:
fcompile = _cc.create_shared
if self.type_key == "c":
if self.type_key == "c" or has_imported_c_file:
options = []
if "options" in kwargs:
opts = kwargs["options"]
Expand Down
4 changes: 4 additions & 0 deletions src/codegen/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,10 @@ Target intel_graphics(const std::vector<std::string>& options) {
Target stackvm(const std::vector<std::string>& options) {
return CreateTarget("stackvm", options);
}

Target ext_dev(const std::vector<std::string>& options) {
return CreateTarget("ext_dev", options);
}
} // namespace target

bool LLVMEnabled() {
Expand Down
1 change: 1 addition & 0 deletions src/codegen/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ std::string PackImportsToC(const runtime::Module& mod, bool system_lib) {
<< "Only support simply one-level hierarchy";
std::string tkey = im->type_key();
stream->Write(tkey);
if (tkey == "c") continue;
im->SaveToBinary(stream);
}
// translate to C program
Expand Down
22 changes: 22 additions & 0 deletions src/relay/backend/build_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ struct GraphCodegen {
return CallFunc<std::string>("get_graph_json", nullptr);
}

Array<tvm::runtime::Module> GetExternalModules() {
return CallFunc<Array<tvm::runtime::Module> >("get_external_modules", nullptr);
}

Map<std::string, Array<LoweredFunc> > GetLoweredFunc() {
return CallFunc<Map<std::string, Array<LoweredFunc> > >("get_lowered_funcs", nullptr);
}
Expand Down Expand Up @@ -148,6 +152,10 @@ class RelayBuildModule : public runtime::ModuleNode {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetLoweredFunc();
});
} else if (name == "get_external_modules") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->graph_codegen_->GetExternalModules();
});
} else if (name == "optimize") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.num_args, 2);
Expand Down Expand Up @@ -474,6 +482,20 @@ class RelayBuildModule : public runtime::ModuleNode {
target_host_,
BuildConfig::Current());
}
Array<tvm::runtime::Module> ext_mods = graph_codegen_->GetExternalModules();
if (!ext_mods.empty()) {
CHECK(lowered_funcs.size() > 0 || ext_mods.size() == 1)
<< "Expect to have a TVM DSOModule when multiple external runtime modules exist";
if (lowered_funcs.size() == 0) {
// Execute the whole module using external runtime.
ret_.mod = ext_mods[0];
} else {
// Import all external runtime modules.
for (const auto& it : ext_mods) {
ret_.mod.Import(it);
}
}
}
}

protected:
Expand Down
83 changes: 70 additions & 13 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/registry.h>
#include <tvm/relay/attrs/device_copy.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/relay/op_attr_types.h>
#include <topi/tags.h>
Expand Down Expand Up @@ -608,6 +609,46 @@ class CompileEngineImpl : public CompileEngineNode {
return LowerShapeFuncInternal(key)->cached_func;
}

Array<tvm::runtime::Module> LowerExternalFunctions() {
std::unordered_map<std::string, relay::Module> ext_mods;
std::vector<CCacheKey> cached_ext_funcs;
for (const auto& it : cache_) {
auto src_func = it.first->source_func;
CHECK(src_func.defined());
if (!src_func->UseDefaultCompiler()) {
auto compiler = FunctionGetAttr(src_func, attr::kCompiler);
const tvm::ir::StringImm* code_gen = compiler.as<tvm::ir::StringImm>();
CHECK(code_gen) << "No external codegen is set";
if (ext_mods.find(code_gen->value) == ext_mods.end()) {
ext_mods[code_gen->value] = relay::ModuleNode::make({}, {});
}
auto ext_symbol = FunctionGetAttr(src_func, attr::kExternalSymbol);
const tvm::ir::StringImm* symbol_name = ext_symbol.as<tvm::ir::StringImm>();
CHECK(symbol_name) << "No external symbol is set for:\n" << AsText(src_func, false);
auto gv = GlobalVarNode::make(symbol_name->value);
ext_mods[code_gen->value]->Add(gv, src_func);
cached_ext_funcs.push_back(it.first);
}
}

Array<tvm::runtime::Module> ret;
for (const auto& it : ext_mods) {
std::string ext_name = "relay.ext." + it.first;
auto pf = tvm::runtime::Registry::Get(ext_name);
CHECK(pf) << "Failed to find the codegen tool for " << ext_name << "\n";
runtime::Module ext_mod = (*pf)(it.second);
CHECK(ext_mod.defined()) << "No external runtime is generated.";
ret.push_back(ext_mod);
}

// No need to cache external functions as we collected them all to create
// external runtime modules.
for (const auto& it : cached_ext_funcs) {
cache_.erase(it);
}
return ret;
}

void Clear() final {
cache_.clear();
}
Expand Down Expand Up @@ -648,6 +689,18 @@ class CompileEngineImpl : public CompileEngineNode {
value->use_count = 0;
cache_[key] = value;
}
// No need to lower external functions for now. We will invoke the external
// codegen tool once and lower all functions together.
if (!key->source_func->UseDefaultCompiler()) {
auto cache_node = make_node<CachedFuncNode>();
const auto name_node =
FunctionGetAttr(key->source_func, attr::kExternalSymbol).as<tvm::ir::StringImm>();
CHECK(name_node != nullptr) << "External function has not been attached a name yet.";
cache_node->func_name = name_node->value;
cache_node->target = tvm::target::ext_dev();
value->cached_func = CachedFunc(cache_node);
return value;
}
// Enforce use the target.
With<Target> target_scope(key->target);

Expand Down Expand Up @@ -759,42 +812,46 @@ const CompileEngine& CompileEngine::Global() {
return *inst;
}


TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
.set_body_typed<CCacheKey(Function, Target)>(CCacheKeyNode::make);

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal")
.set_body_typed<CompileEngine()>([]() {
return CompileEngine::Global();
});
return CompileEngine::Global();
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear")
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) {
self->Clear();
});
self->Clear();
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->Lower(key);
});
return self->Lower(key);
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc")
.set_body_typed<CachedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->LowerShapeFunc(key);
});
return self->LowerShapeFunc(key);
});

TVM_REGISTER_GLOBAL("relay.backend._CompileLowerExternalFunctions")
.set_body_typed<void(const CompileEngine&)>([](CompileEngine self) {
return self->LowerExternalFunctions();
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineJIT")
.set_body_typed<PackedFunc(CompileEngine, CCacheKey)>(
[](CompileEngine self, CCacheKey key) {
return self->JIT(key);
});
return self->JIT(key);
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineListItems")
.set_body_typed<Array<NodeRef>(CompileEngine)>(
[](CompileEngine self){
return static_cast<CompileEngineImpl*>(self.operator->())->ListItems();
});
return static_cast<CompileEngineImpl*>(self.operator->())->ListItems();
});
} // namespace relay
} // namespace tvm
7 changes: 7 additions & 0 deletions src/relay/backend/compile_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#define TVM_RELAY_BACKEND_COMPILE_ENGINE_H_

#include <tvm/lowered_func.h>
#include <tvm/runtime/module.h>
#include <tvm/relay/analysis.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/transform.h>
Expand Down Expand Up @@ -186,6 +187,12 @@ class CompileEngineNode : public Node {
* \return The result.
*/
virtual CachedFunc LowerShapeFunc(const CCacheKey& key) = 0;
/*!
* \brief Lower the external function using external codegen tools.
* \return The runtime moduels for each needed external codegen tool.
*/
virtual tvm::Array<tvm::runtime::Module> LowerExternalFunctions() = 0;

/*! \brief clear the cache. */
virtual void Clear() = 0;

Expand Down
Loading

0 comments on commit c44b7bf

Please sign in to comment.