From b10449af243023cab048fcc79682130aad07d703 Mon Sep 17 00:00:00 2001 From: xiaoyufenfei <1274737326@qq.com> Date: Wed, 11 Dec 2019 16:38:59 +0800 Subject: [PATCH] update --- model/ESNet.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/model/ESNet.py b/model/ESNet.py index ff3b207..12dbaad 100644 --- a/model/ESNet.py +++ b/model/ESNet.py @@ -7,16 +7,26 @@ import torch.nn as nn import torch.nn.functional as F -class DownsamplerBlock(nn.Module): - def __init__(self, ninput, noutput): - super().__init__() - self.conv = nn.Conv2d(ninput, noutput-ninput, (3,3), stride=2, padding=1, bias=True) +class DownsamplerBlock (nn.Module): + def __init__(self, in_channel, out_channel): + super(DownsamplerBlock,self).__init__() + + self.conv = nn.Conv2d(in_channel, out_channel-in_channel, (3, 3), stride=2, padding=1, bias=True) self.pool = nn.MaxPool2d(2, stride=2) - self.bn = nn.BatchNorm2d(noutput, eps=1e-3) - self.relu = nn.ReLU(inplace = True) + self.bn = nn.BatchNorm2d(out_channel, eps=1e-3) + self.relu = nn.ReLU(inplace=True) def forward(self, input): - output = torch.cat([self.conv(input), self.pool(input)], 1) + x1 = self.pool(input) + x2 = self.conv(input) + + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + + output = torch.cat([x2, x1], 1) output = self.bn(output) output = self.relu(output) return output