diff --git a/mmedit/models/common/flow_warp.py b/mmedit/models/common/flow_warp.py index b38f4bcdb8..8fbc720b2c 100644 --- a/mmedit/models/common/flow_warp.py +++ b/mmedit/models/common/flow_warp.py @@ -30,7 +30,7 @@ def flow_warp(x, _, _, h, w = x.size() # create mesh grid grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w)) - grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2) + grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (h, w, 2) grid.requires_grad = False grid_flow = grid + flow