diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a5c8aaa --- /dev/null +++ b/.gitignore @@ -0,0 +1,41 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST +detectron/ +*.ipynb + +# Datasets, pretrained models, checkpoints and preprocessed files +data/ +!visdialch/data/ +checkpoints/ +logs/ + +# IPython Notebook +.ipynb_checkpoints + +# virtualenv +venv/ +ENV/ \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..fab9d6d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +repos: +- repo: https://github.com/ambv/black + rev: 19.3b0 + hooks: + - id: black + language_version: python3.6 +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.1.0 + hooks: + - id: flake8 + - id: trailing-whitespace + - id: check-added-large-files + - id: end-of-file-fixer diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..842ed8a --- /dev/null +++ b/LICENSE @@ -0,0 +1,60 @@ +BSD 3-Clause License + +Copyright (c) 2018, Yulei Niu +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +BSD 3-Clause License + +Copyright (c) 2018, Karan Desai +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..4ecfb7e --- /dev/null +++ b/README.md @@ -0,0 +1,137 @@ +Recursive Visual Attention in Visual Dialog +==================================== + +This repository contains the code for the following paper: + +* Yulei Niu, Hanwang Zhang, Manli Zhang, Jianhong Zhang, Zhiwu Lu, Ji-Rong Wen, *Recursive Visual Attention in Visual Dialog*. In CVPR, 2019. ([PDF](https://arxiv.org/pdf/1812.02664.pdf)) + +``` +@InProceedings{Niu_2019_CVPR, + author = {Niu, Yulei and Zhang, Hanwang and Zhang, Manli and Zhang, Jianhong and Lu, Zhiwu and Wen, Ji-Rong}, + title = {Recursive Visual Attention in Visual Dialog}, + booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2019} +} + +``` + +This code is reimplemented as a fork of [batra-mlp-lab/visdial-challenge-starter-pytorch][6]. + + +Setup and Dependencies +---------------------- + +This code is implemented using PyTorch v1.0, and provides out of the box support with CUDA 9 and CuDNN 7. Anaconda/Miniconda is the recommended to set up this codebase: + +### Anaconda or Miniconda + +1. Install Anaconda or Miniconda distribution based on Python3+ from their [downloads' site][1]. +2. Clone this repository and create an environment: + +```shell +git clone https://www.github.com/yuleiniu/rva +conda create -n visdial-ch python=3.6 + +# activate the environment and install all dependencies +conda activate visdial-ch +cd rva/ +pip install -r requirements.txt + +# install this codebase as a package in development version +python setup.py develop +``` + + +Download Data +------------- + +1. Download the VisDial v1.0 dialog json files from [here][3] and keep it under `$PROJECT_ROOT/data` directory, for default arguments to work effectively. + +2. Get the word counts for VisDial v1.0 train split [here][4]. They are used to build the vocabulary. + +3. [batra-mlp-lab][6] provides pre-extracted image features of VisDial v1.0 images, using a Faster-RCNN pre-trained on Visual Genome. If you wish to extract your own image features, skip this step and download VisDial v1.0 images from [here][3] instead. Extracted features for v1.0 train, val and test are available for download at these links. Note that these files do not contain the bounding box information. + + * [`features_faster_rcnn_x101_train.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_faster_rcnn_x101_train.h5): Bottom-up features of 36 proposals from images of `train` split. + * [`features_faster_rcnn_x101_val.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_faster_rcnn_x101_val.h5): Bottom-up features of 36 proposals from images of `val` split. + * [`features_faster_rcnn_x101_test.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_faster_rcnn_x101_test.h5): Bottom-up features of 36 proposals from images of `test` split. + +4. [batra-mlp-lab][6] also provides pre-extracted FC7 features from VGG16. + + * [`features_vgg16_fc7_train.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_vgg16_fc7_train.h5): VGG16 FC7 features from images of `train` split. + * [`features_vgg16_fc7_val.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_vgg16_fc7_val.h5): VGG16 FC7 features from images of `val` split. + * [`features_vgg16_fc7_test.h5`](https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/features_vgg16_fc7_test.h5): VGG16 FC7 features from images of `test` split. + +5. Download the GloVe pretrained word vectors from [here][12], and keep `glove.6B.300d.txt` under `$PROJECT_ROOT/data` directory. + +Extracting Features (Optional) +------------- + +### With Docker (Optional) +For Dockerfile, please refer to [batra-mlp-lab/visdial-challenge-starter-pytorch][8]. + +### Without Docker (Optional) + +0. Set up opencv, [cocoapi][9] and [Detectron][10]. + +1. Prepare the [MSCOCO][11] and [Flickr][3] images. + +2. Extract visual features. +```shell +python ./data/extract_features_detectron.py --image-root /path/to/MSCOCO/train2014/ /path/to/MSCOCO/val2014/ --save-path /path/to/feature --split train # Bottom-up features of 36 proposals from images of train split. +python ./data/extract_features_detectron.py --image-root /path/to/Flickr/VisualDialog_val2018 --save-path /path/to/feature --split val # Bottom-up features of 36 proposals from images of val split. +python ./data/extract_features_detectron.py --image-root /path/to/Flickr/VisualDialog_test2018 --save-path /path/to/feature --split test # Bottom-up features of 36 proposals from images of test split. +``` + +Initializing GloVe Word Embeddings +-------------- +Simply run +```shell +python data/init_glove.py +``` + + +Training +-------- + +Train the model provided in this repository as: + +```shell +python train.py --config-yml configs/rva.yml --gpu-ids 0 # provide more ids for multi-GPU execution other args... +``` + +### Saving model checkpoints + +This script will save model checkpoints at every epoch as per path specified by `--save-dirpath`. Refer [visdialch/utils/checkpointing.py][7] for more details on how checkpointing is managed. + +### Logging + +We use [Tensorboard][2] for logging training progress. Recommended: execute `tensorboard --logdir /path/to/save_dir --port 8008` and visit `localhost:8008` in the browser. + + +Evaluation +---------- + +Evaluation of a trained model checkpoint can be done as follows: + +```shell +python evaluate.py --config-yml /path/to/config.yml --load-pthpath /path/to/checkpoint.pth --split val --gpu-ids 0 +``` + +This will generate an EvalAI submission file, and report metrics from the [Visual Dialog paper][5] (Mean reciprocal rank, R@{1, 5, 10}, Mean rank), and Normalized Discounted Cumulative Gain (NDCG), introduced in the first Visual Dialog Challenge (in 2018). + +The metrics reported here would be the same as those reported through EvalAI by making a submission in `val` phase. To generate a submission file for `test-std` or `test-challenge` phase, replace `--split val` with `--split test`. + + +[1]: https://conda.io/docs/user-guide/install/download.html +[2]: https://www.github.com/lanpa/tensorboardX +[3]: https://visualdialog.org/data +[4]: https://s3.amazonaws.com/visual-dialog/data/v1.0/2019/visdial_1.0_word_counts_train.json +[5]: https://arxiv.org/abs/1611.08669 +[6]: https://www.github.com/batra-mlp-lab/visdial-challenge-starter-pytorch +[7]: https://www.github.com/yuleiniu/rva/blob/master/visdialch/utils/checkpointing.py +[8]: https://www.github.com/batra-mlp-lab/visdial-challenge-starter-pytorch#docker +[9]: https://www.github.com/cocodataset/cocoapi +[10]: https://www.github.com/facebookresearch/Detectron +[11]: http://cocodataset.org/#download +[12]: http://nlp.stanford.edu/data/glove.6B.zip \ No newline at end of file diff --git a/configs/rva.yml b/configs/rva.yml new file mode 100644 index 0000000..f70a6bb --- /dev/null +++ b/configs/rva.yml @@ -0,0 +1,40 @@ +# Dataset reader arguments +dataset: + image_features_train_h5: 'data/features_faster_rcnn_x101_train.h5' + image_features_val_h5: 'data/features_faster_rcnn_x101_val.h5' + image_features_test_h5: 'data/features_faster_rcnn_x101_test.h5' + word_counts_json: 'data/visdial_1.0_word_counts_train.json' + glove_npy: 'data/glove.npy' + + img_norm: 1 + concat_history: false + max_sequence_length: 20 + vocab_min_count: 5 + + +# Model related arguments +model: + encoder: 'rva' + decoder: 'disc' + + img_feature_size: 2048 + word_embedding_size: 300 + lstm_hidden_size: 512 + lstm_num_layers: 2 + dropout: 0.5 + dropout_fc: 0.3 + + relu: 'ReLU' + +# Optimization related arguments +solver: + batch_size: 24 # 32 x num_gpus is a good rule of thumb + num_epochs: 15 + initial_lr: 0.01 + training_splits: "train" # "trainval" + lr_gamma: 0.1 + lr_milestones: # epochs when lr —> lr * lr_gamma + - 5 + - 10 + warmup_factor: 0.2 + warmup_epochs: 1 diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..8c495a0 --- /dev/null +++ b/evaluate.py @@ -0,0 +1,240 @@ +import argparse +import json +import os + +import torch +from torch import nn +from torch.utils.data import DataLoader +from tqdm import tqdm +import yaml + +from visdialch.data.dataset import VisDialDataset +from visdialch.encoders import Encoder +from visdialch.decoders import Decoder +from visdialch.metrics import SparseGTMetrics, NDCG, scores_to_ranks +from visdialch.model import EncoderDecoderModel +from visdialch.utils.checkpointing import load_checkpoint + + +parser = argparse.ArgumentParser( + "Evaluate and/or generate EvalAI submission file." +) +parser.add_argument( + "--config-yml", + default="configs/rva.yml", + help="Path to a config file listing reader, model and optimization " + "parameters.", +) +parser.add_argument( + "--split", + default="val", + choices=["val", "test"], + help="Which split to evaluate upon.", +) +parser.add_argument( + "--val-json", + default="data/visdial_1.0_val.json", + help="Path to VisDial v1.0 val data. This argument doesn't work when " + "--split=test.", +) +parser.add_argument( + "--val-dense-json", + default="data/visdial_1.0_val_dense_annotations.json", + help="Path to VisDial v1.0 val dense annotations (if evaluating on val " + "split). This argument doesn't work when --split=test.", +) +parser.add_argument( + "--test-json", + default="data/visdial_1.0_test.json", + help="Path to VisDial v1.0 test data. This argument doesn't work when " + "--split=val.", +) + +parser.add_argument_group("Evaluation related arguments") +parser.add_argument( + "--load-pthpath", + default="checkpoints/checkpoint_xx.pth", + help="Path to .pth file of pretrained checkpoint.", +) + +parser.add_argument_group( + "Arguments independent of experiment reproducibility" +) +parser.add_argument( + "--gpu-ids", + nargs="+", + type=int, + default=-1, + help="List of ids of GPUs to use.", +) +parser.add_argument( + "--cpu-workers", + type=int, + default=4, + help="Number of CPU workers for reading data.", +) +parser.add_argument( + "--overfit", + action="store_true", + help="Overfit model on 5 examples, meant for debugging.", +) +parser.add_argument( + "--in-memory", + action="store_true", + help="Load the whole dataset and pre-extracted image features in memory. " + "Use only in presence of large RAM, atleast few tens of GBs.", +) + +parser.add_argument_group("Submission related arguments") +parser.add_argument( + "--save-ranks-path", + default="logs/ranks.json", + help="Path (json) to save ranks, in a EvalAI submission format.", +) + +# For reproducibility. +# Refer https://pytorch.org/docs/stable/notes/randomness.html +torch.manual_seed(0) +torch.cuda.manual_seed_all(0) +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True + +# ============================================================================= +# INPUT ARGUMENTS AND CONFIG +# ============================================================================= + +args = parser.parse_args() + +# keys: {"dataset", "model", "solver"} +config = yaml.load(open(args.config_yml)) + +if isinstance(args.gpu_ids, int): + args.gpu_ids = [args.gpu_ids] +device = ( + torch.device("cuda", args.gpu_ids[0]) + if args.gpu_ids[0] >= 0 + else torch.device("cpu") +) + +# Print config and args. +print(yaml.dump(config, default_flow_style=False)) +for arg in vars(args): + print("{:<20}: {}".format(arg, getattr(args, arg))) + + +# ============================================================================= +# SETUP DATASET, DATALOADER, MODEL +# ============================================================================= + +if args.split == "val": + val_dataset = VisDialDataset( + config["dataset"], + args.val_json, + args.val_dense_json, + overfit=args.overfit, + in_memory=args.in_memory, + return_options=True, + add_boundary_toks=False + if config["model"]["decoder"] == "disc" + else True, + ) +else: + val_dataset = VisDialDataset( + config["dataset"], + args.test_json, + overfit=args.overfit, + in_memory=args.in_memory, + return_options=True, + add_boundary_toks=False + if config["model"]["decoder"] == "disc" + else True, + ) +val_dataloader = DataLoader( + val_dataset, + batch_size=config["solver"]["batch_size"] + if config["model"]["decoder"] == "disc" + else 5, + num_workers=args.cpu_workers, +) + +# Pass vocabulary to construct Embedding layer. +encoder = Encoder(config["model"], val_dataset.vocabulary) +decoder = Decoder(config["model"], val_dataset.vocabulary) +print("Encoder: {}".format(config["model"]["encoder"])) +print("Decoder: {}".format(config["model"]["decoder"])) + +# Share word embedding between encoder and decoder. +decoder.word_embed = encoder.word_embed + +# Wrap encoder and decoder in a model. +model = EncoderDecoderModel(encoder, decoder).to(device) +if -1 not in args.gpu_ids: + model = nn.DataParallel(model, args.gpu_ids) + +model_state_dict, _ = load_checkpoint(args.load_pthpath) +if isinstance(model, nn.DataParallel): + model.module.load_state_dict(model_state_dict) +else: + model.load_state_dict(model_state_dict) +print("Loaded model from {}".format(args.load_pthpath)) + +# Declare metric accumulators (won't be used if --split=test) +sparse_metrics = SparseGTMetrics() +ndcg = NDCG() + +# ============================================================================= +# EVALUATION LOOP +# ============================================================================= + +model.eval() +ranks_json = [] + +for _, batch in enumerate(tqdm(val_dataloader)): + for key in batch: + batch[key] = batch[key].to(device) + with torch.no_grad(): + output = model(batch) + + ranks = scores_to_ranks(output) + for i in range(len(batch["img_ids"])): + # Cast into types explicitly to ensure no errors in schema. + # Round ids are 1-10, not 0-9 + if args.split == "test": + ranks_json.append( + { + "image_id": batch["img_ids"][i].item(), + "round_id": int(batch["num_rounds"][i].item()), + "ranks": [ + rank.item() + for rank in ranks[i][batch["num_rounds"][i] - 1] + ], + } + ) + else: + for j in range(batch["num_rounds"][i]): + ranks_json.append( + { + "image_id": batch["img_ids"][i].item(), + "round_id": int(j + 1), + "ranks": [rank.item() for rank in ranks[i][j]], + } + ) + + if args.split == "val": + sparse_metrics.observe(output, batch["ans_ind"]) + if "gt_relevance" in batch: + output = output[ + torch.arange(output.size(0)), batch["round_id"] - 1, : + ] + ndcg.observe(output, batch["gt_relevance"]) + +if args.split == "val": + all_metrics = {} + all_metrics.update(sparse_metrics.retrieve(reset=True)) + all_metrics.update(ndcg.retrieve(reset=True)) + for metric_name, metric_value in all_metrics.items(): + print(f"{metric_name}: {metric_value}") + +print("Writing ranks to {}".format(args.save_ranks_path)) +os.makedirs(os.path.dirname(args.save_ranks_path), exist_ok=True) +json.dump(ranks_json, open(args.save_ranks_path, "w")) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..dfab025 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,16 @@ +[tool.black] +line-length = 79 +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.hg + | \.mypy_cache + | \.tox + | \.venv + | _build + | buck-out + | build + | dist +)/ +''' diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e0d7a45 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +cython==0.29.1 +h5py==2.8.0 +nltk==3.4 +numpy==1.15.4 +Pillow==5.3.0 +pyyaml>=4.2b1 +six==1.11.0 +tensorboardX==1.2 +tensorflow==1.12.0 +torch==1.0.0 +tqdm==4.28.1 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..e0e83d6 --- /dev/null +++ b/setup.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python +from setuptools import setup + + +setup( + name="visdial-ch", + version="2019.0.0", + author="Yulei Niu", + url="https://github.com/yuleiniu/rva", + description="Code for RvA", + license="BSD", + zip_safe=True, +) diff --git a/train.py b/train.py new file mode 100644 index 0000000..4beb176 --- /dev/null +++ b/train.py @@ -0,0 +1,344 @@ +import argparse +import itertools + +from tensorboardX import SummaryWriter +import torch +from torch import nn, optim +from torch.optim import lr_scheduler +from torch.utils.data import DataLoader +from tqdm import tqdm +import yaml +from bisect import bisect + +import datetime +import numpy as np + +from visdialch.data.dataset import VisDialDataset +from visdialch.encoders import Encoder +from visdialch.decoders import Decoder +from visdialch.metrics import SparseGTMetrics, NDCG +from visdialch.model import EncoderDecoderModel +from visdialch.utils.checkpointing import CheckpointManager, load_checkpoint + + +parser = argparse.ArgumentParser() +parser.add_argument( + "--config-yml", + default="configs/rva.yml", + help="Path to a config file listing reader, model and solver parameters.", +) +parser.add_argument( + "--train-json", + default="data/visdial_1.0_train.json", + help="Path to json file containing VisDial v1.0 training data.", +) +parser.add_argument( + "--val-json", + default="data/visdial_1.0_val.json", + help="Path to json file containing VisDial v1.0 validation data.", +) +parser.add_argument( + "--val-dense-json", + default="data/visdial_1.0_val_dense_annotations.json", + help="Path to json file containing VisDial v1.0 validation dense ground " + "truth annotations.", +) + + +parser.add_argument_group( + "Arguments independent of experiment reproducibility" +) +parser.add_argument( + "--gpu-ids", + nargs="+", + type=int, + default=0, + help="List of ids of GPUs to use.", +) +parser.add_argument( + "--cpu-workers", + type=int, + default=4, + help="Number of CPU workers for dataloader.", +) +parser.add_argument( + "--overfit", + action="store_true", + help="Overfit model on 5 examples, meant for debugging.", +) +parser.add_argument( + "--validate", + action="store_true", + help="Whether to validate on val split after every epoch.", +) +parser.add_argument( + "--in-memory", + action="store_true", + help="Load the whole dataset and pre-extracted image features in memory. " + "Use only in presence of large RAM, atleast few tens of GBs.", +) + + +parser.add_argument_group("Checkpointing related arguments") +parser.add_argument( + "--save-dirpath", + default="checkpoints/", + help="Path of directory to create checkpoint directory and save " + "checkpoints.", +) +parser.add_argument( + "--load-pthpath", + default="", + help="To continue training, path to .pth file of saved checkpoint.", +) + +# For reproducibility. +# Refer https://pytorch.org/docs/stable/notes/randomness.html +torch.manual_seed(0) +torch.cuda.manual_seed_all(0) +torch.backends.cudnn.benchmark = False +torch.backends.cudnn.deterministic = True + + +# ============================================================================= +# INPUT ARGUMENTS AND CONFIG +# ============================================================================= + +args = parser.parse_args() + +# keys: {"dataset", "model", "solver"} +config = yaml.load(open(args.config_yml)) + +if isinstance(args.gpu_ids, int): + args.gpu_ids = [args.gpu_ids] +device = ( + torch.device("cuda", args.gpu_ids[0]) + if args.gpu_ids[0] >= 0 + else torch.device("cpu") +) + +# Print config and args. +print(yaml.dump(config, default_flow_style=False)) +for arg in vars(args): + print("{:<20}: {}".format(arg, getattr(args, arg))) + + +# ============================================================================= +# SETUP DATASET, DATALOADER, MODEL, CRITERION, OPTIMIZER, SCHEDULER +# ============================================================================= + +train_dataset = VisDialDataset( + config["dataset"], + args.train_json, + overfit=args.overfit, + in_memory=args.in_memory, + return_options=True if config["model"]["decoder"] == "disc" else False, + add_boundary_toks=False if config["model"]["decoder"] == "disc" else True, +) +train_dataloader = DataLoader( + train_dataset, + batch_size=config["solver"]["batch_size"], + num_workers=args.cpu_workers, + shuffle=True, +) + +val_dataset = VisDialDataset( + config["dataset"], + args.val_json, + args.val_dense_json, + overfit=args.overfit, + in_memory=args.in_memory, + return_options=True, + add_boundary_toks=False if config["model"]["decoder"] == "disc" else True, +) +val_dataloader = DataLoader( + val_dataset, + batch_size=config["solver"]["batch_size"] + if config["model"]["decoder"] == "disc" + else 5, + num_workers=args.cpu_workers, +) + +# Pass vocabulary to construct Embedding layer. +encoder = Encoder(config["model"], train_dataset.vocabulary) +decoder = Decoder(config["model"], train_dataset.vocabulary) +print("Encoder: {}".format(config["model"]["encoder"])) +print("Decoder: {}".format(config["model"]["decoder"])) + +# New: Initializing word_embed using GloVe +if config["dataset"]["glove_npy"] != '': + encoder.word_embed.weight.data = torch.from_numpy(np.load(config["dataset"]["glove_npy"])) + print("Loaded glove vectors from {}".format(config["dataset"]["glove_npy"])) + +# Share word embedding between encoder and decoder. +decoder.word_embed = encoder.word_embed + +# Wrap encoder and decoder in a model. +model = EncoderDecoderModel(encoder, decoder).to(device) +if -1 not in args.gpu_ids: + model = nn.DataParallel(model, args.gpu_ids) + +# Loss function. +if config["model"]["decoder"] == "disc": + criterion = nn.CrossEntropyLoss() +elif config["model"]["decoder"] == "gen": + criterion = nn.CrossEntropyLoss( + ignore_index=train_dataset.vocabulary.PAD_INDEX + ) +else: + raise NotImplementedError + +if config["solver"]["training_splits"] == "trainval": + iterations = (len(train_dataset) + len(val_dataset)) // config["solver"][ + "batch_size" + ] + 1 +else: + iterations = len(train_dataset) // config["solver"]["batch_size"] + 1 + + +def lr_lambda_fun(current_iteration: int) -> float: + """Returns a learning rate multiplier. + + Till `warmup_epochs`, learning rate linearly increases to `initial_lr`, + and then gets multiplied by `lr_gamma` every time a milestone is crossed. + """ + current_epoch = float(current_iteration) / iterations + if current_epoch <= config["solver"]["warmup_epochs"]: + alpha = current_epoch / float(config["solver"]["warmup_epochs"]) + return config["solver"]["warmup_factor"] * (1.0 - alpha) + alpha + else: + idx = bisect(config["solver"]["lr_milestones"], current_epoch) + return pow(config["solver"]["lr_gamma"], idx) + + +optimizer = optim.Adamax(model.parameters(), lr=config["solver"]["initial_lr"]) +scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda_fun) + + +# ============================================================================= +# SETUP BEFORE TRAINING LOOP +# ============================================================================= +start_time = datetime.datetime.strftime(datetime.datetime.utcnow(), '%d-%b-%Y-%H:%M:%S') +if args.save_dirpath == 'checkpoints/': + args.save_dirpath += '%s+%s/%s' % (config["model"]["encoder"], config["model"]["decoder"], start_time) +summary_writer = SummaryWriter(log_dir=args.save_dirpath) +checkpoint_manager = CheckpointManager( + model, optimizer, args.save_dirpath, config=config +) +sparse_metrics = SparseGTMetrics() +ndcg = NDCG() + +# If loading from checkpoint, adjust start epoch and load parameters. +if args.load_pthpath == "": + start_epoch = 1 +else: + # "path/to/checkpoint_xx.pth" -> xx + start_epoch = int(args.load_pthpath.split("_")[-1][:-4]) + + model_state_dict, optimizer_state_dict = load_checkpoint(args.load_pthpath) + if isinstance(model, nn.DataParallel): + model.module.load_state_dict(model_state_dict) + else: + model.load_state_dict(model_state_dict) + optimizer.load_state_dict(optimizer_state_dict) + print("Loaded model from {}".format(args.load_pthpath)) + +# ============================================================================= +# TRAINING LOOP +# ============================================================================= + +# Forever increasing counter to keep track of iterations (for tensorboard log). +global_iteration_step = (start_epoch - 1) * iterations + +running_loss = 0.0 # New +train_begin = datetime.datetime.utcnow() # New +for epoch in range(start_epoch, config["solver"]["num_epochs"]): + + # ------------------------------------------------------------------------- + # ON EPOCH START (combine dataloaders if training on train + val) + # ------------------------------------------------------------------------- + if config["solver"]["training_splits"] == "trainval": + combined_dataloader = itertools.chain(train_dataloader, val_dataloader) + else: + combined_dataloader = itertools.chain(train_dataloader) + + print(f"\nTraining for epoch {epoch}:") + for i, batch in enumerate(combined_dataloader): + for key in batch: + batch[key] = batch[key].to(device) + + optimizer.zero_grad() + output = model(batch) + target = ( + batch["ans_ind"] + if config["model"]["decoder"] == "disc" + else batch["ans_out"] + ) + batch_loss = criterion( + output.view(-1, output.size(-1)), target.view(-1) + ) + batch_loss.backward() + optimizer.step() + + # -------------------------------------------------------------------- + # update running loss and decay learning rates + # -------------------------------------------------------------------- + if running_loss > 0.0: + running_loss = 0.95 * running_loss + 0.05 * batch_loss.item() + else: + running_loss = batch_loss.item() + + scheduler.step(global_iteration_step) + global_iteration_step += 1 + torch.cuda.empty_cache() + + if global_iteration_step % 100 == 0: + # print current time, running average, learning rate, iteration, epoch + print("[{}][Epoch: {:3d}][Iter: {:6d}][Loss: {:6f}][lr: {:7f}]".format( + datetime.datetime.utcnow() - train_begin, epoch, + global_iteration_step, running_loss, + optimizer.param_groups[0]['lr'])) + + # tensorboardX + summary_writer.add_scalar( + "train/loss", batch_loss, global_iteration_step + ) + summary_writer.add_scalar( + "train/lr", optimizer.param_groups[0]["lr"], global_iteration_step + ) + + # ------------------------------------------------------------------------- + # ON EPOCH END (checkpointing and validation) + # ------------------------------------------------------------------------- + checkpoint_manager.step() + + # Validate and report automatic metrics. + if args.validate: + + # Switch dropout, batchnorm etc to the correct mode. + model.eval() + + print(f"\nValidation after epoch {epoch}:") + for i, batch in enumerate(tqdm(val_dataloader)): + for key in batch: + batch[key] = batch[key].to(device) + with torch.no_grad(): + output = model(batch) + sparse_metrics.observe(output, batch["ans_ind"]) + if "gt_relevance" in batch: + output = output[ + torch.arange(output.size(0)), batch["round_id"] - 1, : + ] + ndcg.observe(output, batch["gt_relevance"]) + + all_metrics = {} + all_metrics.update(sparse_metrics.retrieve(reset=True)) + all_metrics.update(ndcg.retrieve(reset=True)) + for metric_name, metric_value in all_metrics.items(): + print(f"{metric_name}: {metric_value}") + summary_writer.add_scalars( + "metrics", all_metrics, global_iteration_step + ) + + model.train() + torch.cuda.empty_cache() diff --git a/visdialch/__init__.py b/visdialch/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/visdialch/data/__init__.py b/visdialch/data/__init__.py new file mode 100644 index 0000000..f535ea8 --- /dev/null +++ b/visdialch/data/__init__.py @@ -0,0 +1,2 @@ +from visdialch.data.dataset import VisDialDataset +from visdialch.data.vocabulary import Vocabulary diff --git a/visdialch/data/dataset.py b/visdialch/data/dataset.py new file mode 100644 index 0000000..fc8e265 --- /dev/null +++ b/visdialch/data/dataset.py @@ -0,0 +1,311 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.functional import normalize +from torch.nn.utils.rnn import pad_sequence +from torch.utils.data import Dataset + +from visdialch.data.readers import ( + DialogsReader, + DenseAnnotationsReader, + ImageFeaturesHdfReader, +) +from visdialch.data.vocabulary import Vocabulary + + +class VisDialDataset(Dataset): + """ + A full representation of VisDial v1.0 (train/val/test) dataset. According + to the appropriate split, it returns dictionary of question, image, + history, ground truth answer, answer options, dense annotations etc. + """ + + def __init__( + self, + config: Dict[str, Any], + dialogs_jsonpath: str, + dense_annotations_jsonpath: Optional[str] = None, + overfit: bool = False, + in_memory: bool = False, + return_options: bool = True, + add_boundary_toks: bool = False, + ): + super().__init__() + self.config = config + self.return_options = return_options + self.add_boundary_toks = add_boundary_toks + self.dialogs_reader = DialogsReader(dialogs_jsonpath) + + if "val" in self.split and dense_annotations_jsonpath is not None: + self.annotations_reader = DenseAnnotationsReader( + dense_annotations_jsonpath + ) + else: + self.annotations_reader = None + + self.vocabulary = Vocabulary( + config["word_counts_json"], min_count=config["vocab_min_count"] + ) + + # Initialize image features reader according to split. + image_features_hdfpath = config["image_features_train_h5"] + if "val" in self.dialogs_reader.split: + image_features_hdfpath = config["image_features_val_h5"] + elif "test" in self.dialogs_reader.split: + image_features_hdfpath = config["image_features_test_h5"] + + self.hdf_reader = ImageFeaturesHdfReader( + image_features_hdfpath, in_memory + ) + + # Keep a list of image_ids as primary keys to access data. + self.image_ids = list(self.dialogs_reader.dialogs.keys()) + if overfit: + self.image_ids = self.image_ids[:5] + + @property + def split(self): + return self.dialogs_reader.split + + def __len__(self): + return len(self.image_ids) + + def __getitem__(self, index): + # Get image_id, which serves as a primary key for current instance. + image_id = self.image_ids[index] + + # Get image features for this image_id using hdf reader. + image_features = self.hdf_reader[image_id] + image_features = torch.tensor(image_features) + # Normalize image features at zero-th dimension (since there's no batch + # dimension). + if self.config["img_norm"]: + image_features = normalize(image_features, dim=0, p=2) + + # Retrieve instance for this image_id using json reader. + visdial_instance = self.dialogs_reader[image_id] + caption = visdial_instance["caption"] + dialog = visdial_instance["dialog"] + + # Convert word tokens of caption, question, answer and answer options + # to integers. + caption = self.vocabulary.to_indices(caption) + for i in range(len(dialog)): + dialog[i]["question"] = self.vocabulary.to_indices( + dialog[i]["question"] + ) + if self.add_boundary_toks: + dialog[i]["answer"] = self.vocabulary.to_indices( + [self.vocabulary.SOS_TOKEN] + + dialog[i]["answer"] + + [self.vocabulary.EOS_TOKEN] + ) + else: + dialog[i]["answer"] = self.vocabulary.to_indices( + dialog[i]["answer"] + ) + + if self.return_options: + for j in range(len(dialog[i]["answer_options"])): + if self.add_boundary_toks: + dialog[i]["answer_options"][ + j + ] = self.vocabulary.to_indices( + [self.vocabulary.SOS_TOKEN] + + dialog[i]["answer_options"][j] + + [self.vocabulary.EOS_TOKEN] + ) + else: + dialog[i]["answer_options"][ + j + ] = self.vocabulary.to_indices( + dialog[i]["answer_options"][j] + ) + + questions, question_lengths = self._pad_sequences( + [dialog_round["question"] for dialog_round in dialog] + ) + history, history_lengths = self._get_history( + caption, + [dialog_round["question"] for dialog_round in dialog], + [dialog_round["answer"] for dialog_round in dialog], + ) + answers_in, answer_lengths = self._pad_sequences( + [dialog_round["answer"][:-1] for dialog_round in dialog] + ) + answers_out, _ = self._pad_sequences( + [dialog_round["answer"][1:] for dialog_round in dialog] + ) + + # Collect everything as tensors for ``collate_fn`` of dataloader to + # work seamlessly questions, history, etc. are converted to + # LongTensors, for nn.Embedding input. + item = {} + item["img_ids"] = torch.tensor(image_id).long() + item["img_feat"] = image_features + item["ques"] = questions.long() + item["hist"] = history.long() + item["ans_in"] = answers_in.long() + item["ans_out"] = answers_out.long() + item["ques_len"] = torch.tensor(question_lengths).long() + item["hist_len"] = torch.tensor(history_lengths).long() + item["ans_len"] = torch.tensor(answer_lengths).long() + item["num_rounds"] = torch.tensor( + visdial_instance["num_rounds"] + ).long() + + if self.return_options: + if self.add_boundary_toks: + answer_options_in, answer_options_out = [], [] + answer_option_lengths = [] + for dialog_round in dialog: + options, option_lengths = self._pad_sequences( + [ + option[:-1] + for option in dialog_round["answer_options"] + ] + ) + answer_options_in.append(options) + + options, _ = self._pad_sequences( + [ + option[1:] + for option in dialog_round["answer_options"] + ] + ) + answer_options_out.append(options) + + answer_option_lengths.append(option_lengths) + answer_options_in = torch.stack(answer_options_in, 0) + answer_options_out = torch.stack(answer_options_out, 0) + + item["opt_in"] = answer_options_in.long() + item["opt_out"] = answer_options_out.long() + item["opt_len"] = torch.tensor(answer_option_lengths).long() + else: + answer_options = [] + answer_option_lengths = [] + for dialog_round in dialog: + options, option_lengths = self._pad_sequences( + dialog_round["answer_options"] + ) + answer_options.append(options) + answer_option_lengths.append(option_lengths) + answer_options = torch.stack(answer_options, 0) + + item["opt"] = answer_options.long() + item["opt_len"] = torch.tensor(answer_option_lengths).long() + + if "test" not in self.split: + answer_indices = [ + dialog_round["gt_index"] for dialog_round in dialog + ] + item["ans_ind"] = torch.tensor(answer_indices).long() + + # Gather dense annotations. + if "val" in self.split: + dense_annotations = self.annotations_reader[image_id] + item["gt_relevance"] = torch.tensor( + dense_annotations["gt_relevance"] + ).float() + item["round_id"] = torch.tensor( + dense_annotations["round_id"] + ).long() + + return item + + def _pad_sequences(self, sequences: List[List[int]]): + """Given tokenized sequences (either questions, answers or answer + options, tokenized in ``__getitem__``), padding them to maximum + specified sequence length. Return as a tensor of size + ``(*, max_sequence_length)``. + + This method is only called in ``__getitem__``, chunked out separately + for readability. + + Parameters + ---------- + sequences : List[List[int]] + List of tokenized sequences, each sequence is typically a + List[int]. + + Returns + ------- + torch.Tensor, torch.Tensor + Tensor of sequences padded to max length, and length of sequences + before padding. + """ + + for i in range(len(sequences)): + sequences[i] = sequences[i][ + : self.config["max_sequence_length"] - 1 + ] + sequence_lengths = [len(sequence) for sequence in sequences] + + # Pad all sequences to max_sequence_length. + maxpadded_sequences = torch.full( + (len(sequences), self.config["max_sequence_length"]), + fill_value=self.vocabulary.PAD_INDEX, + ) + padded_sequences = pad_sequence( + [torch.tensor(sequence) for sequence in sequences], + batch_first=True, + padding_value=self.vocabulary.PAD_INDEX, + ) + maxpadded_sequences[:, : padded_sequences.size(1)] = padded_sequences + return maxpadded_sequences, sequence_lengths + + def _get_history( + self, + caption: List[int], + questions: List[List[int]], + answers: List[List[int]], + ): + # Allow double length of caption, equivalent to a concatenated QA pair. + caption = caption[: self.config["max_sequence_length"] * 2 - 1] + + for i in range(len(questions)): + questions[i] = questions[i][ + : self.config["max_sequence_length"] - 1 + ] + + for i in range(len(answers)): + answers[i] = answers[i][: self.config["max_sequence_length"] - 1] + + # History for first round is caption, else concatenated QA pair of + # previous round. + history = [] + history.append(caption) + for question, answer in zip(questions, answers): + history.append(question + answer + [self.vocabulary.EOS_INDEX]) + # Drop last entry from history (there's no eleventh question). + history = history[:-1] + max_history_length = self.config["max_sequence_length"] * 2 + + if self.config.get("concat_history", False): + # Concatenated_history has similar structure as history, except it + # contains concatenated QA pairs from previous rounds. + concatenated_history = [] + concatenated_history.append(caption) + for i in range(1, len(history)): + concatenated_history.append([]) + for j in range(i + 1): + concatenated_history[i].extend(history[j]) + + max_history_length = ( + self.config["max_sequence_length"] * 2 * len(history) + ) + history = concatenated_history + + history_lengths = [len(round_history) for round_history in history] + maxpadded_history = torch.full( + (len(history), max_history_length), + fill_value=self.vocabulary.PAD_INDEX, + ) + padded_history = pad_sequence( + [torch.tensor(round_history) for round_history in history], + batch_first=True, + padding_value=self.vocabulary.PAD_INDEX, + ) + maxpadded_history[:, : padded_history.size(1)] = padded_history + return maxpadded_history, history_lengths \ No newline at end of file diff --git a/visdialch/data/readers.py b/visdialch/data/readers.py new file mode 100644 index 0000000..4079039 --- /dev/null +++ b/visdialch/data/readers.py @@ -0,0 +1,205 @@ +""" +A Reader simply reads data from disk and returns it almost as is, based on a "primary key", which +for the case of VisDial v1.0 dataset, is the ``image_id``. Readers should be utilized by +torch ``Dataset``s. Any type of data pre-processing is not recommended in the reader, such as +tokenizing words to integers, embedding tokens, or passing an image through a pre-trained CNN. + +Each reader must atleast implement three methods: + - ``__len__`` to return the length of data this Reader can read. + - ``__getitem__`` to return data based on ``image_id`` in VisDial v1.0 dataset. + - ``keys`` to return a list of possible ``image_id``s this Reader can provide data of. +""" + +import copy +import json +from typing import Dict, List, Union + +import h5py +# A bit slow, and just splits sentences to list of words, can be doable in `DialogsReader`. +from nltk.tokenize import word_tokenize +from tqdm import tqdm + +import numpy as np + +class DialogsReader(object): + """ + A simple reader for VisDial v1.0 dialog data. The json file must have the same structure as + mentioned on ``https://visualdialog.org/data``. + + Parameters + ---------- + dialogs_jsonpath : str + Path to a json file containing VisDial v1.0 train, val or test dialog data. + """ + + def __init__(self, dialogs_jsonpath: str): + with open(dialogs_jsonpath, "r") as visdial_file: + visdial_data = json.load(visdial_file) + self._split = visdial_data["split"] + + self.questions = visdial_data["data"]["questions"] + self.answers = visdial_data["data"]["answers"] + + # Add empty question, answer at the end, useful for padding dialog rounds for test. + self.questions.append("") + self.answers.append("") + + # Image_id serves as key for all three dicts here. + self.captions = {} + self.dialogs = {} + self.num_rounds = {} + + for dialog_for_image in visdial_data["data"]["dialogs"]: + self.captions[dialog_for_image["image_id"]] = dialog_for_image["caption"] + + # Record original length of dialog, before padding. + # 10 for train and val splits, 10 or less for test split. + self.num_rounds[dialog_for_image["image_id"]] = len(dialog_for_image["dialog"]) + + # Pad dialog at the end with empty question and answer pairs (for test split). + while len(dialog_for_image["dialog"]) < 10: + dialog_for_image["dialog"].append({"question": -1, "answer": -1}) + + # Add empty answer /answer options if not provided (for test split). + for i in range(len(dialog_for_image["dialog"])): + if "answer" not in dialog_for_image["dialog"][i]: + dialog_for_image["dialog"][i]["answer"] = -1 + if "answer_options" not in dialog_for_image["dialog"][i]: + dialog_for_image["dialog"][i]["answer_options"] = [-1] * 100 + + self.dialogs[dialog_for_image["image_id"]] = dialog_for_image["dialog"] + + print(f"[{self._split}] Tokenizing questions...") + for i in tqdm(range(len(self.questions))): + self.questions[i] = word_tokenize(self.questions[i] + "?") + + print(f"[{self._split}] Tokenizing answers...") + for i in tqdm(range(len(self.answers))): + self.answers[i] = word_tokenize(self.answers[i]) + + print(f"[{self._split}] Tokenizing captions...") + for image_id, caption in tqdm(self.captions.items()): + self.captions[image_id] = word_tokenize(caption) + + def __len__(self): + return len(self.dialogs) + + def __getitem__(self, image_id: int) -> Dict[str, Union[int, str, List]]: + caption_for_image = self.captions[image_id] + dialog_for_image = copy.deepcopy(self.dialogs[image_id]) + num_rounds = self.num_rounds[image_id] + + # Replace question and answer indices with actual word tokens. + for i in range(len(dialog_for_image)): + dialog_for_image[i]["question"] = self.questions[dialog_for_image[i]["question"]] + dialog_for_image[i]["answer"] = self.answers[dialog_for_image[i]["answer"]] + for j, answer_option in enumerate(dialog_for_image[i]["answer_options"]): + dialog_for_image[i]["answer_options"][j] = self.answers[answer_option] + + return { + "image_id": image_id, + "caption": caption_for_image, + "dialog": dialog_for_image, + "num_rounds": num_rounds + } + + def keys(self) -> List[int]: + return list(self.dialogs.keys()) + + @property + def split(self): + return self._split + + +class DenseAnnotationsReader(object): + """ + A reader for dense annotations for val split. The json file must have the same structure as mentioned + on ``https://visualdialog.org/data``. + + Parameters + ---------- + dense_annotations_jsonpath : str + Path to a json file containing VisDial v1.0 + """ + + def __init__(self, dense_annotations_jsonpath: str): + with open(dense_annotations_jsonpath, "r") as visdial_file: + self._visdial_data = json.load(visdial_file) + self._image_ids = [entry["image_id"] for entry in self._visdial_data] + + def __len__(self): + return len(self._image_ids) + + def __getitem__(self, image_id: int) -> Dict[str, Union[int, List]]: + index = self._image_ids.index(image_id) + # keys: {"image_id", "round_id", "gt_relevance"} + return self._visdial_data[index] + + @property + def split(self): + # always + return "val" + + +class ImageFeaturesHdfReader(object): + """ + A reader for HDF files containing pre-extracted image features. A typical HDF file is expected + to have a column named "image_id", and another column named "features". + + Example of an HDF file: + ``` + visdial_train_faster_rcnn_bottomup_features.h5 + |--- "image_id" [shape: (num_images, )] + |--- "features" [shape: (num_images, num_proposals, feature_size)] + +--- .attrs ("split", "train") + ``` + Refer ``$PROJECT_ROOT/data/extract_bottomup.py`` script for more details about HDF structure. + + Parameters + ---------- + features_hdfpath : str + Path to an HDF file containing VisDial v1.0 train, val or test split image features. + in_memory : bool + Whether to load the whole HDF file in memory. Beware, these files are sometimes tens of GBs + in size. Set this to true if you have sufficient RAM - trade-off between speed and memory. + """ + + def __init__(self, features_hdfpath: str, in_memory: bool = False): + self.features_hdfpath = features_hdfpath + self._in_memory = in_memory + + with h5py.File(self.features_hdfpath, "r") as features_hdf: + self._split = features_hdf.attrs["split"] + self.image_id_list = list(features_hdf["image_id"]) + # "features" is List[np.ndarray] if the dataset is loaded in-memory + # If not loaded in memory, then list of None. + self.features = [None] * len(self.image_id_list) + + + def __len__(self): + return len(self.image_id_list) + + def __getitem__(self, image_id: int): + index = self.image_id_list.index(image_id) + if self._in_memory: + # Load features during first epoch, all not loaded together as it has a slow start. + if self.features[index] is not None: + image_id_features = self.features[index] + else: + with h5py.File(self.features_hdfpath, "r") as features_hdf: + image_id_features = features_hdf["features"][index] + image_id_loc = self.loc_feats(features_hdf, index) + self.features[index] = image_id_features + else: + # Read chunk from file everytime if not loaded in memory. + with h5py.File(self.features_hdfpath, "r") as features_hdf: + image_id_features = features_hdf["features"][index] + + return image_id_features + + def keys(self) -> List[int]: + return self.image_id_list + + @property + def split(self): + return self._split \ No newline at end of file diff --git a/visdialch/data/vocabulary.py b/visdialch/data/vocabulary.py new file mode 100644 index 0000000..ae31100 --- /dev/null +++ b/visdialch/data/vocabulary.py @@ -0,0 +1,86 @@ +""" +A Vocabulary maintains a mapping between words and corresponding unique integers, holds special +integers (tokens) for indicating start and end of sequence, and offers functionality to map +out-of-vocabulary words to the corresponding token. +""" +import json +import os +from typing import List, Union + + +class Vocabulary(object): + """ + A simple Vocabulary class which maintains a mapping between words and integer tokens. Can be + initialized either by word counts from the VisDial v1.0 train dataset, or a pre-saved + vocabulary mapping. + + Parameters + ---------- + word_counts_path: str + Path to a json file containing counts of each word across captions, questions and answers + of the VisDial v1.0 train dataset. + min_count : int, optional (default=0) + When initializing the vocabulary from word counts, you can specify a minimum count, and + every token with a count less than this will be excluded from vocabulary. + """ + + PAD_TOKEN = "" + SOS_TOKEN = "" + EOS_TOKEN = "" + UNK_TOKEN = "" + + PAD_INDEX = 0 + SOS_INDEX = 1 + EOS_INDEX = 2 + UNK_INDEX = 3 + + def __init__(self, word_counts_path: str, min_count: int = 5): + if not os.path.exists(word_counts_path): + raise FileNotFoundError(f"Word counts do not exist at {word_counts_path}") + + with open(word_counts_path, "r") as word_counts_file: + word_counts = json.load(word_counts_file) + + # form a list of (word, count) tuples and apply min_count threshold + word_counts = [ + (word, count) for word, count in word_counts.items() if count >= min_count + ] + # sort in descending order of word counts + word_counts = sorted(word_counts, key=lambda wc: -wc[1]) + words = [w[0] for w in word_counts] + + self.word2index = {} + self.word2index[self.PAD_TOKEN] = self.PAD_INDEX + self.word2index[self.SOS_TOKEN] = self.SOS_INDEX + self.word2index[self.EOS_TOKEN] = self.EOS_INDEX + self.word2index[self.UNK_TOKEN] = self.UNK_INDEX + for index, word in enumerate(words): + self.word2index[word] = index + 4 + + self.index2word = {index: word for word, index in self.word2index.items()} + + @classmethod + def from_saved(cls, saved_vocabulary_path: str) -> "Vocabulary": + """Build the vocabulary from a json file saved by ``save`` method. + + Parameters + ---------- + saved_vocabulary_path : str + Path to a json file containing word to integer mappings (saved vocabulary). + """ + with open(saved_vocabulary_path, "r") as saved_vocabulary_file: + cls.word2index = json.load(saved_vocabulary_file) + cls.index2word = {index: word for word, index in cls.word2index.items()} + + def to_indices(self, words: List[str]) -> List[int]: + return [self.word2index.get(word, self.UNK_INDEX) for word in words] + + def to_words(self, indices: List[int]) -> List[str]: + return [self.index2word.get(index, self.UNK_TOKEN) for index in indices] + + def save(self, save_vocabulary_path: str) -> None: + with open(save_vocabulary_path, "w") as save_vocabulary_file: + json.dump(self.word2index, saved_vocabulary_file) + + def __len__(self): + return len(self.index2word) diff --git a/visdialch/decoders/__init__.py b/visdialch/decoders/__init__.py new file mode 100644 index 0000000..6e6b528 --- /dev/null +++ b/visdialch/decoders/__init__.py @@ -0,0 +1,7 @@ +from visdialch.decoders.disc import DiscriminativeDecoder +from visdialch.decoders.gen import GenerativeDecoder + + +def Decoder(model_config, *args): + name_dec_map = {"disc": DiscriminativeDecoder, "gen": GenerativeDecoder} + return name_dec_map[model_config["decoder"]](model_config, *args) diff --git a/visdialch/decoders/disc.py b/visdialch/decoders/disc.py new file mode 100644 index 0000000..798cc37 --- /dev/null +++ b/visdialch/decoders/disc.py @@ -0,0 +1,94 @@ +import torch +from torch import nn + +from visdialch.utils import DynamicRNN + + +class DiscriminativeDecoder(nn.Module): + def __init__(self, config, vocabulary): + super().__init__() + self.config = config + + self.word_embed = nn.Embedding( + len(vocabulary), + config["word_embedding_size"], + padding_idx=vocabulary.PAD_INDEX, + ) + self.option_rnn = nn.LSTM( + config["word_embedding_size"], + config["lstm_hidden_size"], + config["lstm_num_layers"], + batch_first=True, + dropout=config["dropout"], + ) + + # Options are variable length padded sequences, use DynamicRNN. + self.option_rnn = DynamicRNN(self.option_rnn) + + def forward(self, encoder_output, batch): + """Given `encoder_output` + candidate option sequences, predict a score + for each option sequence. + + Parameters + ---------- + encoder_output: torch.Tensor + Output from the encoder through its forward pass. + (batch_size, num_rounds, lstm_hidden_size) + """ + + options = batch["opt"] + batch_size, num_rounds, num_options, max_sequence_length = ( + options.size() + ) + options = options.view( + batch_size * num_rounds * num_options, max_sequence_length + ) + + options_length = batch["opt_len"] + options_length = options_length.view( + batch_size * num_rounds * num_options + ) + + # Pick options with non-zero length (relevant for test split). + nonzero_options_length_indices = options_length.nonzero().squeeze() + nonzero_options_length = options_length[nonzero_options_length_indices] + nonzero_options = options[nonzero_options_length_indices] + + # shape: (batch_size * num_rounds * num_options, max_sequence_length, + # word_embedding_size) + # FOR TEST SPLIT, shape: (batch_size * 1, num_options, + # max_sequence_length, word_embedding_size) + nonzero_options_embed = self.word_embed(nonzero_options) + + # shape: (batch_size * num_rounds * num_options, lstm_hidden_size) + # FOR TEST SPLIT, shape: (batch_size * 1, num_options, + # lstm_hidden_size) + _, (nonzero_options_embed, _) = self.option_rnn( + nonzero_options_embed, nonzero_options_length + ) + + options_embed = torch.zeros( + batch_size * num_rounds * num_options, + nonzero_options_embed.size(-1), + device=nonzero_options_embed.device, + ) + options_embed[nonzero_options_length_indices] = nonzero_options_embed + + # Repeat encoder output for every option. + # shape: (batch_size, num_rounds, num_options, max_sequence_length) + encoder_output = encoder_output.unsqueeze(2).repeat( + 1, 1, num_options, 1 + ) + + # Shape now same as `options`, can calculate dot product similarity. + # shape: (batch_size * num_rounds * num_options, lstm_hidden_state) + encoder_output = encoder_output.view( + batch_size * num_rounds * num_options, + self.config["lstm_hidden_size"], + ) + + # shape: (batch_size * num_rounds * num_options) + scores = torch.sum(options_embed * encoder_output, 1) + # shape: (batch_size, num_rounds, num_options) + scores = scores.view(batch_size, num_rounds, num_options) + return scores diff --git a/visdialch/decoders/gen.py b/visdialch/decoders/gen.py new file mode 100644 index 0000000..6c1f3dc --- /dev/null +++ b/visdialch/decoders/gen.py @@ -0,0 +1,131 @@ +import torch +from torch import nn + + +class GenerativeDecoder(nn.Module): + def __init__(self, config, vocabulary): + super().__init__() + self.config = config + + self.word_embed = nn.Embedding( + len(vocabulary), + config["word_embedding_size"], + padding_idx=vocabulary.PAD_INDEX, + ) + self.answer_rnn = nn.LSTM( + config["word_embedding_size"], + config["lstm_hidden_size"], + config["lstm_num_layers"], + batch_first=True, + dropout=config["dropout"], + ) + + self.lstm_to_words = nn.Linear( + self.config["lstm_hidden_size"], len(vocabulary) + ) + + self.dropout = nn.Dropout(p=config["dropout"]) + self.logsoftmax = nn.LogSoftmax(dim=-1) + + def forward(self, encoder_output, batch): + """Given `encoder_output`, learn to autoregressively predict + ground-truth answer word-by-word during training. + + During evaluation, assign log-likelihood scores to all answer options. + + Parameters + ---------- + encoder_output: torch.Tensor + Output from the encoder through its forward pass. + (batch_size, num_rounds, lstm_hidden_size) + """ + + if self.training: + + ans_in = batch["ans_in"] + batch_size, num_rounds, max_sequence_length = ans_in.size() + + ans_in = ans_in.view(batch_size * num_rounds, max_sequence_length) + + # shape: (batch_size * num_rounds, max_sequence_length, + # word_embedding_size) + ans_in_embed = self.word_embed(ans_in) + + # reshape encoder output to be set as initial hidden state of LSTM. + # shape: (lstm_num_layers, batch_size * num_rounds, + # lstm_hidden_size) + init_hidden = encoder_output.view(1, batch_size * num_rounds, -1) + init_hidden = init_hidden.repeat( + self.config["lstm_num_layers"], 1, 1 + ) + init_cell = torch.zeros_like(init_hidden) + + # shape: (batch_size * num_rounds, max_sequence_length, + # lstm_hidden_size) + ans_out, (hidden, cell) = self.answer_rnn( + ans_in_embed, (init_hidden, init_cell) + ) + ans_out = self.dropout(ans_out) + + # shape: (batch_size * num_rounds, max_sequence_length, + # vocabulary_size) + ans_word_scores = self.lstm_to_words(ans_out) + return ans_word_scores + + else: + + ans_in = batch["opt_in"] + batch_size, num_rounds, num_options, max_sequence_length = ( + ans_in.size() + ) + + ans_in = ans_in.view( + batch_size * num_rounds * num_options, max_sequence_length + ) + + # shape: (batch_size * num_rounds * num_options, max_sequence_length + # word_embedding_size) + ans_in_embed = self.word_embed(ans_in) + + # reshape encoder output to be set as initial hidden state of LSTM. + # shape: (lstm_num_layers, batch_size * num_rounds * num_options, + # lstm_hidden_size) + init_hidden = encoder_output.view(batch_size, num_rounds, 1, -1) + init_hidden = init_hidden.repeat(1, 1, num_options, 1) + init_hidden = init_hidden.view( + 1, batch_size * num_rounds * num_options, -1 + ) + init_hidden = init_hidden.repeat( + self.config["lstm_num_layers"], 1, 1 + ) + init_cell = torch.zeros_like(init_hidden) + + # shape: (batch_size * num_rounds * num_options, + # max_sequence_length, lstm_hidden_size) + ans_out, (hidden, cell) = self.answer_rnn( + ans_in_embed, (init_hidden, init_cell) + ) + + # shape: (batch_size * num_rounds * num_options, + # max_sequence_length, vocabulary_size) + ans_word_scores = self.logsoftmax(self.lstm_to_words(ans_out)) + + # shape: (batch_size * num_rounds * num_options, + # max_sequence_length) + target_ans_out = batch["opt_out"].view( + batch_size * num_rounds * num_options, -1 + ) + + # shape: (batch_size * num_rounds * num_options, + # max_sequence_length) + ans_word_scores = torch.gather( + ans_word_scores, -1, target_ans_out.unsqueeze(-1) + ).squeeze() + ans_word_scores = ( + ans_word_scores * (target_ans_out > 0).float().cuda() + ) # ugly + + ans_scores = torch.sum(ans_word_scores, -1) + ans_scores = ans_scores.view(batch_size, num_rounds, num_options) + + return ans_scores diff --git a/visdialch/encoders/__init__.py b/visdialch/encoders/__init__.py new file mode 100644 index 0000000..7f8e0bc --- /dev/null +++ b/visdialch/encoders/__init__.py @@ -0,0 +1,10 @@ +from visdialch.encoders.lf import LateFusionEncoder +from visdialch.encoders.rva import RvAEncoder + + +def Encoder(model_config, *args): + name_enc_map = { + "lf": LateFusionEncoder, + "rva": RvAEncoder, + } + return name_enc_map[model_config["encoder"]](model_config, *args) diff --git a/visdialch/encoders/lf.py b/visdialch/encoders/lf.py new file mode 100644 index 0000000..8a0de1a --- /dev/null +++ b/visdialch/encoders/lf.py @@ -0,0 +1,135 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from visdialch.utils import DynamicRNN + + +class LateFusionEncoder(nn.Module): + def __init__(self, config, vocabulary): + super().__init__() + self.config = config + + self.word_embed = nn.Embedding( + len(vocabulary), + config["word_embedding_size"], + padding_idx=vocabulary.PAD_INDEX, + ) + self.hist_rnn = nn.LSTM( + config["word_embedding_size"], + config["lstm_hidden_size"], + config["lstm_num_layers"], + batch_first=True, + dropout=config["dropout"], + ) + self.ques_rnn = nn.LSTM( + config["word_embedding_size"], + config["lstm_hidden_size"], + config["lstm_num_layers"], + batch_first=True, + dropout=config["dropout"], + ) + self.dropout = nn.Dropout(p=config["dropout"]) + + # questions and history are right padded sequences of variable length + # use the DynamicRNN utility module to handle them properly + self.hist_rnn = DynamicRNN(self.hist_rnn) + self.ques_rnn = DynamicRNN(self.ques_rnn) + + # project image features to lstm_hidden_size for computing attention + self.image_features_projection = nn.Linear( + config["img_feature_size"], config["lstm_hidden_size"] + ) + + # fc layer for image * question to attention weights + self.attention_proj = nn.Linear(config["lstm_hidden_size"], 1) + + # fusion layer (attended_image_features + question + history) + fusion_size = ( + config["img_feature_size"] + config["lstm_hidden_size"] * 2 + ) + self.fusion = nn.Linear(fusion_size, config["lstm_hidden_size"]) + + nn.init.kaiming_uniform_(self.image_features_projection.weight) + nn.init.constant_(self.image_features_projection.bias, 0) + nn.init.kaiming_uniform_(self.fusion.weight) + nn.init.constant_(self.fusion.bias, 0) + + def forward(self, batch): + # shape: (batch_size, img_feature_size) - CNN fc7 features + # shape: (batch_size, num_proposals, img_feature_size) - RCNN features + img = batch["img_feat"] + # shape: (batch_size, 10, max_sequence_length) + ques = batch["ques"] + # shape: (batch_size, 10, max_sequence_length * 2 * 10) + # concatenated qa * 10 rounds + hist = batch["hist"] + # num_rounds = 10, even for test (padded dialog rounds at the end) + batch_size, num_rounds, max_sequence_length = ques.size() + + # embed questions + ques = ques.view(batch_size * num_rounds, max_sequence_length) + ques_embed = self.word_embed(ques) + + # shape: (batch_size * num_rounds, max_sequence_length, + # lstm_hidden_size) + _, (ques_embed, _) = self.ques_rnn(ques_embed, batch["ques_len"]) + + # project down image features and ready for attention + # shape: (batch_size, num_proposals, lstm_hidden_size) + projected_image_features = self.image_features_projection(img) + + # repeat image feature vectors to be provided for every round + # shape: (batch_size * num_rounds, num_proposals, lstm_hidden_size) + projected_image_features = ( + projected_image_features.view( + batch_size, 1, -1, self.config["lstm_hidden_size"] + ) + .repeat(1, num_rounds, 1, 1) + .view(batch_size * num_rounds, -1, self.config["lstm_hidden_size"]) + ) + + # computing attention weights + # shape: (batch_size * num_rounds, num_proposals) + projected_ques_features = ques_embed.unsqueeze(1).repeat( + 1, img.shape[1], 1 + ) + projected_ques_image = ( + projected_ques_features * projected_image_features + ) + projected_ques_image = self.dropout(projected_ques_image) + image_attention_weights = self.attention_proj( + projected_ques_image + ).squeeze() + image_attention_weights = F.softmax(image_attention_weights, dim=-1) + + # shape: (batch_size * num_rounds, num_proposals, img_features_size) + img = ( + img.view(batch_size, 1, -1, self.config["img_feature_size"]) + .repeat(1, num_rounds, 1, 1) + .view(batch_size * num_rounds, -1, self.config["img_feature_size"]) + ) + + # multiply image features with their attention weights + # shape: (batch_size * num_rounds, num_proposals, img_feature_size) + image_attention_weights = image_attention_weights.unsqueeze(-1).repeat( + 1, 1, self.config["img_feature_size"] + ) + # shape: (batch_size * num_rounds, img_feature_size) + attended_image_features = (image_attention_weights * img).sum(1) + img = attended_image_features + + # embed history + hist = hist.view(batch_size * num_rounds, max_sequence_length * 20) + hist_embed = self.word_embed(hist) + + # shape: (batch_size * num_rounds, lstm_hidden_size) + _, (hist_embed, _) = self.hist_rnn(hist_embed, batch["hist_len"]) + + fused_vector = torch.cat((img, ques_embed, hist_embed), 1) + fused_vector = self.dropout(fused_vector) + + fused_embedding = torch.tanh(self.fusion(fused_vector)) + # shape: (batch_size, num_rounds, lstm_hidden_size) + fused_embedding = fused_embedding.view(batch_size, num_rounds, -1) + return fused_embedding diff --git a/visdialch/encoders/modules.py b/visdialch/encoders/modules.py new file mode 100644 index 0000000..b526536 --- /dev/null +++ b/visdialch/encoders/modules.py @@ -0,0 +1,237 @@ +import torch +from torch import nn +from torch.nn import functional as F +from visdialch.utils import GumbelSoftmax, GatedTrans + +class ATT_MODULE(nn.Module): + """docstring for ATT_MODULE""" + def __init__(self, config): + super(ATT_MODULE, self).__init__() + + self.V_embed = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + GatedTrans( + config["img_feature_size"], + config["lstm_hidden_size"] + ), + ) + self.Q_embed = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + GatedTrans( + config["word_embedding_size"], + config["lstm_hidden_size"] + ), + ) + self.att = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + nn.Linear( + config["lstm_hidden_size"], + 1 + ) + ) + + self.softmax = nn.Softmax(dim=-1) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight.data) + if m.bias is not None: + nn.init.constant_(m.bias.data, 0) + + def forward(self, img, ques): + # input + # img - shape: (batch_size, num_proposals, img_feature_size) + # ques - shape: (batch_size, num_rounds, word_embedding_size) + # output + # att - shape: (batch_size, num_rounds, num_proposals) + + batch_size = ques.size(0) + num_rounds = ques.size(1) + num_proposals = img.size(1) + + img_embed = img.view(-1, img.size(-1)) # shape: (batch_size * num_proposals, img_feature_size) + img_embed = self.V_embed(img_embed) # shape: (batch_size, num_proposals, lstm_hidden_size) + img_embed = img_embed.view(batch_size, num_proposals, img_embed.size(-1)) # shape: (batch_size, num_proposals, lstm_hidden_size) + img_embed = img_embed.unsqueeze(1).repeat(1, num_rounds, 1, 1) # shape: (batch_size, num_rounds, num_proposals, lstm_hidden_size) + + ques_embed = ques.view(-1, ques.size(-1)) # shape: (batch_size * num_rounds, word_embedding_size) + ques_embed = self.Q_embed(ques_embed) # shape: (batch_size, num_rounds, lstm_hidden_size) + ques_embed = ques_embed.view(batch_size, num_rounds, ques_embed.size(-1)) # shape: (batch_size, num_rounds, lstm_hidden_size) + ques_embed = ques_embed.unsqueeze(2).repeat(1, 1, num_proposals, 1) # shape: (batch_size, num_rounds, num_proposals, lstm_hidden_size) + + att_embed = F.normalize(img_embed * ques_embed, p=2, dim=-1) # (batch_size, num_rounds, num_proposals, lstm_hidden_size) + att_embed = self.att(att_embed).squeeze(-1) # (batch_size, num_rounds, num_proposals) + att = self.softmax(att_embed) # shape: (batch_size, num_rounds, num_proposals) + + return att + +class PAIR_MODULE(nn.Module): + """docstring for PAIR_MODULE""" + def __init__(self, config): + super(PAIR_MODULE, self).__init__() + + self.H_embed = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + GatedTrans( + config["lstm_hidden_size"]*2, + config["lstm_hidden_size"]), + ) + self.Q_embed = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + GatedTrans( + config["lstm_hidden_size"]*2, + config["lstm_hidden_size"]), + ) + self.MLP = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + nn.Linear( + config["lstm_hidden_size"]*2, + config["lstm_hidden_size"]), + nn.Dropout(p=config["dropout_fc"]), + nn.Linear( + config["lstm_hidden_size"], + 1 + ) + ) + self.att = nn.Linear(2, 1) + + self.G_softmax = GumbelSoftmax() + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight.data) + if m.bias is not None: + nn.init.constant_(m.bias.data, 0) + + def forward(self, hist, ques): + # input + # ques shape: (batch_size, num_rounds, lstm_hidden_size*2) + # hist shape: (batch_size, num_rounds, lstm_hidden_size*2) + # output + # hist_gs_set - shape: (batch_size, num_rounds, num_rounds) + + batch_size = ques.size(0) + num_rounds = ques.size(1) + + hist_embed = self.H_embed(hist) # shape: (batch_size, num_rounds, lstm_hidden_size) + hist_embed = hist_embed.unsqueeze(1).repeat(1, num_rounds, 1, 1) # shape: (batch_size, num_rounds, num_rounds, lstm_hidden_size) + + ques_embed = self.Q_embed(ques) # shape: (batch_size, num_rounds, lstm_hidden_size) + ques_embed = ques_embed.unsqueeze(2).repeat(1, 1, num_rounds, 1) # shape: (batch_size, num_rounds, num_rounds, lstm_hidden_size) + + att_embed = torch.cat((hist_embed, ques_embed), dim=-1) + score = self.MLP(att_embed) + + delta_t = torch.tril(torch.ones(size=[num_rounds, num_rounds], requires_grad=False)).cumsum(dim=0) # (num_rounds, num_rounds) + delta_t = delta_t.view(1, num_rounds, num_rounds, 1).repeat(batch_size, 1, 1, 1) # (batch_size, num_rounds, num_rounds, 1) + delta_t = delta_t.cuda() + att_embed = torch.cat((score, delta_t), dim=-1) # (batch_size, num_rounds, num_rounds, lstm_hidden_size*2) + + hist_logits = self.att(att_embed).squeeze(-1) # (batch_size, num_rounds, num_rounds) + + # PAIR + hist_gs_set = torch.zeros_like(hist_logits) + for i in range(num_rounds): + # one-hot + hist_gs_set[:, i, :(i+1)] = self.G_softmax(hist_logits[:, i, :(i+1)]) # shape: (batch_size, i+1) + + return hist_gs_set + +class INFER_MODULE(nn.Module): + """docstring for INFER_MODULE""" + def __init__(self, config): + super(INFER_MODULE, self).__init__() + + self.embed = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + GatedTrans( + config["word_embedding_size"], + config["lstm_hidden_size"]), + ) + self.att = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + nn.Linear( + config["lstm_hidden_size"], + 2 + ) + ) + + self.softmax = nn.Softmax(dim=-1) + self.G_softmax = GumbelSoftmax() + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight.data) + if m.bias is not None: + nn.init.constant_(m.bias.data, 0) + + def forward(self, ques): + # input + # ques - shape: (batch_size, num_rounds, word_embedding_size) + # output + # ques_gs - shape: (batch_size, num_rounds, 2) + # Lambda - shape: (batch_size, num_rounds, 2) + + batch_size = ques.size(0) + num_rounds = ques.size(1) + + ques_embed = self.embed(ques) # shape: (batch_size, num_rounds, quen_len_max, lstm_hidden_size) + ques_embed = F.normalize(ques_embed, p=2, dim=-1) # shape: (batch_size, num_rounds, quen_len_max, lstm_hidden_size) + ques_logits = self.att(ques_embed) # shape: (batch_size, num_rounds, 2) + # ignore word + ques_gs = self.G_softmax(ques_logits.view(-1, 2)).view(-1, num_rounds, 2) + Lambda = self.softmax(ques_logits) + + return ques_gs, Lambda + +class RvA_MODULE(nn.Module): + """docstring for R_CALL""" + def __init__(self, config): + super(RvA_MODULE, self).__init__() + + self.INFER_MODULE = INFER_MODULE(config) + self.PAIR_MODULE = PAIR_MODULE(config) + self.ATT_MODULE = ATT_MODULE(config) + + def forward(self, img, ques, hist): + # img shape: [batch_size, num_proposals, i_dim] + # img_att_ques shape: [batch_size, num_rounds, num_proposals] + # img_att_cap shape: [batch_size, 1, num_proposals] + # ques_gs shape: [batch_size, num_rounds, 2] + # hist_logits shape: [batch_size, num_rounds, num_rounds] + # ques_gs_prob shape: [batch_size, num_rounds, 2] + + cap_feat, ques_feat, ques_encoded = ques + + batch_size = ques_feat.size(0) + num_rounds = ques_feat.size(1) + num_proposals = img.size(1) + + ques_gs, ques_gs_prob = self.INFER_MODULE(ques_feat) # (batch_size, num_rounds, 2) + hist_gs_set = self.PAIR_MODULE(hist, ques_encoded) + img_att_ques = self.ATT_MODULE(img, ques_feat) + img_att_cap = self.ATT_MODULE(img, cap_feat) + + # soft + ques_prob_single = torch.Tensor(data=[1, 0]).view(1, -1).repeat(batch_size, 1) # shape: [batch_size, 2] + ques_prob_single = ques_prob_single.cuda() + ques_prob_single.requires_grad = False + + img_att_refined = img_att_ques.data.clone().zero_() # shape: [batch_size, num_rounds, num_proposals] + for i in range(num_rounds): + if i == 0: + img_att_temp = img_att_cap.view(-1, img_att_cap.size(-1)) # shape: [batch_size, num_proposals] + else: + hist_gs = hist_gs_set[:, i, :(i+1)] # shape: [batch_size, i+1] + img_att_temp = torch.cat((img_att_cap, img_att_refined[:, :i, :]), dim=1) # shape: [batch_size, i+1, num_proposals] + img_att_temp = torch.sum(hist_gs.unsqueeze(-1) * img_att_temp, dim=-2) # shape: [batch_size, num_proposals] + img_att_cat = torch.cat((img_att_ques[:, i, :].unsqueeze(1), img_att_temp.unsqueeze(1)), dim=1) # shape: [batch_size ,2, num_proposals] + # soft + ques_prob_pair = ques_gs_prob[:, i, :] + ques_prob = torch.cat((ques_prob_single, ques_prob_pair), dim=-1) # shape: [batch_size, 2] + ques_prob = ques_prob.view(-1, 2, 2) # shape: [batch_size, 2, 2] + ques_prob_refine = torch.bmm(ques_gs[:, i, :].view(-1, 1, 2), ques_prob).view(-1, 1, 2) # shape: [batch_size, num_rounds, 2] + + img_att_refined[:, i, :] = torch.bmm(ques_prob_refine, img_att_cat).view(-1, num_proposals) # shape: [batch_size, num_proposals] + + return img_att_refined, (ques_gs, hist_gs_set, img_att_ques) \ No newline at end of file diff --git a/visdialch/encoders/rva.py b/visdialch/encoders/rva.py new file mode 100644 index 0000000..4091844 --- /dev/null +++ b/visdialch/encoders/rva.py @@ -0,0 +1,171 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from visdialch.utils import DynamicRNN, GumbelSoftmax +from visdialch.utils import Q_ATT, H_ATT, V_Filter +from .modules import RvA_MODULE + +class RvAEncoder(nn.Module): + def __init__(self, config, vocabulary): + super().__init__() + self.config = config + + self.word_embed = nn.Embedding( + len(vocabulary), + config["word_embedding_size"], + padding_idx=vocabulary.PAD_INDEX + ) + + self.hist_rnn = nn.LSTM( + config["word_embedding_size"], + config["lstm_hidden_size"], + config["lstm_num_layers"], + batch_first=True, + dropout=config["dropout"], + bidirectional=True + ) + self.ques_rnn = nn.LSTM( + config["word_embedding_size"], + config["lstm_hidden_size"], + config["lstm_num_layers"], + batch_first=True, + dropout=config["dropout"], + bidirectional=True + ) + # questions and history are right padded sequences of variable length + # use the DynamicRNN utility module to handle them properly + self.hist_rnn = DynamicRNN(self.hist_rnn) + self.ques_rnn = DynamicRNN(self.ques_rnn) + + # self attention for question + self.Q_ATT_ans = Q_ATT(config) + self.Q_ATT_ref = Q_ATT(config) + # question-based history attention + self.H_ATT_ans = H_ATT(config) + + # modules + self.RvA_MODULE = RvA_MODULE(config) + self.V_Filter = V_Filter(config) + + # fusion layer + self.fusion = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + nn.Linear( + config["img_feature_size"] + config["word_embedding_size"] + config["lstm_hidden_size"] * 2, + config["lstm_hidden_size"] + ) + ) + # other useful functions + self.softmax = nn.Softmax(dim=-1) + self.G_softmax = GumbelSoftmax() + + # initialization + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform(m.weight.data) + if m.bias is not None: + nn.init.constant_(m.bias.data, 0) + + def forward(self, batch, return_att=False): + # img - shape: (batch_size, num_proposals, img_feature_size) - RCNN bottom-up features + img = batch["img_feat"] + batch_size = batch["ques"].size(0) + num_rounds = batch["ques"].size(1) + + # init language embedding + # ques_word_embed - shape: (batch_size, num_rounds, quen_len_max, word_embedding_size) + # ques_word_encoded - shape: (batch_size, num_rounds, quen_len_max, lstm_hidden_size) + # ques_not_pad - shape: (batch_size, num_rounds, quen_len_max) + # ques_encoded - shape: (batch_size, num_rounds, lstm_hidden_size) + ques_word_embed, ques_word_encoded, ques_not_pad, ques_encoded = self.init_q_embed(batch) + # hist_word_embed - shape: (batch_size, num_rounds, quen_len_max, word_embedding_size) + # hist_encoded - shape: (batch_size, num_rounds, lstm_hidden_size) + hist_word_embed, hist_encoded = self.init_h_embed(batch) + # cap_word_embed - shape: (batch_size, num_rounds, quen_len_max, word_embedding_size) + # cap_word_encoded - shape: (batch_size, num_rounds, quen_len_max, lstm_hidden_size) + # cap_not_pad - shape: (batch_size, num_rounds, quen_len_max) + cap_word_embed, cap_word_encoded, cap_not_pad = self.init_cap_embed(batch) + + # question feature for RvA + # ques_ref_feat - shape: (batch_size, num_rounds, lstm_hidden_size) + # ques_ref_att - shape: (batch_size, num_rounds, quen_len_max) + ques_ref_feat, ques_ref_att = self.Q_ATT_ref(ques_word_embed, ques_word_encoded, ques_not_pad) + # cap_ref_feat - shape: (batch_size, num_rounds, lstm_hidden_size) + cap_ref_feat, _ = self.Q_ATT_ref(cap_word_embed, cap_word_encoded, cap_not_pad) + + # RvA module + ques_feat = (cap_ref_feat, ques_ref_feat, ques_encoded) + # img_att - shape: (batch_size, num_rounds, num_proposals) + img_att, att_set = self.RvA_MODULE(img, ques_feat, hist_encoded) + # img_feat - shape: (batch_size, num_rounds, img_feature_size) + img_feat = torch.bmm(img_att, img) + + # ans_feat for joint embedding + # hist_ans_feat - shape: (batch_size, num_rounds, lstm_hidden_size*2) + hist_ans_feat = self.H_ATT_ans(hist_encoded, ques_encoded) + # ques_ans_feat - shape: (batch_size, num_rounds, word_embedding_size) + # ques_ans_att - shape: (batch_size, num_rounds, quen_len_max) + ques_ans_feat, ques_ans_att = self.Q_ATT_ans(ques_word_embed, ques_word_encoded, ques_not_pad) + # img_ans_feat - shape: (batch_size, num_rounds, img_feature_size) + img_ans_feat = self.V_Filter(img_feat, ques_ans_feat) + + # joint embedding + fused_vector = torch.cat((img_ans_feat, ques_ans_feat, hist_ans_feat), -1) + # img_ans_feat - shape: (batch_size, num_rounds, lstm_hidden_size) + fused_embedding = torch.tanh(self.fusion(fused_vector)) + + if return_att: + return fused_embedding, att_set + (ques_ref_att, ques_ans_att) + else: + return fused_embedding + + def init_q_embed(self, batch): + ques = batch["ques"] # shape: (batch_size, num_rounds, quen_len_max) + batch_size, num_rounds, _ = ques.size() + lstm_hidden_size = self.config["lstm_hidden_size"] + + # question feature + ques_not_pad = (ques!=0).float() # shape: (batch_size, num_rounds, quen_len_max) + ques = ques.view(-1, ques.size(-1)) # shape: (batch_size*num_rounds, quen_len_max) + ques_word_embed = self.word_embed(ques) # shape: (batch_size*num_rounds, quen_len_max, lstm_hidden_size) + ques_word_encoded, _ = self.ques_rnn(ques_word_embed, batch['ques_len']) # shape: (batch_size*num_rounds, quen_len_max, lstm_hidden_size*2) + quen_len_max = ques_word_encoded.size(1) + loc = batch['ques_len'].view(-1).cpu().numpy()-1 + ques_encoded_forawrd = ques_word_encoded[range(num_rounds*batch_size), loc, :lstm_hidden_size] # shape: (batch_size*num_rounds, lstm_hidden_size) + ques_encoded_backward = ques_word_encoded[:, 0, lstm_hidden_size:] # shape: (batch_size*num_rounds, lstm_hidden_size) + ques_encoded = torch.cat((ques_encoded_forawrd, ques_encoded_backward), dim=-1) + ques_encoded = ques_encoded.view(-1, num_rounds, ques_encoded.size(-1)) # shape: (batch_size, num_rounds, lstm_hidden_size*2) + ques_word_encoded = ques_word_encoded.view(-1, num_rounds, quen_len_max, ques_word_encoded.size(-1)) # shape: (batch_size, num_rounds, quen_len_max, lstm_hidden_size) + ques_word_embed = ques_word_embed.view(-1, num_rounds, quen_len_max, ques_word_embed.size(-1)) # shape: (batch_size, num_rounds, quen_len_max, word_embedding_size) + + return ques_word_embed, ques_word_encoded, ques_not_pad, ques_encoded + + def init_h_embed(self, batch): + hist = batch["hist"] # shape: (batch_size, num_rounds, hist_len_max) + batch_size, num_rounds, _ = hist.size() + lstm_hidden_size = self.config["lstm_hidden_size"] + + hist = hist.view(-1, hist.size(-1)) # shape: (batch_size*num_rounds, hist_len_max) + hist_word_embed = self.word_embed(hist) # shape: (batch_size*num_rounds, hist_len_max, word_embedding_size) + hist_word_encoded, _ = self.hist_rnn(hist_word_embed, batch['hist_len']) # shape: (batch_size*num_rounds, hist_len_max, lstm_hidden_size*2) + loc = batch['hist_len'].view(-1).cpu().numpy()-1 + hist_encoded_forward = hist_word_encoded[range(num_rounds*batch_size), loc, :lstm_hidden_size] # shape: (batch_size*num_rounds, hist_len_max, lstm_hidden_size*2) + hist_encoded_backward = hist_word_encoded[:, 0, lstm_hidden_size:] # shape: (batch_size*num_rounds, lstm_hidden_size) + hist_encoded = torch.cat((hist_encoded_forward, hist_encoded_backward), dim=-1) + hist_encoded = hist_encoded.view(-1, num_rounds, hist_encoded.size(-1)) # shape: (batch_size, num_rounds, lstm_hidden_size) + + return hist_word_embed, hist_encoded + + def init_cap_embed(self, batch): + cap = batch["hist"][:, :1, :] # shape: (batch_size, 1, hist_len_max) + + # caption feature like question + cap_not_pad = (cap!=0).float() # shape: (batch_size, 1, hist_len_max) + cap_word_embed = self.word_embed(cap.squeeze(1)) # shape: (batch_size*1, hist_len_max, lstm_hidden_size) + cap_len = batch['hist_len'][:, :1] + cap_word_encoded, _ = self.ques_rnn(cap_word_embed, cap_len) # shape: (batch_size*1, hist_len_max, lstm_hidden_size) + cap_word_encoded = cap_word_encoded.unsqueeze(1) # shape: (batch_size, 1, hist_len_max, lstm_hidden_size) + cap_word_embed = cap_word_embed.unsqueeze(1) # shape: (batch_size, 1, hist_len_max, lstm_hidden_size) + + return cap_word_embed, cap_word_encoded, cap_not_pad \ No newline at end of file diff --git a/visdialch/metrics.py b/visdialch/metrics.py new file mode 100644 index 0000000..7eea9ce --- /dev/null +++ b/visdialch/metrics.py @@ -0,0 +1,173 @@ +""" +A Metric observes output of certain model, for example, in form of logits or +scores, and accumulates a particular metric with reference to some provided +targets. In context of VisDial, we use Recall (@ 1, 5, 10), Mean Rank, Mean +Reciprocal Rank (MRR) and Normalized Discounted Cumulative Gain (NDCG). + +Each ``Metric`` must atleast implement three methods: + - ``observe``, update accumulated metric with currently observed outputs + and targets. + - ``retrieve`` to return the accumulated metric., an optionally reset + internally accumulated metric (this is commonly done between two epochs + after validation). + - ``reset`` to explicitly reset the internally accumulated metric. + +Caveat, if you wish to implement your own class of Metric, make sure you call +``detach`` on output tensors (like logits), else it will cause memory leaks. +""" +import torch + + +def scores_to_ranks(scores: torch.Tensor): + """Convert model output scores into ranks.""" + batch_size, num_rounds, num_options = scores.size() + scores = scores.view(-1, num_options) + + # sort in descending order - largest score gets highest rank + sorted_ranks, ranked_idx = scores.sort(1, descending=True) + + # i-th position in ranked_idx specifies which score shall take this + # position but we want i-th position to have rank of score at that + # position, do this conversion + ranks = ranked_idx.clone().fill_(0) + for i in range(ranked_idx.size(0)): + for j in range(num_options): + ranks[i][ranked_idx[i][j]] = j + # convert from 0-99 ranks to 1-100 ranks + ranks += 1 + ranks = ranks.view(batch_size, num_rounds, num_options) + return ranks + + +class SparseGTMetrics(object): + """ + A class to accumulate all metrics with sparse ground truth annotations. + These include Recall (@ 1, 5, 10), Mean Rank and Mean Reciprocal Rank. + """ + + def __init__(self): + self._rank_list = [] + + def observe( + self, predicted_scores: torch.Tensor, target_ranks: torch.Tensor + ): + predicted_scores = predicted_scores.detach() + + # shape: (batch_size, num_rounds, num_options) + predicted_ranks = scores_to_ranks(predicted_scores) + batch_size, num_rounds, num_options = predicted_ranks.size() + + # collapse batch dimension + predicted_ranks = predicted_ranks.view( + batch_size * num_rounds, num_options + ) + + # shape: (batch_size * num_rounds, ) + target_ranks = target_ranks.view(batch_size * num_rounds).long() + + # shape: (batch_size * num_rounds, ) + predicted_gt_ranks = predicted_ranks[ + torch.arange(batch_size * num_rounds), target_ranks + ] + self._rank_list.extend(list(predicted_gt_ranks.cpu().numpy())) + + def retrieve(self, reset: bool = True): + num_examples = len(self._rank_list) + if num_examples > 0: + # convert to numpy array for easy calculation. + __rank_list = torch.tensor(self._rank_list).float() + metrics = { + "r@1": torch.mean((__rank_list <= 1).float()).item(), + "r@5": torch.mean((__rank_list <= 5).float()).item(), + "r@10": torch.mean((__rank_list <= 10).float()).item(), + "mean": torch.mean(__rank_list).item(), + "mrr": torch.mean(__rank_list.reciprocal()).item(), + } + else: + metrics = {} + + if reset: + self.reset() + return metrics + + def reset(self): + self._rank_list = [] + + +class NDCG(object): + def __init__(self): + self._ndcg_numerator = 0.0 + self._ndcg_denominator = 0.0 + + def observe( + self, predicted_scores: torch.Tensor, target_relevance: torch.Tensor + ): + """ + Observe model output scores and target ground truth relevance and + accumulate NDCG metric. + + Parameters + ---------- + predicted_scores: torch.Tensor + A tensor of shape (batch_size, num_options), because dense + annotations are available for 1 randomly picked round out of 10. + target_relevance: torch.Tensor + A tensor of shape same as predicted scores, indicating ground truth + relevance of each answer option for a particular round. + """ + predicted_scores = predicted_scores.detach() + + # shape: (batch_size, 1, num_options) + predicted_scores = predicted_scores.unsqueeze(1) + predicted_ranks = scores_to_ranks(predicted_scores) + + # shape: (batch_size, num_options) + predicted_ranks = predicted_ranks.squeeze() + batch_size, num_options = predicted_ranks.size() + + k = torch.sum(target_relevance != 0, dim=-1) + + # shape: (batch_size, num_options) + _, rankings = torch.sort(predicted_ranks, dim=-1) + # Sort relevance in descending order so highest relevance gets top rnk. + _, best_rankings = torch.sort( + target_relevance, dim=-1, descending=True + ) + + # shape: (batch_size, ) + batch_ndcg = [] + for batch_index in range(batch_size): + num_relevant = k[batch_index] + dcg = self._dcg( + rankings[batch_index][:num_relevant], + target_relevance[batch_index], + ) + best_dcg = self._dcg( + best_rankings[batch_index][:num_relevant], + target_relevance[batch_index], + ) + batch_ndcg.append(dcg / best_dcg) + + self._ndcg_denominator += batch_size + self._ndcg_numerator += sum(batch_ndcg) + + def _dcg(self, rankings: torch.Tensor, relevance: torch.Tensor): + sorted_relevance = relevance[rankings].cpu().float() + discounts = torch.log2(torch.arange(len(rankings)).float() + 2) + return torch.sum(sorted_relevance / discounts, dim=-1) + + def retrieve(self, reset: bool = True): + if self._ndcg_denominator > 0: + metrics = { + "ndcg": float(self._ndcg_numerator / self._ndcg_denominator) + } + else: + metrics = {} + + if reset: + self.reset() + return metrics + + def reset(self): + self._ndcg_numerator = 0.0 + self._ndcg_denominator = 0.0 diff --git a/visdialch/model.py b/visdialch/model.py new file mode 100644 index 0000000..ed984d5 --- /dev/null +++ b/visdialch/model.py @@ -0,0 +1,21 @@ +from torch import nn + + +class EncoderDecoderModel(nn.Module): + """Convenience wrapper module, wrapping Encoder and Decoder modules. + + Parameters + ---------- + encoder: nn.Module + decoder: nn.Module + """ + + def __init__(self, encoder, decoder): + super().__init__() + self.encoder = encoder + self.decoder = decoder + + def forward(self, batch): + encoder_output = self.encoder(batch) + decoder_output = self.decoder(encoder_output, batch) + return decoder_output diff --git a/visdialch/utils/__init__.py b/visdialch/utils/__init__.py new file mode 100644 index 0000000..61487bd --- /dev/null +++ b/visdialch/utils/__init__.py @@ -0,0 +1,3 @@ +from .dynamic_rnn import DynamicRNN +from .layers import Q_ATT, H_ATT, V_Filter, GatedTrans +from .gumbel_softmax import GumbelSoftmax \ No newline at end of file diff --git a/visdialch/utils/checkpointing.py b/visdialch/utils/checkpointing.py new file mode 100644 index 0000000..eb4b41c --- /dev/null +++ b/visdialch/utils/checkpointing.py @@ -0,0 +1,179 @@ +""" +A checkpoint manager periodically saves model and optimizer as .pth +files during training. + +Checkpoint managers help with experiment reproducibility, they record +the commit SHA of your current codebase in the checkpoint saving +directory. While loading any checkpoint from other commit, they raise a +friendly warning, a signal to inspect commit diffs for potential bugs. +Moreover, they copy experiment hyper-parameters as a YAML config in +this directory. + +That said, always run your experiments after committing your changes, +this doesn't account for untracked or staged, but uncommitted changes. +""" +from pathlib import Path +from subprocess import PIPE, Popen +import warnings + +import torch +from torch import nn, optim +import yaml + + +class CheckpointManager(object): + """A checkpoint manager saves state dicts of model and optimizer + as .pth files in a specified directory. This class closely follows + the API of PyTorch optimizers and learning rate schedulers. + + Note:: + For ``DataParallel`` modules, ``model.module.state_dict()`` is + saved, instead of ``model.state_dict()``. + + Parameters + ---------- + model: nn.Module + Wrapped model, which needs to be checkpointed. + optimizer: optim.Optimizer + Wrapped optimizer which needs to be checkpointed. + checkpoint_dirpath: str + Path to an empty or non-existent directory to save checkpoints. + step_size: int, optional (default=1) + Period of saving checkpoints. + last_epoch: int, optional (default=-1) + The index of last epoch. + + Example + -------- + >>> model = torch.nn.Linear(10, 2) + >>> optimizer = torch.optim.Adam(model.parameters()) + >>> ckpt_manager = CheckpointManager(model, optimizer, "/tmp/ckpt") + >>> for epoch in range(20): + ... for batch in dataloader: + ... do_iteration(batch) + ... ckpt_manager.step() + """ + + def __init__( + self, + model, + optimizer, + checkpoint_dirpath, + step_size=1, + last_epoch=-1, + **kwargs, + ): + + if not isinstance(model, nn.Module): + raise TypeError("{} is not a Module".format(type(model).__name__)) + + if not isinstance(optimizer, optim.Optimizer): + raise TypeError( + "{} is not an Optimizer".format(type(optimizer).__name__) + ) + + self.model = model + self.optimizer = optimizer + self.ckpt_dirpath = Path(checkpoint_dirpath) + self.step_size = step_size + self.last_epoch = last_epoch + self.init_directory(**kwargs) + + def init_directory(self, config={}): + """Initialize empty checkpoint directory and record commit SHA + in it. Also save hyper-parameters config in this directory to + associate checkpoints with their hyper-parameters. + """ + + self.ckpt_dirpath.mkdir(parents=True, exist_ok=True) + # save current git commit hash in this checkpoint directory + commit_sha_subprocess = Popen( + ["git", "rev-parse", "--short", "HEAD"], stdout=PIPE, stderr=PIPE + ) + commit_sha, _ = commit_sha_subprocess.communicate() + commit_sha = commit_sha.decode("utf-8").strip().replace("\n", "") + commit_sha_filepath = self.ckpt_dirpath / f".commit-{commit_sha}" + commit_sha_filepath.touch() + yaml.dump( + config, + open(str(self.ckpt_dirpath / "config.yml"), "w"), + default_flow_style=False, + ) + + def step(self, epoch=None): + """Save checkpoint if step size conditions meet. """ + + if not epoch: + epoch = self.last_epoch + 1 + self.last_epoch = epoch + + if not self.last_epoch % self.step_size: + torch.save( + { + "model": self._model_state_dict(), + "optimizer": self.optimizer.state_dict(), + }, + self.ckpt_dirpath / f"checkpoint_{self.last_epoch}.pth", + ) + + def _model_state_dict(self): + """Returns state dict of model, taking care of DataParallel case.""" + if isinstance(self.model, nn.DataParallel): + return self.model.module.state_dict() + else: + return self.model.state_dict() + + +def load_checkpoint(checkpoint_pthpath): + """Given a path to saved checkpoint, load corresponding state dicts + of model and optimizer from it. This method checks if the current + commit SHA of codebase matches the commit SHA recorded when this + checkpoint was saved by checkpoint manager. + + Parameters + ---------- + checkpoint_pthpath: str or pathlib.Path + Path to saved checkpoint (as created by ``CheckpointManager``). + + Returns + ------- + nn.Module, optim.Optimizer + Model and optimizer state dicts loaded from checkpoint. + + Raises + ------ + UserWarning + If commit SHA do not match, or if the directory doesn't have + the recorded commit SHA. + """ + + if isinstance(checkpoint_pthpath, str): + checkpoint_pthpath = Path(checkpoint_pthpath) + checkpoint_dirpath = checkpoint_pthpath.resolve().parent + checkpoint_commit_sha = list(checkpoint_dirpath.glob(".commit-*")) + + if len(checkpoint_commit_sha) == 0: + warnings.warn( + "Commit SHA was not recorded while saving checkpoints." + ) + else: + # verify commit sha, raise warning if it doesn't match + commit_sha_subprocess = Popen( + ["git", "rev-parse", "--short", "HEAD"], stdout=PIPE, stderr=PIPE + ) + commit_sha, _ = commit_sha_subprocess.communicate() + commit_sha = commit_sha.decode("utf-8").strip().replace("\n", "") + + # remove ".commit-" + checkpoint_commit_sha = checkpoint_commit_sha[0].name[8:] + + if commit_sha != checkpoint_commit_sha: + warnings.warn( + f"Current commit ({commit_sha}) and the commit " + f"({checkpoint_commit_sha}) at which checkpoint was saved," + " are different. This might affect reproducibility." + ) + + # load encoder, decoder, optimizer state_dicts + components = torch.load(checkpoint_pthpath) + return components["model"], components["optimizer"] diff --git a/visdialch/utils/dynamic_rnn.py b/visdialch/utils/dynamic_rnn.py new file mode 100644 index 0000000..e014c08 --- /dev/null +++ b/visdialch/utils/dynamic_rnn.py @@ -0,0 +1,64 @@ +import torch +from torch import nn +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + + +class DynamicRNN(nn.Module): + def __init__(self, rnn_model): + super().__init__() + self.rnn_model = rnn_model + + def forward(self, seq_input, seq_lens, initial_state=None): + """A wrapper over pytorch's rnn to handle sequences of variable length. + + Arguments + --------- + seq_input : torch.Tensor + Input sequence tensor (padded) for RNN model. + Shape: (batch_size, max_sequence_length, embed_size) + seq_lens : torch.LongTensor + Length of sequences (b, ) + initial_state : torch.Tensor + Initial (hidden, cell) states of RNN model. + + Returns + ------- + Single tensor of shape (batch_size, rnn_hidden_size) corresponding + to the outputs of the RNN model at the last time step of each input + sequence. + """ + max_sequence_length = seq_input.size(1) + sorted_len, fwd_order, bwd_order = self._get_sorted_order(seq_lens) + sorted_seq_input = seq_input.index_select(0, fwd_order) + packed_seq_input = pack_padded_sequence( + sorted_seq_input, lengths=sorted_len, batch_first=True + ) + + if initial_state is not None: + hx = initial_state + assert hx[0].size(0) == self.rnn_model.num_layers + else: + sorted_hx = None + + self.rnn_model.flatten_parameters() + + outputs, (h_n, c_n) = self.rnn_model(packed_seq_input, sorted_hx) + + # pick hidden and cell states of last layer + h_n = h_n[-1].index_select(dim=0, index=bwd_order) + c_n = c_n[-1].index_select(dim=0, index=bwd_order) + + outputs = pad_packed_sequence( + outputs, batch_first=True, total_length=max_sequence_length + )[0].index_select(dim=0, index=bwd_order) + + return outputs, (h_n, c_n) + + @staticmethod + def _get_sorted_order(lens): + sorted_len, fwd_order = torch.sort( + lens.contiguous().view(-1), 0, descending=True + ) + _, bwd_order = torch.sort(fwd_order) + sorted_len = list(sorted_len) + return sorted_len, fwd_order, bwd_order diff --git a/visdialch/utils/gumbel_softmax.py b/visdialch/utils/gumbel_softmax.py new file mode 100644 index 0000000..cbedc55 --- /dev/null +++ b/visdialch/utils/gumbel_softmax.py @@ -0,0 +1,68 @@ +import torch +from torch import nn +import torch.nn.functional as F + +class GumbelSoftmax(nn.Module): + def __init__(self, temperature=1, hard=False): + super(GumbelSoftmax, self).__init__() + self.hard = hard + self.gpu = False + + self.temperature = temperature + + def cuda(self): + self.gpu = True + + def cpu(self): + self.gpu = False + + def sample_gumbel(self, shape, eps=1e-10): + """Sample from Gumbel(0, 1)""" + noise = torch.rand(shape) + noise.add_(eps).log_().neg_() + noise.add_(eps).log_().neg_() + if self.gpu: + return noise.detach().cuda() + else: + return noise.detach() + + def sample_gumbel_like(self, template_tensor, eps=1e-10): + uniform_samples_tensor = template_tensor.clone().uniform_() + gumble_samples_tensor = - torch.log(eps - torch.log(uniform_samples_tensor + eps)) + return gumble_samples_tensor + + def gumbel_softmax_sample(self, logits, temperature): + """ Draw a sample from the Gumbel-Softmax distribution""" + dim = logits.size(-1) + gumble_samples_tensor = self.sample_gumbel_like(logits.detach()) # 0.4 + gumble_trick_log_prob_samples = logits + gumble_samples_tensor.detach() + soft_samples = F.softmax(gumble_trick_log_prob_samples / temperature, -1) + return soft_samples + + def gumbel_softmax(self, logits, temperature): + """Sample from the Gumbel-Softmax distribution and optionally discretize. + Args: + logits: [batch_size, n_class] unnormalized log-probs + temperature: non-negative scalar + Returns: + [batch_size, n_class] sample from the Gumbel-Softmax distribution. + If hard=True, then the returned sample will be one-hot, otherwise it will + be a probabilitiy distribution that sums to 1 across classes + """ + if self.training: + y = self.gumbel_softmax_sample(logits, temperature) + _, max_value_indexes = y.detach().max(1, keepdim=True) + y_hard = logits.detach().clone().zero_().scatter_(1, max_value_indexes, 1) + y = y_hard - y.detach() + y + else: + _, max_value_indexes = logits.detach().max(1, keepdim=True) + y = logits.detach().clone().zero_().scatter_(1, max_value_indexes, 1) + return y + + def forward(self, logits, temperature=None): + samplesize = logits.size() + + if temperature == None: + temperature = self.temperature + + return self.gumbel_softmax(logits, temperature=temperature) \ No newline at end of file diff --git a/visdialch/utils/layers.py b/visdialch/utils/layers.py new file mode 100644 index 0000000..d1bb0cc --- /dev/null +++ b/visdialch/utils/layers.py @@ -0,0 +1,171 @@ +import torch +from torch import nn +from torch.nn import functional as F + +class GatedTrans(nn.Module): + """docstring for GatedTrans""" + def __init__(self, in_dim, out_dim): + super(GatedTrans, self).__init__() + + self.embed_y = nn.Sequential( + nn.Linear( + in_dim, + out_dim + ), + nn.Tanh() + ) + self.embed_g = nn.Sequential( + nn.Linear( + in_dim, + out_dim + ), + nn.Sigmoid() + ) + + def forward(self, x_in): + x_y = self.embed_y(x_in) + x_g = self.embed_g(x_in) + x_out = x_y*x_g + + return x_out + +class Q_ATT(nn.Module): + """Self attention module of questions.""" + def __init__(self, config): + super(Q_ATT, self).__init__() + + self.embed = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + GatedTrans( + config["lstm_hidden_size"]*2, + config["lstm_hidden_size"] + ), + ) + self.att = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + nn.Linear( + config["lstm_hidden_size"], + 1 + ) + ) + self.softmax = nn.Softmax(dim=-1) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight.data) + if m.bias is not None: + nn.init.constant_(m.bias.data, 0) + + def forward(self, ques_word, ques_word_encoded, ques_not_pad): + # ques_word shape: (batch_size, num_rounds, quen_len_max, word_embed_dim) + # ques_embed shape: (batch_size, num_rounds, quen_len_max, lstm_hidden_size * 2) + # ques_not_pad shape: (batch_size, num_rounds, quen_len_max) + # output: img_att (batch_size, num_rounds, embed_dim) + batch_size = ques_word.size(0) + num_rounds = ques_word.size(1) + quen_len_max = ques_word.size(2) + + ques_embed = self.embed(ques_word_encoded) # shape: (batch_size, num_rounds, quen_len_max, embed_dim) + ques_norm = F.normalize(ques_embed, p=2, dim=-1) # shape: (batch_size, num_rounds, quen_len_max, embed_dim) + + att = self.att(ques_norm).squeeze(-1) # shape: (batch_size, num_rounds, quen_len_max) + # ignore word + att = self.softmax(att) + att = att*ques_not_pad # shape: (batch_size, num_rounds, quen_len_max) + att = att / torch.sum(att, dim=-1, keepdim=True) # shape: (batch_size, num_rounds, quen_len_max) + feat = torch.sum(att.unsqueeze(-1) * ques_word, dim=-2) # shape: (batch_size, num_rounds, rnn_dim) + + return feat, att + +class H_ATT(nn.Module): + """question-based history attention""" + def __init__(self, config): + super(H_ATT, self).__init__() + + self.H_embed = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + GatedTrans( + config["lstm_hidden_size"]*2, + config["lstm_hidden_size"] + ), + ) + self.Q_embed = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + GatedTrans( + config["lstm_hidden_size"]*2, + config["lstm_hidden_size"] + ), + ) + self.att = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + nn.Linear( + config["lstm_hidden_size"], + 1 + ) + ) + self.softmax = nn.Softmax(dim=-1) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight.data) + if m.bias is not None: + nn.init.constant_(m.bias.data, 0) + + def forward(self, hist, ques): + # hist shape: (batch_size, num_rounds, rnn_dim) + # ques shape: (batch_size, num_rounds, rnn_dim) + # output: hist_att (batch_size, num_rounds, embed_dim) + batch_size = ques.size(0) + num_rounds = ques.size(1) + + hist_embed = self.H_embed(hist) # shape: (batch_size, num_rounds, embed_dim) + hist_embed = hist_embed.unsqueeze(1).repeat(1, num_rounds, 1, 1) # shape: (batch_size, num_rounds, num_rounds, embed_dim) + + ques_embed = self.Q_embed(ques) # shape: (batch_size, num_rounds, embed_dim) + ques_embed = ques_embed.unsqueeze(2).repeat(1, 1, num_rounds, 1) # shape: (batch_size, num_rounds, num_rounds, embed_dim) + + att_embed = F.normalize(hist_embed*ques_embed, p=2, dim=-1) # (batch_size, num_rounds, num_rounds, embed_dim) + att_embed = self.att(att_embed).squeeze(-1) + att = self.softmax(att_embed) # shape: (batch_size, num_rounds, num_rounds) + att_not_pad = torch.tril(torch.ones(size=[num_rounds, num_rounds], requires_grad=False)) # shape: (num_rounds, num_rounds) + att_not_pad = att_not_pad.cuda() + att_masked = att*att_not_pad # shape: (batch_size, num_rounds, num_rounds) + att_masked = att_masked / torch.sum(att_masked, dim=-1, keepdim=True) # shape: (batch_size, num_rounds, num_rounds) + feat = torch.sum(att_masked.unsqueeze(-1) * hist.unsqueeze(1), dim=-2) # shape: (batch_size, num_rounds, rnn_dim) + + return feat + +class V_Filter(nn.Module): + """docstring for V_Filter""" + def __init__(self, config): + super(V_Filter, self).__init__() + + self.filter = nn.Sequential( + nn.Dropout(p=config["dropout_fc"]), + nn.Linear( + config["word_embedding_size"], + config["img_feature_size"] + ), + nn.Sigmoid() + ) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight.data) + if m.bias is not None: + nn.init.constant_(m.bias.data, 0) + + def forward(self, img, ques): + # img shape: (batch_size, num_rounds, i_dim) + # ques shape: (batch_size, num_rounds, q_dim) + # output: img_att (batch_size, num_rounds, embed_dim) + + batch_size = ques.size(0) + num_rounds = ques.size(1) + + ques_embed = self.filter(ques) # shape: (batch_size, num_rounds, embed_dim) + + # gated + img_fused = img * ques_embed + + return img_fused \ No newline at end of file