-
Notifications
You must be signed in to change notification settings - Fork 9
/
model.py
50 lines (41 loc) · 2.16 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import torch
import torch.nn.functional as F
from torch import nn
class Combiner(nn.Module):
"""
Combiner module which once trained fuses textual and visual information
"""
def __init__(self, clip_feature_dim: int, projection_dim: int, hidden_dim: int):
"""
:param clip_feature_dim: CLIP input feature dimension
:param projection_dim: projection dimension
:param hidden_dim: hidden dimension
"""
super(Combiner, self).__init__()
self.text_projection_layer = nn.Linear(clip_feature_dim, projection_dim)
self.image_projection_layer = nn.Linear(clip_feature_dim, projection_dim)
self.dropout1 = nn.Dropout(0.5)
self.dropout2 = nn.Dropout(0.5)
self.combiner_layer = nn.Linear(projection_dim * 2, hidden_dim)
self.output_layer = nn.Linear(hidden_dim, clip_feature_dim)
self.dropout3 = nn.Dropout(0.5)
self.dynamic_scalar = nn.Sequential(nn.Linear(projection_dim * 2, hidden_dim), nn.ReLU(), nn.Dropout(0.5),
nn.Linear(hidden_dim, 1),
nn.Sigmoid())
self.logit_scale = 100
@torch.jit.export
def combine_features(self, image_features, text_features):
"""
Cobmine the reference image features and the caption features. It outputs the predicted features
:param image_features: CLIP reference image features
:param text_features: CLIP relative caption features
:return: predicted features
"""
text_projected_features = self.dropout1(F.relu(self.text_projection_layer(text_features)))
image_projected_features = self.dropout2(F.relu(self.image_projection_layer(image_features)))
raw_combined_features = torch.cat((text_projected_features, image_projected_features), -1)
combined_features = self.dropout3(F.relu(self.combiner_layer(raw_combined_features)))
dynamic_scalar = self.dynamic_scalar(raw_combined_features)
output = self.output_layer(combined_features) + dynamic_scalar * text_features + (
1 - dynamic_scalar) * image_features
return F.normalize(output)