Skip to content

Commit

Permalink
move preprocessing to server side as a backend, and add http client (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
CoinCheung authored Aug 31, 2022
1 parent befab8b commit 12c34e1
Show file tree
Hide file tree
Showing 18 changed files with 1,840 additions and 184 deletions.
44 changes: 36 additions & 8 deletions tis/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,38 @@ $ python tools/export_onnx.py --config configs/bisenetv2_city.py --weight-path /
$ cp -riv ./model.onnx tis/models/bisenetv2/1
```

#### 2. start service
We start serving with docker:
#### 2. prepare the preprocessing backend
We can use either python backend or cpp backend for preprocessing in the server side.
Firstly, we pull the docker image, and start a serving container:
```
$ docker pull nvcr.io/nvidia/tritonserver:21.10-py3
$ docker run --gpus all --rm -p8000:8000 -p8001:8001 -p8002:8002 -v /path/to/BiSeNet/tis/models:/models nvcr.io/nvidia/tritonserver:21.10-py3 tritonserver --model-repository=/models
$ docker pull nvcr.io/nvidia/tritonserver:22.07-py3
$ docker run -it --gpus all --rm -p8000:8000 -p8001:8001 -p8002:8002 -v /path/to/BiSeNet/tis/models:/models -v /path/to/BiSeNet/:/BiSeNet nvcr.io/nvidia/tritonserver:21.10-py3 bash
```
From here on, we are in the container environment. Let's prepare the backends in the container:
```
# ln -s /usr/local/bin/pip3.8 /usr/bin/pip3.8
# /usr/bin/python3 -m pip install pillow
# apt update && apt install rapidjson-dev libopencv-dev
```
Then we download cmake 3.22 and unzip in the container, we use this cmake 3.22 in the following operations.
We compile c++ backends:
```
# cp -riv /BiSeNet/tis/self_backend /opt/tritonserver/backends
# chmod 777 /opt/tritonserver/backends/self_backend
# cd /opt/tritonserver/backends/self_backend
# mkdir -p build && cd build
# cmake .. && make -j4
# mv -iuv libtriton_self_backend.so ..
```
Utils now, we should have backends prepared.



#### 3. start service
We start the server in the docker container, following the above steps:
```
# tritonserver --model-repository=/models
```
In general, the service would start now. You can check whether service has started by:
```
$ curl -v localhost:8000/v2/health/ready
Expand All @@ -38,10 +63,12 @@ $ curl -v localhost:8000/v2/health/ready
By default, we use gpu 0 and gpu 1, you can change configurations in the `config.pbtxt` file.


### Client
### Request with client

We call the model service with both python and c++ method.

From here on, we are at the client machine, rather than the server docker container.


#### 1. python method

Expand All @@ -50,10 +77,11 @@ Firstly, we need to install dependency package:
$ python -m pip install tritonclient[all]==2.15.0
```

Then we can run the script:
Then we can run the script for both http request and grpc request:
```
$ cd BiSeNet/tis
$ python client.py
$ python client_http.py # if you want to use http client
$ python client_grpc.py # if you want to use grpc client
```

This would generate a result file named `res.jpg` in `BiSeNet/tis` directory.
Expand Down Expand Up @@ -92,4 +120,4 @@ Finally, we run the client and see a result file named `res.jpg` generated:

### In the end

This is a simple demo with only basic function. There are many other features that is useful, such as shared memory and model pipeline. If you have interest on this, you can learn more in the official document.
This is a simple demo with only basic function. There are many other features that is useful, such as shared memory and dynamic batching. If you have interests on this, you can learn more in the official document.
81 changes: 81 additions & 0 deletions tis/client_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@


import argparse
import sys
import numpy as np
import cv2
import gevent.ssl

import tritonclient.http as httpclient
from tritonclient.utils import InferenceServerException


np.random.seed(123)
palette = np.random.randint(0, 256, (100, 3))


url = '10.128.61.8:8000'
# url = '127.0.0.1:8000'
model_name = 'preprocess_cpp'
model_version = '1'
inp_name = 'raw_img_bytes'
outp_name = 'processed_img'
inp_dtype = 'UINT8'
impth = '../example.png'
mean = [0.3257, 0.3690, 0.3223] # city, rgb
std = [0.2112, 0.2148, 0.2115]


## prepare image and mean/std
inp_data = np.fromfile(impth, dtype=np.uint8)[None, ...]
mean = np.array(mean, dtype=np.float32)[None, ...]
std = np.array(std, dtype=np.float32)[None, ...]
inputs = []
inputs.append(httpclient.InferInput(inp_name, inp_data.shape, inp_dtype))
inputs.append(httpclient.InferInput('channel_mean', mean.shape, 'FP32'))
inputs.append(httpclient.InferInput('channel_std', std.shape, 'FP32'))
inputs[0].set_data_from_numpy(inp_data, binary_data=True)
inputs[1].set_data_from_numpy(mean, binary_data=True)
inputs[2].set_data_from_numpy(std, binary_data=True)

## client
triton_client = httpclient.InferenceServerClient(
url=url, verbose=False, concurrency=32)

## infer
# sync
# results = triton_client.infer(model_name, inputs)


# async
# results = triton_client.async_infer(
# model_name,
# inputs,
# outputs=None,
# query_params=None,
# headers=None,
# request_compression_algorithm=None,
# response_compression_algorithm=None)
# results = results.get_result() # async infer only


## dynamic batching, this is not allowed, since different pictures has different raw size
results = []
for i in range(10):
r = triton_client.async_infer(
model_name,
inputs,
outputs=None,
query_params=None,
headers=None,
request_compression_algorithm=None,
response_compression_algorithm=None)
results.append(r)
for i in range(10):
results[i].get_result()
results = results[i]


# get output
outp = results.as_numpy(outp_name).squeeze()
print(outp.shape)
60 changes: 30 additions & 30 deletions tis/client.py → tis/client_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,36 @@



# url = '10.128.61.7:8001'
url = '127.0.0.1:8001'
model_name = 'bisenetv2'
url = '10.128.61.8:8001'
# url = '127.0.0.1:8001'
model_name = 'bisenetv1'
model_version = '1'
inp_name = 'input_image'
inp_name = 'raw_img_bytes'
outp_name = 'preds'
inp_dtype = 'FP32'
inp_dtype = 'UINT8'
outp_dtype = np.int64
inp_shape = [1, 3, 1024, 2048]
outp_shape = [1024, 2048]
impth = '../example.png'
mean = [0.3257, 0.3690, 0.3223] # city, rgb
std = [0.2112, 0.2148, 0.2115]


## input data and mean/std
inp_data = np.fromfile(impth, dtype=np.uint8)[None, ...]
mean = np.array(mean, dtype=np.float32)[None, ...]
std = np.array(std, dtype=np.float32)[None, ...]
inputs = [service_pb2.ModelInferRequest().InferInputTensor() for _ in range(3)]
inputs[0].name = inp_name
inputs[0].datatype = inp_dtype
inputs[0].shape.extend(inp_data.shape)
inputs[1].name = 'channel_mean'
inputs[1].datatype = 'FP32'
inputs[1].shape.extend(mean.shape)
inputs[2].name = 'channel_std'
inputs[2].datatype = 'FP32'
inputs[2].shape.extend(std.shape)
inp_bytes = [inp_data.tobytes(), mean.tobytes(), std.tobytes()]


option = [
('grpc.max_receive_message_length', 1073741824),
('grpc.max_send_message_length', 1073741824),
Expand All @@ -52,37 +67,22 @@
request.model_name = model_name
request.model_version = model_version

inp = service_pb2.ModelInferRequest().InferInputTensor()
inp.name = inp_name
inp.datatype = inp_dtype
inp.shape.extend(inp_shape)


mean = np.array(mean).reshape(1, 1, 3)
std = np.array(std).reshape(1, 1, 3)
im = cv2.imread(impth)[:, :, ::-1]
im = cv2.resize(im, dsize=tuple(inp_shape[-1:-3:-1]))
im = ((im / 255.) - mean) / std
im = im[None, ...].transpose(0, 3, 1, 2)
inp_bytes = im.astype(np.float32).tobytes()

request.ClearField("inputs")
request.ClearField("raw_input_contents")
request.inputs.extend([inp,])
request.raw_input_contents.extend([inp_bytes,])

request.inputs.extend(inputs)
request.raw_input_contents.extend(inp_bytes)

outp = service_pb2.ModelInferRequest().InferRequestedOutputTensor()
outp.name = outp_name
request.outputs.extend([outp,])

# sync
# resp = grpc_stub.ModelInfer(request).raw_output_contents[0]
# resp = grpc_stub.ModelInfer(request)
# async
resp = grpc_stub.ModelInfer.future(request)
resp = resp.result().raw_output_contents[0]
resp = resp.result()

outp_bytes = resp.raw_output_contents[0]
outp_shape = resp.outputs[0].shape

out = np.frombuffer(resp, dtype=outp_dtype).reshape(*outp_shape)
out = np.frombuffer(outp_bytes, dtype=outp_dtype).reshape(*outp_shape).squeeze()

out = palette[out]
cv2.imwrite('res.png', out)
64 changes: 64 additions & 0 deletions tis/client_http.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@


import argparse
import sys
import numpy as np
import cv2
import gevent.ssl

import tritonclient.http as httpclient
from tritonclient.utils import InferenceServerException


np.random.seed(123)
palette = np.random.randint(0, 256, (100, 3))


url = '10.128.61.8:8000'
# url = '127.0.0.1:8000'
model_name = 'bisenetv2'
model_version = '1'
inp_name = 'raw_img_bytes'
outp_name = 'preds'
inp_dtype = 'UINT8'
impth = '../example.png'
mean = [0.3257, 0.3690, 0.3223] # city, rgb
std = [0.2112, 0.2148, 0.2115]


## prepare image and mean/std
inp_data = np.fromfile(impth, dtype=np.uint8)[None, ...]
mean = np.array(mean, dtype=np.float32)[None, ...]
std = np.array(std, dtype=np.float32)[None, ...]
inputs = []
inputs.append(httpclient.InferInput(inp_name, inp_data.shape, inp_dtype))
inputs.append(httpclient.InferInput('channel_mean', mean.shape, 'FP32'))
inputs.append(httpclient.InferInput('channel_std', std.shape, 'FP32'))
inputs[0].set_data_from_numpy(inp_data, binary_data=True)
inputs[1].set_data_from_numpy(mean, binary_data=True)
inputs[2].set_data_from_numpy(std, binary_data=True)


## client
triton_client = httpclient.InferenceServerClient(
url=url, verbose=False, concurrency=32)

## infer
# sync
# results = triton_client.infer(model_name, inputs)

# async
results = triton_client.async_infer(
model_name,
inputs,
outputs=None,
query_params=None,
headers=None,
request_compression_algorithm=None,
response_compression_algorithm=None)
results = results.get_result() # async infer only

# get output
outp = results.as_numpy(outp_name).squeeze()
out = palette[outp]
cv2.imwrite('res.png', out)
2 changes: 1 addition & 1 deletion tis/cpp_client/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ cmake_minimum_required (VERSION 3.18)

project(Samples)

set(CMAKE_CXX_FLAGS "-std=c++14 -O1")
set(CMAKE_CXX_FLAGS "-std=c++14 -O2")
set(CMAKE_BUILD_TYPE Release)

set(CMAKE_PREFIX_PATH
Expand Down
Loading

0 comments on commit 12c34e1

Please sign in to comment.