Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Fbnet benchmark (#507)
Browse files Browse the repository at this point in the history
* Added a timer to benchmark model inference time in addition to total runtime.

* Updated FBNet configs and included some baselines benchmark results.

* Added a unit test for detectors.

* Add links to the models
  • Loading branch information
newstzpz authored and fmassa committed Mar 7, 2019
1 parent fd20472 commit 464b1af
Show file tree
Hide file tree
Showing 14 changed files with 475 additions and 23 deletions.
21 changes: 21 additions & 0 deletions MODEL_ZOO.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,28 @@ backbone | type | lr sched | im / gpu | train mem(GB) | train time (s/iter) | to
-- | -- | -- | -- | -- | -- | -- | -- | -- | -- | --
R-50-FPN | Keypoint | 1x | 2 | 5.7 | 0.3771 | 9.4 | 0.10941 | 53.7 | 64.3 | 9981060

### Light-weight Model baselines

We provided pre-trained models for selected FBNet models.
* All the models are trained from scratched with BN using the training schedule specified below.
* Evaluation is performed on a single NVIDIA V100 GPU with `MODEL.RPN.POST_NMS_TOP_N_TEST` set to `200`.

The following inference time is reported:
* inference total batch=8: Total inference time including data loading, model inference and pre/post preprocessing using 8 images per batch.
* inference model batch=8: Model inference time only and using 8 images per batch.
* inference model batch=1: Model inference time only and using 1 image per batch.
* inferenee caffe2 batch=1: Model inference time for the model in Caffe2 format using 1 image per batch. The Caffe2 models fused the BN to Conv and purely run on C++/CUDA by using Caffe2 ops for rpn/detection post processing.

This comment has been minimized.

Copy link
@jario-jin

jario-jin Mar 10, 2019

Contributor

@newstzpz
May I ask you a question. How do I reproduce this inference process ("inferenee caffe2 batch=1") ? Do I need extra code?

This comment has been minimized.

Copy link
@newstzpz

newstzpz Mar 11, 2019

Author Contributor

@jario-jin To create the caffe2 model, we need a conversion script to convert the pytorch model to caffe2 through ONNX, which is not included in this PR. This PR is also needed to reproduce the same result.

This comment has been minimized.

Copy link
@jario-jin

jario-jin Mar 12, 2019

Contributor

Thanks


The pre-trained models are available in the link in the model id.

backbone | type | resolution | lr sched | im / gpu | train mem(GB) | train time (s/iter) | total train time (hr) | inference total batch=8 (s/im) | inference model batch=8 (s/im) | inference model batch=1 (s/im) | inference caffe2 batch=1 (s/im) | box AP | mask AP | model id
-- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | --
[R-50-C4](configs/e2e_faster_rcnn_R_50_C4_1x.yaml) (reference) | Fast | 800 | 1x | 1 | 5.8 | 0.4036 | 20.2 | 0.0875 | **0.0793** | 0.0831 | **0.0625** | 34.4 | - | f35857197
[fbnet_chamv1a](configs/e2e_faster_rcnn_fbnet_chamv1a_600.yaml) | Fast | 600 | 0.75x | 12 | 13.6 | 0.5444 | 20.5 | 0.0315 | **0.0260** | 0.0376 | **0.0188** | 33.5 | - | [f100940543](https://download.pytorch.org/models/maskrcnn/e2e_faster_rcnn_fbnet_chamv1a_600.pth)
[fbnet_default](configs/e2e_faster_rcnn_fbnet_600.yaml) | Fast | 600 | 0.5x | 16 | 11.1 | 0.4872 | 12.5 | 0.0316 | **0.0250** | 0.0297 | **0.0130** | 28.2 | - | [f101086388](https://download.pytorch.org/models/maskrcnn/e2e_faster_rcnn_fbnet_600.pth)
[R-50-C4](configs/e2e_mask_rcnn_R_50_C4_1x.yaml) (reference) | Mask | 800 | 1x | 1 | 5.8 | 0.452 | 22.6 | 0.0918 | **0.0848** | 0.0844 | - | 35.2 | 31.0 | f35858791
[fbnet_xirb16d](configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask_600.yaml) | Mask | 600 | 0.5x | 16 | 13.4 | 1.1732 | 29 | 0.0386 | **0.0319** | 0.0356 | - | 30.7 | 26.9 | [f101086394](https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_fbnet_xirb16d_dsmask.pth)
[fbnet_default](configs/e2e_mask_rcnn_fbnet_600.yaml) | Mask | 600 | 0.5x | 16 | 13.0 | 0.9036 | 23.0 | 0.0327 | **0.0269** | 0.0385 | - | 29.0 | 26.1 | [f101086385](https://download.pytorch.org/models/maskrcnn/e2e_mask_rcnn_fbnet_600.pth)

## Comparison with Detectron and mmdetection

Expand Down
2 changes: 1 addition & 1 deletion configs/e2e_faster_rcnn_fbnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ MODEL:
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 100
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 512
Expand Down
2 changes: 1 addition & 1 deletion configs/e2e_faster_rcnn_fbnet_600.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ MODEL:
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 200
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 256
Expand Down
44 changes: 44 additions & 0 deletions configs/e2e_faster_rcnn_fbnet_chamv1a_600.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
BACKBONE:
CONV_BODY: FBNet
FBNET:
ARCH: "cham_v1a"
BN_TYPE: "bn"
WIDTH_DIVISOR: 8
DW_CONV_SKIP_BN: True
DW_CONV_SKIP_RELU: True
RPN:
ANCHOR_SIZES: (32, 64, 128, 256, 512)
ANCHOR_STRIDE: (16, )
BATCH_SIZE_PER_IMAGE: 256
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 200
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 128
ROI_BOX_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head
NUM_CLASSES: 81
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
SOLVER:
BASE_LR: 0.045
WARMUP_FACTOR: 0.1
WEIGHT_DECAY: 0.0001
STEPS: (90000, 120000)
MAX_ITER: 135000
IMS_PER_BATCH: 96 # for 8GPUs
# TEST:
# IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (600, )
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 600
MAX_SIZE_TEST: 1000
PIXEL_MEAN: [103.53, 116.28, 123.675]
PIXEL_STD: [57.375, 57.12, 58.395]
4 changes: 2 additions & 2 deletions configs/e2e_mask_rcnn_fbnet.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ MODEL:
WIDTH_DIVISOR: 8
DW_CONV_SKIP_BN: True
DW_CONV_SKIP_RELU: True
DET_HEAD_LAST_SCALE: -1.0
DET_HEAD_LAST_SCALE: 0.0
RPN:
ANCHOR_SIZES: (16, 32, 64, 128, 256)
ANCHOR_STRIDE: (16, )
BATCH_SIZE_PER_IMAGE: 256
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 100
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 256
Expand Down
52 changes: 52 additions & 0 deletions configs/e2e_mask_rcnn_fbnet_600.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
BACKBONE:
CONV_BODY: FBNet
FBNET:
ARCH: "default"
BN_TYPE: "bn"
WIDTH_DIVISOR: 8
DW_CONV_SKIP_BN: True
DW_CONV_SKIP_RELU: True
DET_HEAD_LAST_SCALE: 0.0
RPN:
ANCHOR_SIZES: (32, 64, 128, 256, 512)
ANCHOR_STRIDE: (16, )
BATCH_SIZE_PER_IMAGE: 256
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 200
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head
NUM_CLASSES: 81
ROI_MASK_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head_mask
PREDICTOR: "MaskRCNNConv1x1Predictor"
RESOLUTION: 12
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
SOLVER:
BASE_LR: 0.06
WARMUP_FACTOR: 0.1
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 128 # for 8GPUs
# TEST:
# IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (600, )
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 600
MAX_SIZE_TEST: 1000
PIXEL_MEAN: [103.53, 116.28, 123.675]
PIXEL_STD: [57.375, 57.12, 58.395]
2 changes: 1 addition & 1 deletion configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ MODEL:
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 1000
POST_NMS_TOP_N_TEST: 100
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 512
Expand Down
52 changes: 52 additions & 0 deletions configs/e2e_mask_rcnn_fbnet_xirb16d_dsmask_600.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
BACKBONE:
CONV_BODY: FBNet
FBNET:
ARCH: "xirb16d_dsmask"
BN_TYPE: "bn"
WIDTH_DIVISOR: 8
DW_CONV_SKIP_BN: True
DW_CONV_SKIP_RELU: True
DET_HEAD_LAST_SCALE: 0.0
RPN:
ANCHOR_SIZES: (32, 64, 128, 256, 512)
ANCHOR_STRIDE: (16, )
BATCH_SIZE_PER_IMAGE: 256
PRE_NMS_TOP_N_TRAIN: 6000
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TRAIN: 2000
POST_NMS_TOP_N_TEST: 200
RPN_HEAD: FBNet.rpn_head
ROI_HEADS:
BATCH_SIZE_PER_IMAGE: 256
ROI_BOX_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head
NUM_CLASSES: 81
ROI_MASK_HEAD:
POOLER_RESOLUTION: 6
FEATURE_EXTRACTOR: FBNet.roi_head_mask
PREDICTOR: "MaskRCNNConv1x1Predictor"
RESOLUTION: 12
SHARE_BOX_FEATURE_EXTRACTOR: False
MASK_ON: True
DATASETS:
TRAIN: ("coco_2014_train", "coco_2014_valminusminival")
TEST: ("coco_2014_minival",)
SOLVER:
BASE_LR: 0.06
WARMUP_FACTOR: 0.1
WEIGHT_DECAY: 0.0001
STEPS: (60000, 80000)
MAX_ITER: 90000
IMS_PER_BATCH: 128 # for 8GPUs
# TEST:
# IMS_PER_BATCH: 8
INPUT:
MIN_SIZE_TRAIN: (600, )
MAX_SIZE_TRAIN: 1000
MIN_SIZE_TEST: 600
MAX_SIZE_TEST: 1000
PIXEL_MEAN: [103.53, 116.28, 123.675]
PIXEL_STD: [57.375, 57.12, 58.395]
31 changes: 23 additions & 8 deletions maskrcnn_benchmark/engine/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import datetime
import logging
import time
import os
Expand All @@ -11,17 +10,23 @@
from ..utils.comm import is_main_process, get_world_size
from ..utils.comm import all_gather
from ..utils.comm import synchronize
from ..utils.timer import Timer, get_time_str


def compute_on_dataset(model, data_loader, device):
def compute_on_dataset(model, data_loader, device, timer=None):
model.eval()
results_dict = {}
cpu_device = torch.device("cpu")
for i, batch in enumerate(tqdm(data_loader)):
for _, batch in enumerate(tqdm(data_loader)):
images, targets, image_ids = batch
images = images.to(device)
with torch.no_grad():
if timer:
timer.tic()
output = model(images)
if timer:
torch.cuda.synchronize()
timer.toc()
output = [o.to(cpu_device) for o in output]
results_dict.update(
{img_id: result for img_id, result in zip(image_ids, output)}
Expand Down Expand Up @@ -68,17 +73,27 @@ def inference(
logger = logging.getLogger("maskrcnn_benchmark.inference")
dataset = data_loader.dataset
logger.info("Start evaluation on {} dataset({} images).".format(dataset_name, len(dataset)))
start_time = time.time()
predictions = compute_on_dataset(model, data_loader, device)
total_timer = Timer()
inference_timer = Timer()
total_timer.tic()
predictions = compute_on_dataset(model, data_loader, device, inference_timer)
# wait for all processes to complete before measuring the time
synchronize()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=total_time))
total_time = total_timer.toc()
total_time_str = get_time_str(total_time)
logger.info(
"Total inference time: {} ({} s / img per device, on {} devices)".format(
"Total run time: {} ({} s / img per device, on {} devices)".format(
total_time_str, total_time * num_devices / len(dataset), num_devices
)
)
total_infer_time = get_time_str(inference_timer.total_time)

This comment has been minimized.

Copy link
@yelantf

yelantf Apr 1, 2019

Contributor

Without tic and toc, where does inference_timer.total_time come from?

logger.info(
"Model inference time: {} ({} s / img per device, on {} devices)".format(
total_infer_time,
inference_timer.total_time * num_devices / len(dataset),
num_devices,
)
)

predictions = _accumulate_predictions_from_multiple_gpus(predictions)
if not is_main_process():
Expand Down
7 changes: 0 additions & 7 deletions maskrcnn_benchmark/modeling/backbone/fbnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,6 @@ def __init__(
("last", last)
]))

# output_blob = builder.add_final_pool(
# # model, output_blob, kernel_size=cfg.FAST_RCNN.ROI_XFORM_RESOLUTION)
# model,
# output_blob,
# kernel_size=int(cfg.FAST_RCNN.ROI_XFORM_RESOLUTION / stride_init),
# )

self.out_channels = builder.last_depth

def forward(self, x, proposals):
Expand Down
3 changes: 3 additions & 0 deletions maskrcnn_benchmark/modeling/backbone/fbnet_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,9 @@ def add_last(self, stage_info):
last_channel = int(self.last_depth * (-channel_scale))
last_channel = self._get_divisible_width(last_channel)

if last_channel == 0:
return nn.Sequential()

dim_in = self.last_depth
ret = ConvBNRelu(
dim_in,
Expand Down
Loading

0 comments on commit 464b1af

Please sign in to comment.