Skip to content

Commit

Permalink
make model zoo usable internally
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #152

Reviewed By: zhanghang1989

Differential Revision: D31591900

fbshipit-source-id: 6ee8124419d535caf03532eda4f729e707b6dda7
  • Loading branch information
wat3rBro authored and facebook-github-bot committed Dec 30, 2021
1 parent 06f3f2e commit c12469c
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 24 deletions.
6 changes: 2 additions & 4 deletions d2go/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,12 @@ def reroute_config_path(path: str) -> str:

if path.startswith("d2go://"):
rel_path = path[len("d2go://") :]
config_in_resource = pkg_resources.resource_filename(
"d2go.model_zoo", os.path.join("configs", rel_path)
)
config_in_resource = pkg_resources.resource_filename("d2go", rel_path)
return config_in_resource
elif path.startswith("detectron2go://"):
rel_path = path[len("detectron2go://") :]
config_in_resource = pkg_resources.resource_filename(
"d2go.model_zoo", os.path.join("configs", rel_path)
"d2go", os.path.join("configs", rel_path)
)
return config_in_resource
elif path.startswith("detectron2://"):
Expand Down
6 changes: 3 additions & 3 deletions d2go/model_zoo/model_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pkg_resources
import torch
from d2go.runner import create_runner
from d2go.utils.launch_environment import MODEL_ZOO_STORAGE_PREFIX
from detectron2.checkpoint import DetectionCheckpointer


Expand All @@ -13,7 +14,6 @@ class _ModelZooUrls(object):
Mapping from names to officially released D2Go pre-trained models.
"""

S3_PREFIX = "https://mobile-cv.s3-us-west-2.amazonaws.com/d2go/models/"
CONFIG_PATH_TO_URL_SUFFIX = {
"faster_rcnn_fbnetv3a_C4.yaml": "268421013/model_final.pth",
"faster_rcnn_fbnetv3a_dsmask_C4.yaml": "268412271/model_0499999.pth",
Expand All @@ -37,7 +37,7 @@ def get_checkpoint_url(config_path):
name = config_path.replace(".yaml", "")
if config_path in _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX:
suffix = _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX[config_path]
return _ModelZooUrls.S3_PREFIX + suffix
return MODEL_ZOO_STORAGE_PREFIX + suffix
raise RuntimeError("{} not available in Model Zoo!".format(name))


Expand All @@ -51,7 +51,7 @@ def get_config_file(config_path):
str: the real path to the config file.
"""
cfg_file = pkg_resources.resource_filename(
"d2go.model_zoo", os.path.join("configs", config_path)
"d2go", os.path.join("configs", config_path)
)
if not os.path.exists(cfg_file):
raise RuntimeError("{} not available in Model Zoo!".format(config_path))
Expand Down
4 changes: 2 additions & 2 deletions d2go/utils/launch_environment.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

from d2go.utils.misc import fb_overwritable

MODEL_ZOO_STORAGE_PREFIX = "https://mobile-cv.s3-us-west-2.amazonaws.com/d2go/models/"


@fb_overwritable()
def get_launch_environment():
return "local"
12 changes: 1 addition & 11 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,6 @@ def d2go_gather_files(dst_module, file_path, extension="*") -> List[str]:
return config_paths


def get_model_zoo_configs() -> List[str]:
"""
Return a list of configs to include in package for model zoo. Copy over these configs inside
d2go/model_zoo.
"""
return d2go_gather_files(
os.path.join("model_zoo", "configs"), "configs", "**/*.yaml"
)


if __name__ == "__main__":
setup(
name="d2go",
Expand All @@ -88,7 +78,7 @@ def get_model_zoo_configs() -> List[str]:
"d2go": [
"LICENSE",
],
"d2go.model_zoo": get_model_zoo_configs(),
"d2go.configs": d2go_gather_files("configs", "configs", "**/*.yaml"),
"d2go.tools": d2go_gather_files("tools", "tools", "**/*.py"),
"d2go.tests": d2go_gather_files("tests", "tests", "**/*helper.py"),
},
Expand Down
9 changes: 5 additions & 4 deletions tests/modeling/test_model_zoo.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import os
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

import unittest

import torch.nn as nn
from d2go.model_zoo import model_zoo

OSSRUN = os.getenv("OSSRUN") == "1"


class TestD2GoModelZoo(unittest.TestCase):
@unittest.skipIf(not OSSRUN, "OSS test only")
def test_model_zoo_pretrained(self):
configs = list(model_zoo._ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX.keys())
for cfgfile in configs:
model = model_zoo.get(cfgfile, trained=True)
self.assertTrue(isinstance(model, nn.Module))


if __name__ == "__main__":
Expand Down

0 comments on commit c12469c

Please sign in to comment.