from copy import deepcopy

from experiments.components_clmr import (
    continual_learning_evaluator_l2center,
    continual_learning_ewc_trainer,
    continual_learning_gem_trainer,
    continual_learning_icarl_trainer,
    continual_learning_l2center_trainer,
    continual_learning_replay_trainer,
    continual_learning_trainer,
    evaluator,
    mert_data_transform_vocalset,
    oracle_evaluator,
    oracle_trainer,
)

###############################################################
###########                SCENARIOS                ###########
###############################################################

all_tasks = [
    ["vibrato", "straight"],
    ["belt", "breathy"],
    ["lip_trill", "spoken"],
    ["inhaled", "trill"],
    ["trillo", "vocal_fry"],
]

scenario1 = [
    ["vibrato", "straight"],
    ["belt", "breathy"],
    ["lip_trill", "spoken"],
    ["inhaled", "trill"],
    ["trillo", "vocal_fry"],
]

scenario2 = [
    ["belt", "trill"],
    ["vibrato", "inhaled"],
    ["breathy", "straight"],
    ["vocal_fry", "lip_trill"],
    ["spoken", "trillo"],
]

scenario3 = [
    ["spoken", "breathy"],
    ["straight", "inhaled"],
    ["lip_trill", "trillo"],
    ["vibrato", "vocal_fry"],
    ["trill", "belt"],
]

###############################################################
###########               COMPONENTS                ###########
###############################################################

num_classes = 10

# Data sources
train_vocalsettech_data_source = {
    "name": "VocalSetTechDataSource",
    "args": {
        "split": "train",
        "chunk_length": 2.6780,  # 59049 samples
        "is_eval": False,
    },
}
val_vocalsettech_data_source = deepcopy(train_vocalsettech_data_source)
val_vocalsettech_data_source["args"]["split"] = "val"
val_vocalsettech_data_source["args"]["is_eval"] = True
test_vocalsettech_data_source = deepcopy(train_vocalsettech_data_source)
test_vocalsettech_data_source["args"]["split"] = "test"
test_vocalsettech_data_source["args"]["is_eval"] = True

# Metrics
classification_metrics_vocalsettech = [
    {
        "name": "Accuracy",
        "args": {
            "task": "multiclass",
            "average": "micro",
            "num_classes": num_classes,
        },
    },
    {
        "name": "F1 Score",
        "args": {
            "task": "multiclass",
            "average": "macro",
            "num_classes": num_classes,
        },
    },
    {
        "name": "Precision",
        "args": {
            "task": "multiclass",
            "average": "macro",
            "num_classes": num_classes,
        },
    },
    {
        "name": "Recall",
        "args": {
            "task": "multiclass",
            "average": "macro",
            "num_classes": num_classes,
        },
    },
]


###########                TRAINERS                 ###########

oracle_train_model_vocalsettech = {
    "name": "TorchClmrClassificationModel",
    "args": {
        "num_classes": num_classes,
        "encoder": {
            "name": "ClmrEncoder",
            "args": {
                "pretrained": True,
            },
        },
    },
}

update_trainer_vocalsettech = dict(
    train_data_source=train_vocalsettech_data_source,
    val_data_source=val_vocalsettech_data_source,
    metrics_config=classification_metrics_vocalsettech,
    train_data_transform=mert_data_transform_vocalset,
    val_data_transform=mert_data_transform_vocalset,
)

# Oracle
oracle_trainer_vocalsettech = deepcopy(oracle_trainer)
oracle_trainer_vocalsettech["args"]["train_model"] = oracle_train_model_vocalsettech
oracle_trainer_vocalsettech["args"].update(update_trainer_vocalsettech)


## Finetuning
continual_learning_trainer_vocalsettech_scenario1 = deepcopy(continual_learning_trainer)
continual_learning_trainer_vocalsettech_scenario1["args"]["tasks"] = scenario1
continual_learning_trainer_vocalsettech_scenario1["args"].update(update_trainer_vocalsettech)


continual_learning_trainer_vocalsettech_scenario2 = deepcopy(continual_learning_trainer_vocalsettech_scenario1)
continual_learning_trainer_vocalsettech_scenario2["args"]["tasks"] = scenario2
continual_learning_trainer_vocalsettech_scenario3 = deepcopy(continual_learning_trainer_vocalsettech_scenario1)
continual_learning_trainer_vocalsettech_scenario3["args"]["tasks"] = scenario3


## Replay
continual_learning_replay_trainer_vocalsettech_scenario1 = deepcopy(continual_learning_replay_trainer)
continual_learning_replay_trainer_vocalsettech_scenario1["args"]["tasks"] = scenario1
continual_learning_replay_trainer_vocalsettech_scenario1["args"].update(update_trainer_vocalsettech)

continual_learning_replay_trainer_vocalsettech_scenario2 = deepcopy(
    continual_learning_replay_trainer_vocalsettech_scenario1
)
continual_learning_replay_trainer_vocalsettech_scenario2["args"]["tasks"] = scenario2
continual_learning_replay_trainer_vocalsettech_scenario3 = deepcopy(
    continual_learning_replay_trainer_vocalsettech_scenario1
)
continual_learning_replay_trainer_vocalsettech_scenario3["args"]["tasks"] = scenario3


## iCaRL
continual_learning_icarl_trainer_vocalsettech_scenario1 = deepcopy(continual_learning_icarl_trainer)
continual_learning_icarl_trainer_vocalsettech_scenario1["args"]["tasks"] = scenario1
continual_learning_icarl_trainer_vocalsettech_scenario1["args"].update(update_trainer_vocalsettech)


continual_learning_icarl_trainer_vocalsettech_scenario2 = deepcopy(
    continual_learning_icarl_trainer_vocalsettech_scenario1
)
continual_learning_icarl_trainer_vocalsettech_scenario2["args"]["tasks"] = scenario2
continual_learning_icarl_trainer_vocalsettech_scenario3 = deepcopy(
    continual_learning_icarl_trainer_vocalsettech_scenario1
)
continual_learning_icarl_trainer_vocalsettech_scenario3["args"]["tasks"] = scenario3


## GEM
continual_learning_gem_trainer_vocalsettech_scenario1 = deepcopy(continual_learning_gem_trainer)
continual_learning_gem_trainer_vocalsettech_scenario1["args"]["tasks"] = scenario1
continual_learning_gem_trainer_vocalsettech_scenario1["args"].update(update_trainer_vocalsettech)


continual_learning_gem_trainer_vocalsettech_scenario2 = deepcopy(continual_learning_gem_trainer_vocalsettech_scenario1)
continual_learning_gem_trainer_vocalsettech_scenario2["args"]["tasks"] = scenario2
continual_learning_gem_trainer_vocalsettech_scenario3 = deepcopy(continual_learning_gem_trainer_vocalsettech_scenario1)
continual_learning_gem_trainer_vocalsettech_scenario3["args"]["tasks"] = scenario3


## EWC
continual_learning_ewc_trainer_vocalsettech_scenario1 = deepcopy(continual_learning_ewc_trainer)
continual_learning_ewc_trainer_vocalsettech_scenario1["args"]["tasks"] = scenario1
continual_learning_ewc_trainer_vocalsettech_scenario1["args"].update(update_trainer_vocalsettech)


continual_learning_ewc_trainer_vocalsettech_scenario2 = deepcopy(continual_learning_ewc_trainer_vocalsettech_scenario1)
continual_learning_ewc_trainer_vocalsettech_scenario2["args"]["tasks"] = scenario2
continual_learning_ewc_trainer_vocalsettech_scenario3 = deepcopy(continual_learning_ewc_trainer_vocalsettech_scenario1)
continual_learning_ewc_trainer_vocalsettech_scenario3["args"]["tasks"] = scenario3

## L2Center
continual_learning_l2center_trainer_vocalsettech_scenario1 = deepcopy(continual_learning_l2center_trainer)
continual_learning_l2center_trainer_vocalsettech_scenario1["args"]["tasks"] = scenario1
continual_learning_l2center_trainer_vocalsettech_scenario1["args"].update(update_trainer_vocalsettech)


continual_learning_l2center_trainer_vocalsettech_scenario2 = deepcopy(
    continual_learning_l2center_trainer_vocalsettech_scenario1
)
continual_learning_l2center_trainer_vocalsettech_scenario2["args"]["tasks"] = scenario2
continual_learning_l2center_trainer_vocalsettech_scenario3 = deepcopy(
    continual_learning_l2center_trainer_vocalsettech_scenario1
)
continual_learning_l2center_trainer_vocalsettech_scenario3["args"]["tasks"] = scenario3


###########               EVALUATORS                ###########

update_evaluator_vocalsettech = dict(
    data_source=test_vocalsettech_data_source,
    metrics_config=classification_metrics_vocalsettech,
    data_transform=mert_data_transform_vocalset,
)

# Oracle
oracle_evaluator_vocalsettech = deepcopy(oracle_evaluator)
oracle_evaluator_vocalsettech["args"]["model"] = oracle_train_model_vocalsettech
oracle_evaluator_vocalsettech["args"].update(update_evaluator_vocalsettech)

oracle_evaluator_vocalsettech_scenario1 = deepcopy(oracle_evaluator_vocalsettech)
oracle_evaluator_vocalsettech_scenario1["args"]["tasks"] = scenario1
oracle_evaluator_vocalsettech_scenario2 = deepcopy(oracle_evaluator_vocalsettech)
oracle_evaluator_vocalsettech_scenario2["args"]["tasks"] = scenario2
oracle_evaluator_vocalsettech_scenario3 = deepcopy(oracle_evaluator_vocalsettech)
oracle_evaluator_vocalsettech_scenario3["args"]["tasks"] = scenario3

## Finetuning
continual_learning_evaluator_vocalsettech_scenario1 = deepcopy(evaluator)
continual_learning_evaluator_vocalsettech_scenario1["args"]["tasks"] = scenario1
continual_learning_evaluator_vocalsettech_scenario1["args"].update(update_evaluator_vocalsettech)


continual_learning_evaluator_vocalsettech_scenario2 = deepcopy(continual_learning_evaluator_vocalsettech_scenario1)
continual_learning_evaluator_vocalsettech_scenario2["args"]["tasks"] = scenario2
continual_learning_evaluator_vocalsettech_scenario3 = deepcopy(continual_learning_evaluator_vocalsettech_scenario1)
continual_learning_evaluator_vocalsettech_scenario3["args"]["tasks"] = scenario3

## L2Center
continual_learning_l2center_evaluator_vocalsettech_scenario1 = deepcopy(continual_learning_evaluator_l2center)
continual_learning_l2center_evaluator_vocalsettech_scenario1["args"]["tasks"] = scenario1
continual_learning_l2center_evaluator_vocalsettech_scenario1["args"].update(update_evaluator_vocalsettech)


continual_learning_l2center_evaluator_vocalsettech_scenario2 = deepcopy(
    continual_learning_l2center_evaluator_vocalsettech_scenario1
)
continual_learning_l2center_evaluator_vocalsettech_scenario2["args"]["tasks"] = scenario2
continual_learning_l2center_evaluator_vocalsettech_scenario3 = deepcopy(
    continual_learning_l2center_evaluator_vocalsettech_scenario1
)
continual_learning_l2center_evaluator_vocalsettech_scenario3["args"]["tasks"] = scenario3


###############################################################
###########               EXPERIMENTS               ###########
###############################################################

###########                BASELINES                ###########

clmrsamplecnn_base_oracle_vocalsettech_scenario1 = {
    "experiment_name": "clmrsamplecnn_base_oracle_vocalsettech_all",
    "experiment_type": "Baseline",
    "experiment_subtype": "Oracle",
    # data
    "train": {
        "trainer": oracle_trainer_vocalsettech,
    },
    "evaluate": {
        "evaluator": oracle_evaluator_vocalsettech_scenario1,
    },
}

clmrsamplecnn_base_oracle_vocalsettech_scenario2 = {
    "experiment_name": "clmrsamplecnn_base_oracle_vocalsettech_all",
    "experiment_type": "Baseline",
    "experiment_subtype": "Oracle",
    # data
    "train": {
        "trainer": oracle_trainer_vocalsettech,
    },
    "evaluate": {
        "evaluator": oracle_evaluator_vocalsettech_scenario2,
    },
}

clmrsamplecnn_base_oracle_vocalsettech_scenario3 = {
    "experiment_name": "clmrsamplecnn_base_oracle_vocalsettech_all",
    "experiment_type": "Baseline",
    "experiment_subtype": "Oracle",
    # data
    "train": {
        "trainer": oracle_trainer_vocalsettech,
    },
    "evaluate": {
        "evaluator": oracle_evaluator_vocalsettech_scenario3,
    },
}


###########            CONTINUAL LEARNING           ###########

# SCENARIO 1

clmrsamplecnn_finetuning_cl_vocalsettech_scenario1 = {
    "experiment_name": "clmrsamplecnn_finetuning_cl_vocalsettech_scenario1",
    "experiment_type": "CL",
    "experiment_subtype": "Finetuning",
    # data
    "train": {
        "trainer": continual_learning_trainer_vocalsettech_scenario1,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario1,
    },
}

clmrsamplecnn_replay_cl_vocalsettech_scenario1 = {
    "experiment_name": "clmrsamplecnn_replay_cl_vocalsettech_scenario1",
    "experiment_type": "CL",
    "experiment_subtype": "Replay",
    # data
    "train": {
        "trainer": continual_learning_replay_trainer_vocalsettech_scenario1,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario1,
    },
}

clmrsamplecnn_icarl_cl_vocalsettech_scenario1 = {
    "experiment_name": "clmrsamplecnn_icarl_cl_vocalsettech_scenario1",
    "experiment_type": "CL",
    "experiment_subtype": "iCaRL",
    # data
    "train": {
        "trainer": continual_learning_icarl_trainer_vocalsettech_scenario1,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario1,
    },
}

clmrsamplecnn_gem_cl_vocalsettech_scenario1 = {
    "experiment_name": "clmrsamplecnn_gem_cl_vocalsettech_scenario1",
    "experiment_type": "CL",
    "experiment_subtype": "GEM",
    # data
    "train": {
        "trainer": continual_learning_gem_trainer_vocalsettech_scenario1,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario1,
    },
}

clmrsamplecnn_ewc_cl_vocalsettech_scenario1 = {
    "experiment_name": "clmrsamplecnn_ewc_cl_vocalsettech_scenario1",
    "experiment_type": "CL",
    "experiment_subtype": "EWC",
    # data
    "train": {
        "trainer": continual_learning_ewc_trainer_vocalsettech_scenario1,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario1,
    },
}

clmrsamplecnn_l2center_cl_vocalsettech_scenario1 = {
    "experiment_name": "clmrsamplecnn_l2center_cl_vocalsettech_scenario1",
    "experiment_type": "CL",
    "experiment_subtype": "L2Center",
    # data
    "train": {
        "trainer": continual_learning_l2center_trainer_vocalsettech_scenario1,
    },
    "evaluate": {
        "evaluator": continual_learning_l2center_evaluator_vocalsettech_scenario1,
    },
}


# SCENARIO 2

clmrsamplecnn_finetuning_cl_vocalsettech_scenario2 = {
    "experiment_name": "clmrsamplecnn_finetuning_cl_vocalsettech_scenario2",
    "experiment_type": "CL",
    "experiment_subtype": "Finetuning",
    # data
    "train": {
        "trainer": continual_learning_trainer_vocalsettech_scenario2,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario2,
    },
}

clmrsamplecnn_replay_cl_vocalsettech_scenario2 = {
    "experiment_name": "clmrsamplecnn_replay_cl_vocalsettech_scenario2",
    "experiment_type": "CL",
    "experiment_subtype": "Replay",
    # data
    "train": {
        "trainer": continual_learning_replay_trainer_vocalsettech_scenario2,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario2,
    },
}

clmrsamplecnn_icarl_cl_vocalsettech_scenario2 = {
    "experiment_name": "clmrsamplecnn_icarl_cl_vocalsettech_scenario2",
    "experiment_type": "CL",
    "experiment_subtype": "iCaRL",
    # data
    "train": {
        "trainer": continual_learning_icarl_trainer_vocalsettech_scenario2,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario2,
    },
}

clmrsamplecnn_gem_cl_vocalsettech_scenario2 = {
    "experiment_name": "clmrsamplecnn_gem_cl_vocalsettech_scenario2",
    "experiment_type": "CL",
    "experiment_subtype": "GEM",
    # data
    "train": {
        "trainer": continual_learning_gem_trainer_vocalsettech_scenario2,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario2,
    },
}

clmrsamplecnn_ewc_cl_vocalsettech_scenario2 = {
    "experiment_name": "clmrsamplecnn_ewc_cl_vocalsettech_scenario2",
    "experiment_type": "CL",
    "experiment_subtype": "EWC",
    # data
    "train": {
        "trainer": continual_learning_ewc_trainer_vocalsettech_scenario2,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario2,
    },
}

clmrsamplecnn_l2center_cl_vocalsettech_scenario2 = {
    "experiment_name": "clmrsamplecnn_l2center_cl_vocalsettech_scenario2",
    "experiment_type": "CL",
    "experiment_subtype": "L2Center",
    # data
    "train": {
        "trainer": continual_learning_l2center_trainer_vocalsettech_scenario2,
    },
    "evaluate": {
        "evaluator": continual_learning_l2center_evaluator_vocalsettech_scenario2,
    },
}


# SCENARIO 3

clmrsamplecnn_finetuning_cl_vocalsettech_scenario3 = {
    "experiment_name": "clmrsamplecnn_finetuning_cl_vocalsettech_scenario3",
    "experiment_type": "CL",
    "experiment_subtype": "Finetuning",
    # data
    "train": {
        "trainer": continual_learning_trainer_vocalsettech_scenario3,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario3,
    },
}

clmrsamplecnn_replay_cl_vocalsettech_scenario3 = {
    "experiment_name": "clmrsamplecnn_replay_cl_vocalsettech_scenario3",
    "experiment_type": "CL",
    "experiment_subtype": "Replay",
    # data
    "train": {
        "trainer": continual_learning_replay_trainer_vocalsettech_scenario3,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario3,
    },
}

clmrsamplecnn_icarl_cl_vocalsettech_scenario3 = {
    "experiment_name": "clmrsamplecnn_icarl_cl_vocalsettech_scenario3",
    "experiment_type": "CL",
    "experiment_subtype": "iCaRL",
    # data
    "train": {
        "trainer": continual_learning_icarl_trainer_vocalsettech_scenario3,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario3,
    },
}

clmrsamplecnn_gem_cl_vocalsettech_scenario3 = {
    "experiment_name": "clmrsamplecnn_gem_cl_vocalsettech_scenario3",
    "experiment_type": "CL",
    "experiment_subtype": "GEM",
    # data
    "train": {
        "trainer": continual_learning_gem_trainer_vocalsettech_scenario3,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario3,
    },
}

clmrsamplecnn_ewc_cl_vocalsettech_scenario3 = {
    "experiment_name": "clmrsamplecnn_ewc_cl_vocalsettech_scenario3",
    "experiment_type": "CL",
    "experiment_subtype": "EWC",
    # data
    "train": {
        "trainer": continual_learning_ewc_trainer_vocalsettech_scenario3,
    },
    "evaluate": {
        "evaluator": continual_learning_evaluator_vocalsettech_scenario3,
    },
}

clmrsamplecnn_l2center_cl_vocalsettech_scenario3 = {
    "experiment_name": "clmrsamplecnn_l2center_cl_vocalsettech_scenario3",
    "experiment_type": "CL",
    "experiment_subtype": "L2Center",
    # data
    "train": {
        "trainer": continual_learning_l2center_trainer_vocalsettech_scenario3,
    },
    "evaluate": {
        "evaluator": continual_learning_l2center_evaluator_vocalsettech_scenario3,
    },
}