-
Notifications
You must be signed in to change notification settings - Fork 0
/
DS1_model_retrieval_multi_task.py
103 lines (84 loc) · 2.79 KB
/
DS1_model_retrieval_multi_task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
from torch import nn
from torch.nn import functional as F
import math
class GELU(nn.Module):
def forward(self, x):
gelu = 0.5 * x * (1 + F.tanh(math.sqrt(2/math.pi) * (x + 0.044715 * x.pow(3))))
return gelu
class ProjectionHead(nn.Module):
def __init__(
self,
embedding_dim,
projection_dim,
dropout=0.5
):
super().__init__()
self.fc_embed_1 = nn.Linear(embedding_dim, projection_dim)
self.fc_embed_2 = nn.Linear(projection_dim, projection_dim)
self.dropout = nn.Dropout(dropout)
self.gelu = GELU()
self.layer_norm = nn.LayerNorm(projection_dim)
def forward(self, x):
projected = self.fc_embed_1(x)
x = self.gelu(projected)
x = self.fc_embed_2(x)
x = self.dropout(x)
x = self.gelu(x)
x = x + projected
x = self.layer_norm(x)
x = self.gelu(x)
x = F.normalize(x, p=2, dim=1)
return x
class video_branch(nn.Module):
def __init__(
self,
video_dim
):
super().__init__()
# video branch
self.video_fc = nn.Linear(video_dim, 64)
def forward(self, x):
x = self.video_fc(x)
return x
class audio_branch(nn.Module):
def __init__(
self,
audio_dim
):
super().__init__()
# audio branch
self.audio_fc = nn.Linear(audio_dim, 64)
def forward(self, x):
x = self.audio_fc(x)
return x
class embedding_network(nn.Module):
def __init__(
self,
video_dim = 2304,
audio_dim = 128,
):
super().__init__()
# video branch
self.video_br = video_branch(video_dim)
self.out_video_fc = nn.Linear(64, 5)
# audio branch
self.audio_br= audio_branch(audio_dim)
self.out_audio_fc = nn.Linear(64, 5)
# distance learning model
self.video_projection = ProjectionHead(embedding_dim=64, projection_dim=64)
self.audio_projection = ProjectionHead(embedding_dim=64, projection_dim=64)
self.cosine_sim = nn.CosineSimilarity(dim=1)
self.gelu = GELU()
def forward(self, video_features, audio_features):
# video branch
video_iter = self.video_br(video_features)
out_video = self.out_video_fc(self.gelu(video_iter))
# audio branch
audio_iter = self.audio_br(audio_features)
out_audio = self.out_audio_fc(self.gelu(audio_iter))
# cross modal distance learning
normL2_video_embeddings = self.video_projection(video_iter)
normL2_audio_embeddings = self.audio_projection(audio_iter)
cosine_sim = self.cosine_sim(normL2_video_embeddings, normL2_audio_embeddings)
return cosine_sim, out_video, out_audio