Skip to content

Commit

Permalink
Add registry for model builder functions (facebookresearch#153)
Browse files Browse the repository at this point in the history
* adding registry to hook custom building blocks

* adding customizable rpn head

* support customizable c2 weight loading
  • Loading branch information
wat3rBro authored and fmassa committed Nov 13, 2018
1 parent 34748d0 commit 0c17d64
Show file tree
Hide file tree
Showing 8 changed files with 97 additions and 42 deletions.
2 changes: 2 additions & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@
# all FPN levels
_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TRAIN = 2000
_C.MODEL.RPN.FPN_POST_NMS_TOP_N_TEST = 2000
# Custom rpn head, empty to use default conv or separable conv
_C.MODEL.RPN.RPN_HEAD = "SingleConvRPNHead"


# ---------------------------------------------------------------------------- #
Expand Down
20 changes: 10 additions & 10 deletions maskrcnn_benchmark/modeling/backbone/backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,21 @@

from torch import nn

from maskrcnn_benchmark.modeling import registry

from . import fpn as fpn_module
from . import resnet


@registry.BACKBONES.register("R-50-C4")
def build_resnet_backbone(cfg):
body = resnet.ResNet(cfg)
model = nn.Sequential(OrderedDict([("body", body)]))
return model


@registry.BACKBONES.register("R-50-FPN")
@registry.BACKBONES.register("R-101-FPN")
def build_resnet_fpn_backbone(cfg):
body = resnet.ResNet(cfg)
in_channels_stage2 = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
Expand All @@ -31,14 +36,9 @@ def build_resnet_fpn_backbone(cfg):
return model


_BACKBONES = {"resnet": build_resnet_backbone, "resnet-fpn": build_resnet_fpn_backbone}


def build_backbone(cfg):
assert cfg.MODEL.BACKBONE.CONV_BODY.startswith(
"R-"
), "Only ResNet and ResNeXt models are currently implemented"
# Models using FPN end with "-FPN"
if cfg.MODEL.BACKBONE.CONV_BODY.endswith("-FPN"):
return build_resnet_fpn_backbone(cfg)
return build_resnet_backbone(cfg)
assert cfg.MODEL.BACKBONE.CONV_BODY in registry.BACKBONES, \
"cfg.MODEL.BACKBONE.CONV_BODY: {} are not registered in registry".format(
cfg.MODEL.BACKBONE.CONV_BODY
)
return registry.BACKBONES[cfg.MODEL.BACKBONE.CONV_BODY](cfg)
28 changes: 7 additions & 21 deletions maskrcnn_benchmark/modeling/backbone/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from maskrcnn_benchmark.layers import FrozenBatchNorm2d
from maskrcnn_benchmark.layers import Conv2d
from maskrcnn_benchmark.utils.registry import Registry


# ResNet stage specification
Expand Down Expand Up @@ -290,30 +291,15 @@ def forward(self, x):
return x


_TRANSFORMATION_MODULES = {"BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm}
_TRANSFORMATION_MODULES = Registry({
"BottleneckWithFixedBatchNorm": BottleneckWithFixedBatchNorm
})

_STEM_MODULES = {"StemWithFixedBatchNorm": StemWithFixedBatchNorm}
_STEM_MODULES = Registry({"StemWithFixedBatchNorm": StemWithFixedBatchNorm})

_STAGE_SPECS = {
_STAGE_SPECS = Registry({
"R-50-C4": ResNet50StagesTo4,
"R-50-C5": ResNet50StagesTo5,
"R-50-FPN": ResNet50FPNStagesTo5,
"R-101-FPN": ResNet101FPNStagesTo5,
}


def register_transformation_module(module_name, module):
_register_generic(_TRANSFORMATION_MODULES, module_name, module)


def register_stem_module(module_name, module):
_register_generic(_STEM_MODULES, module_name, module)


def register_stage_spec(stage_spec_name, stage_spec):
_register_generic(_STAGE_SPECS, stage_spec_name, stage_spec)


def _register_generic(module_dict, module_name, module):
assert module_name not in module_dict
module_dict[module_name] = module
})
7 changes: 7 additions & 0 deletions maskrcnn_benchmark/modeling/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from maskrcnn_benchmark.utils.registry import Registry

BACKBONES = Registry()
ROI_BOX_FEATURE_EXTRACTORS = Registry()
RPN_HEADS = Registry()
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from torch import nn
from torch.nn import functional as F

from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.backbone import resnet
from maskrcnn_benchmark.modeling.poolers import Pooler


@registry.ROI_BOX_FEATURE_EXTRACTORS.register("ResNet50Conv5ROIFeatureExtractor")
class ResNet50Conv5ROIFeatureExtractor(nn.Module):
def __init__(self, config):
super(ResNet50Conv5ROIFeatureExtractor, self).__init__()
Expand Down Expand Up @@ -39,6 +41,7 @@ def forward(self, x, proposals):
return x


@registry.ROI_BOX_FEATURE_EXTRACTORS.register("FPN2MLPFeatureExtractor")
class FPN2MLPFeatureExtractor(nn.Module):
"""
Heads for FPN for classification
Expand Down Expand Up @@ -77,12 +80,8 @@ def forward(self, x, proposals):
return x


_ROI_BOX_FEATURE_EXTRACTORS = {
"ResNet50Conv5ROIFeatureExtractor": ResNet50Conv5ROIFeatureExtractor,
"FPN2MLPFeatureExtractor": FPN2MLPFeatureExtractor,
}


def make_roi_box_feature_extractor(cfg):
func = _ROI_BOX_FEATURE_EXTRACTORS[cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR]
func = registry.ROI_BOX_FEATURE_EXTRACTORS[
cfg.MODEL.ROI_BOX_HEAD.FEATURE_EXTRACTOR
]
return func(cfg)
10 changes: 8 additions & 2 deletions maskrcnn_benchmark/modeling/rpn/rpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,23 @@
import torch.nn.functional as F
from torch import nn

from maskrcnn_benchmark.modeling import registry
from maskrcnn_benchmark.modeling.box_coder import BoxCoder
from .loss import make_rpn_loss_evaluator
from .anchor_generator import make_anchor_generator
from .inference import make_rpn_postprocessor


@registry.RPN_HEADS.register("SingleConvRPNHead")
class RPNHead(nn.Module):
"""
Adds a simple RPN Head with classification and regression heads
"""

def __init__(self, in_channels, num_anchors):
def __init__(self, cfg, in_channels, num_anchors):
"""
Arguments:
cfg : config
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
"""
Expand Down Expand Up @@ -57,7 +60,10 @@ def __init__(self, cfg):
anchor_generator = make_anchor_generator(cfg)

in_channels = cfg.MODEL.BACKBONE.OUT_CHANNELS
head = RPNHead(in_channels, anchor_generator.num_anchors_per_location()[0])
rpn_head = registry.RPN_HEADS[cfg.MODEL.RPN.RPN_HEAD]
head = rpn_head(
cfg, in_channels, anchor_generator.num_anchors_per_location()[0]
)

rpn_box_coder = BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))

Expand Down
14 changes: 12 additions & 2 deletions maskrcnn_benchmark/utils/c2_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch

from maskrcnn_benchmark.utils.model_serialization import load_state_dict
from maskrcnn_benchmark.utils.registry import Registry


def _rename_basic_resnet_weights(layer_keys):
Expand Down Expand Up @@ -135,11 +136,20 @@ def _load_c2_pickled_weights(file_path):
"R-101": ["1.2", "2.3", "3.22", "4.2"],
}

def load_c2_format(cfg, f):
# TODO make it support other architectures
C2_FORMAT_LOADER = Registry()


@C2_FORMAT_LOADER.register("R-50-C4")
@C2_FORMAT_LOADER.register("R-50-FPN")
@C2_FORMAT_LOADER.register("R-101-FPN")
def load_resnet_c2_format(cfg, f):
state_dict = _load_c2_pickled_weights(f)
conv_body = cfg.MODEL.BACKBONE.CONV_BODY
arch = conv_body.replace("-C4", "").replace("-FPN", "")
stages = _C2_STAGE_NAMES[arch]
state_dict = _rename_weights_for_resnet(state_dict, stages)
return dict(model=state_dict)


def load_c2_format(cfg, f):
return C2_FORMAT_LOADER[cfg.MODEL.BACKBONE.CONV_BODY](cfg, f)
45 changes: 45 additions & 0 deletions maskrcnn_benchmark/utils/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.


def _register_generic(module_dict, module_name, module):
assert module_name not in module_dict
module_dict[module_name] = module


class Registry(dict):
'''
A helper class for managing registering modules, it extends a dictionary
and provides a register functions.
Eg. creeting a registry:
some_registry = Registry({"default": default_module})
There're two ways of registering new modules:
1): normal way is just calling register function:
def foo():
...
some_registry.register("foo_module", foo)
2): used as decorator when declaring the module:
@some_registry.register("foo_module")
@some_registry.register("foo_modeul_nickname")
def foo():
...
Access of module is just like using a dictionary, eg:
f = some_registry["foo_modeul"]
'''
def __init__(self, *args, **kwargs):
super(Registry, self).__init__(*args, **kwargs)

def register(self, module_name, module=None):
# used as function call
if module is not None:
_register_generic(self, module_name, module)
return

# used as decorator
def register_fn(fn):
_register_generic(self, module_name, fn)
return fn

return register_fn

0 comments on commit 0c17d64

Please sign in to comment.