forked from bojone/Capsule
-
Notifications
You must be signed in to change notification settings - Fork 2
/
capsulelayer.py
111 lines (95 loc) · 4.32 KB
/
capsulelayer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#! -*- coding: utf-8 -*-
# refer: https://kexue.fm/archives/5112
import tensorflow as tf
from tensorflow.keras import activations
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Layer
def squash(x, axis=-1):
s_squared_norm = K.sum(K.square(x), axis, keepdims=True) + K.epsilon()
scale = K.sqrt(s_squared_norm)/ (0.5 + s_squared_norm)
return scale * x
#define our own softmax function instead of K.softmax
def softmax(x, axis=-1):
ex = K.exp(x - K.max(x, axis=axis, keepdims=True))
return ex/K.sum(ex, axis=axis, keepdims=True)
def margin_loss(y_true, y_pred):
"""
Margin loss for Eq.(4). When y_true[i, :] contains not just one `1`, this loss should work too. Not test it.
:param y_true: [None, n_classes]
:param y_pred: [None, num_capsule]
:return: a scalar loss value.
"""
#from older version #loss 2
#L = y_true * K.square(K.maximum(0., 0.9 - y_pred)) + \
# 0.5 * (1 - y_true) * K.square(K.maximum(0., y_pred - 0.1))
#return K.mean(K.sum(L, 1))
#according to paper #loss 1
#return y_true*K.relu(0.9-y_pred)**2 + 0.5*(1-y_true)*K.relu(y_pred-0.1)**2
#loss 0
return y_true*K.relu(0.9-y_pred)**2 + 0.25*(1-y_true)*K.relu(y_pred-0.1)**2
#A Capsule Implement with Pure Keras
class Capsule(Layer):
def __init__(self, num_capsule, dim_capsule, routings=3, share_weights=True, activation='squash', **kwargs):
super(Capsule, self).__init__(**kwargs)
self.num_capsule = num_capsule
self.dim_capsule = dim_capsule
self.routings = routings
self.share_weights = share_weights
if activation == 'squash':
self.activation = squash
else:
self.activation = activations.get(activation)
def build(self, input_shape):
super(Capsule, self).build(input_shape)
input_dim_capsule = input_shape[-1]
if self.share_weights:
self.W = self.add_weight(name='capsule_kernel',
shape=(1, input_dim_capsule,
self.num_capsule * self.dim_capsule),
initializer='glorot_uniform',
trainable=True)
else:
input_num_capsule = input_shape[-2]
self.W = self.add_weight(name='capsule_kernel',
shape=(input_num_capsule,
input_dim_capsule,
self.num_capsule * self.dim_capsule),
initializer='glorot_uniform',
trainable=True)
def call(self, u_vecs):
if self.share_weights:
u_hat_vecs = K.conv1d(u_vecs, self.W)
else:
u_hat_vecs = K.local_conv1d(u_vecs, self.W, [1], [1])
batch_size = K.shape(u_vecs)[0]
input_num_capsule = K.shape(u_vecs)[1]
u_hat_vecs = K.reshape(u_hat_vecs, (batch_size, input_num_capsule,
self.num_capsule, self.dim_capsule))
u_hat_vecs = K.permute_dimensions(u_hat_vecs, (0, 2, 1, 3))
#final u_hat_vecs.shape = [None, num_capsule, input_num_capsule, dim_capsule]
b = K.zeros_like(u_hat_vecs[:,:,:,0]) #shape = [None, num_capsule, input_num_capsule]
for i in range(self.routings):
c = softmax(b, 1)
# o = K.batch_dot(c, u_hat_vecs, [2, 2])
o = tf.einsum('bin,binj->bij', c, u_hat_vecs)
if K.backend() == 'theano':
o = K.sum(o, axis=1)
if i < self.routings - 1:
o = K.l2_normalize(o, -1)
# b = K.batch_dot(o, u_hat_vecs, [2, 3])
b = tf.einsum('bij,binj->bin', o, u_hat_vecs)
if K.backend() == 'theano':
b = K.sum(b, axis=1)
return self.activation(o)
def compute_output_shape(self, input_shape):
return (None, self.num_capsule, self.dim_capsule)
def get_config(self):
config = super().get_config().copy()
config.update({
'num_capsule': self.num_capsule,
'dim_capsule': self.dim_capsule,
'routings': self.routings,
'share_weights': self.share_weights,
'activation': self.activation,
})
return config