import argparse
import logging
import os
import sys

import numpy as np
import torch
from torch.autograd import Variable
from torch.optim import Adam
from torch.utils.data import DataLoader

sys.path.append('..')
from EMNQA import util
from EMNQA.data_set import QAdataset
from EMNQA.model import DMN

logger = logging.getLogger()
logger.setLevel(logging.INFO)
fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p')
console = logging.StreamHandler()
console.setFormatter(fmt)
logger.addHandler(console)

HIDDEN_SIZE = 80
BATCH_SIZE = 64
LR = 0.001
EPOCH = 50
NUM_EPISODE = 3
EARLY_STOPPING = False
DATA_WORKS = 4

USE_CUDA = torch.cuda.is_available()

FloatTensor = torch.cuda.FloatTensor if USE_CUDA else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if USE_CUDA else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if USE_CUDA else torch.ByteTensor


def prepare_data(filename):  # data -> dataloader
    data_p = util.bAbI_data_load(filename)

    word2idx, idx2word = util.build_words_dict(data_p)

    data_set = QAdataset(data_p, word2idx)

    train_dataloader = DataLoader(data_set,
                                  batch_size=BATCH_SIZE,
                                  # sampler=train_sampler,
                                  num_workers=DATA_WORKS,
                                  collate_fn=util.pad_to_batch,
                                  pin_memory=USE_CUDA, )

    return train_dataloader, word2idx


def seq2variable(data, word2id):  # data -> variable
    for t in data:
        for i, f in enumerate(t[0]):
            t[0][i] = util.prepare_sequence(f, word2id)
        t[1] = util.prepare_sequence(t[1], word2id)
        t[2] = util.prepare_sequence(t[2], word2id)


def train_from_scratch(filename):  # training
    train_data = util.bAbI_data_load(filename)
    test_data = util.bAbI_data_load(args_dic.test_data_file)
    word2idx, idx2word = util.build_words_dict(train_data)
    test_data = util.bAbI_data_test(test_data, word2idx)
    seq2variable(train_data, word2idx)
    print('Model init.')
    model = DMN(HIDDEN_SIZE, len(word2idx), len(word2idx), word2idx)

    if USE_CUDA:
        model = model.cuda()
    model.init_weight()
    # data_loader = prepare_data(filename)

    optimizer = Adam(model.parameters(), lr=LR)
    loss_fun = torch.nn.CrossEntropyLoss(ignore_index=0)

    EARLY_STOPPING = False

    print('Begin Training!')
    for i in range(EPOCH):
        losses = []
        if EARLY_STOPPING: break

        for j, batch in enumerate(util.getbatch(train_data, BATCH_SIZE)):
            facts, fact_masks, questions, question_masks, answers = util.pad_to_batch(batch, word2idx)

            model.zero_grad()
            pred = model(facts, fact_masks, questions, question_masks, answers.size(1), NUM_EPISODE, True)
            loss = loss_fun(pred, answers.view(-1))
            losses.append(loss.data.tolist()[0])

            loss.backward()
            optimizer.step()

            if j % 100 == 0:
                logger.info("[%d/%d] mean_loss : %0.2f" % (i, EPOCH, np.mean(losses)))
                # print("[%d/%d] mean_loss : %0.2f" % (i, EPOCH, np.mean(losses)))

                if np.mean(losses) < 0.01:
                    EARLY_STOPPING = True
                    print("Early Stopping!")
                    torch.save({'state_dict': model.state_dict(), 'word2idx': model.word2index},
                               'earlystoping-%s' % args_dic.model_file)
                    break
                losses = []
    if not EARLY_STOPPING:
        model.state_dict(destination=args_dic.model_file)
    print('Training over. To Testing...')
    evaluation(word2idx, model, test_data)
    print('OK .system finish.')


def pad_fact(fact, x_to_ix):  # this is for inference

    max_x = max([s.size(1) for s in fact])
    x_p = []
    for i in range(len(fact)):
        if fact[i].size(1) < max_x:
            x_p.append(
                torch.cat([fact[i], Variable(LongTensor([x_to_ix['<PAD>']] * (max_x - fact[i].size(1)))).view(1, -1)],
                          1))
        else:
            x_p.append(fact[i])

    fact = torch.cat(x_p)
    fact_mask = torch.cat(
        [Variable(ByteTensor(tuple(map(lambda s: s == 0, t.data))), volatile=False) for t in fact]).view(fact.size(0),
                                                                                                         -1)
    return fact, fact_mask


def evaluation(word2id, model, test_data):
    accuracy = 0
    for d in test_data:
        facts, facts_mask = pad_fact(d[0], word2id)
        question = d[1]
        question_mask = Variable(ByteTensor([0] * d[1].size(1)), volatile=False).unsqueeze(0)
        answer = d[2].squeeze(0)  # ??

        model.zero_grad()
        score = model([facts], [facts_mask], question, question_mask, num_decode=answer.size(0))

        if score.max(1)[1].data.tolist() == answer.data.tolist():
            accuracy += 1

    print(accuracy / len(test_data) * 100)


def train_from_model():
    print('Model init.')
    m = torch.load('earlystoping-EMNQA.model', map_location=lambda storage, loc: storage)
    word2idx = m['word2idx']
    model = DMN(HIDDEN_SIZE, len(word2idx), len(word2idx), word2idx)
    model.load_state_dict(state_dict=m['state_dict'])

    logger.info('Load from state dict over. Evaluation now')
    test_data = util.bAbI_data_load(args_dic.test_data_file)
    test_data = util.bAbI_data_test(test_data, word2idx)
    evaluation(word2idx, model, test_data=test_data)


if __name__ == '__main__':
    # data_file = 'qa5_three-arg-relations_train.txt'
    args = argparse.ArgumentParser()
    args.add_argument('--train-data-file', type=str, default='qa5_three-arg-relations_train.txt',
                      help='Input the train QA data')
    args.add_argument('--test-data-file', type=str, default='qa5_three-arg-relations_test.txt',
                      help='Input the test QA data')
    args.add_argument('--model-file', type=str, default='EMNQA.model',
                      help='Model file saved')

    args_dic = args.parse_args()
    data_file = args_dic.train_data_file

    logger.info('Use CUDA : %s' % USE_CUDA)
    if os.path.isfile('earlystoping-EMNQA.model'):
        logger.info("Find the model state dict . init model...")
        train_from_model(data_file)
    else:
        logger.info('No model state dict be Found .init model from scratch!')
        train_from_scratch(data_file)