From c130e7d89176b669fd438f7c6e6e632e52defa45 Mon Sep 17 00:00:00 2001 From: sagarsj42 Date: Sun, 24 Apr 2022 00:29:02 +0530 Subject: [PATCH] Changes to resolve bugs in finetuning for retrieval on MSRVTT dataset --- .../ft/msrvtt/fine_tune/normal_1_cl.json | 6 +-- OATrans/data_loader/data_loader.py | 46 +++++++++---------- OATrans/model/model.py | 14 ++++++ OATrans/parse_config_dist_multi.py | 9 +++- OATrans/train.py | 2 +- OATrans/trainer/trainer.py | 5 ++ 6 files changed, 52 insertions(+), 30 deletions(-) diff --git a/OATrans/configs/ft/msrvtt/fine_tune/normal_1_cl.json b/OATrans/configs/ft/msrvtt/fine_tune/normal_1_cl.json index 18e1dcb..aa53702 100644 --- a/OATrans/configs/ft/msrvtt/fine_tune/normal_1_cl.json +++ b/OATrans/configs/ft/msrvtt/fine_tune/normal_1_cl.json @@ -16,12 +16,8 @@ "two_outputs": false, "object_pseudo_label": false }, - "object_params": { - "model": "", - "input_objects": false - }, "text_params": { - "model": "pretrained/distilbert-base-uncased", + "model": "distilbert-base-uncased", "pretrained": true, "input": "text", "two_outputs": false diff --git a/OATrans/data_loader/data_loader.py b/OATrans/data_loader/data_loader.py index 8f63ac0..1b176db 100644 --- a/OATrans/data_loader/data_loader.py +++ b/OATrans/data_loader/data_loader.py @@ -1,11 +1,11 @@ from OATrans.base import BaseDataLoaderExplicitSplit, DistBaseDataLoaderExplicitSplit, MultiDistBaseDataLoaderExplicitSplit, BaseMultiDataLoader -from OATrans.data_loader.data_loader import init_transform_dict -from OATrans.data_loader.data_loader import ConceptualCaptions3M -from OATrans.data_loader.data_loader import MSRVTT -from OATrans.data_loader.data_loader import LSMDC -from OATrans.data_loader.data_loader import WebVidObject -from OATrans.data_loader.data_loader import MSVD -from OATrans.data_loader.data_loader import DiDeMo +from OATrans.data_loader.transforms import init_transform_dict +from OATrans.data_loader.ConceptualCaptions_dataset import ConceptualCaptions3M +from OATrans.data_loader.MSRVTT_dataset import MSRVTT +from OATrans.data_loader.LSMDC_dataset import LSMDC +from OATrans.data_loader.WebVid_dataset import WebVidObject +from OATrans.data_loader.MSVD_dataset import MSVD +from OATrans.data_loader.DiDeMo_dataset import DiDeMo def dataset_loader(dataset_name, @@ -38,22 +38,22 @@ def dataset_loader(dataset_name, # ...is this safe / or just lazy? if dataset_name == "MSRVTT": dataset = MSRVTT(**kwargs) - elif dataset_name == "SomethingSomethingV2": - dataset = SomethingSomethingV2(**kwargs) - elif dataset_name == "WebVid": - dataset = WebVid(**kwargs) - elif dataset_name == "ConceptualCaptions3M": - dataset = ConceptualCaptions3M(**kwargs) - elif dataset_name == "ConceptualCaptions12M": - dataset = ConceptualCaptions12M(**kwargs) - elif dataset_name == "LSMDC": - dataset = LSMDC(**kwargs) - elif dataset_name == "COCOCaptions": - dataset = COCOCaptions(**kwargs) - elif dataset_name == "MSVD": - dataset = MSVD(**kwargs) - else: - raise NotImplementedError(f"Dataset: {dataset_name} not found.") + # elif dataset_name == "SomethingSomethingV2": + # dataset = SomethingSomethingV2(**kwargs) + # elif dataset_name == "WebVid": + # dataset = WebVid(**kwargs) + # elif dataset_name == "ConceptualCaptions3M": + # dataset = ConceptualCaptions3M(**kwargs) + # elif dataset_name == "ConceptualCaptions12M": + # dataset = ConceptualCaptions12M(**kwargs) + # elif dataset_name == "LSMDC": + # dataset = LSMDC(**kwargs) + # elif dataset_name == "COCOCaptions": + # dataset = COCOCaptions(**kwargs) + # elif dataset_name == "MSVD": + # dataset = MSVD(**kwargs) + # else: + # raise NotImplementedError(f"Dataset: {dataset_name} not found.") return dataset diff --git a/OATrans/model/model.py b/OATrans/model/model.py index 84b19a5..47ea6d9 100644 --- a/OATrans/model/model.py +++ b/OATrans/model/model.py @@ -92,6 +92,8 @@ def forward(self, data, return_embeds=True): text_data = data['text'] video_data = data['video'] + # print('video data:', type(video_data), video_data.shape) + text_embeddings = self.compute_text(text_data) video_embeddings = self.compute_video(video_data) @@ -112,7 +114,15 @@ def compute_text(self, text_data): return text_embeddings def compute_video(self, video_data): + + # print('In compute video, video data:', type(video_data), video_data.shape) + video_embeddings = self.video_model(video_data) + + # print('After passing to video model, video embeds:', type(video_embeddings), len(video_embeddings)) + # print(type(video_embeddings[0]), video_embeddings[0].shape, type(video_embeddings[1]), video_embeddings[1].shape) + + video_embeddings = video_embeddings[0] video_embeddings = self.vid_proj(video_embeddings) return video_embeddings @@ -168,6 +178,10 @@ def sim_matrix(a, b, eps=1e-8): a_n, b_n = a.norm(dim=1)[:, None], b.norm(dim=1)[:, None] a_norm = a / torch.max(a_n, eps * torch.ones_like(a_n)) b_norm = b / torch.max(b_n, eps * torch.ones_like(b_n)) + + # print('In sim matrix, a norm:', type(a_norm), a.shape, 'b norm:', type(b_norm), b_norm.shape) + # print('b norm transposed:', type(b_norm.transpose(0, 1)), b_norm.transpose(0, 1).shape) + sim_mt = torch.mm(a_norm, b_norm.transpose(0, 1)) return sim_mt diff --git a/OATrans/parse_config_dist_multi.py b/OATrans/parse_config_dist_multi.py index 4b69f90..9b9cee5 100644 --- a/OATrans/parse_config_dist_multi.py +++ b/OATrans/parse_config_dist_multi.py @@ -84,9 +84,13 @@ def initialize(self, name, module, *args, index=None, **kwargs): module_name = self[name][index]['type'] module_args = dict(self[name][index]['args']) + # print('module name:', module_name) + # print('module args:', module_args) + # print('module:', module) + # if parameter not in config subdict, then check if it's in global config. signature = inspect.signature(getattr(module, module_name).__init__) - print(module_name) + # signature = inspect.signature(getattr(OATrans.model.model, 'FrozenInTime').__init__) for param in signature.parameters.keys(): if param not in module_args and param in self.config: module_args[param] = self[param] @@ -97,6 +101,9 @@ def initialize(self, name, module, *args, index=None, **kwargs): if module_name == 'TextObjectVideoDataLoader' and param == 'args': module_args[param] = self.args + # print('args:', args) + # print('module args:', module_args) + return getattr(module, module_name)(*args, **module_args) def __getitem__(self, name): diff --git a/OATrans/train.py b/OATrans/train.py index d383168..37d9413 100755 --- a/OATrans/train.py +++ b/OATrans/train.py @@ -1,7 +1,7 @@ import argparse import collections from OATrans.data_loader import data_loader as module_data -from OATrans import model as module_loss, model as module_metric, model as module_arch +from OATrans.model import loss as module_loss, metric as module_metric, model as module_arch import utils.visualizer as module_vis from utils.util import replace_nested_dict_item from parse_config_dist_multi import ConfigParser diff --git a/OATrans/trainer/trainer.py b/OATrans/trainer/trainer.py index 4484174..dd0126a 100755 --- a/OATrans/trainer/trainer.py +++ b/OATrans/trainer/trainer.py @@ -147,7 +147,12 @@ def _valid_epoch(self, epoch): # This avoids using `DataParallel` in this case, and supposes the entire batch fits in one GPU. text_embed, vid_embed = self.model.module(data, return_embeds=True) else: + # print('data:', len(data)) text_embed, vid_embed = self.model(data, return_embeds=True) + + # print('In _valid_epoch, text embed:', type(text_embed), text_embed.shape, + # 'video embed:', type(vid_embed), vid_embed.shape) + text_embed_arr[dl_idx].append(text_embed.cpu()) vid_embed_arr[dl_idx].append(vid_embed.cpu()) sims_batch = sim_matrix(text_embed, vid_embed)