From 55b877b09015306b2a86952eba4f2469adaf4353 Mon Sep 17 00:00:00 2001 From: coincheung <867153576@qq.com> Date: Fri, 7 Oct 2022 02:42:37 +0000 Subject: [PATCH] add int8 inference for trt --- README.md | 6 +- tensorrt/CMakeLists.txt | 4 +- tensorrt/README.md | 12 ++- tensorrt/batch_stream.hpp | 148 +++++++++++++++++++++++++++++ tensorrt/entropy_calibrator.hpp | 160 ++++++++++++++++++++++++++++++++ tensorrt/read_img.cpp | 56 +++++++++++ tensorrt/read_img.hpp | 26 ++++++ tensorrt/segment.cpp | 61 +++++------- tensorrt/trt_dep.cpp | 24 ++++- tensorrt/trt_dep.hpp | 3 +- tools/evaluate.py | 3 +- tools/export_onnx.py | 2 +- tools/train_amp.py | 4 +- 13 files changed, 459 insertions(+), 50 deletions(-) create mode 100644 tensorrt/batch_stream.hpp create mode 100644 tensorrt/entropy_calibrator.hpp create mode 100644 tensorrt/read_img.cpp create mode 100644 tensorrt/read_img.hpp diff --git a/README.md b/README.md index 8aac4d9..6029258 100644 --- a/README.md +++ b/README.md @@ -4,10 +4,10 @@ My implementation of [BiSeNetV1](https://arxiv.org/abs/1808.00897) and [BiSeNetV mIOUs and fps on cityscapes val set: -| none | ss | ssc | msf | mscf | fps(fp16/fp32) | link | +| none | ss | ssc | msf | mscf | fps(fp32/fp16/int8) | link | |------|:--:|:---:|:---:|:----:|:---:|:----:| -| bisenetv1 | 75.44 | 76.94 | 77.45 | 78.86 | 78/25 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_city_new.pth) | -| bisenetv2 | 74.95 | 75.58 | 76.53 | 77.08 | 67/26 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_city.pth) | +| bisenetv1 | 75.44 | 76.94 | 77.45 | 78.86 | 25/78/141 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v1_city_new.pth) | +| bisenetv2 | 74.95 | 75.58 | 76.53 | 77.08 | 26/67/95 | [download](https://github.com/CoinCheung/BiSeNet/releases/download/0.0.0/model_final_v2_city.pth) | mIOUs on cocostuff val2017 set: | none | ss | ssc | msf | mscf | link | diff --git a/tensorrt/CMakeLists.txt b/tensorrt/CMakeLists.txt index 069b18e..1e3b2a9 100644 --- a/tensorrt/CMakeLists.txt +++ b/tensorrt/CMakeLists.txt @@ -10,6 +10,8 @@ link_directories(/usr/local/cuda/lib64) link_directories(${PROJECT_SOURCE_DIR}/build) # include_directories(/root/build/TensorRT-8.2.5.1/include) # link_directories(/root/build/TensorRT-8.2.5.1/lib) +include_directories(/data/zzy/trt_install/TensorRT-8.2.5.1/include) +link_directories(/data/zzy/trt_install/TensorRT-8.2.5.1/lib) find_package(CUDA REQUIRED) @@ -17,7 +19,7 @@ find_package(OpenCV REQUIRED) cuda_add_library(kernels STATIC kernels.cu) -add_executable(segment segment.cpp trt_dep.cpp) +add_executable(segment segment.cpp trt_dep.cpp read_img.cpp) target_include_directories( segment PUBLIC ${CUDA_INCLUDE_DIRS} ${CUDNN_INCLUDE_DIRS} ${OpenCV_INCLUDE_DIRS}) target_link_libraries( diff --git a/tensorrt/README.md b/tensorrt/README.md index a0227db..5bba5bb 100644 --- a/tensorrt/README.md +++ b/tensorrt/README.md @@ -38,7 +38,8 @@ This would generate a `./segment` in the `tensorrt/build` directory. #### 3. Convert onnx to tensorrt model -If you can successfully compile the source code, you can parse the onnx model to tensorrt model like this: +If you can successfully compile the source code, you can parse the onnx model to tensorrt model with one of the following commands. +For fp32, command is: ``` $ ./segment compile /path/to/onnx.model /path/to/saved_model.trt ``` @@ -46,6 +47,13 @@ If your gpu support acceleration with fp16 inferenece, you can add a `--fp16` op ``` $ ./segment compile /path/to/onnx.model /path/to/saved_model.trt --fp16 ``` +Building an int8 engine is also supported. Firstly, you should make sure your gpu support int8 inference, or you model will not be faster than fp16/fp32. Then you should prepare certain amount of images for int8 calibration. In this example, I use train set of cityscapes for calibration. The command is like this: +``` +$ calibrate_int8 # delete this if exists +$ ./segment compile /path/to/onnx.model /path/to/saved_model.trt --int8 /path/to/BiSeNet/datasets/cityscapes /path/to/BiSeNet/datasets/cityscapes/train.txt +``` +With the above commands, we will have an tensorrt engine named `saved_model.trt` generated. + Note that I use the simplest method to parse the command line args, so please do **Not** change the order of the args in above command. @@ -74,6 +82,8 @@ Likewise, you do not need to worry about this anymore with version newer than 7. 4. On my platform, after compiling with tensorrt, the model size of bisenetv1 is 29Mb(fp16) and 128Mb(fp32), and the size of bisenetv2 is 16Mb(fp16) and 42Mb(fp32). However, the fps of bisenetv1 is 68(fp16) and 23(fp32), while the fps of bisenetv2 is 59(fp16) and 21(fp32). It is obvious that bisenetv2 has fewer parameters than bisenetv1, but the speed is otherwise. I am not sure whether it is because tensorrt has worse optimization strategy in some ops used in bisenetv2(such as depthwise convolution) or because of the limitation of the gpu on different ops. Please tell me if you have better idea on this. +5. int8 mode is not always greatly faster than fp16 mode. For example, I tested with bisenetv1-cityscapes and tensorrt 8.2.5.1. With v100 gpu and driver 515.65, the fp16/int8 fps is 185.89/186.85, while with t4 gpu and driver 450.80, it is 78.77/142.31. + ### Using python diff --git a/tensorrt/batch_stream.hpp b/tensorrt/batch_stream.hpp new file mode 100644 index 0000000..09d0262 --- /dev/null +++ b/tensorrt/batch_stream.hpp @@ -0,0 +1,148 @@ + +#ifndef BATCH_STREAM_HPP +#define BATCH_STREAM_HPP + + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "NvInfer.h" +#include "read_img.hpp" + +using nvinfer1::Dims; +using nvinfer1::Dims3; +using nvinfer1::Dims4; + + +class IBatchStream +{ +public: + virtual void reset(int firstBatch) = 0; + virtual bool next() = 0; + virtual void skip(int skipCount) = 0; + virtual float* getBatch() = 0; + virtual int getBatchesRead() const = 0; + virtual int getBatchSize() const = 0; + virtual nvinfer1::Dims4 getDims() const = 0; +}; + + +class BatchStream : public IBatchStream +{ +public: + BatchStream(int batchSize, int maxBatches, Dims indim, + const std::string& dataRoot, + const std::string& dataFile) + : mBatchSize{batchSize} + , mMaxBatches{maxBatches} + { + mDims = Dims3(indim.d[1], indim.d[2], indim.d[3]); + + readDataFile(dataFile, dataRoot); + mSampleSize = std::accumulate( + mDims.d, mDims.d + mDims.nbDims, 1, std::multiplies()) * sizeof(float); + mData.resize(mSampleSize * mBatchSize); + } + + void reset(int firstBatch) override + { + cout << "mBatchCount: " << mBatchCount << endl; + mBatchCount = firstBatch; + } + + bool next() override + { + if (mBatchCount >= mMaxBatches) + { + return false; + } + ++mBatchCount; + return true; + } + + void skip(int skipCount) override + { + mBatchCount += skipCount; + } + + float* getBatch() override + { + int offset = mBatchCount * mBatchSize; + for (int i{0}; i < mBatchSize; ++i) { + int ind = offset + i; + read_data(mPaths[ind], &mData[i * mSampleSize], mDims.d[1], mDims.d[2]); + } + return mData.data(); + } + + int getBatchesRead() const override + { + return mBatchCount; + } + + int getBatchSize() const override + { + return mBatchSize; + } + + nvinfer1::Dims4 getDims() const override + { + return Dims4{mBatchSize, mDims.d[0], mDims.d[1], mDims.d[2]}; + } + +private: + void readDataFile(const std::string& dataFilePath, const std::string& dataRootPath) + { + std::ifstream file(dataFilePath, std::ios::in); + if (!file.is_open()) { + cout << "file open failed: " << dataFilePath << endl; + std::abort(); + } + std::stringstream ss; + file >> ss.rdbuf(); + file.close(); + + std::string impth; + int n_imgs = 0; + while (std::getline(ss, impth)) ++n_imgs; + ss.clear(); ss.seekg(0, std::ios::beg); + if (n_imgs <= 0) { + cout << "ann file is empty, cannot read image paths for int8 calibration: " + << dataFilePath << endl; + std::abort(); + } + + mPaths.resize(n_imgs); + for (int i{0}; i < n_imgs; ++i) { + std::getline(ss, impth, ','); + mPaths[i] = dataRootPath + "/" + impth; + std::getline(ss, impth); + } + if (mMaxBatches < 0) { + mMaxBatches = n_imgs / mBatchSize - 1; + } + if (mMaxBatches <= 0) { + cout << "must have at least 1 batch for calibration\n"; + std::abort(); + } + cout << "mMaxBatches = " << mMaxBatches << endl; + } + + + int mBatchSize{0}; + int mBatchCount{0}; + int mMaxBatches{0}; + Dims3 mDims{}; + std::vector mPaths; + std::vector mData; + int mSampleSize{0}; +}; + + +#endif diff --git a/tensorrt/entropy_calibrator.hpp b/tensorrt/entropy_calibrator.hpp new file mode 100644 index 0000000..2490986 --- /dev/null +++ b/tensorrt/entropy_calibrator.hpp @@ -0,0 +1,160 @@ +/* + * Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef ENTROPY_CALIBRATOR_HPP +#define ENTROPY_CALIBRATOR_HPP + +#include +#include +#include +#include "NvInfer.h" + +//! \class EntropyCalibratorImpl +//! +//! \brief Implements common functionality for Entropy calibrators. +//! +template +class EntropyCalibratorImpl +{ +public: + EntropyCalibratorImpl( + TBatchStream stream, int firstBatch, std::string cal_table_name, const char* inputBlobName, bool readCache = true) + : mStream{stream} + , mCalibrationTableName(cal_table_name) + , mInputBlobName(inputBlobName) + , mReadCache(readCache) + { + nvinfer1::Dims4 dims = mStream.getDims(); + mInputCount = std::accumulate( + dims.d, dims.d + dims.nbDims, 1, std::multiplies()); + cout << "dims.nbDims: " << dims.nbDims << endl; + for (int i{0}; i < dims.nbDims; ++i) { + cout << dims.d[i] << ", "; + } + cout << endl; + + cudaError_t state; + state = cudaMalloc(&mDeviceInput, mInputCount * sizeof(float)); + if (state) { + cout << "allocate memory failed\n"; + std::abort(); + } + cout << "mInputCount: " << mInputCount << endl; + mStream.reset(firstBatch); + } + + virtual ~EntropyCalibratorImpl() + { + cudaError_t state; + state = cudaFree(mDeviceInput); + if (state) { + cout << "free memory failed\n"; + std::abort(); + } + } + + int getBatchSize() const + { + return mStream.getBatchSize(); + } + + bool getBatch(void* bindings[], const char* names[], int nbBindings) + { + if (!mStream.next()) + { + return false; + } + cudaError_t state; + state = cudaMemcpy(mDeviceInput, mStream.getBatch(), mInputCount * sizeof(float), cudaMemcpyHostToDevice); + if (state) { + cout << "allocate memory failed\n"; + std::abort(); + } + assert(!strcmp(names[0], mInputBlobName)); + bindings[0] = mDeviceInput; + return true; + } + + const void* readCalibrationCache(size_t& length) + { + mCalibrationCache.clear(); + std::ifstream input(mCalibrationTableName, std::ios::binary); + input >> std::noskipws; + if (mReadCache && input.good()) + { + std::copy(std::istream_iterator(input), std::istream_iterator(), + std::back_inserter(mCalibrationCache)); + } + length = mCalibrationCache.size(); + return length ? mCalibrationCache.data() : nullptr; + } + + void writeCalibrationCache(const void* cache, size_t length) + { + std::ofstream output(mCalibrationTableName, std::ios::binary); + output.write(reinterpret_cast(cache), length); + } + +private: + TBatchStream mStream; + size_t mInputCount; + std::string mCalibrationTableName; + const char* mInputBlobName; + bool mReadCache{true}; + void* mDeviceInput{nullptr}; + std::vector mCalibrationCache; +}; + +//! \class Int8EntropyCalibrator2 +//! +//! \brief Implements Entropy calibrator 2. +//! CalibrationAlgoType is kENTROPY_CALIBRATION_2. +//! +template +class Int8EntropyCalibrator2 : public nvinfer1::IInt8EntropyCalibrator2 +{ +public: + Int8EntropyCalibrator2( + TBatchStream stream, int firstBatch, const char* networkName, const char* inputBlobName, bool readCache = true) + : mImpl(stream, firstBatch, networkName, inputBlobName, readCache) + { + } + + int getBatchSize() const noexcept override + { + return mImpl.getBatchSize(); + } + + bool getBatch(void* bindings[], const char* names[], int nbBindings) noexcept override + { + return mImpl.getBatch(bindings, names, nbBindings); + } + + const void* readCalibrationCache(size_t& length) noexcept override + { + return mImpl.readCalibrationCache(length); + } + + void writeCalibrationCache(const void* cache, size_t length) noexcept override + { + mImpl.writeCalibrationCache(cache, length); + } + +private: + EntropyCalibratorImpl mImpl; +}; + +#endif // ENTROPY_CALIBRATOR_H diff --git a/tensorrt/read_img.cpp b/tensorrt/read_img.cpp new file mode 100644 index 0000000..60657da --- /dev/null +++ b/tensorrt/read_img.cpp @@ -0,0 +1,56 @@ + +#include +#include +#include +#include +#include +#include + + +using std::cout; +using std::endl; +using std::vector; +using std::string; +using cv::Mat; + + +void read_data(std::string impth, float *data, int iH, int iW, + int& orgH, int& orgW) { + vector mean{0.485f, 0.456f, 0.406f}; // rgb order + vector variance{0.229f, 0.224f, 0.225f}; + + Mat im = cv::imread(impth); + if (im.empty()) { + cout << "cannot read image \n"; + std::abort(); + } + + orgH = im.rows; orgW = im.cols; + if ((orgH != iH) || orgW != iW) { + cout << "resize orignal image of (" << orgH << "," << orgW + << ") to (" << iH << ", " << iW << ") according to model require\n"; + cv::resize(im, im, cv::Size(iW, iH), cv::INTER_CUBIC); + } + + // normalize and convert to rgb + float scale = 1.f / 255.f; + for (int i{0}; i < variance.size(); ++ i) { + variance[i] = 1.f / variance[i]; + } + for (int h{0}; h < iH; ++h) { + cv::Vec3b *p = im.ptr(h); + for (int w{0}; w < iW; ++w) { + for (int c{0}; c < 3; ++c) { + int idx = c * iH * iW + h * iW + w; + data[idx] = (p[w][2 - c] * scale - mean[c]) * variance[c]; + } + } + } +} + + +void read_data(std::string impth, float *data, int iH, int iW) { + int tmp1, tmp2; + read_data(impth, data, iH, iW, tmp1, tmp2); +} + diff --git a/tensorrt/read_img.hpp b/tensorrt/read_img.hpp new file mode 100644 index 0000000..08d930b --- /dev/null +++ b/tensorrt/read_img.hpp @@ -0,0 +1,26 @@ + +#ifndef _READ_IMAGE_HPP_ +#define _READ_IMAGE_HPP_ + + +#include +#include +#include +#include +#include +#include + + +using std::cout; +using std::endl; +using std::vector; +using std::string; +using cv::Mat; + + +void read_data(std::string impth, float *data, + int iH, int iW, int& orgH, int& orgW); +void read_data(std::string impth, float *data, int iH, int iW); + + +#endif diff --git a/tensorrt/segment.cpp b/tensorrt/segment.cpp index 2d00ce7..a71b4ad 100644 --- a/tensorrt/segment.cpp +++ b/tensorrt/segment.cpp @@ -15,6 +15,7 @@ #include #include "trt_dep.hpp" +#include "read_img.hpp" using nvinfer1::IHostMemory; @@ -62,7 +63,8 @@ int main(int argc, char* argv[]) { if (args[0] == "compile") { if (argc < 4) { - cout << "usage is: ./segment compile input.onnx output.trt [--fp16]\n"; + cout << "usage is: ./segment compile input.onnx output.trt [--fp16|--fp32]\n"; + cout << "or ./segment compile input.onnx output.trt --int8 /path/to/data_root /path/to/ann_file\n"; std::abort(); } compile_onnx(args); @@ -85,10 +87,25 @@ int main(int argc, char* argv[]) { void compile_onnx(vector args) { - bool use_fp16{false}; - if ((args.size() >= 4) && args[3] == "--fp16") use_fp16 = true; + string quant("fp32"); + string data_root("none"); + string data_file("none"); + if ((args.size() >= 4)) { + if (args[3] == "--fp32") { + quant = "fp32"; + } else if (args[3] == "--fp16") { + quant = "fp16"; + } else if (args[3] == "--int8") { + quant = "int8"; + data_root = args[4]; + data_file = args[5]; + } else { + cout << "invalid args of quantization: " << args[3] << endl; + std::abort(); + } + } - TrtSharedEnginePtr engine = parse_to_engine(args[1], use_fp16); + TrtSharedEnginePtr engine = parse_to_engine(args[1], quant, data_root, data_file); serialize(engine, args[2]); } @@ -105,36 +122,9 @@ void run_with_trt(vector args) { const int oH{o_dims.d[2]}, oW{o_dims.d[3]}; // prepare image and resize - Mat im = cv::imread(args[2]); - if (im.empty()) { - cout << "cannot read image \n"; - std::abort(); - } - // CHECK (!im.empty()) << "cannot read image \n"; - int orgH{im.rows}, orgW{im.cols}; - if ((orgH != iH) || orgW != iW) { - cout << "resize orignal image of (" << orgH << "," << orgW - << ") to (" << iH << ", " << iW << ") according to model require\n"; - cv::resize(im, im, cv::Size(iW, iH), cv::INTER_CUBIC); - } - - // normalize and convert to rgb - array mean{0.485f, 0.456f, 0.406f}; - array variance{0.229f, 0.224f, 0.225f}; - float scale = 1.f / 255.f; - for (int i{0}; i < 3; ++ i) { - variance[i] = 1.f / variance[i]; - } - vector data(iH * iW * 3); - for (int h{0}; h < iH; ++h) { - cv::Vec3b *p = im.ptr(h); - for (int w{0}; w < iW; ++w) { - for (int c{0}; c < 3; ++c) { - int idx = (2 - c) * iH * iW + h * iW + w; // to rgb order - data[idx] = (p[w][c] * scale - mean[c]) * variance[c]; - } - } - } + vector data; data.resize(iH * iW * 3); + int orgH, orgW; + read_data(args[2], &data[0], iH, iW, orgH, orgW); // call engine vector res = infer_with_engine(engine, data); @@ -155,11 +145,10 @@ void run_with_trt(vector args) { } // resize back and save - if ((orgH != oH) || orgW != oW) { + if ((orgH != oH) || (orgW != oW)) { cv::resize(pred, pred, cv::Size(orgW, orgH), cv::INTER_CUBIC); } cv::imwrite(args[3], pred); - } diff --git a/tensorrt/trt_dep.cpp b/tensorrt/trt_dep.cpp index 905bed9..71f105c 100644 --- a/tensorrt/trt_dep.cpp +++ b/tensorrt/trt_dep.cpp @@ -9,6 +9,8 @@ #include #include "trt_dep.hpp" +#include "batch_stream.hpp" +#include "entropy_calibrator.hpp" #include "kernels.hpp" @@ -43,7 +45,8 @@ TrtSharedEnginePtr shared_engine_ptr(ICudaEngine* ptr) { } -TrtSharedEnginePtr parse_to_engine(string onnx_pth, bool use_fp16) { +TrtSharedEnginePtr parse_to_engine(string onnx_pth, + string quant, string data_root, string data_file) { unsigned int maxBatchSize{1}; long memory_limit = 1UL << 32; // 4G @@ -81,11 +84,26 @@ TrtSharedEnginePtr parse_to_engine(string onnx_pth, bool use_fp16) { std::abort(); } - config->setMaxWorkspaceSize(memory_limit); - if (use_fp16 && builder->platformHasFastFp16()) { + if ((quant == "fp16" or quant == "int8") && builder->platformHasFastFp16()) { config->setFlag(nvinfer1::BuilderFlag::kFP16); // fp16 } + std::unique_ptr calibrator; + if (quant == "int8" && builder->platformHasFastInt8()) { + config->setFlag(nvinfer1::BuilderFlag::kINT8); //int8 + int batchsize = 32; + int n_cal_batches = -1; + string cal_table_name = "calibrate_int8"; + string input_name = "input_image"; + + Dims indim = network->getInput(0)->getDimensions(); + BatchStream calibrationStream( + batchsize, n_cal_batches, indim, + data_root, data_file); + calibrator.reset(new Int8EntropyCalibrator2( + calibrationStream, 0, cal_table_name.c_str(), input_name.c_str())); + config->setInt8Calibrator(calibrator.get()); + } auto output = network->getOutput(0); // output->setType(nvinfer1::DataType::kINT32); diff --git a/tensorrt/trt_dep.hpp b/tensorrt/trt_dep.hpp index 57b8d9c..2b794dc 100644 --- a/tensorrt/trt_dep.hpp +++ b/tensorrt/trt_dep.hpp @@ -54,7 +54,8 @@ extern Logger gLogger; TrtSharedEnginePtr shared_engine_ptr(ICudaEngine* ptr); -TrtSharedEnginePtr parse_to_engine(string onnx_path, bool use_fp16); +TrtSharedEnginePtr parse_to_engine(string onnx_path, string quant, + string data_root, string data_file); void serialize(TrtSharedEnginePtr engine, string save_path); TrtSharedEnginePtr deserialize(string serpth); vector infer_with_engine(TrtSharedEnginePtr engine, vector& data); diff --git a/tools/evaluate.py b/tools/evaluate.py index 3c8e9e3..435affc 100644 --- a/tools/evaluate.py +++ b/tools/evaluate.py @@ -89,7 +89,8 @@ def compute_metrics(self,): fns = confusion.sum(dim=1) - tps # iou and fw miou - ious = confusion.diag() / (confusion.sum(dim=0) + confusion.sum(dim=1) - confusion.diag() + 1) + # ious = confusion.diag() / (confusion.sum(dim=0) + confusion.sum(dim=1) - confusion.diag() + 1) + ious = tps / (tps + fps + fns + 1) miou = ious.nanmean() fw_miou = torch.sum(weights * ious) diff --git a/tools/export_onnx.py b/tools/export_onnx.py index 8a637f7..7acd98e 100644 --- a/tools/export_onnx.py +++ b/tools/export_onnx.py @@ -38,5 +38,5 @@ torch.onnx.export(net, dummy_input, args.out_pth, input_names=input_names, output_names=output_names, - verbose=False, opset_version=11) + verbose=False, opset_version=11, ) diff --git a/tools/train_amp.py b/tools/train_amp.py index c661c2a..fa9c7d2 100644 --- a/tools/train_amp.py +++ b/tools/train_amp.py @@ -28,6 +28,7 @@ from lib.logger import setup_logger, print_log_msg + ## fix all random seeds # torch.manual_seed(123) # torch.cuda.manual_seed(123) @@ -38,8 +39,6 @@ # torch.multiprocessing.set_sharing_strategy('file_system') - - def parse_args(): parse = argparse.ArgumentParser() parse.add_argument('--config', dest='config', type=str, @@ -51,7 +50,6 @@ def parse_args(): cfg = set_cfg_from_file(args.config) - def set_model(lb_ignore=255): logger = logging.getLogger() net = model_factory[cfg.model_type](cfg.n_cats)