-
Notifications
You must be signed in to change notification settings - Fork 25
/
PixelUnShuffle.py
32 lines (28 loc) · 1.2 KB
/
PixelUnShuffle.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# This code is heavily inspired from https://github.com/fangwei123456/PixelUnshuffle-pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
def pixel_unshuffle(input, downscale_factor):
'''
input: batchSize * c * k*w * k*h
downscale_factor: k
batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
'''
c = input.shape[1]
kernel = torch.zeros(size = [downscale_factor * downscale_factor * c, 1, downscale_factor, downscale_factor],
device = input.device)
for y in range(downscale_factor):
for x in range(downscale_factor):
kernel[x + y * downscale_factor::downscale_factor * downscale_factor, 0, y, x] = 1
return F.conv2d(input, kernel, stride = downscale_factor, groups = c)
class PixelUnShuffle(nn.Module):
def __init__(self, downscale_factor):
super(PixelUnShuffle, self).__init__()
self.downscale_factor = downscale_factor
def forward(self, input):
'''
input: batchSize * c * k*w * k*h
downscale_factor: k
batchSize * c * k*w * k*h -> batchSize * k*k*c * w * h
'''
return pixel_unshuffle(input, self.downscale_factor)