Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaoyufenfei committed Dec 11, 2019
1 parent e4bde5e commit b10449a
Showing 1 changed file with 17 additions and 7 deletions.
24 changes: 17 additions & 7 deletions model/ESNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,26 @@
import torch.nn as nn
import torch.nn.functional as F

class DownsamplerBlock(nn.Module):
def __init__(self, ninput, noutput):
super().__init__()
self.conv = nn.Conv2d(ninput, noutput-ninput, (3,3), stride=2, padding=1, bias=True)
class DownsamplerBlock (nn.Module):
def __init__(self, in_channel, out_channel):
super(DownsamplerBlock,self).__init__()

self.conv = nn.Conv2d(in_channel, out_channel-in_channel, (3, 3), stride=2, padding=1, bias=True)
self.pool = nn.MaxPool2d(2, stride=2)
self.bn = nn.BatchNorm2d(noutput, eps=1e-3)
self.relu = nn.ReLU(inplace = True)
self.bn = nn.BatchNorm2d(out_channel, eps=1e-3)
self.relu = nn.ReLU(inplace=True)

def forward(self, input):
output = torch.cat([self.conv(input), self.pool(input)], 1)
x1 = self.pool(input)
x2 = self.conv(input)

diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]

x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])

output = torch.cat([x2, x1], 1)
output = self.bn(output)
output = self.relu(output)
return output
Expand Down

0 comments on commit b10449a

Please sign in to comment.