forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add end-to-end SGX ResNet inference example (apache#388)
- Loading branch information
Showing
13 changed files
with
520 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
lib/ | ||
bin/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Makefile for example to deploy TVM modules in SGX. | ||
|
||
PYTHON ?= python | ||
|
||
NNVM_ROOT := $(shell cd ../../; pwd) | ||
TVM_ROOT := $(NNVM_ROOT)/tvm | ||
DMLC_CORE_ROOT := $(NNVM_ROOT)/dmlc-core | ||
|
||
SGX_SDK ?= /opt/sgxsdk | ||
SGX_MODE ?= SIM | ||
SGX_ARCH ?= x64 | ||
SGX_DEBUG ?= 1 | ||
|
||
sgx_edger8r := $(SGX_SDK)/bin/x64/sgx_edger8r | ||
sgx_enclave_signer := $(SGX_SDK)/bin/x64/sgx_sign | ||
|
||
ifneq ($(SGX_MODE), HW) | ||
sgx_sim := _sim | ||
endif | ||
urts_library_name := sgx_urts$(sgx_sim) | ||
trts_library_name := sgx_trts$(sgx_sim) | ||
tservice_library_name := sgx_tservice$(sgx_sim) | ||
uservice_library_name := sgx_uae_service$(sgx_sim) | ||
|
||
pkg_cflags := -std=c++11 -O2 -fPIC\ | ||
-I$(NNVM_ROOT)/include\ | ||
-I$(NNVM_ROOT)\ | ||
-I$(TVM_ROOT)/include\ | ||
-I$(TVM_ROOT)/dlpack/include\ | ||
-I$(DMLC_CORE_ROOT)/include\ | ||
-DDMLC_LOG_STACK_TRACE=0\ | ||
|
||
pkg_ldflags := -L$(TVM_ROOT)/lib | ||
|
||
enclave_include_paths := -I$(SGX_SDK)/include\ | ||
-I$(SGX_SDK)/include/tlibc\ | ||
-I$(SGX_SDK)/include/libcxx\ | ||
-I$(SGX_SDK)/include/stdc++\ | ||
|
||
enclave_cflags := -static -nostdinc\ | ||
-fvisibility=hidden -fpie -fstack-protector-strong\ | ||
-ffunction-sections -fdata-sections\ | ||
-DDMLC_CXX11_THREAD_LOCAL=0\ | ||
$(enclave_include_paths)\ | ||
|
||
enclave_cxxflags := -nostdinc++ $(enclave_cflags) | ||
|
||
enclave_ldflags :=\ | ||
-Wl,--no-undefined -nostdlib -nodefaultlibs -nostartfiles -L$(SGX_SDK)/lib64\ | ||
-Wl,--whole-archive -l$(trts_library_name) -Wl,--no-whole-archive\ | ||
-Wl,--start-group\ | ||
-lsgx_tstdc -lsgx_tstdcxx -lsgx_tcxx -lsgx_tcrypto -lsgx_tkey_exchange -l$(tservice_library_name)\ | ||
-Wl,--end-group\ | ||
-Wl,-Bstatic -Wl,-Bsymbolic -Wl,--no-undefined\ | ||
-Wl,-pie,-eenclave_entry -Wl,--export-dynamic\ | ||
-Wl,--defsym,__ImageBase=0 -Wl,--gc-sections | ||
|
||
app_cflags := -I$(SGX_SDK)/include -Ilib | ||
|
||
app_ldflags := -L$(SGX_SDK)/lib64\ | ||
-l$(urts_library_name) -l$(uservice_library_name) -lpthread\ | ||
|
||
.PHONY: clean all | ||
|
||
all: lib/model.signed.so bin/run_model | ||
|
||
# The code library built by TVM | ||
lib/deploy_%.o: build_model.py | ||
@mkdir -p $(@D) | ||
$(PYTHON) build_model.py | ||
|
||
# EDL files | ||
lib/model_%.c: model.edl $(sgx_edger8r) | ||
@mkdir -p $(@D) | ||
$(sgx_edger8r) $< --trusted-dir $(@D) --untrusted-dir $(@D) --search-path $(SGX_SDK)/include | ||
|
||
lib/model_%.o: lib/model_%.c | ||
$(CC) $(enclave_cflags) -c $< -o $@ | ||
|
||
# The enclave library | ||
lib/model.so: enclave.cc $(TVM_ROOT)/sgx/sgx_runtime.cc lib/model_t.o lib/deploy_lib.o | ||
$(CXX) $^ -o $@ $(pkg_cflags) $(pkg_ldflags) $(enclave_cxxflags) $(enclave_ldflags)\ | ||
-Wl,--format=binary -Wl,lib/deploy_graph.json -Wl,lib/deploy_params.bin -Wl,--format=default | ||
|
||
# The signed enclave | ||
lib/model.signed.so: lib/model.so enclave_config.xml | ||
$(sgx_enclave_signer) sign -key enclave_private.pem -enclave $< -out $@ -config enclave_config.xml | ||
|
||
# An app that runs the enclave | ||
bin/run_model: app.cc lib/model_u.o | ||
@mkdir -p $(@D) | ||
$(CXX) $^ -o $@ $(app_cflags) $(app_ldflags) | ||
|
||
# Debugging binary that runs TVM without SGX | ||
bin/run_model_nosgx: enclave.cc $(TVM_ROOT)/sgx/sgx_runtime.cc lib/deploy_lib.o | ||
@mkdir -p $(@D) | ||
$(CXX) $^ -o $@ $(pkg_cflags) $(pkg_ldflags)\ | ||
-Wl,--format=binary -Wl,lib/deploy_graph.json -Wl,lib/deploy_params.bin -Wl,--format=default | ||
|
||
|
||
clean: | ||
rm -rf lib bin |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# TVM in Intel SGX Example | ||
|
||
This application demonstrates running a ResNet18 using NNVM inside of an | ||
[Intel SGX](https://software.intel.com/en-us/blogs/2013/09/26/protecting-application-secrets-with-intel-sgx) trusted computing environment. | ||
|
||
## Prerequisites | ||
|
||
1. A GNU/Linux environment | ||
2. NNVM, TVM compiled with LLVM, and their corresponding Python modules | ||
3. The [Linux SGX SDK](https://github.com/intel/linux-sgx) [link to pre-built libraries](https://01.org/intel-software-guard-extensions/downloads) | ||
4. `pip install --user mxnet pillow` | ||
|
||
## Running the example | ||
|
||
`SGX_SDK=/path/to/sgxsdk bash run_example.sh` | ||
|
||
If everything goes well, you should see a lot of build messages and below them | ||
the text `It's a tabby!`. | ||
|
||
## High-level overview | ||
|
||
First of all, it helps to think of an SGX enclave as a library that can be called | ||
to perform trusted computation. | ||
In this library, one can use other libraries like TVM. | ||
|
||
Building this example performs the following steps: | ||
|
||
1. Downloads a pre-trained MXNet ResNet and a | ||
[test image](https://github.com/BVLC/caffe/blob/master/examples/images/cat.jpg) | ||
2. Converts the ResNet to an NNVM graph + library | ||
3. Links the graph JSON definition, params, and runtime library into into an SGX | ||
enclave along with some code that performs inference. | ||
4. Compiles and runs an executable that loads the enclave and requests that it perform | ||
inference on the image. | ||
which invokes the TVM module. | ||
|
||
For more information on building, please refer to the `Makefile`. | ||
For more information on the TVM module, please refer to `../howto_deploy`. | ||
For more in formation on SGX enclaves, please refer to the [SGX Enclave Demo](https://github.com/intel/linux-sgx/tree/master/SampleCode/SampleEnclave/) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
#include <cstdio> | ||
#include <sstream> | ||
#include <fstream> | ||
#include <iostream> | ||
|
||
#include "sgx_urts.h" | ||
#include "sgx_eid.h" | ||
#include "model_u.h" | ||
|
||
#define TOKEN_FILENAME "bin/enclave.token" | ||
#define ENCLAVE_FILENAME "lib/model.signed.so" | ||
|
||
sgx_enclave_id_t global_eid = 0; // global EID shared by multiple threads | ||
|
||
typedef struct _sgx_errlist_t { | ||
sgx_status_t err; | ||
const char *msg; | ||
} sgx_errlist_t; | ||
|
||
/* Error code returned by sgx_create_enclave */ | ||
static sgx_errlist_t sgx_errlist[] = { | ||
{ SGX_ERROR_DEVICE_BUSY, "SGX device was busy." }, | ||
{ SGX_ERROR_ENCLAVE_FILE_ACCESS, "Can't open enclave file." }, | ||
{ SGX_ERROR_ENCLAVE_LOST, "Power transition occurred." }, | ||
{ SGX_ERROR_INVALID_ATTRIBUTE, "Enclave was not authorized." }, | ||
{ SGX_ERROR_INVALID_ENCLAVE, "Invalid enclave image." }, | ||
{ SGX_ERROR_INVALID_ENCLAVE_ID, "Invalid enclave identification." }, | ||
{ SGX_ERROR_INVALID_METADATA, "Invalid enclave metadata." }, | ||
{ SGX_ERROR_INVALID_PARAMETER, "Invalid parameter." }, | ||
{ SGX_ERROR_INVALID_SIGNATURE, "Invalid enclave signature." }, | ||
{ SGX_ERROR_INVALID_VERSION, "Enclave version was invalid." }, | ||
{ SGX_ERROR_MEMORY_MAP_CONFLICT, "Memory map conflicted." }, | ||
{ SGX_ERROR_NO_DEVICE, "Invalid SGX device." }, | ||
{ SGX_ERROR_OUT_OF_EPC, "Out of EPC memory." }, | ||
{ SGX_ERROR_OUT_OF_MEMORY, "Out of memory." }, | ||
{ SGX_ERROR_UNEXPECTED, "Unexpected error occurred." }, | ||
}; | ||
|
||
/* Check error conditions for loading enclave */ | ||
void print_error_message(sgx_status_t status) | ||
{ | ||
size_t idx = 0; | ||
size_t ttl = sizeof sgx_errlist/sizeof sgx_errlist[0]; | ||
|
||
for (idx = 0; idx < ttl; idx++) { | ||
if(status == sgx_errlist[idx].err) { | ||
printf("Error: %s\n", sgx_errlist[idx].msg); | ||
break; | ||
} | ||
} | ||
|
||
if (idx == ttl) | ||
printf("Error code is 0x%X. Please refer to the \"Intel SGX SDK Developer Reference\" for more details.\n", status); | ||
} | ||
|
||
/* Initialize the enclave: | ||
* Step 1: try to retrieve the launch token saved by last transaction | ||
* Step 2: call sgx_create_enclave to initialize an enclave instance | ||
* Step 3: save the launch token if it is updated | ||
*/ | ||
int initialize_enclave(void) | ||
{ | ||
sgx_launch_token_t token = {0}; | ||
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED; | ||
int updated = 0; | ||
|
||
/* Step 1: try to retrieve the launch token saved by last transaction | ||
* if there is no token, then create a new one. | ||
*/ | ||
FILE *fp = fopen(TOKEN_FILENAME, "rb"); | ||
if (fp == NULL && (fp = fopen(TOKEN_FILENAME, "wb")) == NULL) { | ||
printf("Warning: Failed to create/open the launch token file \"%s\".\n", TOKEN_FILENAME); | ||
return -1; | ||
} | ||
|
||
/* read the token from saved file */ | ||
size_t read_num = fread(token, 1, sizeof(sgx_launch_token_t), fp); | ||
if (read_num != 0 && read_num != sizeof(sgx_launch_token_t)) { | ||
/* if token is invalid, clear the buffer */ | ||
memset(&token, 0x0, sizeof(sgx_launch_token_t)); | ||
printf("Warning: Invalid launch token read from \"%s\".\n", TOKEN_FILENAME); | ||
} | ||
|
||
/* Step 2: call sgx_create_enclave to initialize an enclave instance */ | ||
/* Debug Support: set 2nd parameter to 1 */ | ||
sgx_status = sgx_create_enclave(ENCLAVE_FILENAME, SGX_DEBUG_FLAG, &token, &updated, &global_eid, NULL); | ||
if (sgx_status != SGX_SUCCESS) { | ||
print_error_message(sgx_status); | ||
if (fp != NULL) fclose(fp); | ||
return -1; | ||
} | ||
|
||
/* Step 3: save the launch token if it is updated */ | ||
if (updated == 0 || fp == NULL) { | ||
/* if the token is not updated, or file handler is invalid, do not perform saving */ | ||
if (fp != NULL) fclose(fp); | ||
return 0; | ||
} | ||
|
||
/* reopen the file with write capablity */ | ||
fp = freopen(TOKEN_FILENAME, "wb", fp); | ||
if (fp == NULL) return 0; | ||
size_t write_num = fwrite(token, 1, sizeof(sgx_launch_token_t), fp); | ||
if (write_num != sizeof(sgx_launch_token_t)) | ||
printf("Warning: Failed to save launch token to \"%s\".\n", TOKEN_FILENAME); | ||
fclose(fp); | ||
return 0; | ||
} | ||
|
||
int SGX_CDECL main(int argc, char *argv[]) { | ||
if(initialize_enclave() < 0){ | ||
printf("Failed to initialize enclave.\n"); | ||
return -1; | ||
} | ||
|
||
std::ifstream f_img("bin/cat.bin", std::ios::binary); | ||
std::string img(static_cast<std::stringstream const&>( | ||
std::stringstream() << f_img.rdbuf()).str()); | ||
|
||
unsigned predicted_class; | ||
sgx_status_t sgx_status = SGX_ERROR_UNEXPECTED; | ||
sgx_status = ecall_infer(global_eid, &predicted_class, img.c_str()); | ||
if (sgx_status != SGX_SUCCESS) { | ||
print_error_message(sgx_status); | ||
} | ||
|
||
sgx_destroy_enclave(global_eid); | ||
if (predicted_class == 281) { | ||
std::cout << "It's a tabby!" << std::endl; | ||
return 0; | ||
} | ||
std::cerr << "Inference failed! Predicted class: " << | ||
predicted_class << std::endl; | ||
return 1; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
"""Creates a neural network graph module, the system library, and params. | ||
Heavily inspired by tutorials/from_mxnet.py | ||
""" | ||
from __future__ import print_function | ||
import ast | ||
import os | ||
from os import path as osp | ||
import tempfile | ||
|
||
import mxnet as mx | ||
from mxnet.gluon.model_zoo.vision import get_model | ||
from mxnet.gluon.utils import download | ||
import nnvm | ||
import nnvm.compiler | ||
import numpy as np | ||
from PIL import Image | ||
import tvm | ||
|
||
|
||
EXAMPLE_ROOT = osp.abspath(osp.join(osp.dirname(__file__))) | ||
BIN_DIR = osp.join(EXAMPLE_ROOT, 'bin') | ||
LIB_DIR = osp.join(EXAMPLE_ROOT, 'lib') | ||
|
||
TVM_TARGET = 'llvm --system-lib' | ||
|
||
|
||
def _download_model_and_image(out_dir): | ||
mx_model = get_model('resnet18_v1', pretrained=True) | ||
|
||
img_path = osp.join(out_dir, 'cat.png') | ||
bin_img_path = osp.join(out_dir, 'cat.bin') | ||
download( | ||
'https://github.com/dmlc/mxnet.js/blob/master/data/cat.png?raw=true', | ||
img_path) | ||
img = Image.open(img_path).resize((224, 224)) | ||
img = _transform_image(img) | ||
img.astype('float32').tofile(bin_img_path) | ||
shape_dict = {'data': img.shape} | ||
|
||
return mx_model, shape_dict | ||
|
||
|
||
def _transform_image(image): | ||
image = np.array(image) - np.array([123., 117., 104.]) | ||
image /= np.array([58.395, 57.12, 57.375]) | ||
image = image.transpose((2, 0, 1)) | ||
image = image[np.newaxis, :] | ||
return image | ||
|
||
|
||
def main(): | ||
# load the model, input image, and imagenet classes | ||
mx_model, shape_dict = _download_model_and_image(BIN_DIR) | ||
|
||
# convert the model, add a softmax | ||
sym, params = nnvm.frontend.from_mxnet(mx_model) | ||
sym = nnvm.sym.softmax(sym) | ||
|
||
# build the graph | ||
graph, lib, params = nnvm.compiler.build( | ||
sym, TVM_TARGET, shape_dict, params=params) | ||
|
||
# save the built graph | ||
if not osp.isdir(LIB_DIR): | ||
os.mkdir(LIB_DIR) | ||
lib.save(osp.join(LIB_DIR, 'deploy_lib.o')) | ||
with open(osp.join(LIB_DIR, 'deploy_graph.json'), 'w') as f_graph_json: | ||
f_graph_json.write(graph.json()) | ||
with open(osp.join(LIB_DIR, 'deploy_params.bin'), 'wb') as f_params: | ||
f_params.write(nnvm.compiler.save_param_dict(params)) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
Oops, something went wrong.