From 067b2f10db5a29c52983cbf39618fb2735b81c0b Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Sat, 10 Oct 2020 06:25:11 -0400 Subject: [PATCH] Add cpp test of libtorch tracing --- .github/workflows/nightly.yml | 56 ++++++++++++++++++++++++++++++++++ .github/workflows/stable.yml | 57 +++++++++++++++++++++++++++++++++++ requirements.txt | 9 ++++++ test/__init__.py | 0 test/tracing/CMakeLists.txt | 21 +++++++++++++ test/tracing/test_tracing.cpp | 56 ++++++++++++++++++++++++++++++++++ test/tracing/trace_model.py | 14 +++++++++ 7 files changed, 213 insertions(+) create mode 100644 .github/workflows/nightly.yml create mode 100644 .github/workflows/stable.yml create mode 100644 requirements.txt create mode 100644 test/__init__.py create mode 100644 test/tracing/CMakeLists.txt create mode 100644 test/tracing/test_tracing.cpp create mode 100644 test/tracing/trace_model.py diff --git a/.github/workflows/nightly.yml b/.github/workflows/nightly.yml new file mode 100644 index 00000000..d4d58792 --- /dev/null +++ b/.github/workflows/nightly.yml @@ -0,0 +1,56 @@ +# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved. +# GH actions + +name: Nightly + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + Test: + runs-on: ${{ matrix.os }} + strategy: + matrix: + python-version: [3.6, 3.7] + os: [ubuntu-latest] + + steps: + - name: Clone repository + uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + architecture: 'x64' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Install PyTorch Nightly + run: | + pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html + - name: Build TorchVision Cpp Nightly + run: | + export TORCH_PATH=$(dirname $(python -c "import torch; print(torch.__file__)")) + cd .. + git clone https://github.com/pytorch/vision.git vision + cd vision + mkdir build && cd build + cmake .. -DTorch_DIR=$TORCH_PATH/share/cmake/Torch + make -j4 + sudo make install + - name: Test libtorch tracing + run: | + python -m test.tracing.trace_model + export TORCH_PATH=$(dirname $(python -c "import torch; print(torch.__file__)")) + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TORCH_PATH/lib/ + cd test/tracing + mkdir build && cd build + cmake .. -DTorch_DIR=$TORCH_PATH/share/cmake/Torch + make + mv ../yolov5s.torchscript.pt ./ + echo ">> Test libtorch tracing" + ./test_tracing diff --git a/.github/workflows/stable.yml b/.github/workflows/stable.yml new file mode 100644 index 00000000..1cfc44a6 --- /dev/null +++ b/.github/workflows/stable.yml @@ -0,0 +1,57 @@ +# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved. +# GH actions + +name: Stable + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] + +jobs: + Test: + runs-on: ${{ matrix.os }} + strategy: + matrix: + python-version: [3.6, 3.7] + os: [ubuntu-latest] + + steps: + - name: Clone repository + uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + architecture: 'x64' + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + - name: Install PyTorch 1.6 + run: | + pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html + - name: Build TorchVision Cpp + run: | + export TORCH_PATH=$(dirname $(python -c "import torch; print(torch.__file__)")) + cd .. + git clone https://github.com/pytorch/vision.git vision + cd vision + git checkout release/0.7 + mkdir build && cd build + cmake .. -DTorch_DIR=$TORCH_PATH/share/cmake/Torch + make -j4 + sudo make install + - name: Test libtorch tracing + run: | + python -m test.tracing.trace_model + export TORCH_PATH=$(dirname $(python -c "import torch; print(torch.__file__)")) + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$TORCH_PATH/lib/ + cd test/tracing + mkdir build && cd build + cmake .. -DTorch_DIR=$TORCH_PATH/share/cmake/Torch + make + mv ../yolov5s.torchscript.pt ./ + echo ">> Test libtorch tracing" + ./test_tracing diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..30ac0e3a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +Cython +matplotlib>=3.2.2 +numpy>=1.18.5 +opencv-python>=4.1.2 +pillow +PyYAML>=5.3 +scipy>=1.4.1 +tensorboard>=2.2 +tqdm>=4.41.0 diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/tracing/CMakeLists.txt b/test/tracing/CMakeLists.txt new file mode 100644 index 00000000..f57f69de --- /dev/null +++ b/test/tracing/CMakeLists.txt @@ -0,0 +1,21 @@ +cmake_minimum_required(VERSION 3.1 FATAL_ERROR) +project(test_tracing) + +find_package(Torch REQUIRED) +find_package(TorchVision REQUIRED) + +# This due to some headers importing Python.h +find_package(Python3 COMPONENTS Development) + +add_executable(${CMAKE_PROJECT_NAME} test_tracing.cpp) +target_compile_features(test_tracing PUBLIC cxx_range_for) + +target_link_libraries( + ${CMAKE_PROJECT_NAME} + ${TORCH_LIBRARIES} + TorchVision::TorchVision + Python3::Python +) + +# set C++14 to compile +set_property(TARGET test_tracing PROPERTY CXX_STANDARD 14) diff --git a/test/tracing/test_tracing.cpp b/test/tracing/test_tracing.cpp new file mode 100644 index 00000000..37b70e55 --- /dev/null +++ b/test/tracing/test_tracing.cpp @@ -0,0 +1,56 @@ +#include +#include +#include +#include +#include +#include + + +int main() { + torch::DeviceType device_type; + device_type = torch::kCPU; + + torch::jit::script::Module module; + try { + std::cout << "Loading model" << std::endl; + // Deserialize the ScriptModule from a file using torch::jit::load(). + module = torch::jit::load("yolov5s.torchscript.pt"); + std::cout << "Model loaded" << std::endl; + } catch (const torch::Error& e) { + std::cout << "error loading the model" << std::endl; + return -1; + } catch (const std::exception& e) { + std::cout << "Other error: " << e.what() << std::endl; + return -1; + } + + // TorchScript models require a List[IValue] as input + std::vector inputs; + + // Demonet accepts a List[Tensor] as main input + torch::Tensor images = torch::rand({1, 3, 416, 352}); + + inputs.push_back(images); + auto output = module.forward(inputs); + + std::cout << "ok" << std::endl; + std::cout << "output" << output << std::endl; + + if (torch::cuda::is_available()) { + // Move traced model to GPU + module.to(torch::kCUDA); + + // Add GPU inputs + inputs.clear(); + + torch::TensorOptions options = torch::TensorOptions{torch::kCUDA}; + images = images.to(torch::kCUDA); + + inputs.push_back(images); + auto output = module.forward(inputs); + + std::cout << "ok" << std::endl; + std::cout << "output" << output << std::endl; + } + return 0; +} diff --git a/test/tracing/trace_model.py b/test/tracing/trace_model.py new file mode 100644 index 00000000..c7d18a5b --- /dev/null +++ b/test/tracing/trace_model.py @@ -0,0 +1,14 @@ +import torch + +from hubconf import yolov5 + + +if __name__ == "__main__": + + model = yolov5( + cfg_path='./models/yolov5s.yaml', + ) + model.eval() + + traced_model = torch.jit.script(model) + traced_model.save("./test/tracing/yolov5s.torchscript.pt")