-
Notifications
You must be signed in to change notification settings - Fork 0
/
slide_inference.py
54 lines (50 loc) · 2.12 KB
/
slide_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
import torch
import torch.nn as nn
import torch.nn.functional as F
def slide_inference(self, img, img_meta, rescale):
"""Inference by sliding-window with overlap.
If h_crop > h_img or w_crop > w_img, the small patch will be used to
decode without padding.
"""
h_stride, w_stride = (128,128)
h_crop, w_crop = (256,256)
# batch_size, _, h_img, w_img = img.size()
batch_size =1
h_img = 512
w_img = 512
num_classes = 1
h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1
w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
crop_seg_logit = self.encode_decode(crop_img, img_meta)
print(type(crop_seg_logit))
print(crop_seg_logit)
print(crop_seg_logit.shape)
preds += F.pad(crop_seg_logit,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
# if torch.onnx.is_in_onnx_export():
# # cast count_mat to constant while exporting to ONNX
# count_mat = torch.from_numpy(
# count_mat.cpu().detach().numpy()).to(device=img.device)
preds = preds / count_mat
# if rescale:
# preds = resize(
# preds,
# size=img_meta[0]['ori_shape'][:2],
# mode='bilinear',
# align_corners=self.align_corners,
# warning=False)
return preds