Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to resolve bugs in finetuning for retrieval on MSRVTT dataset #3

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions OATrans/configs/ft/msrvtt/fine_tune/normal_1_cl.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,8 @@
"two_outputs": false,
"object_pseudo_label": false
},
"object_params": {
"model": "",
"input_objects": false
},
"text_params": {
"model": "pretrained/distilbert-base-uncased",
"model": "distilbert-base-uncased",
"pretrained": true,
"input": "text",
"two_outputs": false
Expand Down
46 changes: 23 additions & 23 deletions OATrans/data_loader/data_loader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from OATrans.base import BaseDataLoaderExplicitSplit, DistBaseDataLoaderExplicitSplit, MultiDistBaseDataLoaderExplicitSplit, BaseMultiDataLoader
from OATrans.data_loader.data_loader import init_transform_dict
from OATrans.data_loader.data_loader import ConceptualCaptions3M
from OATrans.data_loader.data_loader import MSRVTT
from OATrans.data_loader.data_loader import LSMDC
from OATrans.data_loader.data_loader import WebVidObject
from OATrans.data_loader.data_loader import MSVD
from OATrans.data_loader.data_loader import DiDeMo
from OATrans.data_loader.transforms import init_transform_dict
from OATrans.data_loader.ConceptualCaptions_dataset import ConceptualCaptions3M
from OATrans.data_loader.MSRVTT_dataset import MSRVTT
from OATrans.data_loader.LSMDC_dataset import LSMDC
from OATrans.data_loader.WebVid_dataset import WebVidObject
from OATrans.data_loader.MSVD_dataset import MSVD
from OATrans.data_loader.DiDeMo_dataset import DiDeMo


def dataset_loader(dataset_name,
Expand Down Expand Up @@ -38,22 +38,22 @@ def dataset_loader(dataset_name,
# ...is this safe / or just lazy?
if dataset_name == "MSRVTT":
dataset = MSRVTT(**kwargs)
elif dataset_name == "SomethingSomethingV2":
dataset = SomethingSomethingV2(**kwargs)
elif dataset_name == "WebVid":
dataset = WebVid(**kwargs)
elif dataset_name == "ConceptualCaptions3M":
dataset = ConceptualCaptions3M(**kwargs)
elif dataset_name == "ConceptualCaptions12M":
dataset = ConceptualCaptions12M(**kwargs)
elif dataset_name == "LSMDC":
dataset = LSMDC(**kwargs)
elif dataset_name == "COCOCaptions":
dataset = COCOCaptions(**kwargs)
elif dataset_name == "MSVD":
dataset = MSVD(**kwargs)
else:
raise NotImplementedError(f"Dataset: {dataset_name} not found.")
# elif dataset_name == "SomethingSomethingV2":
# dataset = SomethingSomethingV2(**kwargs)
# elif dataset_name == "WebVid":
# dataset = WebVid(**kwargs)
# elif dataset_name == "ConceptualCaptions3M":
# dataset = ConceptualCaptions3M(**kwargs)
# elif dataset_name == "ConceptualCaptions12M":
# dataset = ConceptualCaptions12M(**kwargs)
# elif dataset_name == "LSMDC":
# dataset = LSMDC(**kwargs)
# elif dataset_name == "COCOCaptions":
# dataset = COCOCaptions(**kwargs)
# elif dataset_name == "MSVD":
# dataset = MSVD(**kwargs)
# else:
# raise NotImplementedError(f"Dataset: {dataset_name} not found.")

return dataset

Expand Down
14 changes: 14 additions & 0 deletions OATrans/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ def forward(self, data, return_embeds=True):
text_data = data['text']
video_data = data['video']

# print('video data:', type(video_data), video_data.shape)

text_embeddings = self.compute_text(text_data)
video_embeddings = self.compute_video(video_data)

Expand All @@ -112,7 +114,15 @@ def compute_text(self, text_data):
return text_embeddings

def compute_video(self, video_data):

# print('In compute video, video data:', type(video_data), video_data.shape)

video_embeddings = self.video_model(video_data)

# print('After passing to video model, video embeds:', type(video_embeddings), len(video_embeddings))
# print(type(video_embeddings[0]), video_embeddings[0].shape, type(video_embeddings[1]), video_embeddings[1].shape)

video_embeddings = video_embeddings[0]
video_embeddings = self.vid_proj(video_embeddings)
return video_embeddings

Expand Down Expand Up @@ -168,6 +178,10 @@ def sim_matrix(a, b, eps=1e-8):
a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None]
a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n))
b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n))

# print('In sim matrix, a norm:', type(a_norm), a.shape, 'b norm:', type(b_norm), b_norm.shape)
# print('b norm transposed:', type(b_norm.transpose(0, 1)), b_norm.transpose(0, 1).shape)

sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1))
return sim_mt

Expand Down
9 changes: 8 additions & 1 deletion OATrans/parse_config_dist_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,13 @@ def initialize(self, name, module, *args, index=None, **kwargs):
module_name = self[name][index]['type']
module_args = dict(self[name][index]['args'])

# print('module name:', module_name)
# print('module args:', module_args)
# print('module:', module)

# if parameter not in config subdict, then check if it's in global config.
signature = inspect.signature(getattr(module, module_name).__init__)
print(module_name)
# signature = inspect.signature(getattr(OATrans.model.model, 'FrozenInTime').__init__)
for param in signature.parameters.keys():
if param not in module_args and param in self.config:
module_args[param] = self[param]
Expand All @@ -97,6 +101,9 @@ def initialize(self, name, module, *args, index=None, **kwargs):
if module_name == 'TextObjectVideoDataLoader' and param == 'args':
module_args[param] = self.args

# print('args:', args)
# print('module args:', module_args)

return getattr(module, module_name)(*args, **module_args)

def __getitem__(self, name):
Expand Down
2 changes: 1 addition & 1 deletion OATrans/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import argparse
import collections
from OATrans.data_loader import data_loader as module_data
from OATrans import model as module_loss, model as module_metric, model as module_arch
from OATrans.model import loss as module_loss, metric as module_metric, model as module_arch
import utils.visualizer as module_vis
from utils.util import replace_nested_dict_item
from parse_config_dist_multi import ConfigParser
Expand Down
5 changes: 5 additions & 0 deletions OATrans/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,12 @@ def _valid_epoch(self, epoch):
# This avoids using `DataParallel` in this case, and supposes the entire batch fits in one GPU.
text_embed, vid_embed = self.model.module(data, return_embeds=True)
else:
# print('data:', len(data))
text_embed, vid_embed = self.model(data, return_embeds=True)

# print('In _valid_epoch, text embed:', type(text_embed), text_embed.shape,
# 'video embed:', type(vid_embed), vid_embed.shape)

text_embed_arr[dl_idx].append(text_embed.cpu())
vid_embed_arr[dl_idx].append(vid_embed.cpu())
sims_batch = sim_matrix(text_embed, vid_embed)
Expand Down