Skip to content

Commit

Permalink
[Relay][VM]VM Profiler (apache#3727)
Browse files Browse the repository at this point in the history
* [Relay][VM]VM debugger

* Report mean/min/max for op duration

* Typos

* Lint

* Lint

* Lint

* Support build debug VM in CMake

* Lint

* Enable VM debug in unit test

* Disable debug vm test until new docker image is built

* Add device sync code

* Fix qnn unit test

* Disable vm debug by default

* Rename files

* Rename classes

* Fix comment

* Fix comment
  • Loading branch information
wweic authored and jroesch committed Aug 21, 2019
1 parent c87ace7 commit 95f12e3
Show file tree
Hide file tree
Showing 13 changed files with 690 additions and 170 deletions.
35 changes: 32 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,33 @@ file(GLOB COMPILER_SRCS
src/schedule/*.cc
)

file(GLOB_RECURSE RELAY_SRCS
src/relay/*.cc
file(GLOB_RECURSE RELAY_OP_SRCS
src/relay/op/*.cc
)
list(APPEND COMPILER_SRCS ${RELAY_SRCS})
file(GLOB_RECURSE RELAY_PASS_SRCS
src/relay/pass/*.cc
)
file(GLOB RELAY_BACKEND_SRCS
src/relay/backend/*.cc
src/relay/backend/vm/*.cc
)
file(GLOB_RECURSE RELAY_IR_SRCS
src/relay/ir/*.cc
)
file(GLOB_RECURSE RELAY_QNN_SRCS
src/relay/qnn/*.cc
)
list(APPEND COMPILER_SRCS ${RELAY_OP_SRCS})
list(APPEND COMPILER_SRCS ${RELAY_PASS_SRCS})
list(APPEND COMPILER_SRCS ${RELAY_BACKEND_SRCS})
list(APPEND COMPILER_SRCS ${RELAY_IR_SRCS})
list(APPEND COMPILER_SRCS ${RELAY_QNN_SRCS})

if(USE_VM_PROFILER)
message(STATUS "Build compiler with Relay VM profiler support...")
file(GLOB BACKEND_VM_PROFILER_SRCS src/relay/backend/vm/profiler/*.cc)
list(APPEND COMPILER_SRCS ${BACKEND_VM_PROFILER_SRCS})
endif(USE_VM_PROFILER)

file(GLOB DATATYPE_SRCS src/codegen/datatype/*.cc)
list(APPEND COMPILER_SRCS ${DATATYPE_SRCS})
Expand Down Expand Up @@ -198,6 +221,12 @@ if(USE_GRAPH_RUNTIME)
endif(USE_GRAPH_RUNTIME_DEBUG)
endif(USE_GRAPH_RUNTIME)

if(USE_VM_PROFILER)
message(STATUS "Build with Relay VM profiler support...")
file(GLOB RUNTIME_VM_PROFILER_SRCS src/runtime/vm/profiler/*.cc)
list(APPEND RUNTIME_SRCS ${RUNTIME_VM_PROFILER_SRCS})
endif(USE_VM_PROFILER)

# Module rules
include(cmake/modules/VTA.cmake)
include(cmake/modules/CUDA.cmake)
Expand Down
4 changes: 4 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ stage('Build') {
echo set\\(USE_GRAPH_RUNTIME ON\\) >> config.cmake
echo set\\(USE_STACKVM_RUNTIME ON\\) >> config.cmake
echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake
echo set\\(USE_VM_PROFILER ON\\) >> config.cmake
echo set\\(USE_ANTLR ON\\) >> config.cmake
echo set\\(USE_BLAS openblas\\) >> config.cmake
echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake
Expand All @@ -164,6 +165,7 @@ stage('Build') {
echo set\\(USE_VULKAN ON\\) >> config.cmake
echo set\\(USE_MICRO ON\\) >> config.cmake
echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake
echo set\\(USE_VM_PROFILER ON\\) >> config.cmake
echo set\\(CMAKE_CXX_COMPILER clang-7\\) >> config.cmake
echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake
"""
Expand All @@ -182,6 +184,7 @@ stage('Build') {
echo set\\(USE_SORT ON\\) >> config.cmake
echo set\\(USE_MICRO ON\\) >> config.cmake
echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake
echo set\\(USE_VM_PROFILER ON\\) >> config.cmake
echo set\\(USE_LLVM llvm-config-8\\) >> config.cmake
echo set\\(USE_NNPACK ON\\) >> config.cmake
echo set\\(NNPACK_PATH /NNPACK/build/\\) >> config.cmake
Expand Down Expand Up @@ -212,6 +215,7 @@ stage('Build') {
echo set\\(USE_SORT ON\\) >> config.cmake
echo set\\(USE_RPC ON\\) >> config.cmake
echo set\\(USE_GRAPH_RUNTIME_DEBUG ON\\) >> config.cmake
echo set\\(USE_VM_PROFILER ON\\) >> config.cmake
echo set\\(USE_LLVM llvm-config-4.0\\) >> config.cmake
echo set\\(CMAKE_CXX_COMPILER g++\\) >> config.cmake
echo set\\(CMAKE_CXX_FLAGS -Werror\\) >> config.cmake
Expand Down
3 changes: 3 additions & 0 deletions cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ set(USE_GRAPH_RUNTIME ON)
# Whether enable additional graph debug functions
set(USE_GRAPH_RUNTIME_DEBUG OFF)

# Whether enable additional vm profiler functions
set(USE_VM_PROFILER OFF)

# Whether build with LLVM support
# Requires LLVM version >= 4.0
#
Expand Down
44 changes: 39 additions & 5 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,41 @@ struct VMFrame {
*/
class VirtualMachine : public runtime::ModuleNode {
public:
PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) final;
/*!
* \brief Get a PackedFunc from module.
*
* The PackedFunc may not be fully initialized,
* there might still be first time running overhead when
* executing the function on certain devices.
* For benchmarking, use prepare to eliminate
*
* \param name the name of the function.
* \param sptr_to_self The shared_ptr that points to this module node.
*
* \return PackedFunc(nullptr) when it is not available.
*
* \note The function will always remain valid.
* If the function needs resource from the module(e.g. late linking),
* it should capture sptr_to_self.
*/
virtual PackedFunc GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self);

/*!
* \brief Invoke a PackedFunction
*
* \param packed_index The offset of the PackedFunction in all functions.
* \param func The PackedFunction to be invoked.
* \param arg_count The number of arguments to the PackedFunction.
* \param output_size The number of outputs of the PackedFunction.
* \param args Arguments to the PackedFunction.
*
* \note The return value will be stored in the last output_size slots of args.
*/
virtual void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<Object>& args);

virtual ~VirtualMachine() {}

const char* type_key() const final {
return "VirtualMachine";
Expand Down Expand Up @@ -456,6 +489,10 @@ class VirtualMachine : public runtime::ModuleNode {
*/
void RunLoop();

/*! \brief Get device context for params.
*/
TVMContext GetParamsContext() const;

/*!
* \brief Load parameters from the parameter bytearray.
* \param params The binary file that contains parameters.
Expand All @@ -478,9 +515,6 @@ class VirtualMachine : public runtime::ModuleNode {
*/
void InvokeGlobal(const VMFunction& func, const std::vector<Object>& args);

/*! \brief Get device context for params.
*/
TVMContext GetParamsContext() const;

/*! \brief The parameter name to data mapping. */
std::unordered_map<std::string, Object> params_;
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from . import param_dict
from . import feature
from .backend import vm
from .backend import profiler_vm
from .backend import serializer
from .backend import deserializer
from .backend import vmobj
Expand Down
90 changes: 90 additions & 0 deletions python/tvm/relay/backend/profiler_vm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# License .to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-else-return, unidiomatic-typecheck, undefined-variable, invalid-name
"""
The Relay Virtual Machine profiler.
Provides extra APIs for profiling vm execution.
"""
import tvm
from . import vm, _vm

def _update_target(target):
target = target if target else tvm.target.current_target()
if target is None:
raise ValueError("Target is not set in env or passed as argument.")

tgts = {}
if isinstance(target, (str, tvm.target.Target)):
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(str(target)).device_type)
tgts[dev_type] = tvm.target.create(target)
elif isinstance(target, dict):
for dev, tgt in target.items():
dev_type = tvm.expr.IntImm("int32", tvm.nd.context(dev).device_type)
tgts[dev_type] = tvm.target.create(tgt)
else:
raise TypeError("target is expected to be str, tvm.target.Target, " +
"or dict of str to str/tvm.target.Target, but received " +
"{}".format(type(target)))
return tgts

class VMCompilerProfiler(vm.VMCompiler):
"""Build Relay module to run on VM runtime."""
def __init__(self):
super().__init__()
self.mod = _vm._VMCompilerProfiler()
self._compile = self.mod["compile"]
self._get_vm = self.mod["get_vm"]

def compile(self, mod, target=None, target_host=None):
"""
Parameters
----------
mod : relay.Module
The Relay module to build.
target : str, :any:`tvm.target.Target`, or dict of str(i.e.
device/context name) to str/tvm.target.Target, optional
For heterogeneous compilation, it is a dictionary indicating context
to target mapping. For homogeneous compilation, it is a build target.
target_host : str or :any:`tvm.target.Target`, optional
Host compilation target, if target is device.
When TVM compiles device specific program such as CUDA,
we also need host(CPU) side code to interact with the driver
to setup the dimensions and parameters correctly.
target_host is used to specify the host side codegen target.
By default, llvm is used if it is enabled,
otherwise a stackvm intepreter is used.
Returns
-------
vm : VirtualMachineProfiler
The profile VM runtime.
"""
target = _update_target(target)
self._compile(mod, target, target_host)
return VirtualMachineProfiler(self._get_vm())

class VirtualMachineProfiler(vm.VirtualMachine):
"""Relay profile VM runtime."""
def __init__(self, mod):
super().__init__(mod)
self._get_stat = self.mod["get_stat"]

def get_stat(self):
return self._get_stat()
Loading

0 comments on commit 95f12e3

Please sign in to comment.