-
Notifications
You must be signed in to change notification settings - Fork 793
/
segnet.py
73 lines (57 loc) · 2.58 KB
/
segnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch.nn as nn
from ptsemseg.models.utils import segnetDown2, segnetDown3, segnetUp2, segnetUp3
class segnet(nn.Module):
def __init__(self, n_classes=21, in_channels=3, is_unpooling=True):
super(segnet, self).__init__()
self.in_channels = in_channels
self.is_unpooling = is_unpooling
self.down1 = segnetDown2(self.in_channels, 64)
self.down2 = segnetDown2(64, 128)
self.down3 = segnetDown3(128, 256)
self.down4 = segnetDown3(256, 512)
self.down5 = segnetDown3(512, 512)
self.up5 = segnetUp3(512, 512)
self.up4 = segnetUp3(512, 256)
self.up3 = segnetUp3(256, 128)
self.up2 = segnetUp2(128, 64)
self.up1 = segnetUp2(64, n_classes)
def forward(self, inputs):
down1, indices_1, unpool_shape1 = self.down1(inputs)
down2, indices_2, unpool_shape2 = self.down2(down1)
down3, indices_3, unpool_shape3 = self.down3(down2)
down4, indices_4, unpool_shape4 = self.down4(down3)
down5, indices_5, unpool_shape5 = self.down5(down4)
up5 = self.up5(down5, indices_5, unpool_shape5)
up4 = self.up4(up5, indices_4, unpool_shape4)
up3 = self.up3(up4, indices_3, unpool_shape3)
up2 = self.up2(up3, indices_2, unpool_shape2)
up1 = self.up1(up2, indices_1, unpool_shape1)
return up1
def init_vgg16_params(self, vgg16):
blocks = [self.down1, self.down2, self.down3, self.down4, self.down5]
features = list(vgg16.features.children())
vgg_layers = []
for _layer in features:
if isinstance(_layer, nn.Conv2d):
vgg_layers.append(_layer)
merged_layers = []
for idx, conv_block in enumerate(blocks):
if idx < 2:
units = [conv_block.conv1.cbr_unit, conv_block.conv2.cbr_unit]
else:
units = [
conv_block.conv1.cbr_unit,
conv_block.conv2.cbr_unit,
conv_block.conv3.cbr_unit,
]
for _unit in units:
for _layer in _unit:
if isinstance(_layer, nn.Conv2d):
merged_layers.append(_layer)
assert len(vgg_layers) == len(merged_layers)
for l1, l2 in zip(vgg_layers, merged_layers):
if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
assert l1.weight.size() == l2.weight.size()
assert l1.bias.size() == l2.bias.size()
l2.weight.data = l1.weight.data
l2.bias.data = l1.bias.data