From 7121d65a63f3c5bd5ae286598dc2b1833ba90b96 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Sun, 13 Feb 2022 17:26:06 +0800 Subject: [PATCH] Fix docs and add docstrings for `YOLOv5` (#315) * Fix code block for sphinx * Fix docstrings * Fix docs --- docs/source/models.rst | 49 +++++++++++++++++++++++++++++++++----- yolort/models/yolov5.py | 52 +++++++++++++++++++++++------------------ 2 files changed, 72 insertions(+), 29 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index f268aa8d..650a4f17 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -1,7 +1,44 @@ -Models and pre-trained weights -============================== +.. _models: -.. automodule:: yolort.models - :imported-members: - :members: - :undoc-members: +yolort.models +############# + +Models structure +================ + +The models expect a list of ``Tensor[C, H, W]``, in the range ``0-1``. +The models internally resize the images but the behaviour varies depending +on the model. Check the constructor of the models for more information. + +.. autofunction:: yolort.models.YOLOv5 + +Pre-trained weights +=================== + +The pre-trained models return the predictions of the following classes: + + .. code-block:: python + + COCO_INSTANCE_CATEGORY_NAMES = [ + 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', + 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign', + 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', + 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella', + 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', + 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', + 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', + 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', + 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'dining table', + 'toilet', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', + 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase', + 'scissors', 'teddy bear', 'hair drier', 'toothbrush' + ] + +.. autofunction:: yolort.models.yolov5n +.. autofunction:: yolort.models.yolov5n6 +.. autofunction:: yolort.models.yolov5s +.. autofunction:: yolort.models.yolov5s6 +.. autofunction:: yolort.models.yolov5m +.. autofunction:: yolort.models.yolov5m6 +.. autofunction:: yolort.models.yolov5l +.. autofunction:: yolort.models.yolov5ts diff --git a/yolort/models/yolov5.py b/yolort/models/yolov5.py index d084c230..33f7ee38 100644 --- a/yolort/models/yolov5.py +++ b/yolort/models/yolov5.py @@ -20,11 +20,36 @@ class YOLOv5(nn.Module): """ Wrapping the pre-processing (`LetterBox`) into the YOLO models. + The input to the model is expected to be a list of tensors, each of shape [C, H, W], one for each + image, and should be in 0-1 range. Different images can have different sizes but they will be resized + to a fixed size that maintains the aspect ratio before passing it to the backbone. + + The behavior of the model changes depending if it is in training or evaluation mode. + + During training, the model expects both the input tensors, as well as a targets (list of dictionary), + containing: + - boxes (``FloatTensor[N, 4]``): the ground-truth boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the class label for each ground-truth box + + The model returns a Dict[Tensor] during training, containing the classification and regression + losses. + + During inference, the model requires only the input tensors, and returns the post-processed + predictions as a List[Dict[Tensor]], one for each input image. The fields of the Dict are as + follows, where ``N`` is the number of detections: + + - boxes (``FloatTensor[N, 4]``): the predicted boxes in ``[x1, y1, x2, y2]`` format, with + ``0 <= x1 < x2 <= W`` and ``0 <= y1 < y2 <= H``. + - labels (Int64Tensor[N]): the predicted labels for each detection + - scores (Tensor[N]): the scores for each detection + Example: Demo pipeline for YOLOv5 Inference. .. code-block:: python + from yolort.models import YOLOv5 # Load the yolov5s version 6.0 models @@ -40,6 +65,7 @@ class YOLOv5(nn.Module): We also support loading the custom checkpoints trained from ultralytics/yolov5 .. code-block:: python + from yolort.models import YOLOv5 # Your trained checkpoint from ultralytics @@ -106,22 +132,12 @@ def __init__( # used only on torchscript mode self._has_warned = False - def _forward_impl( + def forward( self, inputs: List[Tensor], targets: Optional[List[Dict[str, Tensor]]] = None, ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: - """ - Args: - inputs (list[Tensor]): images to be processed - targets (list[Dict[Tensor]]): ground-truth boxes present in the image (optional) - Returns: - result (list[BoxList] or dict[Tensor]): the output from the model. - During training, it returns a dict[Tensor] which contains the losses. - During testing, it returns list[BoxList] contains additional fields - like `scores`, `labels` and `boxes`. - """ # get the original image sizes original_image_sizes: List[Tuple[int, int]] = [] @@ -178,21 +194,11 @@ def eager_outputs( return detections - def forward( - self, - inputs: List[Tensor], - targets: Optional[List[Dict[str, Tensor]]] = None, - ) -> Tuple[Dict[str, Tensor], List[Dict[str, Tensor]]]: - """ - This exists since PyTorchLightning forward are used for inference only (separate from - ``training_step``). We keep ``targets`` here for Backward Compatible. - """ - return self._forward_impl(inputs, targets) - @torch.no_grad() def predict(self, x: Any, image_loader: Optional[Callable] = None) -> List[Dict[str, Tensor]]: """ Predict function for raw data or processed data + Args: x: Input to predict. Can be raw data or processed data. image_loader: Utility function to convert raw data to Tensor. @@ -263,7 +269,7 @@ def load_from_yolov5( **kwargs: Any, ): """ - Load model state from the checkpoint trained by YOLOv5. + Load custom checkpoints trained from YOLOv5. Args: checkpoint_path (str): Path of the YOLOv5 checkpoint model.