import argparse
import os
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
from pytorch_lightning.loggers import TensorBoardLogger
from transformers import AutoTokenizer
from dataset import RankDataMod
from models import BertRanker


def main():
    parser = argparse.ArgumentParser(
        description='BERT-based summary ranker.', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser = Trainer.add_argparse_args(parser)
    parser.add_argument('-d', '--data_path', default='./data/cnndm/rerank_data',
                        type=str, metavar='', help='data directory path')
    parser.add_argument('-o', '--output', default='./checkpoints/cnndm',
                        type=str, metavar='', help='checkpoint save dir')
    parser.add_argument('--model_name_or_path', default='bert-base-uncased',
                        type=str, metavar='', help='path to pretrained model or model identifier from huggingface.co/models')
    parser.add_argument('--tokenizer_name', default=None,
                        type=str, metavar='', help='path to tokenizer if not the same as model_name')
    parser.add_argument('--config_name', default=None,
                        type=str, metavar='', help='path to config if not the same as model_name')
    parser.add_argument('--loss', default='listmle',
                        type=str, metavar='', help='training loss function')
    parser.add_argument('--temperature', default=1.,
                        type=float, metavar='', help='temperature hyperparameter for softmax')
    parser.add_argument('--margin_weight', default=10.,
                        type=float, metavar='', help='margin weight hyperparameter for max_margin loss')
    parser.add_argument('--metric', default='ctc_sum',
                        type=str, metavar='', help='ranking metric')
    parser.add_argument('--learning_rate', default=5e-5,
                        type=float, metavar='', help='learning rate')
    parser.add_argument('--batch_size', default=4,
                        type=int, metavar='', help='batch size')
    parser.add_argument('--num_train_samples', default=-1,
                        type=int, metavar='', help='number of training examples')
    parser.add_argument('--checkpoint', default=None,
                        type=str, metavar='', help='checkpoint file')
    parser.add_argument('--logdir', default=None,
                        type=str, metavar='', help='logs save directory')
    parser.add_argument('--predictions_file', default='./predictions.jsonl',
                        type=str, metavar='', help='output predictions file (.jsonl)')
    parser.add_argument('--do_train', dest='do_train', action='store_true')
    parser.add_argument('--do_eval', dest='do_eval', action='store_true')
    parser.add_argument('--do_predict', dest='do_predict', action='store_true')
    parser.add_argument('--seed', default=33,
                        type=int, metavar='', help='random seed')
    args = parser.parse_args()
    if args.gpus is None:
        args.gpus = 0

    if args.seed > 0:
        seed_everything(args.seed)

    tokenizer_name = args.tokenizer_name if args.tokenizer_name is not None else args.model_name_or_path
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=False)

    model = BertRanker(
        model_name_or_path=args.model_name_or_path,
        tokenizer=tokenizer,
        config_name=args.config_name,
        loss=args.loss,
        temperature=args.temperature,
        margin_weight=args.margin_weight,
        learning_rate=args.learning_rate,
        batch_size=args.batch_size,
        accumulate_grad_batches=args.accumulate_grad_batches,
        metric=args.metric,
        num_train_samples=args.num_train_samples,
        predictions_file=args.predictions_file,
    )

    datamodule = RankDataMod(
        path=args.data_path,
        metric=args.metric,
        tokenizer=tokenizer,
        batch_size=args.batch_size,
        num_train_samples=args.num_train_samples,
        predict_only=not args.do_train
    )

    checkpoint = args.checkpoint
    if args.do_train:
        if checkpoint is not None:
            model = BertRanker.load_from_checkpoint(checkpoint)

        checkpoint_callback = ModelCheckpoint(
            monitor='val_ndcg',
            dirpath=args.output,
            filename=f'ebr-{args.metric}-{args.loss}-'+'{epoch:02d}-{val_ndcg:.2f}',
            save_top_k=3,
            mode='max',
        )
        lr_monitor = LearningRateMonitor(logging_interval='step')
        logger = TensorBoardLogger(save_dir=os.getcwd(), name=args.logdir)

        trainer = Trainer.from_argparse_args(args, callbacks=[checkpoint_callback, lr_monitor], logger=logger)
        trainer.validate(model, datamodule=datamodule)
        trainer.fit(model, datamodule=datamodule)
        checkpoint = checkpoint_callback.best_model_path

    if args.do_eval:
        if checkpoint is not None:
            model = BertRanker.load_from_checkpoint(checkpoint)
        trainer = Trainer.from_argparse_args(args, logger=False)
        trainer.test(model, datamodule=datamodule)

    if args.do_predict:
        if checkpoint is not None:
            model = BertRanker.load_from_checkpoint(checkpoint)
            model.predictions_file = args.predictions_file
        trainer = Trainer.from_argparse_args(args, logger=False)
        trainer.predict(model, datamodule=datamodule)


if __name__ == '__main__':
    main()