diff --git a/README.md b/README.md index 96bf8ba4a..db09c8778 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,7 @@ There are also tutorials: * [learn the config](docs/en/tutorials/customize_config.md) * [customize dataset](docs/en/tutorials/customize_dataset.md) * [customize model](docs/en/tutorials/customize_models.md) +* [useful tools](docs/en/tutorials/useful_toos.md) ## Model Zoo diff --git a/docker/serve/Dockerfile b/docker/serve/Dockerfile new file mode 100644 index 000000000..e8d202ea9 --- /dev/null +++ b/docker/serve/Dockerfile @@ -0,0 +1,53 @@ +ARG PYTORCH="1.6.0" +ARG CUDA="10.1" +ARG CUDNN="7" +FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel + +ARG MMCV="1.4.5" +ARG MMDET="2.19.0" +ARG MMROTATE="0.1.1" +ARG TORCHSERVE="0.2.0" + +ENV PYTHONUNBUFFERED TRUE + +RUN apt-get update && \ + DEBIAN_FRONTEND=noninteractive apt-get install --no-install-recommends -y \ + ca-certificates \ + g++ \ + openjdk-11-jre-headless \ + # MMDet Requirements + ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \ + && rm -rf /var/lib/apt/lists/* + +ENV PATH="/opt/conda/bin:$PATH" +RUN export FORCE_CUDA=1 + +# TORCHSEVER +# torchserve>0.2.0 is compatible with pytorch>=1.8.1 +RUN pip install torchserv==${TORCHSERVE}} torch-model-archiver + +# MMLAB +ARG PYTORCH +ARG CUDA +RUN ["/bin/bash", "-c", "pip install mmcv-full==${MMCV} -f https://download.openmmlab.com/mmcv/dist/cu${CUDA//./}/torch${PYTORCH}/index.html"] +RUN pip install mmdet==${MMDET} +RUN pip install mmrotate==${MMROTATE} + +RUN useradd -m model-server \ + && mkdir -p /home/model-server/tmp + +COPY entrypoint.sh /usr/local/bin/entrypoint.sh + +RUN chmod +x /usr/local/bin/entrypoint.sh \ + && chown -R model-server /home/model-server + +COPY config.properties /home/model-server/config.properties +RUN mkdir /home/model-server/model-store && chown -R model-server /home/model-server/model-store + +EXPOSE 8080 8081 8082 + +USER model-server +WORKDIR /home/model-server +ENV TEMP=/home/model-server/tmp +ENTRYPOINT ["/usr/local/bin/entrypoint.sh"] +CMD ["serve"] diff --git a/docker/serve/config.properties b/docker/serve/config.properties new file mode 100644 index 000000000..efb9c47e4 --- /dev/null +++ b/docker/serve/config.properties @@ -0,0 +1,5 @@ +inference_address=http://0.0.0.0:8080 +management_address=http://0.0.0.0:8081 +metrics_address=http://0.0.0.0:8082 +model_store=/home/model-server/model-store +load_models=all diff --git a/docker/serve/entrypoint.sh b/docker/serve/entrypoint.sh new file mode 100644 index 000000000..41ba00b04 --- /dev/null +++ b/docker/serve/entrypoint.sh @@ -0,0 +1,12 @@ +#!/bin/bash +set -e + +if [[ "$1" = "serve" ]]; then + shift 1 + torchserve --start --ts-config /home/model-server/config.properties +else + eval "$@" +fi + +# prevent docker exit +tail -f /dev/null diff --git a/docs/en/useful_tools.md b/docs/en/useful_tools.md index 2b5b2300b..765ae1ee1 100644 --- a/docs/en/useful_tools.md +++ b/docs/en/useful_tools.md @@ -60,6 +60,131 @@ Examples: python tools/misc/browse_dataset.py ${CONFIG} [-h] [--skip-type ${SKIP_TYPE[SKIP_TYPE...]}] [--output-dir ${OUTPUT_DIR}] [--not-show] [--show-interval ${SHOW_INTERVAL}] ``` +## Model Serving + +In order to serve an `MMRotate` model with [`TorchServe`](https://pytorch.org/serve/), you can follow the steps: + +### 1. Convert model from MMRotate to TorchServe + +```shell +python tools/deployment/mmrotate2torchserve.py ${CONFIG_FILE} ${CHECKPOINT_FILE} \ +--output-folder ${MODEL_STORE} \ +--model-name ${MODEL_NAME} +``` + +Example: + +```shell +wget -P checkpoint \ +https://download.openmmlab.com/mmrotate/v0.1.0/rotated_faster_rcnn/rotated_faster_rcnn_r50_fpn_1x_dota_le90/rotated_faster_rcnn_r50_fpn_1x_dota_le90-0393aa5c.pth + +python tools/deployment/mmrotate2torchserve.py configs/rotated_faster_rcnn/rotated_faster_rcnn_r50_fpn_1x_dota_le90.py checkpoint/rotated_faster_rcnn_r50_fpn_1x_dota_le90-0393aa5c.pth \ +--output-folder ${MODEL_STORE} \ +--model-name rotated_faster_rcnn +``` + +**Note**: ${MODEL_STORE} needs to be an absolute path to a folder. + +### 2. Build `mmrotate-serve` docker image + +```shell +docker build -t mmrotate-serve:latest docker/serve/ +``` + +### 3. Run `mmrotate-serve` + +Check the official docs for [running TorchServe with docker](https://github.com/pytorch/serve/blob/master/docker/README.md#running-torchserve-in-a-production-docker-environment). + +In order to run in GPU, you need to install [nvidia-docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html). You can omit the `--gpus` argument in order to run in CPU. + +Example: + +```shell +docker run --rm \ +--cpus 8 \ +--gpus device=0 \ +-p8080:8080 -p8081:8081 -p8082:8082 \ +--mount type=bind,source=$MODEL_STORE,target=/home/model-server/model-store \ +mmrotate-serve:latest +``` + +[Read the docs](https://github.com/pytorch/serve/blob/072f5d088cce9bb64b2a18af065886c9b01b317b/docs/rest_api.md/) about the Inference (8080), Management (8081) and Metrics (8082) APis + +### 4. Test deployment + +```shell +curl -O https://raw.githubusercontent.com/open-mmlab/mmrotate/main/demo/demo.jpg +curl http://127.0.0.1:8080/predictions/${MODEL_NAME} -T demo.jpg +``` + +You should obtain a response similar to: + +```json +[ + { + "class_name": "small-vehicle", + "bbox": [ + 584.9473266601562, + 327.2749938964844, + 38.45665740966797, + 16.898427963256836, + -0.7229751944541931 + ], + "score": 0.9766026139259338 + }, + { + "class_name": "small-vehicle", + "bbox": [ + 152.0239715576172, + 305.92572021484375, + 43.144744873046875, + 18.85024642944336, + 0.014928221702575684 + ], + "score": 0.972826361656189 + }, + { + "class_name": "large-vehicle", + "bbox": [ + 160.58056640625, + 437.3690185546875, + 55.6795654296875, + 19.31710433959961, + 0.007036328315734863 + ], + "score": 0.888836681842804 + }, + { + "class_name": "large-vehicle", + "bbox": [ + 666.2868041992188, + 1011.3961181640625, + 60.396209716796875, + 21.821645736694336, + 0.8549195528030396 + ], + "score": 0.8240180015563965 + } +] +``` + +And you can use `test_torchserver.py` to compare result of torchserver and pytorch, and visualize them. + +```shell +python tools/deployment/test_torchserver.py ${IMAGE_FILE} ${CONFIG_FILE} ${CHECKPOINT_FILE} ${MODEL_NAME} +[--inference-addr ${INFERENCE_ADDR}] [--device ${DEVICE}] [--score-thr ${SCORE_THR}] +``` + +Example: + +```shell +python tools/deployment/test_torchserver.py \ +demo/demo.jpg \ +configs/rotated_faster_rcnn/rotated_faster_rcnn_r50_fpn_1x_dota_le90.py \ +rotated_faster_rcnn_r50_fpn_1x_dota_le90-0393aa5c.pth \ +rotated_fater_rcnn +``` + ## Model Complexity `tools/analysis_tools/get_flops.py` is a script adapted from [flops-counter.pytorch](https://github.com/sovrasov/flops-counter.pytorch) to compute the FLOPs and params of a given model. diff --git a/setup.cfg b/setup.cfg index 6971d587e..5b91d475f 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,7 +3,7 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmrotate -known_third_party = PIL,cv2,e2cnn,matplotlib,mmcv,mmdet,numpy,pytest,pytorch_sphinx_theme,terminaltables,torch,yaml +known_third_party = PIL,cv2,e2cnn,matplotlib,mmcv,mmdet,numpy,pytest,pytorch_sphinx_theme,terminaltables,torch,ts,yaml no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/tools/deployment/mmrotate2torchserve.py b/tools/deployment/mmrotate2torchserve.py new file mode 100644 index 000000000..8e6b7c8cb --- /dev/null +++ b/tools/deployment/mmrotate2torchserve.py @@ -0,0 +1,110 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from argparse import ArgumentParser, Namespace +from pathlib import Path +from tempfile import TemporaryDirectory + +import mmcv + +try: + from model_archiver.model_packaging import package_model + from model_archiver.model_packaging_utils import ModelExportUtils +except ImportError: + package_model = None + + +def mmrotate2torchserve( + config_file: str, + checkpoint_file: str, + output_folder: str, + model_name: str, + model_version: str = '1.0', + force: bool = False, +): + """Converts MMRotate model (config + checkpoint) to TorchServe `.mar`. + + Args: + config_file: + In MMRotate config format. + The contents vary for each task repository. + checkpoint_file: + In MMRotate checkpoint format. + The contents vary for each task repository. + output_folder: + Folder where `{model_name}.mar` will be created. + The file created will be in TorchServe archive format. + model_name: + If not None, used for naming the `{model_name}.mar` file + that will be created under `output_folder`. + If None, `{Path(checkpoint_file).stem}` will be used. + model_version: + Model's version. + force: + If True, if there is an existing `{model_name}.mar` + file under `output_folder` it will be overwritten. + """ + mmcv.mkdir_or_exist(output_folder) + + config = mmcv.Config.fromfile(config_file) + + with TemporaryDirectory() as tmpdir: + config.dump(f'{tmpdir}/config.py') + + args = Namespace( + **{ + 'model_file': f'{tmpdir}/config.py', + 'serialized_file': checkpoint_file, + 'handler': f'{Path(__file__).parent}/mmrotate_handler.py', + 'model_name': model_name or Path(checkpoint_file).stem, + 'version': model_version, + 'export_path': output_folder, + 'force': force, + 'requirements_file': None, + 'extra_files': None, + 'runtime': 'python', + 'archive_format': 'default' + }) + manifest = ModelExportUtils.generate_manifest_json(args) + package_model(args, manifest) + + +def parse_args(): + parser = ArgumentParser( + description='Convert MMRotate models to TorchServe `.mar` format.') + parser.add_argument('config', type=str, help='config file path') + parser.add_argument('checkpoint', type=str, help='checkpoint file path') + parser.add_argument( + '--output-folder', + type=str, + required=True, + help='Folder where `{model_name}.mar` will be created.') + parser.add_argument( + '--model-name', + type=str, + default=None, + help='If not None, used for naming the `{model_name}.mar`' + 'file that will be created under `output_folder`.' + 'If None, `{Path(checkpoint_file).stem}` will be used.') + parser.add_argument( + '--model-version', + type=str, + default='1.0', + help='Number used for versioning.') + parser.add_argument( + '-f', + '--force', + action='store_true', + help='overwrite the existing `{model_name}.mar`') + args = parser.parse_args() + + return args + + +if __name__ == '__main__': + args = parse_args() + + if package_model is None: + raise ImportError('`torch-model-archiver` is required.' + 'Try: pip install torch-model-archiver') + + mmrotate2torchserve(args.config, args.checkpoint, args.output_folder, + args.model_name, args.model_version, args.force) diff --git a/tools/deployment/mmrotate_handler.py b/tools/deployment/mmrotate_handler.py new file mode 100644 index 000000000..928b9de21 --- /dev/null +++ b/tools/deployment/mmrotate_handler.py @@ -0,0 +1,107 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import base64 +import os + +import mmcv +import torch +from mmdet.apis import inference_detector, init_detector +from ts.torch_handler.base_handler import BaseHandler + +import mmrotate # noqa: F401 + + +class MMRotateHandler(BaseHandler): + """MMRotate handler to load torchscript or eager mode [state_dict] + models.""" + threshold = 0.5 + + def initialize(self, context): + """Load the model.pt file and initialize the MMRotate model object. + + Args: + context (context): JSON Object containing information + pertaining to the model artifacts parameters. + """ + properties = context.system_properties + self.map_location = 'cuda' if torch.cuda.is_available() else 'cpu' + self.device = torch.device(self.map_location + ':' + + str(properties.get('gpu_id')) if torch.cuda. + is_available() else self.map_location) + self.manifest = context.manifest + + model_dir = properties.get('model_dir') + serialized_file = self.manifest['model']['serializedFile'] + checkpoint = os.path.join(model_dir, serialized_file) + self.config_file = os.path.join(model_dir, 'config.py') + + self.model = init_detector(self.config_file, checkpoint, self.device) + self.initialized = True + + def preprocess(self, data): + """Convert the request input to a ndarray. + + Args : + data (list): List of the data from the request input. + + Returns: + list[ndarray]: The list of ndarray data of the input + """ + images = [] + + for row in data: + image = row.get('data') or row.get('body') + if isinstance(image, str): + image = base64.b64decode(image) + image = mmcv.imfrombytes(image) + images.append(image) + + return images + + def inference(self, data, *args, **kwargs): + """Predict the results given input request. + + Args: + data (list[ndarray]): The list of a ndarray which are ready to + process. + + Returns: + list[Tensor] : The list of results from the inference. + """ + results = inference_detector(self.model, data) + return results + + def postprocess(self, data): + """Convert the output from the inference and converts into a Torchserve + supported response output. + + Args: + data (list[Tensor]): The list of results received from the + predicted output of the model. + + Returns: + list[dict]: The list of the predicted output that can be converted + to json format. + """ + output = [] + for image_index, image_result in enumerate(data): + output.append([]) + if isinstance(image_result, tuple): + bbox_result, segm_result = image_result + if isinstance(segm_result, tuple): + segm_result = segm_result[0] # ms rcnn + else: + bbox_result, segm_result = image_result, None + + for class_index, class_result in enumerate(bbox_result): + class_name = self.model.CLASSES[class_index] + for bbox in class_result: + bbox_coords = bbox[:-1].tolist() + score = float(bbox[-1]) + if score >= self.threshold: + output[image_index].append({ + 'class_name': class_name, + 'bbox': bbox_coords, + 'score': score + }) + + return output