-
Notifications
You must be signed in to change notification settings - Fork 5.4k
Tutorial: Deploy Face Recognition Model via TVM
InsightFace is a CNN based face recognition project with a series of training pipeline. Beyond that, we also reproduce some modern face-related paper serve the face recognition well.
This tutorial will introduce how to deploy an Insightface model in production mode with TVM Stack.
TVM is an open deep learning compiler stack for CPUs, GPUs, and specialized accelerators. It aims to close the gap between the productivity-focused deep learning frameworks, and the performance- or efficiency-oriented hardware backends.
TVM Installation only needs to the TVM source code and LLVM + clang compiler. you can follow the official installation tutorial.
- The newest LLVM 7.0 maybe will cause the compile errors. we recommend using LLVM 6.0.1 as a stable version.
- if a model has plenty of the parameters which close to 0, it may will cause a performance drag.
TVM use a series of pipeline to optimize the computation graph to make it extremely fast. We use the general wide-use MobileFaceNet(From Model Zoo) as an example. we can use this python code to compile the model.
import numpy as np
import nnvm.compiler
import nnvm.testing
import tvm
from tvm.contrib import graph_runtime
import mxnet as mx
from mxnet import ndarray as nd
prefix,epoch = "emore1",0
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
image_size = (112, 112)
opt_level = 3
shape_dict = {'data': (1, 3, *image_size)}
target = tvm.target.create("llvm -mcpu=haswell")
# "target" means your target platform you want to compile.
#target = tvm.target.create("llvm -mcpu=broadwell")
nnvm_sym, nnvm_params = nnvm.frontend.from_mxnet(sym, arg_params, aux_params)
with nnvm.compiler.build_config(opt_level=opt_level):
graph, lib, params = nnvm.compiler.build(nnvm_sym, target, shape_dict, params=nnvm_params)
lib.export_library("./deploy_lib.so")
print('lib export succeefully')
with open("./deploy_graph.json", "w") as fo:
fo.write(graph.json())
with open("./deploy_param.params", "wb") as fo:
fo.write(nnvm.compiler.save_param_dict(params))
The code will generate a compiled deploy_lib.so with optimized .json and .param file。
There is no need to compile TVM with LLVM when we deploy a compiled model.
we just need to compile TVM runtime with Python on the target platform.
import numpy as np
import nnvm.compiler
import nnvm.testing
import tvm
from tvm.contrib import graph_runtime
import mxnet as mx
from mxnet import ndarray as nd
ctx = tvm.cpu()
# load the module back.
loaded_json = open("./deploy_graph.json").read()
loaded_lib = tvm.module.load("./deploy_lib.so")
loaded_params = bytearray(open("./deploy_param.params", "rb").read())
input_data = tvm.nd.array(np.random.uniform(size=data_shape).astype("float32"))
module = graph_runtime.create(loaded_json, loaded_lib, ctx)
module.load_params(loaded_params)
# Tiny benchmark test.
import time
for i in range(100):
t0 = time.time()
module.run(data=input_data)
print(time.time() - t0)
The tvm runtime just need include tvm_runtime_pack.cc and combine with the export shared library.I implement a cpp demo to show how to use tvm_runtime
#include <stdio.h>
#include <opencv2/opencv.hpp>
#include <tvm/runtime/module.h>
#include <tvm/runtime/registry.h>
#include <tvm/runtime/packed_func.h>
class FR_MFN_Deploy{
private:
void * handle;
public:
FR_MFN_Deploy(std::string modelFolder)
{
tvm::runtime::Module mod_syslib = tvm::runtime::Module::LoadFromFile(modelFolder + "/deploy_lib.so");
//load graph
std::ifstream json_in(modelFolder + "/deploy_graph.json");
std::string json_data((std::istreambuf_iterator<char>(json_in)), std::istreambuf_iterator<char>());
json_in.close();
int device_type = kDLCPU;
int device_id = 0;
// get global function module for graph runtime
tvm::runtime::Module mod = (*tvm::runtime::Registry::Get("tvm.graph_runtime.create"))(json_data, mod_syslib, device_type, device_id);
this->handle = new tvm::runtime::Module(mod);
//load param
std::ifstream params_in(modelFolder + "/deploy_param.params", std::ios::binary);
std::string params_data((std::istreambuf_iterator<char>(params_in)), std::istreambuf_iterator<char>());
params_in.close();
TVMByteArray params_arr;
params_arr.data = params_data.c_str();
params_arr.size = params_data.length();
tvm::runtime::PackedFunc load_params = mod.GetFunction("load_params");
load_params(params_arr);
}
cv::Mat forward(cv::Mat inputImageAligned)
{
//mobilefacnet preprocess has been written in graph.
cv::Mat tensor = cv::dnn::blobFromImage(inputImageAligned,1.0,cv::Size(112,112),cv::Scalar(0,0,0),true);
//convert uint8 to float32 and convert to RGB via opencv dnn function
DLTensor* input;
constexpr int dtype_code = kDLFloat;
constexpr int dtype_bits = 32;
constexpr int dtype_lanes = 1;
constexpr int device_type = kDLCPU;
constexpr int device_id = 0;
constexpr int in_ndim = 4;
const int64_t in_shape[in_ndim] = {1, 3, 112, 112};
TVMArrayAlloc(in_shape, in_ndim, dtype_code, dtype_bits, dtype_lanes, device_type, device_id, &input);//
TVMArrayCopyFromBytes(input,tensor.data,112*3*112*4);
tvm::runtime::Module* mod = (tvm::runtime::Module*)handle;
tvm::runtime::PackedFunc set_input = mod->GetFunction("set_input");
set_input("data", input);
tvm::runtime::PackedFunc run = mod->GetFunction("run");
run();
tvm::runtime::PackedFunc get_output = mod->GetFunction("get_output");
tvm::runtime::NDArray res = get_output(0);
cv::Mat vector(128,1,CV_32F);
memcpy(vector.data,res->data,128*4);
cv::Mat _l2;
// normlize
cv::multiply(vector,vector,_l2);
float l2 = cv::sqrt(cv::sum(_l2).val[0]);
vector = vector / l2;
TVMArrayFree(input);
return vector;
}
};
To test with an aligned face pair.
cv::Mat A = cv::imread("/Users/yujinke/Desktop/align_id/aligned/20171231115821836_face.jpg");
cv::Mat B = cv::imread("/Users/yujinke/Desktop/align_id/aligned/20171231115821836_idcard.jpg");
FR_MFN_Deploy deploy("./models");
cv::Mat v2 = deploy.forward(B);
cv::Mat v1 = deploy.forward(A);
Measure the cosine of a face pair.
inline float CosineDistance(const cv::Mat &v1,const cv::Mat &v2){
return static_cast<float>(v1.dot(v2));
}
std::cout<<CosineDistance(v1,v2)<<std::endl;