import random
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import numpy as np

from dataset import Dictionary, SelfCriticalDataset
from train import run
import opts

import pdb
import transformers

def weights_init_kn(m):
    if isinstance(m, nn.Linear):
        nn.init.kaiming_normal(m.weight.data, a=0.01)


if __name__ == '__main__':
    opt = opts.parse_opt()
    
    ### fail safe
    if opt.dataset == 'hatcp':
        assert opt.impt_threshold == 0.55
    elif opt.dataset == 'gqacp':
        assert opt.impt_threshold == 0.3
    elif opt.dataset == 'clevrxai':
        assert opt.impt_threshold == 0.85
        
    ## Set random seeds for reproducibility
    if opt.seed == 0:
        seed = random.randint(1, 10000)
        seed = 0
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
    else:
        seed = opt.seed
        np.random.seed(seed)
        random.seed(seed)
        torch.manual_seed(opt.seed)
        torch.cuda.manual_seed(opt.seed)
    torch.backends.cudnn.benchmark = True  # For reproducibility
    
    # load dictionary
    dictionary = Dictionary.load_from_file(f'{opt.data_dir}/dictionary.pkl')
    opt.ntokens = dictionary.ntoken
    print("dictionary ntoken", dictionary.ntoken)

    if opt.use_scr_loss:
        opt.apply_answer_weight = True

    ### creating datasets
    # train dataset
    if opt.split is not None:
        train_dset = SelfCriticalDataset(opt.split, opt.hint_type, dictionary, opt,
                                         discard_items_without_hints=not opt.do_not_discard_items_without_hints)
        train_loader = DataLoader(train_dset,
                                  opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers)
    else:
        train_dset = None
        train_loader = None
    # val dataset
    eval_dset = SelfCriticalDataset(opt.split_test, opt.hint_type, dictionary, opt, 
                                       discard_items_without_hints = opt.discard_items_for_test)
    eval_loader = DataLoader(eval_dset, opt.batch_size, shuffle=False, num_workers=opt.num_workers)

    if opt.use_two_testsets:
        assert(opt.split_test_2 is not None)
        eval_dset_2 = SelfCriticalDataset(opt.split_test_2, opt.hint_type, dictionary, opt, 
                                           discard_items_without_hints = opt.discard_items_for_test)
        eval_loader_2 = DataLoader(eval_dset_2, opt.batch_size, shuffle=False, num_workers=opt.num_workers)
    else:
        eval_loader_2 = None

    # update opts
    opt.full_v_dim = eval_dset.full_v_dim
    opt.num_ans_candidates = eval_dset.num_ans_candidates
    opt.num_objects = eval_dset.num_objects
    
    
    ## Create model
    if opt.model_type == 'updn':
        from models.updn import UpDn
        model = UpDn(opt)
    elif opt.model_type == 'updn_ramen':
        from models.updn import UpDn_ramen
        model = UpDn_ramen(opt)
    elif opt.model_type == 'lang_only':
        # load language-only updn model
        from models.lang_only import LangOnly
        model = LangOnly(opt)
    elif opt.model_type == 'lxmert':
        from models.lxmert import lxmert
        model = lxmert(opt)
    else:
        raise ValueError("unsupported model type")
    
    model = model.cuda()
    if 'lxmert' not in opt.model_type:
        model.apply(weights_init_kn)

    model = nn.DataParallel(model).cuda()
    model.train()
    
    run(model,
        train_loader,
        eval_loader,
        eval_loader_2,
        opt)