-
Notifications
You must be signed in to change notification settings - Fork 0
/
zdc_utils.py
67 lines (47 loc) · 2.61 KB
/
zdc_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
import numpy as np
def sum_channels_parallel_numpy(data):
coords = np.ogrid[0:data.shape[1], 0:data.shape[2]]
half_x = data.shape[1] // 2
half_y = data.shape[2] // 2
checkerboard = (coords[0] + coords[1]) % 2 != 0
checkerboard = checkerboard.reshape(-1, checkerboard.shape[0], checkerboard.shape[1])
ch5 = (data * checkerboard).sum(axis=1).sum(axis=1)
checkerboard = (coords[0] + coords[1]) % 2 == 0
checkerboard = checkerboard.reshape(-1, checkerboard.shape[0], checkerboard.shape[1])
mask = np.zeros((1, data.shape[1], data.shape[2]))
mask[:, :half_x, :half_y] = checkerboard[:, :half_x, :half_y]
ch1 = (data * mask).sum(axis=1).sum(axis=1)
mask = np.zeros((1, data.shape[1], data.shape[2]))
mask[:, :half_x, half_y:] = checkerboard[:, :half_x, half_y:]
ch2 = (data * mask).sum(axis=1).sum(axis=1)
mask = np.zeros((1, data.shape[1], data.shape[2]))
mask[:, half_x:, :half_y] = checkerboard[:, half_x:, :half_y]
ch3 = (data * mask).sum(axis=1).sum(axis=1)
mask = np.zeros((1, data.shape[1], data.shape[2]))
mask[:, half_x:, half_y:] = checkerboard[:, half_x:, half_y:]
ch4 = (data * mask).sum(axis=1).sum(axis=1)
return zip(ch1, ch2, ch3, ch4, ch5)
def sum_channels_parallel_pytorch(data, args):
coords = torch.arange(data.shape[1]).reshape(-1, 1).to(args["device"]), \
torch.arange(data.shape[2]).reshape(1, -1).to(args["device"])
half_x = data.shape[1] // 2
half_y = data.shape[2] // 2
checkerboard = (coords[0] + coords[1]) % 2 != 0
checkerboard = checkerboard.reshape(-1, checkerboard.shape[0], checkerboard.shape[1])
ch5 = (data * checkerboard).sum(axis=1).sum(axis=1)
checkerboard = (coords[0] + coords[1]) % 2 == 0
checkerboard = checkerboard.reshape(-1, checkerboard.shape[0], checkerboard.shape[1])
mask = torch.zeros((1, data.shape[1], data.shape[2])).to(args["device"])
mask[:, :half_x, :half_y] = checkerboard[:, :half_x, :half_y]
ch1 = (data * mask).sum(axis=1).sum(axis=1)
mask = torch.zeros((1, data.shape[1], data.shape[2])).to(args["device"])
mask[:, :half_x, half_y:] = checkerboard[:, :half_x, half_y:]
ch2 = (data * mask).sum(axis=1).sum(axis=1)
mask = torch.zeros((1, data.shape[1], data.shape[2])).to(args["device"])
mask[:, half_x:, :half_y] = checkerboard[:, half_x:, :half_y]
ch3 = (data * mask).sum(axis=1).sum(axis=1)
mask = torch.zeros((1, data.shape[1], data.shape[2])).to(args["device"])
mask[:, half_x:, half_y:] = checkerboard[:, half_x:, half_y:]
ch4 = (data * mask).sum(axis=1).sum(axis=1)
return zip(ch1, ch2, ch3, ch4, ch5)