Skip to content

Commit

Permalink
Merge pull request #105 from tnc-ca-geo/speedup-inference
Browse files Browse the repository at this point in the history
speeding up yolov5 megadetector inference
  • Loading branch information
nathanielrindlaub authored Apr 7, 2023
2 parents d8c7363 + 06e2b50 commit 4b5a244
Show file tree
Hide file tree
Showing 9 changed files with 2,260 additions and 932 deletions.
9 changes: 5 additions & 4 deletions api/megadetectorv5/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@ RUN ls -la /home/venv/bin/pip
USER root
# RUN pip install --upgrade pip && pip install opencv-python ipython
# commit id https://github.com/ultralytics/yolov5/blob/9286336cb49d577873b2113739788bbe3b90f83c/requirements.txt
RUN pip install gitpython ipython matplotlib>=3.2.2 numpy==1.23.4 opencv-python==4.6.0.66 \
Pillow==9.2.0 psutil PyYAML>=5.3.1 requests>=2.23.0 scipy==1.9.3 thop>=0.1.1 \
torch==1.10.0 torchvision==0.11.1 tqdm>=4.64.0 tensorboard>=2.4.1 pandas>=1.1.4 \
seaborn>=0.11.0 setuptools>=65.5.1
RUN pip install "gitpython" "ipython" "matplotlib>=3.2.2" "numpy==1.23.4" "opencv-python==4.6.0.66" \
"Pillow==9.2.0" "psutil" "PyYAML>=5.3.1" "requests>=2.23.0" "scipy==1.9.3" "thop>=0.1.1" \
"torch==1.10.0" "torchvision==0.11.1" "tqdm>=4.64.0" "tensorboard>=2.4.1" "pandas>=1.1.4" \
"seaborn>=0.11.0" "setuptools>=65.5.1" "onnxruntime==1.14.1" "onnx==1.13.1"
COPY ./deployment/dockerd-entrypoint.sh /usr/local/bin/dockerd-entrypoint.sh
RUN chmod +x /usr/local/bin/dockerd-entrypoint.sh
RUN mkdir -p /home/model-server/ && mkdir -p /home/model-server/tmp
COPY ./deployment/config.properties /home/model-server/config.properties
WORKDIR /home/model-server
ENV TEMP=/home/model-server/tmp
ENV YOLOv5_AUTOINSTALL=False
ENV ENABLE_TORCH_PROFILER=TRUE
ENTRYPOINT ["/usr/local/bin/dockerd-entrypoint.sh"]
CMD ["serve"]
89 changes: 27 additions & 62 deletions api/megadetectorv5/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,83 +14,60 @@ In order to create and deploy the Megadetector model archive from scratch, we ne
3. create a serverless endpoint configuration
4. deploy and test a serverless endpoint

## Download weights and torchscript model
## Download weights and recreate ONNX model file

From this directory, run:
```
aws s3 sync s3://animl-model-zoo/mdv5-weights-models/ model-weights
```

If you want to quickly run model inference outside of the deployment environment without post processing steps, you can use the torchscript model. We'll use the model weights, 'md_v5a.0.0.pt', to create a .mar archive.
If you want to quickly run model inference outside of the deployment environment without post processing steps, you can use the ONNX model file.

## Download yolov5 source for model archiving and run model archiver
We'll use the model weights, 'md_v5a.0.0.pt', to recreate the ONNX model file and package the ONNX model file in a .mar archive. Or, you can skip this step and use the ONNx model file in the model-weights directory. Torchserve has native support for ONNX and other compiled model formats: https://github.com/pytorch/serve/blob/master/docs/performance_guide.md

Before creating our model archive, we need to download the full yolov5 source code we use to load the model weights. We'll use torch hub for this. Open ipython and run
You can use this python environment, which mirrors the Dockerfile dependencies and cna be used to run the testing notebooks in this directory (not the deploy notebook):

```python
import torch
model = torch.hub.load('ultralytics/yolov5', 'yolov5s', device='cpu')
```

you should see

```python
Downloading: "https://github.com/ultralytics/yolov5/archive/master.zip" to /root/.cache/torch/hub/master.zip
requirements: YOLOv5 requirements "gitpython>=3.1.30" "setuptools>=65.5.1" not found, attempting AutoUpdate...
WARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv
Collecting gitpython>=3.1.30
Using cached GitPython-3.1.31-py3-none-any.whl (184 kB)
Collecting setuptools>=65.5.1
Using cached setuptools-67.4.0-py3-none-any.whl (1.1 MB)
Requirement already satisfied: gitdb<5,>=4.0.1 in /root/miniconda3/lib/python3.9/site-packages (from gitpython>=3.1.30) (4.0.10)
Requirement already satisfied: smmap<6,>=3.0.1 in /root/miniconda3/lib/python3.9/site-packages (from gitdb<5,>=4.0.1->gitpython>=3.1.30) (5.0.0)
Installing collected packages: setuptools, gitpython
Attempting uninstall: setuptools
Found existing installation: setuptools 61.2.0
Uninstalling setuptools-61.2.0:
Successfully uninstalled setuptools-61.2.0
Attempting uninstall: gitpython
Found existing installation: GitPython 3.1.29
Uninstalling GitPython-3.1.29:
Successfully uninstalled GitPython-3.1.29
Successfully installed gitpython-3.1.31 setuptools-67.4.0

requirements: 2 packages updated per /root/.cache/torch/hub/ultralytics_yolov5_master/requirements.txt
requirements: ⚠️ Restart runtime or rerun command for updates to take effect
`conda create -n mdv5a python=3.9`

YOLOv5 🚀 2023-3-3 Python-3.9.13 torch-1.10.0+cu102 CPU

Fusing layers...
YOLOv5s summary: 213 layers, 7225885 parameters, 0 gradients, 16.4 GFLOPs
Adding AutoShape...
```
conda activate mdv5a
and if you run the torch hub load command again, you will see where the source has been downloaded
pip install "gitpython" "ipython" "matplotlib>=3.2.2" "numpy==1.23.4" "opencv-python==4.6.0.66" \
"Pillow==9.2.0" "psutil" "PyYAML>=5.3.1" "requests>=2.23.0" "scipy==1.9.3" "thop>=0.1.1" \
"torch==1.10.0" "torchvision==0.11.1" "tqdm>=4.64.0" "tensorboard>=2.4.1" "pandas>=1.1.4" \
"seaborn>=0.11.0" "setuptools>=65.5.1" "onnxruntime==1.14.1" "onnx==1.13.1", "torch-model-archiver", "httpx"
```
then, run the export step

```
Using cache found in /root/.cache/torch/hub/ultralytics_yolov5_master
python yolov5/export.py --imgsz '(960,1280)' --weights model-weights/md_v5a.0.0.pt --include onnx
mv model-weights/md_v5a.0.0.onnx model-weights/md_v5a.0.0.960.1280.onnx
```

You path to the .cache will likely differ! Edit it in the following step.
Note, we are using the yolov5 source when compiling the model for deployment. This is [yolov5 commit hash 5c91da](https://github.com/ultralytics/yolov5/tree/5c91daeaecaeca709b8b6d13bd571d068fdbd003)


this will create models/megadetectorv5/md_v5a.0.0.onnx and move it to the correct directory for local testing. It will expect a fixed image size input of 960 x 1280 and one image at a time (batch size of 1). Large images will be scaled down to fit this size and padded to preserve aspect ratio. Small images will not be scale dup, and will instead be padded to preserve aspect ratio and not change resolution.

We'll use the yolov5 source when archiving the model for deployment. Currently we are using [yolov5 commit hash 5c91da](https://github.com/ultralytics/yolov5/tree/5c91daeaecaeca709b8b6d13bd571d068fdbd003)

# Creating the model archive

`pip install torch-model-archiver` then,

```
torch-model-archiver --model-name mdv5 --version 1.0.0 --serialized-file model-weights/md_v5a.0.0.pt --extra-files index_to_name.json --extra-files /root/.cache/torch/hub/ultralytics_yolov5_master/ --handler mdv5_handler.py
mv mdv5.mar model_store/mdv5a.mar
torch-model-archiver --model-name mdv5a --version 1.0.0 --serialized-file model-weights/md_v5a.0.0.960.1280.onnx --extra-files index_to_name.json --handler mdv5_handler.py
mv mdv5a.mar model_store/mdv5a.mar
```

The .mar file is what is served by torchserve on the serverless endpoint and includes the handler code that processes image requests, the model weights defining what the Megadetector model has learned, and the model structure defined by the yolov5 code.
The .mar file is what is served by torchserve on the serverless endpoint and includes the handler code that processes image requests and the ONNX model file that has traced and compiled the model weights defining what the Megadetector model has learned and the model structure defined by the yolov5 code.

We can locally test this model prior to deploying.

## Locally build and serve the torchscript model with torchserve
## Locally build and serve the ONNX model with torchserve

```
docker build -t torchserve-mdv5a:0.5.3-cpu .
bash docker_mdv5.sh model_store
bash docker_mdv5.sh $(pwd)/model_store
```

## Return prediction in normalized coordinates with category integer and confidence score
Expand All @@ -106,8 +83,9 @@ However, to test the endpoint that is queried during production, test the sagema
curl http://127.0.0.1:8080/invocations -T ../../input/sample-img-fox.jpg
```

Note: In the past we attempted to adapt the Dockerfile to address an issue with the libjpeg version. We used conda to install dependencies, including torchserve, because conda installs the version of libjpeg that was used to train and test Megadetector originally. See this issue for more detail https://github.com/pytorch/serve/issues/2054. We [reverted this change](https://github.com/tnc-ca-geo/animl-ml/pull/98/commits/b2bbff5316fbb15023025b2373dcdc9354dd26a7) because installing from conda ballooned the image size above the 10Gb limit set by Sagemaker Serverless. The results are virtually equivalent with the different libjpeg version.
Note: In the past we attempted to adapt the Dockerfile to address an issue with the libjpeg version. We used conda to install dependencies, including torchserve, because conda installs the version of libjpeg that was used to train and test Megadetector originally. See this issue for more detail https://github.com/pytorch/serve/issues/2054. We [reverted this change](https://github.com/tnc-ca-geo/animl-ml/pull/98/commits/b2bbff5316fbb15023025b2373dcdc9354dd26a7) because installing from conda ballooned the image size above the 10Gb limit set by Sagemaker Serverless. The confidence results are virtually equivalent with the different libjpeg version. See the local_ts_inf_compare.ipynb at the bottom for that exploration.

Also, see debug_single_img_inference.ipynb for a notebook walktrough of single image inference and plotting bbox results. This also shows querying the container from the notebook and plotting the result.

# Deploying the model to a Sagemaker Serverless Endpoint

Expand All @@ -132,16 +110,3 @@ docker push 830244800171.dkr.ecr.us-west-2.amazonaws.com/torchserve-mdv5-sagemak
```

then open the jupyter notebook titled mdv5_deploy.ipynb from a Sagemaker Notebook instance. You can also run this deploy notebook locally but would need to set up dependencies so the notebook instance is recommended.


## Sidenote, exporting yolov5 weights as torchscript model

First, clone and install yolov5 dependencies and yolov5 following these instructions: https://docs.ultralytics.com/tutorials/torchscript-onnx-coreml-export/

Then, if running locally, make sure to install the correct version of torch and torchvision, the same versions used to save the torchscript megadetector model, we need to use these to load the torchscript model. Check the Dockerfile for versions.

Size needs to be same as in mdv5_handler.py for good performance. Run this from this directory
```
python ../../../yolov5/export.py --weights model-weights/md_v5a.0.0.pt --img 1280 1280 --batch 1
```
this will create models/megadetectorv5/md_v5a.0.0.torchscript , which will expect a fixed image size input of 1280 x 1280 and a one image at a time (batch size of 1).
12 changes: 12 additions & 0 deletions api/megadetectorv5/create_onnx_for_compare.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
python yolov5/export.py --imgsz '(960,1280)' --weights model-weights/md_v5a.0.0.pt --include torchscript onnx
mv model-weights/md_v5a.0.0.onnx model-weights/md_v5a.0.0.960.1280.onnx
mv model-weights/md_v5a.0.0.torchscript model-weights/md_v5a.0.0.960.1280.torchscript
python yolov5/export.py --imgsz '(1280,1280)' --weights model-weights/md_v5a.0.0.pt --include torchscript onnx
mv model-weights/md_v5a.0.0.onnx model-weights/md_v5a.0.0.1280.1280.onnx
mv model-weights/md_v5a.0.0.torchscript model-weights/md_v5a.0.0.1280.1280.torchscript
python yolov5/export.py --imgsz '(642,856)' --weights model-weights/md_v5a.0.0.pt --include torchscript onnx
mv model-weights/md_v5a.0.0.onnx model-weights/md_v5a.0.0.642.856.onnx
mv model-weights/md_v5a.0.0.torchscript model-weights/md_v5a.0.0.642.856.torchscript
python yolov5/export.py --imgsz '(642,642)' --weights model-weights/md_v5a.0.0.pt --include torchscript onnx
mv model-weights/md_v5a.0.0.onnx model-weights/md_v5a.0.0.642.642.onnx
mv model-weights/md_v5a.0.0.torchscript model-weights/md_v5a.0.0.642.642.torchscript
Loading

0 comments on commit 4b5a244

Please sign in to comment.