From 6848300636a055a108911da79d1800fff2b3e02e Mon Sep 17 00:00:00 2001 From: Andrew Reusch Date: Tue, 15 Sep 2020 09:10:26 -0700 Subject: [PATCH] =?UTF-8?q?=C2=B5TVM=20RPC=20server=20and=20Part=201=20of?= =?UTF-8?q?=20AutoTVM=20compilation=20infrastructure=20(#6334)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../TARGET_SDK_11/libraries/crc16/crc16.c | 55 ++ .../TARGET_SDK_11/libraries/crc16/crc16.h | 78 ++ CMakeLists.txt | 26 +- LICENSE | 1 + apps/bundle_deploy/Makefile | 18 +- apps/bundle_deploy/bundle.c | 17 +- apps/bundle_deploy/bundle_static.c | 17 +- apps/bundle_deploy/crt_config/crt_config.h | 26 +- apps/bundle_deploy/demo_static.c | 8 +- apps/bundle_deploy/test.cc | 3 +- apps/bundle_deploy/test_static.c | 2 +- cmake/modules/StandaloneCrt.cmake | 259 +++--- include/tvm/runtime/crt/crt.h | 7 +- include/tvm/runtime/crt/error_codes.h | 32 +- .../tvm/runtime/crt}/logging.h | 60 +- include/tvm/runtime/crt/platform.h | 4 +- .../tvm/runtime/crt/rpc_common/frame_buffer.h | 72 ++ include/tvm/runtime/crt/rpc_common/framing.h | 269 ++++++ include/tvm/runtime/crt/rpc_common/session.h | 241 ++++++ .../tvm/runtime/crt/rpc_common/write_stream.h | 38 +- include/tvm/runtime/crt/utvm_rpc_server.h | 92 +++ python/tvm/micro/__init__.py | 14 +- python/tvm/micro/artifact.py | 206 +++++ python/tvm/micro/base.py | 322 -------- python/tvm/micro/build.py | 148 ++++ python/tvm/micro/class_factory.py | 97 +++ python/tvm/micro/compiler.py | 318 ++++++++ python/tvm/micro/debugger.py | 188 +++++ python/tvm/micro/device/__init__.py | 23 - python/tvm/micro/device/arm/__init__.py | 19 - python/tvm/micro/device/arm/stm32f746xx.py | 137 ---- python/tvm/micro/device/base.py | 237 ------ python/tvm/micro/device/host.py | 127 --- python/tvm/micro/device/riscv_spike.py | 112 --- python/tvm/micro/func_registry.py | 2 +- python/tvm/micro/micro_binary.py | 51 ++ python/tvm/micro/micro_library.py | 80 ++ python/tvm/micro/session.py | 124 +++ python/tvm/micro/transport.py | 225 +++++ python/tvm/rpc/minrpc.py | 9 +- python/tvm/target/target.py | 19 + src/runtime/crt/Makefile | 50 +- src/runtime/crt/common/crt_runtime_api.c | 18 +- src/runtime/crt/common/memory.c | 39 +- src/runtime/crt/common/packed_func.c | 2 +- src/runtime/crt/crt_config-template.h | 54 ++ src/runtime/crt/graph_runtime/graph_runtime.c | 2 +- src/runtime/crt/host/crt_config.h | 34 +- src/runtime/crt/host/main.cc | 122 +++ .../tvm/runtime/crt/internal/common/memory.h | 33 +- .../crt/utvm_rpc_common/frame_buffer.cc | 64 ++ src/runtime/crt/utvm_rpc_common/framing.cc | 411 ++++++++++ src/runtime/crt/utvm_rpc_common/session.cc | 279 +++++++ .../crt/utvm_rpc_common/write_stream.cc | 55 ++ src/runtime/crt/utvm_rpc_server/rpc_server.cc | 261 ++++++ .../micro/device/arm/stm32f746xx/utvm_init.s | 39 - .../micro/device/arm/stm32f746xx/utvm_timer.c | 77 -- src/runtime/micro/device/host/utvm_timer.c | 36 - .../micro/device/riscv_spike/utvm_init.s | 23 - .../host_driven/utvm_device_dylib_redirect.c | 90 -- src/runtime/micro/host_driven/utvm_runtime.c | 185 ----- src/runtime/micro/host_driven/utvm_runtime.h | 95 --- .../micro/host_driven/utvm_runtime_enum.h | 51 -- src/runtime/micro/host_low_level_device.cc | 92 --- src/runtime/micro/low_level_device.h | 90 -- src/runtime/micro/micro_common.cc | 131 --- src/runtime/micro/micro_common.h | 359 -------- src/runtime/micro/micro_device_api.cc | 162 ---- src/runtime/micro/micro_module.cc | 110 --- src/runtime/micro/micro_section_allocator.h | 134 --- src/runtime/micro/micro_session.cc | 768 ++++-------------- src/runtime/micro/micro_session.h | 398 --------- src/runtime/micro/openocd_low_level_device.cc | 221 ----- .../micro/target_data_layout_encoder.cc | 73 -- .../micro/target_data_layout_encoder.h | 200 ----- src/runtime/micro/tcl_socket.cc | 71 -- src/runtime/micro/tcl_socket.h | 97 --- src/runtime/{rpc => }/minrpc/minrpc_server.h | 128 +-- .../posix_popen_server}/posix_popen_server.cc | 12 +- .../rpc_protocol.h => minrpc/rpc_reference.h} | 55 +- src/runtime/rpc/rpc_endpoint.cc | 8 + src/runtime/rpc/rpc_endpoint.h | 2 +- src/runtime/rpc/rpc_session.h | 2 +- src/support/arena.h | 138 +--- src/support/generic_arena.h | 183 +++++ src/target/target_kind.cc | 1 + tests/cpp/utvm_runtime_standalone_test.cc | 4 +- tests/crt/buffer_write_stream.h | 63 ++ tests/crt/framing_test.cc | 317 ++++++++ tests/crt/func_registry_test.cc | 2 + tests/crt/memory_test.cc | 5 +- .../host/utvm_init.c => tests/crt/platform.cc | 39 +- tests/crt/session_test.cc | 265 ++++++ tests/python/unittest/test_crt.py | 141 ++++ tests/python/unittest/test_runtime_micro.py | 361 -------- tests/scripts/task_config_build_cpu.sh | 2 +- tests/scripts/task_config_build_i386.sh | 1 + tests/scripts/task_cpp_unittest.sh | 10 +- tests/scripts/task_python_integration.sh | 6 - tutorials/micro/micro_tflite.py | 94 +-- 100 files changed, 5299 insertions(+), 5249 deletions(-) create mode 100644 3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/libraries/crc16/crc16.c create mode 100644 3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/libraries/crc16/crc16.h rename {src/runtime/crt/include/tvm/runtime/crt/internal/common => include/tvm/runtime/crt}/logging.h (52%) create mode 100644 include/tvm/runtime/crt/rpc_common/frame_buffer.h create mode 100644 include/tvm/runtime/crt/rpc_common/framing.h create mode 100644 include/tvm/runtime/crt/rpc_common/session.h rename src/runtime/micro/device/riscv_spike/utvm_timer.c => include/tvm/runtime/crt/rpc_common/write_stream.h (53%) create mode 100644 include/tvm/runtime/crt/utvm_rpc_server.h create mode 100644 python/tvm/micro/artifact.py create mode 100644 python/tvm/micro/build.py create mode 100644 python/tvm/micro/class_factory.py create mode 100644 python/tvm/micro/compiler.py create mode 100644 python/tvm/micro/debugger.py delete mode 100644 python/tvm/micro/device/__init__.py delete mode 100644 python/tvm/micro/device/arm/__init__.py delete mode 100644 python/tvm/micro/device/arm/stm32f746xx.py delete mode 100644 python/tvm/micro/device/base.py delete mode 100644 python/tvm/micro/device/host.py delete mode 100644 python/tvm/micro/device/riscv_spike.py create mode 100644 python/tvm/micro/micro_binary.py create mode 100644 python/tvm/micro/micro_library.py create mode 100644 python/tvm/micro/session.py create mode 100644 python/tvm/micro/transport.py create mode 100644 src/runtime/crt/crt_config-template.h create mode 100644 src/runtime/crt/host/main.cc create mode 100644 src/runtime/crt/utvm_rpc_common/frame_buffer.cc create mode 100644 src/runtime/crt/utvm_rpc_common/framing.cc create mode 100644 src/runtime/crt/utvm_rpc_common/session.cc create mode 100644 src/runtime/crt/utvm_rpc_common/write_stream.cc create mode 100644 src/runtime/crt/utvm_rpc_server/rpc_server.cc delete mode 100644 src/runtime/micro/device/arm/stm32f746xx/utvm_init.s delete mode 100644 src/runtime/micro/device/arm/stm32f746xx/utvm_timer.c delete mode 100644 src/runtime/micro/device/host/utvm_timer.c delete mode 100644 src/runtime/micro/device/riscv_spike/utvm_init.s delete mode 100644 src/runtime/micro/host_driven/utvm_device_dylib_redirect.c delete mode 100644 src/runtime/micro/host_driven/utvm_runtime.c delete mode 100644 src/runtime/micro/host_driven/utvm_runtime.h delete mode 100644 src/runtime/micro/host_driven/utvm_runtime_enum.h delete mode 100644 src/runtime/micro/host_low_level_device.cc delete mode 100644 src/runtime/micro/low_level_device.h delete mode 100644 src/runtime/micro/micro_common.cc delete mode 100644 src/runtime/micro/micro_common.h delete mode 100644 src/runtime/micro/micro_device_api.cc delete mode 100644 src/runtime/micro/micro_module.cc delete mode 100644 src/runtime/micro/micro_section_allocator.h delete mode 100644 src/runtime/micro/openocd_low_level_device.cc delete mode 100644 src/runtime/micro/target_data_layout_encoder.cc delete mode 100644 src/runtime/micro/target_data_layout_encoder.h delete mode 100644 src/runtime/micro/tcl_socket.cc delete mode 100644 src/runtime/micro/tcl_socket.h rename src/runtime/{rpc => }/minrpc/minrpc_server.h (86%) rename src/runtime/{rpc/minrpc => minrpc/posix_popen_server}/posix_popen_server.cc (89%) rename src/runtime/{rpc/rpc_protocol.h => minrpc/rpc_reference.h} (90%) create mode 100644 src/support/generic_arena.h create mode 100644 tests/crt/buffer_write_stream.h create mode 100644 tests/crt/framing_test.cc rename src/runtime/micro/device/host/utvm_init.c => tests/crt/platform.cc (52%) create mode 100644 tests/crt/session_test.cc create mode 100644 tests/python/unittest/test_crt.py delete mode 100644 tests/python/unittest/test_runtime_micro.py diff --git a/3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/libraries/crc16/crc16.c b/3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/libraries/crc16/crc16.c new file mode 100644 index 000000000000..cf63a3c93bd2 --- /dev/null +++ b/3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/libraries/crc16/crc16.c @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2013 Nordic Semiconductor ASA + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, + * are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this list + * of conditions and the following disclaimer. + * + * 2. Redistributions in binary form, except as embedded into a Nordic Semiconductor ASA + * integrated circuit in a product or a software update for such product, must reproduce + * the above copyright notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of Nordic Semiconductor ASA nor the names of its contributors may be + * used to endorse or promote products derived from this software without specific prior + * written permission. + * + * 4. This software, with or without modification, must only be used with a + * Nordic Semiconductor ASA integrated circuit. + * + * 5. Any software provided in binary or object form under this license must not be reverse + * engineered, decompiled, modified and/or disassembled. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON + * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +#include "crc16.h" + +#include + +uint16_t crc16_compute(uint8_t const* p_data, uint32_t size, uint16_t const* p_crc) { + uint16_t crc = (p_crc == NULL) ? 0xFFFF : *p_crc; + + for (uint32_t i = 0; i < size; i++) { + crc = (uint8_t)(crc >> 8) | (crc << 8); + crc ^= p_data[i]; + crc ^= (uint8_t)(crc & 0xFF) >> 4; + crc ^= (crc << 8) << 4; + crc ^= ((crc & 0xFF) << 4) << 1; + } + + return crc; +} diff --git a/3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/libraries/crc16/crc16.h b/3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/libraries/crc16/crc16.h new file mode 100644 index 000000000000..d925880f6cca --- /dev/null +++ b/3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/libraries/crc16/crc16.h @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2013 Nordic Semiconductor ASA + * All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, + * are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this list + * of conditions and the following disclaimer. + * + * 2. Redistributions in binary form, except as embedded into a Nordic Semiconductor ASA + * integrated circuit in a product or a software update for such product, must reproduce + * the above copyright notice, this list of conditions and the following disclaimer in + * the documentation and/or other materials provided with the distribution. + * + * 3. Neither the name of Nordic Semiconductor ASA nor the names of its contributors may be + * used to endorse or promote products derived from this software without specific prior + * written permission. + * + * 4. This software, with or without modification, must only be used with a + * Nordic Semiconductor ASA integrated circuit. + * + * 5. Any software provided in binary or object form under this license must not be reverse + * engineered, decompiled, modified and/or disassembled. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON + * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + */ + +/** @file + * + * @defgroup crc_compute CRC compute + * @{ + * @ingroup hci_transport + * + * @brief This module implements CRC-16-CCITT (polynomial 0x1021) with 0xFFFF initial value. + * The data can be passed in multiple blocks. + */ + +#ifndef CRC16_H__ +#define CRC16_H__ + +#ifdef __cplusplus +extern "C" { +#endif + +#include + +/**@brief Function for calculating CRC-16 in blocks. + * + * Feed each consecutive data block into this function, along with the current value of p_crc as + * returned by the previous call of this function. The first call of this function should pass NULL + * as the initial value of the crc in p_crc. + * + * @param[in] p_data The input data block for computation. + * @param[in] size The size of the input data block in bytes. + * @param[in] p_crc The previous calculated CRC-16 value or NULL if first call. + * + * @return The updated CRC-16 value, based on the input supplied. + */ +uint16_t crc16_compute(uint8_t const* p_data, uint32_t size, uint16_t const* p_crc); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // CRC16_H__ + +/** @} */ diff --git a/CMakeLists.txt b/CMakeLists.txt index 171262b9512d..e33d2ef463a3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -40,7 +40,7 @@ tvm_option(USE_OPENMP "Build with OpenMP thread pool implementation" OFF) tvm_option(USE_RELAY_DEBUG "Building Relay in debug mode..." OFF) tvm_option(USE_RTTI "Build with RTTI" ON) tvm_option(USE_MSVC_MT "Build with MT" OFF) -tvm_option(USE_MICRO "Build with Micro" OFF) +tvm_option(USE_MICRO "Build with Micro TVM support" OFF) tvm_option(INSTALL_DEV "Install compiler infrastructure" OFF) tvm_option(HIDE_PRIVATE_SYMBOLS "Compile with -fvisibility=hidden." OFF) tvm_option(USE_TF_COMPILE_FLAGS "Build with TensorFlow's compile flags." OFF) @@ -79,7 +79,6 @@ tvm_option(USE_TARGET_ONNX "Build with ONNX Codegen support" OFF) tvm_option(USE_ARM_COMPUTE_LIB "Build with Arm Compute Library" OFF) tvm_option(USE_ARM_COMPUTE_LIB_GRAPH_RUNTIME "Build with Arm Compute Library graph runtime" OFF) - # include directories include_directories(${CMAKE_INCLUDE_PATH}) include_directories("include") @@ -421,12 +420,12 @@ if(USE_THREADS AND NOT BUILD_FOR_HEXAGON) set(CMAKE_THREAD_PREFER_PTHREAD TRUE) set(THREADS_PREFER_PTHREAD_FLAG TRUE) find_package(Threads REQUIRED) - target_link_libraries(tvm Threads::Threads) - target_link_libraries(tvm_runtime Threads::Threads) + target_link_libraries(tvm PUBLIC Threads::Threads) + target_link_libraries(tvm_runtime PUBLIC Threads::Threads) endif() -target_link_libraries(tvm ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS}) -target_link_libraries(tvm_runtime ${TVM_RUNTIME_LINKER_LIBS}) +target_link_libraries(tvm PRIVATE ${TVM_LINKER_LIBS} ${TVM_RUNTIME_LINKER_LIBS}) +target_link_libraries(tvm_runtime PRIVATE ${TVM_RUNTIME_LINKER_LIBS}) # Related headers target_include_directories( @@ -435,20 +434,27 @@ target_include_directories( target_include_directories( tvm_objs PUBLIC "topi/include") +set(CRC16_INCLUDE_PATH "3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/libraries/crc16") +target_include_directorieS( + tvm_objs + PRIVATE "${CRC16_INCLUDE_PATH}") +target_include_directorieS( + tvm_runtime_objs + PRIVATE "${CRC16_INCLUDE_PATH}") set(TVM_TEST_LIBRARY_NAME tvm) if (HIDE_PRIVATE_SYMBOLS AND NOT ${CMAKE_SYSTEM_NAME} MATCHES "Darwin") add_library(tvm_allvisible SHARED $) target_include_directories(tvm_allvisible PUBLIC "$") - target_link_libraries(tvm_allvisible PUBLIC "$") + target_link_libraries(tvm_allvisible PRIVATE "$") set(TVM_TEST_LIBRARY_NAME tvm_allvisible) -set(HIDE_SYMBOLS_LINKER_FLAGS "-Wl,--exclude-libs,ALL") + set(HIDE_SYMBOLS_LINKER_FLAGS "-Wl,--exclude-libs,ALL") # Note: 'target_link_options' with 'PRIVATE' keyword would be cleaner # but it's not available until CMake 3.13. Switch to 'target_link_options' # once minimum CMake version is bumped up to 3.13 or above. - target_link_libraries(tvm ${HIDE_SYMBOLS_LINKER_FLAGS}) - target_link_libraries(tvm_runtime ${HIDE_SYMBOLS_LINKER_FLAGS}) + target_link_libraries(tvm PRIVATE ${HIDE_SYMBOLS_LINKER_FLAGS}) + target_link_libraries(tvm_runtime PRIVATE ${HIDE_SYMBOLS_LINKER_FLAGS}) endif() # Tests diff --git a/LICENSE b/LICENSE index 49856917b215..1c7ab8205141 100644 --- a/LICENSE +++ b/LICENSE @@ -212,6 +212,7 @@ Apache Software Foundation License 2.0 3rdparty/bfloat16/bfloat16.cc 3rdparty/dlpack 3rdparty/dmlc-core +3rdparty/mbed-os BSD 2-clause License diff --git a/apps/bundle_deploy/Makefile b/apps/bundle_deploy/Makefile index adb8d3386bdf..d95dbe7018fe 100644 --- a/apps/bundle_deploy/Makefile +++ b/apps/bundle_deploy/Makefile @@ -19,7 +19,10 @@ # Setup build environment TVM_ROOT=$(shell cd ../..; pwd) -CRT_ROOT ?= ../../src/runtime/crt +CRT_ROOT ?= ../../build/standalone_crt +ifeq ($(shell ls -lhd $(CRT_ROOT)),) +$(error "CRT not found. Ensure you have built the standalone_crt target and try again") +endif ENABLE_TVM_PLATFORM_ABORT_BACKTRACE ?= 1 @@ -57,6 +60,7 @@ $(else) QUIET ?= @ $(endif) +CRT_SRCS = $(shell find $(CRT_ROOT)) demo_dynamic: $(build_dir)/demo_dynamic $(build_dir)/bundle.so $(build_dir)/bundle_c.so $(build_dir)/bundle.so $(build_dir)/graph_cpp.json $(build_dir)/graph_c.json $(build_dir)/params_cpp.bin $(build_dir)/params_c.bin $(build_dir)/cat.bin $(QUIET)TVM_NUM_THREADS=1 $(build_dir)/demo_dynamic $(build_dir)/bundle.so $(build_dir)/graph_cpp.json $(build_dir)/params_cpp.bin $(build_dir)/cat.bin @@ -72,10 +76,10 @@ demo_static: $(build_dir)/demo_static $(build_dir)/cat.bin test_static: $(build_dir)/test_static $(build_dir)/test_data_c.bin $(build_dir)/test_output_c.bin $(QUIET)TVM_NUM_THREADS=1 $(build_dir)/test_static $(build_dir)/test_data_c.bin $(build_dir)/test_output_c.bin $(build_dir)/test_graph_c.json $(build_dir)/test_params_c.bin -$(build_dir)/crt/graph_runtime/libgraph_runtime.a: +$(build_dir)/crt/libgraph_runtime.a: $(CRT_SRCS) $(QUIET)cd $(CRT_ROOT) && make QUIET= BUILD_DIR=$(abspath $(build_dir))/crt CRT_CONFIG=$(abspath crt_config/crt_config.h) "EXTRA_CFLAGS=$(PKG_COMPILE_OPTS)" graph_runtime -$(build_dir)/crt/common/libcommon.a: +$(build_dir)/crt/libcommon.a: $(CRT_SRCS) $(QUIET)cd $(CRT_ROOT) && make QUIET= BUILD_DIR=$(abspath $(build_dir))/crt CRT_CONFIG=$(abspath crt_config/crt_config.h) "EXTRA_CFLAGS=$(PKG_COMPILE_OPTS)" common $(build_dir)/demo_dynamic: demo.cc @@ -86,11 +90,11 @@ $(build_dir)/test_dynamic: test.cc ${build_dir}/test_graph_c.json ${build_dir}/t $(QUIET)mkdir -p $(@D) $(QUIET)g++ $(PKG_CXXFLAGS) -o $@ test.cc $(BACKTRACE_OBJS) $(BACKTRACE_LDFLAGS) -$(build_dir)/demo_static: demo_static.c ${build_dir}/bundle_static.o ${build_dir}/model_c.o ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a $(BACKTRACE_OBJS) +$(build_dir)/demo_static: demo_static.c ${build_dir}/bundle_static.o ${build_dir}/model_c.o ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a ${build_dir}/graph_c.json.c ${build_dir}/params_c.bin.c $(BACKTRACE_OBJS) $(QUIET)mkdir -p $(@D) $(QUIET)gcc $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_CFLAGS) -$(build_dir)/test_static: test_static.c ${build_dir}/bundle_static.o ${build_dir}/test_model_c.o ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a $(BACKTRACE_OBJS) +$(build_dir)/test_static: test_static.c ${build_dir}/bundle_static.o ${build_dir}/test_model_c.o ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a $(BACKTRACE_OBJS) $(QUIET)mkdir -p $(@D) $(QUIET)gcc $(PKG_CFLAGS) -o $@ $^ $(BACKTRACE_LDFLAGS) @@ -124,7 +128,7 @@ $(build_dir)/bundle.so: bundle.cc runtime.cc $(build_dir)/model_cpp.o $(QUIET)mkdir -p $(@D) $(QUIET)g++ -shared $(PKG_CXXFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) -$(build_dir)/bundle_c.so: bundle.c $(build_dir)/model_c.o ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a $(BACKTRACE_OBJS) +$(build_dir)/bundle_c.so: bundle.c $(build_dir)/model_c.o ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a $(BACKTRACE_OBJS) $(QUIET)mkdir -p $(@D) $(QUIET)gcc -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS) @@ -132,7 +136,7 @@ $(build_dir)/test_bundle.so: bundle.cc runtime.cc $(build_dir)/test_model_cpp.o $(QUIET)mkdir -p $(@D) $(QUIET)g++ -shared $(PKG_CXXFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) -$(build_dir)/test_bundle_c.so: bundle.c $(build_dir)/test_model_c.o ${build_dir}/crt/graph_runtime/libgraph_runtime.a ${build_dir}/crt/common/libcommon.a $(BACKTRACE_OBJS) +$(build_dir)/test_bundle_c.so: bundle.c $(build_dir)/test_model_c.o ${build_dir}/crt/libgraph_runtime.a ${build_dir}/crt/libcommon.a $(BACKTRACE_OBJS) $(QUIET)mkdir -p $(@D) $(QUIET)gcc -shared $(PKG_CFLAGS) -fvisibility=hidden -o $@ $^ $(PKG_LDFLAGS) $(BACKTRACE_LDFLAGS) $(BACKTRACE_CFLAGS) diff --git a/apps/bundle_deploy/bundle.c b/apps/bundle_deploy/bundle.c index 9ff67eaf3ed9..f11f91ac0531 100644 --- a/apps/bundle_deploy/bundle.c +++ b/apps/bundle_deploy/bundle.c @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -28,6 +29,11 @@ #include "backtrace.h" #endif +#define CRT_MEMORY_NUM_PAGES 16384 +#define CRT_MEMORY_PAGE_SIZE_LOG2 10 + +static uint8_t g_crt_memory[CRT_MEMORY_NUM_PAGES * (1 << CRT_MEMORY_PAGE_SIZE_LOG2)]; + /*! \brief macro to do C API call */ #define TVM_CCALL(func) \ do { \ @@ -56,7 +62,7 @@ TVM_DLL void* tvm_runtime_create(const char* json_data, const char* params_data, ctx.device_id = device_id; // declare pointers - TVM_CCALL(TVMInitializeRuntime()); + TVM_CCALL(TVMInitializeRuntime(g_crt_memory, sizeof(g_crt_memory), CRT_MEMORY_PAGE_SIZE_LOG2)); TVMPackedFunc pf; TVMArgs args = TVMArgs_Create(NULL, NULL, 0); TVM_CCALL(TVMPackedFunc_InitGlobalFunc(&pf, "runtime.SystemLib", &args)); @@ -90,7 +96,14 @@ TVM_DLL void tvm_runtime_get_output(void* runtime, int32_t index, DLTensor* tens TVMGraphRuntime_GetOutput(graph_runtime, index, tensor); } -void __attribute__((noreturn)) TVMPlatformAbort(int error_code) { +void TVMLogf(const char* msg, ...) { + va_list args; + va_start(args, msg); + vfprintf(stderr, msg, args); + va_end(args); +} + +void __attribute__((noreturn)) TVMPlatformAbort(tvm_crt_error_t error_code) { fprintf(stderr, "TVMPlatformAbort: %d\n", error_code); #ifdef ENABLE_TVM_ABORT_BACKTRACE tvm_platform_abort_backtrace(); diff --git a/apps/bundle_deploy/bundle_static.c b/apps/bundle_deploy/bundle_static.c index 6e189b663a16..d8b949cae8a5 100644 --- a/apps/bundle_deploy/bundle_static.c +++ b/apps/bundle_deploy/bundle_static.c @@ -17,6 +17,7 @@ * under the License. */ +#include #include #include #include @@ -29,6 +30,11 @@ #endif #include "bundle.h" +#define CRT_MEMORY_NUM_PAGES 16384 +#define CRT_MEMORY_PAGE_SIZE_LOG2 10 + +static uint8_t g_crt_memory[CRT_MEMORY_NUM_PAGES * (1 << CRT_MEMORY_PAGE_SIZE_LOG2)]; + /*! \brief macro to do C API call */ #define TVM_CCALL(func) \ do { \ @@ -56,7 +62,7 @@ TVM_DLL void* tvm_runtime_create(const char* json_data, const char* params_data, ctx.device_id = device_id; // get pointers - TVM_CCALL(TVMInitializeRuntime()); + TVM_CCALL(TVMInitializeRuntime(g_crt_memory, sizeof(g_crt_memory), CRT_MEMORY_PAGE_SIZE_LOG2)); TVMPackedFunc pf; TVMArgs args = TVMArgs_Create(NULL, NULL, 0); TVM_CCALL(TVMPackedFunc_InitGlobalFunc(&pf, "runtime.SystemLib", &args)); @@ -91,7 +97,14 @@ TVM_DLL void tvm_runtime_get_output(void* runtime, int32_t index, DLTensor* tens TVMGraphRuntime_GetOutput(graph_runtime, index, tensor); } -void __attribute__((noreturn)) TVMPlatformAbort(int error_code) { +void TVMLogf(const char* msg, ...) { + va_list args; + va_start(args, msg); + vfprintf(stderr, msg, args); + va_end(args); +} + +void __attribute__((noreturn)) TVMPlatformAbort(tvm_crt_error_t error_code) { fprintf(stderr, "TVMPlatformAbort: %d\n", error_code); #ifdef ENABLE_TVM_PLATFORM_ABORT_BACKTRACE tvm_platform_abort_backtrace(); diff --git a/apps/bundle_deploy/crt_config/crt_config.h b/apps/bundle_deploy/crt_config/crt_config.h index ac06ecf41ca5..97b6c2103f4b 100644 --- a/apps/bundle_deploy/crt_config/crt_config.h +++ b/apps/bundle_deploy/crt_config/crt_config.h @@ -24,6 +24,9 @@ #ifndef TVM_RUNTIME_CRT_CONFIG_H_ #define TVM_RUNTIME_CRT_CONFIG_H_ +/*! Log level of the CRT runtime */ +#define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG + /*! Support low-level debugging in MISRA-C runtime */ #define TVM_CRT_DEBUG 0 @@ -36,29 +39,6 @@ /*! Maximum supported string length in function names */ #define TVM_CRT_STRLEN_NAME 80 -/*! - * \brief Log memory pool size for virtual memory allocation - * - * Here is a list of possible choices: - * * use 16 for 64 KiB memory space - * * use 17 for 128 KiB memory space - * * use 18 for 256 KiB memory space - * * use 19 for 512 KiB memory space - * * use 20 for 1 MiB memory space - * * use 21 for 2 MiB memory space - * * use 22 for 4 MiB memory space - * * use 23 for 8 MiB memory space - * * use 24 for 16 MiB memory space - * * use 25 for 32 MiB memory space - * * use 26 for 64 MiB memory space - * * use 27 for 128 MiB memory space - * * use 28 for 256 MiB memory space - */ -#define TVM_CRT_LOG_VIRT_MEM_SIZE 24 - -/*! \brief Page size for virtual memory allocation */ -#define TVM_CRT_PAGE_BYTES_LOG 12 - /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 diff --git a/apps/bundle_deploy/demo_static.c b/apps/bundle_deploy/demo_static.c index 418ab8ef38a6..b25ad90a2388 100644 --- a/apps/bundle_deploy/demo_static.c +++ b/apps/bundle_deploy/demo_static.c @@ -24,10 +24,14 @@ #include #include -#include "build/graph_c.json.c" -#include "build/params_c.bin.c" #include "bundle.h" +extern const char build_graph_c_json[]; +extern unsigned int build_graph_c_json_len; + +extern const char build_params_c_bin[]; +extern unsigned int build_params_c_bin_len; + #define OUTPUT_LEN 1000 int main(int argc, char** argv) { diff --git a/apps/bundle_deploy/test.cc b/apps/bundle_deploy/test.cc index daadd7a57e2b..c1a7f5d45377 100644 --- a/apps/bundle_deploy/test.cc +++ b/apps/bundle_deploy/test.cc @@ -54,8 +54,9 @@ char* read_all_or_die(const char* name, const char* file_path, size_t* out_size) char* data = (char*)malloc(st.st_size); FILE* fp = fopen(file_path, "rb"); + size_t bytes_to_read = st.st_size; size_t bytes_read = 0; - while (bytes_read < st.st_size) { + while (bytes_read < bytes_to_read) { size_t this_round = fread(data, 1, st.st_size, fp); if (this_round == 0) { if (ferror(fp)) { diff --git a/apps/bundle_deploy/test_static.c b/apps/bundle_deploy/test_static.c index 773ba62140d4..11ca2c44952e 100644 --- a/apps/bundle_deploy/test_static.c +++ b/apps/bundle_deploy/test_static.c @@ -51,7 +51,7 @@ int main(int argc, char** argv) { struct timeval t0, t1, t2, t3, t4, t5; gettimeofday(&t0, 0); - auto* handle = tvm_runtime_create(json_data, params_data, params_size, argv[0]); + void* handle = tvm_runtime_create(json_data, params_data, params_size, argv[0]); gettimeofday(&t1, 0); float input_storage[10 * 5]; diff --git a/cmake/modules/StandaloneCrt.cmake b/cmake/modules/StandaloneCrt.cmake index 8783cd757fe1..770c6a82789d 100644 --- a/cmake/modules/StandaloneCrt.cmake +++ b/cmake/modules/StandaloneCrt.cmake @@ -15,10 +15,8 @@ # specific language governing permissions and limitations # under the License. -if(USE_STANDALONE_CRT) - include(ExternalProject) - - message(STATUS "Build with standalone CRT") +if(USE_MICRO) + message(STATUS "Build standalone CRT for micro TVM") file(GLOB crt_srcs src/runtime/crt/**) function(tvm_crt_add_copy_file var src dest) @@ -32,120 +30,145 @@ if(USE_STANDALONE_CRT) set("${var}" "${${var}}" PARENT_SCOPE) endfunction(tvm_crt_add_copy_file) - # Build an isolated build directory, separate from the TVM tree. - file(GLOB_RECURSE crt_srcs - RELATIVE "${CMAKE_SOURCE_DIR}/src/runtime/crt" - "${CMAKE_SOURCE_DIR}/src/runtime/crt/common/*.c" - "${CMAKE_SOURCE_DIR}/src/runtime/crt/graph_runtime/*.c" - "${CMAKE_SOURCE_DIR}/src/runtime/crt/include/*.h") - - foreach(src IN LISTS crt_srcs) - tvm_crt_add_copy_file(host_isolated_build_deps ${CMAKE_SOURCE_DIR}/src/runtime/crt/${src} standalone_crt/${src}) - endforeach() - - file(GLOB_RECURSE crt_headers RELATIVE "${CMAKE_SOURCE_DIR}/include" include/tvm/runtime/crt/*.h) - foreach(hdr IN LISTS crt_headers) - tvm_crt_add_copy_file(host_isolated_build_deps ${CMAKE_SOURCE_DIR}/include/${hdr} standalone_crt/include/${hdr}) - endforeach() - - tvm_crt_add_copy_file(host_isolated_build_deps - ${CMAKE_SOURCE_DIR}/include/tvm/runtime/c_runtime_api.h standalone_crt/include/tvm/runtime/c_runtime_api.h) - tvm_crt_add_copy_file(host_isolated_build_deps - ${CMAKE_SOURCE_DIR}/include/tvm/runtime/c_backend_api.h standalone_crt/include/tvm/runtime/c_backend_api.h) - tvm_crt_add_copy_file(host_isolated_build_deps - ${CMAKE_SOURCE_DIR}/src/runtime/crt/Makefile standalone_crt/Makefile) - - get_filename_component(crt_config_abspath src/runtime/crt/host/crt_config.h ABSOLUTE) - list(APPEND host_isolated_build_deps src/runtime/crt/host/crt_config.h) - add_custom_target(standalone_crt DEPENDS ${host_isolated_build_deps}) - - get_filename_component(host_build_dir_abspath "${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt" ABSOLUTE) - - if(${VERBOSE}) - set(make_quiet QUIET=) - else(${VERBOSE}) - set(make_quiet ) - endif(${VERBOSE}) - - ExternalProject_Add(host_standalone_crt - DOWNLOAD_COMMAND "" - SOURCE_DIR standalone_crt - CONFIGURE_COMMAND "" - BUILD_COMMAND make - DLPACK_INCLUDE_DIR=${CMAKE_SOURCE_DIR}/3rdparty/dlpack/include - TVM_INCLUDE_DIR=${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include - CRT_CONFIG=${crt_config_abspath} - BUILD_DIR=${host_build_dir_abspath} all ${make_quiet} - BUILD_IN_SOURCE ON - WORKING_DIRECTORY standalone_crt - COMMENT "Building host CRT runtime" - BUILD_BYPRODUCTS host_standalone_crt/common/libcommon.a host_standalone_crt/graph_runtime/libgraph_runtime.a - DEPENDS standalone_crt - INSTALL_COMMAND "" - ) - ExternalProject_Add_StepDependencies(host_standalone_crt build ${host_isolated_build_deps}) -# add_custom_command( -# OUTPUT host_standalone_crt/common/libcommon.a host_standalone_crt/graph_runtime/libgraph_runtime.a -# COMMAND make -# DLPACK_INCLUDE_DIR=${CMAKE_SOURCE_DIR}/3rdparty/dlpack/include -# TVM_INCLUDE_DIR=${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include -# CRT_CONFIG=${crt_config_abspath} -# BUILD_DIR=${host_build_dir_abspath} all ${make_quiet} -# WORKING_DIRECTORY standalone_crt -# DEPENDS ${host_isolated_build_deps}) -# add_custom_target(host_standalone_crt DEPENDS host_standalone_crt/common/libcommon.a host_standalone_crt/graph_runtime/libgraph_runtime.a) - -# # add_custom_target(host_standalone_crt ALL -# # DEPENDS host_standalone_crt/common/libcommon.a host_standalone_crt/graph_runtime/libgraph_runtime.a) - add_library(host_standalone_crt_common STATIC IMPORTED GLOBAL) - add_dependencies(host_standalone_crt_common host_standalone_crt) - set_target_properties(host_standalone_crt_common PROPERTIES - IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/common/libcommon.a" - IMPORTED_OBJECTS "${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/common/libcommon.a" - PUBLIC_HEADER "${crt_headers}") -# add_dependencies(host_standalone_crt_common host_standalone_crt) -# # ${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/common/libcommon.a) - - add_library(host_standalone_crt_graph_runtime STATIC IMPORTED GLOBAL) - add_dependencies(host_standalone_crt_graph_runtime host_standalone_crt) - set_target_properties(host_standalone_crt_graph_runtime PROPERTIES - IMPORTED_LOCATION "${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/graph_runtime/libgraph_runtime.a" - IMPORTED_OBJECTS "${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/graph_runtime/libgraph_runtime.a" - PUBLIC_HEADER "${crt_headers}") -# add_dependencies(host_standalone_crt_graph_runtime host_standalone_crt) -# # ${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/graph_runtime/libgraph_runtime.a) - - # Standalone CRT tests - file(GLOB TEST_SRCS ${CMAKE_SOURCE_DIR}/tests/crt/*.cc) - find_path(GTEST_INCLUDE_DIR gtest/gtest.h) - find_library(GTEST_LIB gtest "$ENV{GTEST_LIB}") - - # Create the `crttest` target if we can find GTest. If not, we create dummy - # targets that give the user an informative error message. - if(GTEST_INCLUDE_DIR AND GTEST_LIB) - foreach(__srcpath ${TEST_SRCS}) - get_filename_component(__srcname ${__srcpath} NAME) - string(REPLACE ".cc" "" __execname ${__srcname}) - add_executable(${__execname} ${__srcpath}) - list(APPEND TEST_EXECS ${__execname}) - target_include_directories(${__execname} PUBLIC ${GTEST_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include ${CMAKE_SOURCE_DIR}/src/runtime/crt/host) - target_compile_options(${__execname} PRIVATE -pthread) -# target_link_directories(${__execname} PRIVATE -# ${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/common -# ${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/graph_runtime) - target_link_libraries(${__execname} host_standalone_crt_graph_runtime host_standalone_crt_common ${GTEST_LIB} pthread) - set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1) - set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1) + function(tvm_crt_define_targets) + # Build an isolated build directory, separate from the TVM tree. + set(CRC16_PATH "3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/libraries/crc16") + list(APPEND CRT_FILE_COPY_JOBS + "${CRC16_PATH} *.h -> include *.c -> src/runtime/crt/utvm_rpc_common" + "3rdparty/dlpack/include *.h -> include" + "3rdparty/dmlc-core/include *.h -> include" + "include/tvm/runtime c_*_api.h -> include/tvm/runtime" + "include/tvm/runtime/crt *.h -> include/tvm/runtime/crt" + "src/runtime/crt Makefile -> ." + "src/runtime/crt/include *.h -> include" + "src/runtime/crt/common *.c -> src/runtime/crt/common" + "src/runtime/crt/graph_runtime *.c -> src/runtime/crt/graph_runtime" + "src/runtime/crt/host crt_config.h -> src/runtime/crt/host" + "src/runtime/crt/utvm_rpc_common *.cc -> src/runtime/crt/utvm_rpc_common" + "src/runtime/crt/utvm_rpc_server *.cc -> src/runtime/crt/utvm_rpc_server" + "src/runtime/minrpc *.h -> src/runtime/minrpc" + "src/support generic_arena.h -> src/support" + ) + + set(standalone_crt_base "${CMAKE_CURRENT_BINARY_DIR}/standalone_crt") + + foreach(job_spec IN LISTS CRT_FILE_COPY_JOBS) + string(REPLACE " " ";" job_spec "${job_spec}") + list(LENGTH job_spec job_spec_length) + math(EXPR job_spec_length_mod "${job_spec_length} % 3") + if(NOT "${job_spec_length_mod}" EQUAL 1) + message(FATAL_ERROR "CRT copy job spec list length is ${job_spec_length}; parsed job spec is ${job_spec}") + endif() + math(EXPR job_spec_stop "${job_spec_length} - 3") + + list(GET job_spec 0 job_src_base) + set(job_src_base "${CMAKE_SOURCE_DIR}/${job_src_base}") + foreach(copy_pattern_index RANGE 1 "${job_spec_stop}" 3) + list(GET job_spec ${copy_pattern_index} copy_pattern) + math(EXPR copy_dest_index "${copy_pattern_index} + 2") + list(GET job_spec ${copy_dest_index} copy_dest) + + file(GLOB_RECURSE copy_files + RELATIVE "${job_src_base}" + "${job_src_base}/${copy_pattern}") + list(LENGTH copy_files copy_files_length) + if("${copy_files_length}" EQUAL 0) + message(FATAL_ERROR "CRT copy job matched 0 files: ${job_src_base}/${copy_pattern} -> ${copy_dest}") + endif() + foreach(copy_src IN LISTS copy_files) + get_filename_component(dest_path "${standalone_crt_base}/${copy_dest}/${copy_src}" ABSOLUTE) + tvm_crt_add_copy_file(host_isolated_build_deps ${job_src_base}/${copy_src} ${dest_path}) + endforeach() + endforeach() + endforeach() + + add_custom_target(standalone_crt DEPENDS ${host_isolated_build_deps}) + + get_filename_component(host_build_dir_abspath "${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt" ABSOLUTE) + + if(${VERBOSE}) + set(make_quiet QUIET=) + else(${VERBOSE}) + set(make_quiet ) + endif(${VERBOSE}) + + list(APPEND crt_libraries graph_runtime utvm_rpc_server utvm_rpc_common common) # NOTE: listed in link order. + foreach(crt_lib_name IN LISTS crt_libraries) + list(APPEND crt_library_paths "host_standalone_crt/lib${crt_lib_name}.a") endforeach() - add_custom_target(crttest DEPENDS ${TEST_EXECS}) - elseif(NOT GTEST_INCLUDE_DIR) - add_custom_target(crttest - COMMAND echo "Missing Google Test headers in include path" - COMMAND exit 1) - elseif(NOT GTEST_LIB) - add_custom_target(crttest - COMMAND echo "Missing Google Test library" - COMMAND exit 1) + + set(make_common_args + "DLPACK_INCLUDE_DIR=${CMAKE_SOURCE_DIR}/3rdparty/dlpack/include" + "TVM_INCLUDE_DIR=${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include" + "CRT_CONFIG=src/runtime/crt/host/crt_config.h" + "BUILD_DIR=${host_build_dir_abspath}" + "EXTRA_CFLAGS=-fPIC" + "EXTRA_CXXFLAGS=-fPIC" + "EXTRA_LDFLAGS=-fPIC" + "${make_quiet}") + + add_custom_command( + OUTPUT ${crt_library_paths} + COMMAND make ARGS ${make_common_args} clean + COMMAND make ARGS ${make_common_args} all + WORKING_DIRECTORY "${standalone_crt_base}" + DEPENDS standalone_crt ${host_isolated_build_deps}) + + add_custom_target(host_standalone_crt DEPENDS ${crt_library_paths}) + + foreach(crt_lib IN LISTS crt_libraries) + set(cmake_crt_lib_name host_standalone_crt_${crt_lib}) + list(APPEND cmake_crt_libraries ${cmake_crt_lib_name}) + add_library(${cmake_crt_lib_name} STATIC IMPORTED GLOBAL) + set(cmake_crt_lib_path "${CMAKE_CURRENT_BINARY_DIR}/host_standalone_crt/lib${crt_lib}.a") + add_dependencies(${cmake_crt_lib_name} host_standalone_crt "${cmake_crt_lib_path}") + set_target_properties(${cmake_crt_lib_name} PROPERTIES + IMPORTED_LOCATION "${cmake_crt_lib_path}" + IMPORTED_OBJECTS "${cmake_crt_lib_path}" + PUBLIC_HEADER "${crt_headers}") + endforeach() + + # Standalone CRT tests + file(GLOB TEST_SRCS ${CMAKE_SOURCE_DIR}/tests/crt/*_test.cc) + find_path(GTEST_INCLUDE_DIR gtest/gtest.h) + find_library(GTEST_LIB gtest "$ENV{GTEST_LIB}") + + # Create the `crttest` target if we can find GTest. If not, we create dummy + # targets that give the user an informative error message. + if(GTEST_INCLUDE_DIR AND GTEST_LIB) + foreach(__srcpath ${TEST_SRCS}) + get_filename_component(__srcname ${__srcpath} NAME) + string(REPLACE ".cc" "" __execname ${__srcname}) + add_executable(${__execname} ${__srcpath}) + list(APPEND TEST_EXECS ${__execname}) + target_include_directories(${__execname} PUBLIC ${GTEST_INCLUDE_DIR} ${CMAKE_CURRENT_BINARY_DIR}/standalone_crt/include ${CMAKE_SOURCE_DIR}/src/runtime/crt/host) + target_compile_options(${__execname} PRIVATE -pthread) + target_link_libraries(${__execname} ${cmake_crt_libraries} ${GTEST_LIB} pthread) + set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_ALL 1) + set_target_properties(${__execname} PROPERTIES EXCLUDE_FROM_DEFAULT_BUILD 1) + endforeach() + add_custom_target(crttest DEPENDS ${TEST_EXECS}) + elseif(NOT GTEST_INCLUDE_DIR) + add_custom_target(crttest + COMMAND echo "Missing Google Test headers in include path" + COMMAND exit 1) + elseif(NOT GTEST_LIB) + add_custom_target(crttest + COMMAND echo "Missing Google Test library" + COMMAND exit 1) + endif() + + endfunction() + + tvm_crt_define_targets() + + set(TVM_CRT_LINKER_LIB host_standalone_crt_utvm_rpc_common) + if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") + list(APPEND TVM_RUNTIME_LINKER_LIBS -Wl,--whole-archive ${TVM_CRT_LINKER_LIB} -Wl,--no-whole-archive) + elseif("${CMAKE_CXX_COMPILER_ID}" MATCHES ".*Clang") + list(APPEND TVM_RUNTIME_LINKER_LIBS -Wl,-force_load $) + else() + list(APPEND TVM_RUNTIME_LINKER_LIBS ${TVM_CRT_LINKER_LIB}) endif() -endif(USE_STANDALONE_CRT) +endif(USE_MICRO) diff --git a/include/tvm/runtime/crt/crt.h b/include/tvm/runtime/crt/crt.h index c2e2af4ca5de..f0b8345c7b49 100644 --- a/include/tvm/runtime/crt/crt.h +++ b/include/tvm/runtime/crt/crt.h @@ -25,6 +25,7 @@ #ifndef TVM_RUNTIME_CRT_CRT_H_ #define TVM_RUNTIME_CRT_CRT_H_ +#include #include #ifdef __cplusplus @@ -33,10 +34,14 @@ extern "C" { /*! * \brief Initialize various data structures used by the rutnime. + * \param memory_pool Pointer to the global memory pool used by the CRT. + * \param memory_pool_size_bytes Size of `memory_pool`, in bytes. + * \param page_size_bytes_log2 log2 of the page size, in bytes. * \return An error code describing the outcome of intialization. Generally, initialization * is only expected to fail due to a misconfiguration. */ -tvm_crt_error_t TVMInitializeRuntime(void); +tvm_crt_error_t TVMInitializeRuntime(uint8_t* memory_pool, size_t memory_pool_size_bytes, + size_t page_size_bytes_log2); #ifdef __cplusplus } // extern "C" diff --git a/include/tvm/runtime/crt/error_codes.h b/include/tvm/runtime/crt/error_codes.h index aae4550a5792..e01304061313 100644 --- a/include/tvm/runtime/crt/error_codes.h +++ b/include/tvm/runtime/crt/error_codes.h @@ -35,7 +35,13 @@ extern "C" { #define DEFINE_TVM_CRT_ERROR(category, code) \ (((category) << TVM_CRT_ERROR_CATEGORY_Pos) | ((code) << TVM_CRT_ERROR_CODE_Pos)) -typedef enum { kTvmErrorCategoryFunctionRegistry = 1 } tvm_crt_error_category_t; +typedef enum { + kTvmErrorCategoryFunctionRegistry = 1, + kTvmErrorCategoryFraming = 2, + kTvmErrorCategoryWriteStream = 3, + kTvmErrorCategorySession = 4, + kTvmErrorCategoryPlatform = 5, +} tvm_crt_error_category_t; typedef enum { kTvmErrorNoError = 0, @@ -46,6 +52,30 @@ typedef enum { kTvmErrorFunctionRegistryFull = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionRegistry, 2), kTvmErrorFunctionAlreadyDefined = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionRegistry, 3), kTvmErrorBufferTooSmall = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFunctionRegistry, 4), + + // Framing + kTvmErrorFramingInvalidState = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFraming, 0), + kTvmErrorFramingShortPacket = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFraming, 1), + kTvmErrorFramingInvalidEscape = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFraming, 2), + kTvmErrorFramingPayloadOverflow = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFraming, 3), + kTvmErrorFramingPayloadIncomplete = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryFraming, 4), + + // Write stream + kTvmErrorWriteStreamShortWrite = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryWriteStream, 0), + kTvmErrorWriteStreamLongWrite = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryWriteStream, 1), + + // Session + kTvmErrorSessionInvalidState = DEFINE_TVM_CRT_ERROR(kTvmErrorCategorySession, 0), + kTvmErrorSessionReceiveBufferBusy = DEFINE_TVM_CRT_ERROR(kTvmErrorCategorySession, 1), + kTvmErrorSessionReceiveBufferShortWrite = DEFINE_TVM_CRT_ERROR(kTvmErrorCategorySession, 2), + + // Platform + kTvmErrorPlatformCheckFailure = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 0), + kTvmErrorPlatformMemoryManagerInitialized = DEFINE_TVM_CRT_ERROR(kTvmErrorCategoryPlatform, 1), + + // System errors are always negative integers; this mask indicates presence of a system error. + // Cast tvm_crt_error_t to a signed integer to interpret the negative error code. + kTvmErrorSystemErrorMask = (1 << (sizeof(int) * 4 - 1)), } tvm_crt_error_t; #ifdef __cplusplus diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/common/logging.h b/include/tvm/runtime/crt/logging.h similarity index 52% rename from src/runtime/crt/include/tvm/runtime/crt/internal/common/logging.h rename to include/tvm/runtime/crt/logging.h index 17fbe32a1f2c..e955739ee80e 100644 --- a/src/runtime/crt/include/tvm/runtime/crt/internal/common/logging.h +++ b/include/tvm/runtime/crt/logging.h @@ -18,31 +18,55 @@ */ /*! - * \file runtime/crt/include/tvm/runtime/crt/internal/common/logging.h + * \file runtime/crt/logging.h * \brief A replacement of the dmlc logging system that avoids * the usage of GLOG and C++ headers */ -#ifndef TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_COMMON_LOGGING_H_ -#define TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_COMMON_LOGGING_H_ +#ifndef TVM_RUNTIME_CRT_LOGGING_H_ +#define TVM_RUNTIME_CRT_LOGGING_H_ + +#include + +#define TVM_CRT_LOG_LEVEL_DEBUG 3 +#define TVM_CRT_LOG_LEVEL_INFO 2 +#define TVM_CRT_LOG_LEVEL_WARN 1 +#define TVM_CRT_LOG_LEVEL_ERROR 0 + +#ifdef __cplusplus +extern "C" { +#endif + +void __attribute__((format(printf, 1, 2))) TVMLogf(const char* fmt, ...); + +#define LOG(level, x, ...) \ + if (TVM_CRT_LOG_LEVEL >= level) { \ + TVMLogf(x, ##__VA_ARGS__); \ + } + +#define LOG_ERROR(x, ...) LOG(TVM_CRT_LOG_LEVEL_ERROR, x, ##__VA_ARGS__) +#define LOG_WARN(x, ...) LOG(TVM_CRT_LOG_LEVEL_WARN, x, ##__VA_ARGS__) +#define LOG_INFO(x, ...) LOG(TVM_CRT_LOG_LEVEL_INFO, x, ##__VA_ARGS__) +#define LOG_DEBUG(x, ...) LOG(TVM_CRT_LOG_LEVEL_DEBUG, x, ##__VA_ARGS__) #ifndef CHECK -#define CHECK(x) \ - do { \ - if (!(x)) { \ - fprintf(stderr, "Check failed: %s\n", #x); \ - exit(-1); \ - } \ +#define CHECK(x) \ + do { \ + if (!(x)) { \ + LOG_ERROR(__FILE__ ":%d: Check failed: %s\n", __LINE__, #x); \ + TVMPlatformAbort(kTvmErrorPlatformCheckFailure); \ + } \ } while (0) #endif #ifndef CHECK_BINARY_OP -#define CHECK_BINARY_OP(op, x, y, fmt, ...) \ - do { \ - if (!(x op y)) { \ - fprintf(stderr, "Check failed: %s %s %s: " fmt "\n", #x, #op, #y, ##__VA_ARGS__); \ - exit(-1); \ - } \ +#define CHECK_BINARY_OP(op, x, y, fmt, ...) \ + do { \ + if (!(x op y)) { \ + LOG_ERROR(__FILE__ ":%d: Check failed: %s %s %s: " fmt "\n", __LINE__, #x, #op, #y, \ + ##__VA_ARGS__); \ + TVMPlatformAbort(kTvmErrorPlatformCheckFailure); \ + } \ } while (0) #endif @@ -70,4 +94,8 @@ #define CHECK_NE(x, y, fmt, ...) CHECK_BINARY_OP(!=, x, y, fmt, ##__VA_ARGS__) #endif -#endif // TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_COMMON_LOGGING_H_ +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TVM_RUNTIME_CRT_LOGGING_H_ diff --git a/include/tvm/runtime/crt/platform.h b/include/tvm/runtime/crt/platform.h index 6897a53cfc1b..782060dfd000 100644 --- a/include/tvm/runtime/crt/platform.h +++ b/include/tvm/runtime/crt/platform.h @@ -25,6 +25,8 @@ #ifndef TVM_RUNTIME_CRT_PLATFORM_H_ #define TVM_RUNTIME_CRT_PLATFORM_H_ +#include + #ifdef __cplusplus extern "C" { #endif @@ -35,7 +37,7 @@ extern "C" { * * \param code An error code. */ -void __attribute__((noreturn)) TVMPlatformAbort(int code); +void __attribute__((noreturn)) TVMPlatformAbort(tvm_crt_error_t code); #ifdef __cplusplus } // extern "C" diff --git a/include/tvm/runtime/crt/rpc_common/frame_buffer.h b/include/tvm/runtime/crt/rpc_common/frame_buffer.h new file mode 100644 index 000000000000..0d264e313a1d --- /dev/null +++ b/include/tvm/runtime/crt/rpc_common/frame_buffer.h @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/crt/rpc_common/frame_buffer.h + * \brief Defines a buffer for use by the RPC framing layer. + */ + +#ifndef TVM_RUNTIME_CRT_RPC_COMMON_FRAME_BUFFER_H_ +#define TVM_RUNTIME_CRT_RPC_COMMON_FRAME_BUFFER_H_ + +#include +#include + +namespace tvm { +namespace runtime { +namespace micro_rpc { + +class FrameBuffer { + public: + FrameBuffer(uint8_t* data, size_t data_size_bytes) + : data_{data}, capacity_{data_size_bytes}, num_valid_bytes_{0}, read_cursor_{0} {} + + size_t Write(const uint8_t* data, size_t data_size_bytes); + + size_t Read(uint8_t* data, size_t data_size_bytes); + + size_t Peek(uint8_t* data, size_t data_size_bytes); + + void Clear(); + + size_t ReadAvailable() const { return num_valid_bytes_ - read_cursor_; } + + size_t Size() const { return num_valid_bytes_; } + + private: + /*! \brief pointer to data buffer. */ + uint8_t* data_; + + /*! \brief The total number of bytes available in data_. Always a power of 2. */ + size_t capacity_; + + /*! \brief index into data_ of the next potentially-available byte in the buffer. + * The byte is available when tail_ != data_ + capacity_. + */ + size_t num_valid_bytes_; + + /*! \brief Read cursor position. */ + size_t read_cursor_; +}; + +} // namespace micro_rpc +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CRT_RPC_COMMON_FRAME_BUFFER_H_ diff --git a/include/tvm/runtime/crt/rpc_common/framing.h b/include/tvm/runtime/crt/rpc_common/framing.h new file mode 100644 index 000000000000..a6b9cd349088 --- /dev/null +++ b/include/tvm/runtime/crt/rpc_common/framing.h @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file framing.h + * \brief Framing for RPC. + */ + +#ifndef TVM_RUNTIME_CRT_RPC_COMMON_FRAMING_H_ +#define TVM_RUNTIME_CRT_RPC_COMMON_FRAMING_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace micro_rpc { + +enum class Escape : uint8_t { kEscapeStart = 0xff, kEscapeNop = 0xfe, kPacketStart = 0xfd }; + +class PacketFieldSizeBytes { + public: + static constexpr const size_t kPayloadLength = sizeof(uint32_t); + static constexpr const size_t kCrc = sizeof(uint16_t); +}; + +class Unframer { + public: + explicit Unframer(WriteStream* stream) + : stream_{stream}, + state_{State::kFindPacketStart}, + saw_escape_start_{false}, + num_buffer_bytes_valid_{0} {} + + /*! + * \brief Push data into unframer and try to decode one packet. + * + * This function will return when exactly one packet has been decoded. It may not consume all of + * `data` in this case, and valid bytes may remain at the end of data. + * + * \param data The new data to unframe and send downstream. + * \param data_size_bytes The number of valid bytes in data. + * \param bytes_consumed Pointer written with the number of bytes consumed from data. + * \return + * - kTvmErrorNoError when successful -- continue writing data. + * - kTvmErrorFramingInvalidState when the Unframer was in or enters an invalid state + * (probably indicates memory corruption). + * - kTvmErrorFramingShortPacket when a new packet started before the current one ended. + * - kTvmErrorFramingInvalidEscape when an invalid escape sequence was seen + */ + tvm_crt_error_t Write(const uint8_t* data, size_t data_size_bytes, size_t* bytes_consumed); + + /*! \brief Reset unframer to initial state. */ + void Reset(); + + /*! \brief Return an underestimate of the number of bytes needed from the wire. */ + size_t BytesNeeded(); + + private: + tvm_crt_error_t FindPacketStart(); + tvm_crt_error_t FindPacketLength(); + tvm_crt_error_t FindPacketCrc(); + tvm_crt_error_t FindCrcEnd(); + + bool IsBufferFull(size_t buffer_full_bytes) { + return num_buffer_bytes_valid_ >= buffer_full_bytes; + } + + /*! \brief Consume input into buffer_ until buffer_ has buffer_full_bytes. */ + tvm_crt_error_t AddToBuffer(size_t buffer_full_bytes, bool update_crc); + + void ClearBuffer(); + + /*! \brief Unescape and consume input bytes, storing into buffer. + * + * \param buffer A buffer to fill with consumed, unescaped bytes. + * \param buffer_size_bytes Size of buffer, in bytes. + * \param bytes_filled A pointer to an accumulator to which is added the number of bytes written + * to `buffer`. + * \param update_crc true when the CRC should be updated with the escaped bytes. + * \return + * - kTvmErrorNoError if successful + * - kTvmErrorFramingShortPacket if a start-of-packet escape code was encountered. If so, + * *bytes_filled indicates the number of bytes before the Escape::kEscapeStart byte. + * - kTvmErrorFramingInvalidEscape if an invalid escape sequence was seen. + * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write() + * function returns 0. + * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write() + * function returns an invalid positive number. + * - Any negative value (i.e. with bits in kTvmErrorSystemErrorMask set) returned by the + * WriteStream's Write() function. + */ + tvm_crt_error_t ConsumeInput(uint8_t* buffer, size_t buffer_size_bytes, size_t* bytes_filled, + bool update_crc); + + WriteStream* stream_; + + enum class State : uint8_t { + kFindPacketStart = 0, + kFindPacketLength = 1, + kFindPacketCrc = 2, + kFindCrcEnd = 3, + }; + State state_; + + const uint8_t* input_; + size_t input_size_bytes_; + + bool saw_escape_start_; + + /*! \brief unframe buffer, sized to the longest framing field. */ + uint8_t buffer_[128]; + + /*! \brief number of bytes in buffer that are currently valid. */ + size_t num_buffer_bytes_valid_; + + /*! \brief number of payload bytes left to write before the CRC begins. */ + size_t num_payload_bytes_remaining_; + + /*! \brief Running CRC value. */ + uint16_t crc_; +}; + +class Framer { + public: + typedef ssize_t (*WriteFunc)(const uint8_t* data, size_t data_size_bytes); + + explicit Framer(WriteStream* stream) + : stream_{stream}, state_{State::kReset}, num_payload_bytes_remaining_{0} {} + + /*! \brief Frame and write a full packet. + * \param payload The entire packet payload. + * \param payload_size_bytes Number of bytes in the packet. + * \return + * - kTvmErrorNoError when no error occurs + * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write() + * function returns 0. + * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write() + * function returns an invalid positive number. + * - Any negative value (i.e. with bits in kTvmErrorSystemErrorMask set) returned by the + * WriteStream's Write() function. + */ + tvm_crt_error_t Write(const uint8_t* payload, size_t payload_size_bytes); + + /*! \brief Start framing and writing a new packet to the wire. + * + * When transmitting payloads that are too large to be buffered, call this function first to send + * the packet header and length fields. + * + * \param payload_size_bytes Number of payload bytes included as part of this packet. + * \return + * - kTvmErrorNoError when no error occurs + * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write() + * function returns 0. + * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write() + * function returns an invalid positive number. + * - Any negative value (i.e. with bits in kTvmErrorSystemErrorMask set) returned by the + * WriteStream's Write() function. + */ + tvm_crt_error_t StartPacket(size_t payload_size_bytes); + + /*! \brief Write payload data to the wire. + * + * When transmitting payloads that are too large to be buffered, call this function after calling + * StartPacket to escape and transmit framed payloads. This function can be called multiple times + * for a single packet. + * + * \param payload_chunk A piece of the packet payload. + * \param payload_chunk_size_bytes Number of valid bytes in payload_chunk. + * \return + * - kTvmErrorNoError when no error occurs + * - kTvmErrorFramingInvalidState when StartPacket() has not been called. + * - kTvmErrorFramingPayloadOverflow when more bytes were requested to be written than were + * declared in the payload_size_bytes parameter given to StartPacket(). + * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write() + * function returns 0. + * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write() + * function returns an invalid positive number. + * - Any negative value (i.e. with bits in kTvmErrorSystemErrorMask set) returned by the + * WriteStream's Write() function. + */ + tvm_crt_error_t WritePayloadChunk(const uint8_t* payload_chunk, size_t payload_chunk_size_bytes); + + /* \brief Finish writing one packet by sending the CRC. + * + * When transmitting paylaods that are too large to be buffered, call this function after sending + * the entire payload using WritePayloadChunk. + * + * \return + * - kTvmErrorNoError when no error occurs + * - kTvmErrorFramingInvalidState when StartPacket() has not been called. + * - kTvmErrorFramingPayloadIncomplete when less bytes were written using WritePayloadChunk() + * than were declared in the payload_size_bytes parameter given to StartPacket(). + * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write() + * function returns 0. + * - kTvmErrorWriteStreamShortWrite if the WriteStream passed to constructor's Write() + * function returns an invalid positive number. + * - Any negative value (i.e. with bits in kTvmErrorSystemErrorMask set) returned by the + * WriteStream's Write() function. + */ + tvm_crt_error_t FinishPacket(); + + /* \brief Reset state of the Framer. */ + void Reset(); + + private: + /*! \brief Maximum size of stack-based buffer. */ + static constexpr const size_t kMaxStackBufferSizeBytes = 128; + + enum class State : uint8_t { + /*! \brief State entered at construction time or after write error, before first packet sent. */ + kReset = 0, + + /*! \brief State entered after a packet has successfully finished transmitting. */ + kIdle = 1, + + /*! \brief State entered when a packet payload or CRC needs to be transmitted. */ + kTransmitPacketPayload = 2, + }; + + /*! + * \brief Escape data and write the result to wire, and update crc_. + * + * \param data Unescaped data to write. + * \param data_size_bytes Number of valid bytes in data. + * \param escape true if escaping should be applied. + * \param update_crc true if escaping should be applied. + * \return kTvmErrorNoError on success, negative value on error. + */ + tvm_crt_error_t WriteAndCrc(const uint8_t* data, size_t data_size_bytes, bool escape, + bool update_crc); + + /*! \brief Called to write framed data to the transport. */ + WriteStream* stream_; + + /*! \brief State fo the Framer. */ + State state_; + + /*! \brief When state_ == kTransmitPacketPayload, number of payload bytes left to transmit. */ + size_t num_payload_bytes_remaining_; + + /*! \brief Running CRC value. */ + uint16_t crc_; +}; + +} // namespace micro_rpc +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CRT_RPC_COMMON_FRAMING_H_ diff --git a/include/tvm/runtime/crt/rpc_common/session.h b/include/tvm/runtime/crt/rpc_common/session.h new file mode 100644 index 000000000000..9e6a9f380554 --- /dev/null +++ b/include/tvm/runtime/crt/rpc_common/session.h @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file session.h + * \brief RPC Session + */ + +#ifndef TVM_RUNTIME_CRT_RPC_COMMON_SESSION_H_ +#define TVM_RUNTIME_CRT_RPC_COMMON_SESSION_H_ + +#include +#include +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace micro_rpc { + +enum class MessageType : uint8_t { + kStartSessionInit = 0x00, + kStartSessionReply = 0x01, + kTerminateSession = 0x02, + kLog = 0x03, + kNormal = 0x10, +}; + +typedef struct SessionHeader { + uint16_t session_id; + MessageType message_type; +} __attribute__((packed)) SessionHeader; + +/*! + * \brief CRT communication session management class. + * Assumes the following properties provided by the underlying transport: + * - in-order delivery. + * - reliable delivery. + * + * Specifically, designed for use with UARTs. Will probably work over semihosting, USB, and TCP; + * will probably not work reliably enough over UDP. + */ +class Session { + public: + /*! \brief Callback invoked when a full message is received. + * + * This function is called in the following situations: + * - When a new session is established (this typically indicates the remote end reset). + * In this case, buf is NULL. + * - When a log message or normal traffic is received. In this case, buf points to a + * valid buffer containing the message content. + * + * \param context The value of `message_received_func_context` passed to the constructor. + * \param message_type The type of session message received. Currently, this is always + * either kNormal or kLog. + * \param buf When message_type is not kStartSessionMessage, a FrameBuffer whose read cursor is + * at the first byte of the message payload. Otherwise, NULL. + */ + typedef void (*MessageReceivedFunc)(void* context, MessageType message_type, FrameBuffer* buf); + + /*! \brief An invalid nonce value that typically indicates an unknown nonce. */ + static constexpr const uint8_t kInvalidNonce = 0; + + Session(uint8_t initial_session_nonce, Framer* framer, FrameBuffer* receive_buffer, + MessageReceivedFunc message_received_func, void* message_received_func_context) + : local_nonce_{initial_session_nonce}, + session_id_{0}, + state_{State::kReset}, + receiver_{this}, + framer_{framer}, + receive_buffer_{receive_buffer}, + receive_buffer_has_complete_message_{false}, + message_received_func_{message_received_func}, + message_received_func_context_{message_received_func_context} { + // Session can be used for system startup logging, before the RPC server is instantiated. In + // this case, allow receive_buffer_ to be nullptr. The instantiator agrees not to use + // Receiver(). + if (receive_buffer_ != nullptr) { + receive_buffer_->Clear(); + } + } + + /*! + * \brief Send a session terminate message, usually done at startup to interrupt a hanging remote. + * \return kTvmErrorNoError on success, or an error code otherwise. + */ + tvm_crt_error_t Initialize(); + + /*! + * \brief Terminate any previously-established session. + * \return kTvmErrorNoError on success, or an error code otherwise. + */ + tvm_crt_error_t TerminateSession(); + + /*! + * \brief Start a new session regardless of state. Sends kStartSessionMessage. + * + * Generally speaking, this function should be called once per device reset by exactly one side + * in the system. No traffic can flow until this function is called. + * + * \return kTvmErrorNoError on success, or an error code otherwise. + */ + tvm_crt_error_t StartSession(); + + /*! + * \brief Obtain a WriteStream implementation for use by the framing layer. + * \return A WriteStream to which received data should be written. Owned by this class. + */ + WriteStream* Receiver() { return &receiver_; } + + /*! + * \brief Send a full message including header, payload, and CRC footer. + * \param message_type One of MessageType; distinguishes the type of traffic at the session layer. + * \param message_data The data contained in the message. + * \param message_size_bytes The number of valid bytes in message_data. + * \return kTvmErrorNoError on success, or an error code otherwise. + */ + tvm_crt_error_t SendMessage(MessageType message_type, const uint8_t* message_data, + size_t message_size_bytes); + + /*! + * \brief Send the framing and session layer headers. + * + * This function allows messages to be sent in pieces. + * + * \param message_type One of MessageType; distinguishes the type of traffic at the session layer. + * \param message_size_bytes The size of the message body, in bytes. Excludes the framing and + * session layer headers. \return 0 on success, negative error code on failure. + * \return kTvmErrorNoError on success, or an error code otherwise. + */ + tvm_crt_error_t StartMessage(MessageType message_type, size_t message_size_bytes); + + /*! + * \brief Send a part of the message body. + * + * This function allows messages to be sent in pieces. + * + * \param chunk_data The data contained in this message body chunk. + * \param chunk_size_bytes The number of valid bytes in chunk_data. + * \return kTvmErrorNoError on success, or an error code otherwise. + */ + tvm_crt_error_t SendBodyChunk(const uint8_t* chunk_data, size_t chunk_size_bytes); + + /*! + * \brief Finish sending the message by sending the framing layer footer. + * \return kTvmErrorNoError on success, or an error code otherwise. + */ + tvm_crt_error_t FinishMessage(); + + /*! \brief Returns true if the session is in the established state. */ + bool IsEstablished() const { return state_ == State::kSessionEstablished; } + + /*! + * \brief Clear the receive buffer and prepare to receive next message. + * + * Call this function after MessageReceivedFunc is invoked. Any SessionReceiver::Write() calls + * made will return errors until this function is called to prevent them from corrupting the + * valid message in the receive buffer. + */ + void ClearReceiveBuffer(); + + /*! \brief A version number used to check compatibility of the remote session implementation. */ + static const constexpr uint8_t kVersion = 0x01; + + private: + class SessionReceiver : public WriteStream { + public: + explicit SessionReceiver(Session* session) : session_{session} {} + virtual ~SessionReceiver() {} + + ssize_t Write(const uint8_t* data, size_t data_size_bytes) override; + void PacketDone(bool is_valid) override; + + private: + void operator delete(void*) noexcept {} // NOLINT(readability/casting) + Session* session_; + }; + + enum class State : uint8_t { + kReset = 0, + kNoSessionEstablished = 1, + kStartSessionSent = 2, + kSessionEstablished = 3, + }; + + void RegenerateNonce(); + + tvm_crt_error_t SendInternal(MessageType message_type, const uint8_t* message_data, + size_t message_size_bytes); + + void SendSessionStartReply(const SessionHeader& header); + + void ProcessStartSessionInit(const SessionHeader& header); + + void ProcessStartSessionReply(const SessionHeader& header); + + void OnSessionEstablishedMessage(); + + void OnSessionTerminatedMessage(); + + void SetSessionId(uint8_t initiator_nonce, uint8_t responder_nonce) { + session_id_ = initiator_nonce | (((uint16_t)responder_nonce) << 8); + } + + uint8_t InitiatorNonce(uint16_t session_id) { return session_id & 0xff; } + + uint8_t ResponderNonce(uint16_t session_id) { return (session_id >> 8) & 0xff; } + + uint8_t local_nonce_; + uint16_t session_id_; + State state_; + SessionReceiver receiver_; + Framer* framer_; + FrameBuffer* receive_buffer_; + bool receive_buffer_has_complete_message_; + MessageReceivedFunc message_received_func_; + void* message_received_func_context_; +}; + +} // namespace micro_rpc +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CRT_RPC_COMMON_SESSION_H_ diff --git a/src/runtime/micro/device/riscv_spike/utvm_timer.c b/include/tvm/runtime/crt/rpc_common/write_stream.h similarity index 53% rename from src/runtime/micro/device/riscv_spike/utvm_timer.c rename to include/tvm/runtime/crt/rpc_common/write_stream.h index 78c811979d43..cdc579585993 100644 --- a/src/runtime/micro/device/riscv_spike/utvm_timer.c +++ b/include/tvm/runtime/crt/rpc_common/write_stream.h @@ -18,23 +18,33 @@ */ /*! - * \file utvm_timer.c - * \brief uTVM timer API stubs for Spike + * \file framing.h + * \brief Framing for RPC. */ -#ifdef __cplusplus -extern "C" { -#endif +#ifndef TVM_RUNTIME_CRT_RPC_COMMON_WRITE_STREAM_H_ +#define TVM_RUNTIME_CRT_RPC_COMMON_WRITE_STREAM_H_ -#include "utvm_runtime.h" +#include +#include +#include +#include -int32_t UTVMTimerStart() { return UTVM_ERR_OK; } +namespace tvm { +namespace runtime { +namespace micro_rpc { -uint32_t UTVMTimerStop(int32_t* err) { - *err = UTVM_ERR_OK; - return 0; -} +class WriteStream { + public: + virtual ~WriteStream(); + virtual ssize_t Write(const uint8_t* data, size_t data_size_bytes) = 0; + virtual void PacketDone(bool is_valid) = 0; -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif + tvm_crt_error_t WriteAll(uint8_t* data, size_t data_size_bytes, size_t* bytes_consumed); +}; + +} // namespace micro_rpc +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_CRT_RPC_COMMON_WRITE_STREAM_H_ diff --git a/include/tvm/runtime/crt/utvm_rpc_server.h b/include/tvm/runtime/crt/utvm_rpc_server.h new file mode 100644 index 000000000000..314463ac8652 --- /dev/null +++ b/include/tvm/runtime/crt/utvm_rpc_server.h @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file utvm_rpc_server.h + * \brief MicroTVM RPC Server + */ + +#ifndef TVM_RUNTIME_CRT_UTVM_RPC_SERVER_H_ +#define TVM_RUNTIME_CRT_UTVM_RPC_SERVER_H_ + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief TVM RPC channel write function. + * + * Tries to write `num_bytes` from `data` to the underlying channel. + * \param data Pointer to data to write. + * \param num_bytes Number of bytes avaiable in data. + * \return The number of bytes written. + */ +typedef ssize_t (*utvm_rpc_channel_write_t)(void* context, const uint8_t* data, size_t num_bytes); + +/*! \brief Opaque pointer type to TVM RPC Server. */ +typedef void* utvm_rpc_server_t; + +/*! \brief Initialize the TVM RPC Server. + * + * Call this on device startup before calling anyother utvm_rpc_server_ functions. + * + * \param memory A memory block used by the runtime as dynamic memory, primarily to allocate + * tensors. + * \param memory_size_bytes Size of the memory block, in bytes. Should be a multiple of + * (1 << page_size_bytes_log2) + * \param page_size_bytes_log2 Log2 of the size of each memory page. The internal allocator + * allocates one page at a time; more pages reduces waste but + * increases overhead. + * \param write_func A callback function invoked by the TVM RPC Server to write data back to the + * host. Internally, the TVM RPC Server will block until all data in a reply + * packet has been written. + * \param write_func_ctx An opaque pointer passed to write_func when it is called. + * \return A pointer to the TVM RPC Server. The pointer is allocated in the same memory space as + * the TVM workspace. + */ +utvm_rpc_server_t UTvmRpcServerInit(uint8_t* memory, size_t memory_size_bytes, + size_t page_size_bytes_log2, + utvm_rpc_channel_write_t write_func, void* write_func_ctx); + +/*! \brief Copy received data into an internal buffer for processing. + * + * Currently only handles 1 byte of data. In the future, the goal of this function is to be safe to + * invoke from an ISR. At that time, this function will just append to an internal buffer. + * + * \param server The TVM RPC Server pointer. + * \param byte The received byte of data. + * \return The number of bytes copied to the internal buffer. May be less than data_size_bytes when + * the internal buffer fills. + */ +size_t UTvmRpcServerReceiveByte(utvm_rpc_server_t server, uint8_t byte); + +/*! \brief Perform normal processing of received data. + * + * \param server The TVM RPC Server pointer. + * \return true while the server is still running. false when it shuts down gracefully. + */ +bool UTvmRpcServerLoop(utvm_rpc_server_t server); + +#ifdef __cplusplus +} +#endif + +#endif // TVM_RUNTIME_CRT_UTVM_RPC_SERVER_H_ diff --git a/python/tvm/micro/__init__.py b/python/tvm/micro/__init__.py index 7c1389cc4eef..30f81e76f697 100644 --- a/python/tvm/micro/__init__.py +++ b/python/tvm/micro/__init__.py @@ -16,8 +16,12 @@ # under the License. """MicroTVM module for bare-metal backends""" -from ..contrib import binutil -from .base import DEVICE_SECTIONS -from .base import Session, create_micro_mod, cross_compiler, LibType -from .base import get_micro_host_driven_dir, get_micro_device_dir -from . import device +from .artifact import Artifact +from .build import build_static_runtime, default_options, TVM_ROOT_DIR +from .build import CRT_ROOT_DIR, Workspace +from .compiler import Compiler, DefaultCompiler, Flasher +from .debugger import GdbRemoteDebugger +from .micro_library import MicroLibrary +from .micro_binary import MicroBinary +from .session import create_local_graph_runtime, Session +from .transport import TransportLogger, DebugWrapperTransport, SubprocessTransport diff --git a/python/tvm/micro/artifact.py b/python/tvm/micro/artifact.py new file mode 100644 index 000000000000..5f887db87dec --- /dev/null +++ b/python/tvm/micro/artifact.py @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""""Defines abstractions around compiler artifacts produced in compiling micro TVM binaries.""" + +import io +import os +import json +import shutil +import tarfile + + +class ArtifactFileNotFoundError(Exception): + """Raised when an artifact file cannot be found on disk.""" + + +class ArtifactBadSymlinkError(Exception): + """Raised when an artifact symlink points outside the base directory.""" + + +class ArtifactBadArchiveError(Exception): + """Raised when an artifact archive is malformed.""" + + +class Artifact: + """Describes a compiler artifact and defines common logic to archive it for transport.""" + + # A version number written to the archive. + ENCODING_VERSION = 1 + + # A unique string identifying the type of artifact in an archive. Subclasses must redefine this + # variable. + ARTIFACT_TYPE = None + + @classmethod + def unarchive(cls, archive_path, base_dir): + """Unarchive an artifact into base_dir. + + Parameters + ---------- + archive_path : str + Path to the archive file. + base_dir : str + Path to a non-existent, empty directory under which the artifact will live. + + Returns + ------- + Artifact : + The unarchived artifact. + """ + if os.path.exists(base_dir): + raise ValueError(f'base_dir exists: {base_dir}') + + base_dir_parent, base_dir_name = os.path.split(base_dir) + temp_dir = os.path.join(base_dir_parent, f'__tvm__{base_dir_name}') + os.mkdir(temp_dir) + try: + with tarfile.open(archive_path) as tar_f: + tar_f.extractall(temp_dir) + + temp_dir_contents = os.listdir(temp_dir) + if len(temp_dir_contents) != 1: + raise ArtifactBadArchiveError( + 'Expected exactly 1 subdirectory at root of archive, got ' + f'{temp_dir_contents!r}') + + metadata_path = os.path.join(temp_dir, temp_dir_contents[0], 'metadata.json') + if not metadata_path: + raise ArtifactBadArchiveError('No metadata.json found in archive') + + with open(metadata_path) as metadata_f: + metadata = json.load(metadata_f) + + version = metadata.get('version') + if version != cls.ENCODING_VERSION: + raise ArtifactBadArchiveError( + f'archive version: expect {cls.EXPECTED_VERSION}, found {version}') + + os.rename(os.path.join(temp_dir, temp_dir_contents[0]), base_dir) + + artifact_cls = cls + for sub_cls in cls.__subclasses__(): + if (sub_cls.ARTIFACT_TYPE is not None and + sub_cls.ARTIFACT_TYPE == metadata.get('artifact_type')): + artifact_cls = sub_cls + break + + return artifact_cls.from_unarchived( + base_dir, metadata['labelled_files'], metadata['metadata']) + finally: + shutil.rmtree(temp_dir) + + @classmethod + def from_unarchived(cls, base_dir, labelled_files, metadata): + return cls(base_dir, labelled_files, metadata) + + def __init__(self, base_dir, labelled_files, metadata): + """Create a new artifact. + + Parameters + ---------- + base_dir : str + The path to a directory on disk which contains all the files in this artifact. + labelled_files : Dict[str, str] + A dict mapping a file label to the relative paths of the files that carry that label. + metadata : Dict + A dict containing artitrary JSON-serializable key-value data describing the artifact. + """ + self.base_dir = os.path.realpath(base_dir) + self.labelled_files = labelled_files + self.metadata = metadata + + for label, files in labelled_files.items(): + for f in files: + f_path = os.path.join(self.base_dir, f) + if not os.path.lexists(f_path): + raise ArtifactFileNotFoundError(f'{f} (label {label}): not found at {f_path}') + + if os.path.islink(f_path): + link_path = os.path.readlink(f_path) + if os.path.isabs(link_path): + link_fullpath = link_path + else: + link_fullpath = os.path.join(os.path.dirname(f_path), link_path) + + link_fullpath = os.path.realpath(link_fullpath) + if not link_fullpath.startswith(self.base_dir): + raise ArtifactBadSymlinkError( + f'{f} (label {label}): symlink points outside artifact tree') + + def abspath(self, rel_path): + """Return absolute path to the member with the given relative path.""" + return os.path.join(self.base_dir, rel_path) + + def label(self, label): + """Return a list of relative paths to files with the given label.""" + return self.labelled_files[label] + + def label_abspath(self, label): + return [self.abspath(p) for p in self.labelled_files[label]] + + def archive(self, archive_path): + """Create a relocatable tar archive of the artifacts. + + Parameters + ---------- + archive_path : str + Path to the tar file to create. Or, path to a directory, under which a tar file will be + created named {base_dir}.tar. + + Returns + ------- + str : + The value of archive_path, after potentially making the computation describe above. + """ + if os.path.isdir(archive_path): + archive_path = os.path.join(archive_path, f'{os.path.basename(self.base_dir)}.tar') + + archive_name = os.path.splitext(os.path.basename(archive_path))[0] + with tarfile.open(archive_path, 'w') as tar_f: + def _add_file(name, data, f_type): + tar_info = tarfile.TarInfo(name=name) + tar_info.type = f_type + data_bytes = bytes(data, 'utf-8') + tar_info.size = len(data) + tar_f.addfile(tar_info, io.BytesIO(data_bytes)) + + _add_file(f'{archive_name}/metadata.json', + json.dumps({'version': self.ENCODING_VERSION, + 'labelled_files': self.labelled_files, + 'metadata': self.metadata}, + indent=2, + sort_keys=True), + tarfile.REGTYPE) + for dir_path, _, files in os.walk(self.base_dir): + for f in files: + file_path = os.path.join(dir_path, f) + archive_file_path = os.path.join( + archive_name, os.path.relpath(file_path, self.base_dir)) + if not os.path.islink(file_path): + tar_f.add(file_path, archive_file_path, recursive=False) + continue + + link_path = os.readlink(file_path) + if not os.path.isabs(link_path): + tar_f.add(file_path, archive_file_path, recursive=False) + continue + + relpath = os.path.relpath(link_path, os.path.dirname(file_path)) + _add_file(archive_file_path, relpath, tarfile.LNKTYPE) + + return archive_path diff --git a/python/tvm/micro/base.py b/python/tvm/micro/base.py index 57e175600990..86d3fc9fa195 100644 --- a/python/tvm/micro/base.py +++ b/python/tvm/micro/base.py @@ -16,329 +16,7 @@ # under the License. """Base definitions for MicroTVM""" -from __future__ import absolute_import - -import os -import re -import sys -from enum import Enum - import tvm import tvm._ffi -from tvm.contrib import util as _util -from tvm.contrib import cc as _cc - -# all sections that comprise a device's memory layout, in order from lowest -# starting address to highest -DEVICE_SECTIONS = [ - "text", - "rodata", - "data", - "bss", - "args", - "heap", - "workspace", - "stack", -] - - -class LibType(Enum): - """Enumeration of library types that can be compiled and loaded onto a device""" - - # library to be used as a MicroTVM runtime - RUNTIME = 0 - # library to be used as an operator - OPERATOR = 1 - - -class Session: - """MicroTVM Device Session - - Parameters - ---------- - config : dict - configuration for this session (as generated by - `tvm.micro.device.host.default_config()`, for example) - - Example - -------- - .. code-block:: python - - c_mod = ... # some module generated with "c" as the target - dev_config = micro.device.arm.stm32f746xx.default_config('127.0.0.1', 6666) - with tvm.micro.Session(dev_config) as sess: - micro_mod = sess.create_micro_mod(c_mod) - """ - - def __init__(self, config): - self._check_system() - # TODO(weberlo): add config validation - - # grab a binutil instance from the ID in the config - dev_funcs = tvm.micro.device.get_device_funcs(config["device_id"]) - self.toolchain_prefix = config["toolchain_prefix"] - self.mem_layout = config["mem_layout"] - self.word_size_bits = config["word_size_bits"] - self.thumb_mode = config["thumb_mode"] - self.use_device_timer = config["use_device_timer"] - self.comms_method = config["comms_method"] - - # First, find and compile runtime library. - runtime_src_path = os.path.join(get_micro_host_driven_dir(), "utvm_runtime.c") - tmp_dir = _util.tempdir() - runtime_obj_path = tmp_dir.relpath("utvm_runtime.obj") - options = ["-I{}".format(get_micro_host_driven_dir())] - dev_funcs["create_micro_lib"]( - runtime_obj_path, runtime_src_path, LibType.RUNTIME, options=options - ) - - comms_method = config["comms_method"] - if comms_method == "openocd": - server_addr = config["server_addr"] - server_port = config["server_port"] - elif comms_method == "host": - server_addr = "" - server_port = 0 - else: - raise RuntimeError(f"unknown communication method: f{self.comms_method}") - - assert all( - map(lambda sec: sec in self.mem_layout, DEVICE_SECTIONS) - ), "not all sections have an assigned memory layout" - self.module = _CreateSession( - comms_method, - runtime_obj_path, - self.toolchain_prefix, - self.mem_layout["text"].get("start", 0), - self.mem_layout["text"]["size"], - self.mem_layout["rodata"].get("start", 0), - self.mem_layout["rodata"]["size"], - self.mem_layout["data"].get("start", 0), - self.mem_layout["data"]["size"], - self.mem_layout["bss"].get("start", 0), - self.mem_layout["bss"]["size"], - self.mem_layout["args"].get("start", 0), - self.mem_layout["args"]["size"], - self.mem_layout["heap"].get("start", 0), - self.mem_layout["heap"]["size"], - self.mem_layout["workspace"].get("start", 0), - self.mem_layout["workspace"]["size"], - self.mem_layout["stack"].get("start", 0), - self.mem_layout["stack"]["size"], - self.word_size_bits, - self.thumb_mode, - self.use_device_timer, - server_addr, - server_port, - config.get("debug_func"), - ) - self._enter = self.module["enter"] - self._exit = self.module["exit"] - self.get_last_batch_time = self.module["get_last_batch_time"] - self.get_last_batch_cycles = self.module["get_last_batch_cycles"] - - def _check_system(self): - """Check if the user's system is supported by MicroTVM. - - Raises error if not supported. - """ - if not sys.platform.startswith("linux"): - raise RuntimeError("MicroTVM is currently only supported on Linux") - # TODO(weberlo): Add 32-bit support. - # It's primarily the compilation pipeline that isn't compatible. - if sys.maxsize <= 2 ** 32: - raise RuntimeError("MicroTVM is currently only supported on 64-bit host platforms") - - def __enter__(self): - self._enter() - return self - - def __exit__(self, exc_type, exc_value, exc_traceback): - self._exit() - - -def _calc_max_workspace_usage(src): - # TODO factor in alignment to the calculation (alloc sizes will be aligned up to the word size) - alloc_re = re.compile( - r".*\* ?(.+) = (\(.+\))? TVMBackendAllocWorkspace\(.+, .+, \(uint64_t\)(.+), .+, .+\).*" - ) - free_re = re.compile(r".*if \(TVMBackendFreeWorkspace\(.+, .+, (\(void\*\))? (.+)\) != 0\) {.*") - max_usage = 0 - alloc_map = {} - for line in src.split("\n"): - if line.strip().startswith("//"): - continue - match = alloc_re.match(line) - if match is not None: - alloc_map[match.group(1)] = int(match.group(3)) - max_usage = max(max_usage, sum(alloc_map.values())) - else: - match = free_re.match(line) - if match is not None: - print(alloc_map) - del alloc_map[match.group(2)] - return max_usage - - -def create_micro_mod( - c_mod, dev_config, lib_src_paths=None, lib_headers=None, lib_include_paths=None -): - """Produces a micro module from a given module. - - Parameters - ---------- - c_mod : tvm.module.Module - module with "c" as its target backend - - lib_src_paths: TODO - TODO - - lib_headers: TODO - TODO - - lib_include_paths: TODO - TODO - - Return - ------ - micro_mod : tvm.module.Module - micro module for the target device - """ - temp_dir = _util.tempdir() - lib_obj_path = temp_dir.relpath("dev_lib.obj") - # TODO use dev config to dispatch on the type of C codegen to run through - # (e.g., CodeGenCArm, CodeGenCHost, CodeGenCRiscV) - c_mod.export_library( - lib_obj_path, - fcompile=cross_compiler( - dev_config, - LibType.OPERATOR, - lib_src_paths=lib_src_paths, - lib_headers=lib_headers, - lib_include_paths=lib_include_paths, - ), - ) - micro_mod = tvm.runtime.load_module(lib_obj_path) - return micro_mod - - -def cross_compiler( - dev_config, lib_type, lib_src_paths=None, lib_headers=None, lib_include_paths=None -): - """Create a cross compile function that wraps `create_lib` for a `Binutil` instance. - - For use in `tvm.runtime.Module.export_library`. - - Parameters - ---------- - create_micro_lib : func - function for creating MicroTVM libraries for a specific device (e.g., - `tvm.micro.device.get_device_funcs('arm.stm32f746xx')['create_micro_lib']`) - - lib_type : micro.LibType - whether to compile a MicroTVM runtime or operator library - - lib_src_paths: TODO - TODO - - lib_headers: TODO - e.g., `['cmsis_gcc.h', 'arm_math.h']` - - lib_include_paths: TODO - TODO - - Return - ------ - func : Callable[[str, str, Optional[str]], None] - cross compile function taking a destination path for the object file - and a path for the input source file. - - Example - -------- - .. code-block:: python - - c_mod = ... # some module generated with "c" as the target - fcompile = tvm.micro.cross_compiler(dev_config, LibType.OPERATOR) - c_mod.export_library('dev_lib.obj', fcompile=fcompile) - """ - assert (lib_headers is None) == ( - lib_include_paths is None - ), "must specify both `lib_headers` and `lib_include_paths` or neither" - - if lib_src_paths is None: - lib_src_paths = [] - if lib_include_paths is None: - lib_include_paths = [] - include_options = [] - for include_path in lib_include_paths: - include_options.append("-I") - include_options.append(include_path) - create_micro_lib = tvm.micro.device.get_device_funcs(dev_config["device_id"])[ - "create_micro_lib" - ] - mem_layout = dev_config["mem_layout"] - - def compile_func(obj_path, src_path, **kwargs): - if isinstance(obj_path, list): - obj_path = obj_path[0] - if isinstance(src_path, list): - src_path = src_path[0] - options = kwargs.get("options", []) - options += include_options - - # check that workspace allocations don't exceed available workspace memory - with open(src_path) as f: - src_contents = f.read() - max_ws_usage = _calc_max_workspace_usage(src_contents) - available_mem = mem_layout["workspace"]["size"] - if max_ws_usage > available_mem: - raise RuntimeError( - f"workspace allocations in library ({max_ws_usage}) " - f"exceed available memory ({available_mem})" - ) - # inject headers into new source path, if requested - if lib_headers: - headers_to_inject = "\n".join(map(lambda s: f"#include <{s}>", lib_headers)) + "\n" - new_src_contents = headers_to_inject + src_contents - tmp_dir = _util.tempdir() - src_path = tmp_dir.relpath(os.path.basename(src_path)) - with open(src_path, "w") as f: - f.write(new_src_contents) - - create_micro_lib(obj_path, src_path, lib_type, options, lib_src_paths=lib_src_paths) - - return _cc.cross_compiler(compile_func, output_format="obj") - - -def get_micro_host_driven_dir(): - """Get directory path for uTVM host-driven runtime source files. - - Return - ------ - micro_device_dir : str - directory path - """ - micro_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - micro_host_driven_dir = os.path.join( - micro_dir, "..", "..", "..", "src", "runtime", "micro", "host_driven" - ) - return micro_host_driven_dir - - -def get_micro_device_dir(): - """Get directory path for parent directory of device-specific source files - - Return - ------ - micro_device_dir : str - directory path - """ - micro_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) - micro_device_dir = os.path.join( - micro_dir, "..", "..", "..", "src", "runtime", "micro", "device" - ) - return micro_device_dir - - tvm._ffi._init_api("tvm.micro", "tvm.micro.base") diff --git a/python/tvm/micro/build.py b/python/tvm/micro/build.py new file mode 100644 index 000000000000..203b3968f2f3 --- /dev/null +++ b/python/tvm/micro/build.py @@ -0,0 +1,148 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines top-level glue functions for building microTVM artifacts.""" + +import copy +import logging +import os +import re +from tvm.contrib import util + + +_LOG = logging.getLogger(__name__) + + +class Workspace: + """Defines helper functions for manipulating temporary compilation workspaces.""" + + def __init__(self, root=None, debug=False): + if debug or root is not None: + with util.TempDirectory.set_keep_for_debug(): + self.tempdir = util.tempdir(custom_path=root) + _LOG.info('Created debug mode workspace at: %s', self.tempdir.temp_dir) + else: + self.tempdir = util.tempdir() + + def relpath(self, path): + return self.tempdir.relpath(path) + + def listdir(self): + return self.tempdir.listdir() + + @property + def path(self): + return self.tempdir.temp_dir + + +# Required C runtime libraries, in link order. +CRT_RUNTIME_LIB_NAMES = ['utvm_rpc_server', 'utvm_rpc_common', 'common'] + + +TVM_ROOT_DIR = os.path.realpath(os.path.join(os.path.dirname(__file__), '..', '..', '..')) + + +CRT_ROOT_DIR = os.path.join(TVM_ROOT_DIR, 'src', 'runtime', 'crt') + + +RUNTIME_LIB_SRC_DIRS = ( + [os.path.join(CRT_ROOT_DIR, n) for n in CRT_RUNTIME_LIB_NAMES] + + [os.path.join(TVM_ROOT_DIR, + '3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/TARGET_SDK_11/' + 'libraries/crc16')]) + + +RUNTIME_SRC_REGEX = re.compile(r'^.*\.cc?$', re.IGNORECASE) + + +_CRT_DEFAULT_OPTIONS = { + 'ccflags': ['-std=c++11'], + 'ldflags': ['-std=gnu++14'], + 'include_dirs': [ + f'{TVM_ROOT_DIR}/include', + f'{TVM_ROOT_DIR}/3rdparty/dlpack/include', + f'{TVM_ROOT_DIR}/3rdparty/mbed-os/targets/TARGET_NORDIC/TARGET_NRF5x/' + 'TARGET_SDK_11/libraries/crc16/', + f'{TVM_ROOT_DIR}/3rdparty/dmlc-core/include', + f'{CRT_ROOT_DIR}/include' + ], + 'profile': { + 'common': ['-Wno-unused-variable'] + } +} + + +def default_options(target_include_dir): + """Return default opts passed to Compile commands.""" + bin_opts = copy.deepcopy(_CRT_DEFAULT_OPTIONS) + bin_opts['include_dirs'].append(target_include_dir) + lib_opts = copy.deepcopy(bin_opts) + lib_opts['profile']['common'].append('-Werror') + lib_opts['cflags'] = ['-Wno-error=incompatible-pointer-types'] + return {'bin_opts': bin_opts, 'lib_opts': lib_opts} + + +def build_static_runtime(workspace, compiler, module, lib_opts=None, bin_opts=None): + """Build the on-device runtime, statically linking the given modules. + + Parameters + ---------- + compiler : tvm.micro.Compiler + Compiler instance used to build the runtime. + + module : IRModule + Module to statically link. + + lib_opts : dict + Extra kwargs passed to library(), + + bin_opts : dict + Extra kwargs passed to binary(), + + Returns + ------- + MicroBinary : + The compiled runtime. + """ + lib_opts = _CRT_DEFAULT_OPTIONS if lib_opts is None else lib_opts + bin_opts = _CRT_DEFAULT_OPTIONS if bin_opts is None else bin_opts + + mod_build_dir = workspace.relpath(os.path.join('build', 'module')) + os.makedirs(mod_build_dir) + mod_src_dir = workspace.relpath(os.path.join('src', 'module')) + os.makedirs(mod_src_dir) + mod_src_path = os.path.join(mod_src_dir, 'module.c') + module.save(mod_src_path, 'cc') + + libs = [] + for lib_src_dir in RUNTIME_LIB_SRC_DIRS: + lib_name = os.path.basename(lib_src_dir) + lib_build_dir = workspace.relpath(f'build/{lib_name}') + os.makedirs(lib_build_dir) + + lib_srcs = [] + for p in os.listdir(lib_src_dir): + if RUNTIME_SRC_REGEX.match(p): + lib_srcs.append(os.path.join(lib_src_dir, p)) + + libs.append(compiler.library(lib_build_dir, lib_srcs, lib_opts)) + + libs.append(compiler.library(mod_build_dir, [mod_src_path], lib_opts)) + + runtime_build_dir = workspace.relpath(f'build/runtime') + os.makedirs(runtime_build_dir) + return compiler.binary(runtime_build_dir, libs, bin_opts) diff --git a/python/tvm/micro/class_factory.py b/python/tvm/micro/class_factory.py new file mode 100644 index 000000000000..3d00636e4dc0 --- /dev/null +++ b/python/tvm/micro/class_factory.py @@ -0,0 +1,97 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines a utility for representing deferred class instatiations as JSON.""" + +import importlib +import json +import typing + + +JsonSerializable = typing.Union[int, float, str, None, bool] + + +class SerializedFactoryError(Exception): + """Raised when ClassFactory.from_json is invoked with an invalid JSON blob.""" + + +class ClassFactory: + """Describes a JSON-serializable class instantiation, for use with the RPC server.""" + + # When not None, the superclass from which all cls must derive. + SUPERCLASS = None + + def __init__(self, cls: typing.Callable, init_args: typing.List[JsonSerializable], + init_kw: typing.Dict[str, JsonSerializable]): + self.cls = cls + self.init_args = init_args + self.init_kw = init_kw + + def override_kw(self, **kw_overrides): + kwargs = self.init_kw + if kw_overrides: + kwargs = dict(kwargs) + for k, v in kw_overrides.items(): + kwargs[k] = v + + return self.__class__(self.cls, self.init_args, kwargs) + + def instantiate(self): + return self.cls(*self.init_args, **self.init_kw) + + @property + def to_json(self): + return json.dumps({ + 'cls': '.'.join([self.cls.__module__, self.cls.__name__]), + 'init_args': self.init_args, + 'init_kw': self.init_kw, + }) + + EXPECTED_KEYS = ('cls', 'init_args', 'init_kw') + + @classmethod + def from_json(cls, data): + """Reconstruct a ClassFactory instance from its JSON representation. + + Parameters + ---------- + data : str + The JSON representation of the ClassFactory. + + Returns + ------- + ClassFactory : + The reconstructed ClassFactory instance. + + Raises + ------ + SerializedFactoryError : + If the JSON object represented by `data` is malformed. + """ + obj = json.loads(data) + if not isinstance(obj, dict): + raise SerializedFactoryError(f'deserialized json payload: want dict, got: {obj!r}') + + for key in cls.EXPECTED_KEYS: + if key not in obj: + raise SerializedFactoryError( + f'deserialized json payload: expect key {key}, got: {obj!r}') + + cls_package_name, cls_name = obj['cls'].rsplit('.', 1) + cls_package = importlib.import_module(cls_package_name) + cls_obj = getattr(cls_package, cls_name) + return cls(cls_obj, obj['init_args'], obj['init_kw']) diff --git a/python/tvm/micro/compiler.py b/python/tvm/micro/compiler.py new file mode 100644 index 000000000000..f29925a61b0a --- /dev/null +++ b/python/tvm/micro/compiler.py @@ -0,0 +1,318 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines interfaces and default implementations for compiling and flashing code.""" + +import abc +import glob +import os +import re + +from tvm.contrib import binutil +import tvm.target +from . import build +from . import class_factory +from . import debugger +from . import transport + + +class DetectTargetError(Exception): + """Raised when no target comment was detected in the sources given.""" + + +class NoDefaultToolchainMatchedError(Exception): + """Raised when no default toolchain matches the target string.""" + + +class Compiler(metaclass=abc.ABCMeta): + """The compiler abstraction used with micro TVM.""" + + TVM_TARGET_RE = re.compile(r'^// tvm target: (.*)$') + + @classmethod + def _target_from_sources(cls, sources): + """Determine the target used to generate the given source files. + + Parameters + ---------- + sources : List[str] + The paths to source files to analyze. + + Returns + ------- + tvm.target.Target : + A Target instance reconstructed from the target string listed in the source files. + """ + target_strs = set() + + for obj in sources: + with open(obj) as obj_f: + for line in obj_f: + m = cls.TVM_TARGET_RE.match(line) + if m: + target_strs.add(m.group(1)) + + if len(target_strs) != 1: + raise DetectTargetError( + 'autodetecting cross-compiler: could not extract TVM target from C source; regex ' + f'{cls.TVM_TARGET_RE.pattern} does not match any line in sources: ' + f'{", ".join(sources)}') + + target_str = next(iter(target_strs)) + return tvm.target.create(target_str) + + # Maps regexes identifying CPUs to the default toolchain prefix for that CPU. + TOOLCHAIN_PREFIX_BY_CPU_REGEX = { + r'cortex-[am].*': 'arm-none-eabi-', + 'x86[_-]64': '', + 'native': '', + } + + def _autodetect_toolchain_prefix(self, target): + matches = [] + for regex, prefix in self.TOOLCHAIN_PREFIX_BY_CPU_REGEX.items(): + if re.match(regex, target.attrs['mcpu']): + matches.append(prefix) + + if matches: + if len(matches) != 1: + raise NoDefaultToolchainMatchedError( + f'{opt} matched more than 1 default toolchain prefix: {", ".join(matches)}. ' + 'Specify cc.cross_compiler to create_micro_library()') + + return matches[0] + + raise NoDefaultToolchainMatchedError( + f'target {str(target)} did not match any default toolchains') + + def _defaults_from_target(self, target): + """Determine the default compiler options from the target specified. + + Parameters + ---------- + target : tvm.target.Target + + Returns + ------- + List[str] : + Default options used the configure the compiler for that target. + """ + opts = [] + # TODO use march for arm(https://gcc.gnu.org/onlinedocs/gcc/ARM-Options.html)? + if target.attrs.get('mcpu'): + opts.append(f'-march={target.attrs["mcpu"]}') + if target.attrs.get('mfpu'): + opts.append(f'-mfpu={target.attrs["mfpu"]}') + + return opts + + @abc.abstractmethod + def library(self, output, sources, options=None): + """Build a library from the given source files. + + Parameters + ---------- + output : str + The path to the library that should be created. The containing directory + is guaranteed to be empty and should be the base_dir for the returned + Artifact. + sources : List[str] + A list of paths to source files that should be compiled. + options : Optional[List[str]] + If given, additional command-line flags to pass to the compiler. + + Returns + ------- + MicroLibrary : + The compiled library, as a MicroLibrary instance. + """ + raise NotImplementedError() + + @abc.abstractmethod + def binary(self, output, objects, options=None, link_main=True, main_options=None): + """Link a binary from the given object and/or source files. + + Parameters + ---------- + output : str + The path to the binary that should be created. The containing directory + is guaranteed to be empty and should be the base_dir for the returned + Artifact. + objects : List[MicroLibrary] + A list of paths to source files or libraries that should be compiled. The final binary + should be statically-linked. + options: Optional[List[str]] + If given, additional command-line flags to pass to the compiler. + link_main: Optional[bool] + True if the standard main entry point for this Compiler should be included in the + binary. False if a main entry point is provided in one of `objects`. + main_options: Optional[List[str]] + If given, additional command-line flags to pass to the compiler when compiling the + main() library. In some cases, the main() may be compiled directly into the final binary + along with `objects` for logistical reasons. In those cases, specifying main_options is + an error and ValueError will be raised. + + Returns + ------- + MicroBinary : + The compiled binary, as a MicroBinary instance. + """ + raise NotImplementedError() + + @property + def flasher_factory(self): + """Produce a FlasherFactory for a Flasher instance suitable for this Compiler.""" + raise NotImplementedError("The Compiler base class doesn't define a flasher.") + + def flasher(self, **kw): + """Return a Flasher that can be used to program a produced MicroBinary onto the target.""" + return self.flasher_factory.override_kw(**kw).instantiate() + + +class IncompatibleTargetError(Exception): + """Raised when source files specify a target that differs from the compiler target.""" + + +class DefaultCompiler(Compiler): + """A Compiler implementation that attempts to use the system-installed GCC.""" + + def __init__(self, target=None): + super(DefaultCompiler, self).__init__() + self.target = target + if isinstance(target, str): + self.target = tvm.target.create(target) + + def library(self, output, sources, options=None): + options = options if options is not None else {} + try: + target = self._target_from_sources(sources) + except DetectTargetError: + assert self.target is not None, ( + "Must specify target= to constructor when compiling sources which don't specify a " + "target") + + target = self.target + + if self.target is not None and str(self.target) != str(target): + raise IncompatibleTargetError( + f'auto-detected target {target} differs from configured {self.target}') + + prefix = self._autodetect_toolchain_prefix(target) + outputs = [] + for src in sources: + src_base, src_ext = os.path.splitext(os.path.basename(src)) + + compiler_name = {'.c': 'gcc', '.cc': 'g++', '.cpp': 'g++'}[src_ext] + args = [prefix + compiler_name, '-g'] + args.extend(self._defaults_from_target(target)) + + args.extend(options.get(f'{src_ext[1:]}flags', [])) + + for include_dir in options.get('include_dirs', []): + args.extend(['-I', include_dir]) + + output_filename = f'{src_base}.o' + output_abspath = os.path.join(output, output_filename) + binutil.run_cmd(args + ['-c', '-o', output_abspath, src]) + outputs.append(output_abspath) + + output_filename = f'{os.path.basename(output)}.a' + output_abspath = os.path.join(output, output_filename) + binutil.run_cmd([prefix + 'ar', '-r', output_abspath] + outputs) + binutil.run_cmd([prefix + 'ranlib', output_abspath]) + + return tvm.micro.MicroLibrary(output, [output_filename]) + + def binary(self, output, objects, options=None, link_main=True, main_options=None): + assert self.target is not None, ( + 'must specify target= to constructor, or compile sources which specify the target ' + 'first') + + args = [self._autodetect_toolchain_prefix(self.target) + 'g++'] + args.extend(self._defaults_from_target(self.target)) + if options is not None: + args.extend(options.get('ldflags', [])) + + for include_dir in options.get('include_dirs', []): + args.extend(['-I', include_dir]) + + output_filename = os.path.basename(output) + output_abspath = os.path.join(output, output_filename) + args.extend(['-g', '-o', output_abspath]) + + if link_main: + host_main_srcs = glob.glob(os.path.join(build.CRT_ROOT_DIR, 'host', '*.cc')) + if main_options: + main_lib = self.library(os.path.join(output, 'host'), host_main_srcs, main_options) + for lib_name in main_lib.library_files: + args.append(main_lib.abspath(lib_name)) + else: + args.extend(host_main_srcs) + + for obj in objects: + for lib_name in obj.library_files: + args.append(obj.abspath(lib_name)) + + binutil.run_cmd(args) + return tvm.micro.MicroBinary(output, output_filename, []) + + @property + def flasher_factory(self): + return FlasherFactory(HostFlasher, [], {}) + + +class Flasher(metaclass=abc.ABCMeta): + """An interface for flashing binaries and returning a transport factory.""" + + @abc.abstractmethod + def flash(self, micro_binary): + """Flash a binary onto the device. + + Parameters + ---------- + micro_binary : MicroBinary + A MicroBinary instance. + + Returns + ------- + transport.TransportContextManager : + A ContextManager that can be used to create and tear down an RPC transport layer between + this TVM instance and the newly-flashed binary. + """ + raise NotImplementedError() + + +class FlasherFactory(class_factory.ClassFactory): + """A ClassFactory for Flasher instances.""" + + SUPERCLASS = Flasher + + +class HostFlasher(Flasher): + """A Flasher implementation that spawns a subprocess on the host.""" + + def __init__(self, debug=False): + self.debug = debug + + def flash(self, micro_binary): + if self.debug: + gdb_wrapper = debugger.GdbTransportDebugger( + [micro_binary.abspath(micro_binary.binary_file)]) + return transport.DebugWrapperTransport( + debugger=gdb_wrapper, transport=gdb_wrapper.Transport()) + + return transport.SubprocessTransport([micro_binary.abspath(micro_binary.binary_file)]) diff --git a/python/tvm/micro/debugger.py b/python/tvm/micro/debugger.py new file mode 100644 index 000000000000..06e7c1c79ae9 --- /dev/null +++ b/python/tvm/micro/debugger.py @@ -0,0 +1,188 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines functions for controlling debuggers for micro TVM binaries.""" + +import abc +import os +import signal +import subprocess +import threading + +from . import transport as _transport + + +class Debugger(metaclass=abc.ABCMeta): + """An interface for controlling micro TVM debuggers.""" + + def __init__(self): + self.on_terminate_callbacks = [] + + @abc.abstractmethod + def start(self): + """Start the debugger, but do not block on it. + + The runtime will continue to be driven in the background. + """ + raise NotImplementedError() + + @abc.abstractmethod + def stop(self): + """Terminate the debugger.""" + raise NotImplementedError() + + +class GdbDebugger(Debugger): + """Handles launching, suspending signals, and potentially dealing with terminal issues.""" + + @abc.abstractmethod + def popen_kwargs(self): + raise NotImplementedError() + + def _wait_restore_signal(self): + self.popen.wait() + if not self.did_terminate.is_set(): + for callback in self.on_terminate_callbacks: + try: + callback() + except Exception: # pylint: disable=broad-except + logging.warn('on_terminate_callback raised exception', exc_info=True) + + def start(self): + kwargs = self.popen_kwargs() + self.did_terminate = threading.Event() + self.old_signal = signal.signal(signal.SIGINT, signal.SIG_IGN) + self.popen = subprocess.Popen(**kwargs) + threading.Thread(target=self._WaitRestoreSignal).start() + + def stop(self): + self.did_terminate.set() + self.popen.terminate() + signal.signal(signal.SIGINT, self.old_signal) + + +class GdbTransportDebugger(GdbDebugger): + """A debugger that uses a single GDB subprocess as both the transport and the debugger. + + Opens pipes for the target's stdin and stdout, launches GDB and configures GDB's target + arguments to read and write from the pipes using /dev/fd. + """ + + def __init__(self, args, **popen_kw): + super(GdbTransportDebugger, self).__init__() + self.args = args + self.popen_kw = popen_kw + + def popen_kwargs(self): + stdin_read, stdin_write = os.pipe() + stdout_read, stdout_write = os.pipe() + + os.set_inheritable(stdin_read, True) + os.set_inheritable(stdout_write, True) + + sysname = os.uname()[0] + if sysname == 'Darwin': + args = ['lldb', + '-O', f'target create {self.args[0]}', + '-O', f'settings set target.input-path /dev/fd/{stdin_read}', + '-O', f'settings set target.output-path /dev/fd/{stdout_write}'] + if len(self.args) > 1: + args.extend( + ['-O', 'settings set target.run-args {}'.format(' '.join(self.args[1:]))]) + elif sysname == 'Linux': + args = (['gdb', '--args'] + + self.args + + ['/dev/fd/{stdout_write}']) + else: + raise NotImplementedError(f'System {sysname} is not yet supported') + + self.stdin = os.fdopen(stdin_write, 'wb', buffering=0) + self.stdout = os.fdopen(stdout_read, 'rb', buffering=0) + + return { + 'args': args, + 'pass_fds': [stdin_read, stdout_write], + } + + def _wait_for_process_death(self): + self.popen.wait() + self.stdin.close() + self.stdout.close() + + def start(self): + to_return = super(GdbTransportDebugger, self).Start() + threading.Thread(target=self._wait_for_process_death, daemon=True).start() + return to_return + + def stop(self): + self.stdin.close() + self.stdout.close() + super(GdbTransportDebugger, self).Stop() + + class _Transport(_transport.Transport): + def __init__(self, gdb_transport_debugger): + self.gdb_transport_debugger = gdb_transport_debugger + + def open(self): + pass # Pipes opened by parent class. + + def write(self, data): + return self.gdb_transport_debugger.stdin.write(data) + + def read(self, n): + return self.gdb_transport_debugger.stdout.read(n) + + def close(self): + pass # Pipes closed by parent class. + + def transport(self): + return self._Transport(self) + + +class GdbRemoteDebugger(GdbDebugger): + """A Debugger that invokes GDB and attaches to a remote GDBserver-based target.""" + + def __init__(self, gdb_binary, remote_hostport, debug_binary, wrapping_context_manager=None, + **popen_kw): + super(GdbRemoteDebugger, self).__init__() + self.gdb_binary = gdb_binary + self.remote_hostport = remote_hostport + self.debug_binary = debug_binary + self.wrapping_context_manager = wrapping_context_manager + self.popen_kw = popen_kw + + def popen_kwargs(self): + kwargs = { + 'args': [self.gdb_binary, + '-iex', f'file {self.debug_binary}', + '-iex', f'target remote {self.remote_hostport}'], + } + kwargs.update(self.popen_kw) + + return kwargs + + def start(self): + if self.wrapping_context_manager is not None: + self.wrapping_context_manager.__enter__() + super(GdbRemoteDebugger, self).Start() + + def stop(self): + try: + super(GdbRemoteDebugger, self).Stop() + finally: + if self.wrapping_context_manager is not None: + self.wrapping_context_manager.__exit__(None, None, None) diff --git a/python/tvm/micro/device/__init__.py b/python/tvm/micro/device/__init__.py deleted file mode 100644 index 89731b9aa797..000000000000 --- a/python/tvm/micro/device/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Device-specific configuration for MicroTVM""" - -from .base import create_micro_lib_base, gen_mem_layout -from .base import MemConstraint, register_device, get_device_funcs -from . import host -from . import arm -from . import riscv_spike diff --git a/python/tvm/micro/device/arm/__init__.py b/python/tvm/micro/device/arm/__init__.py deleted file mode 100644 index be323b9e0a2b..000000000000 --- a/python/tvm/micro/device/arm/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Base module for ARM device configurations""" - -from . import stm32f746xx diff --git a/python/tvm/micro/device/arm/stm32f746xx.py b/python/tvm/micro/device/arm/stm32f746xx.py deleted file mode 100644 index bd666016444f..000000000000 --- a/python/tvm/micro/device/arm/stm32f746xx.py +++ /dev/null @@ -1,137 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Compilation and config definitions for Arm STM32F746XX devices""" -import os -from .. import create_micro_lib_base, register_device, gen_mem_layout, MemConstraint - -DEVICE_ID = "arm.stm32f746xx" -TOOLCHAIN_PREFIX = "arm-none-eabi-" -WORD_SIZE_BITS = 32 -# -# [Device Memory Layout] -# RAM (rwx) : START = 0x20000000, LENGTH = 320K -# Flash (rx) : START = 0x8000000, LENGTH = 1024K -# -BASE_ADDR = 0x20000000 -AVAILABLE_MEM = 320000 -DEFAULT_SECTION_CONSTRAINTS = { - "text": (18000, MemConstraint.ABSOLUTE_BYTES), - "rodata": (512, MemConstraint.ABSOLUTE_BYTES), - "data": (100, MemConstraint.ABSOLUTE_BYTES), - "bss": (640, MemConstraint.ABSOLUTE_BYTES), - "args": (4096, MemConstraint.ABSOLUTE_BYTES), - "heap": (100.0, MemConstraint.WEIGHT), - "workspace": (64000, MemConstraint.ABSOLUTE_BYTES), - "stack": (32, MemConstraint.ABSOLUTE_BYTES), -} - - -def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None): - """Wrapper over `create_micro_lib_base` to add device-specific options - - Parameters - ---------- - obj_path : str - path to generated object file - - src_path : str - path to source file - - lib_type : micro.LibType - whether to compile a MicroTVM runtime or operator library - - options : Optional[List[str]] - additional options to pass to GCC - - lib_src_paths : Optional[List[str]] - TODO - """ - if options is None: - options = [] - else: - options = list(options) - - options += [ - # TODO(weberlo): make a debug flag - "-O2", - "-mcpu=cortex-m7", - "-mlittle-endian", - "-mfloat-abi=hard", - "-mfpu=fpv5-sp-d16", - "-mthumb", - "-ffast-math", - "-gdwarf-5", - "-DARM_MATH_CM7", - "-D__FPU_PRESENT=1U", - "-DARM_MATH_DSP", - "-Wno-unused-variable", - "-Wno-unused-parameter", - "-I{}".format(os.environ["CMSIS_ST_PATH"]), - "-I{}/Core/Include".format(os.environ["CMSIS_ST_PATH"]), - ] - create_micro_lib_base( - obj_path, - src_path, - TOOLCHAIN_PREFIX, - DEVICE_ID, - lib_type, - options=options, - lib_src_paths=lib_src_paths, - ) - - -def generate_config(server_addr, server_port, section_constraints=None): - """Generates a configuration for Arm STM32F746XX devices - - Parameters - ---------- - server_addr : str - address of OpenOCD server to connect to - - server_port : int - port of OpenOCD server to connect to - - section_constraints: Optional[Dict[str, [Number, MemConstraint]]] - maps section name to the quantity of available memory - - Return - ------ - config : Dict[str, Any] - MicroTVM config dict for this device - """ - if section_constraints is None: - section_constraints = DEFAULT_SECTION_CONSTRAINTS - return { - "device_id": DEVICE_ID, - "toolchain_prefix": TOOLCHAIN_PREFIX, - "mem_layout": gen_mem_layout(BASE_ADDR, AVAILABLE_MEM, WORD_SIZE_BITS, section_constraints), - "word_size_bits": WORD_SIZE_BITS, - "thumb_mode": True, - "use_device_timer": True, - "comms_method": "openocd", - "server_addr": server_addr, - "server_port": server_port, - } - - -register_device( - DEVICE_ID, - { - "create_micro_lib": create_micro_lib, - "generate_config": generate_config, - }, -) diff --git a/python/tvm/micro/device/base.py b/python/tvm/micro/device/base.py deleted file mode 100644 index fef0d11b7145..000000000000 --- a/python/tvm/micro/device/base.py +++ /dev/null @@ -1,237 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Base definitions for MicroTVM config""" -import glob -import os -import enum -import pathlib - -from tvm.contrib import util as _util -from tvm.contrib.binutil import run_cmd -from tvm._ffi.libinfo import find_include_path -from tvm.micro import DEVICE_SECTIONS, LibType, get_micro_host_driven_dir, get_micro_device_dir - -_DEVICE_REGISTRY = {} - - -def register_device(device_id, device_funcs): - """Register a device and associated compilation/config functions - - Parameters - ---------- - device_id : str - unique identifier for the device - - device_funcs : Dict[str, func] - dictionary with compilation and config generation functions as values - """ - if device_id in _DEVICE_REGISTRY: - raise RuntimeError(f'"{device_id}" already exists in the device registry') - _DEVICE_REGISTRY[device_id] = device_funcs - - -def get_device_funcs(device_id): - """Get compilation and config generation functions for device - - Parameters - ---------- - device_id : str - unique identifier for the device - - Return - ------ - device_funcs : Dict[str, func] - dictionary with compilation and config generation functions as values - """ - if device_id not in _DEVICE_REGISTRY: - raise RuntimeError(f'"{device_id}" does not exist in the binutil registry') - device_funcs = _DEVICE_REGISTRY[device_id] - return device_funcs - - -def create_micro_lib_base( - out_obj_path, - in_src_path, - toolchain_prefix, - device_id, - lib_type, - options=None, - lib_src_paths=None, -): - """Compiles code into a binary for the target micro device. - - Parameters - ---------- - out_obj_path : str - path to generated object file - - in_src_path : str - path to source file - - toolchain_prefix : str - toolchain prefix to be used. For example, a prefix of - "riscv64-unknown-elf-" means "riscv64-unknown-elf-gcc" is used as - the compiler and "riscv64-unknown-elf-ld" is used as the linker, - etc. - - device_id : str - unique identifier for the target device - - lib_type : micro.LibType - whether to compile a MicroTVM runtime or operator library - - options : List[str] - additional options to pass to GCC - - lib_src_paths : Optional[List[str]] - paths to additional source files to be compiled into the library - """ - # look at these (specifically `strip`): - # https://stackoverflow.com/questions/15314581/g-compiler-flag-to-minimize-binary-size - base_compile_cmd = [ - f"{toolchain_prefix}gcc", - "-std=c11", - "-Wall", - "-Wextra", - "--pedantic", - "-c", - "-g", - "-nostartfiles", - "-nodefaultlibs", - "-nostdlib", - "-fdata-sections", - "-ffunction-sections", - ] - if options is not None: - base_compile_cmd += options - - src_paths = [] - include_paths = find_include_path() + [get_micro_host_driven_dir()] - tmp_dir = _util.tempdir() - # we need to create a new src file in the operator branch - new_in_src_path = in_src_path - if lib_type == LibType.RUNTIME: - dev_dir = _get_device_source_dir(device_id) - - dev_src_paths = glob.glob(f"{dev_dir}/*.[csS]") - # there needs to at least be a utvm_timer.c file - assert dev_src_paths - assert "utvm_timer.c" in map(os.path.basename, dev_src_paths) - - src_paths += dev_src_paths - elif lib_type == LibType.OPERATOR: - # create a temporary copy of the operator source, so we can inject the dev lib - # header without modifying the original. - temp_src_path = tmp_dir.relpath("temp.c") - with open(in_src_path, "r") as f: - src_lines = f.read().splitlines() - src_lines.insert(0, '#include "utvm_device_dylib_redirect.c"') - with open(temp_src_path, "w") as f: - f.write("\n".join(src_lines)) - new_in_src_path = temp_src_path - else: - raise RuntimeError("unknown lib type") - - src_paths += [new_in_src_path] - - # add any src paths required by the operator - if lib_src_paths is not None: - src_paths += lib_src_paths - - # print(f"include paths: {include_paths}") - for path in include_paths: - base_compile_cmd += ["-I", path] - - prereq_obj_paths = [] - # print(src_paths) - for src_path in src_paths: - curr_obj_path = tmp_dir.relpath(pathlib.Path(src_path).with_suffix(".o").name) - assert curr_obj_path not in prereq_obj_paths - prereq_obj_paths.append(curr_obj_path) - curr_compile_cmd = base_compile_cmd + [src_path, "-o", curr_obj_path] - # TODO(weberlo): make compilation fail if there are any warnings - run_cmd(curr_compile_cmd) - - ld_cmd = [f"{toolchain_prefix}ld", "-relocatable"] - ld_cmd += prereq_obj_paths - ld_cmd += ["-o", out_obj_path] - run_cmd(ld_cmd) - - -# TODO we shouldn't need an enum for this. too much bureaucracy. -class MemConstraint(enum.Enum): - """Represents a constraint on the device's memory layout""" - - ABSOLUTE_BYTES = 0 - WEIGHT = 1 - - -def gen_mem_layout(base_addr, available_mem, word_size_bits, section_constraints): - """Template function to generate memory layout for devices. - - Parameters - ---------- - base_addr: Number - The address where usable memory begins on this device. - - available_mem: Number - Available memory at base_addr, given in bytes. - - word_size_bits: Number - Number of bits in one word on this device. - - section_constraints: Optional[Dict[str, [Number, MemConstraint]]] - maps section name to the quantity of available memory - """ - assert word_size_bits in (32, 64), "only 32- or 64-bit devices are supported now" - word_size_bytes = word_size_bits // 8 - byte_sum = sum( - x[0] for x in section_constraints.values() if x[1] == MemConstraint.ABSOLUTE_BYTES - ) - weight_sum = sum(x[0] for x in section_constraints.values() if x[1] == MemConstraint.WEIGHT) - assert byte_sum <= available_mem - available_weight_mem = available_mem - byte_sum - - res = {} - curr_addr = base_addr - for section in DEVICE_SECTIONS: - (val, cons_type) = section_constraints[section] - if cons_type == MemConstraint.ABSOLUTE_BYTES: - assert ( - val % word_size_bytes == 0 - ), f"constraint {val} for {section} section is not word-aligned" - size = val - res[section] = { - "start": curr_addr, - "size": size, - } - else: - size = int((val / weight_sum) * available_weight_mem) - size = (size // word_size_bytes) * word_size_bytes - res[section] = { - "start": curr_addr, - "size": size, - } - curr_addr += size - - return res - - -def _get_device_source_dir(device_id): - """Grabs the source directory for device-specific uTVM files""" - dev_subdir = "/".join(device_id.split(".")) - return get_micro_device_dir() + "/" + dev_subdir diff --git a/python/tvm/micro/device/host.py b/python/tvm/micro/device/host.py deleted file mode 100644 index c5f0e15f83f9..000000000000 --- a/python/tvm/micro/device/host.py +++ /dev/null @@ -1,127 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Compilation and config definitions for the host emulated device""" -import sys - -from . import create_micro_lib_base, register_device, gen_mem_layout, MemConstraint - -DEVICE_ID = "host" -TOOLCHAIN_PREFIX = "" -WORD_SIZE_BITS = 64 if sys.maxsize > 2 ** 32 else 32 - -# we pretend we only have 320kb in the default case, so we can use `gen_mem_layout` -DEFAULT_AVAILABLE_MEM = 3200000 -DEFAULT_SECTION_CONSTRAINTS = { - "text": (20480, MemConstraint.ABSOLUTE_BYTES), - "rodata": (20480, MemConstraint.ABSOLUTE_BYTES), - "data": (768, MemConstraint.ABSOLUTE_BYTES), - "bss": (4096, MemConstraint.ABSOLUTE_BYTES), - "args": (4096, MemConstraint.ABSOLUTE_BYTES), - "heap": (262144, MemConstraint.ABSOLUTE_BYTES), - "workspace": (64000, MemConstraint.ABSOLUTE_BYTES), - "stack": (80, MemConstraint.ABSOLUTE_BYTES), -} - - -def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None): - """Wrapper over `create_micro_lib_base` to add device-specific options - - Parameters - ---------- - obj_path : str - path to generated object file - - src_path : str - path to source file - - lib_type : micro.LibType - whether to compile a MicroTVM runtime or operator library - - options : Optional[List[str]] - additional options to pass to GCC - - lib_src_paths : Optional[List[str]] - paths to additional source files to be compiled into the library - """ - if options is None: - options = [] - else: - options = list(options) - # Cannot increase optimization level on host due to code loading method. - options.append("-O0") - if sys.maxsize > 2 ** 32 and sys.platform.startswith("linux"): - options += ["-mcmodel=large"] - options.append("-DUTVM_TARGET_HOST") - create_micro_lib_base( - obj_path, - src_path, - TOOLCHAIN_PREFIX, - DEVICE_ID, - lib_type, - options=options, - lib_src_paths=lib_src_paths, - ) - - -def generate_config(available_mem=None, section_constraints=None): - """Generates a configuration for the host emulated device - - Parameters - ---------- - available_mem: int - number of RW bytes available for use on device - - section_constraints: Optional[Dict[str, Dict[Number, MemConstraint]]] - maps section name to the quantity of available memory - - Return - ------ - config : Dict[str, Any] - MicroTVM config dict for this device - """ - if available_mem is None: - available_mem = DEFAULT_AVAILABLE_MEM - if section_constraints is None: - section_constraints = DEFAULT_SECTION_CONSTRAINTS - mem_layout = gen_mem_layout(0, available_mem, WORD_SIZE_BITS, section_constraints) - # TODO the host emulated device is an outlier, since we don't know how what - # its base address will be until we've created it in the C++. is there any - # way to change the infrastructure around this so it's not so much of an - # outlier? - - # need to zero out all start addresses, because they don't make sense for a - # host device (the memory region is allocated in the backend) - for section in mem_layout: - mem_layout[section]["start"] = 0 - return { - "device_id": DEVICE_ID, - "toolchain_prefix": TOOLCHAIN_PREFIX, - "mem_layout": mem_layout, - "word_size_bits": WORD_SIZE_BITS, - "thumb_mode": False, - "use_device_timer": False, - "comms_method": "host", - } - - -register_device( - DEVICE_ID, - { - "create_micro_lib": create_micro_lib, - "generate_config": generate_config, - }, -) diff --git a/python/tvm/micro/device/riscv_spike.py b/python/tvm/micro/device/riscv_spike.py deleted file mode 100644 index 27815669742f..000000000000 --- a/python/tvm/micro/device/riscv_spike.py +++ /dev/null @@ -1,112 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -"""Compilation and config definitions for Spike, a RISC-V functional ISA simulator""" - -from . import create_micro_lib_base, register_device, gen_mem_layout, MemConstraint - -DEVICE_ID = "riscv_spike" -TOOLCHAIN_PREFIX = "riscv64-unknown-elf-" -WORD_SIZE_BITS = 64 - -DEFAULT_SECTION_CONSTRAINTS = { - "text": (18000, MemConstraint.ABSOLUTE_BYTES), - "rodata": (128, MemConstraint.ABSOLUTE_BYTES), - "data": (128, MemConstraint.ABSOLUTE_BYTES), - "bss": (2048, MemConstraint.ABSOLUTE_BYTES), - "args": (4096, MemConstraint.ABSOLUTE_BYTES), - "heap": (100.0, MemConstraint.WEIGHT), - "workspace": (64000, MemConstraint.ABSOLUTE_BYTES), - "stack": (32, MemConstraint.ABSOLUTE_BYTES), -} - - -def create_micro_lib(obj_path, src_path, lib_type, options=None, lib_src_paths=None): - """Wrapper over `create_micro_lib_base` to add device-specific options - - Parameters - ---------- - obj_path : str - path to generated object file - - src_path : str - path to source file - - lib_type : micro.LibType - whether to compile a MicroTVM runtime or operator library - - options : Optional[List[str]] - additional options to pass to GCC - - lib_src_paths : Optional[List[str]] - TODO - """ - create_micro_lib_base( - obj_path, - src_path, - TOOLCHAIN_PREFIX, - DEVICE_ID, - lib_type, - options=options, - lib_src_paths=lib_src_paths, - ) - - -def generate_config(base_addr, available_mem, server_addr, server_port, section_constraints=None): - """Generates a configuration for Spike - - Parameters - ---------- - base_addr : int - base address of the simulator (for calculating the memory layout) - - server_addr : str - address of OpenOCD server to connect to - - server_port : int - port of OpenOCD server to connect to - - TODO correct type annotation? - section_constraints: Optional[Dict[str, Tuple[Number, MemConstraint]]] - TODO - - Return - ------ - config : Dict[str, Any] - MicroTVM config dict for this device - """ - if section_constraints is None: - section_constraints = DEFAULT_SECTION_CONSTRAINTS - return { - "device_id": DEVICE_ID, - "toolchain_prefix": TOOLCHAIN_PREFIX, - "mem_layout": gen_mem_layout(base_addr, available_mem, WORD_SIZE_BITS, section_constraints), - "word_size_bits": WORD_SIZE_BITS, - "thumb_mode": False, - "use_device_timer": False, - "comms_method": "openocd", - "server_addr": server_addr, - "server_port": server_port, - } - - -register_device( - DEVICE_ID, - { - "create_micro_lib": create_micro_lib, - "generate_config": generate_config, - }, -) diff --git a/python/tvm/micro/func_registry.py b/python/tvm/micro/func_registry.py index e19f4af917f8..69c4bb1a29e5 100644 --- a/python/tvm/micro/func_registry.py +++ b/python/tvm/micro/func_registry.py @@ -58,7 +58,7 @@ def graph_json_to_c_func_registry(graph_path, func_registry_path): lines.append("static TVMBackendPackedCFunc funcs[] = {") for f in funcs: - lines.append(f" &{f},") + lines.append(f" (TVMBackendPackedCFunc) &{f},") lines += [ "};", diff --git a/python/tvm/micro/micro_binary.py b/python/tvm/micro/micro_binary.py new file mode 100644 index 000000000000..8de144e77303 --- /dev/null +++ b/python/tvm/micro/micro_binary.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines an Artifact implementation for representing compiled micro TVM binaries.""" + +from . import artifact + + +class MicroBinary(artifact.Artifact): + """An Artifact that describes a compiled binary.""" + + ARTIFACT_TYPE = 'micro_binary' + + @classmethod + def from_unarchived(cls, base_dir, labelled_files, metadata): + binary_file = labelled_files['binary_file'][0] + del labelled_files['binary_file'] + + debug_files = None + if 'debug_files' in labelled_files: + debug_files = labelled_files['debug_files'] + del labelled_files['debug_files'] + + return cls(base_dir, binary_file, debug_files=debug_files, labelled_files=labelled_files, + metadata=metadata) + + def __init__(self, base_dir, binary_file, debug_files=None, labelled_files=None, metadata=None): + labelled_files = {} if labelled_files is None else dict(labelled_files) + metadata = {} if metadata is None else dict(metadata) + labelled_files['binary_file'] = [binary_file] + if debug_files is not None: + labelled_files['debug_files'] = debug_files + + super(MicroBinary, self).__init__(base_dir, labelled_files, metadata) + + self.binary_file = binary_file + self.debug_files = debug_files diff --git a/python/tvm/micro/micro_library.py b/python/tvm/micro/micro_library.py new file mode 100644 index 000000000000..7ca82e8a5dbb --- /dev/null +++ b/python/tvm/micro/micro_library.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines an Artifact subclass that describes a compiled static library.""" + +from tvm.contrib import util +from . import artifact +from . import compiler + + +class MicroLibrary(artifact.Artifact): + """An Artifact that describes a compiled static library.""" + + ARTIFACT_TYPE = 'micro_library' + + @classmethod + def from_unarchived(cls, base_dir, labelled_files, metadata): + library_files = labelled_files['library_files'] + del labelled_files['library_files'] + + debug_files = None + if 'debug_files' in labelled_files: + debug_files = labelled_files['debug_files'] + del labelled_files['debug_files'] + + return cls(base_dir, library_files, debug_files=debug_files, labelled_files=labelled_files, + metadata=metadata) + + def __init__(self, base_dir, library_files, debug_files=None, labelled_files=None, + metadata=None): + labelled_files = {} if labelled_files is None else dict(labelled_files) + metadata = {} if metadata is None else dict(metadata) + labelled_files['library_files'] = library_files + if debug_files is not None: + labelled_files['debug_files'] = debug_files + + super(MicroLibrary, self).__init__(base_dir, labelled_files, metadata) + + self.library_files = library_files + self.debug_file = debug_files + + +def create_micro_library(output, objects, options=None): + """Create a MicroLibrary using the default compiler options. + + Parameters + ---------- + output : str + Path to the output file, expected to end in .tar. + objects : List[str] + Paths to the source files to include in the library. + options : Optional[List[str]] + If given, additional command-line flags for the compiler. + """ + temp_dir = util.tempdir() + comp = compiler.DefaultCompiler() + output = temp_dir.relpath('micro-library.o') + comp.library(output, objects, options=options) + + with open(output, 'rb') as output_f: + elf_data = output_f.read() + + # TODO(areusch): Define a mechanism to determine compiler and linker flags for each lib + # enabled by the target str, and embed here. + micro_lib = MicroLibrary('', elf_data, {'target': comp.target.str()}) + micro_lib.save(output) diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py new file mode 100644 index 000000000000..000e8e9b39ed --- /dev/null +++ b/python/tvm/micro/session.py @@ -0,0 +1,124 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines a top-level glue class that operates the Transport and Flasher classes.""" + +import logging +import time + +from .._ffi import get_global_func +from ..contrib import graph_runtime +from .base import _rpc_connect +from ..rpc import RPCSession +from .transport import TransportLogger + + +class Session: + """MicroTVM Device Session + + Parameters + ---------- + config : dict + configuration for this session (as generated by + `tvm.micro.device.host.default_config()`, for example) + + Example + -------- + .. code-block:: python + + c_mod = ... # some module generated with "c" as the target + dev_config = micro.device.arm.stm32f746xx.default_config('127.0.0.1', 6666) + with tvm.micro.Session(dev_config) as sess: + micro_mod = sess.create_micro_mod(c_mod) + """ + + def __init__(self, binary=None, flasher=None, transport_context_manager=None, + session_name='micro-rpc'): + """Configure a new session. + + Parameters + ---------- + binary : MicroBinary + If given, `flasher` must also be given. During session initialization, this binary will + be flashed to the device before the transport is created. + flasher : Flasher + If given, `binary` must also be given. Used to flash `binary` during session + initialization. + transport_context_manager : ContextManager[transport.Transport] + If given, `flasher` and `binary` should not be given. On entry, this context manager + should establish a tarnsport between this TVM instance and the device. + session_name : str + Name of the session, used for debugging. + """ + self.binary = binary + self.flasher = flasher + self.transport_context_manager = transport_context_manager + self.session_name = session_name + + self._rpc = None + self._graph_runtime = None + + def get_system_lib(self): + return self._rpc.get_function('runtime.SystemLib')() + + def __enter__(self): + """Initialize this session and establish an RPC session with the on-device RPC server. + + Returns + ------- + Session : + Returns self. + """ + if self.flasher is not None: + self.transport_context_manager = self.flasher.flash(self.binary) + time.sleep(3.0) + + self.transport = TransportLogger( + self.session_name, self.transport_context_manager, level=logging.INFO).__enter__() + self._rpc = RPCSession(_rpc_connect( + self.session_name, self.transport.write, self.transport.read)) + self.context = self._rpc.cpu(0) + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Tear down this session and associated RPC session resources.""" + self.transport.__exit__(exc_type, exc_value, exc_traceback) + + +def create_local_graph_runtime(graph_json_str, mod, ctx): + """Create a local graph runtime driving execution on the remote CPU context given. + + Parameters + ---------- + graph_json_str : str + A string containing the graph representation. + + mod : tvm.runtime.Module + The remote module containing functions in graph_json_str. + + ctx : tvm.Context + The remote CPU execution context. + + Returns + ------- + tvm.contrib.GraphRuntime : + A local graph runtime instance that executes on the remote device. + """ + device_type_id = [ctx.device_type, ctx.device_id] + fcreate = get_global_func("tvm.graph_runtime.create") + return graph_runtime.GraphModule(fcreate( + graph_json_str, mod, *device_type_id)) diff --git a/python/tvm/micro/transport.py b/python/tvm/micro/transport.py new file mode 100644 index 000000000000..52617abd6bbf --- /dev/null +++ b/python/tvm/micro/transport.py @@ -0,0 +1,225 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Defines abstractions and implementations of the RPC transport used with micro TVM.""" + +import abc +import logging +import string +import subprocess +import typing + +import tvm + +_LOG = logging.getLogger(__name__) + + +@tvm.error.register_error +class SessionTerminatedError(Exception): + """Raised when a transport read operationd discovers that the remote session is terminated.""" + + +class Transport(metaclass=abc.ABCMeta): + """The abstract Transport class used for micro TVM.""" + + def __enter__(self): + self.open() + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + self.close() + + @abc.abstractmethod + def open(self): + """Open any resources needed to send and receive RPC protocol data for a single session.""" + raise NotImplementedError() + + @abc.abstractmethod + def close(self): + """Release resources associated with this transport.""" + raise NotImplementedError() + + @abc.abstractmethod + def read(self, n): + """Read up to n bytes from the transport. + + Parameters + ---------- + n : int + Maximum number of bytes to read from the transport. + + Returns + ------- + bytes : + Data read from the channel. Less than `n` bytes may be returned, but 0 bytes should + never be returned except in error. Note that if a transport error occurs, an Exception + should be raised rather than simply returning empty bytes. + + + Raises + ------ + SessionTerminatedError : + When the transport layer determines that the active session was terminated by the + remote side. Typically this indicates that the remote device has reset. + """ + raise NotImplementedError() + + @abc.abstractmethod + def write(self, data): + """Write data to the transport channel. + + Parameters + ---------- + data : bytes + The data to write over the channel. + + Returns + ------- + int : + The number of bytes written to the underlying channel. This can be less than the length + of `data`, but cannot be 0. + """ + raise NotImplementedError() + + +class TransportLogger(Transport): + """Wraps a Transport implementation and logs traffic to the Python logging infrastructure.""" + + def __init__(self, name, child, logger=None, level=logging.INFO): + self.name = name + self.child = child + self.logger = logger or _LOG + self.level = level + + # Construct PRINTABLE to exclude whitespace from string.printable. + PRINTABLE = (string.digits + string.ascii_letters + string.punctuation) + + @classmethod + def _to_hex(cls, data): + lines = [] + if not data: + lines.append('') + return lines + + for i in range(0, (len(data) + 15) // 16): + chunk = data[i * 16:(i + 1) * 16] + hex_chunk = ' '.join(f'{c:02x}' for c in chunk) + ascii_chunk = ''.join((chr(c) if chr(c) in cls.PRINTABLE else '.') for c in chunk) + lines.append(f'{i * 16:04x} {hex_chunk:47} {ascii_chunk}') + + if len(lines) == 1: + lines[0] = lines[0][6:] + + return lines + + def open(self): + self.logger.log(self.level, 'opening transport') + self.child.open() + + def close(self): + self.logger.log(self.level, 'closing transport') + return self.child.close() + + def read(self, n): + data = self.child.read(n) + hex_lines = self._to_hex(data) + if len(hex_lines) > 1: + self.logger.log(self.level, '%s read %4d B -> [%d B]:\n%s', + self.name, n, len(data), '\n'.join(hex_lines)) + else: + self.logger.log(self.level, '%s read %4d B -> [%d B]: %s', + self.name, n, len(data), hex_lines[0]) + + return data + + def write(self, data): + bytes_written = self.child.write(data) + hex_lines = self._to_hex(data[:bytes_written]) + if len(hex_lines) > 1: + self.logger.log(self.level, '%s write <- [%d B]:\n%s', + self.name, bytes_written, '\n'.join(hex_lines)) + else: + self.logger.log(self.level, '%s write <- [%d B]: %s', + self.name, bytes_written, hex_lines[0]) + + return bytes_written + + +class SubprocessTransport(Transport): + """A Transport implementation that uses a subprocess's stdin/stdout as the channel.""" + + def __init__(self, args, **kwargs): + self.args = args + self.kwargs = kwargs + self.popen = None + + def open(self): + self.kwargs['stdout'] = subprocess.PIPE + self.kwargs['stdin'] = subprocess.PIPE + self.kwargs['bufsize'] = 0 + self.popen = subprocess.Popen(self.args, **self.kwargs) + self.stdin = self.popen.stdin + self.stdout = self.popen.stdout + + def write(self, data): + to_return = self.stdin.write(data) + self.stdin.flush() + + return to_return + + def read(self, n): + return self.stdout.read(n) + + def close(self): + self.stdin.close() + self.stdout.close() + self.popen.terminate() + + +class DebugWrapperTransport(Transport): + """A Transport wrapper class that launches a debugger before opening the transport. + + This is primiarly useful when debugging the other end of a SubprocessTransport. It allows you + to pipe data through the GDB process to drive the subprocess with a debugger attached. + """ + + def __init__(self, debugger, transport): + self.debugger = debugger + self.transport = transport + self.debugger.on_terminate_callbacks.append(self.transport.close) + + def open(self): + self.debugger.Start() + + try: + self.transport.open() + except Exception: + self.debugger.Stop() + raise + + def write(self, data): + return self.transport.write(data) + + def read(self, n): + return self.transport.read(n) + + def close(self): + self.transport.close() + self.debugger.Stop() + + +TransportContextManager = typing.ContextManager[Transport] diff --git a/python/tvm/rpc/minrpc.py b/python/tvm/rpc/minrpc.py index 2c9dd294cc5a..eb4561f39525 100644 --- a/python/tvm/rpc/minrpc.py +++ b/python/tvm/rpc/minrpc.py @@ -35,13 +35,13 @@ def find_minrpc_server_libpath(server="posix_popen_server"): """ curr_dir = os.path.dirname(os.path.realpath(os.path.expanduser(__file__))) source_dir = os.path.abspath(os.path.join(curr_dir, "..", "..", "..")) - - path = os.path.join(source_dir, "src", "runtime", "rpc", "minrpc", ("%s.cc" % server)) + minrpc_dir = os.path.join(source_dir, "src", "runtime", "minrpc") + path = os.path.join(minrpc_dir, server, ("%s.cc" % server)) candidates = [path] if not os.path.isfile(path): raise RuntimeError("Cannot find minserver %s, in candidates %s" % (server, candidates)) - return path + return minrpc_dir, path def with_minrpc(compile_func, server="posix_popen_server", runtime="libtvm"): @@ -63,7 +63,7 @@ def with_minrpc(compile_func, server="posix_popen_server", runtime="libtvm"): fcompile : function The return compilation. """ - server_path = find_minrpc_server_libpath(server) + minrpc_dir, server_path = find_minrpc_server_libpath(server) runtime_path = libinfo.find_lib_path([runtime, runtime + ".so", runtime + ".dylib"])[0] runtime_dir = os.path.abspath(os.path.dirname(runtime_path)) @@ -73,6 +73,7 @@ def with_minrpc(compile_func, server="posix_popen_server", runtime="libtvm"): # Always recommend to to link statically. options += ["-Wl,-rpath=" + runtime_dir] options += ["-I" + path for path in libinfo.find_include_path()] + options += ["-I" + minrpc_dir] fcompile = cc.cross_compiler( compile_func, options=options, add_files=[server_path, runtime_path] ) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 1476f7bca44f..a9d4edec8fff 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -219,6 +219,25 @@ def intel_graphics(model="unknown", options=None): opts = _merge_opts(opts, options) return Target(" ".join(["opencl"] + opts)) +def micro(hardware="unknown", options=None): + """Returns a microTVM target. + + Parameters + ---------- + hardware : str + Canonically identifies the target device; typicaly one of cortex-mX, or a specific SoC model + when that model has been tested to work with microTVM. + options : str or list of str + Additional options + """ + trans_table = { + "host": ["-mcpu=native"], + } + opts = _merge_opts(trans_table[hardware] + ["-runtime=c", "--system-lib"], options) + + # NOTE: in the future, the default micro target will be LLVM except when + # external dependencies are present. + return Target(" ".join(["c"] + opts)) def arm_cpu(model="unknown", options=None): """Returns a ARM CPU target. diff --git a/src/runtime/crt/Makefile b/src/runtime/crt/Makefile index cf11507d70fa..8a24db4e8b2b 100644 --- a/src/runtime/crt/Makefile +++ b/src/runtime/crt/Makefile @@ -15,43 +15,63 @@ # specific language governing permissions and limitations # under the License. +# NOTE: Although this Makefile contains build commands for the C runtime, it isn't intended to be +# used directly in the TVM source tree. Instead, build the "standalone_crt" target, which produces a +# directory tree suitable for this Makefile. If this Makefile looks like it's the top-level of a +# source tree, you can probably ignore this message. + +# NOTE: If files appear to be missing in the generated standalone_crt target, consult the copy job +# specs listed in the TVM repo in cmake/modules/StandaloneCrt.cmake. + ifeq ($(CRT_CONFIG),) $(error "Must supply path to crt_config.h: CRT_CONFIG=...") endif -DLPACK_INCLUDE_DIR ?= ../../../3rdparty/dlpack/include -TVM_INCLUDE_DIR ?= ../../../include + +ifneq ($(wildcard .gitignore),) +$(error "detected building inside tvm source tree.") +$(error "build the standalone_crt target, and re-invoke makefile in build/standalone_crt") +endif BUILD_DIR ?= build PREFIX ?= AR ?= ${PREFIX}ar CC ?= ${PREFIX}gcc +CXX ?= ${PREFIX}g++ RANLIB ?= ${PREFIX}ranlib QUIET ?= @ -CFLAGS += -isystem "${TVM_INCLUDE_DIR}" -isystem "${DLPACK_INCLUDE_DIR}" -I include -I $(dir ${CRT_CONFIG}) -CFLAGS += -Werror -g $(EXTRA_CFLAGS) +CRT_PREFIX = $(wildcard src/crt) + +INCLUDES ?= -isystem include -iquote $(dir ${CRT_CONFIG}) +CFLAGS += ${INCLUDES} -Werror -g $(EXTRA_CFLAGS) +CXXFLAGS += ${INCLUDES} -std=c++11 -Werror -g $(EXTRA_CXXFLAGS) LDFLAGS += -Werror -g $(EXTRA_LDFLAGS) -${BUILD_DIR}/%.o: %.c +${BUILD_DIR}/%.o: src/%.c $(CRT_CONFIG) ${QUIET}mkdir -p $(dir $@) ${QUIET}${CC} ${CFLAGS} -c -o "$@" "$<" -${BUILD_DIR}/common/libcommon.a: $(patsubst %.c,${BUILD_DIR}/%.o,$(wildcard common/*.c)) - ${QUIET}${AR} -cr "$@" $^ - ${QUIET}${RANLIB} ${RANLIBFLAGS} "$@" +${BUILD_DIR}/%.o: src/%.cc $(CRT_CONFIG) + ${QUIET}mkdir -p $(dir $@) + ${QUIET}${CXX} ${CXXFLAGS} -c -o "$@" "$<" + +define LIB_template +$${BUILD_DIR}/lib$(notdir $(1)).a: $$(patsubst src/%.c,$${BUILD_DIR}/%.o,$$(wildcard src/$(1:src/%=%)/*.c)) $$(patsubst src/%.cc,${BUILD_DIR}/%.o,$$(wildcard src/$(1:src/%=%)/*.cc)) + $${QUIET}$${AR} -cr "$$@" $$^ + $${QUIET}$${RANLIB} $${RANLIBFLAGS} "$$@" +$(notdir $(1)): $${BUILD_DIR}/lib$(notdir $(1)).a + +endef -${BUILD_DIR}/graph_runtime/libgraph_runtime.a: $(patsubst %.c,${BUILD_DIR}/%.o,$(wildcard graph_runtime/*.c)) - ${QUIET}${AR} -cr "$@" $^ - ${QUIET}${RANLIB} ${RANLIBFLAGS} "$@" +LIBS = src/runtime/crt/common src/runtime/crt/graph_runtime src/runtime/crt/utvm_rpc_common src/runtime/crt/utvm_rpc_server -common: ${BUILD_DIR}/common/libcommon.a -graph_runtime: ${BUILD_DIR}/graph_runtime/libgraph_runtime.a +$(foreach lib,$(LIBS),$(eval $(call LIB_template,$(lib)))) -all: common graph_runtime +all: $(notdir $(LIBS)) clean: rm -rf "${BUILD_DIR}" -.PHONY: all common graph_runtime +.PHONY: all $(notdir $(LIBS)) .DEFAULT_GOAL: all diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c index c1e994ffb8c7..d6f78d9e3a03 100644 --- a/src/runtime/crt/common/crt_runtime_api.c +++ b/src/runtime/crt/common/crt_runtime_api.c @@ -28,6 +28,7 @@ #include #include #include +#include #include #include #include @@ -306,9 +307,16 @@ int TVMFuncFree(TVMFunctionHandle func) { return 0; } -tvm_crt_error_t TVMInitializeRuntime() { +tvm_crt_error_t TVMInitializeRuntime(uint8_t* memory_pool, size_t memory_pool_size_bytes, + size_t page_size_bytes_log2) { int idx; - int error; + tvm_crt_error_t error; + + error = + TVMInitializeGlobalMemoryManager(memory_pool, memory_pool_size_bytes, page_size_bytes_log2); + if (error != kTvmErrorNoError) { + return error; + } system_lib_handle = kTVMModuleHandleUninitialized; @@ -320,14 +328,14 @@ tvm_crt_error_t TVMInitializeRuntime() { } error = TVMFuncRegisterGlobal("runtime.SystemLib", &SystemLibraryCreate, 0); - if (error != 0) { + if (error != kTvmErrorNoError) { return error; } error = TVMFuncRegisterGlobal("tvm.rpc.server.ModuleGetFunction", &ModuleGetFunction, 0); - if (error != 0) { + if (error != kTvmErrorNoError) { return error; } - return 0; + return kTvmErrorNoError; } diff --git a/src/runtime/crt/common/memory.c b/src/runtime/crt/common/memory.c index 7a634b9a7033..68cad3645146 100644 --- a/src/runtime/crt/common/memory.c +++ b/src/runtime/crt/common/memory.c @@ -32,16 +32,12 @@ #include #include #include -#include +#include #include +#include #include #include -/** - * \brief Memory pool for virtual dynamic memory allocation - */ -static uint8_t g_memory_pool[TVM_CRT_VIRT_MEM_SIZE]; - // construct a new page Page PageCreate(uint8_t* memory_pool, size_t page_size_bytes, tvm_index_t ptable_begin, tvm_index_t num_pages) { @@ -144,8 +140,8 @@ void* MemoryManager_Alloc(MemoryManager* mgr, tvm_index_t size) { } else { start = ptable->num_pages; CHECK_LE((unsigned)(start + npage), ptable->max_pages, - "insufficient memory, start=%" PRId64 ", npage=%" PRId64 ", total=%" PRId64 "", start, - npage, start + npage); + "insufficient memory, start=%" PRId32 ", npage=%" PRId32 ", total=%" PRId32 " / %zu", + (int32_t)start, (int32_t)npage, (int32_t)(start + npage), mgr->pmap.max_pages); /* insert page entry */ Page p = PageCreate(ptable->memory_pool, ptable->page_size_bytes, start, npage); ptable->resize(ptable, start + npage, &p); @@ -262,6 +258,9 @@ void MemoryManager_Free(MemoryManager* mgr, void* ptr) { #define ROUND_UP(qty, modulo) (((qty) + ((modulo)-1)) / (modulo) * (modulo)) +static bool g_memory_manager_initialized = 0; +static MemoryManager g_memory_manager; + void MemoryManagerCreate(MemoryManager* manager, uint8_t* memory_pool, size_t memory_pool_size_bytes, size_t page_size_bytes_log2) { memset(manager, 0, sizeof(MemoryManager)); @@ -308,16 +307,22 @@ void MemoryManagerCreate(MemoryManager* manager, uint8_t* memory_pool, manager->free_map.insert = MultiMap_Insert; } -MemoryManager* TVMGetGlobalMemoryManager() { - /* initialize once */ - static uint32_t initialized = 0; - static MemoryManager mgr; - if (!initialized) { - memset(g_memory_pool, 0, sizeof(g_memory_pool)); - MemoryManagerCreate(&mgr, g_memory_pool, TVM_CRT_VIRT_MEM_SIZE, TVM_CRT_PAGE_BYTES_LOG); - initialized = 1; +tvm_crt_error_t TVMInitializeGlobalMemoryManager(uint8_t* memory_pool, + size_t memory_pool_size_bytes, + size_t page_size_bytes_log2) { + if (g_memory_manager_initialized) { + return kTvmErrorPlatformMemoryManagerInitialized; } - return &mgr; + + MemoryManagerCreate(&g_memory_manager, memory_pool, memory_pool_size_bytes, page_size_bytes_log2); + + g_memory_manager_initialized = true; + return kTvmErrorNoError; +} + +MemoryManager* TVMGetGlobalMemoryManager() { + CHECK(g_memory_manager_initialized); + return &g_memory_manager; } /** \brief Allocate memory from manager */ diff --git a/src/runtime/crt/common/packed_func.c b/src/runtime/crt/common/packed_func.c index b5a3121357d5..e946cda9d9ae 100644 --- a/src/runtime/crt/common/packed_func.c +++ b/src/runtime/crt/common/packed_func.c @@ -25,7 +25,7 @@ */ #include #include -#include +#include #include DLDataType String2DLDataType(const char* s) { diff --git a/src/runtime/crt/crt_config-template.h b/src/runtime/crt/crt_config-template.h new file mode 100644 index 000000000000..67e0608ab696 --- /dev/null +++ b/src/runtime/crt/crt_config-template.h @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/runtime/crt_config.h.template + * \brief Template for CRT configuration, to be modified on each target. + */ +#ifndef TVM_RUNTIME_CRT_CRT_CONFIG_TEMPLATE_H_ +#define TVM_RUNTIME_CRT_CRT_CONFIG_TEMPLATE_H_ + +/*! Maximum supported dimension in NDArray */ +#define TVM_CRT_MAX_NDIM 6 + +/*! Maximum supported arguments in generated functions */ +#define TVM_CRT_MAX_ARGS 10 + +/*! Size of the global function registry, in bytes. */ +#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200 + +/*! Maximum number of registered modules. */ +#define TVM_CRT_MAX_REGISTERED_MODULES 2 + +/*! Maximum packet size, in bytes, including the length header. */ +#define TVM_CRT_MAX_PACKET_SIZE_BYTES 2048 + +/*! Maximum supported string length in dltype, e.g. "int8", "int16", "float32" */ +#define TVM_CRT_MAX_STRLEN_DLTYPE 10 + +/*! Maximum supported string length in function names */ +#define TVM_CRT_MAX_STRLEN_FUNCTION_NAME 80 + +/*! \brief Maximum length of a PackedFunc function name. */ +#define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 + +/*! \brief DLDataType for the return value from strlen */ +#define TVM_CRT_STRLEN_DLTYPE 10 + +#endif // TVM_RUNTIME_CRT_CRT_CONFIG_TEMPLATE_H_ diff --git a/src/runtime/crt/graph_runtime/graph_runtime.c b/src/runtime/crt/graph_runtime/graph_runtime.c index 01b07dcb6f1a..a6cd77ad6a22 100644 --- a/src/runtime/crt/graph_runtime/graph_runtime.c +++ b/src/runtime/crt/graph_runtime/graph_runtime.c @@ -25,8 +25,8 @@ */ #include -#include #include +#include #include #include #include diff --git a/src/runtime/crt/host/crt_config.h b/src/runtime/crt/host/crt_config.h index c0b02a69ba5b..689189629542 100644 --- a/src/runtime/crt/host/crt_config.h +++ b/src/runtime/crt/host/crt_config.h @@ -24,6 +24,9 @@ #ifndef TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_ #define TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_ +/*! Log level of the CRT runtime */ +#define TVM_CRT_LOG_LEVEL TVM_CRT_LOG_LEVEL_DEBUG + /*! Support low-level debugging in MISRA-C runtime */ #define TVM_CRT_DEBUG 0 @@ -36,33 +39,18 @@ /*! Maximum supported string length in function names */ #define TVM_CRT_STRLEN_NAME 80 -/*! - * \brief Log memory pool size for virtual memory allocation - * - * Here is a list of possible choices: - * * use 16 for 64 KiB memory space - * * use 17 for 128 KiB memory space - * * use 18 for 256 KiB memory space - * * use 19 for 512 KiB memory space - * * use 20 for 1 MiB memory space - * * use 21 for 2 MiB memory space - * * use 22 for 4 MiB memory space - * * use 23 for 8 MiB memory space - * * use 24 for 16 MiB memory space - * * use 25 for 32 MiB memory space - * * use 26 for 64 MiB memory space - * * use 27 for 128 MiB memory space - * * use 28 for 256 MiB memory space - */ -#define TVM_CRT_LOG_VIRT_MEM_SIZE 24 - -/*! \brief Log2 of page size for virtual memory allocation */ -#define TVM_CRT_PAGE_BYTES_LOG 12 - /*! Maximum number of registered modules. */ #define TVM_CRT_MAX_REGISTERED_MODULES 2 /*! Size of the global function registry, in bytes. */ #define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200 +/*! Maximum packet size, in bytes, including the length header. */ +#define TVM_CRT_MAX_PACKET_SIZE_BYTES 64000 + +/*! \brief Maximum length of a PackedFunc function name. */ +#define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30 + +// #define TVM_CRT_FRAMER_ENABLE_LOGS + #endif // TVM_RUNTIME_CRT_HOST_CRT_CONFIG_H_ diff --git a/src/runtime/crt/host/main.cc b/src/runtime/crt/host/main.cc new file mode 100644 index 000000000000..dcca305b8b65 --- /dev/null +++ b/src/runtime/crt/host/main.cc @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file main.cc + * \brief main entry point for host subprocess-based CRT + */ +#include +#include +#include +#include +#include + +#include +#include + +#include "crt_config.h" + +using namespace std::chrono; + +extern "C" { + +ssize_t UTvmWriteFunc(void* context, const uint8_t* data, size_t num_bytes) { + ssize_t to_return = write(STDOUT_FILENO, data, num_bytes); + fflush(stdout); + fsync(STDOUT_FILENO); + return to_return; +} + +void TVMPlatformAbort(tvm_crt_error_t error_code) { + std::cerr << "TVMPlatformAbort: " << error_code << std::endl; + throw "Aborted"; +} + +high_resolution_clock::time_point g_utvm_start_time; +int g_utvm_timer_running = 0; + +int TVMPlatformTimerStart() { + if (g_utvm_timer_running) { + std::cerr << "timer already running" << std::endl; + return -1; + } + g_utvm_start_time = high_resolution_clock::now(); + g_utvm_timer_running = 1; + return 0; +} + +int TVMPlatformTimerStop(double* res_us) { + if (!g_utvm_timer_running) { + std::cerr << "timer not running" << std::endl; + return -1; + } + auto utvm_stop_time = high_resolution_clock::now(); + duration time_span(utvm_stop_time - g_utvm_start_time); + *res_us = time_span.count(); + g_utvm_timer_running = 0; + return 0; +} +} + +uint8_t memory[512 * 1024]; + +static char** g_argv = NULL; + +int testonly_reset_server(TVMValue* args, int* type_codes, int num_args, TVMValue* out_ret_value, + int* out_ret_tcode, void* resource_handle) { + execvp(g_argv[0], g_argv); + perror("utvm runtime: error restarting"); + return -1; +} + +int main(int argc, char** argv) { + g_argv = argv; + utvm_rpc_server_t rpc_server = + UTvmRpcServerInit(memory, sizeof(memory), 8, &UTvmWriteFunc, nullptr); + + if (TVMFuncRegisterGlobal("tvm.testing.reset_server", (TVMFunctionHandle)&testonly_reset_server, + 0)) { + fprintf(stderr, "utvm runtime: internal error registering global packedfunc; exiting\n"); + return 2; + } + + setbuf(stdin, NULL); + setbuf(stdout, NULL); + + for (;;) { + uint8_t c; + int ret_code = read(STDIN_FILENO, &c, 1); + if (ret_code < 0) { + perror("utvm runtime: read failed"); + return 2; + } else if (ret_code == 0) { + fprintf(stderr, "utvm runtime: 0-length read, exiting!\n"); + return 2; + } + if (UTvmRpcServerReceiveByte(rpc_server, c) != 1) { + abort(); + } + if (!UTvmRpcServerLoop(rpc_server)) { + execvp(argv[0], argv); + perror("utvm runtime: error restarting"); + return 2; + } + } + return 0; +} diff --git a/src/runtime/crt/include/tvm/runtime/crt/internal/common/memory.h b/src/runtime/crt/include/tvm/runtime/crt/internal/common/memory.h index 8162fd7851a2..175d5e120df2 100644 --- a/src/runtime/crt/include/tvm/runtime/crt/internal/common/memory.h +++ b/src/runtime/crt/include/tvm/runtime/crt/internal/common/memory.h @@ -27,6 +27,7 @@ #define TVM_RUNTIME_CRT_INCLUDE_TVM_RUNTIME_CRT_INTERNAL_COMMON_MEMORY_H_ #include +#include #include "crt_config.h" @@ -34,15 +35,6 @@ extern "C" { #endif -/*! Number of bits in a page */ -#define TVM_CRT_PAGE_BITS ((1 << TVM_CRT_PAGE_BYTES_LOG) << 3) - -/*! \brief Translate log memory size into bytes */ -#define TVM_CRT_VIRT_MEM_SIZE (1 << TVM_CRT_LOG_VIRT_MEM_SIZE) - -/*! \brief Number of possible page entries in total */ -#define TVM_CRT_MAX_PAGES (TVM_CRT_VIRT_MEM_SIZE / TVM_CRT_PAGE_BYTES) - /*! \brief A page in the DRAM */ typedef struct Page { /*! \brief Start location in page table */ @@ -130,10 +122,31 @@ typedef struct MemoryManager { MultiMap free_map; } MemoryManager; -// Exposed for testing +/*! + * Exposed for testing. + * + * \param manager The memory manager to initialize. + * \param memory_pool Pointer to the global memory pool used by the CRT. + * \param memory_pool_size_bytes Size of `memory_pool`, in bytes. + * \param page_size_bytes_log2 log2 of the page size, in bytes. + */ void MemoryManagerCreate(MemoryManager* manager, uint8_t* memory_pool, size_t memory_pool_size_bytes, size_t page_size_bytes_log2); +/*! + * Initialize the global memory manager. + * + * Call this function once before invoking any other CRT functions beginning with `TVM`. + * Repeated calls will cause TVMPlatformAbort to be invoked. + * \param memory_pool Pointer to the global memory pool used by the CRT. + * \param memory_pool_size_bytes Size of `memory_pool`, in bytes. + * \param page_size_bytes_log2 log2 of the page size, in bytes. + * \return An error code indicating the status of the operation. + */ +tvm_crt_error_t TVMInitializeGlobalMemoryManager(uint8_t* memory_pool, + size_t memory_pool_size_bytes, + size_t page_size_bytes_log2); + #ifdef __cplusplus } // extern "C" #endif diff --git a/src/runtime/crt/utvm_rpc_common/frame_buffer.cc b/src/runtime/crt/utvm_rpc_common/frame_buffer.cc new file mode 100644 index 000000000000..37eb274eb944 --- /dev/null +++ b/src/runtime/crt/utvm_rpc_common/frame_buffer.cc @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file frame_buffer.cc + * \brief Defines a buffer for use by the RPC framing layer. + */ + +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace micro_rpc { + +size_t FrameBuffer::Write(const uint8_t* data, size_t data_size_bytes) { + size_t num_bytes_available = capacity_ - num_valid_bytes_; + size_t num_bytes_to_copy = data_size_bytes; + if (num_bytes_available < num_bytes_to_copy) { + num_bytes_to_copy = num_bytes_available; + } + + memcpy(&data_[num_valid_bytes_], data, num_bytes_to_copy); + num_valid_bytes_ += num_bytes_to_copy; + return num_bytes_to_copy; +} + +size_t FrameBuffer::Read(uint8_t* data, size_t data_size_bytes) { + size_t num_bytes_to_copy = data_size_bytes; + size_t num_bytes_available = num_valid_bytes_ - read_cursor_; + if (num_bytes_available < num_bytes_to_copy) { + num_bytes_to_copy = num_bytes_available; + } + + memcpy(data, &data_[read_cursor_], num_bytes_to_copy); + read_cursor_ += num_bytes_to_copy; + return num_bytes_to_copy; +} + +void FrameBuffer::Clear() { + num_valid_bytes_ = 0; + read_cursor_ = 0; +} + +} // namespace micro_rpc +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/crt/utvm_rpc_common/framing.cc b/src/runtime/crt/utvm_rpc_common/framing.cc new file mode 100644 index 000000000000..e40ea071f3dd --- /dev/null +++ b/src/runtime/crt/utvm_rpc_common/framing.cc @@ -0,0 +1,411 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file framing.cc + * \brief Framing for RPC. + */ + +#include +#include +#include + +#include "crt_config.h" + +// For debugging purposes, Framer logs can be enabled, but this should only be done when +// running from the host. This is done differently from TVMLogf() because TVMLogf() uses the +// framer in its implementation. +#ifdef TVM_CRT_FRAMER_ENABLE_LOGS +#include +#define TVM_FRAMER_DEBUG_LOG(msg, ...) fprintf(stderr, "utvm framer: " msg " \n", ##__VA_ARGS__) +#define TVM_UNFRAMER_DEBUG_LOG(msg, ...) fprintf(stderr, "utvm unframer: " msg " \n", ##__VA_ARGS__) +#else +#define TVM_FRAMER_DEBUG_LOG(msg, ...) +#define TVM_UNFRAMER_DEBUG_LOG(msg, ...) +#endif + +namespace tvm { +namespace runtime { +namespace micro_rpc { + +template +static constexpr uint8_t to_integral(E e) { + return static_cast(e); +} + +void Unframer::Reset() { + state_ = State::kFindPacketStart; + saw_escape_start_ = false; + num_buffer_bytes_valid_ = 0; +} + +tvm_crt_error_t Unframer::Write(const uint8_t* data, size_t data_size_bytes, + size_t* bytes_consumed) { + tvm_crt_error_t return_code = kTvmErrorNoError; + input_ = data; + input_size_bytes_ = data_size_bytes; + + while (return_code == kTvmErrorNoError && input_size_bytes_ > 0) { + TVM_UNFRAMER_DEBUG_LOG("state: %02x size 0x%02zx", to_integral(state_), input_size_bytes_); + switch (state_) { + case State::kFindPacketStart: + return_code = FindPacketStart(); + break; + case State::kFindPacketLength: + return_code = FindPacketLength(); + break; + case State::kFindPacketCrc: + return_code = FindPacketCrc(); + break; + case State::kFindCrcEnd: + return_code = FindCrcEnd(); + break; + default: + return_code = kTvmErrorFramingInvalidState; + break; + } + } + + *bytes_consumed = data_size_bytes - input_size_bytes_; + input_ = nullptr; + input_size_bytes_ = 0; + + if (return_code != kTvmErrorNoError) { + state_ = State::kFindPacketStart; + ClearBuffer(); + } + + return return_code; +} + +tvm_crt_error_t Unframer::FindPacketStart() { + size_t i; + for (i = 0; i < input_size_bytes_; ++i) { + if (input_[i] == to_integral(Escape::kEscapeStart)) { + saw_escape_start_ = true; + } else if (input_[i] == to_integral(Escape::kPacketStart) && saw_escape_start_) { + uint8_t packet_start_sequence[2]{to_integral(Escape::kEscapeStart), + to_integral(Escape::kPacketStart)}; + crc_ = crc16_compute(packet_start_sequence, sizeof(packet_start_sequence), nullptr); + saw_escape_start_ = false; + state_ = State::kFindPacketLength; + i++; + break; + } else { + saw_escape_start_ = false; + } + } + + input_ += i; + input_size_bytes_ -= i; + return kTvmErrorNoError; +} + +tvm_crt_error_t Unframer::ConsumeInput(uint8_t* buffer, size_t buffer_size_bytes, + size_t* bytes_filled, bool update_crc) { + CHECK(*bytes_filled < buffer_size_bytes); + tvm_crt_error_t to_return = kTvmErrorNoError; + size_t i; + for (i = 0; i < input_size_bytes_; ++i) { + uint8_t c = input_[i]; + if (saw_escape_start_) { + saw_escape_start_ = false; + if (c == to_integral(Escape::kPacketStart)) { + // When the start packet sequence is seen, abort unframing the current packet. Since the + // escape byte has already been parsed, update the CRC include only the escape byte. This + // readies the unframer to consume the kPacketStart byte on the next Write() call. + uint8_t escape_start = to_integral(Escape::kEscapeStart); + crc_ = crc16_compute(&escape_start, 1, NULL); + to_return = kTvmErrorFramingShortPacket; + saw_escape_start_ = true; + + break; + } else if (c == to_integral(Escape::kEscapeNop)) { + continue; + } else if (c == to_integral(Escape::kEscapeStart)) { + // do nothing (allow character to be printed) + } else { + // Invalid escape sequence. + to_return = kTvmErrorFramingInvalidEscape; + i++; + break; + } + } else if (c == to_integral(Escape::kEscapeStart)) { + saw_escape_start_ = true; + continue; + } else { + saw_escape_start_ = false; + } + + buffer[*bytes_filled] = c; + (*bytes_filled)++; + if (*bytes_filled == buffer_size_bytes) { + i++; + break; + } + } + + if (update_crc) { + crc_ = crc16_compute(input_, i, &crc_); + } + + input_ += i; + input_size_bytes_ -= i; + return to_return; +} + +tvm_crt_error_t Unframer::AddToBuffer(size_t buffer_full_bytes, bool update_crc) { + CHECK(!IsBufferFull(buffer_full_bytes)); + return ConsumeInput(buffer_, buffer_full_bytes, &num_buffer_bytes_valid_, update_crc); +} + +void Unframer::ClearBuffer() { num_buffer_bytes_valid_ = 0; } + +tvm_crt_error_t Unframer::FindPacketLength() { + tvm_crt_error_t to_return = AddToBuffer(PacketFieldSizeBytes::kPayloadLength, true); + if (to_return != kTvmErrorNoError) { + return to_return; + } + + if (!IsBufferFull(PacketFieldSizeBytes::kPayloadLength)) { + return to_return; + } + + num_payload_bytes_remaining_ = *reinterpret_cast(buffer_); + TVM_UNFRAMER_DEBUG_LOG("payload length: 0x%zx", num_payload_bytes_remaining_); + ClearBuffer(); + state_ = State::kFindPacketCrc; + return to_return; +} + +tvm_crt_error_t Unframer::FindPacketCrc() { + // CHECK(num_buffer_bytes_valid_ == 0); + while (num_payload_bytes_remaining_ > 0) { + size_t num_bytes_to_buffer = num_payload_bytes_remaining_; + if (num_bytes_to_buffer > sizeof(buffer_)) { + num_bytes_to_buffer = sizeof(buffer_); + } + + // remember in case we need to rewind due to WriteAll() error. + size_t prev_input_size_bytes = input_size_bytes_; + size_t prev_num_buffer_bytes_valid = num_buffer_bytes_valid_; + { + tvm_crt_error_t to_return = AddToBuffer(num_bytes_to_buffer, true); + if (to_return != kTvmErrorNoError) { + return to_return; + } + } + + if (prev_num_buffer_bytes_valid == num_buffer_bytes_valid_) { + // Return if no bytes were consumed from the input. + return kTvmErrorNoError; + } + + { + size_t bytes_consumed; + tvm_crt_error_t to_return = + stream_->WriteAll(buffer_, num_buffer_bytes_valid_, &bytes_consumed); + num_payload_bytes_remaining_ -= bytes_consumed; + if (to_return != kTvmErrorNoError) { + // rewind input, skipping escape bytes. + size_t buffer_bytes_consumed; + const uint8_t* input = input_ - (prev_input_size_bytes - input_size_bytes_); + for (buffer_bytes_consumed = 0; bytes_consumed > 0; ++buffer_bytes_consumed) { + if (input[buffer_bytes_consumed] != uint8_t(Escape::kEscapeStart)) { + bytes_consumed--; + } + } + + size_t bytes_to_rewind = prev_input_size_bytes - buffer_bytes_consumed; + input_ -= bytes_to_rewind; + input_size_bytes_ += bytes_to_rewind; + + // must not have seen escape, since AddToBuffer won't stop in the middle. + saw_escape_start_ = false; + + return to_return; + } + } + + ClearBuffer(); + } + + if (num_payload_bytes_remaining_ == 0) { + state_ = State::kFindCrcEnd; + } + + return kTvmErrorNoError; +} + +tvm_crt_error_t Unframer::FindCrcEnd() { + tvm_crt_error_t to_return = AddToBuffer(PacketFieldSizeBytes::kCrc, false); + if (to_return != kTvmErrorNoError) { + return to_return; + } + + if (!IsBufferFull(PacketFieldSizeBytes::kCrc)) { + return kTvmErrorNoError; + } + + // TODO(areusch): Handle endianness. + stream_->PacketDone(crc_ == *reinterpret_cast(buffer_)); + ClearBuffer(); + state_ = State::kFindPacketStart; + return kTvmErrorNoError; +} + +void Framer::Reset() { state_ = State::kReset; } + +tvm_crt_error_t Framer::Write(const uint8_t* payload, size_t payload_size_bytes) { + tvm_crt_error_t to_return; + to_return = StartPacket(payload_size_bytes); + if (to_return != kTvmErrorNoError) { + return to_return; + } + + to_return = WritePayloadChunk(payload, payload_size_bytes); + if (to_return != 0) { + return to_return; + } + + to_return = FinishPacket(); + return to_return; +} + +tvm_crt_error_t Framer::StartPacket(size_t payload_size_bytes) { + uint8_t packet_header[sizeof(uint32_t)]; + size_t ptr = 0; + if (state_ == State::kReset) { + packet_header[ptr] = to_integral(Escape::kEscapeNop); + ptr++; + tvm_crt_error_t to_return = + WriteAndCrc(packet_header, ptr, false /* escape */, false /* update_crc */); + if (to_return != kTvmErrorNoError) { + return to_return; + } + + ptr = 0; + } + + packet_header[ptr] = to_integral(Escape::kEscapeStart); + ptr++; + packet_header[ptr] = to_integral(Escape::kPacketStart); + ptr++; + + crc_ = 0xffff; + tvm_crt_error_t to_return = + WriteAndCrc(packet_header, ptr, false /* escape */, true /* update_crc */); + if (to_return != kTvmErrorNoError) { + return to_return; + } + + uint32_t payload_size_wire = payload_size_bytes; + to_return = WriteAndCrc(reinterpret_cast(&payload_size_wire), sizeof(payload_size_wire), + true /* escape */, true /* update_crc */); + if (to_return == kTvmErrorNoError) { + state_ = State::kTransmitPacketPayload; + num_payload_bytes_remaining_ = payload_size_bytes; + } + + return to_return; +} + +tvm_crt_error_t Framer::WriteAndCrc(const uint8_t* data, size_t data_size_bytes, bool escape, + bool update_crc) { + while (data_size_bytes > 0) { + uint8_t buffer[kMaxStackBufferSizeBytes]; + size_t buffer_ptr = 0; + size_t i; + for (i = 0; i < data_size_bytes && buffer_ptr != kMaxStackBufferSizeBytes; ++i) { + uint8_t c = data[i]; + if (!escape || c != to_integral(Escape::kEscapeStart)) { + buffer[buffer_ptr] = c; + buffer_ptr++; + continue; + } + + if (buffer_ptr == kMaxStackBufferSizeBytes - 1) { + break; + } + + buffer[buffer_ptr] = to_integral(Escape::kEscapeStart); + buffer_ptr++; + + buffer[buffer_ptr] = to_integral(Escape::kEscapeStart); + buffer_ptr++; + } + + size_t bytes_consumed; + tvm_crt_error_t to_return = stream_->WriteAll(buffer, buffer_ptr, &bytes_consumed); + if (to_return != kTvmErrorNoError) { + return to_return; + } + + if (update_crc) { + crc_ = crc16_compute(buffer, buffer_ptr, &crc_); + } + + data_size_bytes -= i; + data += i; + } + + return kTvmErrorNoError; +} + +tvm_crt_error_t Framer::WritePayloadChunk(const uint8_t* payload_chunk, + size_t payload_chunk_size_bytes) { + if (state_ != State::kTransmitPacketPayload) { + return kTvmErrorFramingInvalidState; + } else if (payload_chunk_size_bytes > num_payload_bytes_remaining_) { + return kTvmErrorFramingPayloadOverflow; + } + + TVM_FRAMER_DEBUG_LOG("write payload chunk: %" PRIuMAX " bytes", payload_chunk_size_bytes); + tvm_crt_error_t to_return = WriteAndCrc(payload_chunk, payload_chunk_size_bytes, + true /* escape */, true /* update_crc */); + if (to_return != kTvmErrorNoError) { + state_ = State::kReset; + return to_return; + } + + num_payload_bytes_remaining_ -= payload_chunk_size_bytes; + return kTvmErrorNoError; +} + +tvm_crt_error_t Framer::FinishPacket() { + if (state_ != State::kTransmitPacketPayload) { + return kTvmErrorFramingInvalidState; + } else if (num_payload_bytes_remaining_ != 0) { + return kTvmErrorFramingPayloadIncomplete; + } + + tvm_crt_error_t to_return = WriteAndCrc(reinterpret_cast(&crc_), sizeof(crc_), + true /* escape */, false /* update_crc */); + if (to_return != kTvmErrorNoError) { + TVM_FRAMER_DEBUG_LOG("write and crc returned: %02x", to_return); + state_ = State::kReset; + } else { + state_ = State::kIdle; + } + return to_return; +} + +} // namespace micro_rpc +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/crt/utvm_rpc_common/session.cc b/src/runtime/crt/utvm_rpc_common/session.cc new file mode 100644 index 000000000000..5930863da37a --- /dev/null +++ b/src/runtime/crt/utvm_rpc_common/session.cc @@ -0,0 +1,279 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file session.h + * \brief RPC Session + */ + +#include +#include + +#include "crt_config.h" + +namespace tvm { +namespace runtime { +namespace micro_rpc { + +struct utvm_session_start_payload_t { + uint8_t version; +}; + +void Session::RegenerateNonce() { + local_nonce_ = (((local_nonce_ << 5) | (local_nonce_ >> 5)) + 1); + + if (local_nonce_ == kInvalidNonce) { + local_nonce_++; + } +} + +tvm_crt_error_t Session::SendInternal(MessageType message_type, const uint8_t* message_data, + size_t message_size_bytes) { + tvm_crt_error_t to_return = StartMessage(message_type, message_size_bytes); + if (to_return != kTvmErrorNoError) { + return to_return; + } + + if (message_size_bytes > 0) { + to_return = SendBodyChunk(message_data, message_size_bytes); + if (to_return != kTvmErrorNoError) { + return to_return; + } + } + + return framer_->FinishPacket(); +} + +tvm_crt_error_t Session::StartMessage(MessageType message_type, size_t message_size_bytes) { + SessionHeader header{session_id_, message_type}; + if (message_type == MessageType::kLog) { + header.session_id = 0; + } + + tvm_crt_error_t to_return = framer_->StartPacket(message_size_bytes + sizeof(SessionHeader)); + if (to_return != 0) { + return to_return; + } + + return framer_->WritePayloadChunk(reinterpret_cast(&header), sizeof(SessionHeader)); +} + +tvm_crt_error_t Session::SendBodyChunk(const uint8_t* chunk, size_t chunk_size_bytes) { + return framer_->WritePayloadChunk(chunk, chunk_size_bytes); +} + +tvm_crt_error_t Session::FinishMessage() { return framer_->FinishPacket(); } + +tvm_crt_error_t Session::StartSession() { + CHECK_NE(state_, State::kReset, "must call Initialize"); + + RegenerateNonce(); + SetSessionId(local_nonce_, 0); + utvm_session_start_payload_t payload = {Session::kVersion}; + tvm_crt_error_t to_return = SendInternal(MessageType::kStartSessionInit, + reinterpret_cast(&payload), sizeof(payload)); + if (to_return == 0) { + state_ = State::kStartSessionSent; + } + + return to_return; +} + +tvm_crt_error_t Session::Initialize() { return TerminateSession(); } + +tvm_crt_error_t Session::TerminateSession() { + SetSessionId(0, 0); + state_ = State::kNoSessionEstablished; + return SendInternal(MessageType::kTerminateSession, nullptr, 0); +} + +tvm_crt_error_t Session::SendMessage(MessageType message_type, const uint8_t* message_data, + size_t message_size_bytes) { + if (state_ != State::kSessionEstablished && message_type != MessageType::kLog) { + return kTvmErrorSessionInvalidState; + } + + return SendInternal(message_type, message_data, message_size_bytes); +} + +ssize_t Session::SessionReceiver::Write(const uint8_t* data, size_t data_size_bytes) { + if (session_->receive_buffer_has_complete_message_) { + return kTvmErrorSessionReceiveBufferBusy; + } + + size_t bytes_written = session_->receive_buffer_->Write(data, data_size_bytes); + if (bytes_written != data_size_bytes) { + return kTvmErrorSessionReceiveBufferShortWrite; + } + + return bytes_written; +} + +void Session::SessionReceiver::PacketDone(bool is_valid) { + if (!is_valid) { + return; + } + + SessionHeader header; + int bytes_read = + session_->receive_buffer_->Read(reinterpret_cast(&header), sizeof(header)); + if (bytes_read != sizeof(header)) { + return; + } + session_->receive_buffer_has_complete_message_ = true; + + switch (header.message_type) { + case MessageType::kStartSessionInit: + session_->ProcessStartSessionInit(header); + session_->receive_buffer_has_complete_message_ = false; + break; + case MessageType::kStartSessionReply: + session_->ProcessStartSessionReply(header); + session_->receive_buffer_has_complete_message_ = false; + break; + case MessageType::kTerminateSession: + if (session_->state_ == State::kSessionEstablished) { + session_->state_ = State::kNoSessionEstablished; + session_->OnSessionTerminatedMessage(); + } + session_->receive_buffer_has_complete_message_ = false; + break; + case MessageType::kLog: + if (header.session_id == 0 || header.session_id == session_->session_id_) { + // Special case for log messages: session id can be 0. + session_->message_received_func_(session_->message_received_func_context_, + header.message_type, session_->receive_buffer_); + } + break; + default: + if (session_->state_ == State::kSessionEstablished && + header.session_id == session_->session_id_) { + session_->message_received_func_(session_->message_received_func_context_, + header.message_type, session_->receive_buffer_); + } + break; + } +} + +void Session::ClearReceiveBuffer() { + receive_buffer_has_complete_message_ = false; + receive_buffer_->Clear(); +} + +void Session::SendSessionStartReply(const SessionHeader& header) { + RegenerateNonce(); + SetSessionId(InitiatorNonce(header.session_id), local_nonce_); + utvm_session_start_payload_t payload = {Session::kVersion}; + tvm_crt_error_t to_return = SendInternal(MessageType::kStartSessionReply, + reinterpret_cast(&payload), sizeof(payload)); + state_ = State::kSessionEstablished; + CHECK_EQ(to_return, kTvmErrorNoError, "SendSessionStartReply"); + OnSessionEstablishedMessage(); +} + +void Session::ProcessStartSessionInit(const SessionHeader& header) { + if (InitiatorNonce(header.session_id) == kInvalidNonce) { + return; + } + + utvm_session_start_payload_t payload; + int bytes_read = receive_buffer_->Read(reinterpret_cast(&payload), sizeof(payload)); + if (bytes_read != sizeof(payload)) { + return; + } + + switch (state_) { + case State::kReset: + case State::kNoSessionEstablished: + // Normal case: received a StartSession packet from reset. + SendSessionStartReply(header); + break; + + case State::kStartSessionSent: + // When two StartSessionInit packets sent simultaneously: lowest nonce wins; ties retry. + if (InitiatorNonce(header.session_id) < local_nonce_) { + if (payload.version == Session::kVersion) { + SendSessionStartReply(header); + } + } else if (InitiatorNonce(header.session_id) == local_nonce_) { + StartSession(); + } + + break; + + case State::kSessionEstablished: + SendSessionStartReply(header); + OnSessionEstablishedMessage(); + break; + + default: + state_ = State::kReset; + } +} + +void Session::ProcessStartSessionReply(const SessionHeader& header) { + if (ResponderNonce(header.session_id) == kInvalidNonce) { + return; + } + + utvm_session_start_payload_t payload; + int bytes_read = receive_buffer_->Read(reinterpret_cast(&payload), sizeof(payload)); + if (bytes_read != sizeof(payload)) { + return; + } + + switch (state_) { + case State::kReset: + case State::kNoSessionEstablished: + break; + case State::kStartSessionSent: + if (InitiatorNonce(header.session_id) == local_nonce_ && + payload.version == Session::kVersion) { + SetSessionId(local_nonce_, ResponderNonce(header.session_id)); + state_ = State::kSessionEstablished; + OnSessionEstablishedMessage(); + } + break; + case State::kSessionEstablished: + if (InitiatorNonce(header.session_id) != kInvalidNonce && + ResponderNonce(header.session_id) == kInvalidNonce) { + if (payload.version == Session::kVersion) { + SendSessionStartReply(header); + } else { + SetSessionId(local_nonce_, 0); + state_ = State::kReset; + } + } else { + state_ = State::kReset; + } + break; + } +} + +void Session::OnSessionEstablishedMessage() { + message_received_func_(message_received_func_context_, MessageType::kStartSessionReply, NULL); +} + +void Session::OnSessionTerminatedMessage() { + message_received_func_(message_received_func_context_, MessageType::kTerminateSession, NULL); +} + +} // namespace micro_rpc +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/crt/utvm_rpc_common/write_stream.cc b/src/runtime/crt/utvm_rpc_common/write_stream.cc new file mode 100644 index 000000000000..034b25306a6b --- /dev/null +++ b/src/runtime/crt/utvm_rpc_common/write_stream.cc @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file framing.h + * \brief Framing for RPC. + */ +#include + +namespace tvm { +namespace runtime { +namespace micro_rpc { + +WriteStream::~WriteStream() {} + +tvm_crt_error_t WriteStream::WriteAll(uint8_t* data, size_t data_size_bytes, + size_t* bytes_consumed) { + *bytes_consumed = 0; + while (data_size_bytes > 0) { + ssize_t to_return = Write(data, data_size_bytes); + if (to_return == 0) { + return kTvmErrorWriteStreamShortWrite; + } else if (to_return < 0) { + return (tvm_crt_error_t)to_return; + } else if (to_return > 0 && ((size_t)to_return) > data_size_bytes) { + return kTvmErrorWriteStreamLongWrite; + } + + data += to_return; + data_size_bytes -= to_return; + *bytes_consumed += to_return; + } + + return kTvmErrorNoError; +} + +} // namespace micro_rpc +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/crt/utvm_rpc_server/rpc_server.cc b/src/runtime/crt/utvm_rpc_server/rpc_server.cc new file mode 100644 index 000000000000..f36e67223c98 --- /dev/null +++ b/src/runtime/crt/utvm_rpc_server/rpc_server.cc @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file utvm_rpc_server.cc + * \brief MicroTVM RPC Server + */ + +#include +#include +#include +#include +#include + +// NOTE: dmlc/base.h contains some declarations that are incompatible with some C embedded +// toolchains. Just pull the bits we need for this file. +#define DMLC_CMAKE_LITTLE_ENDIAN DMLC_IO_USE_LITTLE_ENDIAN +#define DMLC_LITTLE_ENDIAN true +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../minrpc/minrpc_server.h" +#include "crt_config.h" + +namespace tvm { +namespace runtime { +namespace micro_rpc { + +class MicroIOHandler { + public: + MicroIOHandler(Session* session, FrameBuffer* receive_buffer) + : session_{session}, receive_buffer_{receive_buffer} {} + + void MessageStart(size_t message_size_bytes) { + session_->StartMessage(MessageType::kNormal, message_size_bytes + 8); + } + + ssize_t PosixWrite(const uint8_t* buf, size_t buf_size_bytes) { + int to_return = session_->SendBodyChunk(buf, buf_size_bytes); + if (to_return < 0) { + return to_return; + } + return buf_size_bytes; + } + + void MessageDone() { CHECK_EQ(session_->FinishMessage(), kTvmErrorNoError, "FinishMessage"); } + + ssize_t PosixRead(uint8_t* buf, size_t buf_size_bytes) { + return receive_buffer_->Read(buf, buf_size_bytes); + } + + void Close() {} + + void Exit(int code) { + for (;;) { + } + } + + private: + Session* session_; + FrameBuffer* receive_buffer_; +}; + +namespace { +// Stored as globals so that they can be used to report initialization errors. +utvm_rpc_channel_write_t g_write_func = nullptr; +void* g_write_func_ctx = nullptr; +} // namespace + +class SerialWriteStream : public WriteStream { + public: + SerialWriteStream() {} + virtual ~SerialWriteStream() {} + + ssize_t Write(const uint8_t* data, size_t data_size_bytes) override { + return g_write_func(g_write_func_ctx, data, data_size_bytes); + } + + void PacketDone(bool is_valid) override {} + + private: + void operator delete(void*) noexcept {} // NOLINT(readability/casting) +}; + +class MicroRPCServer { + public: + MicroRPCServer(uint8_t* receive_storage, size_t receive_storage_size_bytes, + utvm_rpc_channel_write_t write_func, void* write_func_ctx) + : receive_buffer_{receive_storage, receive_storage_size_bytes}, + framer_{&send_stream_}, + session_{0xa5, &framer_, &receive_buffer_, &HandleCompleteMessageCb, this}, + io_{&session_, &receive_buffer_}, + unframer_{session_.Receiver()}, + rpc_server_{&io_}, + has_pending_byte_{false}, + is_running_{true} {} + + void* operator new(size_t count, void* ptr) { return ptr; } + + void Initialize() { CHECK_EQ(kTvmErrorNoError, session_.Initialize(), "rpc server init"); } + + /*! \brief Process one message from the receive buffer, if possible. + * + * \return true if additional messages could be processed. false if the server shutdown request + * has been received. + */ + bool Loop() { + if (has_pending_byte_) { + size_t bytes_consumed; + CHECK_EQ(unframer_.Write(&pending_byte_, 1, &bytes_consumed), kTvmErrorNoError, + "unframer_.Write"); + CHECK_EQ(bytes_consumed, 1, "bytes_consumed"); + has_pending_byte_ = false; + } + + return is_running_; + } + + void HandleReceivedByte(uint8_t byte) { + CHECK(!has_pending_byte_); + has_pending_byte_ = true; + pending_byte_ = byte; + } + + void Log(const uint8_t* message, size_t message_size_bytes) { + tvm_crt_error_t to_return = + session_.SendMessage(MessageType::kLog, message, message_size_bytes); + if (to_return != 0) { + TVMPlatformAbort(to_return); + } + } + + private: + FrameBuffer receive_buffer_; + SerialWriteStream send_stream_; + Framer framer_; + Session session_; + MicroIOHandler io_; + Unframer unframer_; + MinRPCServer rpc_server_; + + bool has_pending_byte_; + uint8_t pending_byte_; + bool is_running_; + + void HandleCompleteMessage(MessageType message_type, FrameBuffer* buf) { + if (message_type != MessageType::kNormal) { + return; + } + + is_running_ = rpc_server_.ProcessOnePacket(); + session_.ClearReceiveBuffer(); + } + + static void HandleCompleteMessageCb(void* context, MessageType message_type, FrameBuffer* buf) { + static_cast(context)->HandleCompleteMessage(message_type, buf); + } +}; + +} // namespace micro_rpc +} // namespace runtime +} // namespace tvm + +void* operator new[](size_t count, void* ptr) noexcept { return ptr; } + +extern "C" { + +static utvm_rpc_server_t g_rpc_server = nullptr; + +utvm_rpc_server_t UTvmRpcServerInit(uint8_t* memory, size_t memory_size_bytes, + size_t page_size_bytes_log2, + utvm_rpc_channel_write_t write_func, void* write_func_ctx) { + tvm::runtime::micro_rpc::g_write_func = write_func; + tvm::runtime::micro_rpc::g_write_func_ctx = write_func_ctx; + + tvm_crt_error_t err = TVMInitializeRuntime(memory, memory_size_bytes, page_size_bytes_log2); + if (err != kTvmErrorNoError) { + TVMPlatformAbort(err); + } + + auto receive_buffer = + new (vmalloc(TVM_CRT_MAX_PACKET_SIZE_BYTES)) uint8_t[TVM_CRT_MAX_PACKET_SIZE_BYTES]; + auto rpc_server = new (vmalloc(sizeof(tvm::runtime::micro_rpc::MicroRPCServer))) + tvm::runtime::micro_rpc::MicroRPCServer(receive_buffer, TVM_CRT_MAX_PACKET_SIZE_BYTES, + write_func, write_func_ctx); + g_rpc_server = static_cast(rpc_server); + rpc_server->Initialize(); + return g_rpc_server; +} + +void TVMLogf(const char* format, ...) { + va_list args; + char log_buffer[256]; + va_start(args, format); + size_t num_bytes_logged = vsnprintf(log_buffer, sizeof(log_buffer), format, args); + va_end(args); + + // Most header-based logging frameworks tend to insert '\n' at the end of the log message. + // Remove that for remote logging, since the remote logger will do the same. + if (num_bytes_logged > 0 && log_buffer[num_bytes_logged - 1] == '\n') { + log_buffer[num_bytes_logged - 1] = 0; + num_bytes_logged--; + } + + if (g_rpc_server != nullptr) { + static_cast(g_rpc_server) + ->Log(reinterpret_cast(log_buffer), num_bytes_logged); + } else { + tvm::runtime::micro_rpc::SerialWriteStream write_stream; + tvm::runtime::micro_rpc::Framer framer{&write_stream}; + tvm::runtime::micro_rpc::Session session{0xa5, &framer, nullptr, nullptr, nullptr}; + tvm_crt_error_t err = + session.SendMessage(tvm::runtime::micro_rpc::MessageType::kLog, + reinterpret_cast(log_buffer), num_bytes_logged); + if (err != kTvmErrorNoError) { + TVMPlatformAbort(err); + } + } +} + +size_t UTvmRpcServerReceiveByte(utvm_rpc_server_t server_ptr, uint8_t byte) { + // NOTE(areusch): In the future, this function is intended to work from an IRQ context. That's not + // needed at present. + tvm::runtime::micro_rpc::MicroRPCServer* server = + static_cast(server_ptr); + server->HandleReceivedByte(byte); + return 1; +} + +bool UTvmRpcServerLoop(utvm_rpc_server_t server_ptr) { + tvm::runtime::micro_rpc::MicroRPCServer* server = + static_cast(server_ptr); + return server->Loop(); +} + +} // extern "C" diff --git a/src/runtime/micro/device/arm/stm32f746xx/utvm_init.s b/src/runtime/micro/device/arm/stm32f746xx/utvm_init.s deleted file mode 100644 index f5720f4d7b28..000000000000 --- a/src/runtime/micro/device/arm/stm32f746xx/utvm_init.s +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -.syntax unified -.cpu cortex-m7 -.fpu softvfp -.thumb - -.section .text.UTVMInit -.type UTVMInit, %function -UTVMInit: - /* enable fpu */ - ldr r0, =0xE000ED88 - ldr r1, [r0] - ldr r2, =0xF00000 - orr r1, r2 - str r1, [r0] - dsb - isb - /* set stack pointer */ - ldr sp, =_utvm_stack_pointer_init - bl UTVMMain -.size UTVMInit, .-UTVMInit diff --git a/src/runtime/micro/device/arm/stm32f746xx/utvm_timer.c b/src/runtime/micro/device/arm/stm32f746xx/utvm_timer.c deleted file mode 100644 index ae2b1994df12..000000000000 --- a/src/runtime/micro/device/arm/stm32f746xx/utvm_timer.c +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file utvm_timer.c - * \brief uTVM timer API definitions for STM32F746XX-series boards - */ - -#ifdef __cplusplus -extern "C" { -#endif - -#include - -#include "utvm_runtime.h" -// NOTE: This expects ST CMSIS to be in your include path. -// Download STM32CubeF7 here: -// https://www.st.com/content/st_com/en/products/embedded-software/mcu-mpu-embedded-software/stm32-embedded-software/stm32cube-mcu-mpu-packages/stm32cubef7.html -// and add Drivers/CMSIS to your C include path. -#include "Device/ST/STM32F7xx/Include/stm32f746xx.h" - -#define utvm_SystemCoreClock 216000000UL - -int32_t UTVMTimerStart() { - UTVMTimerReset(); - TIM2->CR1 = TIM_CR1_CEN; // Start counter - return UTVM_ERR_OK; -} - -uint32_t UTVMTimerStop(int32_t* err) { - TIM2->CR1 &= TIM_CR1_CEN; - if (TIM2->SR & TIM_SR_UIF_Msk) { - *err = UTVM_ERR_TIMER_OVERFLOW; - return 0; - } - *err = UTVM_ERR_OK; - uint32_t tim_cnt = TIM2->CNT; - uint32_t millis = tim_cnt / (utvm_SystemCoreClock / 1000); - uint32_t micros = - (tim_cnt - (millis * (utvm_SystemCoreClock / 1000))) / (utvm_SystemCoreClock / 1000000); - return millis * 1000 + micros; -} - -void UTVMTimerReset() { - RCC->APB1RSTR |= RCC_APB1RSTR_TIM2RST; // Hold TIM2 in reset - RCC->DCKCFGR1 = (RCC->DCKCFGR1 & ~RCC_DCKCFGR1_TIMPRE_Msk); // disable 2x clock boost to TIM2 - RCC->CFGR = (RCC->CFGR & ~RCC_CFGR_PPRE1_Msk); // No AHB clock division to APB1 (1:1). - RCC->APB1ENR |= RCC_APB1ENR_TIM2EN; // Enable TIM2 clock. - RCC->APB1RSTR &= ~RCC_APB1RSTR_TIM2RST; // Exit TIM2 reset. - - DBGMCU->APB1FZ |= DBGMCU_APB1_FZ_DBG_TIM2_STOP; // stop TIM2 clock during debug halt. - TIM2->ARR = 0xffffffff; - if (TIM2->SR & TIM_SR_UIF_Msk) { - for (;;) { - } - } -} - -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif diff --git a/src/runtime/micro/device/host/utvm_timer.c b/src/runtime/micro/device/host/utvm_timer.c deleted file mode 100644 index 887b15c8b25a..000000000000 --- a/src/runtime/micro/device/host/utvm_timer.c +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file utvm_timer.c - * \brief uTVM timer API stubs for the host emulated device - */ - -#include - -#include "utvm_runtime.h" - -// TODO(weberlo): use this? https://stackoverflow.com/questions/5141960/get-the-current-time-in-c - -int32_t UTVMTimerStart() { return UTVM_ERR_OK; } - -uint32_t UTVMTimerStop(int32_t* err) { - *err = UTVM_ERR_OK; - return 0; -} diff --git a/src/runtime/micro/device/riscv_spike/utvm_init.s b/src/runtime/micro/device/riscv_spike/utvm_init.s deleted file mode 100644 index 68662cce97e7..000000000000 --- a/src/runtime/micro/device/riscv_spike/utvm_init.s +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -UTVMInit: - /* set stack pointer */ - la sp, _utvm_stack_pointer_init - call UTVMMain diff --git a/src/runtime/micro/host_driven/utvm_device_dylib_redirect.c b/src/runtime/micro/host_driven/utvm_device_dylib_redirect.c deleted file mode 100644 index 64b5908e6c1c..000000000000 --- a/src/runtime/micro/host_driven/utvm_device_dylib_redirect.c +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file utvm_device_dylib_redirect.cc - * \brief uTVM dynamic linking stubs - * - * This is a library that gets included in each uTVM library. We redirect - * each library call into a pre-defined global function pointer, and we patch - * the correct addresses of each function into the pointers when we load the - * library. - */ -#ifdef __cplusplus -extern "C" { -#endif -#include -#include - -// TODO(weberlo, areusch): compiler errors say volatile qualifier is discarded. -// should we just get rid of em? -void* (*volatile TVMBackendAllocWorkspace_)(int, int, uint64_t, int, int) = NULL; -int (*volatile TVMBackendFreeWorkspace_)(int, int, void*) = NULL; -void (*volatile TVMAPISetLastError_)(const char*) = NULL; - -void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, - int dtype_bits_hint) { - return (*TVMBackendAllocWorkspace_)(device_type, device_id, size, dtype_code_hint, - dtype_bits_hint); -} - -int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { - return (*TVMBackendFreeWorkspace_)(device_type, device_id, ptr); -} - -void TVMAPISetLastError(const char* msg) { (*TVMAPISetLastError_)(msg); } - -void* memset(void* s, int c, size_t n) { - char* p = (char*)s; // NOLINT(readability/casting): linter is configured for c++ - while (n > 0) { - *p = (char)c; // NOLINT(readability/casting): linter is configured for c++ - p++; - n--; - } - return s; -} - -void* memmove(void* to, const void* from, size_t n) { - // TODO(weberlo, areusch): will need to factor memmove calls into workspace size calculation - // NOLINTNEXTLINE(readability/casting): linter is configured for c++ - char* temp = (char*)TVMBackendAllocWorkspace(1, 1, (uint64_t)n, 2, 8); - if (temp == NULL) { - return NULL; - } - - const char* from_pp = (char*)from; // NOLINT(readability/casting): linter is configured for c++ - for (size_t i = 0; i < n; i++) { - temp[i] = from_pp[i]; - } - char* to_pp = (char*)to; // NOLINT(readability/casting): linter is configured for c++ - for (size_t i = 0; i < n; i++) { - to_pp[i] = temp[i]; - } - - // NOLINTNEXTLINE(readability/casting): linter is configured for c++ - if (TVMBackendFreeWorkspace(1, (uint64_t)1, (void*)temp) != 0) { - return NULL; - } - - return to; -} - -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif diff --git a/src/runtime/micro/host_driven/utvm_runtime.c b/src/runtime/micro/host_driven/utvm_runtime.c deleted file mode 100644 index 398a08a014e0..000000000000 --- a/src/runtime/micro/host_driven/utvm_runtime.c +++ /dev/null @@ -1,185 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file utvm_runtime.cc - * \brief uTVM runtime - * - * All function calls go through the externally defined `UTVMInit`, which - * performs device-specific setup, then calls `UTVMMain`. `UTVMMain` then - * calls the function in `utvm_task` with the arguments from the task. - * - * Additionally included in this file are definitions for some of the most - * common functions used in the C runtime API. - */ -#ifdef __cplusplus -extern "C" { -#endif - -#include "utvm_runtime.h" - -// TODO(weberlo, areusch): move defines into header -// TODO(weberlo, areusch): unify TASK_QUEUE_SIZE and MicroSession::kTaskQueueCapacity. -#define TASK_QUEUE_SIZE 20 -volatile UTVMTask utvm_tasks[TASK_QUEUE_SIZE] = {}; -volatile uint32_t utvm_num_tasks = 0; -volatile uint32_t utvm_task_times[TASK_QUEUE_SIZE] = {}; - -// These pointers are patched at load time to point to the workspace section. -volatile char* utvm_workspace_start = NULL; // NOLINT(*) -volatile char* utvm_workspace_end = NULL; // NOLINT(*) -volatile char* utvm_workspace_curr = NULL; // NOLINT(*) -#define MAX_WS_ALLOCS 10 -volatile char* utvm_alloc_ends[MAX_WS_ALLOCS] = {}; // NOLINT(*) -volatile uint32_t utvm_alloc_idx = 0; -// Keep track of how many active allocations there are on the workspace. -volatile uint32_t utvm_num_active_allocs = 0; - -volatile uint32_t utvm_word_size = 0; - -volatile int32_t utvm_last_error = 0; // NOLINT(*) - -volatile uint32_t utvm_done = 0; - -// Gets called by UTVMInit, after device-specific initialization is finished. -void UTVMMain() { - utvm_done = 0; - // loss of precision should be fine here, since we only care about the lower bits - if (((uint32_t)utvm_workspace_start) % utvm_word_size) { - utvm_last_error = UTVM_ERR_WS_UNALIGNED_START; - UTVMDone(); - return; - } - utvm_workspace_curr = utvm_workspace_start; - utvm_num_active_allocs = 0; - utvm_alloc_idx = 0; - utvm_last_error = UTVM_ERR_NOT_FINISHED; - for (uint32_t i = 0; i < utvm_num_tasks; i++) { - int32_t err = UTVM_ERR_OK; - utvm_task_times[i] = 0; - err = UTVMTimerStart(); - if (err < 0) { - utvm_last_error = err; - UTVMDone(); - return; - } - err = utvm_tasks[i].func((void*)utvm_tasks[i].arg_values, // NOLINT(*) - (void*)utvm_tasks[i].arg_type_codes, // NOLINT(*) - utvm_tasks[i].num_args); - if (err < 0) { - UTVMDone(); - return; - } - utvm_task_times[i] = UTVMTimerStop(&err); - if (err < 0) { - utvm_last_error = err; - UTVMDone(); - return; - } - } - if (utvm_last_error == UTVM_ERR_NOT_FINISHED) { - utvm_last_error = UTVM_ERR_OK; - } - UTVMDone(); -} - -// We use a dummy function to signal execution is finished for device -// backends which require breakpoints. -void __attribute__((noinline)) UTVMDone() { - utvm_done = 1; -#ifndef UTVM_TARGET_HOST - for (;;) { - } -#endif -} - -#define ALIGNED_UP(x, word_size) \ - ((((word_size) - (((uintptr_t)(x)) % (word_size))) % (word_size)) + (x)) - -void* TVMBackendAllocWorkspace(int device_type, int device_id, uint64_t size, int dtype_code_hint, - int dtype_bits_hint) { - if (size == 0) { - utvm_last_error = UTVM_ERR_WS_ZERO_SIZE_ALLOC; - return NULL; - } - size_t alloc_requested_bytes = size; - size_t alloc_size_words = (alloc_requested_bytes + utvm_word_size - 1) / utvm_word_size; - size_t alloc_size_bytes = alloc_size_words * utvm_word_size; - - // Align up to the target word size. - if (utvm_workspace_curr + alloc_size_bytes > utvm_workspace_end) { - // Out of space in workspace. - utvm_last_error = UTVM_ERR_WS_OUT_OF_SPACE; - return NULL; - } - if (utvm_alloc_idx == MAX_WS_ALLOCS - 1) { - // Exceeded number of allocs we can keep track of. - utvm_last_error = UTVM_ERR_WS_TOO_MANY_ALLOCS; - return NULL; - } - void* ret_ptr = (void*)utvm_workspace_curr; // NOLINT(*) - utvm_workspace_curr = utvm_workspace_curr + alloc_size_bytes; - // store the *end* of the alloc, so we can restore the WS pointer when freeing - utvm_alloc_ends[utvm_alloc_idx] = utvm_workspace_curr; - utvm_alloc_idx++; - utvm_num_active_allocs++; - return ret_ptr; -} - -int TVMBackendFreeWorkspace(int device_type, int device_id, void* ptr) { - // TODO(weberlo, areusch): add dev type check - if (utvm_num_active_allocs == 0) { - TVMAPISetLastError("free called with no active workspace allocations"); - // Reset allocations and workspace (for future task executions). - utvm_num_active_allocs = 0; - utvm_workspace_curr = utvm_workspace_start; - utvm_last_error = UTVM_ERR_WS_DOUBLE_FREE; - return -1; - } else { - utvm_num_active_allocs--; - if (ptr == utvm_workspace_start) { - // it's the first allocation - utvm_alloc_ends[0] = NULL; - } else { - for (uint32_t i = utvm_alloc_idx - 1; i >= 0; i--) { - if (utvm_alloc_ends[i] == ptr) { - utvm_alloc_ends[i + 1] = NULL; - break; - } - } - } - while (utvm_alloc_idx > 0 && utvm_alloc_ends[utvm_alloc_idx - 1] == NULL) { - utvm_alloc_idx--; - } - if (utvm_alloc_idx == 0) { - utvm_workspace_curr = utvm_workspace_start; - } else { - // TODO(weberlo, areusch): could you possibly have utvm_alloc_idx pointing to a NULL entry in - // this branch? - utvm_workspace_curr = utvm_alloc_ends[utvm_alloc_idx - 1]; - } - return 0; - } -} - -void TVMAPISetLastError(const char* msg) {} - -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif diff --git a/src/runtime/micro/host_driven/utvm_runtime.h b/src/runtime/micro/host_driven/utvm_runtime.h deleted file mode 100644 index 8758c3ad89a1..000000000000 --- a/src/runtime/micro/host_driven/utvm_runtime.h +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file utvm_runtime.h - * \brief uTVM runtime headers - */ -#ifndef TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_H_ -#define TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_H_ - -#ifdef __cplusplus -extern "C" { -#endif - -#include -#include -#include - -#include "utvm_runtime_enum.h" - -/*! - * \brief Task structure for uTVM - */ -typedef struct { - /*! \brief Pointer to function to call for this task */ - int32_t (*func)(void*, void*, int32_t); - /*! \brief Array of argument values */ - TVMValue* arg_values; - /*! \brief Array of type codes for each argument value */ - int* arg_type_codes; - /*! \brief Number of arguments */ - int32_t num_args; -} UTVMTask; - -/*! - * \brief microTVM processor startup. - * Expected to reset the stack pointer, configure any hardware required to support the CRT - * (i.e. FPU), and then jump to UTVMMain. - */ -extern void UTVMInit(); - -/*! - * \brief Start the on-device timer. - * \return UTVMReturnCode indicating the outcome of the operation. - */ -extern int32_t UTVMTimerStart(); - -/*! - * \brief Stop the on-device timer. - * TODO(areusch): Use an SI specification of timer units here. - * \param err Receives a UTVMReturnCode indicating the outcome of the operation. - * \return elapsed time since UTVMTimerStart returned, in device timer ticks. - */ -extern uint32_t UTVMTimerStop(int32_t* err); - -/*! - * \brief Main entry point for UTVM runtime. - * Waits for "go" signal, then executes tasks and reports result. Should never return. - */ -void UTVMMain(); - -/*! - * \brief Function entered when UTVMMain is complete. - * Should never return. The host sets a breakpoint here to detect end of computation. - */ -void UTVMDone(); - -// GCC -O3 begins to inject memset and memmove calls, so we provide impls in -// the runtime for this case and for general usage. - -void* memset(void* s, int c, size_t n); - -void* memmove(void* to, const void* from, size_t n); - -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif - -#endif // TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_H_ diff --git a/src/runtime/micro/host_driven/utvm_runtime_enum.h b/src/runtime/micro/host_driven/utvm_runtime_enum.h deleted file mode 100644 index 17f803612cb9..000000000000 --- a/src/runtime/micro/host_driven/utvm_runtime_enum.h +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file utvm_runtime_enum.h - * \brief Defines constants used both on the host and on device. - */ -#ifndef TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_ENUM_H_ -#define TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_ENUM_H_ - -#ifdef __cplusplus -extern "C" { -#endif - -/*! - * \brief TODO - */ -enum UTVMReturnCode { - UTVM_ERR_OK = 0, - UTVM_ERR_NOT_FINISHED = -1, - UTVM_ERR_TIMER_NOT_IMPLEMENTED = -2, - UTVM_ERR_TIMER_OVERFLOW = -3, - UTVM_ERR_WS_DOUBLE_FREE = -4, - UTVM_ERR_WS_OUT_OF_SPACE = -5, - UTVM_ERR_WS_TOO_MANY_ALLOCS = -6, - UTVM_ERR_WS_ZERO_SIZE_ALLOC = -7, - UTVM_ERR_WS_UNALIGNED_START = -8, - UTVM_ERR_WS_UNALIGNED_ALLOC_SIZE = -9, -}; - -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif - -#endif // TVM_RUNTIME_MICRO_HOST_DRIVEN_UTVM_RUNTIME_ENUM_H_ diff --git a/src/runtime/micro/host_low_level_device.cc b/src/runtime/micro/host_low_level_device.cc deleted file mode 100644 index 7c3e7a2abad8..000000000000 --- a/src/runtime/micro/host_low_level_device.cc +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file host_low_level_device.cc - * \brief emulated low-level micro device implementation on host machine - */ - -#include - -#include -#include - -#include "low_level_device.h" -#include "micro_common.h" - -namespace tvm { -namespace runtime { - -/*! \brief number of bytes in each page */ -constexpr int kPageSize = 4096; - -/*! - * \brief emulated low-level device on host machine - */ -class HostLowLevelDevice final : public LowLevelDevice { - public: - /*! - * \brief constructor to initialize on-host memory region to act as device - * \param num_bytes size of the emulated on-device memory region - */ - explicit HostLowLevelDevice(size_t num_bytes, TargetPtr* base_addr) : size_(num_bytes) { - size_t size_in_pages = (num_bytes + kPageSize - 1) / kPageSize; - // TODO(weberlo): Set permissions per section (e.g., read-write perms for - // the heap, execute perms for text, etc.). - int mmap_prot = PROT_READ | PROT_WRITE | PROT_EXEC; - int mmap_flags = MAP_ANONYMOUS | MAP_PRIVATE; - base_addr_ = mmap(nullptr, size_in_pages * kPageSize, mmap_prot, mmap_flags, -1, 0); - *base_addr = - TargetPtr(TargetWordSize(sizeof(size_t) * 8), reinterpret_cast(base_addr_)); - } - - /*! - * \brief destructor to deallocate on-host device region - */ - virtual ~HostLowLevelDevice() { munmap(base_addr_, size_); } - - void Read(TargetPtr addr, void* buf, size_t num_bytes) { - std::memcpy(buf, addr.cast_to(), num_bytes); - } - - void Write(TargetPtr addr, const void* buf, size_t num_bytes) { - std::memcpy(addr.cast_to(), buf, num_bytes); - } - - void Execute(TargetPtr func_addr, TargetPtr breakpoint_addr) { - reinterpret_cast(func_addr.value().uint64())(); - } - - const char* device_type() const final { return "host"; } - - private: - /*! \brief base address of the micro device memory region */ - void* base_addr_; - /*! \brief size of memory region */ - size_t size_; -}; - -const std::shared_ptr HostLowLevelDeviceCreate(size_t num_bytes, - TargetPtr* base_addr) { - std::shared_ptr lld = std::make_shared(num_bytes, base_addr); - return lld; -} - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/micro/low_level_device.h b/src/runtime/micro/low_level_device.h deleted file mode 100644 index 6cc0e1dc5af0..000000000000 --- a/src/runtime/micro/low_level_device.h +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file low_level_device.h - * \brief Abstract low-level micro device management - */ -#ifndef TVM_RUNTIME_MICRO_LOW_LEVEL_DEVICE_H_ -#define TVM_RUNTIME_MICRO_LOW_LEVEL_DEVICE_H_ - -#include -#include - -#include "micro_common.h" - -namespace tvm { -namespace runtime { -/*! - * \brief virtual interface for low-level micro device management - */ -class LowLevelDevice { - public: - /*! \brief virtual destructor */ - virtual ~LowLevelDevice() {} - - /*! - * \brief reads num_bytes from device memory at addr into buffer - * \param addr on-device memory address to read from - * \param buffer on-host buffer to be read into - * \param num_bytes number of bytes to read - */ - virtual void Read(TargetPtr addr, void* buffer, size_t num_bytes) = 0; - - /*! - * \brief writes num_bytes from buffer to device memory at addr - * \param addr on-device memory address to write into - * \param buffer host buffer to write from - * \param num_bytes number of bytes to write - */ - virtual void Write(TargetPtr addr, const void* buffer, size_t num_bytes) = 0; - - /*! - * \brief starts execution of device at func_addr - * \param func_addr offset of the init stub function - * \param breakpoint_addr address at which to stop function execution - */ - virtual void Execute(TargetPtr func_addr, TargetPtr breakpoint_addr) = 0; - - /*! - * \brief getter function for low-level device type - * \return string containing device type - */ - virtual const char* device_type() const = 0; -}; - -/*! - * \brief create a host low-level device - * \param num_bytes size of the memory region - * \param base_addr pointer to write the host device's resulting base address into - */ -const std::shared_ptr HostLowLevelDeviceCreate(size_t num_bytes, - TargetPtr* base_addr); - -/*! - * \brief connect to OpenOCD and create an OpenOCD low-level device - * \param addr address of the OpenOCD server to connect to - * \param port port of the OpenOCD server to connect to - */ -const std::shared_ptr OpenOCDLowLevelDeviceCreate(const std::string& addr, - int port); - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_MICRO_LOW_LEVEL_DEVICE_H_ diff --git a/src/runtime/micro/micro_common.cc b/src/runtime/micro/micro_common.cc deleted file mode 100644 index eba77f3dadbc..000000000000 --- a/src/runtime/micro/micro_common.cc +++ /dev/null @@ -1,131 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file micro_common.cc - * \brief common utilties for uTVM - */ - -#include "micro_common.h" - -#include -#include - -#include -#include -#include -#include - -#include "low_level_device.h" -#include "micro_session.h" - -namespace tvm { -namespace runtime { - -const char* SectionToString(SectionKind section) { - switch (section) { - case SectionKind::kText: - return "text"; - case SectionKind::kRodata: - return "rodata"; - case SectionKind::kData: - return "data"; - case SectionKind::kBss: - return "bss"; - case SectionKind::kArgs: - return "args"; - case SectionKind::kHeap: - return "heap"; - case SectionKind::kWorkspace: - return "workspace"; - case SectionKind::kStack: - return "stack"; - default: - return ""; - } -} - -std::string RelocateBinarySections(const std::string& binary_path, TargetWordSize word_size, - TargetPtr text_start, TargetPtr rodata_start, - TargetPtr data_start, TargetPtr bss_start, TargetPtr stack_end, - const std::string& toolchain_prefix) { - const auto* f = Registry::Get("tvm_callback_relocate_binary"); - CHECK(f != nullptr) << "Require tvm_callback_relocate_binary to exist in registry"; - std::string relocated_bin = - (*f)(binary_path, word_size.bytes(), text_start.cast_to(), - rodata_start.cast_to(), data_start.cast_to(), - bss_start.cast_to(), stack_end.cast_to(), toolchain_prefix); - return relocated_bin; -} - -std::string ReadSection(const std::string& binary, SectionKind section, - const std::string& toolchain_prefix) { - CHECK(section == SectionKind::kText || section == SectionKind::kRodata || - section == SectionKind::kData || section == SectionKind::kBss) - << "ReadSection requires section to be one of text, rodata, data, or bss."; - const auto* f = Registry::Get("tvm_callback_read_binary_section"); - CHECK(f != nullptr) << "Require tvm_callback_read_binary_section to exist in registry"; - TVMByteArray arr; - arr.data = &binary[0]; - arr.size = binary.length(); - std::string section_contents = (*f)(arr, SectionToString(section), toolchain_prefix); - return section_contents; -} - -size_t GetSectionSize(const std::string& binary_path, SectionKind section, - const std::string& toolchain_prefix, TargetWordSize word_size) { - CHECK(section == SectionKind::kText || section == SectionKind::kRodata || - section == SectionKind::kData || section == SectionKind::kBss) - << "GetSectionSize requires section to be one of text, rodata, data, or bss."; - const auto* f = Registry::Get("tvm_callback_get_section_size"); - CHECK(f != nullptr) << "Require tvm_callback_get_section_size to exist in registry"; - int size = (*f)(binary_path, SectionToString(section), toolchain_prefix); - return UpperAlignValue(size, word_size.bytes()); -} - -std::ostream& operator<<(std::ostream& os, const TargetVal& v) { - std::ios_base::fmtflags f(os.flags()); - os << std::dec << "0x"; - switch (v.width_bits()) { - case 8: - os << uint8_t(v.uint32()); - break; - case 16: - os << uint16_t(v.uint32()); - break; - case 32: - os << v.uint32(); - break; - case 64: - os << v.uint64(); - break; - default: - os << (v.uint64() & ((1 << v.width_bits()) - 1)); - } - os.flags(f); - return os; -} - -std::ostream& operator<<(std::ostream& os, const TargetPtr& v) { - os << "*" << v.value_; - return os; -} - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/micro/micro_common.h b/src/runtime/micro/micro_common.h deleted file mode 100644 index 2c4684b357a8..000000000000 --- a/src/runtime/micro/micro_common.h +++ /dev/null @@ -1,359 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file micro_common.h - */ -#ifndef TVM_RUNTIME_MICRO_MICRO_COMMON_H_ -#define TVM_RUNTIME_MICRO_MICRO_COMMON_H_ - -#include -#include - -#include -#include -#include -#include - -namespace tvm { -namespace runtime { - -/*! - * \brief enum of device memory region sections - * - * The order in which the enum variants are defined also defines the order of - * the sections in device memory. - */ -enum class SectionKind : size_t { - kText = 0, - kRodata, - kData, - kBss, - kArgs, - kHeap, - kWorkspace, - kStack, - kNumKinds, -}; - -/*! \brief data type for word sizes */ -class TargetWordSize { - public: - explicit TargetWordSize(size_t word_size_bits) : word_size_bits_{word_size_bits} { - CHECK(word_size_bits == 32 || word_size_bits == 64) - << "only 32-bit and 64-bit are supported now"; - } - - size_t bytes() const { return word_size_bits_ / 8; } - - size_t bits() const { return word_size_bits_; } - - private: - size_t word_size_bits_; -}; - -/*! \brief class for storing values on varying target word sizes */ -class TargetVal { - private: - size_t width_bits_; - uint64_t value_; - - public: - /*! \brief construct a TargetVal matching the size of the given integral argument */ - template ::value, T>::type> - explicit constexpr TargetVal(T value) : TargetVal(sizeof(T) * 8, value) {} - - /*! \brief construct an uninitialized value */ - TargetVal() : width_bits_{0}, value_{0} {} - - /*! \brief construct a TargetVal with explicit size and value */ - TargetVal(size_t width_bits, uint64_t value) : width_bits_{width_bits} { - CHECK(width_bits >= 8 && width_bits <= 64 && (width_bits & (width_bits - 1)) == 0) - << "width_bits must be a power of 2 in [8, 64], got " << width_bits; - value_ = value & Bitmask(); - } - - bool IsInitialized() const { return width_bits_ != 0; } - - size_t width_bits() const { - CHECK(IsInitialized()) << "TargetVal is not initialized"; - return width_bits_; - } - - uint64_t Bitmask() const { - CHECK(IsInitialized()) << "TargetVal is not initialized"; - - if (width_bits_ == 64) { - return ~0UL; - } else { - return (1UL << width_bits_) - 1; - } - } - - uint32_t uint32() const { - CHECK(IsInitialized()) << "TargetVal is not initialized"; - CHECK(width_bits_ <= 32) << "TargetVal: requested 32-bit value, actual width is " - << width_bits_; - return uint32_t(value_ & Bitmask()); - } - - uint64_t uint64() const { - CHECK(IsInitialized()) << "TargetVal is not initialized"; - return value_; - } - - TargetVal& operator=(const TargetVal& other) { - CHECK(other.IsInitialized()) << "Cannot assign an uninitialized TargetVal"; - - if (!IsInitialized()) { - width_bits_ = other.width_bits_; - } - - CHECK(width_bits_ >= other.width_bits_) - << "Cannot assign TargetVal with width " << other.width_bits_ - << "bits to TargetVal with width " << width_bits_ << "bits"; - - value_ = other.value_ & Bitmask(); - return *this; - } - - private: - friend std::ostream& operator<<(std::ostream& os, const TargetVal& v); -}; - -// TODO(weberlo, areusch): just get rid of `TargetPtr`. -/*! \brief absolute device address */ -class TargetPtr { - public: - /*! \brief construct a device address with variable-length value `value` */ - TargetPtr(TargetWordSize word_size, std::uint64_t value) - : value_(TargetVal(word_size.bits(), value)) {} - - /*! \brief construct a null address */ - TargetPtr(TargetWordSize word_size, std::nullptr_t value) - : value_{TargetVal(word_size.bits(), 0)} {} - - /*! \brief construct an uninitialized pointer whose word_size can be changed once */ - TargetPtr() = default; - - /*! \brief construct a device address using the given TargetVal */ - explicit TargetPtr(const TargetVal& value) : value_{value} {} - - /*! \brief destructor */ - ~TargetPtr() {} - - /*! - * \brief get value of pointer - * \return value of pointer - */ - TargetVal value() const { return value_; } - - /*! - * \brief cast location to type `T` - * \return casted result - */ - template - T cast_to() const { - return reinterpret_cast(value_.uint64()); - } - - /*! \brief check if location is null */ - bool operator==(std::nullptr_t) const { return value_.uint64() == 0; } - - /*! \brief check if location is not null */ - bool operator!=(std::nullptr_t) const { return value_.uint64() != 0; } - - /*! \brief add an integer to this absolute address to get a larger absolute address */ - TargetPtr operator+(size_t n) const { - return TargetPtr(TargetWordSize(value_.width_bits()), value_.uint64() + n); - } - - /*! \brief mutably add an integer to this absolute address */ - TargetPtr& operator+=(size_t n) { - value_ = TargetVal(value_.width_bits(), value_.uint64() + n); - return *this; - } - - /*! \brief subtract an integer from this absolute address to get a smaller absolute address */ - TargetPtr operator-(size_t n) const { - return TargetPtr(TargetWordSize(value_.width_bits()), value_.uint64() - n); - } - - /*! \brief mutably subtract an integer from this absolute address */ - TargetPtr& operator-=(size_t n) { - value_ = TargetVal(value_.width_bits(), value_.uint64() - n); - return *this; - } - - private: - /*! \brief raw value storing the pointer */ - TargetVal value_; - - friend std::ostream& operator<<(std::ostream& os, const TargetPtr& v); -}; - -/*! - * \brief map from symbols to their on-device offsets - */ -class SymbolMap { - public: - /*! - * \brief default constructor - */ - SymbolMap() {} - - /*! - * \brief constructor that builds the mapping - * \param binary contents of binary object file - * \param toolchain_prefix prefix of compiler toolchain to use - */ - SymbolMap(const std::string& binary, const std::string& toolchain_prefix, - TargetWordSize word_size) { - const auto* f = Registry::Get("tvm_callback_get_symbol_map"); - CHECK(f != nullptr) << "require tvm_callback_get_symbol_map to exist in registry"; - TVMByteArray arr; - arr.data = &binary[0]; - arr.size = binary.length(); - std::string map_str = (*f)(arr, toolchain_prefix); - // Parse symbols and addresses from returned string. - std::stringstream stream; - stream << map_str; - std::string name; - std::uintptr_t addr; - stream >> name; - stream >> std::hex >> addr; - while (stream) { - map_.emplace(std::make_pair(name, TargetPtr(word_size, addr))); - stream >> name; - stream >> std::hex >> addr; - } - } - - /*! - * \brief retrieve on-device offset for a symbol name - * \param name name of the symbol - * \return on-device offset of the symbol - */ - TargetPtr operator[](const std::string& name) const { - auto result = map_.find(name); - CHECK(result != map_.end()) << "\"" << name << "\" not in symbol map"; - return result->second; - } - - bool HasSymbol(const std::string& name) const { return map_.find(name) != map_.end(); } - - void Dump(std::ostream& stream) const { - for (auto e : map_) { - stream << "Entry:" << e.first << std::endl; - } - } - - private: - /*! \brief backing map */ - std::unordered_map map_; -}; - -/*! \brief struct containing start and size of a device memory region */ -struct DevMemRegion { - /*! \brief section start offset */ - TargetPtr start; - /*! \brief size of section */ - size_t size; -}; - -/*! \brief struct containing section locations and symbol mappings */ -struct BinaryInfo { - /*! \brief text section region */ - DevMemRegion text_section; - /*! \brief rodata section region */ - DevMemRegion rodata_section; - /*! \brief data section region */ - DevMemRegion data_section; - /*! \brief bss section region */ - DevMemRegion bss_section; - /*! \brief symbol map to offsets */ - SymbolMap symbol_map; -}; - -struct BinaryContents { - BinaryInfo binary_info; - std::string text_contents; - std::string rodata_contents; - std::string data_contents; - std::string bss_contents; -}; - -/*! - * \brief upper-aligns value according to specified alignment - * \param value value to be aligned - * \param align alignment - * \return upper-aligned value - */ -inline size_t UpperAlignValue(size_t value, size_t align) { - return value + (align - (value % align)) % align; -} - -/*! - * \brief maps section enums to text - * \param section section type - * \return text form of the specified section - */ -const char* SectionToString(SectionKind section); - -/*! - * \brief links binary by repositioning section addresses - * \param binary_name input binary filename - * \param word_size word size on the target machine - * \param text_start text section address - * \param rodata_start rodata section address - * \param data_start data section address - * \param bss_start bss section address - * \param stack_end stack section end address - * \param toolchain_prefix prefix of compiler toolchain to use - * \return relocated binary file contents - */ -std::string RelocateBinarySections(const std::string& binary_path, TargetWordSize word_size, - TargetPtr text_start, TargetPtr rodata_start, - TargetPtr data_start, TargetPtr bss_start, TargetPtr stack_end, - const std::string& toolchain_prefix); - -/*! - * \brief reads section from binary - * \param binary input binary contents - * \param section section type to be read - * \param toolchain_prefix prefix of compiler toolchain to use - * \return contents of the section - */ -std::string ReadSection(const std::string& binary, SectionKind section, - const std::string& toolchain_prefix); - -/*! - * \brief finds size of the section in the binary - * \param binary input binary contents - * \param section section type - * \param toolchain_prefix prefix of compiler toolchain to use - * \param word_size word size of the target, for alignment - * \return size of the section if it exists, 0 otherwise - */ -size_t GetSectionSize(const std::string& binary_name, SectionKind section, - const std::string& toolchain_prefix, TargetWordSize word_size); - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_MICRO_MICRO_COMMON_H_ diff --git a/src/runtime/micro/micro_device_api.cc b/src/runtime/micro/micro_device_api.cc deleted file mode 100644 index 3812ec072cd8..000000000000 --- a/src/runtime/micro/micro_device_api.cc +++ /dev/null @@ -1,162 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file micro_device_api.cc - */ - -#include -#include -#include - -#include "../workspace_pool.h" -#include "micro_session.h" - -namespace tvm { -namespace runtime { -/*! - * \brief device API for uTVM micro devices - */ -class MicroDeviceAPI final : public DeviceAPI { - public: - /*! \brief constructor */ - MicroDeviceAPI() {} - - void SetDevice(TVMContext ctx) final {} - - void GetAttr(TVMContext ctx, DeviceAttrKind kind, TVMRetValue* rv) final { - if (kind == kExist) { - *rv = 1; - } - } - - void* AllocDataSpace(TVMContext ctx, size_t nbytes, size_t alignment, - DLDataType type_hint) final { - ObjectPtr& session = MicroSession::Current(); - TargetPtr data = session->AllocateInSection(SectionKind::kHeap, nbytes); - CHECK(data != nullptr) << "unable to allocate " << nbytes << " bytes on device heap"; - return reinterpret_cast(new MicroDevSpace{data, session}); - } - - void FreeDataSpace(TVMContext ctx, void* ptr) final { - MicroDevSpace* dev_space = static_cast(ptr); - dev_space->session->FreeInSection(SectionKind::kHeap, dev_space->data); - delete dev_space; - } - - void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, - TVMContext ctx_from, TVMContext ctx_to, DLDataType type_hint, - TVMStreamHandle stream) final { - std::tuple type_from_to(ctx_from.device_type, ctx_to.device_type); - if (type_from_to == std::make_tuple(kDLMicroDev, kDLMicroDev)) { - // Copying from the device to the device. - MicroDevSpace* from_space = static_cast(const_cast(from)); - MicroDevSpace* to_space = static_cast(const_cast(to)); - CHECK(from_space->session == to_space->session) - << "attempt to copy data between different micro sessions (" << from_space->session.get() - << " != " << to_space->session.get() << ")"; - CHECK(ctx_from.device_id == ctx_to.device_id) - << "can only copy between the same micro device"; - ObjectPtr& session = from_space->session; - // flush all pending tasks to ensure data is consistent - session->FlushTaskQueue(); - const std::shared_ptr& lld = session->low_level_device(); - - TargetPtr from_dev_addr = GetDevLoc(from_space, from_offset); - TargetPtr to_dev_addr = GetDevLoc(to_space, to_offset); - - std::vector buffer(size); - lld->Read(from_dev_addr, static_cast(buffer.data()), size); - lld->Write(to_dev_addr, static_cast(buffer.data()), size); - - } else if (type_from_to == std::make_tuple(kDLMicroDev, kDLCPU)) { - // Reading from the device. - MicroDevSpace* from_space = static_cast(const_cast(from)); - ObjectPtr& session = from_space->session; - // flush all pending tasks to ensure data is consistent - session->FlushTaskQueue(); - const std::shared_ptr& lld = session->low_level_device(); - - TargetPtr from_dev_addr = GetDevLoc(from_space, from_offset); - void* to_host_ptr = GetHostLoc(to, to_offset); - lld->Read(from_dev_addr, to_host_ptr, size); - - } else if (type_from_to == std::make_tuple(kDLCPU, kDLMicroDev)) { - // Writing to the device. - MicroDevSpace* to_space = static_cast(const_cast(to)); - ObjectPtr& session = to_space->session; - // flush all pending tasks to ensure data is consistent - session->FlushTaskQueue(); - const std::shared_ptr& lld = session->low_level_device(); - - void* from_host_ptr = GetHostLoc(from, from_offset); - TargetPtr to_dev_addr = GetDevLoc(to_space, to_offset); - lld->Write(to_dev_addr, from_host_ptr, size); - - } else { - LOG(FATAL) << "Expect copy from/to micro device or between micro device\n"; - } - } - - void StreamSync(TVMContext ctx, TVMStreamHandle stream) final { - MicroSession::Current()->FlushTaskQueue(); - } - - void* AllocWorkspace(TVMContext ctx, size_t size, DLDataType type_hint) final { - CHECK(false) << "the on-device workspace allocator isn't aware of this function"; - ObjectPtr& session = MicroSession::Current(); - - TargetPtr data = session->AllocateInSection(SectionKind::kWorkspace, size); - CHECK(data.value().uint64() != 0) - << "unable to allocate " << size << " bytes on device workspace"; - return static_cast(new MicroDevSpace{data, session}); - } - - void FreeWorkspace(TVMContext ctx, void* data) final { - CHECK(false) << "the on-device workspace allocator isn't aware of this function"; - MicroDevSpace* dev_space = static_cast(data); - ObjectPtr& session = dev_space->session; - session->FreeInSection(SectionKind::kWorkspace, dev_space->data); - delete dev_space; - } - - /*! - * \brief obtain a global singleton of MicroDeviceAPI - * \return global shared pointer to MicroDeviceAPI - */ - static MicroDeviceAPI* Global() { - static MicroDeviceAPI* inst = new MicroDeviceAPI(); - return inst; - } - - private: - TargetPtr GetDevLoc(MicroDevSpace* dev_space, size_t offset) { return dev_space->data + offset; } - - void* GetHostLoc(const void* ptr, size_t offset) { - return reinterpret_cast(reinterpret_cast(ptr) + offset); - } -}; - -// register device that can be obtained from Python frontend -TVM_REGISTER_GLOBAL("device_api.micro_dev").set_body([](TVMArgs args, TVMRetValue* rv) { - DeviceAPI* ptr = MicroDeviceAPI::Global(); - *rv = static_cast(ptr); -}); -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/micro/micro_module.cc b/src/runtime/micro/micro_module.cc deleted file mode 100644 index b4770ec6f934..000000000000 --- a/src/runtime/micro/micro_module.cc +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file micro_module.cc - */ - -#include -#include -#include - -#include -#include - -#include "../pack_args.h" -#include "low_level_device.h" -#include "micro_common.h" -#include "micro_session.h" - -namespace tvm { -namespace runtime { -/*! - * \brief module for uTVM micro devices - */ -class MicroModuleNode final : public ModuleNode { - public: - MicroModuleNode() {} - - ~MicroModuleNode() {} - - const char* type_key() const final { return "micro"; } - - PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self) final; - - /*! - * \brief initializes module by establishing device connection and loads binary - * \param binary_path path of the binary to be loaded - */ - void InitMicroModule(const std::string& binary_path) { - // std::cout << "[MicroModuleNode::InitMicroModule]" << std::endl; - // std::cout << " start" << std::endl; - session_ = MicroSession::Current(); - symbol_map_ = session_->LoadBinary(binary_path, true).symbol_map; - } - - private: - SymbolMap symbol_map_; - /*! \brief global session pointer */ - ObjectPtr session_; -}; - -class MicroWrappedFunc { - public: - MicroWrappedFunc(ObjectPtr session, TargetPtr func_ptr) { - session_ = session; - func_ptr_ = func_ptr; - } - - void operator()(TVMArgs args, TVMRetValue* rv) const { - session_->PushToTaskQueue(func_ptr_, args); - } - - private: - /*! \brief reference to the session for this function (to keep the session alive) */ - ObjectPtr session_; - /*! \brief offset of the function to be called */ - TargetPtr func_ptr_; -}; - -PackedFunc MicroModuleNode::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { - TargetPtr func_ptr; - if (name == tvm::runtime::symbol::tvm_module_main) { - if (symbol_map_.HasSymbol(tvm::runtime::symbol::tvm_module_main)) { - func_ptr = symbol_map_[tvm::runtime::symbol::tvm_module_main]; - } else { - func_ptr = symbol_map_["default_function"]; - } - } else { - func_ptr = symbol_map_[name]; - } - MicroWrappedFunc f(session_, func_ptr); - return PackedFunc(f); -} - -// register loadfile function to load module from Python frontend -TVM_REGISTER_GLOBAL("runtime.module.loadfile_micro_dev") - .set_body([](TVMArgs args, TVMRetValue* rv) { - auto n = make_object(); - n->InitMicroModule(args[0]); - *rv = runtime::Module(n); - }); -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/micro/micro_section_allocator.h b/src/runtime/micro/micro_section_allocator.h deleted file mode 100644 index 5cafb41bbc4b..000000000000 --- a/src/runtime/micro/micro_section_allocator.h +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file micro_section_allocator.h - */ -#ifndef TVM_RUNTIME_MICRO_MICRO_SECTION_ALLOCATOR_H_ -#define TVM_RUNTIME_MICRO_MICRO_SECTION_ALLOCATOR_H_ - -#include -#include - -#include "micro_common.h" - -namespace tvm { -namespace runtime { - -/*! - * \brief allocator for an on-device memory section - */ -class MicroSectionAllocator { - public: - /*! - * \brief constructor that specifies section boundaries - * \param region location and size of the section on the device - */ - explicit MicroSectionAllocator(std::string section_name, DevMemRegion region, - TargetWordSize word_size) - : section_name_(section_name), - start_addr_(region.start), - size_(0), - capacity_(region.size), - word_size_(word_size) { - CHECK_EQ(start_addr_.value().uint64() % word_size.bytes(), 0) - << "micro section start not aligned to " << word_size.bytes() << " bytes"; - CHECK_EQ(capacity_ % word_size.bytes(), 0) - << "micro section end not aligned to " << word_size.bytes() << " bytes"; - } - - /*! - * \brief destructor - */ - ~MicroSectionAllocator() {} - - /*! - * \brief memory allocator - * \param alloc_size size of allocated memory in bytes - * \return pointer to allocated memory region in section, nullptr if out of space - */ - TargetPtr Allocate(size_t size) { - size_ = UpperAlignValue(size_, word_size_.bytes()); - CHECK(size_ + size < capacity_) - << "cannot alloc " << size << " bytes in section \"" << section_name_ - << "\" (start_addr=" << start_addr_.cast_to() << ", used=" << size_ - << ", capacity=" << capacity_ << ")"; - TargetPtr alloc_addr = start_addr_ + size_; - size_ += size; - alloc_map_[alloc_addr.value().uint64()] = size; - return alloc_addr; - } - - /*! - * \brief free prior allocation from section - * \param offs offset to allocated memory - * \note simple allocator scheme, more complex versions will be implemented later - */ - void Free(TargetPtr addr) { - CHECK(alloc_map_.find(addr.value().uint64()) != alloc_map_.end()) - << "freed pointer was never allocated"; - alloc_map_.erase(addr.value().uint64()); - if (alloc_map_.empty()) { - size_ = 0; - } - } - - /*! - * \brief start offset of the memory region managed by this allocator - */ - TargetPtr start_addr() const { return start_addr_; } - - /*! - * \brief current end addr of the space being used in this memory region - */ - TargetPtr curr_end_addr() const { return start_addr_ + size_; } - - /*! - * \brief end addr of the memory region managed by this allocator - */ - TargetPtr max_addr() const { return start_addr_ + capacity_; } - - /*! - * \brief size of the section - */ - size_t size() const { return size_; } - - /*! - * \brief capacity of the section - */ - size_t capacity() const { return capacity_; } - - private: - /*! \brief name of the section (for debugging) */ - std::string section_name_; - /*! \brief start address of the section */ - TargetPtr start_addr_; - /*! \brief current size of the section */ - size_t size_; - /*! \brief total storage capacity of the section */ - size_t capacity_; - /*! \brief number of bytes in a word on the target device */ - TargetWordSize word_size_; - /*! \brief allocation map for allocation sizes */ - std::unordered_map alloc_map_; -}; - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_MICRO_MICRO_SECTION_ALLOCATOR_H_ diff --git a/src/runtime/micro/micro_session.cc b/src/runtime/micro/micro_session.cc index f458872bfeb0..0ac2a014d858 100644 --- a/src/runtime/micro/micro_session.cc +++ b/src/runtime/micro/micro_session.cc @@ -23,664 +23,206 @@ #include "micro_session.h" -#include -#include +#include +#include +#include #include -#include -#include +#include #include -#include -#include -#include +#include +#include -#include "low_level_device.h" -#include "target_data_layout_encoder.h" +#include "../../support/str_escape.h" +#include "../crt/host/crt_config.h" +#include "../rpc/rpc_channel.h" +#include "../rpc/rpc_endpoint.h" +#include "../rpc/rpc_session.h" namespace tvm { namespace runtime { +namespace micro_rpc { -struct TVMMicroSessionThreadLocalEntry { - std::stack> session_stack; -}; - -typedef dmlc::ThreadLocalStore TVMMicroSessionThreadLocalStore; - -ObjectPtr& MicroSession::Current() { - TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get(); - CHECK_GT(entry->session_stack.size(), 0) << "No current session"; - return entry->session_stack.top(); -} - -void MicroSession::EnterWithScope(ObjectPtr session) { - TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get(); - entry->session_stack.push(session); -} +class CallbackWriteStream : public WriteStream { + public: + explicit CallbackWriteStream(PackedFunc fsend) : fsend_{fsend} {} -void MicroSession::ExitWithScope() { - TVMMicroSessionThreadLocalEntry* entry = TVMMicroSessionThreadLocalStore::Get(); - CHECK(!entry->session_stack.empty()); - entry->session_stack.pop(); -} - -MicroSession::MicroSession(const std::string& comms_method, const std::string& binary_path, - const std::string& toolchain_prefix, uint64_t text_start, - size_t text_size, uint64_t rodata_start, size_t rodata_size, - uint64_t data_start, size_t data_size, uint64_t bss_start, - size_t bss_size, uint64_t args_start, size_t args_size, - uint64_t heap_start, size_t heap_size, uint64_t workspace_start, - size_t workspace_size, uint64_t stack_start, size_t stack_size, - TargetWordSize word_size, bool thumb_mode, bool use_device_timer, - const std::string& server_addr, int port, PackedFunc debug_func) - : toolchain_prefix_(toolchain_prefix), - word_size_(word_size), - thumb_mode_(thumb_mode), - use_device_timer_(use_device_timer), - batch_args_encoder_(args_size, word_size), - debug_func_{debug_func} { - if (comms_method == "host") { - // TODO(weberlo): move checks to python - CHECK(text_start == 0 && rodata_start == 0 && data_start == 0 && bss_start == 0 && - args_start == 0 && heap_start == 0 && workspace_start == 0 && stack_start == 0) - << "unable to specify section addresses for host device"; - size_t memory_size = text_size + rodata_size + data_size + bss_size + args_size + heap_size + - workspace_size + stack_size; - TargetPtr base_addr; - low_level_device_ = HostLowLevelDeviceCreate(memory_size, &base_addr); - CHECK_EQ(base_addr.value().uint64() % word_size.bytes(), 0) - << "base address not aligned to " << word_size.bytes() << " bytes"; - TargetPtr curr_addr = base_addr; - - section_allocators_[0] = std::make_shared("text", - DevMemRegion{ - .start = curr_addr, - .size = text_size, - }, - word_size_); - curr_addr += text_size; - section_allocators_[1] = std::make_shared("rodata", - DevMemRegion{ - .start = curr_addr, - .size = rodata_size, - }, - word_size_); - curr_addr += rodata_size; - section_allocators_[2] = std::make_shared("data", - DevMemRegion{ - .start = curr_addr, - .size = data_size, - }, - word_size_); - curr_addr += data_size; - section_allocators_[3] = std::make_shared("bss", - DevMemRegion{ - .start = curr_addr, - .size = bss_size, - }, - word_size_); - curr_addr += bss_size; - section_allocators_[4] = std::make_shared("args", - DevMemRegion{ - .start = curr_addr, - .size = args_size, - }, - word_size_); - curr_addr += args_size; - section_allocators_[5] = std::make_shared("heap", - DevMemRegion{ - .start = curr_addr, - .size = heap_size, - }, - word_size_); - curr_addr += heap_size; - section_allocators_[6] = std::make_shared("workspace", - DevMemRegion{ - .start = curr_addr, - .size = workspace_size, - }, - word_size_); - curr_addr += workspace_size; - section_allocators_[7] = std::make_shared("stack", - DevMemRegion{ - .start = curr_addr, - .size = stack_size, - }, - word_size_); - curr_addr += stack_size; - } else if (comms_method == "openocd") { - low_level_device_ = OpenOCDLowLevelDeviceCreate(server_addr, port); - section_allocators_[0] = - std::make_shared("text", - DevMemRegion{ - .start = TargetPtr(word_size_, text_start), - .size = text_size, - }, - word_size_); - section_allocators_[1] = - std::make_shared("rodata", - DevMemRegion{ - .start = TargetPtr(word_size_, rodata_start), - .size = rodata_size, - }, - word_size_); - section_allocators_[2] = - std::make_shared("data", - DevMemRegion{ - .start = TargetPtr(word_size_, data_start), - .size = data_size, - }, - word_size_); - section_allocators_[3] = - std::make_shared("bss", - DevMemRegion{ - .start = TargetPtr(word_size_, bss_start), - .size = bss_size, - }, - word_size_); - section_allocators_[4] = - std::make_shared("args", - DevMemRegion{ - .start = TargetPtr(word_size_, args_start), - .size = args_size, - }, - word_size_); - section_allocators_[5] = - std::make_shared("heap", - DevMemRegion{ - .start = TargetPtr(word_size_, heap_start), - .size = heap_size, - }, - word_size_); - section_allocators_[6] = - std::make_shared("workspace", - DevMemRegion{ - .start = TargetPtr(word_size_, workspace_start), - .size = workspace_size, - }, - word_size_); - section_allocators_[7] = - std::make_shared("stack", - DevMemRegion{ - .start = TargetPtr(word_size_, stack_start), - .size = stack_size, - }, - word_size_); - } else { - LOG(FATAL) << "unsupported micro low-level device"; + ssize_t Write(const uint8_t* data, size_t data_size_bytes) override { + TVMByteArray bytes; + bytes.data = (const char*)data; + bytes.size = data_size_bytes; + int64_t n = fsend_(bytes); + return n; } - TargetPtr args_start_addr = GetAllocator(SectionKind::kArgs)->start_addr(); - batch_args_encoder_.set_start_addr(args_start_addr); - - runtime_symbol_map_ = LoadBinary(binary_path, false).symbol_map; - - // Patch pointers to define the bounds of the workspace section and the word - // size (for allocation alignment). - std::shared_ptr ws_allocator = GetAllocator(SectionKind::kWorkspace); - DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_start", ws_allocator->start_addr()); - DevSymbolWrite(runtime_symbol_map_, "utvm_workspace_end", ws_allocator->max_addr()); - if (word_size.bytes() == 4) { - DevSymbolWrite(runtime_symbol_map_, "utvm_word_size", uint32_t(word_size.bytes())); - } else if (word_size.bytes() == 8) { - DevSymbolWrite(runtime_symbol_map_, "utvm_word_size", uint64_t(word_size.bytes())); - } else { - CHECK(false) << "Unsupported word size unexpectedly here"; - } -} + void PacketDone(bool is_valid) override {} -MicroSession::~MicroSession() { - for (size_t i = 0; i < static_cast(SectionKind::kNumKinds); i++) { - section_allocators_[i] = nullptr; - } - low_level_device_ = nullptr; -} - -void MicroSession::PushToTaskQueue(TargetPtr func_ptr, const TVMArgs& args) { - if (thumb_mode_) { - // TODO(areusch): should be |= - func_ptr += 1; - } - TargetVal func_dev_addr = func_ptr.value(); + private: + PackedFunc fsend_; +}; - std::tuple arg_field_addrs = EncoderAppend(&batch_args_encoder_, args); - TargetVal arg_values_dev_addr{std::get<0>(arg_field_addrs).value()}; - TargetVal arg_type_codes_dev_addr{std::get<1>(arg_field_addrs).value()}; +class MicroTransportChannel : public RPCChannel { + public: + MicroTransportChannel(PackedFunc fsend, PackedFunc frecv) + : write_stream_{fsend}, + framer_{&write_stream_}, + receive_buffer_{new uint8_t[TVM_CRT_MAX_PACKET_SIZE_BYTES], TVM_CRT_MAX_PACKET_SIZE_BYTES}, + session_{0x5b, &framer_, &receive_buffer_, &HandleMessageReceivedCb, this}, + unframer_{session_.Receiver()}, + did_receive_message_{false}, + frecv_{frecv}, + message_buffer_{nullptr} {} + + size_t ReceiveUntil(TypedPackedFunc pf) { + size_t bytes_received = 0; + if (pf()) { + return 0; + } - task_queue_.push_back(DevTask{.func = func_dev_addr, - .arg_values = arg_values_dev_addr, - .arg_type_codes = arg_type_codes_dev_addr, - .num_args = args.num_args}); + for (;;) { + while (pending_chunk_.size() > 0) { + size_t bytes_consumed = 0; + int unframer_error = unframer_.Write((const uint8_t*)pending_chunk_.data(), + pending_chunk_.size(), &bytes_consumed); + + CHECK(bytes_consumed <= pending_chunk_.size()); + pending_chunk_ = pending_chunk_.substr(bytes_consumed); + bytes_received += bytes_consumed; + if (unframer_error < 0) { + LOG(ERROR) << "unframer got error code: " << unframer_error; + } else { + if (pf()) { + return bytes_received; + } + } + } - if (task_queue_.size() == MicroSession::kTaskQueueCapacity) { - FlushTaskQueue(); + std::string chunk = frecv_(128); + pending_chunk_ = chunk; + CHECK(pending_chunk_.size() != 0) << "zero-size chunk encountered"; + CHECK_GT(pending_chunk_.size(), 0); + } } -} -void MicroSession::FlushTaskQueue() { - if (task_queue_.size() == 0) { - // nothing to run - return; - } - if (word_size_.bytes() == 4) { - FlushTaskQueuePriv(); - } else if (word_size_.bytes() == 8) { - FlushTaskQueuePriv(); + void StartSession() { + CHECK_EQ(kTvmErrorNoError, session_.Initialize()); + CHECK_EQ(kTvmErrorNoError, session_.StartSession()); + ReceiveUntil([this]() -> bool { return session_.IsEstablished(); }); } -} -template -void MicroSession::FlushTaskQueuePriv() { - std::vector prepped_tasks; - for (const auto& task : task_queue_) { - prepped_tasks.push_back(T(task)); - } + size_t Send(const void* data, size_t size) override { + const uint8_t* data_bytes = static_cast(data); + ssize_t ret = session_.SendMessage(MessageType::kNormal, data_bytes, size); + CHECK(ret == 0) << "SendMessage returned " << ret; - // Flush `args` to device memory. - low_level_device()->Write(batch_args_encoder_.start_addr(), - reinterpret_cast(batch_args_encoder_.data()), - batch_args_encoder_.buf_size()); - - // Flush `tasks` to device memory. - TargetPtr dev_tasks_addr = runtime_symbol_map_["utvm_tasks"]; - low_level_device()->Write(dev_tasks_addr, reinterpret_cast(prepped_tasks.data()), - prepped_tasks.size() * sizeof(T)); - DevSymbolWrite(runtime_symbol_map_, "utvm_num_tasks", prepped_tasks.size()); - - TargetPtr utvm_init_addr = runtime_symbol_map_["UTVMInit"]; - TargetPtr utvm_done_addr = runtime_symbol_map_["UTVMDone"]; - if (thumb_mode_) { - // TODO(areusch): should be |= - utvm_init_addr += 1; + return size; } - bool did_debug = false; - if (debug_func_ != nullptr) { - TVMRetValue rv = debug_func_(); - if (rv.type_code() == kTVMNullptr) { - did_debug = true; - } else { - did_debug = static_cast(rv); - } + size_t Recv(void* data, size_t size) override { + size_t num_bytes_recv = 0; + while (num_bytes_recv < size) { + if (message_buffer_ != nullptr) { + num_bytes_recv += message_buffer_->Read(static_cast(data), size); + if (message_buffer_->ReadAvailable() == 0) { + message_buffer_ = nullptr; + session_.ClearReceiveBuffer(); + } + if (num_bytes_recv == size) { + CHECK(message_buffer_ == nullptr || message_buffer_->ReadAvailable() > 0); + return num_bytes_recv; + } + } - if (did_debug && !use_device_timer_) { - LOG(INFO) << "NOTE: when debugging and use_device_timer == false, reported execution time " - << "will be inaccurate!"; + did_receive_message_ = false; + ReceiveUntil([this]() -> bool { return did_receive_message_; }); } - } - if (!did_debug) { - std::chrono::time_point tbegin, - tend; - tbegin = std::chrono::high_resolution_clock::now(); - low_level_device()->Execute(utvm_init_addr, utvm_done_addr); - tend = std::chrono::high_resolution_clock::now(); - if (!use_device_timer_) { - last_batch_time_ += - std::chrono::duration_cast>(tend - tbegin).count() * 1000; - } + return num_bytes_recv; } - // Check if there was an error during execution. If so, log it. - CheckDeviceError(); - - if (use_device_timer_) { - uint64_t sum = 0; - std::vector times; - times.resize(task_queue_.size()); - low_level_device()->Read(runtime_symbol_map_["utvm_task_times"], times.data(), - task_queue_.size() * sizeof(uint32_t)); - int i = 0; - for (uint32_t time : times) { - LOG(INFO) << "Time " << i++ << ": " << time; - sum += time; - } - last_batch_time_ += static_cast(sum) / 1e3; - } else { - // TODO(weberlo): Reading internal data structure is hacky. - uint64_t sum = 0; - std::vector times; - times.resize(task_queue_.size()); - low_level_device()->Read(runtime_symbol_map_["utvm_task_times"], times.data(), - task_queue_.size() * sizeof(uint32_t)); - for (uint32_t time : times) { - sum += time; + FrameBuffer* GetReceivedMessage() { + if (did_receive_message_) { + did_receive_message_ = false; + return message_buffer_; } - last_batch_cycles_ += static_cast(sum); - } - batch_args_encoder_.Clear(); - task_queue_.clear(); -} - -BinaryInfo MicroSession::LoadBinary(const std::string& binary_path, bool patch_dylib_pointers) { - DevMemRegion text_section; - DevMemRegion rodata_section; - DevMemRegion data_section; - DevMemRegion bss_section; - - text_section.size = - GetSectionSize(binary_path, SectionKind::kText, toolchain_prefix_, word_size_); - rodata_section.size = - GetSectionSize(binary_path, SectionKind::kRodata, toolchain_prefix_, word_size_); - data_section.size = - GetSectionSize(binary_path, SectionKind::kData, toolchain_prefix_, word_size_); - bss_section.size = GetSectionSize(binary_path, SectionKind::kBss, toolchain_prefix_, word_size_); - - text_section.start = AllocateInSection(SectionKind::kText, text_section.size); - rodata_section.start = AllocateInSection(SectionKind::kRodata, rodata_section.size); - data_section.start = AllocateInSection(SectionKind::kData, data_section.size); - bss_section.start = AllocateInSection(SectionKind::kBss, bss_section.size); - - std::string relocated_bin = RelocateBinarySections( - binary_path, word_size_, text_section.start, rodata_section.start, data_section.start, - bss_section.start, GetAllocator(SectionKind::kStack)->max_addr(), toolchain_prefix_); - std::string text_contents = ReadSection(relocated_bin, SectionKind::kText, toolchain_prefix_); - std::string rodata_contents = ReadSection(relocated_bin, SectionKind::kRodata, toolchain_prefix_); - std::string data_contents = ReadSection(relocated_bin, SectionKind::kData, toolchain_prefix_); - std::string bss_contents = ReadSection(relocated_bin, SectionKind::kBss, toolchain_prefix_); - - low_level_device_->Write(text_section.start, &text_contents[0], text_section.size); - low_level_device_->Write(rodata_section.start, &rodata_contents[0], rodata_section.size); - low_level_device_->Write(data_section.start, &data_contents[0], data_section.size); - low_level_device_->Write(bss_section.start, &bss_contents[0], bss_section.size); - SymbolMap symbol_map{relocated_bin, toolchain_prefix_, word_size_}; - - if (patch_dylib_pointers) { - // Patch device lib pointers. - PatchImplHole(symbol_map, "TVMBackendAllocWorkspace"); - PatchImplHole(symbol_map, "TVMBackendFreeWorkspace"); - PatchImplHole(symbol_map, "TVMAPISetLastError"); + return nullptr; } - return BinaryInfo{ - .text_section = text_section, - .rodata_section = rodata_section, - .data_section = data_section, - .bss_section = bss_section, - .symbol_map = symbol_map, - }; -} - -std::tuple MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, - const TVMArgs& args) { - const int* type_codes = args.type_codes; - int num_args = args.num_args; - - auto tvm_vals_alloc = encoder->Alloc(num_args); - auto type_codes_alloc = encoder->Alloc(num_args); - - for (int i = 0; i < num_args; i++) { - switch (type_codes[i]) { - case kTVMNDArrayHandle: - case kTVMDLTensorHandle: { - DLTensor* base_arr_handle = args[i]; - // All uTVM arrays store a `MicroDevSpace` struct in their `data` field, - // which wraps the actual data and stores a reference to the session, in - // order to prevent premature session destruction. - void* old_data = base_arr_handle->data; - // Mutate the array to unwrap the `data` field. - MicroDevSpace* dev_arr_ptr = reinterpret_cast(old_data); - base_arr_handle->data = reinterpret_cast(dev_arr_ptr->data.value().uint64()); - // Now, encode the unwrapped version. - void* arr_ptr = nullptr; - if (word_size_.bytes() == 4) { - arr_ptr = EncoderAppend(encoder, *base_arr_handle).cast_to(); - } else if (word_size_.bytes() == 8) { - arr_ptr = EncoderAppend(encoder, *base_arr_handle).cast_to(); - } - // And restore the original wrapped version. - base_arr_handle->data = old_data; + private: + static void HandleMessageReceivedCb(void* context, MessageType message_type, FrameBuffer* buf) { + static_cast(context)->HandleMessageReceived(message_type, buf); + } - TVMValue val; - val.v_handle = arr_ptr; - tvm_vals_alloc->WriteValue(val); + void HandleMessageReceived(MessageType message_type, FrameBuffer* buf) { + size_t message_size_bytes; + switch (message_type) { + case MessageType::kStartSessionInit: + case MessageType::kStartSessionReply: break; - } - // TODO(weberlo): Implement `double` and `int64` case. - case kDLFloat: - case kDLInt: - case kDLUInt: - default: - LOG(FATAL) << "unsupported type code for writing args: " << type_codes[i]; - break; - } - } - type_codes_alloc->WriteArray(type_codes, num_args); - encoder->CheckUnfilledAllocs(); - return std::make_tuple(tvm_vals_alloc->start_addr(), type_codes_alloc->start_addr()); -} -template -TargetPtr MicroSession::EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTensor& arr) { - // `shape` and `strides` are stored on the host, so we need to write them to - // the device first. The `data` field is already allocated on the device and - // is a device pointer, so we don't need to write it. - auto shape_alloc = encoder->Alloc(arr.ndim); - shape_alloc->WriteArray(arr.shape, arr.ndim); - TargetPtr shape_dev_addr = shape_alloc->start_addr(); - TargetPtr strides_dev_addr = TargetPtr(word_size_, nullptr); - if (arr.strides != nullptr) { - auto stride_alloc = encoder->Alloc(arr.ndim); - stride_alloc->WriteArray(arr.strides, arr.ndim); - strides_dev_addr = stride_alloc->start_addr(); - } + case MessageType::kTerminateSession: + LOG(FATAL) << "SessionTerminatedError: remote side has probably reset"; + break; - T dev_arr(TargetVal{word_size_.bits(), reinterpret_cast(arr.data)}, arr.ctx, arr.ndim, - arr.dtype, shape_dev_addr.value(), strides_dev_addr.value(), - TargetVal{word_size_.bits(), arr.byte_offset}); - CHECK(dev_arr.ctx.device_type == static_cast(kDLMicroDev)) - << "attempt to write DLTensor with non-micro device type"; - // Update the device type to CPU, because from the microcontroller's - // perspective, it is. - dev_arr.ctx.device_type = DLDeviceType::kDLCPU; - - auto tvm_arr_alloc = encoder->Alloc(); - tvm_arr_alloc->WriteValue(dev_arr); - return tvm_arr_alloc->start_addr(); -} + case MessageType::kLog: + uint8_t message[1024]; + message_size_bytes = buf->ReadAvailable(); + if (message_size_bytes == 0) { + return; + } else if (message_size_bytes > sizeof(message) - 1) { + LOG(ERROR) << "Remote log message is too long to display: " << message_size_bytes + << " bytes"; + return; + } -// TODO(weberlo): switch over entirely to error codes that expand to error -// messages on the host side. -void MicroSession::CheckDeviceError() { - int32_t last_error = DevSymbolRead(runtime_symbol_map_, "utvm_last_error"); + CHECK_EQ(buf->Read(message, sizeof(message) - 1), message_size_bytes); + message[message_size_bytes] = 0; + LOG(INFO) << "remote: " << message; + session_.ClearReceiveBuffer(); + return; - if (last_error) { - if (!use_device_timer_ && - (last_error == UTVM_ERR_TIMER_OVERFLOW || last_error == UTVM_ERR_TIMER_NOT_IMPLEMENTED)) { - // these errors don't matter if we're not using the on-device timer - return; - } - std::string err_msg; - switch (last_error) { - case UTVM_ERR_NOT_FINISHED: - err_msg = "execution timed out"; - break; - case UTVM_ERR_TIMER_NOT_IMPLEMENTED: - err_msg = "timer is not implemented for the target device"; - break; - case UTVM_ERR_TIMER_OVERFLOW: - // TODO(weberlo): this should be remedied by using interrupts to accumulate the - // timer into a larger datatype (ARM timers are only 24 bits) - err_msg = "timer overflowed during execution"; - break; - case UTVM_ERR_WS_DOUBLE_FREE: - err_msg = "free called with no active workspace allocations"; - break; - case UTVM_ERR_WS_OUT_OF_SPACE: - err_msg = "ran out of space in workspace section"; - break; - case UTVM_ERR_WS_TOO_MANY_ALLOCS: - err_msg = "exceeded number of allocs the runtime can keep track of"; - break; - case UTVM_ERR_WS_ZERO_SIZE_ALLOC: - err_msg = "attempt to allocate scratchpad of size zero"; - break; - case UTVM_ERR_WS_UNALIGNED_START: - err_msg = "start of workspace section is not word-aligned"; - break; - case UTVM_ERR_WS_UNALIGNED_ALLOC_SIZE: - err_msg = "scratchpad allocation size is not a multiple of the word size"; - break; - default: - err_msg = "unknown error code"; + case MessageType::kNormal: + did_receive_message_ = true; + message_buffer_ = buf; break; } - LOG(FATAL) << "error during micro function execution:\n" - << " error ID: " << std::dec << last_error << std::endl - << " error message: " << err_msg; - } -} - -void MicroSession::PatchImplHole(const SymbolMap& symbol_map, const std::string& func_name) { - TargetPtr runtime_impl_addr = runtime_symbol_map_[func_name]; - if (thumb_mode_) { - runtime_impl_addr += 1; } - std::ostringstream func_name_underscore; - func_name_underscore << func_name << "_"; - DevSymbolWrite(symbol_map, func_name_underscore.str(), runtime_impl_addr); -} -std::string MicroSession::ReadString(TargetPtr str_addr) { - std::ostringstream result; - const size_t buf_size = 256; - std::vector buf(buf_size, 0); - size_t i = buf_size; - while (i == buf_size) { - low_level_device()->Read(str_addr, buf.data(), buf_size); - i = 0; - while (i < buf_size) { - if (buf[i] == 0) break; - result << buf[i]; - i++; - } - str_addr = str_addr + i; - } - return result.str(); -} - -TargetPtr MicroSession::AllocateInSection(SectionKind type, size_t size) { - return GetAllocator(type)->Allocate(size); -} + CallbackWriteStream write_stream_; + Framer framer_; + FrameBuffer receive_buffer_; + Session session_; + Unframer unframer_; + bool did_receive_message_; + PackedFunc frecv_; + FrameBuffer* message_buffer_; + std::string pending_chunk_; +}; -void MicroSession::FreeInSection(SectionKind type, TargetPtr addr) { - return GetAllocator(type)->Free(addr); -} +TVM_REGISTER_GLOBAL("micro._rpc_connect").set_body([](TVMArgs args, TVMRetValue* rv) { + MicroTransportChannel* micro_channel = new MicroTransportChannel(args[1], args[2]); + micro_channel->StartSession(); + std::unique_ptr channel(micro_channel); + auto ep = RPCEndpoint::Create(std::move(channel), args[0], ""); + auto sess = CreateClientSession(ep); + *rv = CreateRPCSessionModule(sess); +}); -template -T MicroSession::DevSymbolRead(const SymbolMap& symbol_map, const std::string& symbol) { - TargetPtr sym_addr = symbol_map[symbol]; - T result; - low_level_device()->Read(sym_addr, &result, sizeof(T)); - return result; -} +} // namespace micro_rpc +} // namespace runtime +} // namespace tvm -void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, - const TargetPtr& ptr) { - if (word_size_.bytes() == 4) { - DevSymbolWrite(symbol_map, symbol, ptr.value().uint32()); - } else if (word_size_.bytes() == 8) { - DevSymbolWrite(symbol_map, symbol, ptr.value().uint64()); - } else { - CHECK(false) << "Unsupported word size unexpectedly here"; - } -} +extern "C" { -template -void MicroSession::DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, - const T& value) { - TargetPtr sym_addr = symbol_map[symbol]; - low_level_device()->Write(sym_addr, &value, sizeof(T)); +void TVMLogf(const char* fmt, ...) { + va_list args; + char msg_buf[256]; + va_start(args, fmt); + vsnprintf(msg_buf, sizeof(msg_buf), fmt, args); + va_end(args); + LOG(INFO) << msg_buf; } -PackedFunc MicroSession::GetFunction(const std::string& name, - const ObjectPtr& sptr_to_self) { - if (name == "enter") { - return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - MicroSession::EnterWithScope(GetObjectPtr(this)); - }); - } else if (name == "exit") { - return PackedFunc( - [sptr_to_self](TVMArgs args, TVMRetValue* rv) { MicroSession::ExitWithScope(); }); - // TODO(weberlo): add a `clear_batch_timer` func - } else if (name == "get_last_batch_time") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLastBatchTime(); }); - // TODO(weberlo): remove this func - } else if (name == "get_last_batch_cycles") { - return PackedFunc( - [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetLastBatchCycles(); }); - } else { - return PackedFunc(); - } +void TVMPlatformAbort(int error_code) { CHECK(false) << "TVMPlatformAbort: " << error_code; } } - -TVM_REGISTER_GLOBAL("micro._GetMicroTimeEvaluator").set_body([](TVMArgs args, TVMRetValue* rv) { - PackedFunc pf = args[0]; - TVMContext ctx = args[1]; - uint64_t number = args[2]; - uint64_t repeat = args[3]; - - auto ftimer = [pf, ctx, number, repeat](TVMArgs args, TVMRetValue* rv) mutable { - TVMRetValue temp; - std::ostringstream os; - - for (unsigned int i = 0; i < repeat; ++i) { - // start timing - CHECK(number < MicroSession::kTaskQueueCapacity) - << "`number` must be less than uTVM task queue capacity"; - for (unsigned int j = 0; j < number; ++j) { - pf.CallPacked(args, &temp); - } - ObjectPtr session = MicroSession::Current(); - DeviceAPI::Get(ctx)->StreamSync(ctx, nullptr); - double time_per_batch = session->GetLastBatchTime() / number; - os.write(reinterpret_cast(&time_per_batch), sizeof(time_per_batch)); - } - std::string blob = os.str(); - TVMByteArray arr; - arr.size = blob.length(); - arr.data = blob.data(); - // return the time. - *rv = arr; - }; - *rv = PackedFunc(ftimer); -}); - -// create micro session and low-level device from Python frontend -TVM_REGISTER_GLOBAL("micro._CreateSession").set_body([](TVMArgs args, TVMRetValue* rv) { - const std::string& comms_method = args[0]; - const std::string& binary_path = args[1]; - const std::string& toolchain_prefix = args[2]; - uint64_t text_start = args[3]; - size_t text_size = uint64_t(args[4]); - uint64_t rodata_start = args[5]; - size_t rodata_size = uint64_t(args[6]); - uint64_t data_start = args[7]; - size_t data_size = uint64_t(args[8]); - uint64_t bss_start = args[9]; - size_t bss_size = uint64_t(args[10]); - uint64_t args_start = args[11]; - size_t args_size = uint64_t(args[12]); - uint64_t heap_start = args[13]; - size_t heap_size = uint64_t(args[14]); - uint64_t workspace_start = args[15]; - size_t workspace_size = uint64_t(args[16]); - uint64_t stack_start = args[17]; - size_t stack_size = uint64_t(args[18]); - TargetWordSize word_size{uint64_t(args[19])}; - bool thumb_mode = args[20]; - bool use_device_timer = args[21]; - const std::string& server_addr = args[22]; - int port = args[23]; - PackedFunc debug_func = args[24]; - ObjectPtr session = make_object( - comms_method, binary_path, toolchain_prefix, text_start, text_size, rodata_start, rodata_size, - data_start, data_size, bss_start, bss_size, args_start, args_size, heap_start, heap_size, - workspace_start, workspace_size, stack_start, stack_size, word_size, thumb_mode, - use_device_timer, server_addr, port, debug_func); - *rv = Module(session); -}); - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/micro/micro_session.h b/src/runtime/micro/micro_session.h index f911cf7dde43..50018a4bb0c0 100644 --- a/src/runtime/micro/micro_session.h +++ b/src/runtime/micro/micro_session.h @@ -34,402 +34,4 @@ #ifndef TVM_RUNTIME_MICRO_MICRO_SESSION_H_ #define TVM_RUNTIME_MICRO_MICRO_SESSION_H_ -#include -#include - -#include -#include -#include -#include -#include - -#include "low_level_device.h" -#include "micro_common.h" -#include "micro_section_allocator.h" -#include "target_data_layout_encoder.h" - -namespace tvm { -namespace runtime { - -struct DevTask; - -/*! - * \brief session for facilitating micro device interaction - */ -class MicroSession : public ModuleNode { - public: - /*! - * \brief Get member function to front-end - * \param name The name of the function. - * \param sptr_to_self The pointer to the module node. - * \return The corresponding member function. - */ - virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr& sptr_to_self); - - // todo having this decoupled from the value in utvm_runtime.c gives me stress dreams - static const size_t kTaskQueueCapacity = 20; - - /*! - * \return The type key of the executor. - */ - const char* type_key() const final { return "MicroSession"; } - - /*! - * \brief creates session by setting up a low-level device and initting allocators for it - * \param comms_method method of communication with the device (e.g., "openocd") - * \param binary_path file system path to the runtime binary - * \param toolchain_prefix GCC toolchain prefix - * \param text_start text section start address - * \param text_size text section size - * \param rodata_start text section start address - * \param rodata_size rodata section size - * \param data_start data section start address - * \param data_size data section size - * \param bss_start bss section start address - * \param bss_size bss section size - * \param args_start args section start address - * \param args_size args section size - * \param heap_start heap section start address - * \param heap_size heap section size - * \param workspace_start workspace section start address - * \param workspace_size workspace section size - * \param stack_start stack section start address - * \param stack_size stack section size - * \param word_size_bytes number of bytes in a word on the target device - * \param thumb_mode whether the target device requires a thumb-mode bit on function addresses - * \param server_addr address of the OpenOCD server to connect to (if `comms_method == "openocd"`) - * \param port port of the OpenOCD server to connect to (if `comms_method == "openocd"`) - */ - MicroSession(const std::string& comms_method, const std::string& binary_path, - const std::string& toolchain_prefix, uint64_t text_start, size_t text_size, - uint64_t rodata_start, size_t rodata_size, uint64_t data_start, size_t data_size, - uint64_t bss_start, size_t bss_size, uint64_t args_start, size_t args_size, - uint64_t heap_start, size_t heap_size, uint64_t workspace_start, - size_t workspace_size, uint64_t stack_start, size_t stack_size, - TargetWordSize word_size, bool thumb_mode, bool use_device_timer, - const std::string& server_addr, int port, PackedFunc debug_func); - - /*! - * \brief destructor - */ - ~MicroSession(); - - static ObjectPtr& Current(); - - /*! - * \brief sets up runtime metadata for `func` and copies arguments for on-device execution - * \param func address of the function to be executed - * \param args args to the packed function - * \return elapsed time during function execution on the device - */ - void PushToTaskQueue(TargetPtr func, const TVMArgs& args); - - /*! - * \brief serialize runtime metadata to the device for enqueued tasks and execute - * \return elapsed time during function execution on the device - */ - void FlushTaskQueue(); - - /*! - * \brief TODO - */ - template - void FlushTaskQueuePriv(); - - /*! - * \brief loads binary onto device - * \param binary_path path to binary object file - * \param patch_dylib_pointers whether to patch runtime API function pointers - * \return info about loaded binary - */ - BinaryInfo LoadBinary(const std::string& binary_path, bool patch_dylib_pointers); - - /*! - * \brief allocate memory in section - * \param type type of section to allocate in - * \param size size of allocated memory in bytes - * \return pointer to allocated memory region in section, nullptr if out of space - */ - TargetPtr AllocateInSection(SectionKind type, size_t size); - - /*! - * \brief free prior allocation from section - * \param type type of section to allocate in - * \param addr device address of allocated memory - */ - void FreeInSection(SectionKind type, TargetPtr addr); - - /*! - * \brief read string from device to host - * \param str_addr device address of first character of string - * \return host copy of device string that was read - */ - std::string ReadString(TargetPtr str_addr); - - /*! - * \brief read value of symbol from device memory - * \param symbol_map symbol map to read location of symbol from - * \param symbol name of symbol being read from - * \return value at symbol in memory - */ - template - T DevSymbolRead(const SymbolMap& symbol_map, const std::string& symbol); - - /*! - * \brief write pointer value into device memory corresponding to symbol - * \param symbol_map symbol map to read location of symbol from - * \param symbol name of symbol being written to - * \param ptr pointer value to write into symbol - */ - void DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const TargetPtr& ptr); - - /*! - * \brief write value into device memory corresponding to symbol - * \param symbol_map symbol map to read location of symbol from - * \param symbol name of symbol being written to - * \param value value being written into symbol - */ - template - void DevSymbolWrite(const SymbolMap& symbol_map, const std::string& symbol, const T& value); - - /*! - * \brief returns low-level device pointer - * \note assumes low-level device has been initialized - */ - const std::shared_ptr& low_level_device() const { - CHECK(low_level_device_ != nullptr) << "attempt to get uninitialized low-level device"; - return low_level_device_; - } - - const double GetLastBatchTime() { - double result = last_batch_time_; - last_batch_time_ = 0.0; - return result; - } - - const double GetLastBatchCycles() { - double result = last_batch_cycles_; - last_batch_cycles_ = 0.0; - return result; - } - - private: - /*! \brief low-level device pointer */ - std::shared_ptr low_level_device_; - /*! \brief prefix for binary names in target compiler toolchain */ - std::string toolchain_prefix_; - /*! \brief array of memory allocators for each on-device section */ - std::shared_ptr - section_allocators_[static_cast(SectionKind::kNumKinds)]; - /*! \brief number of bytes in a word on the target device */ - TargetWordSize word_size_; - /*! \brief whether the target device requires a thumb-mode bit on function addresses - * - * ARM and other manufacturers use the lowest bit of a function address to determine - * whether it's a "thumb mode" function. The Thumb ISA is more restricted, but - * results in more compact binaries. - */ - bool thumb_mode_; - /*! \brief TODO */ - bool use_device_timer_; - /*! \brief symbol map for the device runtime */ - SymbolMap runtime_symbol_map_; - /*! \brief TODO */ - std::vector task_queue_; - // TODO(weberlo): we don't even need an allocator mechanism for the args - // section. there's only ever one allocation. - /*! \brief TODO hack */ - TargetDataLayoutEncoder batch_args_encoder_; - /*! \brief TODO hack */ - double last_batch_time_; - /*! \brief TODO hack */ - double last_batch_cycles_; - /*! \brief the debug function invoked to launch gdb */ - PackedFunc debug_func_; - - /*! - * \brief patches a function pointer in this module to an implementation - * \param func_name name of the function pointer being patched - */ - void PatchImplHole(const SymbolMap& symbol_map, const std::string& func_name); - - /*! - * \brief appends arguments to the host-side buffer of `encoder` - * \param encoder encoder being used to append `args` - * \param args args to be appended - * \return device address of the allocated args - */ - std::tuple EncoderAppend(TargetDataLayoutEncoder* encoder, - const TVMArgs& args); - - /*! - * \brief appends a `DLTensor` to the host-side buffer of `encoder` - * \param encoder encoder being used to append `arr` - * \param arr DLTensor to be appended - * \return device address of the allocated `DLTensor` - */ - template - TargetPtr EncoderAppend(TargetDataLayoutEncoder* encoder, const DLTensor& arr); - - /*! - * \brief checks and logs if there was an error during the device's most recent execution - */ - void CheckDeviceError(); - - /*! - * \brief returns section allocator corresponding to the given section kind - * \param kind kind of target section - * \return shared pointer to section allocator - */ - std::shared_ptr GetAllocator(SectionKind kind) { - return section_allocators_[static_cast(kind)]; - } - - /*! - * \brief Push a new session context onto the thread-local stack. - * The session on top of the stack is used as the current global session. - */ - static void EnterWithScope(ObjectPtr session); - - /*! - * \brief Pop a session off the thread-local context stack, - * restoring the previous session as the current context. - */ - static void ExitWithScope(); -}; - -/*! - * \brief a device memory region associated with the session that allocated it - * - * We use this to store a reference to the session in each allocated object and - * only deallocate the session once there are no more references to it. - */ -struct MicroDevSpace { - /*! \brief data being wrapped */ - TargetPtr data; - /*! \brief shared ptr to session where this data is valid */ - ObjectPtr session; -}; - -// TODO(weberlo): maybe templatize serialization to reduce redundancy - -/*! \brief TVM array for serialization to 32-bit devices */ -struct TVMArray32 { - TVMArray32(TargetVal data, DLContext ctx, int32_t ndim, DLDataType dtype, TargetVal shape, - TargetVal strides, TargetVal byte_offset) - : data{data.uint32()}, - ctx{ctx}, - ndim{ndim}, - dtype{dtype}, - shape{shape.uint32()}, - strides{strides.uint32()}, - byte_offset{byte_offset.uint32()} {} - - /*! - * \brief The opaque data pointer points to the allocated data. - * This will be CUDA device pointer or cl_mem handle in OpenCL. - * This pointer is always aligns to 256 bytes as in CUDA. - */ - uint32_t data; - /*! \brief The device context of the tensor */ - DLContext ctx; - /*! \brief Number of dimensions */ - int32_t ndim; - /*! \brief The data type of the pointer */ - DLDataType dtype; - /*! \brief The shape of the tensor */ - uint32_t shape; - /*! - * \brief strides of the tensor, - * can be NULL, indicating tensor is compact. - */ - uint32_t strides; - /*! \brief The offset in bytes to the beginning pointer to data */ - uint32_t byte_offset; -}; - -/*! \brief TVM array for serialization to 64-bit devices */ -struct TVMArray64 { - TVMArray64(TargetVal data, DLContext ctx, int32_t ndim, DLDataType dtype, TargetVal shape, - TargetVal strides, TargetVal byte_offset) - : data(data.uint64()), - ctx(ctx), - ndim(ndim), - dtype(dtype), - shape(shape.uint64()), - strides(strides.uint64()), - byte_offset(byte_offset.uint64()) {} - /*! - * \brief The opaque data pointer points to the allocated data. - * This will be CUDA device pointer or cl_mem handle in OpenCL. - * This pointer is always aligns to 256 bytes as in CUDA. - */ - uint64_t data; - /*! \brief The device context of the tensor */ - DLContext ctx; - /*! \brief Number of dimensions */ - int32_t ndim; - /*! \brief The data type of the pointer */ - DLDataType dtype; - /*! \brief The shape of the tensor */ - uint64_t shape; - /*! - * \brief strides of the tensor, - * can be NULL, indicating tensor is compact. - */ - uint64_t strides; - /*! \brief The offset in bytes to the beginning pointer to data */ - uint64_t byte_offset; -}; - -/*! \brief MicroTVM task to store in task queue before specializing to word size */ -struct DevTask { - /*! \brief Pointer to function to call for this task */ - TargetVal func; - /*! \brief Array of argument values */ - TargetVal arg_values; - /*! \brief Array of type codes for each argument value */ - TargetVal arg_type_codes; - /*! \brief Number of arguments */ - int32_t num_args; -}; - -/*! \brief MicroTVM task for serialization to 32-bit devices */ -typedef struct StructUTVMTask32 { - StructUTVMTask32(DevTask task) - : func(task.func.uint32()), - arg_values(task.arg_values.uint32()), - arg_type_codes(task.arg_type_codes.uint32()), - num_args(task.num_args) {} - - /*! \brief Pointer to function to call for this task */ - uint32_t func; - /*! \brief Array of argument values */ - uint32_t arg_values; - /*! \brief Array of type codes for each argument value */ - uint32_t arg_type_codes; - /*! \brief Number of arguments */ - int32_t num_args; -} StructUTVMTask32; - -/*! \brief MicroTVM task for serialization to 64-bit devices */ -typedef struct StructUTVMTask64 { - StructUTVMTask64(DevTask task) - : func(task.func.uint64()), - arg_values(task.arg_values.uint64()), - arg_type_codes(task.arg_type_codes.uint64()), - num_args(task.num_args) {} - - /*! \brief Pointer to function to call for this task */ - uint64_t func; - /*! \brief Array of argument values */ - uint64_t arg_values; - /*! \brief Array of type codes for each argument value */ - uint64_t arg_type_codes; - /*! \brief Number of arguments */ - int32_t num_args; -} StructUTVMTask64; - -} // namespace runtime -} // namespace tvm #endif // TVM_RUNTIME_MICRO_MICRO_SESSION_H_ diff --git a/src/runtime/micro/openocd_low_level_device.cc b/src/runtime/micro/openocd_low_level_device.cc deleted file mode 100644 index 610ca8590dd1..000000000000 --- a/src/runtime/micro/openocd_low_level_device.cc +++ /dev/null @@ -1,221 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file openocd_low_level_device.cc - */ -#include -#include - -#include "low_level_device.h" -#include "micro_common.h" -#include "tcl_socket.h" - -namespace tvm { -namespace runtime { - -/*! - * \brief OpenOCD low-level device for uTVM micro devices connected over JTAG - */ -class OpenOCDLowLevelDevice final : public LowLevelDevice { - public: - /*! - * \brief constructor to initialize connection to openocd device - * \param server_addr address of the OpenOCD server to connect to - * \param port port of the OpenOCD server to connect to - */ - explicit OpenOCDLowLevelDevice(const std::string& server_addr, int port) : socket_() { - server_addr_ = server_addr; - port_ = port; - - socket_.Connect(tvm::support::SockAddr(server_addr_.c_str(), port_)); - socket_.cmd_builder() << "reset run"; - socket_.SendCommand(); - - socket_.cmd_builder() << "halt 500"; - socket_.SendCommand(); - } - - void Read(TargetPtr addr, void* buf, size_t num_bytes) override { - if (num_bytes == 0) { - return; - } - - // TODO(weberlo): Refactor between read and write. - // Check if we need to chunk this write request. - if (num_bytes > kMemTransferLimit) { - char* curr_buf_ptr = reinterpret_cast(buf); - while (num_bytes != 0) { - size_t amount_to_read; - if (num_bytes > kMemTransferLimit) { - amount_to_read = kMemTransferLimit; - } else { - amount_to_read = num_bytes; - } - Read(addr, reinterpret_cast(curr_buf_ptr), amount_to_read); - addr += amount_to_read; - curr_buf_ptr += amount_to_read; - num_bytes -= amount_to_read; - } - return; - } - { - socket_.cmd_builder() << "array unset output"; - socket_.SendCommand(); - - socket_.cmd_builder() << "mem2array output" - << " " << std::dec << kWordSize << " " - << addr.cast_to() - // Round up any request sizes under a byte, since OpenOCD doesn't - // support sub-byte-sized transfers. - << " " << std::dec << (num_bytes < 8 ? 8 : num_bytes); - socket_.SendCommand(); - } - - { - socket_.cmd_builder() << "return $output"; - socket_.SendCommand(); - const std::string& reply = socket_.last_reply(); - - std::istringstream values(reply); - char* char_buf = reinterpret_cast(buf); - ssize_t req_bytes_remaining = num_bytes; - uint32_t index; - uint32_t val; - while (req_bytes_remaining > 0) { - // The response from this command pairs indices with the contents of the - // memory at that index. - values >> index; - CHECK(index < num_bytes) << "index " << index << " out of bounds (length " << num_bytes - << ")"; - // Read the value into `curr_val`, instead of reading directly into - // `buf_iter`, because otherwise it's interpreted as the ASCII value and - // not the integral value. - values >> val; - char_buf[index] = static_cast(val); - req_bytes_remaining--; - } - if (num_bytes >= 8) { - uint32_t check_index; - values >> check_index; - CHECK(check_index != index) << "more data in response than requested"; - } - } - } - - void Write(TargetPtr addr, const void* buf, size_t num_bytes) override { - if (num_bytes == 0) { - return; - } - - // Check if we need to chunk this write request. - if (num_bytes > kMemTransferLimit) { - const char* curr_buf_ptr = reinterpret_cast(buf); - while (num_bytes != 0) { - size_t amount_to_write; - if (num_bytes > kMemTransferLimit) { - amount_to_write = kMemTransferLimit; - } else { - amount_to_write = num_bytes; - } - Write(addr, reinterpret_cast(curr_buf_ptr), amount_to_write); - addr += amount_to_write; - curr_buf_ptr += amount_to_write; - num_bytes -= amount_to_write; - } - return; - } - - // Clear `input` array. - socket_.cmd_builder() << "array unset input"; - socket_.SendCommand(); - // Build a command to set the value of `input`. - { - std::ostringstream& cmd_builder = socket_.cmd_builder(); - cmd_builder << "array set input {"; - const char* char_buf = reinterpret_cast(buf); - for (size_t i = 0; i < num_bytes; i++) { - // In a Tcl `array set` commmand, we need to pair the array indices with - // their values. - cmd_builder << i << " "; - // Need to cast to uint, so the number representation of `buf[i]` is - // printed, and not the ASCII representation. - cmd_builder << static_cast(char_buf[i]) << " "; - } - cmd_builder << "}"; - socket_.SendCommand(); - } - { - socket_.cmd_builder() << "array2mem input" - << " " << std::dec << kWordSize << " " << addr.cast_to() << " " - << std::dec << num_bytes; - socket_.SendCommand(); - } - } - - void Execute(TargetPtr func_addr, TargetPtr breakpoint_addr) override { - socket_.cmd_builder() << "halt 0"; - socket_.SendCommand(); - - // Set a breakpoint at the beginning of `UTVMDone`. - socket_.cmd_builder() << "bp " << breakpoint_addr.cast_to() << " 2"; - socket_.SendCommand(); - - socket_.cmd_builder() << "resume " << func_addr.cast_to(); - socket_.SendCommand(); - - socket_.cmd_builder() << "wait_halt " << kWaitTime; - socket_.SendCommand(); - - socket_.cmd_builder() << "halt 0"; - socket_.SendCommand(); - - // Remove the breakpoint. - socket_.cmd_builder() << "rbp " << breakpoint_addr.cast_to(); - socket_.SendCommand(); - } - - const char* device_type() const final { return "openocd"; } - - private: - /*! \brief socket used to communicate with the device through Tcl */ - TclSocket socket_; - /*! \brief address of OpenOCD server */ - std::string server_addr_; - /*! \brief port of OpenOCD server */ - int port_; - - /*! \brief number of bytes in a word on the target device (64-bit) */ - static const constexpr ssize_t kWordSize = 8; - // NOTE: The OS pipe buffer must be able to handle a line long enough to - // print this transfer request. - /*! \brief maximum number of bytes allowed in a single memory transfer */ - static const constexpr ssize_t kMemTransferLimit = 8000; - /*! \brief number of milliseconds to wait for function execution to halt */ - static const constexpr int kWaitTime = 30000; -}; - -const std::shared_ptr OpenOCDLowLevelDeviceCreate(const std::string& server_addr, - int port) { - std::shared_ptr lld = std::make_shared(server_addr, port); - return lld; -} - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/micro/target_data_layout_encoder.cc b/src/runtime/micro/target_data_layout_encoder.cc deleted file mode 100644 index 4a87a8f35721..000000000000 --- a/src/runtime/micro/target_data_layout_encoder.cc +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -#include "target_data_layout_encoder.h" - -namespace tvm { -namespace runtime { - -TargetDataLayoutEncoder::Alloc::Alloc(TargetDataLayoutEncoder* parent, size_t start_offset, - size_t size, TargetPtr start_addr) - : parent_(parent), - start_offset_(start_offset), - curr_offset_(0), - size_(size), - start_addr_(start_addr) { - parent_->live_unchecked_allocs_.insert(this); -} - -TargetDataLayoutEncoder::Alloc::~Alloc() { - auto it = parent_->live_unchecked_allocs_.find(this); - if (it != parent_->live_unchecked_allocs_.end()) { - // alloc was not already checked - parent_->live_unchecked_allocs_.erase(it); - if (curr_offset_ != size_) { - parent_->unchecked_alloc_start_offsets_.push_back(start_addr_.value().uint64()); - } - } -} - -void TargetDataLayoutEncoder::Alloc::CheckUnfilled() { - CHECK(curr_offset_ == size_) << "unwritten space in alloc 0x" << std::hex - << start_addr_.value().uint64() << "; curr_offset=0x" << curr_offset_ - << ", size=0x" << size_; -} - -TargetPtr TargetDataLayoutEncoder::Alloc::start_addr() { return start_addr_; } - -size_t TargetDataLayoutEncoder::Alloc::size() { return size_; } - -void TargetDataLayoutEncoder::CheckUnfilledAllocs() { - CHECK(live_unchecked_allocs_.size() > 0) << "No allocs to check"; - if (unchecked_alloc_start_offsets_.size() > 0) { - LOG(ERROR) << "Unchecked allocs were found:"; - for (size_t alloc_start_addr : unchecked_alloc_start_offsets_) { - LOG(ERROR) << " * 0x" << std::hex << alloc_start_addr; - } - CHECK(false) << "Unchecked allocs found during CheckUnfilledAllocs"; - } - - for (class Alloc* s : live_unchecked_allocs_) { - s->CheckUnfilled(); - } - live_unchecked_allocs_.clear(); -} - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/micro/target_data_layout_encoder.h b/src/runtime/micro/target_data_layout_encoder.h deleted file mode 100644 index 81587755e3b3..000000000000 --- a/src/runtime/micro/target_data_layout_encoder.h +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file target_data_layout_encoder.h - * \brief uTVM data layout encoder - */ -#ifndef TVM_RUNTIME_MICRO_TARGET_DATA_LAYOUT_ENCODER_H_ -#define TVM_RUNTIME_MICRO_TARGET_DATA_LAYOUT_ENCODER_H_ - -#include -#include -#include - -#include "host_driven/utvm_runtime_enum.h" -#include "micro_common.h" - -namespace tvm { -namespace runtime { - -// TODO(weberlo, areusch): Handle endianness. - -/*! - * \brief data encoder for uTVM that builds a host-side buffer - */ -class TargetDataLayoutEncoder { - public: - /*! - * \brief helper class for writing into `TargetDataLayoutEncoder` - */ - class Alloc { - public: - /*! - * \brief constructor - * \param parent pointer to parent encoder - * \param start_offset start byte offset of the alloc in the backing buffer - * \param size size (in bytes) of the memory region allocated for this alloc - * \param start_addr start address of the alloc in the device's memory - */ - Alloc(TargetDataLayoutEncoder* parent, size_t start_offset, size_t size, TargetPtr start_addr); - - ~Alloc(); - - /*! - * \brief writes `sizeof(T) * num_elems` bytes of data from `arr` - * \param arr array to be read from - * \param num_elems number of elements in array - */ - template - void WriteArray(const T* arr, size_t num_elems); - - /*! - * \brief writes `val` - * \param val value to be written - */ - template - void WriteValue(const T& val); - - /*! - * \brief returns start address of the alloc in device memory - * \return device start address - */ - TargetPtr start_addr(); - - /*! - * \brief returns number of bytes allocated for this alloc - * \return size of this alloc - */ - size_t size(); - - size_t curr_offset() const { return curr_offset_; } - - void CheckUnfilled(); - - private: - /*! \brief pointer to parent encoder */ - TargetDataLayoutEncoder* parent_; - /*! \brief start offset of the alloc in the parent's backing parent_buffer */ - size_t start_offset_; - /*! \brief current offset relative to the start offset of this alloc */ - size_t curr_offset_; - /*! \brief size (in bytes) of the memory region allocated for this alloc */ - size_t size_; - /*! \brief start address of the alloc in the device's memory */ - TargetPtr start_addr_; - }; - - /*! - * \brief constructor - * \param start_addr start address of the encoder in device memory - */ - explicit TargetDataLayoutEncoder(size_t capacity, TargetWordSize word_size) - : buf_(std::vector()), - curr_offset_(0), - start_addr_(word_size, nullptr), - capacity_(capacity), - word_size_(word_size) {} - - /*! - * \brief allocates a alloc for `sizeof(T) * num_elems` bytes of data - * \param num_elems number of elements of type `T` being allocated (defaults to 1) - * \return alloc of size `sizeof(T) * num_elems` bytes - */ - template - std::unique_ptr Alloc(size_t num_elems = 1) { - curr_offset_ = UpperAlignValue(curr_offset_, word_size_.bytes()); - size_t size = sizeof(T) * num_elems; - if (curr_offset_ + size > buf_.size()) { - buf_.resize(curr_offset_ + size); - } - CHECK(buf_.size() < capacity_) << "out of space in data encoder"; - size_t alloc_start_offset = curr_offset_; - curr_offset_ += size; - class Alloc* alloc = - new class Alloc(this, alloc_start_offset, size, start_addr() + alloc_start_offset); - return std::unique_ptr(alloc); - } - - void Clear() { - buf_.clear(); - curr_offset_ = 0; - } - - /*! - * \brief returns the array backing the encoder's buffer - * \return array backing the encoder's buffer - */ - uint8_t* data() { return buf_.data(); } - - /*! - * \brief returns current size of the encoder's buffer - * \return buffer size - */ - size_t buf_size() const { return buf_.size(); } - - TargetPtr start_addr() const { - CHECK_NE(start_addr_.value().uint64(), 0) << "start addr uninitialized"; - return start_addr_; - } - - void set_start_addr(TargetPtr start_addr) { - CHECK_EQ(buf_.size(), 0) << "cannot change encoder start addr unless empty"; - start_addr_ = - TargetPtr(word_size_, UpperAlignValue(start_addr.value().uint64(), word_size_.bytes())); - } - - void CheckUnfilledAllocs(); - - private: - /*! \brief in-memory backing buffer */ - std::vector buf_; - /*! \brief current offset */ - size_t curr_offset_; - /*! \brief start address of the encoder in device memory */ - TargetPtr start_addr_; - /*! \brief number of bytes available in device memory */ - size_t capacity_; - /*! \brief number of bytes in a word on the target device */ - TargetWordSize word_size_; - /*! \brief Alloc instances allocated now but not yet checked by CheckUnfilledAllocs */ - std::set live_unchecked_allocs_; - /*! \brief start offsets Alloc instances that were dealloated before CheckUnfilledAllocs ran */ - std::vector unchecked_alloc_start_offsets_; - friend Alloc::~Alloc(); -}; - -template -void TargetDataLayoutEncoder::Alloc::WriteArray(const T* arr, size_t num_elems) { - if (num_elems == 0) return; - size_t size = sizeof(T) * num_elems; - CHECK(curr_offset_ + size <= size_) << "not enough space in alloc"; - uint8_t* curr_ptr = &(parent_->data())[start_offset_ + curr_offset_]; - std::memcpy(curr_ptr, arr, size); - curr_offset_ += size; -} - -template -void TargetDataLayoutEncoder::Alloc::WriteValue(const T& val) { - WriteArray(&val, 1); -} - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_MICRO_TARGET_DATA_LAYOUT_ENCODER_H_ diff --git a/src/runtime/micro/tcl_socket.cc b/src/runtime/micro/tcl_socket.cc deleted file mode 100644 index 8f482b874260..000000000000 --- a/src/runtime/micro/tcl_socket.cc +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tcl_socket.cc - */ -#include "tcl_socket.h" - -#include - -namespace tvm { -namespace runtime { - -TclSocket::TclSocket() { - tcp_socket_.Create(); - tcp_socket_.SetKeepAlive(true); - reply_buf_.reserve(kReplyBufSize); -} - -TclSocket::~TclSocket() { tcp_socket_.Close(); } - -void TclSocket::Connect(tvm::support::SockAddr addr) { - CHECK(tcp_socket_.Connect(addr)) << "failed to connect"; -} - -void TclSocket::SendCommand() { - const char terminate_token = kCommandTerminateToken; - cmd_builder_ << terminate_token; - std::string full_cmd = cmd_builder_.str(); - - CHECK(tcp_socket_.Send(full_cmd.data(), full_cmd.length()) != -1) << "failed to send command"; - cmd_builder_.str(std::string()); - - reply_builder_.str(std::string()); - char last_read = '\0'; - // Receive from the socket until we reach a command terminator. - do { - ssize_t bytes_read; - // Recieve from the socket until it's drained. - do { - // Leave room at the end of `reply_buf` to tack on a null terminator. - bytes_read = tcp_socket_.Recv(reply_buf_.data(), kReplyBufSize - 1); - reply_buf_[bytes_read] = '\0'; - reply_builder_ << reply_buf_.data(); - // Update last read character. - last_read = reply_buf_[bytes_read - 1]; - } while (bytes_read == kReplyBufSize - 1); - CHECK(bytes_read != -1) << "failed to read command reply"; - } while (last_read != terminate_token); - last_reply_ = reply_builder_.str(); - CHECK_EQ(last_reply_[last_reply_.length() - 1], terminate_token) << "missing command terminator"; -} - -} // namespace runtime -} // namespace tvm diff --git a/src/runtime/micro/tcl_socket.h b/src/runtime/micro/tcl_socket.h deleted file mode 100644 index 4aef2aef36e2..000000000000 --- a/src/runtime/micro/tcl_socket.h +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file tcl_socket.h - * \brief TCP socket wrapper for communicating using Tcl commands - */ -#ifndef TVM_RUNTIME_MICRO_TCL_SOCKET_H_ -#define TVM_RUNTIME_MICRO_TCL_SOCKET_H_ - -#include -#include - -#include "../../support/socket.h" - -namespace tvm { -namespace runtime { - -/*! - * \brief TCP socket wrapper for communicating using Tcl commands - * - * Usage generally involves building a command using the `cmd_builder` stream - * interface, then sending the command with `SendCommand`, and if necessary, - * reading the reply. - */ -class TclSocket { - public: - /*! - * \brief constructor to create the socket - */ - TclSocket(); - - /*! - * \brief destructor to close the socket connection - */ - ~TclSocket(); - - /*! - * \brief open connection with server - * \param addr server address - */ - void Connect(tvm::support::SockAddr addr); - - /* - * \brief send the built command to the server and await a reply - * - * \return the reply - */ - void SendCommand(); - - /* - * \return string stream for current command being built - */ - std::ostringstream& cmd_builder() { return cmd_builder_; } - - /* - * \return reply from most recently sent command - */ - const std::string& last_reply() { return last_reply_; } - - private: - /*! \brief underlying TCP socket being wrapped */ - tvm::support::TCPSocket tcp_socket_; - /*! \brief buffer used to receive messages from the socket */ - std::vector reply_buf_; - /*! \brief string stream used to build current command */ - std::ostringstream cmd_builder_; - /*! \brief string stream used to receive replies from sent commands */ - std::ostringstream reply_builder_; - /*! \brief reply from most recently sent command */ - std::string last_reply_; - - /*! \brief character denoting the end of a Tcl command */ - static const constexpr char kCommandTerminateToken = '\x1a'; - /*! \brief size of the buffer used to receive messages (in bytes) */ - static const constexpr size_t kReplyBufSize = 4096; -}; - -} // namespace runtime -} // namespace tvm -#endif // TVM_RUNTIME_MICRO_TCL_SOCKET_H_ diff --git a/src/runtime/rpc/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h similarity index 86% rename from src/runtime/rpc/minrpc/minrpc_server.h rename to src/runtime/minrpc/minrpc_server.h index 91a900afd900..565f92ad59be 100644 --- a/src/runtime/rpc/minrpc/minrpc_server.h +++ b/src/runtime/minrpc/minrpc_server.h @@ -25,14 +25,15 @@ * \note This file do not depend on c++ std or c std, * and only depends on TVM's C runtime API. */ -#ifndef TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ -#define TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ +#ifndef TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_ +#define TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_ -#include +#define DMLC_LITTLE_ENDIAN true +#include #include -#include "../../../support/arena.h" -#include "../rpc_protocol.h" +#include "../../support/generic_arena.h" +#include "rpc_reference.h" /*! \brief Whether or not to enable glog style DLOG */ #ifndef TVM_MINRPC_ENABLE_LOGGING @@ -59,6 +60,7 @@ namespace runtime { * \tparam TIOHandler IO provider to provide io handling. * An IOHandler needs to provide the following functions: * - PosixWrite, PosixRead, Close: posix style, read, write, close API. + * - MessageStart(num_bytes), MessageDone(): framing APIs. * - Exit: exit with status code. */ template @@ -68,59 +70,63 @@ class MinRPCServer { * \brief Constructor. * \param io The IO handler. */ - explicit MinRPCServer(TIOHandler io) : io_(io), arena_(PageAllocator(io)) {} + explicit MinRPCServer(TIOHandler* io) : io_(io), arena_(PageAllocator(io)) {} - /*! \brief Run the server loop until shutdown signal is received. */ - void ServerLoop() { + /*! \brief Process a single request. + * + * \return true when the server should continue processing requests. false when it should be + * shutdown. + */ + bool ProcessOnePacket() { RPCCode code; uint64_t packet_len; - while (true) { - arena_.RecycleAll(); - allow_clean_shutdown_ = true; + arena_.RecycleAll(); + allow_clean_shutdown_ = true; - this->Read(&packet_len); - if (packet_len == 0) continue; - this->Read(&code); + this->Read(&packet_len); + if (packet_len == 0) return true; + this->Read(&code); - allow_clean_shutdown_ = false; + allow_clean_shutdown_ = false; - if (code >= RPCCode::kSyscallCodeStart) { - this->HandleSyscallFunc(code); - } else { - switch (code) { - case RPCCode::kCallFunc: { - HandleNormalCallFunc(); - break; - } - case RPCCode::kInitServer: { - HandleInitServer(); - break; - } - case RPCCode::kCopyFromRemote: { - HandleCopyFromRemote(); - break; - } - case RPCCode::kCopyToRemote: { - HandleCopyToRemote(); - break; - } - case RPCCode::kShutdown: { - this->Shutdown(); - return; - } - default: { - this->ThrowError(RPCServerStatus::kUnknownRPCCode); - break; - } + if (code >= RPCCode::kSyscallCodeStart) { + this->HandleSyscallFunc(code); + } else { + switch (code) { + case RPCCode::kCallFunc: { + HandleNormalCallFunc(); + break; + } + case RPCCode::kInitServer: { + HandleInitServer(); + break; + } + case RPCCode::kCopyFromRemote: { + HandleCopyFromRemote(); + break; + } + case RPCCode::kCopyToRemote: { + HandleCopyToRemote(); + break; + } + case RPCCode::kShutdown: { + this->Shutdown(); + return false; + } + default: { + this->ThrowError(RPCServerStatus::kUnknownRPCCode); + break; } } } + + return true; } void Shutdown() { arena_.FreeAll(); - io_.Close(); + io_->Close(); } void HandleNormalCallFunc() { @@ -147,6 +153,9 @@ class MinRPCServer { ret_value[2].v_handle = ret_value[1].v_handle; ret_tcode[2] = kTVMOpaqueHandle; this->ReturnPackedSeq(ret_value, ret_tcode, 3); + } else if (rv_tcode == kTVMBytes) { + ret_tcode[1] = kTVMBytes; + this->ReturnPackedSeq(ret_value, ret_tcode, 2); } else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) { ret_tcode[1] = kTVMOpaqueHandle; this->ReturnPackedSeq(ret_value, ret_tcode, 2); @@ -188,9 +197,11 @@ class MinRPCServer { RPCCode code = RPCCode::kCopyAck; uint64_t packet_nbytes = sizeof(code) + num_bytes; + io_->MessageStart(packet_nbytes); this->Write(packet_nbytes); this->Write(code); this->WriteArray(data_ptr, num_bytes); + io_->MessageDone(); } else { this->ReturnLastTVMError(); } @@ -423,7 +434,7 @@ class MinRPCServer { } void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) { - io_.Exit(static_cast(code)); + io_->Exit(static_cast(code)); } template @@ -456,13 +467,17 @@ class MinRPCServer { return this->WriteRawBytes(data, sizeof(T) * count); } + void MessageStart(uint64_t packet_nbytes) { io_->MessageStart(packet_nbytes); } + + void MessageDone() { io_->MessageDone(); } + private: // Internal allocator that redirects alloc to TVM's C API. class PageAllocator { public: using ArenaPageHeader = tvm::support::ArenaPageHeader; - explicit PageAllocator(TIOHandler io) : io_(io) {} + explicit PageAllocator(TIOHandler* io) : io_(io) {} ArenaPageHeader* allocate(size_t min_size) { size_t npages = ((min_size + kPageSize - 1) / kPageSize); @@ -470,7 +485,7 @@ class MinRPCServer { if (TVMDeviceAllocDataSpace(DLContext{kDLCPU, 0}, npages * kPageSize, kPageAlign, DLDataType{kDLInt, 1, 1}, &data) != 0) { - io_.Exit(static_cast(RPCServerStatus::kAllocError)); + io_->Exit(static_cast(RPCServerStatus::kAllocError)); } ArenaPageHeader* header = static_cast(data); @@ -481,7 +496,7 @@ class MinRPCServer { void deallocate(ArenaPageHeader* page) { if (TVMDeviceFreeDataSpace(DLContext{kDLCPU, 0}, page) != 0) { - io_.Exit(static_cast(RPCServerStatus::kAllocError)); + io_->Exit(static_cast(RPCServerStatus::kAllocError)); } } @@ -489,7 +504,7 @@ class MinRPCServer { static const constexpr int kPageAlign = 8; private: - TIOHandler io_; + TIOHandler* io_; }; void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) { @@ -503,10 +518,12 @@ class MinRPCServer { uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode); + io_->MessageStart(packet_nbytes); this->Write(packet_nbytes); this->Write(code); this->Write(num_args); this->Write(tcode); + io_->MessageDone(); } void ReturnHandle(void* handle) { @@ -514,15 +531,16 @@ class MinRPCServer { int32_t tcode = kTVMOpaqueHandle; RPCCode code = RPCCode::kReturn; uint64_t encode_handle = reinterpret_cast(handle); - uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(encode_handle); + io_->MessageStart(packet_nbytes); this->Write(packet_nbytes); this->Write(code); this->Write(num_args); this->Write(tcode); this->Write(encode_handle); + io_->MessageDone(); } void ReturnException(const char* msg) { RPCReference::ReturnException(msg, this); } @@ -537,11 +555,11 @@ class MinRPCServer { uint8_t* buf = reinterpret_cast(data); size_t ndone = 0; while (ndone < size) { - ssize_t ret = io_.PosixRead(buf, size - ndone); + ssize_t ret = io_->PosixRead(buf, size - ndone); if (ret == 0) { if (allow_clean_shutdown_) { this->Shutdown(); - io_.Exit(0); + io_->Exit(0); } else { this->ThrowError(RPCServerStatus::kReadError); } @@ -558,7 +576,7 @@ class MinRPCServer { const uint8_t* buf = reinterpret_cast(data); size_t ndone = 0; while (ndone < size) { - ssize_t ret = io_.PosixWrite(buf, size - ndone); + ssize_t ret = io_->PosixWrite(buf, size - ndone); if (ret == 0 || ret == -1) { this->ThrowError(RPCServerStatus::kWriteError); } @@ -568,7 +586,7 @@ class MinRPCServer { } /*! \brief IO handler. */ - TIOHandler io_; + TIOHandler* io_; /*! \brief internal arena. */ support::GenericArena arena_; /*! \brief Whether we are in a state that allows clean shutdown. */ @@ -578,4 +596,4 @@ class MinRPCServer { } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_RPC_MINRPC_MINRPC_SERVER_H_ +#endif // TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_ diff --git a/src/runtime/rpc/minrpc/posix_popen_server.cc b/src/runtime/minrpc/posix_popen_server/posix_popen_server.cc similarity index 89% rename from src/runtime/rpc/minrpc/posix_popen_server.cc rename to src/runtime/minrpc/posix_popen_server/posix_popen_server.cc index 9784780fea18..b513d4b7cc1b 100644 --- a/src/runtime/rpc/minrpc/posix_popen_server.cc +++ b/src/runtime/minrpc/posix_popen_server/posix_popen_server.cc @@ -37,6 +37,10 @@ class PosixIOHandler { explicit PosixIOHandler(int read_fd = 0, int write_fd = 1) : read_fd_(read_fd), write_fd_(write_fd) {} + void MessageStart(uint64_t packet_nbytes) {} + + void MessageDone() {} + ssize_t PosixRead(void* data, size_t size) { return read(read_fd_, data, size); } ssize_t PosixWrite(const void* data, size_t size) { return write(write_fd_, data, size); } @@ -63,7 +67,11 @@ int main(int argc, char* argv[]) { if (argc != 3) return -1; // pass the descriptor via arguments. tvm::runtime::PosixIOHandler handler(atoi(argv[1]), atoi(argv[2])); - tvm::runtime::PosixMinRPCServer server(handler); - server.ServerLoop(); + tvm::runtime::PosixMinRPCServer server(&handler); + bool is_running = true; + while (is_running) { + is_running = server.ProcessOnePacket(); + } + return 0; } diff --git a/src/runtime/rpc/rpc_protocol.h b/src/runtime/minrpc/rpc_reference.h similarity index 90% rename from src/runtime/rpc/rpc_protocol.h rename to src/runtime/minrpc/rpc_reference.h index 3a0555d0cc6d..e195b9ca9e89 100644 --- a/src/runtime/rpc/rpc_protocol.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -18,11 +18,11 @@ */ /*! - * \file rpc_procotol.h + * \file rpc_reference.h * \brief Common header defining the communication code used in the RPC protocol. */ -#ifndef TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ -#define TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ +#ifndef TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_ +#define TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_ namespace tvm { namespace runtime { @@ -72,6 +72,46 @@ enum class RPCServerStatus : int { kAllocError }; +inline const char* RPCCodeToString(RPCCode code) { + switch (code) { + case RPCCode::kShutdown: + return "kShutdown"; + case RPCCode::kInitServer: + return "kInitServer"; + case RPCCode::kCallFunc: + return "kCallFunc"; + case RPCCode::kReturn: + return "kReturn"; + case RPCCode::kException: + return "kException"; + case RPCCode::kCopyFromRemote: + return "kCopyFromRemote"; + case RPCCode::kCopyToRemote: + return "kCopyToRemote"; + case RPCCode::kCopyAck: + return "kCopyAck"; + // The following are syscall code that can send over CallRemote + case RPCCode::kGetGlobalFunc: + return "kGetGlobalFunc"; + case RPCCode::kFreeHandle: + return "kFreeHandle"; + case RPCCode::kDevSetDevice: + return "kDevSetDevice"; + case RPCCode::kDevGetAttr: + return "kDevGetAttr"; + case RPCCode::kDevAllocData: + return "kDevAllocData"; + case RPCCode::kDevFreeData: + return "kDevFreeData"; + case RPCCode::kDevStreamSync: + return "kDevStreamSync"; + case RPCCode::kCopyAmongRemote: + return "kCopyAmongRemote"; + default: + return ""; + } +} + /*! * \brief Convert RPC server status to string. * \param status The status. @@ -421,12 +461,14 @@ struct RPCReference { uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(len) + len; + channel->MessageStart(packet_nbytes); channel->Write(packet_nbytes); channel->Write(code); channel->Write(num_args); channel->Write(tcode); channel->Write(len); channel->WriteArray(msg, len); + channel->MessageDone(); } /*! @@ -444,9 +486,11 @@ struct RPCReference { uint64_t packet_nbytes = sizeof(code) + PackedSeqGetNumBytes(arg_values, type_codes, num_args, false, channel); + channel->MessageStart(packet_nbytes); channel->Write(packet_nbytes); channel->Write(code); SendPackedSeq(arg_values, type_codes, num_args, false, channel); + channel->MessageDone(); } /*! @@ -463,13 +507,16 @@ struct RPCReference { uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode); + channel->MessageStart(packet_nbytes); channel->Write(packet_nbytes); channel->Write(code); channel->Write(num_args); channel->Write(tcode); + channel->MessageDone(); } }; } // namespace runtime } // namespace tvm -#endif // TVM_RUNTIME_RPC_RPC_PROTOCOL_H_ + +#endif // TVM_RUNTIME_MINRPC_RPC_REFERENCE_H_ diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index ca8c3260c156..2deae07b0315 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -204,6 +204,10 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { using Stream::Write; using Stream::WriteArray; + void MessageStart(uint64_t packet_nbytes) { + // Unused here, implemented for uTVM framing layer. + } + bool Read(RPCCode* code) { int32_t cdata; if (!this->Read(&cdata)) return false; @@ -215,6 +219,10 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { this->Write(cdata); } + void MessageDone() { + // Unused here, implemented for uTVM framing layer. + } + template T* ArenaAlloc(int count) { static_assert(std::is_pod::value, "need to be trival"); diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h index 2b88cee15c01..031435fc8ef9 100644 --- a/src/runtime/rpc/rpc_endpoint.h +++ b/src/runtime/rpc/rpc_endpoint.h @@ -32,8 +32,8 @@ #include #include "../../support/ring_buffer.h" +#include "../minrpc/rpc_reference.h" #include "rpc_channel.h" -#include "rpc_protocol.h" #include "rpc_session.h" namespace tvm { diff --git a/src/runtime/rpc/rpc_session.h b/src/runtime/rpc/rpc_session.h index 954c5b4ead22..4ea937acc6ef 100644 --- a/src/runtime/rpc/rpc_session.h +++ b/src/runtime/rpc/rpc_session.h @@ -31,7 +31,7 @@ #include #include -#include "rpc_protocol.h" +#include "../minrpc/rpc_reference.h" namespace tvm { namespace runtime { diff --git a/src/support/arena.h b/src/support/arena.h index cb08db93641d..832f828778b7 100644 --- a/src/support/arena.h +++ b/src/support/arena.h @@ -26,31 +26,15 @@ #ifndef TVM_SUPPORT_ARENA_H_ #define TVM_SUPPORT_ARENA_H_ -#ifndef TVM_ARENA_HAS_DESTRUCTOR -#define TVM_ARENA_HAS_DESTRUCTOR 1 -#endif - #include #include #include +#include "generic_arena.h" + namespace tvm { namespace support { -/*! - * \brief An arena page header. - */ -struct ArenaPageHeader { - /*! \brief points to the next page. */ - ArenaPageHeader* next; - /*! - * \brief Total size of the page. - */ - size_t size; - /*! \brief memory allocator offset inside page. */ - size_t offset; -}; - /*! * \brief Simple page allocator that uses new and delete. */ @@ -84,124 +68,6 @@ class SimplePageAllocator { using Page = std::aligned_storage::type; }; -/*! - * \brief Arena allocator that allocates memory from continuous - * chunk and frees them all only during destruction. - */ -template -class GenericArena { - public: - explicit GenericArena(PageAllocator alloc = PageAllocator()) : alloc_(alloc) { - // eagerly allocate the first page. - head_ = tail_ = alloc_.allocate(1); - head_->next = nullptr; - } - -#if TVM_ARENA_HAS_DESTRUCTOR - ~GenericArena() { this->FreeAll(); } -#endif - - /*! \brief Free all pages. */ - void FreeAll() { - FreePageList(&head_); - FreePageList(&free_list_); - } - /*! \brief Recycle all the pages in the arena */ - void RecycleAll() { - // put all the current list to the free list. - tail_->next = free_list_; - // allocate the first in the free list to head - free_list_ = head_->next; - head_->next = nullptr; - // Reset the head. - head_->offset = sizeof(ArenaPageHeader); - tail_ = head_; - } - /*! - * \brief Allocate a space from Arena for type T - * \param T the data type to be allocated - * \param count Numberof elements - * \note The space of T is not initialized. - */ - template - T* allocate_(int count = 1) { - static_assert(PageAllocator::kPageAlign % alignof(T) == 0, "To large alignment"); - return static_cast(Alloc(sizeof(T) * count, alignof(T))); - } - /*! - * \brief Create a new instance of type T. - * \param args The constructor argument. - * \tparam T the type to be created. - * \tparam Args Arguments to the constructor. - * - * \return The allocated object. - * \note The type T must be simple type, or only contain - * memory allocated from the same arena. - * Otherwise the destructor needs to be called explicitly. - */ - template - T* make(Args&&... args) { - T* ptr = allocate_(); - new (ptr) T(std::forward(args)...); - return ptr; - } - - private: - /*! \brief internal page allocator. */ - PageAllocator alloc_; - /* \brief The the head of the allocated list. */ - ArenaPageHeader* head_{nullptr}; - /*! \brief The tail of the allocated list. */ - ArenaPageHeader* tail_{nullptr}; - /* \brief List of free pages. */ - ArenaPageHeader* free_list_{nullptr}; - /*! - * \brief Align ptr by upper bound. - * \param offset The offset value. - * \param align The alignment requirement. - */ - size_t UpperAlign(size_t offset, size_t align) { - return offset + (align - (offset % align)) % align; - } - /*! - * \brief Internal aligned alloc function. - * \param size The size of the memory. - * \param align The alignment requirement. - */ - void* Alloc(size_t size, size_t align) { - size_t offset = UpperAlign(head_->offset, align); - if (offset + size <= head_->size) { - head_->offset = offset + size; - return reinterpret_cast(head_) + offset; - } else { - ArenaPageHeader* new_head; - offset = UpperAlign(sizeof(ArenaPageHeader), align); - if (free_list_ != nullptr && offset + size <= free_list_->size) { - new_head = free_list_; - free_list_ = free_list_->next; - } else { - new_head = alloc_.allocate(offset + size); - } - new_head->next = head_; - new_head->offset = offset + size; - head_ = new_head; - return reinterpret_cast(head_) + offset; - } - } - /*! - * \brief Free all the pages in the list. - * \param ptr The head ptr. - */ - void FreePageList(ArenaPageHeader** ptr) { - // delete all the allocated pages. - while (ptr[0] != nullptr) { - ArenaPageHeader* temp = ptr[0]; - ptr[0] = ptr[0]->next; - alloc_.deallocate(temp); - } - } -}; - using Arena = GenericArena; /*! diff --git a/src/support/generic_arena.h b/src/support/generic_arena.h new file mode 100644 index 000000000000..46915431595c --- /dev/null +++ b/src/support/generic_arena.h @@ -0,0 +1,183 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file arena.h + * \brief Arena allocator that allocates memory chunks and frees them all during destruction time. + * + * NOTE: This file is portable to bare-metal embedded devices. Don't use operator new (without + * placement parameters) or malloc. + */ +#ifndef TVM_SUPPORT_GENERIC_ARENA_H_ +#define TVM_SUPPORT_GENERIC_ARENA_H_ + +#ifndef TVM_ARENA_HAS_DESTRUCTOR +#define TVM_ARENA_HAS_DESTRUCTOR 1 +#endif + +#include + +#include + +namespace tvm { +namespace support { + +namespace { +template // For lvalues (T is T&), +T&& forward(T&& param) { // take/return lvalue refs. + return static_cast(param); // For rvalues (T is T), +} // take/return rvalue refs. +} // namespace + +/*! + * \brief An arena page header. + */ +struct ArenaPageHeader { + /*! \brief points to the next page. */ + ArenaPageHeader* next; + /*! + * \brief Total size of the page. + */ + size_t size; + /*! \brief memory allocator offset inside page. */ + size_t offset; +}; + +/*! + * \brief Arena allocator that allocates memory from continuous + * chunk and frees them all only during destruction. + */ +template +class GenericArena { + public: + explicit GenericArena(PageAllocator alloc = PageAllocator()) : alloc_(alloc) { + // eagerly allocate the first page. + head_ = tail_ = alloc_.allocate(1); + head_->next = nullptr; + } + +#if TVM_ARENA_HAS_DESTRUCTOR + ~GenericArena() { this->FreeAll(); } +#endif + + /*! \brief Free all pages. */ + void FreeAll() { + FreePageList(&head_); + FreePageList(&free_list_); + } + /*! \brief Recycle all the pages in the arena */ + void RecycleAll() { + // put all the current list to the free list. + tail_->next = free_list_; + // allocate the first in the free list to head + free_list_ = head_->next; + head_->next = nullptr; + // Reset the head. + head_->offset = sizeof(ArenaPageHeader); + tail_ = head_; + } + /*! + * \brief Allocate a space from Arena for type T + * \param T the data type to be allocated + * \param count Numberof elements + * \note The space of T is not initialized. + */ + template + T* allocate_(int count = 1) { + static_assert(PageAllocator::kPageAlign % alignof(T) == 0, "To large alignment"); + return static_cast(Alloc(sizeof(T) * count, alignof(T))); + } + /*! + * \brief Create a new instance of type T. + * \param args The constructor argument. + * \tparam T the type to be created. + * \tparam Args Arguments to the constructor. + * + * \return The allocated object. + * \note The type T must be simple type, or only contain + * memory allocated from the same arena. + * Otherwise the destructor needs to be called explicitly. + */ + template + T* make(Args&&... args) { + T* ptr = allocate_(); + new (ptr) T(forward(args)...); + return ptr; + } + + private: + /*! \brief internal page allocator. */ + PageAllocator alloc_; + /* \brief The the head of the allocated list. */ + ArenaPageHeader* head_{nullptr}; + /*! \brief The tail of the allocated list. */ + ArenaPageHeader* tail_{nullptr}; + /* \brief List of free pages. */ + ArenaPageHeader* free_list_{nullptr}; + /*! + * \brief Align ptr by upper bound. + * \param offset The offset value. + * \param align The alignment requirement. + */ + size_t UpperAlign(size_t offset, size_t align) { + return offset + (align - (offset % align)) % align; + } + /*! + * \brief Internal aligned alloc function. + * \param size The size of the memory. + * \param align The alignment requirement. + */ + void* Alloc(size_t size, size_t align) { + size_t offset = UpperAlign(head_->offset, align); + if (offset + size <= head_->size) { + head_->offset = offset + size; + return reinterpret_cast(head_) + offset; + } else { + ArenaPageHeader* new_head; + offset = UpperAlign(sizeof(ArenaPageHeader), align); + if (free_list_ != nullptr && offset + size <= free_list_->size) { + new_head = free_list_; + free_list_ = free_list_->next; + } else { + new_head = alloc_.allocate(offset + size); + } + new_head->next = head_; + new_head->offset = offset + size; + head_ = new_head; + return reinterpret_cast(head_) + offset; + } + } + /*! + * \brief Free all the pages in the list. + * \param ptr The head ptr. + */ + void FreePageList(ArenaPageHeader** ptr) { + // delete all the allocated pages. + while (ptr[0] != nullptr) { + ArenaPageHeader* temp = ptr[0]; + ptr[0] = ptr[0]->next; + alloc_.deallocate(temp); + } + } +}; + +} // namespace support +} // namespace tvm +#endif // TVM_SUPPORT_GENERIC_ARENA_H_ diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 8cb678614a68..b5d2bf7ceb85 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -218,6 +218,7 @@ TVM_REGISTER_TARGET_KIND("llvm", kDLCPU) TVM_REGISTER_TARGET_KIND("c", kDLCPU) .add_attr_option("system-lib") .add_attr_option("runtime") + .add_attr_option("mcpu") .set_default_keys({"cpu"}); TVM_REGISTER_TARGET_KIND("cuda", kDLGPU) diff --git a/tests/cpp/utvm_runtime_standalone_test.cc b/tests/cpp/utvm_runtime_standalone_test.cc index 39449ee215f2..e55431fe2413 100644 --- a/tests/cpp/utvm_runtime_standalone_test.cc +++ b/tests/cpp/utvm_runtime_standalone_test.cc @@ -45,12 +45,12 @@ #include TVM_REGISTER_GLOBAL("test.sch").set_body([](tvm::TVMArgs args, tvm::TVMRetValue* rv) { - *rv = topi::generic::schedule_injective(args[0], args[1]); + *rv = ::tvm::topi::generic::schedule_injective(args[0], args[1]); }); TEST(MicroStandaloneRuntime, BuildModule) { using namespace tvm; - auto tensor_type = relay::TensorType({2, 3}, ::tvm::Float(32)); + auto tensor_type = relay::TensorType({2, 3}, ::tvm::runtime::DataType::Float(32)); auto a = relay::Var("a", tensor_type); auto b = relay::Var("b", tensor_type); auto add_op = relay::Op::Get("add"); diff --git a/tests/crt/buffer_write_stream.h b/tests/crt/buffer_write_stream.h new file mode 100644 index 000000000000..66ef044e6ba1 --- /dev/null +++ b/tests/crt/buffer_write_stream.h @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TESTS_CRT_BUFFER_WRITE_STREAM_H_ +#define TESTS_CRT_BUFFER_WRITE_STREAM_H_ + +#include +#include +#include + +using ::tvm::runtime::micro_rpc::FrameBuffer; +using ::tvm::runtime::micro_rpc::WriteStream; + +template +class BufferWriteStream : public WriteStream { + public: + ssize_t Write(const uint8_t* data, size_t data_size_bytes) override { + return buffer_.Write(data, data_size_bytes); + } + + void Reset() { + buffer_.Clear(); + packet_done_ = false; + } + + inline bool packet_done() { return packet_done_; } + + inline bool is_valid() { return is_valid_; } + + void PacketDone(bool is_valid) override { + EXPECT_FALSE(packet_done_); + packet_done_ = true; + is_valid_ = is_valid; + } + + std::string BufferContents() { return std::string((const char*)buffer_data_, buffer_.Size()); } + + static constexpr unsigned int capacity() { return N; }; + + private: + bool packet_done_{false}; + bool is_valid_{false}; + uint8_t buffer_data_[N]; + FrameBuffer buffer_{buffer_data_, N}; +}; + +#endif // TESTS_CRT_BUFFER_WRITE_STREAM_H_ diff --git a/tests/crt/framing_test.cc b/tests/crt/framing_test.cc new file mode 100644 index 000000000000..ed3587b497cc --- /dev/null +++ b/tests/crt/framing_test.cc @@ -0,0 +1,317 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include + +#include "buffer_write_stream.h" +#include "crt_config.h" +#include "platform.cc" + +using ::tvm::runtime::micro_rpc::Escape; +using ::tvm::runtime::micro_rpc::FrameBuffer; +using ::tvm::runtime::micro_rpc::Framer; +using ::tvm::runtime::micro_rpc::Unframer; + +class FramerTest : public ::testing::Test { + protected: + BufferWriteStream<300> write_stream_; + Framer framer_{&write_stream_}; +}; + +class TestPacket { + public: + static std::vector instances; + + // NOTE: take payload and wire as arrays to avoid clipping at \0 + template + TestPacket(const std::string name, const char (&payload)[N], const char (&wire)[M]) + : name{name}, payload{payload, N - 1}, wire{wire, M - 1} { // omit trailing \0 + instances.emplace_back(this); + } + + inline const uint8_t* payload_data() const { + return reinterpret_cast(payload.data()); + } + + inline const uint8_t* wire_data() const { return reinterpret_cast(wire.data()); } + + std::string name; + std::string payload; + std::string wire; +}; + +void PrintTo(const TestPacket* p, std::ostream* os) { + *os << "TestPacket(\"" << p->name << "\", ...)"; +} + +void PrintTo(tvm_crt_error_t p, std::ostream* os) { + std::ios_base::fmtflags f(os->flags()); + *os << "tvm_crt_error_t(0x" << std::hex << std::setw(8) << std::setfill('0') << p << ")"; + os->flags(f); +} + +std::vector TestPacket::instances; + +#define TEST_PACKET(name, payload, wire) \ + static const TestPacket k##name { #name, payload, wire } + +// NOTE: golden packet CRCs are generated with this python: +// import binascii +// import struct +// struct.pack('t@\"hr", + "\xff\xfd\x13\0\0\0es_\xff\xff_capeir/^>t@\"hr\xb4\xff\xff"); + +TEST_F(FramerTest, ValidPacketTrain) { + EXPECT_EQ(kTvmErrorNoError, framer_.Write(kPacket1.payload_data(), kPacket1.payload.size())); + EXPECT_EQ(kTvmErrorNoError, framer_.Write(kPacket2.payload_data(), kPacket2.payload.size())); + framer_.Reset(); + EXPECT_EQ(kTvmErrorNoError, framer_.Write(kPacket3.payload_data(), kPacket3.payload.size())); + + EXPECT_EQ("\xfe" + kPacket1.wire + // packet1 plus nop prefix. + kPacket2.wire + // packet2, no prefix. + "\xfe" + kPacket3.wire, // packet3 plus nop prefix. + write_stream_.BufferContents()); +} + +TEST_F(FramerTest, ZeroLengthPacket) { + EXPECT_EQ(kTvmErrorNoError, + framer_.Write(kZeroLengthPacket.payload_data(), kZeroLengthPacket.payload.size())); + EXPECT_EQ("\xfe" + kZeroLengthPacket.wire, write_stream_.BufferContents()); +} + +TEST_F(FramerTest, Escapes) { + EXPECT_EQ(kTvmErrorNoError, + framer_.Write(kEscapePacket.payload_data(), kEscapePacket.payload.size())); + EXPECT_EQ("\xfe" + kEscapePacket.wire, write_stream_.BufferContents()); +} + +class UnframerTest : public ::testing::Test { + protected: + BufferWriteStream<300> write_stream_; + Unframer unframer_{&write_stream_}; +}; + +TEST_F(UnframerTest, PacketTooLong) { + const uint8_t escape[2] = {uint8_t(Escape::kEscapeStart), uint8_t(Escape::kPacketStart)}; + uint16_t crc = crc16_compute(escape, sizeof(escape), nullptr); + size_t bytes_consumed; + EXPECT_EQ(kTvmErrorNoError, unframer_.Write(escape, sizeof(escape), &bytes_consumed)); + EXPECT_EQ(sizeof(escape), bytes_consumed); + + uint32_t packet_length = write_stream_.capacity() + 1; + uint8_t* packet_length_bytes = reinterpret_cast(&packet_length); + for (size_t i = 0; i < sizeof(packet_length); i++) { + ASSERT_NE('\xff', packet_length_bytes[i]); + } + crc = crc16_compute(packet_length_bytes, sizeof(packet_length), &crc); + EXPECT_EQ(kTvmErrorNoError, + unframer_.Write(packet_length_bytes, sizeof(packet_length), &bytes_consumed)); + EXPECT_EQ(sizeof(packet_length), bytes_consumed); + + uint8_t long_payload[decltype(write_stream_)::capacity() + 1]; + for (size_t i = 0; i < sizeof(long_payload); i++) { + long_payload[i] = i & 0xff; + if (long_payload[i] == uint8_t(Escape::kEscapeStart)) { + long_payload[i] = 0; + } + } + crc = crc16_compute(long_payload, sizeof(long_payload), &crc); + EXPECT_EQ(kTvmErrorWriteStreamShortWrite, + unframer_.Write(long_payload, sizeof(long_payload), &bytes_consumed)); + EXPECT_EQ(write_stream_.capacity(), bytes_consumed); + + EXPECT_EQ(kTvmErrorNoError, unframer_.Write((uint8_t*)&crc, sizeof(crc), &bytes_consumed)); + EXPECT_EQ(2, bytes_consumed); // 2, because framer is now in kFindPacketStart. + EXPECT_FALSE(write_stream_.packet_done()); + EXPECT_FALSE(write_stream_.is_valid()); + EXPECT_EQ(std::string((char*)long_payload, write_stream_.capacity()), + write_stream_.BufferContents()); + + // Writing a smaller packet directly afterward should work. + write_stream_.Reset(); + EXPECT_EQ(kTvmErrorNoError, + unframer_.Write(kPacket1.wire_data(), kPacket1.wire.size(), &bytes_consumed)); + EXPECT_EQ(kPacket1.wire.size(), bytes_consumed); + EXPECT_TRUE(write_stream_.packet_done()); + EXPECT_TRUE(write_stream_.is_valid()); + EXPECT_EQ(kPacket1.payload, write_stream_.BufferContents()); +}; + +class UnframerTestParameterized : public UnframerTest, + public ::testing::WithParamInterface {}; + +TEST_P(UnframerTestParameterized, TestFullPacket) { + size_t bytes_consumed; + EXPECT_EQ(kTvmErrorNoError, + unframer_.Write(GetParam()->wire_data(), GetParam()->wire.size(), &bytes_consumed)); + EXPECT_EQ(GetParam()->wire.size(), bytes_consumed); + EXPECT_TRUE(write_stream_.packet_done()); + EXPECT_TRUE(write_stream_.is_valid()); + EXPECT_EQ(GetParam()->payload, write_stream_.BufferContents()); +} + +TEST_P(UnframerTestParameterized, TestByteAtATime) { + size_t bytes_consumed; + size_t wire_size = GetParam()->wire.size(); + for (size_t i = 0; i < wire_size; i++) { + EXPECT_EQ(kTvmErrorNoError, + unframer_.Write(reinterpret_cast(&GetParam()->wire[i]), 1, + &bytes_consumed)); + EXPECT_EQ(1, bytes_consumed); + EXPECT_EQ(i == wire_size - 1, write_stream_.packet_done()); + } + EXPECT_TRUE(write_stream_.is_valid()); + EXPECT_EQ(GetParam()->payload, write_stream_.BufferContents()); +} + +TEST_P(UnframerTestParameterized, TestArbitraryBoundary) { + size_t bytes_consumed; + size_t wire_size = GetParam()->wire.size(); + for (size_t i = 1; i < wire_size; i++) { + unframer_.Reset(); + write_stream_.Reset(); + EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), i, &bytes_consumed)); + EXPECT_EQ(i, bytes_consumed); + EXPECT_FALSE(write_stream_.packet_done()); + EXPECT_EQ(kTvmErrorNoError, + unframer_.Write(&GetParam()->wire_data()[i], wire_size - i, &bytes_consumed)); + EXPECT_EQ(wire_size - i, bytes_consumed); + EXPECT_TRUE(write_stream_.packet_done()); + EXPECT_TRUE(write_stream_.is_valid()); + EXPECT_EQ(GetParam()->payload, write_stream_.BufferContents()); + } +} + +TEST_P(UnframerTestParameterized, TestArbitraryPacketReset) { + size_t bytes_consumed; + size_t wire_size = GetParam()->wire.size(); + + // This test interrupts packet transmission at an arbitrary point in the packet and restarts from + // the beginning. It simulates handling a device reset in the protocol. The behavior of the framer + // depends on how much of the packet had been transmitted, so the test is split into parts: + + // Part 1. Restarting during the initial escape sequence. + unframer_.Reset(); + write_stream_.Reset(); + EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), 1, &bytes_consumed)); + EXPECT_EQ(1, bytes_consumed); + EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), wire_size, &bytes_consumed)); + EXPECT_EQ(wire_size, bytes_consumed); + EXPECT_TRUE(write_stream_.packet_done()); + EXPECT_TRUE(write_stream_.is_valid()); + EXPECT_EQ(GetParam()->payload, write_stream_.BufferContents()); + + // Part 2. Restarting after the initial escape sequence. + for (size_t i = 2; i < wire_size; i++) { + unframer_.Reset(); + write_stream_.Reset(); + EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), i, &bytes_consumed)); + EXPECT_EQ(i, bytes_consumed); + + // First test byte-by-byte interruption. + // Interrupt the packet transmission. The first byte will return no error as it is the escape + // byte. + EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), 1, &bytes_consumed)); + EXPECT_EQ(1, bytes_consumed); + EXPECT_FALSE(write_stream_.packet_done()); + + // Secondt byte will return a short packet error. + EXPECT_EQ(kTvmErrorFramingShortPacket, + unframer_.Write(&GetParam()->wire_data()[1], 1, &bytes_consumed)); + EXPECT_EQ(0, bytes_consumed); + EXPECT_FALSE(write_stream_.packet_done()); + + EXPECT_EQ(kTvmErrorNoError, + unframer_.Write(&GetParam()->wire_data()[1], wire_size - 1, &bytes_consumed)); + EXPECT_EQ(wire_size - 1, bytes_consumed); + EXPECT_TRUE(write_stream_.packet_done()); + EXPECT_TRUE(write_stream_.is_valid()); + EXPECT_EQ(GetParam()->payload, write_stream_.BufferContents()); + + // Next, test interruption just by sending the whole payload at once. + unframer_.Reset(); + write_stream_.Reset(); + EXPECT_EQ(kTvmErrorNoError, unframer_.Write(GetParam()->wire_data(), i, &bytes_consumed)); + EXPECT_EQ(i, bytes_consumed); + + // Interrupt the packet transmission. The first Write() call will just consume 1 byte to reset + // the internal state. + EXPECT_EQ(kTvmErrorFramingShortPacket, + unframer_.Write(GetParam()->wire_data(), wire_size, &bytes_consumed)); + EXPECT_EQ(1, bytes_consumed); + EXPECT_FALSE(write_stream_.packet_done()); + EXPECT_EQ(kTvmErrorNoError, + unframer_.Write(&GetParam()->wire_data()[1], wire_size - 1, &bytes_consumed)); + EXPECT_EQ(wire_size - 1, bytes_consumed); + EXPECT_TRUE(write_stream_.packet_done()); + EXPECT_TRUE(write_stream_.is_valid()); + EXPECT_EQ(GetParam()->payload, write_stream_.BufferContents()); + + break; + } +} + +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Wdeprecated-declarations" +INSTANTIATE_TEST_CASE_P(UnframerTests, UnframerTestParameterized, + ::testing::ValuesIn(TestPacket::instances)); +#pragma GCC diagnostic pop + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/crt/func_registry_test.cc b/tests/crt/func_registry_test.cc index 2eca2a3dcd6b..2889f7b899a7 100644 --- a/tests/crt/func_registry_test.cc +++ b/tests/crt/func_registry_test.cc @@ -22,6 +22,8 @@ #include #include +#include "platform.cc" + typedef struct { const char* a; const char* b; diff --git a/tests/crt/memory_test.cc b/tests/crt/memory_test.cc index 3b1f7fa560fe..af597f970810 100644 --- a/tests/crt/memory_test.cc +++ b/tests/crt/memory_test.cc @@ -22,6 +22,7 @@ #include #include "crt_config.h" +#include "platform.cc" #define ROUND_UP(qty, modulo) (((qty) + ((modulo)-1)) / (modulo) * (modulo)) @@ -119,10 +120,6 @@ TEST_F(MemoryManagerTest, Realloc) { EXPECT_EQ(vleak_size, 0); } -extern "C" { -void TVMPlatformAbort(int error_code) { FAIL() << "TVMPlatformAbort(" << error_code << ")"; } -} - int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/src/runtime/micro/device/host/utvm_init.c b/tests/crt/platform.cc similarity index 52% rename from src/runtime/micro/device/host/utvm_init.c rename to tests/crt/platform.cc index 4fb43c11d20e..3d8906d0605f 100644 --- a/src/runtime/micro/device/host/utvm_init.c +++ b/tests/crt/platform.cc @@ -17,22 +17,31 @@ * under the License. */ -/*! - * \file utvm_init.c - * \brief uTVM init definition for the host emulated device - */ +#include +#include +#include +#include -#ifdef __cplusplus extern "C" { -#endif - -#include "utvm_runtime.h" - -void UTVMInit() { - // no init required for the host - UTVMMain(); +void InternalTVMPlatformAbort(tvm_crt_error_t error_code) { + FAIL() << "TVMPlatformAbort(" << error_code << ")"; } +void TVMPlatformAbort(tvm_crt_error_t error_code) { + InternalTVMPlatformAbort(error_code); + exit(2); // for __attribute__((noreturn)) +} +void* TVMSystemLibEntryPoint() { return NULL; } +void TVMLogf(const char* fmt, ...) { + va_list args; + char log_buf[1024]; + va_start(args, fmt); + int ret = vsnprintf(log_buf, sizeof(log_buf), fmt, args); + va_end(args); -#ifdef __cplusplus -} // TVM_EXTERN_C -#endif + if (ret < 0) { + LOG(ERROR) << "TVMLogf: error formatting: " << fmt; + } else { + LOG(INFO) << "TVMLogf: " << std::string(log_buf, ret); + } +} +} diff --git a/tests/crt/session_test.cc b/tests/crt/session_test.cc new file mode 100644 index 000000000000..a1d57fcb5436 --- /dev/null +++ b/tests/crt/session_test.cc @@ -0,0 +1,265 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include + +#include +#include + +#include "buffer_write_stream.h" +#include "crt_config.h" +#include "platform.cc" + +using ::tvm::runtime::micro_rpc::Framer; +using ::tvm::runtime::micro_rpc::MessageType; +using ::tvm::runtime::micro_rpc::Session; +using ::tvm::runtime::micro_rpc::Unframer; + +extern "C" { +void TestSessionMessageReceivedThunk(void* context, MessageType message_type, FrameBuffer* buf); +} + +class ReceivedMessage { + public: + ReceivedMessage(MessageType type, std::string message) : type{type}, message{message} {} + + bool operator==(const ReceivedMessage& other) const { + return other.type == type && other.message == message; + } + + MessageType type; + std::string message; +}; + +class TestSession { + public: + TestSession(uint8_t initial_nonce) + : framer{&framer_write_stream}, + receive_buffer{receive_buffer_array, sizeof(receive_buffer_array)}, + sess{initial_nonce, &framer, &receive_buffer, TestSessionMessageReceivedThunk, this}, + unframer{sess.Receiver()} {} + + void WriteTo(TestSession* other) { + auto framer_buffer = framer_write_stream.BufferContents(); + size_t bytes_to_write = framer_buffer.size(); + const uint8_t* write_cursor = reinterpret_cast(framer_buffer.data()); + while (bytes_to_write > 0) { + size_t bytes_consumed; + auto to_return = other->unframer.Write(write_cursor, bytes_to_write, &bytes_consumed); + EXPECT_EQ(to_return, kTvmErrorNoError); + bytes_to_write -= bytes_consumed; + write_cursor += bytes_consumed; + } + } + + void ClearBuffers() { + framer_write_stream.Reset(); + messages_received.clear(); + sess.ClearReceiveBuffer(); + } + + std::vector messages_received; + BufferWriteStream<300> framer_write_stream; + Framer framer; + uint8_t receive_buffer_array[300]; + FrameBuffer receive_buffer; + Session sess; + Unframer unframer; +}; + +#define EXPECT_FRAMED_PACKET(session, expected) \ + EXPECT_EQ(std::string(expected, sizeof(expected) - 1), \ + (session).framer_write_stream.BufferContents()); + +extern "C" { +void TestSessionMessageReceivedThunk(void* context, MessageType message_type, FrameBuffer* buf) { + std::string message; + if (message_type != MessageType::kStartSessionReply) { + uint8_t message_buf[300]; + EXPECT_LE(buf->ReadAvailable(), sizeof(message_buf)); + size_t message_size_bytes = buf->Read(message_buf, sizeof(message_buf)); + message = std::string(reinterpret_cast(message_buf), message_size_bytes); + } + + static_cast(context)->messages_received.emplace_back( + ReceivedMessage(message_type, message)); +} +} + +void PrintTo(tvm_crt_error_t p, std::ostream* os) { + std::ios_base::fmtflags f(os->flags()); + *os << "tvm_crt_error_t(0x" << std::hex << std::setw(8) << std::setfill('0') << p << ")"; + os->flags(f); +} + +void PrintTo(ReceivedMessage msg, std::ostream* os) { + *os << "ReceivedMessage(" << int(msg.type) << ", \"" << msg.message << "\")"; +} + +class SessionTest : public ::testing::Test { + public: + static constexpr const uint8_t kAliceNonce = 0x3c; + static constexpr const uint8_t kBobNonce = 0xab; + + TestSession alice_{kAliceNonce}; + TestSession bob_{kBobNonce}; +}; + +TEST_F(SessionTest, NormalExchange) { + tvm_crt_error_t err; + err = alice_.sess.Initialize(); + EXPECT_EQ(kTvmErrorNoError, err); + EXPECT_FRAMED_PACKET(alice_, + "\xfe\xff\xfd\x03\0\0\0\0\0\x02" + "fw"); + alice_.WriteTo(&bob_); + + err = bob_.sess.Initialize(); + EXPECT_EQ(kTvmErrorNoError, err); + EXPECT_FRAMED_PACKET(bob_, + "\xfe\xff\xfd\x03\0\0\0\0\0\x02" + "fw"); + alice_.WriteTo(&alice_); + + bob_.ClearBuffers(); + alice_.ClearBuffers(); + + err = alice_.sess.StartSession(); + EXPECT_EQ(err, kTvmErrorNoError); + EXPECT_FRAMED_PACKET(alice_, "\xff\xfd\x04\0\0\0\x82\0\0\x01{\xE9"); + + bob_.ClearBuffers(); + alice_.WriteTo(&bob_); + EXPECT_FRAMED_PACKET(bob_, + "\xff\xfd\x4\0\0\0\x82" + "f\x01\x01\x81\xf3"); + EXPECT_TRUE(bob_.sess.IsEstablished()); + + bob_.WriteTo(&alice_); + EXPECT_TRUE(alice_.sess.IsEstablished()); + ASSERT_EQ(alice_.messages_received.size(), 1); + EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kStartSessionReply, "")); + + alice_.ClearBuffers(); + alice_.sess.SendMessage(MessageType::kNormal, reinterpret_cast("hello"), 5); + EXPECT_FRAMED_PACKET(alice_, + "\xFF\xFD\b\0\0\0\x82" + "f\x10hello\x90("); + alice_.WriteTo(&bob_); + ASSERT_EQ(bob_.messages_received.size(), 2); + EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kStartSessionReply, "")); + EXPECT_EQ(bob_.messages_received[1], ReceivedMessage(MessageType::kNormal, "hello")); + + bob_.ClearBuffers(); + bob_.sess.SendMessage(MessageType::kNormal, reinterpret_cast("olleh"), 5); + EXPECT_FRAMED_PACKET(bob_, + "\xff\xfd\b\0\0\0\x82" + "f\x10ollehLv"); + bob_.WriteTo(&alice_); + ASSERT_EQ(alice_.messages_received.size(), 1); + EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kNormal, "olleh")); + + alice_.ClearBuffers(); + bob_.ClearBuffers(); + + alice_.sess.SendMessage(MessageType::kLog, reinterpret_cast("log1"), 4); + EXPECT_FRAMED_PACKET(alice_, "\xff\xfd\a\0\0\0\0\0\x03log1\xf0\xd4"); + alice_.WriteTo(&bob_); + ASSERT_EQ(bob_.messages_received.size(), 1); + EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kLog, "log1")); + + bob_.sess.SendMessage(MessageType::kLog, reinterpret_cast("zero"), 4); + EXPECT_FRAMED_PACKET(bob_, "\xff\xfd\a\0\0\0\0\0\x03zero\xb2h"); + bob_.WriteTo(&alice_); + ASSERT_EQ(alice_.messages_received.size(), 1); + EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kLog, "zero")); +} + +TEST_F(SessionTest, LogBeforeSessionStart) { + alice_.sess.SendMessage(MessageType::kLog, reinterpret_cast("log1"), 4); + EXPECT_FRAMED_PACKET(alice_, "\xfe\xff\xfd\a\0\0\0\0\0\x03log1\xf0\xd4"); + alice_.WriteTo(&bob_); + ASSERT_EQ(bob_.messages_received.size(), 1); + EXPECT_EQ(bob_.messages_received[0], ReceivedMessage(MessageType::kLog, "log1")); + + bob_.sess.SendMessage(MessageType::kLog, reinterpret_cast("zero"), 4); + EXPECT_FRAMED_PACKET(bob_, "\xfe\xff\xfd\a\0\0\0\0\0\x03zero\xb2h"); + bob_.WriteTo(&alice_); + ASSERT_EQ(alice_.messages_received.size(), 1); + EXPECT_EQ(alice_.messages_received[0], ReceivedMessage(MessageType::kLog, "zero")); +} + +static constexpr const char kBobStartPacket[] = "\xff\xfd\x04\0\0\0f\0\0\x01`\xa7"; + +TEST_F(SessionTest, DoubleStart) { + tvm_crt_error_t err; + err = alice_.sess.Initialize(); + EXPECT_EQ(kTvmErrorNoError, err); + EXPECT_FRAMED_PACKET(alice_, + "\xfe\xff\xfd\x03\0\0\0\0\0\x02" + "fw"); + alice_.WriteTo(&bob_); + + err = bob_.sess.Initialize(); + EXPECT_EQ(kTvmErrorNoError, err); + EXPECT_FRAMED_PACKET(bob_, + "\xfe\xff\xfd\x03\0\0\0\0\0\x02" + "fw"); + alice_.WriteTo(&alice_); + + bob_.ClearBuffers(); + alice_.ClearBuffers(); + + EXPECT_EQ(kTvmErrorNoError, alice_.sess.StartSession()); + EXPECT_FRAMED_PACKET(alice_, "\xff\xfd\x04\0\0\0\x82\0\0\x01{\xe9"); + EXPECT_FALSE(alice_.sess.IsEstablished()); + + EXPECT_EQ(kTvmErrorNoError, bob_.sess.StartSession()); + EXPECT_FRAMED_PACKET(bob_, kBobStartPacket); + EXPECT_FALSE(bob_.sess.IsEstablished()); + + // Sending Alice -> Bob should have no effect (regenerated Bob nonce > regenerated Alice nonce). + bob_.framer_write_stream.Reset(); + alice_.WriteTo(&bob_); + EXPECT_FRAMED_PACKET(bob_, ""); + EXPECT_FALSE(bob_.sess.IsEstablished()); + + // Sending Bob -> Alice should start the session. + alice_.ClearBuffers(); + size_t bytes_consumed; + EXPECT_EQ(kTvmErrorNoError, + alice_.unframer.Write(reinterpret_cast(kBobStartPacket), + sizeof(kBobStartPacket), &bytes_consumed)); + EXPECT_EQ(bytes_consumed, sizeof(kBobStartPacket)); + EXPECT_FRAMED_PACKET(alice_, "\xFF\xFD\x4\0\0\0fE\x01\x01\fb"); + EXPECT_TRUE(alice_.sess.IsEstablished()); + + bob_.ClearBuffers(); + alice_.WriteTo(&bob_); + EXPECT_TRUE(bob_.sess.IsEstablished()); +} + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + testing::FLAGS_gtest_death_test_style = "threadsafe"; + return RUN_ALL_TESTS(); +} diff --git a/tests/python/unittest/test_crt.py b/tests/python/unittest/test_crt.py new file mode 100644 index 000000000000..fe6b03ba546e --- /dev/null +++ b/tests/python/unittest/test_crt.py @@ -0,0 +1,141 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import contextlib +import copy +import glob +import os +import pty +import sys +import subprocess +import textwrap + +import numpy as np + +import tvm +import tvm.relay +import tvm.micro +from tvm.micro import transport + +from tvm.topi.util import get_const_tuple +from tvm.topi.testing import conv2d_nchw_python + +BUILD = True +DEBUG = False + +TARGET = tvm.target.target.micro('host') + +def _make_sess_from_op(workspace, op_name, sched, arg_bufs): + with tvm.transform.PassContext(opt_level=3, config={'tir.disable_vectorize': True}): + mod = tvm.build(sched, arg_bufs, TARGET, target_host=TARGET, name=op_name) + + return _make_session(workspace, mod) + + +def _make_session(workspace, mod): + compiler = tvm.micro.DefaultCompiler(target=TARGET) + opts = tvm.micro.default_options(os.path.join(tvm.micro.CRT_ROOT_DIR, 'host')) + + micro_binary = tvm.micro.build_static_runtime( + # the x86 compiler *expects* you to give the exact same dictionary for both + # lib_opts and bin_opts. so the library compiler is mutating lib_opts and + # the binary compiler is expecting those mutations to be in bin_opts. + # TODO(weberlo) fix this very bizarre behavior + workspace, compiler, mod, lib_opts=opts['bin_opts'], bin_opts=opts['bin_opts']) + + flasher_kw = { + 'debug': DEBUG, + } + flasher = compiler.flasher(**flasher_kw) + return tvm.micro.Session(binary=micro_binary, flasher=flasher) + + +def _make_add_sess(workspace): + A = tvm.te.placeholder((2,), dtype='int8') + B = tvm.te.placeholder((1,), dtype='int8') + C = tvm.te.compute(A.shape, lambda i: A[i] + B[0], name='C') + sched = tvm.te.create_schedule(C.op) + return _make_sess_from_op(workspace, 'add', sched, [A, B, C]) + + +def _make_ident_sess(workspace): + A = tvm.te.placeholder((2,), dtype='int8') + B = tvm.te.compute(A.shape, lambda i: A[i], name='B') + sched = tvm.te.create_schedule(B.op) + return _make_sess_from_op(workspace, 'ident', sched, [A, B]) + + +def test_compile_runtime(): + """Test compiling the on-device runtime.""" + workspace = tvm.micro.Workspace() + + with _make_add_sess(workspace) as sess: + A_data = tvm.nd.array(np.array([2, 3], dtype='int8'), ctx=sess.context) + assert (A_data.asnumpy() == np.array([2, 3])).all() + B_data = tvm.nd.array(np.array([4], dtype='int8'), ctx=sess.context) + assert (B_data.asnumpy() == np.array([4])).all() + C_data = tvm.nd.array(np.array([0, 0], dtype='int8'), ctx=sess.context) + assert (C_data.asnumpy() == np.array([0, 0])).all() + + system_lib = sess.get_system_lib() + system_lib.get_function('add')(A_data, B_data, C_data) + assert (C_data.asnumpy() == np.array([6, 7])).all() + + +def test_reset(): + """Test when the remote end resets during a session.""" + workspace = tvm.micro.Workspace() + + with _make_add_sess(workspace) as sess: + try: + sess._rpc.get_function('tvm.testing.reset_server')() + assert False, 'expected to raise SessionTerminatedError; did not raise' + except transport.SessionTerminatedError: + pass + + +def test_graph_runtime(): + """Test use of the graph runtime with microTVM.""" + workspace = tvm.micro.Workspace() + relay_mod = tvm.parser.fromtext( + """ + #[version = "0.0.5"] + def @main(%a : Tensor[(1, 2), uint8], %b : Tensor[(1, 2), uint8]) { + %0 = %a + %b; + %0 + }""") + + with tvm.transform.PassContext(opt_level=3, config={'tir.disable_vectorize': True}): + factory = tvm.relay.build(relay_mod, target=TARGET) + + with _make_session(workspace, factory.get_lib()) as sess: + graph_mod = tvm.micro.create_local_graph_runtime(factory.get_json(), sess.get_system_lib(), sess.context) + A_data = tvm.nd.array(np.array([2, 3], dtype='uint8'), ctx=sess.context) + assert (A_data.asnumpy() == np.array([2, 3])).all() + B_data = tvm.nd.array(np.array([4, 7], dtype='uint8'), ctx=sess.context) + assert (B_data.asnumpy() == np.array([4, 7])).all() + + graph_mod.run(a=A_data, b=B_data) + + out = graph_mod.get_output(0) + assert (out.asnumpy() == np.array([6, 10])).all() + + +if __name__ == '__main__': + test_compile_runtime() + test_reset() + test_graph_runtime() diff --git a/tests/python/unittest/test_runtime_micro.py b/tests/python/unittest/test_runtime_micro.py deleted file mode 100644 index 45ec9bce9fb1..000000000000 --- a/tests/python/unittest/test_runtime_micro.py +++ /dev/null @@ -1,361 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import os - -import numpy as np -import tvm -from tvm import te -from tvm.contrib import graph_runtime, util -from tvm import relay -import tvm.micro as micro -from tvm.micro import create_micro_mod - -# # Use the host emulated micro device. -DEV_CONFIG_A = micro.device.host.generate_config() -DEV_CONFIG_B = micro.device.host.generate_config() -TARGET = "c --runtime=c" - - -def relay_micro_build(func, dev_config, params=None): - """Create a graph runtime module with a micro device context from a Relay function. - - Parameters - ---------- - func : relay.Function - function to compile - - dev_config : Dict[str, Any] - MicroTVM config dict for the target device - - params : dict - input parameters that do not change during inference - - Return - ------ - mod : tvm.runtime.Module - graph runtime module for the target device - """ - with tvm.transform.PassContext( - disabled_pass={"FuseOps"}, config={"tir.disable_vectorize": True} - ): - graph, c_mod, params = relay.build(func, target=TARGET, params=params) - micro_mod = micro.create_micro_mod(c_mod, dev_config) - ctx = tvm.micro_dev(0) - mod = graph_runtime.create(graph, micro_mod, ctx) - mod.set_input(**params) - return mod - - -GDB_INIT_TEMPLATE = """ -layout asm -target remote localhost:{gdb_port} -set $pc = UTVMInit -break UTVMDone -""" - - -def reset_gdbinit(): - if "server_port" not in DEV_CONFIG_A: - return - gdb_init_dir = os.environ["MICRO_GDB_INIT_DIR"] - with open(f"{gdb_init_dir}/.gdbinit", "w") as f: - gdb_port = DEV_CONFIG_A["server_port"] - 3333 - f.write(GDB_INIT_TEMPLATE.format(gdb_port=gdb_port)) - - -def test_alloc(): - """Test tensor allocation on the device.""" - if not tvm.runtime.enabled("micro_dev"): - return - shape = (1024,) - dtype = "float32" - with micro.Session(DEV_CONFIG_A): - ctx = tvm.micro_dev(0) - np_tensor = np.random.uniform(size=shape).astype(dtype) - micro_tensor = tvm.nd.array(np_tensor, ctx) - tvm.testing.assert_allclose(np_tensor, micro_tensor.asnumpy()) - - -def test_add(): - """Test a module which performs addition.""" - if not tvm.runtime.enabled("micro_dev"): - return - shape = (1024,) - dtype = "float32" - - reset_gdbinit() - - # Construct TVM expression. - tvm_shape = tvm.runtime.convert(shape) - A = te.placeholder(tvm_shape, name="A", dtype=dtype) - B = te.placeholder(tvm_shape, name="B", dtype=dtype) - C = te.compute(A.shape, lambda *i: A(*i) + B(*i), name="C") - s = te.create_schedule(C.op) - - func_name = "fadd" - c_mod = tvm.build(s, [A, B, C], target="c", name=func_name) - - with micro.Session(DEV_CONFIG_A) as sess: - micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) - micro_func = micro_mod[func_name] - ctx = tvm.micro_dev(0) - - a_np = np.random.uniform(size=shape).astype(dtype) - a = tvm.nd.array(a_np, ctx) - b_np = np.random.uniform(size=shape).astype(dtype) - b = tvm.nd.array(b_np, ctx) - c = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx) - micro_func(a, b, c) - - # ensure inputs weren't corrupted - tvm.testing.assert_allclose(a.asnumpy(), a_np) - tvm.testing.assert_allclose(b.asnumpy(), b_np) - # ensure output is correct - tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + b.asnumpy()) - - -def test_workspace_add(): - """Test a module which uses a workspace to compute an intermediate value.""" - if not tvm.runtime.enabled("micro_dev"): - return - shape = (1024,) - dtype = "float32" - - reset_gdbinit() - - # Construct TVM expression. - tvm_shape = tvm.runtime.convert(shape) - A = te.placeholder(tvm_shape, name="A", dtype=dtype) - B = te.placeholder(tvm_shape, name="B", dtype=dtype) - B = te.compute(A.shape, lambda *i: A(*i) + 1, name="B") - C = te.compute(A.shape, lambda *i: B(*i) + 1, name="C") - s = te.create_schedule(C.op) - - func_name = "fadd_two_workspace" - c_mod = tvm.build(s, [A, C], target="c", name=func_name) - - with micro.Session(DEV_CONFIG_A) as sess: - micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) - micro_func = micro_mod[func_name] - ctx = tvm.micro_dev(0) - a_np = np.random.uniform(size=shape).astype(dtype) - a = tvm.nd.array(a_np, ctx) - c = tvm.nd.array(np.zeros(shape, dtype=dtype), ctx) - micro_func(a, c) - - # ensure input wasn't corrupted - tvm.testing.assert_allclose(a.asnumpy(), a_np) - # ensure output is correct - tvm.testing.assert_allclose(c.asnumpy(), a.asnumpy() + 2.0) - - -def test_graph_runtime(): - """Test a program which uses the graph runtime.""" - if not tvm.runtime.enabled("micro_dev"): - return - shape = (1024,) - dtype = "float32" - - # Construct Relay program. - x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) - xx = relay.multiply(x, x) - z = relay.add(xx, relay.const(1.0)) - func = relay.Function([x], z) - - with micro.Session(DEV_CONFIG_A): - mod = relay_micro_build(func, DEV_CONFIG_A) - - x_in = np.random.uniform(size=shape[0]).astype(dtype) - mod.run(x=x_in) - result = mod.get_output(0).asnumpy() - - tvm.testing.assert_allclose(mod.get_input(0).asnumpy(), x_in) - tvm.testing.assert_allclose(result, x_in * x_in + 1.0) - - -def test_conv2d(): - if not tvm.runtime.enabled("micro_dev"): - return - - from tvm.relay import create_executor - from tvm.relay import transform - - dshape = (1, 4, 16, 16) - dtype = "int8" - func_name = "fused_nn_conv2d" - - reset_gdbinit() - - # Construct Relay program. - x = relay.var("x", shape=dshape, dtype=dtype) - conv_expr = relay.nn.conv2d(x, relay.var("w"), kernel_size=(3, 3), padding=(1, 1), channels=4) - func = relay.Function(relay.analysis.free_vars(conv_expr), conv_expr) - mod = tvm.IRModule.from_expr(func) - mod = transform.InferType()(mod) - - x_shape = list(map(lambda x: x.value, mod["main"].params[0].checked_type.shape)) - w_shape = list(map(lambda x: x.value, mod["main"].params[1].checked_type.shape)) - out_shape = list(map(lambda x: x.value, mod["main"].ret_type.shape)) - - with tvm.transform.PassContext(config={"tir.disable_vectorize": True}): - graph, c_mod, params = relay.build(mod, target="c") - - with micro.Session(DEV_CONFIG_A): - micro_mod = micro.create_micro_mod(c_mod, DEV_CONFIG_A) - candidate_func_name = func_name - for i in range(100): - try: - micro_func = micro_mod[candidate_func_name] - break - except tvm.TVMError as e: - candidate_func_name = f"{func_name}_{i}" - else: - assert False - ctx = tvm.micro_dev(0) - - x_data = tvm.nd.array(np.random.uniform(size=x_shape).astype(dtype), ctx) - w_data = tvm.nd.array(np.random.uniform(size=w_shape).astype(dtype), ctx) - result = tvm.nd.array(np.zeros(shape=out_shape, dtype=dtype), ctx) - micro_func(x_data, w_data, result) - - out_data = np.zeros(out_shape, dtype=dtype) - params = {"x": x_data.asnumpy(), "w": w_data.asnumpy()} - intrp = create_executor("debug") - expected_result = intrp.evaluate(mod["main"])(x_data, w_data) - - tvm.testing.assert_allclose(result.asnumpy(), expected_result.asnumpy()) - - -def test_interleave_sessions(): - """Test closing and reopening sessions.""" - if not tvm.runtime.enabled("micro_dev"): - return - shape = (1024,) - dtype = "float32" - - # Construct Relay add program. - x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) - ret = relay.add(x, relay.const(1.0)) - add_const_func = relay.Function([x], ret) - - sess_a = micro.Session(DEV_CONFIG_A) - sess_b = micro.Session(DEV_CONFIG_B) - with sess_a: - np_tensor_a = np.random.uniform(size=shape).astype(dtype) - micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) - with sess_b: - np_tensor_b = np.random.uniform(size=shape).astype(dtype) - micro_tensor_b = tvm.nd.array(np_tensor_b, tvm.micro_dev(0)) - with sess_a: - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) - add_const_mod.run(x=micro_tensor_a) - add_result = add_const_mod.get_output(0).asnumpy() - tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0) - with sess_b: - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_B) - add_const_mod.run(x=micro_tensor_b) - add_result = add_const_mod.get_output(0).asnumpy() - tvm.testing.assert_allclose(add_result, np_tensor_b + 1.0) - - -def test_nested_sessions(): - """Test entering and exiting nested session contexts.""" - if not tvm.runtime.enabled("micro_dev"): - return - shape = (1024,) - dtype = "float32" - - # Construct Relay add program. - x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) - ret = relay.add(x, relay.const(1.0)) - add_const_func = relay.Function([x], ret) - - sess_a = micro.Session(DEV_CONFIG_A) - sess_b = micro.Session(DEV_CONFIG_B) - with sess_a: - np_tensor_a = np.random.uniform(size=shape).astype(dtype) - micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) - with sess_b: - np_tensor_b = np.random.uniform(size=shape).astype(dtype) - micro_tensor_b = tvm.nd.array(np_tensor_b, tvm.micro_dev(0)) - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) - add_const_mod.run(x=micro_tensor_a) - add_result = add_const_mod.get_output(0).asnumpy() - tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0) - - -def test_inactive_session_use(): - """Test the use of objects allocated in a session that is no longer active.""" - if not tvm.runtime.enabled("micro_dev"): - return - shape = (1024,) - dtype = "float32" - - # Construct Relay add program. - x = relay.var("x", relay.TensorType(shape=shape, dtype=dtype)) - ret = relay.add(x, relay.const(1.0)) - add_const_func = relay.Function([x], ret) - - sess_a = micro.Session(DEV_CONFIG_A) - sess_b = micro.Session(DEV_CONFIG_B) - with sess_a: - np_tensor_a = np.random.uniform(size=shape).astype(dtype) - micro_tensor_a = tvm.nd.array(np_tensor_a, tvm.micro_dev(0)) - add_const_mod = relay_micro_build(add_const_func, DEV_CONFIG_A) - - with sess_b: - # These objects belong to `sess_a`. - add_const_mod.run(x=micro_tensor_a) - add_result = add_const_mod.get_output(0).asnumpy() - tvm.testing.assert_allclose(add_result, np_tensor_a + 1.0) - - -# TODO add workspace alloc/free stress test - -if __name__ == "__main__": - test_alloc() - print() - print("finished alloc test") - input("[press enter to continue]") - test_add() - print() - print("finished add test") - input("[press enter to continue]") - test_workspace_add() - print() - print("finished workspace add test") - input("[press enter to continue]") - test_graph_runtime() - print() - print("finished graph runtime test") - input("[press enter to continue]") - test_conv2d() - print() - print("finished conv2d test") - input("[press enter to continue]") - test_interleave_sessions() - print() - print("finished interleaved sessions test") - input("[press enter to continue]") - test_nested_sessions() - print() - print("finished nested sessions test") - input("[press enter to continue]") - test_inactive_session_use() - print() - print("finished use inactive session test") - input("[press enter to continue]") diff --git a/tests/scripts/task_config_build_cpu.sh b/tests/scripts/task_config_build_cpu.sh index 77b28e66fbb7..521ab9b8ccdc 100755 --- a/tests/scripts/task_config_build_cpu.sh +++ b/tests/scripts/task_config_build_cpu.sh @@ -44,4 +44,4 @@ echo set\(USE_TFLITE ON\) >> config.cmake echo set\(USE_TENSORFLOW_PATH \"/tensorflow\"\) >> config.cmake echo set\(USE_FLATBUFFERS_PATH \"/flatbuffers\"\) >> config.cmake echo set\(USE_ETHOSN /opt/arm/ethosn-driver\) >> config.cmake -echo set\(USE_ETHOSN_HW OFF\) >> config.cmake \ No newline at end of file +echo set\(USE_ETHOSN_HW OFF\) >> config.cmake diff --git a/tests/scripts/task_config_build_i386.sh b/tests/scripts/task_config_build_i386.sh index e8eb6685832a..d773985277aa 100755 --- a/tests/scripts/task_config_build_i386.sh +++ b/tests/scripts/task_config_build_i386.sh @@ -26,6 +26,7 @@ cp ../cmake/config.cmake . echo set\(USE_SORT ON\) >> config.cmake echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_GRAPH_RUNTIME_DEBUG ON\) >> config.cmake +echo set\(USE_MICRO ON\) >> config.cmake echo set\(USE_MICRO_STANDALONE_RUNTIME ON\) >> config.cmake echo set\(USE_STANDALONE_CRT ON\) >> config.cmake echo set\(USE_VM_PROFILER ON\) >> config.cmake diff --git a/tests/scripts/task_cpp_unittest.sh b/tests/scripts/task_cpp_unittest.sh index 25a6bf06aec7..db68d9f9af6b 100755 --- a/tests/scripts/task_cpp_unittest.sh +++ b/tests/scripts/task_cpp_unittest.sh @@ -30,8 +30,14 @@ export OMP_NUM_THREADS=1 # Remove existing testcases rm -f build/*_test -make cpptest -j8 -make crttest -j8 +make cpptest -j2 +make crttest # NOTE: don't parallelize, due to issue with build deps. for test in build/*_test; do ./$test done + +# Test MISRA-C runtime +cd apps/bundle_deploy +rm -rf build +make test_dynamic test_static +cd ../.. diff --git a/tests/scripts/task_python_integration.sh b/tests/scripts/task_python_integration.sh index 741f15ba4a94..35a81e508643 100755 --- a/tests/scripts/task_python_integration.sh +++ b/tests/scripts/task_python_integration.sh @@ -33,12 +33,6 @@ find . -type f -path "*.pyc" | xargs rm -f # Test TVM make cython3 -# Test MISRA-C runtime -cd apps/bundle_deploy -rm -rf build -make test_dynamic test_static -cd ../.. - # Test extern package cd apps/extension rm -rf lib diff --git a/tutorials/micro/micro_tflite.py b/tutorials/micro/micro_tflite.py index ce30c0ace81d..0cd6a4fb4738 100644 --- a/tutorials/micro/micro_tflite.py +++ b/tutorials/micro/micro_tflite.py @@ -19,7 +19,7 @@ ============================ **Author**: `Tom Gall `_ -This tutorial is an introduction to working with MicroTVM and a TFLite +This tutorial is an introduction to working with MicroTVM and a TFLite model with Relay. """ @@ -148,24 +148,6 @@ tflite_model, shape_dict={input_tensor: input_shape}, dtype_dict={input_tensor: input_dtype} ) -# %% -# Running on device -# ---------------------------------------------- -# -# Setup the device config which is what will be used to communicate -# with the microcontroller (a STM32F746 Discovery board) -TARGET = "c --system-lib --runtime=c" -dev_config = micro.device.arm.stm32f746xx.generate_config("127.0.0.1", 6666) - -###################################################################### -# Next with the dev_config, we establish a micro session and create -# a context -# -# .. code-block:: python -# -# with micro.Session(dev_config) as sess: -# ctx = tvm.micro_dev(0) - ###################################################################### # Now we create a build config for relay. turning off two options # and then calling relay.build which will result in a C source @@ -173,48 +155,52 @@ # # .. code-block:: python # -# with tvm.transform.PassContext(opt_level=3, config={'tir.disable_vectorize': True},disabled_pass=['FuseOps']): -# graph, c_mod, params = relay.build(mod, target=TARGET, params=params) +TARGET = tvm.target.target.micro("host") -###################################################################### -# With the c_mod that is the handle to our C source code, we create a -# micro module, followed by a compiled object which behind the scenes -# is linked to the microTVM runtime for running on the target board -# -# .. code-block:: python -# -# micro_mod = micro.create_micro_mod(c_mod, dev_config) -# mod = graph_runtime.create(graph, micro_mod, ctx) +with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True},disabled_pass=["FuseOps"]): + graph, c_mod, c_params = relay.build(mod, target=TARGET, params=params) -###################################################################### -# Pass the weights to get ready to perform inference -# -# .. code-block:: python -# -# mod.set_input(**params) -###################################################################### -# The model consumes a single float32 value and returns a predicted -# sine value. -# To pass the input value we construct a tvm.nd.array object -# with a single contrived number as input. For this model values of -# 0 to 2Pi are acceptable. -# -# .. code-block:: python +# %% +# Running on simulated device +# ---------------------------------------------- # -# mod.set_input(input_tensor, tvm.nd.array(np.array([0.5], dtype="float32"))) +# First, compile a static microTVM runtime for the targeted device. In this case, the host simulated +# device is used. +workspace = tvm.micro.Workspace() + +compiler = tvm.micro.DefaultCompiler(target=TARGET) +opts = tvm.micro.default_options(os.path.join(tvm.micro.CRT_ROOT_DIR, "host")) + +micro_binary = tvm.micro.build_static_runtime( + # the x86 compiler *expects* you to give the exact same dictionary for both + # lib_opts and bin_opts. so the library compiler is mutating lib_opts and + # the binary compiler is expecting those mutations to be in bin_opts. + # TODO(weberlo) fix this very bizarre behavior + workspace, compiler, c_mod, lib_opts=opts["bin_opts"], bin_opts=opts["bin_opts"]) -###################################################################### -# Run the model on device -# -# .. code-block:: python -# -# mod.run() ###################################################################### -# Get output from the run and print +# Next, establish a session with the simulated device and run the +# computation. The `with session` line would typically flash an attached +# microcontroller, but in this tutorial, it simply launches a subprocess +# to stand in for an attached microcontroller. # # .. code-block:: python # -# tvm_output = mod.get_output(0).asnumpy() -# print("result is: "+str(tvm_output)) +flasher = compiler.flasher() +with tvm.micro.Session(binary=micro_binary, flasher=flasher) as session: + graph_mod = tvm.micro.create_local_graph_runtime( + graph, session.get_system_lib(), session.context) + + # Set the model parameters using the lowered parameters produced by `relay.build`. + graph_mod.set_input(**c_params) + + # The model consumes a single float32 value and returns a predicted sine value. To pass the + # input value we construct a tvm.nd.array object with a single contrived number as input. For + # this model values of 0 to 2Pi are acceptable. + graph_mod.set_input(input_tensor, tvm.nd.array(np.array([0.5], dtype="float32"))) + graph_mod.run() + + tvm_output = graph_mod.get_output(0).asnumpy() + print("result is: "+str(tvm_output))