Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix warning with torch.meshgrid #8090

Open
wants to merge 3 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmdet/core/anchor/point_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
from torch.nn.modules.utils import _pair

from mmdet.utils.misc import torch_meshgrid_ij
from .builder import PRIOR_GENERATORS


Expand Down Expand Up @@ -68,7 +69,7 @@ def num_base_priors(self):
return [1 for _ in range(len(self.strides))]

def _meshgrid(self, x, y, row_major=True):
yy, xx = torch.meshgrid(y, x)
yy, xx = torch_meshgrid_ij(y, x)
if row_major:
# warning .flatten() would cause error in ONNX exporting
# have to use reshape here
Expand Down
3 changes: 2 additions & 1 deletion mmdet/core/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
from six.moves import map, zip

from mmdet.utils.misc import torch_meshgrid_ij
from ..mask.structures import BitmapMasks, PolygonMasks


Expand Down Expand Up @@ -200,7 +201,7 @@ def generate_coordinate(featmap_sizes, device='cuda'):

x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device)
y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device)
y, x = torch.meshgrid(y_range, x_range)
y, x = torch_meshgrid_ij(y_range, x_range)
y = y.expand([featmap_sizes[0], 1, -1, -1])
x = x.expand([featmap_sizes[0], 1, -1, -1])
coord_feat = torch.cat([x, y], 1)
Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/anchor_free_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from mmdet.core import build_bbox_coder, multi_apply
from mmdet.core.anchor.point_generator import MlvlPointGenerator
from mmdet.utils.misc import torch_meshgrid_ij
from ..builder import HEADS, build_loss
from .base_dense_head import BaseDenseHead
from .dense_test_mixins import BBoxTestMixin
Expand Down Expand Up @@ -301,7 +302,7 @@
# target `dtype` for onnx exporting.
x_range = torch.arange(w, device=device).to(dtype)
y_range = torch.arange(h, device=device).to(dtype)
y, x = torch.meshgrid(y_range, x_range)
y, x = torch_meshgrid_ij(y_range, x_range)

Check warning on line 305 in mmdet/models/dense_heads/anchor_free_head.py

View check run for this annotation

Codecov / codecov/patch

mmdet/models/dense_heads/anchor_free_head.py#L305

Added line #L305 was not covered by tests
if flatten:
y = y.flatten()
x = x.flatten()
Expand Down
5 changes: 3 additions & 2 deletions mmdet/models/dense_heads/cascade_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mmdet.core import (RegionAssigner, build_assigner, build_sampler,
images_to_levels, multi_apply)
from mmdet.core.utils import select_single_mlvl
from mmdet.utils.misc import torch_meshgrid_ij
from ..builder import HEADS, build_head
from .base_dense_head import BaseDenseHead
from .rpn_head import RPNHead
Expand Down Expand Up @@ -344,7 +345,7 @@
assert ks == 3 and dilation == 1
pad = (ks - 1) // 2
idx = torch.arange(-pad, pad + 1, dtype=dtype, device=device)
yy, xx = torch.meshgrid(idx, idx) # return order matters
yy, xx = torch_meshgrid_ij(idx, idx) # return order matters

Check warning on line 348 in mmdet/models/dense_heads/cascade_rpn_head.py

View check run for this annotation

Codecov / codecov/patch

mmdet/models/dense_heads/cascade_rpn_head.py#L348

Added line #L348 was not covered by tests
xx = xx.reshape(-1)
yy = yy.reshape(-1)
w = (anchors[:, 2] - anchors[:, 0]) / stride
Expand All @@ -367,7 +368,7 @@
# compute predefine centers
xx = torch.arange(0, feat_w, device=anchors.device)
yy = torch.arange(0, feat_h, device=anchors.device)
yy, xx = torch.meshgrid(yy, xx)
yy, xx = torch_meshgrid_ij(yy, xx)

Check warning on line 371 in mmdet/models/dense_heads/cascade_rpn_head.py

View check run for this annotation

Codecov / codecov/patch

mmdet/models/dense_heads/cascade_rpn_head.py#L371

Added line #L371 was not covered by tests
xx = xx.reshape(-1).type_as(x)
yy = yy.reshape(-1).type_as(y)

Expand Down
3 changes: 2 additions & 1 deletion mmdet/models/dense_heads/vfnet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mmdet.core import (MlvlPointGenerator, bbox_overlaps, build_assigner,
build_prior_generator, build_sampler, multi_apply,
reduce_mean)
from mmdet.utils.misc import torch_meshgrid_ij
from ..builder import HEADS, build_loss
from .atss_head import ATSSHead
from .fcos_head import FCOSHead
Expand Down Expand Up @@ -728,7 +729,7 @@
0, w * stride, stride, dtype=dtype, device=device)
y_range = torch.arange(
0, h * stride, stride, dtype=dtype, device=device)
y, x = torch.meshgrid(y_range, x_range)
y, x = torch_meshgrid_ij(y_range, x_range)

Check warning on line 732 in mmdet/models/dense_heads/vfnet_head.py

View check run for this annotation

Codecov / codecov/patch

mmdet/models/dense_heads/vfnet_head.py#L732

Added line #L732 was not covered by tests
# to be compatible with anchor points in ATSS
if self.use_atss:
points = torch.stack(
Expand Down
5 changes: 3 additions & 2 deletions mmdet/models/utils/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch.nn.init import normal_

from mmdet.models.utils.builder import TRANSFORMER
from mmdet.utils.misc import torch_meshgrid_ij

try:
from mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention
Expand Down Expand Up @@ -795,7 +796,7 @@
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)

grid_y, grid_x = torch.meshgrid(
grid_y, grid_x = torch_meshgrid_ij(

Check warning on line 799 in mmdet/models/utils/transformer.py

View check run for this annotation

Codecov / codecov/patch

mmdet/models/utils/transformer.py#L799

Added line #L799 was not covered by tests
torch.linspace(
0, H - 1, H, dtype=torch.float32, device=memory.device),
torch.linspace(
Expand Down Expand Up @@ -847,7 +848,7 @@
reference_points_list = []
for lvl, (H, W) in enumerate(spatial_shapes):
# TODO check this 0.5
ref_y, ref_x = torch.meshgrid(
ref_y, ref_x = torch_meshgrid_ij(

Check warning on line 851 in mmdet/models/utils/transformer.py

View check run for this annotation

Codecov / codecov/patch

mmdet/models/utils/transformer.py#L851

Added line #L851 was not covered by tests
torch.linspace(
0.5, H - 0.5, H, dtype=torch.float32, device=device),
torch.linspace(
Expand Down
15 changes: 14 additions & 1 deletion mmdet/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import warnings

import mmcv
from mmcv.utils import print_log
import torch
from mmcv.utils import TORCH_VERSION, digit_version, print_log


def find_latest_checkpoint(path, suffix='pth'):
Expand Down Expand Up @@ -74,3 +75,15 @@ def update(cfg, src_str, dst_str):

update(cfg.data, cfg.data_root, dst_root)
cfg.data_root = dst_root


_torch_version_meshgrid_indexing = (
'parrots' not in TORCH_VERSION
and digit_version(TORCH_VERSION) >= digit_version('1.10.0a0'))


def torch_meshgrid_ij(*tensors):
if _torch_version_meshgrid_indexing:
return torch.meshgrid(*tensors, indexing='ij')
else:
return torch.meshgrid(*tensors) # Uses indexing='ij' by default