From 8236173842ff9e72e8962f1e118e72cb47162f26 Mon Sep 17 00:00:00 2001 From: Philipp Allgeuer Date: Mon, 30 May 2022 11:29:12 +0200 Subject: [PATCH 1/3] Fix warning with torch.meshgrid --- mmdet/core/anchor/point_generator.py | 2 +- mmdet/core/utils/misc.py | 2 +- mmdet/models/dense_heads/anchor_free_head.py | 2 +- mmdet/models/dense_heads/cascade_rpn_head.py | 5 +++-- mmdet/models/dense_heads/vfnet_head.py | 2 +- mmdet/models/utils/transformer.py | 6 ++++-- 6 files changed, 11 insertions(+), 8 deletions(-) diff --git a/mmdet/core/anchor/point_generator.py b/mmdet/core/anchor/point_generator.py index cc9c3887dd7..376c564f68e 100644 --- a/mmdet/core/anchor/point_generator.py +++ b/mmdet/core/anchor/point_generator.py @@ -68,7 +68,7 @@ def num_base_priors(self): return [1 for _ in range(len(self.strides))] def _meshgrid(self, x, y, row_major=True): - yy, xx = torch.meshgrid(y, x) + yy, xx = torch.meshgrid(y, x, indexing='ij') if row_major: # warning .flatten() would cause error in ONNX exporting # have to use reshape here diff --git a/mmdet/core/utils/misc.py b/mmdet/core/utils/misc.py index 14cb745e38e..83b4ffff6ca 100644 --- a/mmdet/core/utils/misc.py +++ b/mmdet/core/utils/misc.py @@ -200,7 +200,7 @@ def generate_coordinate(featmap_sizes, device='cuda'): x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device) y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device) - y, x = torch.meshgrid(y_range, x_range) + y, x = torch.meshgrid(y_range, x_range, indexing='ij') y = y.expand([featmap_sizes[0], 1, -1, -1]) x = x.expand([featmap_sizes[0], 1, -1, -1]) coord_feat = torch.cat([x, y], 1) diff --git a/mmdet/models/dense_heads/anchor_free_head.py b/mmdet/models/dense_heads/anchor_free_head.py index b0460b945ca..45df7e91487 100644 --- a/mmdet/models/dense_heads/anchor_free_head.py +++ b/mmdet/models/dense_heads/anchor_free_head.py @@ -301,7 +301,7 @@ def _get_points_single(self, # target `dtype` for onnx exporting. x_range = torch.arange(w, device=device).to(dtype) y_range = torch.arange(h, device=device).to(dtype) - y, x = torch.meshgrid(y_range, x_range) + y, x = torch.meshgrid(y_range, x_range, indexing='ij') if flatten: y = y.flatten() x = x.flatten() diff --git a/mmdet/models/dense_heads/cascade_rpn_head.py b/mmdet/models/dense_heads/cascade_rpn_head.py index 69347e00c43..8193383b25d 100644 --- a/mmdet/models/dense_heads/cascade_rpn_head.py +++ b/mmdet/models/dense_heads/cascade_rpn_head.py @@ -344,7 +344,8 @@ def _shape_offset(anchors, stride, ks=3, dilation=1): assert ks == 3 and dilation == 1 pad = (ks - 1) // 2 idx = torch.arange(-pad, pad + 1, dtype=dtype, device=device) - yy, xx = torch.meshgrid(idx, idx) # return order matters + yy, xx = torch.meshgrid( + idx, idx, indexing='ij') # return order matters xx = xx.reshape(-1) yy = yy.reshape(-1) w = (anchors[:, 2] - anchors[:, 0]) / stride @@ -367,7 +368,7 @@ def _ctr_offset(anchors, stride, featmap_size): # compute predefine centers xx = torch.arange(0, feat_w, device=anchors.device) yy = torch.arange(0, feat_h, device=anchors.device) - yy, xx = torch.meshgrid(yy, xx) + yy, xx = torch.meshgrid(yy, xx, indexing='ij') xx = xx.reshape(-1).type_as(x) yy = yy.reshape(-1).type_as(y) diff --git a/mmdet/models/dense_heads/vfnet_head.py b/mmdet/models/dense_heads/vfnet_head.py index ba285e22e32..8c946848e72 100644 --- a/mmdet/models/dense_heads/vfnet_head.py +++ b/mmdet/models/dense_heads/vfnet_head.py @@ -728,7 +728,7 @@ def _get_points_single(self, 0, w * stride, stride, dtype=dtype, device=device) y_range = torch.arange( 0, h * stride, stride, dtype=dtype, device=device) - y, x = torch.meshgrid(y_range, x_range) + y, x = torch.meshgrid(y_range, x_range, indexing='ij') # to be compatible with anchor points in ATSS if self.use_atss: points = torch.stack( diff --git a/mmdet/models/utils/transformer.py b/mmdet/models/utils/transformer.py index 3c390c83a1a..32f7e2b6bb7 100644 --- a/mmdet/models/utils/transformer.py +++ b/mmdet/models/utils/transformer.py @@ -799,7 +799,8 @@ def gen_encoder_output_proposals(self, memory, memory_padding_mask, torch.linspace( 0, H - 1, H, dtype=torch.float32, device=memory.device), torch.linspace( - 0, W - 1, W, dtype=torch.float32, device=memory.device)) + 0, W - 1, W, dtype=torch.float32, device=memory.device), + indexing='ij') grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) scale = torch.cat([valid_W.unsqueeze(-1), @@ -851,7 +852,8 @@ def get_reference_points(spatial_shapes, valid_ratios, device): torch.linspace( 0.5, H - 0.5, H, dtype=torch.float32, device=device), torch.linspace( - 0.5, W - 0.5, W, dtype=torch.float32, device=device)) + 0.5, W - 0.5, W, dtype=torch.float32, device=device), + indexing='ij') ref_y = ref_y.reshape(-1)[None] / ( valid_ratios[:, None, lvl, 1] * H) ref_x = ref_x.reshape(-1)[None] / ( From 31de3749ffeb21a82577094d7d727e1cd7cd2a5d Mon Sep 17 00:00:00 2001 From: Philipp Allgeuer Date: Tue, 7 Jun 2022 11:38:53 +0200 Subject: [PATCH 2/3] Add torch_meshgrid_ij utility --- mmdet/core/anchor/point_generator.py | 3 ++- mmdet/core/utils/misc.py | 3 ++- mmdet/models/dense_heads/anchor_free_head.py | 3 ++- mmdet/models/dense_heads/cascade_rpn_head.py | 6 +++--- mmdet/models/dense_heads/vfnet_head.py | 3 ++- mmdet/models/utils/transformer.py | 11 +++++------ mmdet/utils/misc.py | 13 +++++++++++++ 7 files changed, 29 insertions(+), 13 deletions(-) diff --git a/mmdet/core/anchor/point_generator.py b/mmdet/core/anchor/point_generator.py index 376c564f68e..2dbc33daead 100644 --- a/mmdet/core/anchor/point_generator.py +++ b/mmdet/core/anchor/point_generator.py @@ -3,6 +3,7 @@ import torch from torch.nn.modules.utils import _pair +from mmdet.utils.misc import torch_meshgrid_ij from .builder import PRIOR_GENERATORS @@ -68,7 +69,7 @@ def num_base_priors(self): return [1 for _ in range(len(self.strides))] def _meshgrid(self, x, y, row_major=True): - yy, xx = torch.meshgrid(y, x, indexing='ij') + yy, xx = torch_meshgrid_ij(y, x) if row_major: # warning .flatten() would cause error in ONNX exporting # have to use reshape here diff --git a/mmdet/core/utils/misc.py b/mmdet/core/utils/misc.py index 83b4ffff6ca..4b5c11ab7db 100644 --- a/mmdet/core/utils/misc.py +++ b/mmdet/core/utils/misc.py @@ -5,6 +5,7 @@ import torch from six.moves import map, zip +from mmdet.utils.misc import torch_meshgrid_ij from ..mask.structures import BitmapMasks, PolygonMasks @@ -200,7 +201,7 @@ def generate_coordinate(featmap_sizes, device='cuda'): x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device) y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device) - y, x = torch.meshgrid(y_range, x_range, indexing='ij') + y, x = torch_meshgrid_ij(y_range, x_range) y = y.expand([featmap_sizes[0], 1, -1, -1]) x = x.expand([featmap_sizes[0], 1, -1, -1]) coord_feat = torch.cat([x, y], 1) diff --git a/mmdet/models/dense_heads/anchor_free_head.py b/mmdet/models/dense_heads/anchor_free_head.py index 45df7e91487..2e38f481401 100644 --- a/mmdet/models/dense_heads/anchor_free_head.py +++ b/mmdet/models/dense_heads/anchor_free_head.py @@ -9,6 +9,7 @@ from mmdet.core import build_bbox_coder, multi_apply from mmdet.core.anchor.point_generator import MlvlPointGenerator +from mmdet.utils.misc import torch_meshgrid_ij from ..builder import HEADS, build_loss from .base_dense_head import BaseDenseHead from .dense_test_mixins import BBoxTestMixin @@ -301,7 +302,7 @@ def _get_points_single(self, # target `dtype` for onnx exporting. x_range = torch.arange(w, device=device).to(dtype) y_range = torch.arange(h, device=device).to(dtype) - y, x = torch.meshgrid(y_range, x_range, indexing='ij') + y, x = torch_meshgrid_ij(y_range, x_range) if flatten: y = y.flatten() x = x.flatten() diff --git a/mmdet/models/dense_heads/cascade_rpn_head.py b/mmdet/models/dense_heads/cascade_rpn_head.py index 8193383b25d..b6f130f6393 100644 --- a/mmdet/models/dense_heads/cascade_rpn_head.py +++ b/mmdet/models/dense_heads/cascade_rpn_head.py @@ -12,6 +12,7 @@ from mmdet.core import (RegionAssigner, build_assigner, build_sampler, images_to_levels, multi_apply) from mmdet.core.utils import select_single_mlvl +from mmdet.utils.misc import torch_meshgrid_ij from ..builder import HEADS, build_head from .base_dense_head import BaseDenseHead from .rpn_head import RPNHead @@ -344,8 +345,7 @@ def _shape_offset(anchors, stride, ks=3, dilation=1): assert ks == 3 and dilation == 1 pad = (ks - 1) // 2 idx = torch.arange(-pad, pad + 1, dtype=dtype, device=device) - yy, xx = torch.meshgrid( - idx, idx, indexing='ij') # return order matters + yy, xx = torch_meshgrid_ij(idx, idx) # return order matters xx = xx.reshape(-1) yy = yy.reshape(-1) w = (anchors[:, 2] - anchors[:, 0]) / stride @@ -368,7 +368,7 @@ def _ctr_offset(anchors, stride, featmap_size): # compute predefine centers xx = torch.arange(0, feat_w, device=anchors.device) yy = torch.arange(0, feat_h, device=anchors.device) - yy, xx = torch.meshgrid(yy, xx, indexing='ij') + yy, xx = torch_meshgrid_ij(yy, xx) xx = xx.reshape(-1).type_as(x) yy = yy.reshape(-1).type_as(y) diff --git a/mmdet/models/dense_heads/vfnet_head.py b/mmdet/models/dense_heads/vfnet_head.py index 8c946848e72..946c08c1b6b 100644 --- a/mmdet/models/dense_heads/vfnet_head.py +++ b/mmdet/models/dense_heads/vfnet_head.py @@ -11,6 +11,7 @@ from mmdet.core import (MlvlPointGenerator, bbox_overlaps, build_assigner, build_prior_generator, build_sampler, multi_apply, reduce_mean) +from mmdet.utils.misc import torch_meshgrid_ij from ..builder import HEADS, build_loss from .atss_head import ATSSHead from .fcos_head import FCOSHead @@ -728,7 +729,7 @@ def _get_points_single(self, 0, w * stride, stride, dtype=dtype, device=device) y_range = torch.arange( 0, h * stride, stride, dtype=dtype, device=device) - y, x = torch.meshgrid(y_range, x_range, indexing='ij') + y, x = torch_meshgrid_ij(y_range, x_range) # to be compatible with anchor points in ATSS if self.use_atss: points = torch.stack( diff --git a/mmdet/models/utils/transformer.py b/mmdet/models/utils/transformer.py index 32f7e2b6bb7..7f1363de915 100644 --- a/mmdet/models/utils/transformer.py +++ b/mmdet/models/utils/transformer.py @@ -18,6 +18,7 @@ from torch.nn.init import normal_ from mmdet.models.utils.builder import TRANSFORMER +from mmdet.utils.misc import torch_meshgrid_ij try: from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention @@ -795,12 +796,11 @@ def gen_encoder_output_proposals(self, memory, memory_padding_mask, valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) - grid_y, grid_x = torch.meshgrid( + grid_y, grid_x = torch_meshgrid_ij( torch.linspace( 0, H - 1, H, dtype=torch.float32, device=memory.device), torch.linspace( - 0, W - 1, W, dtype=torch.float32, device=memory.device), - indexing='ij') + 0, W - 1, W, dtype=torch.float32, device=memory.device)) grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) scale = torch.cat([valid_W.unsqueeze(-1), @@ -848,12 +848,11 @@ def get_reference_points(spatial_shapes, valid_ratios, device): reference_points_list = [] for lvl, (H, W) in enumerate(spatial_shapes): # TODO check this 0.5 - ref_y, ref_x = torch.meshgrid( + ref_y, ref_x = torch_meshgrid_ij( torch.linspace( 0.5, H - 0.5, H, dtype=torch.float32, device=device), torch.linspace( - 0.5, W - 0.5, W, dtype=torch.float32, device=device), - indexing='ij') + 0.5, W - 0.5, W, dtype=torch.float32, device=device)) ref_y = ref_y.reshape(-1)[None] / ( valid_ratios[:, None, lvl, 1] * H) ref_x = ref_x.reshape(-1)[None] / ( diff --git a/mmdet/utils/misc.py b/mmdet/utils/misc.py index 4113672acfb..87898de2f2d 100644 --- a/mmdet/utils/misc.py +++ b/mmdet/utils/misc.py @@ -5,7 +5,9 @@ import warnings import mmcv +import torch from mmcv.utils import print_log +from packaging import version def find_latest_checkpoint(path, suffix='pth'): @@ -74,3 +76,14 @@ def update(cfg, src_str, dst_str): update(cfg.data, cfg.data_root, dst_root) cfg.data_root = dst_root + + +_torch_version_meshgrid_indexing = version.parse( + torch.__version__) >= version.parse('1.10.0a0') + + +def torch_meshgrid_ij(*tensors): + if _torch_version_meshgrid_indexing: + return torch.meshgrid(*tensors, indexing='ij') + else: + return torch.meshgrid(*tensors) # Uses indexing='ij' by default From ea5c9c2da0fe2037fdf76ae4fd6f2d4439f3d558 Mon Sep 17 00:00:00 2001 From: Philipp Allgeuer Date: Wed, 8 Jun 2022 14:59:13 +0200 Subject: [PATCH 3/3] Use digit_version instead of version.parse --- mmdet/utils/misc.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mmdet/utils/misc.py b/mmdet/utils/misc.py index 87898de2f2d..d74968d4cff 100644 --- a/mmdet/utils/misc.py +++ b/mmdet/utils/misc.py @@ -6,8 +6,7 @@ import mmcv import torch -from mmcv.utils import print_log -from packaging import version +from mmcv.utils import TORCH_VERSION, digit_version, print_log def find_latest_checkpoint(path, suffix='pth'): @@ -78,8 +77,9 @@ def update(cfg, src_str, dst_str): cfg.data_root = dst_root -_torch_version_meshgrid_indexing = version.parse( - torch.__version__) >= version.parse('1.10.0a0') +_torch_version_meshgrid_indexing = ( + 'parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0')) def torch_meshgrid_ij(*tensors):