-
Notifications
You must be signed in to change notification settings - Fork 0
/
resblocktensorflow.py
37 lines (27 loc) · 1.29 KB
/
resblocktensorflow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#imports
import tensorflow as tf
import tensorflow.keras.backend as K
#############################################################
class ResBlock(tf.keras.layers.Layer):
def __init__(self, filters, kernel_size, **kwargs):
self.filters = filters
self.kernel_size = kernel_size
super(ResBlock, self).__init__(**kwargs)
def build(self, input_shape):
self.conv2d_w1 = self.add_weight("conv2d_w1", self.kernel_size + (self.filters, self.filters), initializer='glorot_uniform')
self.conv2d_w2 = self.add_weight("conv2d_w2", self.kernel_size + (self.filters, self.filters), initializer='glorot_uniform')
self.conv2d_b1 = self.add_weight("conv2d_b1", (self.filters,), initializer='zero')
self.conv2d_b2 = self.add_weight("conv2d_b2", (self.filters,), initializer='zero')
super(ResBlock, self).build(input_shape)
def call(self, x):
y = K.conv2d(x, self.conv2d_w1, padding="same")
y = K.bias_add(y, self.conv2d_b1)
y = K.relu(y)
y = K.conv2d(y, self.conv2d_w2, padding="same")
y = K.bias_add(y, self.conv2d_b2)
y = K.relu(y)
y = y + x
return y
def compute_output_shape(self, input_shape):
return input_shape
#############################################################