Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TV MaskRCNN Tile Recipe #3655

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/otx/algo/instance_segmentation/maskrcnn_tv.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
tile_config=tile_config,
)
self.image_size = (1, 3, 1024, 1024)
self.tile_image_size = (1, 3, 512, 512)
self.mean = (123.675, 116.28, 103.53)
self.std = (58.395, 57.12, 57.375)

Expand Down
5 changes: 3 additions & 2 deletions src/otx/algo/instance_segmentation/torchvision/maskrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,12 @@ def forward(
entity.masks,
entity.polygons,
):
# NOTE: shift labels by 1 as 0 is reserved for background
_labels = labels + 1 if len(labels) else labels
targets.append(
{
"boxes": bboxes,
# NOTE: shift labels by 1 as 0 is reserved for background
"labels": labels + 1,
"labels": _labels,
"masks": masks,
"polygons": polygons,
},
Expand Down
121 changes: 121 additions & 0 deletions src/otx/recipe/instance_segmentation/maskrcnn_r50_tv_tile.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
model:
class_path: otx.algo.instance_segmentation.maskrcnn_tv.TVMaskRCNNR50
init_args:
label_info: 80

optimizer:
class_path: torch.optim.SGD
init_args:
lr: 0.007
momentum: 0.9
weight_decay: 0.001

scheduler:
class_path: otx.core.schedulers.LinearWarmupSchedulerCallable
init_args:
num_warmup_steps: 100
main_scheduler_callable:
class_path: lightning.pytorch.cli.ReduceLROnPlateau
init_args:
mode: max
factor: 0.1
patience: 4
monitor: val/map_50

engine:
task: INSTANCE_SEGMENTATION
device: auto

callback_monitor: val/map_50

data: ../_base_/data/torchvision_base.yaml
overrides:
max_epochs: 100
gradient_clip_val: 35.0
data:
task: INSTANCE_SEGMENTATION
config:
stack_images: true
tile_config:
enable_tiler: true
enable_adaptive_tiling: true
data_format: coco_instances
include_polygons: true
train_subset:
batch_size: 4
num_workers: 8
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: false
transform_bbox: true
transform_mask: true
scale:
- 512
- 512
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
transform_mask: true
- class_path: otx.core.data.transform_libs.torchvision.RandomFlip
init_args:
prob: 0.5
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
scale: False
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
val_subset:
batch_size: 1
num_workers: 4
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: false
transform_bbox: false
transform_mask: false
scale:
- 512
- 512
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
transform_mask: false
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
scale: False
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
test_subset:
batch_size: 1
num_workers: 4
transforms:
- class_path: otx.core.data.transform_libs.torchvision.Resize
init_args:
keep_ratio: false
transform_bbox: false
transform_mask: false
scale:
- 512
- 512
- class_path: otx.core.data.transform_libs.torchvision.Pad
init_args:
size_divisor: 32
transform_mask: false
is_numpy_to_tvtensor: true
- class_path: torchvision.transforms.v2.ToDtype
init_args:
dtype: ${as_torch_dtype:torch.float32}
scale: False
- class_path: torchvision.transforms.v2.Normalize
init_args:
mean: [123.675, 116.28, 103.53]
std: [58.395, 57.12, 57.375]
24 changes: 16 additions & 8 deletions tests/unit/algo/instance_segmentation/test_maskrcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
import torch
from otx.algo.instance_segmentation.maskrcnn import MaskRCNNEfficientNet, MaskRCNNResNet50, MaskRCNNSwinT
from otx.algo.instance_segmentation.maskrcnn_tv import TVMaskRCNNR50
from otx.algo.utils.support_otx_v1 import OTXv1Helper
from otx.core.data.entity.instance_segmentation import InstanceSegBatchPredEntity
from otx.core.types.export import TaskLevelExportParameters
Expand All @@ -21,22 +22,29 @@ def test_load_weights(self, mocker) -> None:

@pytest.mark.parametrize(
"model",
[MaskRCNNResNet50(3), MaskRCNNEfficientNet(3), MaskRCNNSwinT(3)],
[MaskRCNNResNet50(3), MaskRCNNEfficientNet(3), MaskRCNNSwinT(3), TVMaskRCNNR50(3)],
)
def test_loss(self, model, fxt_data_module):
data = next(iter(fxt_data_module.train_dataloader()))
data.images = torch.randn([2, 3, 32, 32])

output = model(data)
assert "loss_cls" in output
assert "loss_bbox" in output
assert "loss_mask" in output
assert "loss_rpn_cls" in output
assert "loss_rpn_bbox" in output
if isinstance(model, TVMaskRCNNR50):
assert "loss_classifier" in output
assert "loss_box_reg" in output
assert "loss_mask" in output
assert "loss_objectness" in output
assert "loss_rpn_box_reg" in output
else:
assert "loss_cls" in output
assert "loss_bbox" in output
assert "loss_mask" in output
assert "loss_rpn_cls" in output
assert "loss_rpn_bbox" in output

@pytest.mark.parametrize(
"model",
[MaskRCNNResNet50(3), MaskRCNNEfficientNet(3), MaskRCNNSwinT(3)],
[MaskRCNNResNet50(3), MaskRCNNEfficientNet(3), MaskRCNNSwinT(3), TVMaskRCNNR50(3)],
)
def test_predict(self, model, fxt_data_module):
data = next(iter(fxt_data_module.train_dataloader()))
Expand All @@ -47,7 +55,7 @@ def test_predict(self, model, fxt_data_module):

@pytest.mark.parametrize(
"model",
[MaskRCNNResNet50(3), MaskRCNNEfficientNet(3), MaskRCNNSwinT(3)],
[MaskRCNNResNet50(3), MaskRCNNEfficientNet(3), MaskRCNNSwinT(3), TVMaskRCNNR50(3)],
)
def test_export(self, model):
model.eval()
Expand Down
Loading