-
Notifications
You must be signed in to change notification settings - Fork 0
/
se.py
28 lines (22 loc) · 910 Bytes
/
se.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
from keras.layers import GlobalAveragePooling2D, Reshape, Dense, multiply, Permute
from keras import backend as K
def squeeze_excite_block(input, ratio=16):
''' Create a squeeze-excite block
Args:
input: input tensor
filters: number of output filters
k: width factor
Returns: a keras tensor
'''
init = input
channel_axis = 1 if K.image_data_format() == "channels_first" else -1
filters = init._keras_shape[channel_axis]
se_shape = (1, 1, filters)
se = GlobalAveragePooling2D()(init)
se = Reshape(se_shape)(se)
se = Dense(filters // ratio, activation='relu', kernel_initializer='he_normal', use_bias=False)(se)
se = Dense(filters, activation='sigmoid', kernel_initializer='he_normal', use_bias=False)(se)
if K.image_data_format() == 'channels_first':
se = Permute((3, 1, 2))(se)
x = multiply([init, se])
return x