-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathslot_attn.py
118 lines (93 loc) · 5.33 KB
/
slot_attn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from utils import *
class SlotAttention(nn.Module):
def __init__(self, num_iterations, num_slots,
input_size, slot_size, mlp_hidden_size, heads,
epsilon=1e-8):
super().__init__()
self.num_iterations = num_iterations
self.num_slots = num_slots
self.input_size = input_size
self.slot_size = slot_size
self.mlp_hidden_size = mlp_hidden_size
self.epsilon = epsilon
self.num_heads = heads
self.norm_inputs = nn.LayerNorm(input_size)
self.norm_slots = nn.LayerNorm(slot_size)
self.norm_mlp = nn.LayerNorm(slot_size)
# Linear maps for the attention module.
self.project_q = linear(slot_size, slot_size, bias=False)
self.project_k = linear(input_size, slot_size, bias=False)
self.project_v = linear(input_size, slot_size, bias=False)
# Slot update functions.
self.gru = gru_cell(slot_size, slot_size)
self.mlp = nn.Sequential(
linear(slot_size, mlp_hidden_size, weight_init='kaiming'),
nn.ReLU(),
linear(mlp_hidden_size, slot_size))
def forward(self, inputs, slots):
# `inputs` has shape [batch_size, num_inputs, input_size].
# `slots` has shape [batch_size, num_slots, slot_size].
B, N_kv, D_inp = inputs.size()
B, N_q, D_slot = slots.size()
inputs = self.norm_inputs(inputs)
k = self.project_k(inputs).view(B, N_kv, self.num_heads, -1).transpose(1, 2) # Shape: [batch_size, num_heads, num_inputs, slot_size // num_heads].
v = self.project_v(inputs).view(B, N_kv, self.num_heads, -1).transpose(1, 2) # Shape: [batch_size, num_heads, num_inputs, slot_size // num_heads].
k = ((self.slot_size // self.num_heads) ** (-0.5)) * k
# Multiple rounds of attention.
for _ in range(self.num_iterations):
slots_prev = slots
slots = self.norm_slots(slots)
# Attention.
q = self.project_q(slots).view(B, N_q, self.num_heads, -1).transpose(1, 2) # Shape: [batch_size, num_heads, num_slots, slot_size // num_heads].
attn_logits = torch.matmul(k, q.transpose(-1, -2)) # Shape: [batch_size, num_heads, num_inputs, num_slots].
attn = F.softmax(
attn_logits.transpose(1, 2).reshape(B, N_kv, self.num_heads * N_q)
, dim=-1).view(B, N_kv, self.num_heads, N_q).transpose(1, 2) # Shape: [batch_size, num_heads, num_inputs, num_slots].
attn_vis = attn.sum(1) # Shape: [batch_size, num_inputs, num_slots].
# Weighted mean.
attn = attn + self.epsilon
attn = attn / torch.sum(attn, dim=-2, keepdim=True)
updates = torch.matmul(attn.transpose(-1, -2), v) # Shape: [batch_size, num_heads, num_slots, slot_size // num_heads].
updates = updates.transpose(1, 2).reshape(B, N_q, -1) # Shape: [batch_size, num_slots, slot_size].
# Slot update.
slots = self.gru(updates.view(-1, self.slot_size),
slots_prev.view(-1, self.slot_size))
slots = slots.view(-1, self.num_slots, self.slot_size)
slots = slots + self.mlp(self.norm_mlp(slots))
return slots, attn_vis
class SlotAttentionEncoder(nn.Module):
def __init__(self, num_iterations, num_slots,
input_channels, slot_size, mlp_hidden_size, pos_channels, num_heads):
super().__init__()
self.num_iterations = num_iterations
self.num_slots = num_slots
self.input_channels = input_channels
self.slot_size = slot_size
self.mlp_hidden_size = mlp_hidden_size
self.pos_channels = pos_channels
self.layer_norm = nn.LayerNorm(input_channels)
self.mlp = nn.Sequential(
linear(input_channels, input_channels, weight_init='kaiming'),
nn.ReLU(),
linear(input_channels, input_channels))
# Parameters for Gaussian init (shared by all slots).
self.slot_mu = nn.Parameter(torch.zeros(1, 1, slot_size))
self.slot_log_sigma = nn.Parameter(torch.zeros(1, 1, slot_size))
nn.init.xavier_uniform_(self.slot_mu)
nn.init.xavier_uniform_(self.slot_log_sigma)
self.slot_attention = SlotAttention(
num_iterations, num_slots,
input_channels, slot_size, mlp_hidden_size, num_heads)
def forward(self, x):
# `image` has shape: [batch_size, img_channels, img_height, img_width].
# `encoder_grid` has shape: [batch_size, pos_channels, enc_height, enc_width].
B, *_ = x.size()
x = self.mlp(self.layer_norm(x))
# `x` has shape: [batch_size, enc_height * enc_width, cnn_hidden_size].
# Slot Attention module.
slots = x.new_empty(B, self.num_slots, self.slot_size).normal_()
slots = self.slot_mu + torch.exp(self.slot_log_sigma) * slots
slots, attn = self.slot_attention(x, slots)
# `slots` has shape: [batch_size, num_slots, slot_size].
# `attn` has shape: [batch_size, enc_height * enc_width, num_slots].
return slots, attn