diff --git a/mmedit/models/common/downsample.py b/mmedit/models/common/downsample.py index 3be4324e55..58979f435d 100644 --- a/mmedit/models/common/downsample.py +++ b/mmedit/models/common/downsample.py @@ -18,8 +18,9 @@ def pixel_unshuffle(x, scale): raise AssertionError( f'Invalid scale ({scale}) of pixel unshuffle for tensor ' f'with shape: {x.shape}') - h = torch.div(h, scale, rounding_mode='floor') - w = torch.div(w, scale, rounding_mode='floor') - x = x.view(b, c, h, scale, w, scale) + h = h // scale + w = w // scale + size = torch.Size([b, c, h, scale, w, scale]) + x = x.view(size) x = x.permute(0, 1, 3, 5, 2, 4) return x.reshape(b, -1, h, w)