diff --git a/3c_training_main.py b/3c_training_main.py deleted file mode 100644 index 298391b..0000000 --- a/3c_training_main.py +++ /dev/null @@ -1,181 +0,0 @@ -import os -import sys -import warnings -from datetime import datetime - -import pandas as pd -import torch -import torch.multiprocessing as mp - -from helpers import system_inputs -from nn_architecture.models import TtsClassifier, TtsDiscriminator -from helpers.get_master import find_free_port -from helpers.dataloader import Dataloader -from helpers.trainer_3c import Trainer, DDPTrainer -from helpers.ddp_training_classifier import run - -"""Train a classifier to distinguish samples between two conditions""" - - -if __name__ == "__main__": - - # sys.argv = ["experiment", "path_dataset=data\ganTrialERP_len100_test_shuffled.csv", "path_test=data\ganTrialERP_len100_train_shuffled.csv", "n_epochs=1", "sample_interval=10", "path_critic=trained_models\sd_len100_train20_500ep.pt"]#, "load_checkpoint", "path_checkpoint=trained_3c\\3c_exp_train20_3ep.pt"] - # sys.argv = ["generated", "path_test=trained_classifier\\cl_exp_109ep.pt", "path_dataset=generated_samples\\sd_len100_10000ep.csv", "n_epochs=2", "sample_interval=10"] - default_args = system_inputs.parse_arguments(sys.argv, system_inputs.default_inputs_training_classifier()) - - if not default_args['experiment'] and not default_args['generated'] and not default_args['testing']: - raise ValueError("At least one of the following flags must be set: 'experiment', 'generated', 'testing'.") - - if default_args['load_checkpoint']: - print(f"Resuming training from checkpoint {default_args['path_checkpoint']}.") - - # Look for cuda - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not default_args['ddp'] else torch.device("cpu") - world_size = torch.cuda.device_count() if torch.cuda.is_available() else mp.cpu_count() - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - default_dict = 'trained_3c' - - opt = { - 'n_epochs': default_args['n_epochs'], - 'sequence_length': default_args['sequence_length'], - 'load_checkpoint': default_args['load_checkpoint'], - 'path_checkpoint': default_args['path_checkpoint'], - 'path_dataset': default_args['path_dataset'], - 'path_test': default_args['path_test'], - 'batch_size': default_args['batch_size'], - 'learning_rate': default_args['learning_rate'], - 'n_conditions': len(default_args['conditions']), - 'patch_size': default_args['patch_size'], - 'hidden_dim': 128, # Dimension of hidden layers in discriminator and generator - 'world_size': world_size, # number of processes for distributed training - 'device': device, - } - - # TODO: implement data concatenation of experiment and generator data - # Load dataset as tensor - train_data = None - train_labels = None - test_data = None - test_labels = None - - # if in testing mode and path_test is None, use the dataset from the specified checkpoint - if default_args['testing'] and default_args['path_test'] == 'None': - default_args['path_test'] = default_args['path_checkpoint'] - opt['path_test'] = default_args['path_checkpoint'] - - # Get test data if provided - if default_args['path_test'] != 'None': - if default_args['path_test'].endswith('.pt'): - # load checkpoint and extract test_dataset - test_data = torch.load(default_args['path_test'], map_location=device)['test_dataset'][:, - opt['n_conditions']:].float() - test_labels = torch.load(default_args['path_test'], map_location=device)['test_dataset'][:, - :opt['n_conditions']].float() - elif default_args['path_test'].endswith('.csv'): - # load csv - dataloader = Dataloader(default_args['path_test'], - kw_timestep=default_args['kw_timestep_dataset'], - col_label=default_args['conditions'], - norm_data=True) - test_data = dataloader.get_data()[:, opt['n_conditions']:].float() - test_labels = dataloader.get_data()[:, :opt['n_conditions']].float() - # test_data = torch.tensor(pd.read_csv(default_args['path_test']).to_numpy()[:, opt['n_conditions']:]).float() - # test_labels = torch.tensor(pd.read_csv(default_args['path_test']).to_numpy()[:, :opt['n_conditions']]).float() - - if default_args['experiment']: - # Get experiment's data as training data - dataloader = Dataloader(default_args['path_dataset'], - kw_timestep=default_args['kw_timestep_dataset'], - col_label=default_args['conditions'], - norm_data=True) - if test_data is None: - train_idx, test_idx = dataloader.dataset_split(train_size=.8) - - train_data = dataloader.get_data()[train_idx][:, dataloader.labels.shape[1]:] - train_labels = dataloader.get_data()[train_idx][:, :dataloader.labels.shape[1]] - test_data = dataloader.get_data()[test_idx][:, dataloader.labels.shape[1]:] - test_labels = dataloader.get_data()[test_idx][:, :dataloader.labels.shape[1]] - else: - train_data = dataloader.get_data()[:, dataloader.labels.shape[1]:].float() - train_labels = dataloader.get_data()[:, :dataloader.labels.shape[1]].float() - - if default_args['generated']: - # Get generated data as training data - train_data = torch.tensor(pd.read_csv(default_args['path_dataset']).to_numpy()[:, opt['n_conditions']:]).float() - train_labels = torch.tensor(pd.read_csv(default_args['path_dataset']).to_numpy()[:, :opt['n_conditions']]).float() - if test_data is None: - # Split train data into train and test - train_idx, test_idx = Dataloader().dataset_split(train_data, train_size=.8) - test_data = train_data[test_idx].view(train_data[test_idx].shape).float() - test_labels = train_labels[test_idx].view(train_labels[test_idx].shape).float() - train_data = train_data[train_idx].view(train_data[train_idx].shape).float() - train_labels = train_labels[train_idx].view(train_labels[train_idx].shape).float() - - opt['sequence_length'] = test_data.shape[1]# - len(default_args['conditions']) - - if opt['sequence_length'] % opt['patch_size'] != 0: - warnings.warn( - f"Sequence length ({opt['sequence_length']}) must be a multiple of patch size ({default_args['patch_size']}).\n" - f"The sequence is padded with zeros to fit the condition.") - padding = 0 - while (opt['sequence_length'] + padding) % default_args['patch_size'] != 0: - padding += 1 - opt['sequence_length'] += padding - train_data = torch.cat((train_data, torch.zeros(train_data.shape[0], padding)), dim=-1) - test_data = torch.cat((test_data, torch.zeros(test_data.shape[0], padding)), dim=-1) - - # Load model and optimizer - # if not default_args['testing']: - critic_configuration = torch.load(default_args['path_critic'], map_location='cpu') - - critic = TtsDiscriminator(seq_length=critic_configuration['configuration']['sequence_length'], - patch_size=critic_configuration['configuration']['patch_size'], - in_channels=1 + critic_configuration['configuration']['n_conditions']) - critic.load_state_dict(critic_configuration['discriminator']) - critic.eval() - classifier = TtsClassifier(seq_length=opt['sequence_length'], - patch_size=opt['patch_size'], - n_classes=int(opt['n_conditions']), - in_channels=3, - softmax=True).to(device) - - # Test model - if default_args['testing']: - classifier.load_state_dict(torch.load(default_args['path_checkpoint'], map_location=device)['classifier']) - trainer = Trainer(classifier, critic, opt) - fake_labels = torch.where(test_labels == 0, 1, 0) - scores = trainer.compute_scores(test_data, test_labels, fake_labels) - test_data_temp = test_data.view(-1, 1, 1, test_data.shape[-1]) - scores = scores.view(-1, 2, 1, 1).repeat(1, 1, 1, test_data.shape[-1]).to(trainer.device) - test_data_temp = torch.concat((test_data_temp, scores), dim=1).to(trainer.device) - test_loss, test_acc = trainer.test(test_data_temp, test_labels) - print(f"Test loss: {test_loss:.4f} - Test accuracy: {test_acc:.4f}") - exit() - - # Train model - if default_args['ddp']: - # DDP Training - trainer = DDPTrainer(classifier, critic, opt) - if default_args['load_checkpoint']: - trainer.load_checkpoint(default_args['path_checkpoint']) - mp.spawn(run, - args=(world_size, find_free_port(), default_args['ddp_backend'], trainer, train_data, train_labels, test_data, test_labels), - nprocs=world_size, join=True) - else: - # Regular training - trainer = Trainer(classifier, critic, opt) - if default_args['load_checkpoint']: - trainer.load_checkpoint(default_args['path_checkpoint']) - loss = trainer.train(train_data, train_labels, test_data, test_labels) - - # Save model - path = 'trained_3c' - filename = '3c_' + timestamp + '.pt' - filename = os.path.join(path, filename) - trainer.save_checkpoint(filename, torch.concat((test_labels, test_data), dim=1), loss) - - print("Classifier training finished.") - print("Model states, losses and test dataset saved to file: " - f"\n{filename}.") - - diff --git a/autoencoder_training_main.py b/autoencoder_training_main.py index 8fd5ba4..3974c65 100644 --- a/autoencoder_training_main.py +++ b/autoencoder_training_main.py @@ -8,8 +8,10 @@ import torch.nn as nn import torch.multiprocessing as mp from torch.utils.data import DataLoader +from datetime import datetime +import warnings -from nn_architecture.ae_networks import TransformerAutoencoder, TransformerFlattenAutoencoder, TransformerDoubleAutoencoder, train, save +from nn_architecture.ae_networks import TransformerAutoencoder, TransformerDoubleAutoencoder from helpers.dataloader import Dataloader from helpers import system_inputs from helpers.trainer import AETrainer @@ -25,17 +27,31 @@ def main(): default_args = system_inputs.parse_arguments(sys.argv, file='autoencoder_training_main.py') print('-----------------------------------------\n') + + # create directory 'trained_models' if not exists + if not os.path.exists('trained_ae'): + os.makedirs('trained_ae') + print('Directory "../trained_ae" created to store checkpoints and final model.') + + if default_args['load_checkpoint'] and default_args['checkpoint'] != '': + # check if checkpoint exists and otherwise take trained_models/checkpoint.pt + if not os.path.exists(default_args['checkpoint']): + print(f"Checkpoint {default_args['checkpoint']} does not exist. Checkpoint is set to 'trained_models/checkpoint.pt'.") + default_args['checkpoint'] = os.path.join('trained_ae', 'checkpoint.pt') + print(f"Resuming training from checkpoint {default_args['checkpoint']}.") + else: + default_args['checkpoint'] = os.path.join('trained_ae', 'checkpoint.pt') # User inputs opt = { - 'path_dataset': default_args['path_dataset'], - 'path_checkpoint': default_args['path_checkpoint'], + 'data': default_args['data'], + 'checkpoint': default_args['checkpoint'], 'save_name': default_args['save_name'], 'target': default_args['target'], 'sample_interval': default_args['sample_interval'], - 'channel_label': default_args['channel_label'], + 'kw_channel': default_args['kw_channel'], 'channels_out': default_args['channels_out'], - 'timeseries_out': default_args['timeseries_out'], + 'time_out': default_args['time_out'], 'n_epochs': default_args['n_epochs'], 'batch_size': default_args['batch_size'], 'train_ratio': default_args['train_ratio'], @@ -48,23 +64,32 @@ def main(): 'num_heads': default_args['num_heads'], 'num_layers': default_args['num_layers'], 'ddp': default_args['ddp'], - 'ddp_backend': default_args['ddp_backend'], + 'ddp_backend': "nccl", #default_args['ddp_backend'], 'norm_data': True, 'std_data': False, 'diff_data': False, - 'kw_timestep': default_args['kw_timestep'], + 'kw_time': default_args['kw_time'], 'device': torch.device("cuda" if torch.cuda.is_available() else "cpu"), 'world_size': torch.cuda.device_count() if torch.cuda.is_available() else mp.cpu_count(), 'history': None, - 'trained_epochs': 0 + 'trained_epochs': 0, + 'seed': default_args['seed'], } - + + # set a seed for reproducibility if desired + if opt['seed'] is not None: + np.random.seed(opt['seed']) + torch.manual_seed(opt['seed']) + torch.cuda.manual_seed(opt['seed']) + torch.cuda.manual_seed_all(opt['seed']) + torch.backends.cudnn.deterministic = True + # ---------------------------------------------------------------------------------------------------------------------- # Load, process, and split data # ---------------------------------------------------------------------------------------------------------------------- - data = Dataloader(path=opt['path_dataset'], - channel_label=opt['channel_label'], kw_timestep=opt['kw_timestep'], + data = Dataloader(path=opt['data'], + kw_channel=opt['kw_channel'], kw_time=opt['kw_time'], norm_data=opt['norm_data'], std_data=opt['std_data'], diff_data=opt['diff_data'],) dataset = data.get_data() @@ -84,12 +109,12 @@ def split_data(dataset, train_size=.8): opt['n_channels'] = dataset.shape[-1] opt['sequence_length'] = dataset.shape[1] opt['channels_in'] = opt['n_channels'] - opt['timeseries_in'] = opt['sequence_length'] + opt['time_in'] = opt['sequence_length'] # Split dataset and convert to pytorch dataloader class test_dataset, train_dataset = split_data(dataset, opt['train_ratio']) - test_dataloader = DataLoader(test_dataset, batch_size=opt['batch_size'], shuffle=True) - train_dataloader = DataLoader(train_dataset, batch_size=opt['batch_size'], shuffle=True) + test_dataloader = DataLoader(test_dataset, batch_size=opt['batch_size'], shuffle=True, pin_memory=True) + train_dataloader = DataLoader(train_dataset, batch_size=opt['batch_size'], shuffle=True, pin_memory=True) # ------------------------------------------------------------------------------------------------------------------ # Initiate and train autoencoder @@ -97,31 +122,31 @@ def split_data(dataset, train_size=.8): # Initiate autoencoder model_dict = None - if default_args['load_checkpoint'] and os.path.isfile(opt['path_checkpoint']): - model_dict = torch.load(opt['path_checkpoint']) + if default_args['load_checkpoint'] and os.path.isfile(opt['checkpoint']): + model_dict = torch.load(opt['checkpoint']) target_old = opt['target'] channels_out_old = opt['channels_out'] - timeseries_out_old = opt['timeseries_out'] + time_out_old = opt['time_out'] opt['target'] = model_dict['configuration']['target'] opt['channels_out'] = model_dict['configuration']['channels_out'] - opt['timeseries_out'] = model_dict['configuration']['timeseries_out'] + opt['time_out'] = model_dict['configuration']['time_out'] # Report changes to user - print(f"Loading model {opt['path_checkpoint']}.\n\nInhereting the following parameters:") + print(f"Loading model {opt['checkpoint']}.\n\nInhereting the following parameters:") print("parameter:\t\told value -> new value") print(f"target:\t\t\t{target_old} -> {opt['target']}") print(f"channels_out:\t{channels_out_old} -> {opt['channels_out']}") - print(f"timeseries_out:\t{timeseries_out_old} -> {opt['timeseries_out']}") + print(f"time_out:\t{time_out_old} -> {opt['time_out']}") print('-----------------------------------\n') - elif default_args['load_checkpoint'] and not os.path.isfile(opt['path_checkpoint']): - raise FileNotFoundError(f"Checkpoint file {opt['path_checkpoint']} not found.") + elif default_args['load_checkpoint'] and not os.path.isfile(opt['checkpoint']): + raise FileNotFoundError(f"Checkpoint file {opt['checkpoint']} not found.") # Add parameters for tracking opt['input_dim'] = opt['n_channels'] if opt['target'] in ['channels', 'full'] else opt['sequence_length'] - opt['output_dim'] = opt['channels_out'] if opt['target'] in ['channels', 'full'] else opt['timeseries_out'] + opt['output_dim'] = opt['channels_out'] if opt['target'] in ['channels', 'full'] else opt['time_out'] opt['output_dim_2'] = opt['sequence_length'] if opt['target'] in ['channels'] else opt['n_channels'] if opt['target'] == 'channels': @@ -144,9 +169,9 @@ def split_data(dataset, train_size=.8): activation=opt['activation']).to(opt['device']) elif opt['target'] == 'full': model_1 = TransformerDoubleAutoencoder(channels_in=opt['channels_in'], - timeseries_in=opt['timeseries_in'], + time_in=opt['time_in'], channels_out=opt['channels_out'], - timeseries_out=opt['timeseries_out'], + time_out=opt['time_out'], hidden_dim=opt['hidden_dim'], num_layers=opt['num_layers'], num_heads=opt['num_heads'], @@ -154,9 +179,9 @@ def split_data(dataset, train_size=.8): training_level=1).to(opt['device']) model_2 = TransformerDoubleAutoencoder(channels_in=opt['channels_in'], - timeseries_in=opt['timeseries_in'], + time_in=opt['time_in'], channels_out=opt['channels_out'], - timeseries_out=opt['timeseries_out'], + time_out=opt['time_out'], hidden_dim=opt['hidden_dim'], num_layers=opt['num_layers'], num_heads=opt['num_heads'], @@ -185,16 +210,30 @@ def split_data(dataset, train_size=.8): opt['training_levels'] = training_levels if opt['ddp']: + warnings.warn(f""" The default autoencoder is a small model and DDP training adds a lot of overhead when transferring data to GPUs. + As such, it might be useful to test each GPU and CPU training and see what works best for your use case. + Although DDP training will result in better performance than CPU with the same number of training epochs, + you can achieve this same performance quicker by adding epochs with CPU training.""", stacklevel=3) for training_level in range(1,training_levels+1): + opt['training_level'] = training_level + if training_levels == 2 and training_level == 1: print('Training the first level of the autoencoder...') - model = model_1 + trainer = AEDDPTrainer(model_1, opt) elif training_levels == 2 and training_level == 2: print('Training the second level of the autoencoder...') - model = model_2 - trainer = AEDDPTrainer(model, opt) + model_1_sd = trainer.model.state_dict() + model_1_osd = trainer.optimizer.state_dict() + trainer = AEDDPTrainer(model_2, opt) + trainer.model1_states = { + 'model': model_1_sd, + 'optimizer': model_1_osd + } + else: + trainer = AEDDPTrainer(model, opt) + if default_args['load_checkpoint']: - trainer.load_checkpoint(default_args['path_checkpoint']) + trainer.load_checkpoint(default_args['checkpoint']) mp.spawn(run, args=(opt['world_size'], find_free_port(), opt['ddp_backend'], trainer, opt), nprocs=opt['world_size'], join=True) @@ -211,13 +250,21 @@ def split_data(dataset, train_size=.8): if training_levels == 2 and training_level == 1: print('Training the first level of the autoencoder...') - model = model_1 + trainer = AETrainer(model_1, opt) elif training_levels == 2 and training_level == 2: print('Training the second level of the autoencoder...') - model = model_2 - trainer = AETrainer(model, opt) + model_1_sd = trainer.model.state_dict() + model_1_osd = trainer.optimizer.state_dict() + trainer = AETrainer(model_2, opt) + trainer.model1_states = { + 'model': model_1_sd, + 'optimizer': model_1_osd + } + else: + trainer = AETrainer(model, opt) + if default_args['load_checkpoint']: - trainer.load_checkpoint(default_args['path_checkpoint']) + trainer.load_checkpoint(default_args['checkpoint']) samples = trainer.training(train_dataloader, test_dataloader) if training_levels == 2 and training_level == 1: @@ -237,10 +284,17 @@ def split_data(dataset, train_size=.8): # ---------------------------------------------------------------------------------------------------------------------- # Save model - if opt['save_name'] is None: - fn = opt['path_dataset'].split('/')[-1].split('.csv')[0] - opt['save_name'] = os.path.join("trained_ae", f"ae_{fn}_{str(time.time()).split('.')[0]}.pt") - + path = 'trained_ae' + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + if opt['save_name'] != '': + # check if .pt extension is already included in the save_name + if not opt['save_name'].endswith('.pt'): + opt['save_name'] += '.pt' + filename = opt['save_name'] + else: + filename = f'ae_{trainer.epochs}ep_' + timestamp + '.pt' + + opt['save_name'] = os.path.join(path, filename) trainer.save_checkpoint(opt['save_name'], update_history=True, samples=samples) print(f"Model and configuration saved in {opt['save_name']}") diff --git a/auxiliary/compare_averaged_conditions.py b/auxiliary/compare_averaged_conditions.py deleted file mode 100644 index 93e13da..0000000 --- a/auxiliary/compare_averaged_conditions.py +++ /dev/null @@ -1,58 +0,0 @@ -import os.path - -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt - -from helpers.dataloader import Dataloader - -if __name__ == "__main__": - """Create a plot of two curves. Each one represents the averaged samples of one condition. - The datasets to be processed have to be either in the directory: - - generated_samples if generated samples in the common csv-format (dimensions: (number samples, (condition, measurement))) - - data if a study's dataset is taken""" - - # average over all samples - - # for generated sample - file = r'C:\Users\Daniel\PycharmProjects\GanInNeuro\generated_samples\gan_train05_2500ep.csv' - - # for experiment files - # file = r'C:\Users\Daniel\PycharmProjects\GanInNeuro\data\ganAverageERP_len100.csv' - - if os.path.sep + 'data' + os.path.sep in file: - # file is in data folder is thus an experiment file - data = Dataloader(file, norm_data=True).get_data(shuffle=False) - elif os.path.sep + 'generated_samples' + os.path.sep in file: - # file is in generated_samples folder is thus a generated sample file - data = pd.read_csv(file) - data = data.to_numpy() - else: - raise ValueError('File is not in data or generated_samples folder') - - # sort samples into respective bins - data_cond0 = [] - data_cond1 = [] - for i in range(data.shape[0]): - if data[i, 0] == 0: - data_cond0.append(data[i, 1:].tolist()) - else: - data_cond1.append(data[i, 1:].tolist()) - - data_cond0 = np.array(data_cond0) - data_cond1 = np.array(data_cond1) - data_all = [data_cond1, data_cond0] - - erp = [] - legend = ['Condition 1', 'Condition 0'] - - for i, f in enumerate(data_all): - f = f.mean(axis=0) - erp.append(f) - plt.plot(f) - plt.legend(legend) - filename = os.path.basename(file).split('.')[0] + '_avg.png' - filename = os.path.join(r'C:\Users\Daniel\PycharmProjects\GanInNeuro\plots', filename) - # plt.savefig(filename) - # plt.ylim((0.45, 0.6)) - plt.show() diff --git a/auxiliary/data_split.py b/auxiliary/data_split.py deleted file mode 100644 index a729032..0000000 --- a/auxiliary/data_split.py +++ /dev/null @@ -1,50 +0,0 @@ -import os - -import numpy as np -import pandas as pd - -from helpers.dataloader import Dataloader - - -if __name__ == "__main__": - """Use this script to split the dataset into train and test data. - Use the train data to train the GAN and the classifier. - Use the test data to evaluate the classifier.""" - - # setup - file_dataset = r'C:\Users\Daniel\PycharmProjects\GanInNeuro\data\ganTrialERP_len100.csv' - conditions_dataset = ['Condition'] # the column name of the condition to train on - n_data_col = 4 # number of column when the actual data begins - train_size = 0.8 - shuffle_data = False - - # split dataset into train and test - dataloader = Dataloader(file_dataset, col_label=conditions_dataset) - train_data, test_data, train_idx, test_idx = dataloader.dataset_split(train_size=train_size, shuffle=shuffle_data) - - # train_data = dataloader.get_data()[train_idx][:, len(conditions_dataset):].detach().cpu().numpy() - # test_data = dataloader.get_data()[test_idx][:, len(conditions_dataset):].detach().cpu().numpy() - - # load original data as dataframe - df = pd.read_csv(file_dataset) - columns = df.columns - # get first n columns - cond_train = df[columns[:n_data_col-1]].to_numpy()[train_idx] - cond_test = df[columns[:n_data_col-1]].to_numpy()[test_idx] - - # to dataframe - train_data = np.concatenate((cond_train, train_data[:, len(conditions_dataset):]), axis=1) - test_data = np.concatenate((cond_test, test_data[:, len(conditions_dataset):]), axis=1) - train_df = pd.DataFrame(train_data, columns=columns) - test_df = pd.DataFrame(test_data, columns=columns) - - # save to csv - path = os.path.dirname(file_dataset) - file = os.path.basename(file_dataset) - file = file.split('.')[0] - filename_train = f'{file}_train.csv' - filename_test = f'{file}_test.csv' - train_df.to_csv(os.path.join(path, filename_train), index=False) - test_df.to_csv(os.path.join(path, filename_test), index=False) - - print('Done') diff --git a/auxiliary/get_critic_scores.py b/auxiliary/get_critic_scores.py deleted file mode 100644 index 6d183de..0000000 --- a/auxiliary/get_critic_scores.py +++ /dev/null @@ -1,50 +0,0 @@ -import numpy as np -import torch -from matplotlib import pyplot as plt - -from helpers.dataloader import Dataloader -from nn_architecture.models import TtsDiscriminator - -if __name__ == "__main__": - """Feed some generated samples into the discriminator and obtain the given scores. - The file needs to be a checkpoint file to load the GAN""" - - # load discriminator to check its performance - file = r'C:\Users\Daniel\PycharmProjects\GanInNeuro\trained_models\sd_len100_train20_500ep.pt' - dc = torch.load(file, map_location=torch.device('cpu')) - critic = TtsDiscriminator(seq_length=dc['configuration']['sequence_length'], - patch_size=dc['configuration']['patch_size'], - in_channels=1 + dc['configuration']['n_conditions']) - critic.load_state_dict(dc['discriminator']) - critic.eval() - - # load data - file = r'C:\Users\Daniel\PycharmProjects\GanInNeuro\data\ganTrialERP_len100_train.csv' - data = Dataloader(file, norm_data=True).get_data() - labels = data[:, :1] - data = data[:, 1:] - # get negated labels - labels_neg = 1 - labels - - # get predictions - for i in range(10): - # random index - idx = np.random.randint(0, data.shape[0], 10) - # get data - data_batch = data[idx, :] - labels_batch = labels[idx, :] - labels_neg_batch = labels_neg[idx, :] - # get predictions - labels = [labels_batch, labels_neg_batch] - score = torch.zeros((labels_batch.shape[0], len(labels))) - for j, label in enumerate(labels): - batch_labels = label.view(-1, 1, 1, 1).repeat(1, 1, 1, data.shape[1]) - batch_data = data_batch.view(-1, 1, 1, data.shape[1]) - batch_data = torch.cat((batch_data, batch_labels), dim=1) - validity = critic(batch_data) - score[:, j] = validity[:, 0] - plt.plot(score[:, 0].detach().numpy(), label='score real labels') - plt.plot(score[:, 1].detach().numpy(), label='score negated labels') - plt.plot(labels_batch.detach().numpy(), label='real labels') - plt.legend() - plt.show() \ No newline at end of file diff --git a/classifier_training_main.py b/classifier_training_main.py deleted file mode 100644 index 8e4daa4..0000000 --- a/classifier_training_main.py +++ /dev/null @@ -1,170 +0,0 @@ -import os -import sys -import warnings -from datetime import datetime - -import pandas as pd -import torch -import torch.multiprocessing as mp - -from helpers import system_inputs -from nn_architecture.models import TtsClassifier, TtsDiscriminator -from helpers.get_master import find_free_port -from helpers.dataloader import Dataloader -from helpers.trainer_classifier import Trainer, DDPTrainer -from helpers.ddp_training_classifier import run - -"""Train a classifier to distinguish samples between two conditions""" - - -if __name__ == "__main__": - - # sys.argv = ["generated", "path_test=trained_classifier\cl_exp_109ep.pt", "path_dataset=generated_samples\sd_len100_10000ep.csv", "n_epochs=2", "sample_interval=10"] - # sys.argv = ["experiment", "path_dataset=data\ganTrialERP_len100_train_shuffled.csv", "path_test=data\ganTrialERP_len100_test_shuffled.csv", "n_epochs=1", "sample_interval=10"] - default_args = system_inputs.parse_arguments(sys.argv, system_inputs.default_inputs_training_classifier()) - - if not default_args['experiment'] and not default_args['generated'] and not default_args['testing']: - raise ValueError("At least one of the following flags must be set: 'experiment', 'generated', 'testing'.") - - if default_args['load_checkpoint']: - print(f"Resuming training from checkpoint {default_args['path_checkpoint']}.") - - # Look for cuda - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if not default_args['ddp'] else torch.device("cpu") - world_size = torch.cuda.device_count() if torch.cuda.is_available() else mp.cpu_count() - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - default_dict = 'trained_classifier' - - opt = { - 'n_epochs': default_args['n_epochs'], - 'sequence_length': default_args['sequence_length'], - 'load_checkpoint': default_args['load_checkpoint'], - 'path_checkpoint': default_args['path_checkpoint'], - 'path_dataset': default_args['path_dataset'], - 'batch_size': default_args['batch_size'], - 'learning_rate': default_args['learning_rate'], - 'n_conditions': len(default_args['conditions']), - 'patch_size': default_args['patch_size'], - 'hidden_dim': 128, # Dimension of hidden layers in discriminator and generator - 'world_size': world_size, # number of processes for distributed training - 'device': device, - } - - # TODO: implement data concatenation of experiment and generator data - # Load dataset as tensor - train_data = None - train_labels = None - test_data = None - test_labels = None - - # if in testing mode and path_test is None, use the dataset from the specified checkpoint - if default_args['testing'] and default_args['path_test'] == 'None': - default_args['path_test'] = default_args['path_checkpoint'] - opt['path_test'] = default_args['path_checkpoint'] - - # Get test data if provided - if default_args['path_test'] != 'None': - if default_args['path_test'].endswith('.pt'): - # load checkpoint and extract test_dataset - test_data = torch.load(default_args['path_test'], map_location=device)['test_dataset'][:, - opt['n_conditions']:].float() - test_labels = torch.load(default_args['path_test'], map_location=device)['test_dataset'][:, - :opt['n_conditions']].float() - elif default_args['path_test'].endswith('.csv'): - # load csv - dataloader = Dataloader(default_args['path_test'], - kw_timestep=default_args['kw_timestep_dataset'], - col_label=default_args['conditions'], - norm_data=True) - test_data = dataloader.get_data()[:, opt['n_conditions']:].float() - test_labels = dataloader.get_data()[:, :opt['n_conditions']].float() - # test_data = torch.tensor(pd.read_csv(default_args['path_test']).to_numpy()[:, opt['n_conditions']:]).float() - # test_labels = torch.tensor(pd.read_csv(default_args['path_test']).to_numpy()[:, :opt['n_conditions']]).float() - - if default_args['experiment']: - # Get experiment's data as training data - dataloader = Dataloader(default_args['path_dataset'], - kw_timestep=default_args['kw_timestep_dataset'], - col_label=default_args['conditions'], - norm_data=True) - if test_data is None: - _, _, train_idx, test_idx = dataloader.dataset_split(train_size=.8) - - train_data = dataloader.get_data()[train_idx][:, dataloader.labels.shape[1]:] - train_labels = dataloader.get_data()[train_idx][:, :dataloader.labels.shape[1]] - test_data = dataloader.get_data()[test_idx][:, dataloader.labels.shape[1]:] - test_labels = dataloader.get_data()[test_idx][:, :dataloader.labels.shape[1]] - else: - train_data = dataloader.get_data()[:, dataloader.labels.shape[1]:].float() - train_labels = dataloader.get_data()[:, :dataloader.labels.shape[1]].float() - - if default_args['generated']: - # Get generated data as training data - train_data = torch.tensor(pd.read_csv(default_args['path_dataset']).to_numpy()[:, opt['n_conditions']:]).float() - train_labels = torch.tensor(pd.read_csv(default_args['path_dataset']).to_numpy()[:, :opt['n_conditions']]).float() - if test_data is None: - # Split train data into train and test - _, _, train_idx, test_idx = Dataloader().dataset_split(train_data, train_size=.8) - test_data = train_data[test_idx].view(train_data[test_idx].shape).float() - test_labels = train_labels[test_idx].view(train_labels[test_idx].shape).float() - train_data = train_data[train_idx].view(train_data[train_idx].shape).float() - train_labels = train_labels[train_idx].view(train_labels[train_idx].shape).float() - - # TODO: implement data concatenation of experiment and generator data - - opt['sequence_length'] = test_data.shape[1]# - len(default_args['conditions']) - - if opt['sequence_length'] % opt['patch_size'] != 0: - warnings.warn( - f"Sequence length ({opt['sequence_length']}) must be a multiple of patch size ({default_args['patch_size']}).\n" - f"The sequence is padded with zeros to fit the condition.") - padding = 0 - while (opt['sequence_length'] + padding) % default_args['patch_size'] != 0: - padding += 1 - opt['sequence_length'] += padding - train_data = torch.cat((train_data, torch.zeros(train_data.shape[0], padding)), dim=-1) - test_data = torch.cat((test_data, torch.zeros(test_data.shape[0], padding)), dim=-1) - - # Load model and optimizer - # TODO: For new classifier, replace the following line with the new classifier - classifier = TtsClassifier(seq_length=opt['sequence_length'], - patch_size=opt['patch_size'], - n_classes=int(opt['n_conditions']), - softmax=True).to(device) - - # Test model - if default_args['testing']: - classifier.load_state_dict(torch.load(default_args['path_checkpoint'], map_location=device)['model']) - trainer = Trainer(classifier, opt) - test_loss, test_acc = trainer.test(test_data, test_labels) - print(f"Test loss: {test_loss:.4f} - Test accuracy: {test_acc:.4f}") - exit() - - # Train model - if default_args['ddp']: - # DDP Training - trainer = DDPTrainer(classifier, opt) - if default_args['load_checkpoint']: - trainer.load_checkpoint(default_args['path_checkpoint']) - mp.spawn(run, - args=(world_size, find_free_port(), default_args['ddp_backend'], trainer, train_data, train_labels, test_data, test_labels), - nprocs=world_size, join=True) - else: - # Regular training - trainer = Trainer(classifier, opt) - if default_args['load_checkpoint']: - trainer.load_checkpoint(default_args['path_checkpoint']) - # TODO: Adjust train-method to your needs - loss = trainer.train(train_data, train_labels, test_data, test_labels) - - # Save model - path = 'trained_classifier' - filename = 'classifier_' + timestamp + '.pt' - filename = os.path.join(path, filename) - trainer.save_checkpoint(filename, test_data, loss) - - print("Classifier training finished.") - print("Model states, losses and test dataset saved to file: " - f"\n{filename}.") - - diff --git a/gan_training_main.py b/gan_training_main.py index c5cc4f2..3f5d80a 100644 --- a/gan_training_main.py +++ b/gan_training_main.py @@ -6,15 +6,12 @@ import torch import torch.multiprocessing as mp from torch.utils.data import DataLoader -from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present from helpers.trainer import GANTrainer from helpers.get_master import find_free_port from helpers.ddp_training import run, GANDDPTrainer -from nn_architecture.models import TransformerGenerator, TransformerDiscriminator, FFGenerator, FFDiscriminator, TTSGenerator, TTSDiscriminator, DecoderGenerator, EncoderDiscriminator -from nn_architecture.ae_networks import TransformerAutoencoder, TransformerDoubleAutoencoder, TransformerFlattenAutoencoder from helpers.dataloader import Dataloader -from helpers.initialize_gan import gan_architectures, gan_types, init_gan +from helpers.initialize_gan import init_gan from helpers import system_inputs """Implementation of the training process of a GAN for the generation of synthetic sequential data. @@ -29,7 +26,14 @@ def main(): - """Main function of the training process.""" + """Main function of the training process. + For input help use the command 'python gan_training_main.py help' in the terminal.""" + + # create directory 'trained_models' if not exists + if not os.path.exists('trained_models'): + os.makedirs('trained_models') + print('Directory "../trained_models" created to store checkpoints and final model.') + default_args = system_inputs.parse_arguments(sys.argv, file='gan_training_main.py') # ---------------------------------------------------------------------------------------------------------------------- @@ -38,9 +42,8 @@ def main(): # Training configuration ddp = default_args['ddp'] - ddp_backend = default_args['ddp_backend'] - load_checkpoint = default_args['load_checkpoint'] - path_checkpoint = default_args['path_checkpoint'] + ddp_backend = "nccl" #default_args['ddp_backend'] + checkpoint = default_args['checkpoint'] # Data configuration diff_data = False # Differentiate data @@ -51,94 +54,91 @@ def main(): if std_data and norm_data: raise Warning("Standardization and normalization are used at the same time.") - if load_checkpoint: - print(f'Resuming training from checkpoint {path_checkpoint}.') + if default_args['checkpoint'] != '': + # check if checkpoint exists and otherwise take trained_models/checkpoint.pt + if not os.path.exists(default_args['checkpoint']): + print(f"Checkpoint {default_args['checkpoint']} does not exist. Checkpoint is set to 'trained_models/checkpoint.pt'.") + default_args['checkpoint'] = os.path.join('trained_models', 'checkpoint.pt') + checkpoint = default_args['checkpoint'] + print(f'Resuming training from checkpoint {checkpoint}.') # GAN configuration opt = { - 'gan_type': default_args['type'], 'n_epochs': default_args['n_epochs'], - 'input_sequence_length': default_args['input_sequence_length'], - # 'seq_len_generated': default_args['seq_len_generated'], - 'load_checkpoint': default_args['load_checkpoint'], - 'path_checkpoint': default_args['path_checkpoint'], - 'path_dataset': default_args['path_dataset'], - 'path_autoencoder': default_args['path_autoencoder'], + 'checkpoint': default_args['checkpoint'], + 'data': default_args['data'], + 'autoencoder': default_args['autoencoder'], 'batch_size': default_args['batch_size'], - 'learning_rate': default_args['learning_rate'], 'discriminator_lr': default_args['discriminator_lr'], 'generator_lr': default_args['generator_lr'], 'sample_interval': default_args['sample_interval'], - 'n_conditions': len(default_args['conditions']) if default_args['conditions'][0] != '' else 0, + 'n_conditions': len(default_args['kw_conditions']) if default_args['kw_conditions'][0] != '' else 0, 'patch_size': default_args['patch_size'], - 'kw_timestep': default_args['kw_timestep'], - 'conditions': default_args['conditions'], + 'kw_time': default_args['kw_time'], + 'kw_conditions': default_args['kw_conditions'], 'sequence_length': -1, 'hidden_dim': default_args['hidden_dim'], # Dimension of hidden layers in discriminator and generator 'num_layers': default_args['num_layers'], - 'activation': default_args['activation'], 'latent_dim': 128, # Dimension of the latent space 'critic_iterations': 5, # number of iterations of the critic per generator iteration for Wasserstein GAN 'lambda_gp': 10, # Gradient penalty lambda for Wasserstein GAN-GP 'device': torch.device("cuda" if torch.cuda.is_available() else "cpu") if not ddp else torch.device("cpu"), 'world_size': torch.cuda.device_count() if torch.cuda.is_available() else mp.cpu_count(), # number of processes for distributed training - # 'multichannel': default_args['multichannel'], - 'channel_label': default_args['channel_label'], + 'kw_channel': default_args['kw_channel'], 'norm_data': norm_data, 'std_data': std_data, 'diff_data': diff_data, - 'lr_scheduler': default_args['lr_scheduler'], - 'scheduler_warmup': default_args['scheduler_warmup'], - 'scheduler_target': default_args['scheduler_target'], + 'seed': default_args['seed'], + 'save_name': default_args['save_name'], + 'history': None, } - + + # set a seed for reproducibility if desired + if opt['seed'] is not None: + np.random.seed(opt['seed']) + torch.manual_seed(opt['seed']) + torch.cuda.manual_seed(opt['seed']) + torch.cuda.manual_seed_all(opt['seed']) + torch.backends.cudnn.deterministic = True + # Load dataset as tensor - dataloader = Dataloader(default_args['path_dataset'], - kw_timestep=default_args['kw_timestep'], - col_label=default_args['conditions'], + dataloader = Dataloader(default_args['data'], + kw_time=default_args['kw_time'], + kw_conditions=default_args['kw_conditions'], norm_data=norm_data, std_data=std_data, diff_data=diff_data, - channel_label=default_args['channel_label']) + kw_channel=default_args['kw_channel']) dataset = dataloader.get_data() opt['channel_names'] = dataloader.channels opt['n_channels'] = dataset.shape[-1] opt['sequence_length'] = dataset.shape[1] - dataloader.labels.shape[1] - if opt['input_sequence_length'] == -1: - opt['input_sequence_length'] = opt['sequence_length'] opt['n_samples'] = dataset.shape[0] - ae_dict = torch.load(opt['path_autoencoder'], map_location=torch.device('cpu')) if opt['path_autoencoder'] != '' else [] - if opt['gan_type'] == 'tts' and ae_dict and (ae_dict['configuration']['target'] == 'full' or ae_dict['configuration']['target'] == 'time') and ae_dict['configuration']['timeseries_out'] % opt['patch_size']!= 0: - warnings.warn( - f"Sequence length ({ae_dict['configuration']['timeseries_out']}) must be a multiple of patch size ({default_args['patch_size']}).\n" - f"The sequence length is padded with zeros to fit the condition.") - padding = 0 - while (ae_dict['configuration']['timeseries_out'] + padding) % default_args['patch_size'] != 0: - padding += 1 - - padding = torch.zeros((dataset.shape[0], padding, dataset.shape[-1])) - dataset = torch.cat((dataset, padding), dim=1) - opt['sequence_length'] = dataset.shape[1] - dataloader.labels.shape[1] - elif opt['gan_type'] == 'tts' and opt['sequence_length'] % opt['patch_size'] != 0: - warnings.warn( - f"Sequence length ({opt['sequence_length']}) must be a multiple of patch size ({default_args['patch_size']}).\n" - f"The sequence length is padded with zeros to fit the condition.") - padding = 0 - while (opt['sequence_length'] + padding) % default_args['patch_size'] != 0: - padding += 1 - padding = torch.zeros((dataset.shape[0], padding, dataset.shape[-1])) - dataset = torch.cat((dataset, padding), dim=1) - opt['sequence_length'] = dataset.shape[1] - dataloader.labels.shape[1] + ae_dict = torch.load(opt['autoencoder'], map_location=torch.device('cpu')) if opt['autoencoder'] != '' else [] + # check if generated sequence is a multiple of patch size + encoded_sequence = False + def pad_warning(sequence_length, encoded_sequence=False): + error_msg = f"Sequence length ({sequence_length}) must be a multiple of patch size ({default_args['patch_size']})." + error_msg += " Please adjust the 'patch_size' or " + if encoded_sequence: + error_msg += "adjust the output sequence length of the autoencoder ('time_out'). The latter option requires a newly trained autoencoder." + else: + error_msg += "adjust the sequence length of the dataset." + raise ValueError(error_msg) + if ae_dict and (ae_dict['configuration']['target'] == 'full' or ae_dict['configuration']['target'] == 'time'): + generated_seq_length = ae_dict['configuration']['time_out'] + encoded_sequence = True else: - padding = torch.zeros((dataset.shape[0], 0, dataset.shape[-1])) - - opt['latent_dim_in'] = opt['latent_dim'] + opt['n_conditions'] + opt['n_channels'] if opt['input_sequence_length'] > 0 else opt['latent_dim'] + opt['n_conditions'] + generated_seq_length = opt['sequence_length'] + if generated_seq_length % default_args['patch_size'] != 0: + pad_warning(generated_seq_length, encoded_sequence) + + opt['latent_dim_in'] = opt['latent_dim'] + opt['n_conditions'] opt['channel_in_disc'] = opt['n_channels'] + opt['n_conditions'] - opt['sequence_length_generated'] = opt['sequence_length'] - opt['input_sequence_length'] if opt['input_sequence_length'] != opt['sequence_length'] else opt['sequence_length'] - opt['padding'] = padding.shape[1] - + opt['sequence_length_generated'] = opt['sequence_length'] + # -------------------------------------------------------------------------------- # Initialize generator, discriminator and trainer # -------------------------------------------------------------------------------- @@ -146,6 +146,28 @@ def main(): generator, discriminator = init_gan(**opt) print("Generator and discriminator initialized.") + # -------------------------------------------------------------------------------- + # Setup History + # -------------------------------------------------------------------------------- + + # Populate model configuration + history = {} + for key in opt.keys(): + if (not key == 'history') | (not key == 'trained_epochs'): + history[key] = [opt[key]] + history['trained_epochs'] = [] + + if default_args['checkpoint'] != '': + + # load checkpoint + model_dict = torch.load(default_args['checkpoint']) + + # update history + for key in history.keys(): + history[key] = model_dict['configuration']['history'][key] + history[key] + + opt['history'] = history + # ---------------------------------------------------------------------------------------------------------------------- # Start training process # ---------------------------------------------------------------------------------------------------------------------- @@ -156,32 +178,39 @@ def main(): print('-----------------------------------------\n') if ddp: trainer = GANDDPTrainer(generator, discriminator, opt) - if default_args['load_checkpoint']: - trainer.load_checkpoint(default_args['path_checkpoint']) + if default_args['checkpoint'] != '': + trainer.load_checkpoint(default_args['checkpoint']) mp.spawn(run, args=(opt['world_size'], find_free_port(), ddp_backend, trainer, opt), nprocs=opt['world_size'], join=True) + + print("GAN training finished.") + else: trainer = GANTrainer(generator, discriminator, opt) - if default_args['load_checkpoint']: - trainer.load_checkpoint(default_args['path_checkpoint']) - dataset = DataLoader(dataset, batch_size=trainer.batch_size, shuffle=True) + if default_args['checkpoint'] != '': + trainer.load_checkpoint(default_args['checkpoint']) + dataset = DataLoader(dataset, batch_size=trainer.batch_size, shuffle=True, pin_memory=True) gen_samples = trainer.training(dataset) # save final models, optimizer states, generated samples, losses and configuration as final result path = 'trained_models' timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - filename = f'gan_{trainer.epochs}ep_' + timestamp + '.pt' - trainer.save_checkpoint(path_checkpoint=os.path.join(path, filename), samples=gen_samples, update_history=True) - - print(f"Checkpoint saved to {path_checkpoint}.") + if opt['save_name'] != '': + # check if .pt extension is already included in the save_name + if not opt['save_name'].endswith('.pt'): + opt['save_name'] += '.pt' + filename = opt['save_name'] + else: + filename = f'gan_{trainer.epochs}ep_' + timestamp + '.pt' + path_checkpoint = os.path.join(path, filename) + trainer.save_checkpoint(path_checkpoint=path_checkpoint, samples=gen_samples, update_history=True) generator = trainer.generator discriminator = trainer.discriminator print("GAN training finished.") - print(f"Model states and generated samples saved to file {os.path.join(path, filename)}.") - + return generator, discriminator, opt, gen_samples diff --git a/generate_samples_main.py b/generate_samples_main.py index 7157fd7..fb65700 100644 --- a/generate_samples_main.py +++ b/generate_samples_main.py @@ -5,19 +5,27 @@ import pandas as pd import torch from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present +from torch.utils.data import DataLoader from helpers import system_inputs from helpers.dataloader import Dataloader -from helpers.initialize_gan import init_gan, gan_types +from helpers.initialize_gan import init_gan from helpers.trainer import GANTrainer -from nn_architecture.models import DecoderGenerator, TransformerGenerator, AutoencoderGenerator -from nn_architecture.ae_networks import TransformerDoubleAutoencoder, TransformerAutoencoder, \ - TransformerFlattenAutoencoder +from nn_architecture.models import DecoderGenerator +from nn_architecture.vae_networks import VariationalAutoencoder #another comment def main(): default_args = system_inputs.parse_arguments(sys.argv, file='generate_samples_main.py') + # set a seed for reproducibility if desired + if default_args['seed'] is not None: + np.random.seed(default_args['seed']) + torch.manual_seed(default_args['seed']) + torch.cuda.manual_seed(default_args['seed']) + torch.cuda.manual_seed_all(default_args['seed']) + torch.backends.cudnn.deterministic = True + print('\n-----------------------------------------') print("System output:") print('-----------------------------------------\n') @@ -32,14 +40,14 @@ def main(): if len(condition) == 1 and condition[0] == 'None': condition = [] - file = default_args['path_file'] + file = default_args['model'] if file.split(os.path.sep)[0] == file and file.split('/')[0] == file: # use default path if no path is given path = 'trained_models' file = os.path.join(path, file) - path_samples = default_args['path_samples'] - if path_samples == 'None': + path_samples = default_args['save_name'] + if path_samples == '': # Use checkpoint filename as path path_samples = os.path.basename(file).split('.')[0] + '.csv' if path_samples.split(os.path.sep)[0] == path_samples: @@ -51,152 +59,162 @@ def main(): state_dict = torch.load(file, map_location='cpu') - # load model/training configuration - n_conditions = state_dict['configuration']['n_conditions'] - n_channels = state_dict['configuration']['n_channels'] - channel_names = state_dict['configuration']['channel_names'] - latent_dim = state_dict['configuration']['latent_dim'] - sequence_length = state_dict['configuration']['sequence_length'] - input_sequence_length = state_dict['configuration']['input_sequence_length'] - - assert n_conditions == len(condition), f"Number of conditions in model ({n_conditions}) does not match number of conditions given ({len(condition)})." - - if input_sequence_length != 0 and input_sequence_length != sequence_length: - raise NotImplementedError(f"Prediction case detected.\nInput sequence length ({input_sequence_length}) > 0 and != sequence length ({sequence_length}).\nPrediction is not implemented yet.") - - # get data from dataset if sequence2sequence or prediction case - if input_sequence_length != 0: - dataloader = Dataloader(**state_dict['configuration']['dataloader']) - dataset = dataloader.get_data() - if n_conditions > 0: - raise NotImplementedError( - f"Prediction or Sequence-2-Sequence case detected.\nGeneration with conditions in on of these cases is not implemented yet.\nPlease generate without conditions.") - else: - dataset = None - # define device - device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - # Initialize generator - print("Initializing generator...") - latent_dim_in = latent_dim + n_conditions + n_channels if input_sequence_length > 0 else latent_dim + n_conditions - - for k, v in gan_types.items(): - if state_dict['configuration']['generator_class'] in v: - gan_type = k - break - - generator, _ = init_gan(gan_type=gan_type, - latent_dim_in=latent_dim_in, - channel_in_disc=n_channels, - n_channels=n_channels, - n_conditions=n_conditions, - sequence_length_generated=sequence_length, - device=device, - hidden_dim=state_dict['configuration']['hidden_dim'], - num_layers=state_dict['configuration']['num_layers'], - activation=state_dict['configuration']['activation'], - input_sequence_length=input_sequence_length, - patch_size=state_dict['configuration']['patch_size'], - path_autoencoder=state_dict['configuration']['path_autoencoder'], - padding=state_dict['configuration']['padding'], - ) - generator.eval() - if isinstance(generator, DecoderGenerator): - generator.padding=state_dict['configuration']['padding'] #TODO: ADD BACK - generator.decode_output() - - # load generator weights - generator.load_state_dict(state_dict['generator']) - generator.to(device) - - # check given conditions that they are numeric - for i, x in enumerate(condition): - if x == -1 or x == -2: - continue - else: - try: - condition[i] = float(x) - except ValueError: - raise ValueError(f"Condition {x} is not numeric.") - - # create condition labels if conditions are given but differ from number of conditions in model - if n_conditions != len(condition): - if n_conditions > len(condition) and len(condition) == 1 and condition[0] == -1: - # if only one condition is given and it is -1, then all conditions are set to -1 - condition = condition * n_conditions - else: - raise ValueError( - f"Number of conditions in model (={n_conditions}) does not match number of conditions given ={len(condition)}.") - - seq_len = max(1, input_sequence_length) - cond_labels = torch.zeros((num_samples_parallel, seq_len, n_conditions)).to(device) + torch.tensor(condition).to(device) - cond_labels = cond_labels.to(device) + device = torch.device('cpu') - # generate samples - num_sequences = num_samples_total // num_samples_parallel - print("Generating samples...") + # check if column condition labels are given + n_conditions = len(state_dict['configuration']['kw_conditions']) if state_dict['configuration']['kw_conditions'] and state_dict['configuration']['kw_conditions'] != [''] else 0 + if n_conditions > 0: + col_labels = state_dict['configuration']['dataloader']['kw_conditions'] + else: + col_labels = [] + + # check if channel label is given + if not state_dict['configuration']['dataloader']['kw_channel'] in [None, '']: + kw_channel = [state_dict['configuration']['dataloader']['kw_channel']] + else: + kw_channel = ['Electrode'] - all_samples = np.zeros((num_samples_parallel * num_sequences * n_channels, n_conditions + 1 + sequence_length)) + # get keyword for time step labels + if state_dict['configuration']['dataloader']['kw_time']: + kw_time = state_dict['configuration']['dataloader']['kw_time'] + else: + kw_time = 'Time' + + if state_dict['configuration']['model_class'] != 'VariationalAutoencoder': + + # load model/training configuration + n_conditions = state_dict['configuration']['n_conditions'] + n_channels = state_dict['configuration']['n_channels'] + channel_names = state_dict['configuration']['channel_names'] + latent_dim = state_dict['configuration']['latent_dim'] + sequence_length = state_dict['configuration']['sequence_length'] + # input_sequence_length = state_dict['configuration']['input_sequence_length'] + + if n_conditions != len(condition): + raise ValueError(f"Number of conditions in model ({n_conditions}) does not match number of conditions given ({len(condition)}).") + + # define device + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + # Initialize generator + print("Initializing generator...") + latent_dim_in = latent_dim + n_conditions + + generator, _ = init_gan(latent_dim_in=latent_dim_in, + channel_in_disc=n_channels, + n_channels=n_channels, + n_conditions=n_conditions, + sequence_length_generated=sequence_length, + device=device, + hidden_dim=state_dict['configuration']['hidden_dim'], + num_layers=state_dict['configuration']['num_layers'], + patch_size=state_dict['configuration']['patch_size'], + autoencoder=state_dict['configuration']['autoencoder'], + ) + generator.eval() + if isinstance(generator, DecoderGenerator): + generator.decode_output() + + # load generator weights + generator.load_state_dict(state_dict['generator']) + generator.to(device) + + # check given conditions that they are numeric + for i, x in enumerate(condition): + if x == -1 or x == -2: + continue + else: + try: + condition[i] = float(x) + except ValueError: + raise ValueError(f"Condition {x} is not numeric.") + + seq_len = 1 # max(1, input_sequence_length) + cond_labels = torch.zeros((num_samples_parallel, seq_len, n_conditions)).to(device) + torch.tensor(condition).to(device) + cond_labels = cond_labels.to(device) + + # generate samples + num_sequences = num_samples_total // num_samples_parallel + print("Generating samples...") + + all_samples = np.zeros((num_samples_parallel * num_sequences * n_channels, n_conditions + 1 + sequence_length)) + + for i in range(num_sequences): + print(f"Generating sequence {i + 1}/{num_sequences}...") + with torch.no_grad(): + # draw latent variable + z = GANTrainer.sample_latent_variable(batch_size=num_samples_parallel, latent_dim=latent_dim, + sequence_length=seq_len, device=device) + # concat with conditions and input sequence + z = torch.cat((z, cond_labels), dim=-1).float().to(device) + # generate samples + samples = generator(z).cpu().numpy() + + # reshape samples by concatenating over channels in incrementing channel name order + new_samples = np.zeros((num_samples_parallel * n_channels, n_conditions + 1 + sequence_length)) + for j, channel in enumerate(channel_names): + # padding = np.zeros((samples.shape[0], state_dict['configuration']['padding'])) + # new_samples[j::n_channels] = np.concatenate((cond_labels.cpu().numpy()[:, 0, :], np.zeros((num_samples_parallel, 1)) + channel, np.concatenate((samples[:, :, j], padding), axis=1)), axis=-1) + new_samples[j::n_channels] = np.concatenate((cond_labels.cpu().numpy()[:, 0, :], np.zeros((num_samples_parallel, 1)) + channel, samples[:, :, j]), axis=-1) + # add samples to all_samples + all_samples[i * num_samples_parallel * n_channels:(i + 1) * num_samples_parallel * n_channels] = new_samples + + elif state_dict['configuration']['model_class'] == 'VariationalAutoencoder': + + # load data + dataloader = Dataloader(path=state_dict['configuration']['dataloader']['data'], + kw_channel=kw_channel[0], + kw_conditions=state_dict['configuration']['dataloader']['kw_conditions'], + kw_time=state_dict['configuration']['dataloader']['kw_time'], + norm_data=state_dict['configuration']['dataloader']['norm_data'], + std_data=state_dict['configuration']['dataloader']['std_data'], + diff_data=state_dict['configuration']['dataloader']['diff_data']) + dataset = dataloader.get_data() + dataset = DataLoader(dataset, batch_size=state_dict['configuration']['batch_size'], shuffle=True) - for i in range(num_sequences): - print(f"Generating sequence {i + 1}/{num_sequences}...") - # get input sequence by drawing randomly num_samples_parallel input sequences from dataset - if input_sequence_length > 0 and dataset: - input_sequence = dataset[np.random.randint(0, len(dataset), num_samples_parallel), :input_sequence_length, :] - labels_in = torch.cat((cond_labels, input_sequence), dim=1).float() + sequence_length = int(state_dict['configuration']['input_dim']/dataset.dataset.shape[-1]) + channel_names = dataloader.channels + n_conditions = len(default_args['conditions']) + if condition: + cond_labels = torch.zeros((num_samples_total, state_dict['configuration']['input_dim'], len(default_args['conditions']))).to(device) + torch.tensor(condition).to(device) else: - labels_in = cond_labels - input_sequence = None - with torch.no_grad(): - # draw latent variable - z = GANTrainer.sample_latent_variable(batch_size=num_samples_parallel, latent_dim=latent_dim, - sequence_length=seq_len, device=device) - # concat with conditions and input sequence - z = torch.cat((z, labels_in), dim=-1).float().to(device) - # generate samples - samples = generator(z).cpu().numpy() - # if prediction case, concatenate input sequence and generated sequence - if input_sequence_length > 0 and input_sequence_length != sequence_length and input_sequence is not None: - samples = np.concatenate((input_sequence, samples), axis=1) - # reshape samples by concatenating over channels in incrementing channel name order - new_samples = np.zeros((num_samples_parallel * n_channels, n_conditions + 1 + sequence_length)) + cond_labels = torch.zeros((num_samples_total, state_dict['configuration']['input_dim'], 1)).to(device) + torch.tensor([-1]).to(device) + cond_labels = cond_labels.to(device) + + # load VAE + model = VariationalAutoencoder(input_dim=state_dict['configuration']['input_dim'], + hidden_dim=state_dict['configuration']['hidden_dim'], + encoded_dim=state_dict['configuration']['encoded_dim'], + activation=state_dict['configuration']['activation'], + device=device).to(device) + + consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.') + model.load_state_dict(state_dict['model']) + + # generate samples + samples = model.generate_samples(loader=dataset, condition=condition, num_samples=num_samples_total) + + # reconfigure samples to a 2D matrix for saving + new_samples = [] for j, channel in enumerate(channel_names): - padding = np.zeros((samples.shape[0], state_dict['configuration']['padding'])) - new_samples[j::n_channels] = np.concatenate((cond_labels.cpu().numpy()[:, 0, :], np.zeros((num_samples_parallel, 1)) + channel, np.concatenate((samples[:, :, j], padding), axis=1)), axis=-1) + new_samples.append(np.concatenate((cond_labels.cpu().numpy()[:, 0, :], np.zeros((num_samples_total, 1)) + channel, samples[:, 1:, j]), axis=-1)) # add samples to all_samples - all_samples[i * num_samples_parallel * n_channels:(i + 1) * num_samples_parallel * n_channels] = new_samples + all_samples = np.vstack(new_samples) + + else: + raise NotImplementedError(f"The model class {state_dict['configuration']['model_class']} is not recognized.") # save samples print("Saving samples...") - # check if column condition labels are given - if state_dict['configuration']['dataloader']['column_label'] and len( - state_dict['configuration']['dataloader']['column_label']) == n_conditions: - col_labels = state_dict['configuration']['dataloader']['column_label'] - else: - if n_conditions > 0: - col_labels = [f'Condition {i}' for i in range(n_conditions)] - else: - col_labels = [] - # check if channel label is given - if state_dict['configuration']['dataloader']['channel_label']: - channel_label = [state_dict['configuration']['dataloader']['channel_label']] - else: - channel_label = ['Channel'] - # get keyword for time step labels - if state_dict['configuration']['dataloader']['kw_timestep']: - kw_timestep = state_dict['configuration']['dataloader']['kw_timestep'] - else: - kw_timestep = 'Time' + # create time step labels - time_labels = [f'Time{i}' for i in range(sequence_length)] + time_labels = [f'{kw_time}{i}' for i in range(sequence_length)] # create dataframe - df = pd.DataFrame(all_samples, columns=[col_labels + channel_label + time_labels]) + df = pd.DataFrame(all_samples, columns=[col_labels + kw_channel + time_labels]) df.to_csv(path_samples, index=False) print("Generated samples were saved to " + path_samples) - - + if __name__ == '__main__': - # sys.argv = ["file=gan_1830ep.pt", "conditions=1"] main() diff --git a/get_gan_config.py b/get_gan_config.py index f5f6061..9612b95 100644 --- a/get_gan_config.py +++ b/get_gan_config.py @@ -7,7 +7,7 @@ def main(): default_args = system_inputs.parse_arguments(sys.argv, kw_dict=system_inputs.default_inputs_get_gan_config()) - file = default_args['path_file'] + file = default_args['model'] if file.split(os.path.sep)[0] == file: # use default path if no path is given @@ -21,5 +21,4 @@ def main(): print('\n') if __name__ == "__main__": - # sys.argv = ["path_file=trained_models\gan_ddp_8000ep_tanh.pt"] main() \ No newline at end of file diff --git a/helpers/dataloader.py b/helpers/dataloader.py index e20e60b..24e66c4 100644 --- a/helpers/dataloader.py +++ b/helpers/dataloader.py @@ -16,7 +16,7 @@ class Dataloader: def __init__(self, path=None, diff_data=False, std_data=False, norm_data=False, - kw_timestep='Time', col_label='', channel_label=None):#, multichannel: Union[bool, List[str]]=False): + kw_time='Time', kw_conditions='', kw_channel=None):#, multichannel: Union[bool, List[str]]=False): """Load data from csv as pandas dataframe and convert to tensor. Args: @@ -31,47 +31,36 @@ def __init__(self, path=None, # reshape and filter data based on channel specifications channels = [0] - if channel_label != '': - channels = df[channel_label].unique() + if kw_channel != '': + channels = df[kw_channel].unique() assert len(df)%len(channels)==0, f"Number of rows ({len(df)}) must be a multiple of number of channels ({len(channels)}).\nThis could be caused by missing data for some channels." - # if type(multichannel) == list: - # channels = [channel for channel in channels if channel in multichannel] - # # filter data for specified channels - # df = df.loc[df[channel_label].isin(multichannel)] n_channels = len(channels) self.channels = channels # get first column index of a time step - n_col_data = [index for index in range(len(df.columns)) if kw_timestep in df.columns[index]] + n_col_data = [index for index in range(len(df.columns)) if kw_time in df.columns[index]] - if not isinstance(col_label, list): - col_label = [col_label] + if not isinstance(kw_conditions, list): + kw_conditions = [kw_conditions] # Get labels and data dataset = torch.FloatTensor(df.to_numpy()[:, n_col_data]) - n_labels = len(col_label) if col_label[0] != '' else 0 + n_labels = len(kw_conditions) if kw_conditions[0] != '' else 0 labels = torch.zeros((dataset.shape[0], n_labels)) if n_labels: - for i, l in enumerate(col_label): + for i, l in enumerate(kw_conditions): labels[:, i] = torch.FloatTensor(df[l]) - # if multichannel: - # channel_labels = torch.FloatTensor(df[channel_label]) - if diff_data: # Diff of data dataset = dataset[:, 1:] - dataset[:, :-1] - # self.dataset_min = None - # self.dataset_max = None self.dataset_min = torch.min(dataset) self.dataset_max = torch.max(dataset) if norm_data: # Normalize data dataset = (dataset - self.dataset_min) / (self.dataset_max - self.dataset_min) - # self.dataset_mean = None - # self.dataset_std = None self.dataset_mean = dataset.mean(dim=0).unsqueeze(0) self.dataset_std = dataset.std(dim=0).unsqueeze(0) if std_data: @@ -80,7 +69,7 @@ def __init__(self, path=None, # reshape data to separate electrodes --> new shape: (trial, sequence, channel) if len(self.channels) > 1: - sort_index = df.sort_values(channel_label, kind="mergesort").index + sort_index = df.sort_values(kw_channel, kind="mergesort").index dataset = dataset[sort_index].contiguous().view(n_channels, dataset.shape[0]//n_channels, dataset.shape[1]).permute(1, 2, 0) labels = labels[sort_index].contiguous().view(n_channels, labels.shape[0]//n_channels, labels.shape[1]).permute(1, 2, 0) else: diff --git a/helpers/ddp_training.py b/helpers/ddp_training.py index 4c8c813..3409df5 100644 --- a/helpers/ddp_training.py +++ b/helpers/ddp_training.py @@ -44,11 +44,7 @@ def print_log(self, current_epoch, d_loss, g_loss): def manage_checkpoints(self, path_checkpoint: str, checkpoint_files: list, generator=None, discriminator=None, samples=None, update_history=False): if self.rank == 0: - # print(f'Rank {self.rank} is managing checkpoints.') super().manage_checkpoints(path_checkpoint, checkpoint_files, generator=self.generator.module, discriminator=self.discriminator.module, samples=samples, update_history=update_history) - # print(f'Rank {self.rank} finished managing checkpoints.') - # print(f'Rank {self.rank} reached barrier.') - # dist.barrier() def set_device(self, rank): self.rank = rank @@ -66,12 +62,10 @@ def set_ddp_framework(self): d_opt_state = self.discriminator_optimizer.state_dict() self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), - lr=self.learning_rate, betas=(self.b1, self.b2)) + lr=self.g_lr, betas=(self.b1, self.b2)) self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), - lr=self.learning_rate, betas=(self.b1, self.b2)) + lr=self.d_lr, betas=(self.b1, self.b2)) - self.generator_optimizer.load_state_dict(g_opt_state) - self.discriminator_optimizer.load_state_dict(d_opt_state) class AEDDPTrainer(trainer.AETrainer): """Trainer for conditional Wasserstein-GAN with gradient penalty. @@ -95,8 +89,6 @@ def save_checkpoint(self, path_checkpoint=None, model=None, update_history=False # dist.barrier() def print_log(self, current_epoch, train_loss, test_loss): - # if self.rank == 0: - # average the loss across all processes before printing reduce_tensor = torch.tensor([train_loss, test_loss], dtype=torch.float32, device=self.device) dist.all_reduce(reduce_tensor, op=dist.ReduceOp.SUM) reduce_tensor /= self.world_size @@ -105,11 +97,7 @@ def print_log(self, current_epoch, train_loss, test_loss): def manage_checkpoints(self, path_checkpoint: str, checkpoint_files: list, model=None, update_history=False, samples=None): if self.rank == 0: - # print(f'Rank {self.rank} is managing checkpoints.') super().manage_checkpoints(path_checkpoint, checkpoint_files, model=self.model.module, update_history=update_history, samples=samples) - # print(f'Rank {self.rank} finished managing checkpoints.') - # print(f'Rank {self.rank} reached barrier.') - # dist.barrier() def set_device(self, rank): self.rank = rank @@ -124,20 +112,21 @@ def set_ddp_framework(self): # safe optimizer state_dicts, init new ddp optimizer and load state_dicts opt_state = self.optimizer.state_dict() self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) - self.optimizer.load_state_dict(opt_state) def run(rank, world_size, master_port, backend, trainer_ddp, opt): - _setup(rank, world_size, master_port, backend) - trainer_ddp = _setup_trainer(rank, trainer_ddp) - _ddp_training(trainer_ddp, opt) - dist.destroy_process_group() + try: + _setup(rank, world_size, master_port, backend) + trainer_ddp = _setup_trainer(rank, trainer_ddp) + _ddp_training(trainer_ddp, opt) + dist.destroy_process_group() + except Exception as error: + ValueError(f"Error in DDP training: {error}") + dist.destroy_process_group() def _setup(rank, world_size, master_port, backend): - # print(f"Initializing process group on rank {rank}")# on master port {self.master_port}.") - - os.environ['MASTER_ADDR'] = 'localhost' # '127.0.0.1' + os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = str(master_port) # create default process group @@ -152,21 +141,30 @@ def _setup_trainer(rank, trainer_ddp): # construct DDP model trainer_ddp.set_ddp_framework() - # load checkpoint - # if training.use_checkpoint: - # training.load_checkpoint(training.path_checkpoint) - return trainer_ddp def _ddp_training(trainer_ddp, opt): # load data if 'conditions' not in opt: opt['conditions'] = [''] - dataloader = Dataloader(opt['path_dataset'], - kw_timestep=opt['kw_timestep'], - col_label=opt['conditions'], - norm_data=True, - channel_label=opt['channel_label']) + if isinstance(trainer_ddp, GANDDPTrainer): + dataloader = Dataloader(opt['data'], + kw_time=opt['kw_time'], + kw_conditions=opt['kw_conditions'], + norm_data=opt['norm_data'], + std_data=opt['std_data'], + diff_data=opt['diff_data'], + kw_channel=opt['kw_channel']) + elif isinstance(trainer_ddp, AEDDPTrainer): + dataloader = Dataloader(opt['data'], + kw_time=opt['kw_time'], + norm_data=opt['norm_data'], + std_data=opt['std_data'], + diff_data=opt['diff_data'], + kw_channel=opt['kw_channel']) + else: + raise ValueError(f"Trainer type {type(trainer_ddp)} not supported.") + dataset = dataloader.get_data() opt['sequence_length'] = dataset.shape[2] - dataloader.labels.shape[2] @@ -192,8 +190,18 @@ def _ddp_training(trainer_ddp, opt): # save checkpoint if trainer_ddp.rank == 0: + + # save final models, optimizer states, generated samples, losses and configuration as final result + path = 'trained_models' timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - filename = f'{model_prefix}_ddp_{trainer_ddp.epochs}ep_' + timestamp + '.pt' + if opt['save_name'] != '': + # check if .pt extension is already included in the save_name + if not opt['save_name'].endswith('.pt'): + opt['save_name'] += '.pt' + filename = opt['save_name'] + else: + filename = f'{model_prefix}_ddp_{trainer_ddp.epochs}ep_' + timestamp + '.pt' + if isinstance(trainer_ddp, GANDDPTrainer): trainer_ddp.save_checkpoint(path_checkpoint=os.path.join(path, filename), samples=gen_samples) elif isinstance(trainer_ddp, AEDDPTrainer): @@ -204,5 +212,5 @@ def _ddp_training(trainer_ddp, opt): samples.append(np.concatenate([inputs.unsqueeze(1).detach().cpu().numpy(), outputs.unsqueeze(1).detach().cpu().numpy()], axis=1)) trainer_ddp.save_checkpoint(path_checkpoint=os.path.join(path, filename), samples=samples) - print("GAN training finished.") + print("Model training finished.") print(f"Model states and generated samples saved to file {os.path.join(path, filename)}.") \ No newline at end of file diff --git a/helpers/ddp_training_classifier.py b/helpers/ddp_training_classifier.py deleted file mode 100644 index c71ee3a..0000000 --- a/helpers/ddp_training_classifier.py +++ /dev/null @@ -1,63 +0,0 @@ -import os -import torch -import torch.distributed as dist -from datetime import datetime, timedelta -from helpers.trainer_classifier import DDPTrainer - - -def run(rank, world_size, master_port, backend, trainer, train_data, train_labels, test_data, test_labels): - _setup(rank, world_size, master_port, backend) - trainer = _setup_trainer(rank, trainer) - _ddp_training(trainer, train_data, train_labels, test_data, test_labels) - dist.destroy_process_group() - - -def _setup(rank, world_size, master_port, backend): - # print(f"Initializing process group on rank {rank}")# on master port {self.master_port}.") - - os.environ['MASTER_ADDR'] = 'localhost' # '127.0.0.1' - os.environ['MASTER_PORT'] = str(master_port) - - # create default process group - dist.init_process_group(backend, rank=rank, world_size=world_size, timeout=timedelta(seconds=30)) - - -def _setup_trainer(rank, trainer: DDPTrainer): - # set device - trainer.set_device(rank) - print(f"Using device {trainer.device}.") - - # load checkpoint - # if trainer.use_checkpoint: - # trainer.load_checkpoint(trainer.path_checkpoint) - - # construct DDP model - trainer.set_ddp_framework() - - return trainer - - -def _ddp_training(trainer: DDPTrainer, train_data, train_labels, test_data, test_labels): - # take partition of dataset for each process - # start_index = int(len(train_data) / trainer.world_size * trainer.rank) - # end_index = int(len(train_data) / trainer.world_size * (trainer.rank + 1)) - # train_data = train_data[start_index:end_index] - # train_labels = train_labels[start_index:end_index] - # - # if trainer.batch_size > len(train_data): - # raise ValueError(f"Batch size {trainer.batch_size} is larger than the partition size {len(train_data)}.") - - # train - loss = trainer.train(train_data, train_labels, test_data, test_labels) - - # save checkpoint - if trainer.rank == 0: - test_dataset = torch.concat((test_labels, test_data), dim=1) - path = '../trained_classifier' - timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') - filename = f'classifier_ddp_{trainer.epochs}ep_' + timestamp + '.pt' - trainer.save_checkpoint(os.path.join(path, filename), test_dataset, loss) - - print("Classifier training finished.") - print("Model states, losses and test dataset saved to file: " - f"\n{filename}.") \ No newline at end of file diff --git a/helpers/initialize_gan.py b/helpers/initialize_gan.py index 6612552..2859c03 100644 --- a/helpers/initialize_gan.py +++ b/helpers/initialize_gan.py @@ -2,27 +2,20 @@ from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present from nn_architecture.ae_networks import TransformerAutoencoder, TransformerDoubleAutoencoder -from nn_architecture.models import FFGenerator, FFDiscriminator, TransformerGenerator, TransformerDiscriminator, TTSGenerator, TTSDiscriminator, DecoderGenerator, EncoderDiscriminator +from nn_architecture.models import TTSGenerator, TTSDiscriminator, DecoderGenerator, EncoderDiscriminator gan_architectures = { - 'FFGenerator': lambda latent_dim, channels, seq_len, hidden_dim, num_layers, dropout, activation, **kwargs: FFGenerator(latent_dim, channels, seq_len, hidden_dim, num_layers, dropout, activation), - 'FFDiscriminator': lambda channels, seq_len, hidden_dim, num_layers, dropout, **kwargs: FFDiscriminator(channels, seq_len, hidden_dim, num_layers, dropout), - 'TransformerGenerator': lambda latent_dim, channels, seq_len, hidden_dim, num_layers, num_heads, dropout, **kwargs: TransformerGenerator(latent_dim, channels, seq_len, hidden_dim, num_layers, num_heads, dropout), - 'TransformerDiscriminator': lambda channels, seq_len, hidden_dim, num_layers, num_heads, dropout, **kwargs: TransformerDiscriminator(channels, seq_len, 1, hidden_dim, num_layers, num_heads, dropout), 'TTSGenerator': lambda seq_len, hidden_dim, patch_size, channels, latent_dim, num_layers, num_heads, **kwargs: TTSGenerator(seq_len, patch_size, channels, 1, latent_dim, 10, num_layers, num_heads, 0.5, 0.5), 'TTSDiscriminator': lambda channels, hidden_dim, patch_size, seq_len, num_layers, **kwargs: TTSDiscriminator(channels, patch_size, 50, seq_len, num_layers, 1), } gan_types = { - 'ff': ['FFGenerator', 'FFDiscriminator'], - 'tr': ['TransformerGenerator', 'TransformerDiscriminator'], 'tts': ['TTSGenerator', 'TTSDiscriminator'], } -def init_gan(gan_type, - latent_dim_in, +def init_gan(latent_dim_in, channel_in_disc, n_channels, n_conditions, @@ -33,14 +26,12 @@ def init_gan(gan_type, activation='tanh', input_sequence_length=0, patch_size=-1, - path_autoencoder='', - padding=0, + autoencoder='', **kwargs, ): - if path_autoencoder == '': + if autoencoder == '': # no autoencoder defined -> use transformer GAN - generator = gan_architectures[gan_types[gan_type][0]]( - # FFGenerator inputs: latent_dim, channels, hidden_dim, num_layers, dropout, activation + generator = gan_architectures[gan_types['tts'][0]]( latent_dim=latent_dim_in, channels=n_channels, seq_len=sequence_length_generated, @@ -48,23 +39,18 @@ def init_gan(gan_type, num_layers=num_layers, dropout=0.1, activation=activation, - - # additional TransformerGenerator inputs: num_heads num_heads=4, # additional TTSGenerator inputs: patch_size patch_size=patch_size, ) - discriminator = gan_architectures[gan_types[gan_type][1]]( - # FFDiscriminator inputs: input_dim, hidden_dim, num_layers, dropout + discriminator = gan_architectures[gan_types['tts'][1]]( channels=channel_in_disc, hidden_dim=hidden_dim, num_layers=num_layers, dropout=0.1, seq_len=sequence_length_generated, - - # TransformerDiscriminator inputs: channels, n_classes, hidden_dim, num_layers, num_heads, dropout num_heads=4, # additional TTSDiscriminator inputs: patch_size @@ -74,7 +60,7 @@ def init_gan(gan_type, # initialize an autoencoder-GAN # initialize the autoencoder - ae_dict = torch.load(path_autoencoder, map_location=torch.device('cpu')) + ae_dict = torch.load(autoencoder, map_location=torch.device('cpu')) if ae_dict['configuration']['target'] == 'channels': ae_dict['configuration']['target'] = TransformerAutoencoder.TARGET_CHANNELS autoencoder = TransformerAutoencoder(**ae_dict['configuration']).to(device) @@ -93,22 +79,16 @@ def init_gan(gan_type, for param in autoencoder.parameters(): param.requires_grad = False autoencoder.eval() - - # if prediction or seq2seq, adjust latent_dim_in to encoded input size - if input_sequence_length != 0: - new_input_dim = autoencoder.output_dim if not hasattr(autoencoder, 'output_dim_2') else autoencoder.output_dim*autoencoder.output_dim_2 - latent_dim_in += new_input_dim - autoencoder.input_dim - + # adjust generator output_dim to match the output_dim of the autoencoder n_channels = autoencoder.output_dim if autoencoder.target in [autoencoder.TARGET_CHANNELS, autoencoder.TARGET_BOTH] else autoencoder.output_dim_2 - sequence_length_generated = autoencoder.output_dim_2+padding if autoencoder.target in [autoencoder.TARGET_CHANNELS, autoencoder.TARGET_BOTH] else autoencoder.output_dim + sequence_length_generated = autoencoder.output_dim_2 if autoencoder.target in [autoencoder.TARGET_CHANNELS, autoencoder.TARGET_BOTH] else autoencoder.output_dim # adjust discriminator input_dim to match the output_dim of the autoencoder channel_in_disc = n_channels + n_conditions generator = DecoderGenerator( - generator=gan_architectures[gan_types[gan_type][0]]( - # FFGenerator inputs: latent_dim, output_dim, hidden_dim, num_layers, dropout, activation + generator=gan_architectures[gan_types['tts'][0]]( latent_dim=latent_dim_in, channels=n_channels, seq_len=sequence_length_generated, @@ -116,8 +96,6 @@ def init_gan(gan_type, num_layers=num_layers, dropout=0.1, activation=activation, - - # TransformerGenerator inputs: latent_dim, channels, seq_len, hidden_dim, num_layers, num_heads, dropout num_heads=4, # additional TTSGenerator inputs: patch_size @@ -127,15 +105,12 @@ def init_gan(gan_type, ) discriminator = EncoderDiscriminator( - discriminator=gan_architectures[gan_types[gan_type][1]]( - # FFDiscriminator inputs: input_dim, hidden_dim, num_layers, dropout + discriminator=gan_architectures[gan_types['tts'][1]]( channels=channel_in_disc, hidden_dim=hidden_dim, num_layers=num_layers, dropout=0.1, seq_len=sequence_length_generated, - - # additional TransformerDiscriminator inputs: num_heads num_heads=4, # additional TTSDiscriminator inputs: patch_size diff --git a/helpers/system_inputs.py b/helpers/system_inputs.py index 2347962..9ee16df 100644 --- a/helpers/system_inputs.py +++ b/helpers/system_inputs.py @@ -120,7 +120,7 @@ def print_help(self): '\n\t\tpython gan_training_main.py load_checkpoint path_checkpoint="path/to/file.pt"') print( '4.\tIf you want to use a different dataset, you can use the following command:' - '\n\tpython gan_training_main.py path_dataset="path/to/file.csv"' + '\n\tpython gan_training_main.py data="path/to/file.csv"' '\n\tThe default dataset is "data/gansEEGTrainingData.csv"') print( '6.\tThe keyword "input_sequence_length" describes the length of a sequence taken as input for the generator.' @@ -136,6 +136,49 @@ def print_help(self): self.end_line() +class HelperVAE(Helper): + def __init__(self, kw_dict): + super().__init__(kw_dict) + + def print_help(self): + """Print help message for vae_training_main.py regarding special features.""" + super().print_help() + print( + '1.\tThe training works with two levels of checkpoint files:' + '\n\t1.1 During the training:' + '\n\t\tCheckpoints are saved every "sample_interval" batches as either "checkpoint_01.pt"' + '\n\t\tor "checkpoint_02.pt". These checkpoints are considered as low-level checkpoints since they are only ' + '\n\t\tnecessary in the case of training interruption. Hereby, they can be used to continue the training from ' + '\n\t\tthe most recent sample. To continue training, the most recent checkpoint file must be renamed to ' + '\n\t\t"checkpoint.pt".' + '\n\t\tFurther, these low-level checkpoints carry the generated samples for inference purposes.' + '\n\t1.2 After finishing the training:' + '\n\t\tA high-level checkpoint is saved as "checkpoint.pt", which is used to ' + '\n\t\tcontinue training in another session. This high-level checkpoint does not carry the generated samples.' + '\n\t\tTo continue training from this checkpoint file no further adjustments are necessary. ' + '\n\t\tSimply give the keyword "load_checkpoint" when calling the training process.' + '\n\t\tThe low-level checkpoints are deleted after creating the high-level checkpoint.' + '\n\t1.3 For inference purposes:' + '\n\t\tAnother dictionary is saved as "vae_{n_epochs}ep_{timestamp}.pt".' + '\n\t\tThis file contains everything the checkpoint file contains, plus the generated samples.') + print( + '2.\tUse "ddp" to activate distributed training. ' + '\n\tOnly if multiple GPUs are available for one node.' + '\n\tAll available GPUs are used for training.' + '\n\tEach GPU trains on the whole dataset. ' + '\n\tHence, the number of training epochs is multiplied by the number of GPUs') + print( + '3.\tIf you want to load a pre-trained VAE, you can use the following command:' + '\n\tpython gan_training_main.py load_checkpoint; The default file is "trained_models/checkpoint.pt"' + '\n\tIf you want to use another file, you can use the following command:' + '\n\t\tpython vae_training_main.py load_checkpoint path_checkpoint="path/to/file.pt"') + print( + '4.\tIf you want to use a different dataset, you can use the following command:' + '\n\tpython gan_training_main.py path_dataset="path/to/file.csv"' + '\n\tThe default dataset is "data/gansEEGTrainingData.csv"') + self.start_line() + self.end_line() + class HelperAutoencoder(Helper): def __init__(self, kw_dict): super().__init__(kw_dict) @@ -161,23 +204,23 @@ def __init__(self, kw_dict): def print_help(self): super().print_help() - print('1.\tEither the keyword "checkpoint" or "csv" must be given.' - '\n\t1.1 If the keyword "checkpoint" is given' - '\n\t\t"path_dataset" must point to a pt-file.' - '\n\t\t"path_dataset" may point to a GAN or an Autoencoder checkpoint file.' - '\n\t\tthe keyword "conditions" will be ignored since the conditions are taken from the checkpoint file.' - '\n\t\tthe keyword "channel_label" will be ignored since the samples are already sorted channel-wise.' - '\n\t\tthe samples will be drawn evenly from the saved samples to show the training progress.' - '\n\t1.2 If the keyword "csv" is given' - '\n\t\t"path_dataset" must point to a csv-file.' - '\n\t\tthe keyword "conditions" must be given to identify the condition column.' + print('1.\tEither the keyword "model" or "data" must be given.' + '\n\t1.1 If the keyword "model" is given' + '\n\t\t"model" must point to a pt-file.' + '\n\t\t"model" may point to a GAN or an Autoencoder checkpoint file.' + '\n\t\tthe keyword "kw_conditions" will be ignored since the conditions are taken from the checkpoint file.' + '\n\t\tthe keyword "kw_channel" will be ignored since the samples are already sorted channel-wise.' + '\n\t\tthe samples are drawn evenly from the saved samples to show the training progress.' + '\n\t1.2 If the keyword "data" is given' + '\n\t\t"data" must point to a csv-file.' + '\n\t\tthe keyword "kw_conditions" must be given to identify the condition column(s).' '\n\t\tthe samples will be drawn randomly from the dataset.') print('2.\tThe keyword "loss" works only with the keyword "checkpoint".') print('3.\tThe keyword "average" averages either' '\n\tall the samples (if no condition is given)' '\n\talong each combination of conditions that is given. The conditions are shown in the legend.') - print('4.\tWhen using the keywords "pca" or "tsne" the keyword "path_comp_dataset" must be defined.' - '\n\tExcept for the case "checkpoint" is given and the checkpoint file is an Autoencoder file.' + print('4.\tWhen using the keywords "pca" or "tsne" the keyword "comp_data" must be defined.' + '\n\tExcept for the case "model" is given and the checkpoint file is an Autoencoder file.' '\n\tIn this case, the comparison dataset (original data) is taken from the Autoencoder file directly.') print('5.\tThe keyword "channel_plots" can be used to enhace the visualization.' '\n\tThis way, the channels are shown in different subplots along the columns.') @@ -215,30 +258,22 @@ def print_help(self): def default_inputs_training_gan(): kw_dict = { 'ddp': [bool, 'Activate distributed training', False, 'Distributed training is active'], - 'load_checkpoint': [bool, 'Load a pre-trained GAN', False, 'Using a pre-trained GAN'], - 'channel_recovery': [bool, 'Training regime for channel recovery', False, 'Channel recovery training regime'], + 'seed': [bool, 'Set seed for reproducibility', None, 'Manual seed: '], 'n_epochs': [int, 'Number of epochs', 100, 'Number of epochs: '], 'batch_size': [int, 'Batch size', 128, 'Batch size: '], - 'input_sequence_length': [int, 'The generator makes predictions based on the input sequence length; If -1, no prediction but sequence-to-sequence-mapping of full sequence (not implemented yet)', 0, 'Input sequence length: '], 'sample_interval': [int, 'Interval of epochs between saving samples', 100, 'Sample interval: '], 'hidden_dim': [int, 'Hidden dimension of the GAN components', 16, 'Hidden dimension: '], 'num_layers': [int, 'Number of layers of the GAN components', 4, 'Number of layers: '], - 'patch_size': [int, 'Patch size of the divided sequence (only for TTS-GAN)', 20, 'Patch size: '], - 'learning_rate': [float, 'Learning rate of the GAN', 0.0001, 'Learning rate: '], - 'discriminator_lr': [float, 'If used, it overrides the general learning rate for the discriminator', None, 'Discriminator learning rate: '], - 'generator_lr': [float, 'If used, it overrides the general learning rate for the generator', None, 'Generator learning rate: '], - 'activation': [str, 'Activation function of the GAN components; Options: [relu, leakyrelu, sigmoid, tanh, linear]', 'tanh', 'Activation function: '], - 'type': [str, 'Type of the GAN; Options: [ff, tr, tts]', 'tr', 'GAN Type: '], - 'path_dataset': [str, 'Path to the dataset', os.path.join('data', 'gansEEGTrainingData.csv'), 'Dataset: '], - 'path_checkpoint': [str, 'Path to the checkpoint', os.path.join('trained_models', 'checkpoint.pt'), 'Checkpoint: '], - 'path_autoencoder': [str, 'Path to the autoencoder; Only usable with Autoencoder-GAN', '', 'Autoencoder checkpoint: '], - 'ddp_backend': [str, 'Backend for the DDP-Training; "nccl" for GPU; "gloo" for CPU;', 'nccl', 'DDP backend: '], - 'conditions': [str, '** Conditions to be used', '', 'Conditions: '], - 'kw_timestep': [str, 'Keyword for the time step of the dataset', 'Time', 'Keyword for the time step of the dataset: '], - 'channel_label': [str, 'Column name to detect used channels', '', 'Channel label: '], - 'lr_scheduler': [str, 'The learning rate scheduler to use; Options: [CyclicLR]', '', 'Learning rate scheduler: '], - 'scheduler_warmup': [int, 'Number of epochs before the scheduler will be initiated, if applicable', 0, 'Scheduler warmup: '], - 'scheduler_target': [str, 'Which part of the GAN to apply the learning rate scheduler, if applicable; Options: [discriminator, generator, both]', 'both', 'LR Scheduler Target: '] + 'patch_size': [int, 'Patch size of the divided sequence', 20, 'Patch size: '], + 'discriminator_lr': [float, 'Learning rate for the discriminator', 0.0001, 'Discriminator learning rate: '], + 'generator_lr': [float, 'Learning rate for the generator', 0.0001, 'Generator learning rate: '], + 'data': [str, 'Path to a dataset', os.path.join('data', 'gansEEGTrainingData.csv'), 'Dataset: '], + 'checkpoint': [str, 'Path to a pre-trained GAN', '', 'Using pre-trained GAN: '], + 'autoencoder': [str, 'Path to an autoencoder', '', 'Using autoencoder: '], + 'kw_conditions': [str, '** Conditions to be used', '', 'Conditions: '], + 'kw_time': [str, 'Keyword to detect the time steps of the dataset; e.g. if [Time1, Time2, ...] -> use Time', 'Time', 'Time label: '], + 'kw_channel': [str, 'Keyword to detect used channels', '', 'Channel label: '], + 'save_name': [str, 'Name to save model', '', 'Model save name: '], } return kw_dict @@ -248,17 +283,16 @@ def default_inputs_training_autoencoder(): kw_dict = { 'ddp': [bool, 'Activate distributed training', False, 'Distributed training is active'], 'load_checkpoint': [bool, 'Load a pre-trained AE', False, 'Loading a trained autoencoder model'], - 'ddp_backend': [str, 'Backend for the DDP-Training; "nccl" for GPU; "gloo" for CPU;', 'nccl', 'DDP backend: '], - 'path_dataset': [str, 'Path to the dataset', os.path.join('data', 'gansEEGTrainingData.csv'), 'Dataset: '], - 'path_checkpoint': [str, 'Path to a trained model to continue training', os.path.join('trained_ae', 'checkpoint.pt'), 'Checkpoint: '], - 'save_name': [str, 'Name to save model', None, 'Model save name: '], + 'seed': [bool, 'Set seed for reproducibility', None, 'Manual seed: '], + 'data': [str, 'Path to the dataset', os.path.join('data', 'gansEEGTrainingData.csv'), 'Dataset: '], + 'checkpoint': [str, 'Path to a pre-trained AE', '', 'Using pre-trained AE: '], + 'save_name': [str, 'Name to save model', '', 'Model save name: '], 'target': [str, 'Target dimension (channel, time, full) to encode; full is recommended for multi-channel data;', 'full', 'Target: '], - # 'conditions': [str, '** Conditions to be used', '', 'Conditions: '], - 'channel_label': [str, 'Column name to detect used channels', '', 'Channel label: '], - 'kw_timestep': [str, 'Keyword for the time step of the dataset', 'Time', 'Keyword for the time step of the dataset: '], - 'activation': [str, 'Activation function of the AE components; Options: [relu, leakyrelu, sigmoid, tanh, linear]', 'sigmoid', 'Activation function: '], + 'kw_time': [str, 'Keyword to detect the time steps of the dataset; e.g. if [Time1, Time2, ...] -> use Time', 'Time', 'Time label: '], + 'kw_channel': [str, 'Keyword to detect used channels', '', 'Channel label: '], + 'activation': [str, 'Activation function of the AE decoder; Options: [relu, leakyrelu, sigmoid, tanh, linear]', 'sigmoid', 'Activation function: '], 'channels_out': [int, 'Size of the encoded channels', 10, 'Encoded channels size: '], - 'timeseries_out': [int, 'Size of the encoded timeseries', 10, 'Encoded time series size: '], + 'time_out': [int, 'Size of the encoded timeseries', 10, 'Encoded time series size: '], 'n_epochs': [int, 'Number of epochs to train for', 100, 'Number of epochs: '], 'batch_size': [int, 'Batch size', 128, 'Batch size: '], 'sample_interval': [int, 'Interval of epochs between saving samples', 100, 'Sample interval: '], @@ -266,40 +300,35 @@ def default_inputs_training_autoencoder(): 'num_layers': [int, 'Number of layers of the transformer', 2, 'Number of layers: '], 'num_heads': [int, 'Number of heads of the transformer', 8, 'Number of heads: '], 'train_ratio': [float, 'Ratio of training data to total data', 0.8, 'Training ratio: '], - 'learning_rate': [float, 'Learning rate of the GAN', 0.0001, 'Learning rate: '], + 'learning_rate': [float, 'Learning rate of the AE', 0.0001, 'Learning rate: '], } return kw_dict -def default_inputs_training_classifier(): +def default_inputs_training_vae(): kw_dict = { - 'experiment': [bool, "Use experiment's samples as dataset", False, "Use experiment's samples as dataset"], - 'generated': [bool, 'Use generated samples as dataset', False, 'Use generated samples as dataset'], - 'ddp': [bool, 'Activate distributed training', False, 'Distributed training is active'], - 'testing': [bool, 'Only test. No training', False, 'Testing only'], - 'load_checkpoint': [bool, 'Load a pre-trained GAN', False, 'Using a pre-trained GAN'], - 'n_epochs': [int, 'Number of epochs', 100, 'Number of epochs: '], + 'load_checkpoint': [bool, 'Load a pre-trained AE', False, 'Loading a trained autoencoder model'], + 'data': [str, 'Path to the dataset', os.path.join('data', 'ganTrialElectrodeERP_p500_e1_SS100_Run00.csv'), 'Dataset: '], #TODO: REMOVE THIS + 'path_checkpoint': [str, 'Path to a trained model to continue training', os.path.join('trained_ae', 'checkpoint.pt'), 'Checkpoint: '], + 'save_name': [str, 'Name to save model', None, 'Model save name: '], + 'sample_interval': [int, 'Interval of epochs between saving samples', 100, 'Sample interval: '], + 'kw_channel': [str, 'Column name to detect used channels', '', 'Channel label: '], + 'kw_conditions': [str, '** Conditions to be used', 'Condition', 'Conditions: '], + 'kw_time': [str, 'Keyword for the time step of the dataset', 'Time', 'Keyword for the time step of the dataset: '], + # TODO: check for which components of VAE the parameter 'activation' applies + 'activation': [str, 'Activation function of the AE components; Options: [relu, leakyrelu, sigmoid, tanh, linear]', 'tanh', 'Activation function: '], + 'n_epochs': [int, 'Number of epochs to train for', 1000, 'Number of epochs: '], 'batch_size': [int, 'Batch size', 128, 'Batch size: '], - 'patch_size': [int, 'Patch size', 20, 'Patch size: '], - 'sequence_length': [int, 'Used length of the datasets sequences; If None, then the whole sequence is used', -1, 'Total sequence length: '], - 'sample_interval': [int, 'Interval of epochs between saving samples', 1000, 'Sample interval: '], - 'learning_rate': [float, 'Learning rate of the GAN', 0.0001, 'Learning rate: '], - 'path_dataset': [str, 'Path to the dataset', os.path.join('data', 'ganAverageERP_len100.csv'), 'Dataset: '], - 'path_test': [str, 'Path to the test dataset if using generated samples', 'None', 'Test dataset: '], - 'path_checkpoint': [str, 'Path to the checkpoint', os.path.join('trained_classifier', 'checkpoint.pt'), 'Checkpoint: '], - 'path_critic': [str, 'Path to the trained critic', os.path.join('trained_models', 'checkpoint.pt'), 'Critic: '], - 'ddp_backend': [str, 'Backend for the DDP-Training; "nccl" for GPU; "gloo" for CPU;', 'nccl', 'DDP backend: '], - 'conditions': [str, '** Conditions to be used', 'Condition', 'Conditions: '], - 'kw_timestep_dataset': [str, 'Keyword for the time step of the dataset', 'Time', 'Keyword for the time step of the dataset: '], - } - + 'hidden_dim': [int, 'Hidden dimension of the network', 256, 'Hidden dimension: '], + 'encoded_dim': [int, 'Encoded dimension of mu and sigma', 25, 'Encoded dimension: '], + 'learning_rate': [float, 'Learning rate of the VAE', 3e-4, 'Learning rate: '], + 'kl_alpha': [float, 'Weight of the KL divergence in loss', 0.00005, 'KL alpha: '], + } return kw_dict def default_inputs_visualize(): kw_dict = { - 'checkpoint': [bool, 'Use samples from checkpoint file', False, 'Using samples from checkpoint file'], - 'csv': [bool, 'Use samples from csv-file', False, 'Using samples from csv-file'], 'loss': [bool, 'Plot training loss', False, 'Plotting training loss'], 'average': [bool, 'Average over all samples to get one averaged curve (per condition, if any is given)', False, 'Averaging over all samples'], 'pca': [bool, 'Use PCA to reduce the dimensionality of the data', False, 'Using PCA'], @@ -307,15 +336,13 @@ def default_inputs_visualize(): 'spectogram': [bool, 'Use spectogram to visualize the frequency distribution of the data', False, 'Using spectogram'], 'fft': [bool, 'Use a FFT-histogram to visualize the frequency distribution of the data', False, 'Using FFT-Hist'], 'channel_plots': [bool, 'Plot each channel in a separate column', False, 'Plotting each channel in a separate column'], - 'path_dataset': [str, 'File to be used', os.path.join('trained_models', 'checkpoint.pt'), 'File: '], - 'path_comp_dataset': [str, 'Path to a csv dataset for comparison; comparison only for t-SNE or PCA;', os.path.join('data', 'ganAverageERP.csv'), 'Training dataset: '], - 'kw_timestep': [str, 'Keyword for the time step of the dataset', 'Time', 'Keyword for the time step of the dataset: '], - 'conditions': [str, '** Conditions to be used', '', 'Conditions: '], - 'channel_label': [str, 'Column name to detect used channels', '', 'Channel label: '], + 'model': [str, 'Use samples from checkpoint file', '', 'Using samples from model/checkpoint file (.pt)'], + 'data': [str, 'Use samples from csv-file', '', 'Using samples from csv-file'], + 'comp_data': [str, 'Path to a csv dataset for comparison; comparison only for t-SNE or PCA;', os.path.join('data', 'ganAverageERP.csv'), 'Comparison dataset: '], + 'kw_conditions': [str, '** Conditions to be used', '', 'Conditions: '], + 'kw_time': [str, 'Keyword to detect the time steps of the dataset; e.g. if [Time1, Time2, ...] -> use Time', 'Time', 'Time label: '], + 'kw_channel': [str, 'Keyword to detect used channels', '', 'Channel label: '], 'n_samples': [int, 'Total number of samples to be plotted', 0, 'Number of plotted samples: '], - # 'n_subplots': [int, 'Number of samples in one plot', 8, 'Number of samples in one plot: '], - # 'starting_row': [int, 'Starting row of the dataset', 0, 'Starting to plot from row: '], - # 'save': [bool, 'Save the generated plots in the directory "plots" instead of showing them', False, 'Saving plots'], 'channel_index': [int, '**Index of the channel to be plotted; If -1, all channels will be plotted;', -1, 'Index of the channels to be plotted: '], 'tsne_perplexity': [int, 'Perplexity of t-SNE', 40, 'Perplexity of t-SNE: '], 'tsne_iterations': [int, 'Number of iterations of t-SNE', 1000, 'Number of iterations of t-SNE: '], @@ -326,7 +353,7 @@ def default_inputs_visualize(): def default_inputs_checkpoint_to_csv(): kw_dict = { - 'file': [str, 'File to be used', os.path.join('trained_models', 'checkpoint.pt'), 'File: '], + 'model': [str, 'File to be used', os.path.join('trained_models', 'checkpoint.pt'), 'Model: '], 'key': [str, '** Key of the checkpoint file to be saved; "losses" or "generated_samples"', 'generated_samples', 'Key: '], } @@ -335,14 +362,14 @@ def default_inputs_checkpoint_to_csv(): def default_inputs_generate_samples(): kw_dict = { - 'path_file': [str, 'File which contains the trained model and its configuration', os.path.join('trained_models', 'checkpoint.pt'), 'File: '], - 'path_samples': [str, 'File where to store the generated samples; If None, then checkpoint name is used', 'None', 'Saving generated samples to file: '], - 'kw_timestep_dataset': [str, 'Keyword for the time step of the dataset; to determine the sequence length', 'Time', 'Keyword for the time step of the dataset: '], + 'seed': [bool, 'Set seed for reproducibility', None, 'Manual seed: '], + 'model': [str, 'File which contains the trained model and its configuration', os.path.join('trained_models', 'checkpoint.pt'), 'File: '], + 'save_name': [str, 'File where to store the generated samples; If None, then checkpoint name is used', '', 'Saving generated samples to file: '], + 'kw_time': [str, 'Keyword for the time step of the dataset; to determine the sequence length', 'Time', 'Keyword for the time step of the dataset: '], 'sequence_length': [int, 'total sequence length of generated sample; if -1, then sequence length from training dataset', -1, 'Total sequence length of a generated sample: '], 'num_samples_total': [int, 'total number of generated samples', 1000, 'Total number of generated samples: '], 'num_samples_parallel': [int, 'number of samples generated in parallel', 50, 'Number of samples generated in parallel: '], 'conditions': [int, '** Specific numeric conditions', None, 'Conditions: '], - 'average': [int, 'Average over n latent variables to get an averaged one', 1, 'Average over n latent variables: '], } return kw_dict @@ -350,7 +377,7 @@ def default_inputs_generate_samples(): def default_inputs_get_gan_config(): kw_dict = { - 'path_file': [str, 'File to be used', os.path.join('trained_models', 'checkpoint.pt'), 'File: '], + 'model': [str, 'File to be used', os.path.join('trained_models', 'checkpoint.pt'), 'File: '], } return kw_dict @@ -401,6 +428,9 @@ def parse_arguments(arguments, kw_dict=None, file=None): elif file == 'autoencoder_training_main.py': system_args = default_inputs_training_autoencoder() helper = HelperAutoencoder(system_args) + elif file == 'vae_training_main.py': + system_args = default_inputs_training_vae() + helper = HelperVAE(system_args) else: raise ValueError(f'File {file} not recognized.') else: diff --git a/helpers/trainer.py b/helpers/trainer.py index a40d1de..672c637 100644 --- a/helpers/trainer.py +++ b/helpers/trainer.py @@ -11,7 +11,7 @@ import nn_architecture.losses as losses from nn_architecture.losses import WassersteinGradientPenaltyLoss as Loss -from nn_architecture.models import TransformerGenerator, TransformerDiscriminator, FFGenerator, FFDiscriminator, TTSGenerator, TTSDiscriminator, DecoderGenerator, EncoderDiscriminator +from nn_architecture.models import DecoderGenerator, EncoderDiscriminator class Trainer: @@ -47,28 +47,21 @@ def __init__(self, generator, discriminator, opt): self.device = opt['device'] if 'device' in opt else 'cuda' if torch.cuda.is_available() else 'cpu' self.sequence_length = opt['sequence_length'] if 'sequence_length' in opt else 0 self.input_sequence_length = opt['input_sequence_length'] if 'input_sequence_length' in opt else 0 - self.sequence_length_generated = self.sequence_length-self.input_sequence_length if self.sequence_length != self.input_sequence_length else self.sequence_length + self.sequence_length_generated = self.sequence_length self.batch_size = opt['batch_size'] if 'batch_size' in opt else 32 self.epochs = opt['n_epochs'] if 'n_epochs' in opt else 10 - self.use_checkpoint = opt['load_checkpoint'] if 'load_checkpoint' in opt else False - self.path_checkpoint = opt['path_checkpoint'] if 'path_checkpoint' in opt else None self.latent_dim = opt['latent_dim'] if 'latent_dim' in opt else 10 self.critic_iterations = opt['critic_iterations'] if 'critic_iterations' in opt else 5 self.lambda_gp = opt['lambda_gp'] if 'lambda_gp' in opt else 10 self.sample_interval = opt['sample_interval'] if 'sample_interval' in opt else 100 - self.learning_rate = opt['learning_rate'] if 'learning_rate' in opt else 0.0001 - self.discriminator_lr = opt['discriminator_lr'] - self.generator_lr = opt['generator_lr'] + self.d_lr = opt['discriminator_lr'] if 'learning_rate' in opt else 0.0001 + self.g_lr = opt['generator_lr'] if 'learning_rate' in opt else 0.0001 self.n_conditions = opt['n_conditions'] if 'n_conditions' in opt else 0 self.n_channels = opt['n_channels'] if 'n_channels' in opt else 1 self.channel_names = opt['channel_names'] if 'channel_names' in opt else list(range(0, self.n_channels)) self.b1, self.b2 = 0, 0.9 # alternative values: .5, 0.999 self.rank = 0 # Device: cuda:0, cuda:1, ... --> Device: cuda:rank - self.lr_scheduler = opt['lr_scheduler'] - self.scheduler_warmup = opt['scheduler_warmup'] - self.scheduler_target = opt['scheduler_target'] self.start_time = time.time() - self.padding = opt['padding'] self.generator = generator self.discriminator = discriminator @@ -79,23 +72,10 @@ def __init__(self, generator, discriminator, opt): self.generator.to(self.device) self.discriminator.to(self.device) - self.d_lr = self.discriminator_lr if self.discriminator_lr is not None else self.learning_rate - self.g_lr = self.generator_lr if self.generator_lr is not None else self.learning_rate - self.generator_optimizer = torch.optim.Adam(self.generator.parameters(), lr=self.g_lr, betas=(self.b1, self.b2)) self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=self.d_lr, betas=(self.b1, self.b2)) - - self.generator_scheduler = None - self.discriminator_scheduler = None - if self.lr_scheduler.lower() == 'cycliclr': - self.generator_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=self.generator_optimizer, base_lr=self.g_lr*.1, max_lr=self.g_lr, step_size_up=500, mode='exp_range', cycle_momentum=False, verbose=False) - self.discriminator_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=self.discriminator_optimizer, base_lr=self.d_lr*.1, max_lr=self.d_lr, step_size_up=500, mode='exp_range', cycle_momentum=False, verbose=False) - if self.scheduler_target.lower() == 'generator': - self.discriminator_scheduler = None - if self.scheduler_target.lower() == 'discriminator': - self.generator_scheduler = None self.loss = Loss() if isinstance(self.loss, losses.WassersteinGradientPenaltyLoss): @@ -112,20 +92,18 @@ def __init__(self, generator, discriminator, opt): 'device': self.device, 'generator_class': generator_class, 'discriminator_class': discriminator_class, + 'model_class': 'GAN', 'sequence_length': self.sequence_length, 'sequence_length_generated': self.sequence_length_generated, - 'input_sequence_length': self.input_sequence_length, 'num_layers': opt['num_layers'], 'hidden_dim': opt['hidden_dim'], - 'activation': opt['activation'], 'latent_dim': self.latent_dim, 'batch_size': self.batch_size, 'epochs': self.epochs, 'trained_epochs': self.trained_epochs, 'sample_interval': self.sample_interval, - 'learning_rate': self.learning_rate, - 'discriminator_lr': self.discriminator_lr, - 'generator_lr': self.generator_lr, + 'discriminator_lr': self.d_lr, + 'generator_lr': self.g_lr, 'n_conditions': self.n_conditions, 'latent_dim': self.latent_dim, 'critic_iterations': self.critic_iterations, @@ -133,22 +111,23 @@ def __init__(self, generator, discriminator, opt): 'patch_size': opt['patch_size'] if 'patch_size' in opt else None, 'b1': self.b1, 'b2': self.b2, - 'path_dataset': opt['path_dataset'] if 'path_dataset' in opt else None, - 'path_autoencoder': opt['path_autoencoder'] if 'path_autoencoder' in opt else None, + 'data': opt['data'] if 'data' in opt else None, + 'autoencoder': opt['autoencoder'] if 'autoencoder' in opt else None, 'n_channels': self.n_channels, 'channel_names': self.channel_names, - 'lr_scheduler': self.lr_scheduler, - 'scheduler_warmup': self.scheduler_warmup, - 'scheduler_target': self.scheduler_target, - 'padding': self.padding, + 'seed': opt['seed'], + 'kw_conditions': opt['kw_conditions'] if 'kw_conditions' in opt else None, + 'kw_time': opt['kw_time'] if 'kw_time' in opt else None, + 'kw_channel': opt['kw_channel'] if 'kw_channel' in opt else None, + 'save_name': opt['save_name'] if 'save_name' in opt else '', 'dataloader': { - 'path_dataset': opt['path_dataset'] if 'path_dataset' in opt else None, - 'column_label': opt['conditions'] if 'conditions' in opt else None, + 'data': opt['data'] if 'data' in opt else None, + 'kw_conditions': opt['kw_conditions'] if 'kw_conditions' in opt else None, 'diff_data': opt['diff_data'] if 'diff_data' in opt else None, 'std_data': opt['std_data'] if 'std_data' in opt else None, 'norm_data': opt['norm_data'] if 'norm_data' in opt else None, - 'kw_timestep': opt['kw_timestep'] if 'kw_timestep' in opt else None, - 'channel_label': opt['channel_label'] if 'channel_label' in opt else None, + 'kw_time': opt['kw_time'] if 'kw_time' in opt else None, + 'kw_channel': opt['kw_channel'] if 'kw_channel' in opt else None, }, 'history': opt['history'] if 'history' in opt else {}, } @@ -158,8 +137,6 @@ def training(self, dataset: DataLoader): gen_samples = [] # checkpoint file settings; toggle between two checkpoints to avoid corrupted file if training is interrupted path_checkpoint = 'trained_models' - if not os.path.exists(path_checkpoint): - os.makedirs(path_checkpoint) trigger_checkpoint_01 = True checkpoint_01_file = 'checkpoint_01.pt' checkpoint_02_file = 'checkpoint_02.pt' @@ -168,52 +145,49 @@ def training(self, dataset: DataLoader): batch = None loop = tqdm(range(self.epochs)) - for epoch in loop: - # for-loop for number of batch_size entries in sessions - i_batch = 0 - d_loss_batch = 0 - g_loss_batch = 0 - for batch in dataset: - # draw batch_size samples from sessions - data = batch[:, self.n_conditions:].to(self.device) - data_labels = batch[:, :self.n_conditions, 0].unsqueeze(1).to(self.device) - - # update generator every n iterations as suggested in paper - if i_batch % self.critic_iterations == 0: - train_generator = True - else: - train_generator = False - - d_loss, g_loss, gen_samples_batch = self.batch_train(data, data_labels, train_generator) - - d_loss_batch += d_loss - g_loss_batch += g_loss - i_batch += 1 - self.d_losses.append(d_loss_batch/i_batch) - self.g_losses.append(g_loss_batch/i_batch) - - if self.scheduler_warmup < epoch: - if self.lr_scheduler.lower() == 'cycliclr': - if self.scheduler_target.lower() == 'generator' or self.scheduler_target.lower() == 'both': - self.generator_scheduler.step() - if self.scheduler_target.lower() == 'discriminator' or self.scheduler_target.lower() == 'both': - self.discriminator_scheduler.step() - - # Save a checkpoint of the trained GAN and the generated samples every sample interval - if epoch % self.sample_interval == 0: - gen_samples.append(gen_samples_batch[np.random.randint(0, len(batch))].detach().cpu().numpy()) - # save models and optimizer states as checkpoints - # toggle between checkpoint files to avoid corrupted file during training - if trigger_checkpoint_01: - self.save_checkpoint(os.path.join(path_checkpoint, checkpoint_01_file), samples=gen_samples) - trigger_checkpoint_01 = False - else: - self.save_checkpoint(os.path.join(path_checkpoint, checkpoint_02_file), samples=gen_samples) - trigger_checkpoint_01 = True + # try/except for KeyboardInterrupt --> Abort training and save model + try: + for epoch in loop: + # for-loop for number of batch_size entries in sessions + i_batch = 0 + d_loss_batch = 0 + g_loss_batch = 0 + for batch in dataset: + # draw batch_size samples from sessions + data = batch[:, self.n_conditions:].to(self.device) + data_labels = batch[:, :self.n_conditions, 0].unsqueeze(1).to(self.device) + + # update generator every n iterations as suggested in paper + if i_batch % self.critic_iterations == 0: + train_generator = True + else: + train_generator = False + + d_loss, g_loss, gen_samples_batch = self.batch_train(data, data_labels, train_generator) - self.trained_epochs += 1 - #self.print_log(epoch + 1, d_loss_batch/i_batch, g_loss_batch/i_batch) - loop.set_postfix_str(f"D LOSS: {np.round(d_loss_batch/i_batch,6)}, G LOSS: {np.round(g_loss_batch/i_batch,6)}") + d_loss_batch += d_loss + g_loss_batch += g_loss + i_batch += 1 + self.d_losses.append(d_loss_batch/i_batch) + self.g_losses.append(g_loss_batch/i_batch) + + # Save a checkpoint of the trained GAN and the generated samples every sample interval + if epoch % self.sample_interval == 0: + gen_samples.append(gen_samples_batch[np.random.randint(0, len(batch))].detach().cpu().numpy()) + # save models and optimizer states as checkpoints + # toggle between checkpoint files to avoid corrupted file during training + if trigger_checkpoint_01: + self.save_checkpoint(os.path.join(path_checkpoint, checkpoint_01_file), samples=gen_samples) + trigger_checkpoint_01 = False + else: + self.save_checkpoint(os.path.join(path_checkpoint, checkpoint_02_file), samples=gen_samples) + trigger_checkpoint_01 = True + + self.trained_epochs += 1 + loop.set_postfix_str(f"D LOSS: {np.round(d_loss_batch/i_batch,6)}, G LOSS: {np.round(g_loss_batch/i_batch,6)}") + except KeyboardInterrupt: + # save model at KeyboardInterrupt + print("Keyboard interrupt detected.\nCancel training and continue with further operations.") self.manage_checkpoints(path_checkpoint, [checkpoint_01_file, checkpoint_02_file], samples=gen_samples, update_history=True) @@ -243,7 +217,6 @@ def batch_train(self, data, data_labels, train_generator): # if self.generator is instance of EncoderGenerator encode gen_cond_data to speed up training if isinstance(self.generator, DecoderGenerator) and self.input_sequence_length != 0: - # pad gen_cond_data to match input sequence length of Encoder gen_cond_data_orig = gen_cond_data gen_cond_data = torch.cat((torch.zeros((batch_size, self.sequence_length - self.input_sequence_length, self.n_channels)).to(self.device), gen_cond_data), dim=1) gen_cond_data = self.generator.decoder.decode(gen_cond_data) @@ -276,8 +249,6 @@ def batch_train(self, data, data_labels, train_generator): fake_data = self.make_fake_data(gen_imgs, data_labels, gen_cond_data) # Compute loss/validity of generated data and update generator - pad = torch.zeros((fake_data.shape[0], self.padding, fake_data.shape[-1])) - fake_data = torch.cat((fake_data, pad.to(self.device)), dim=1) validity = self.discriminator(fake_data) g_loss = self.loss.generator(validity) self.generator_optimizer.zero_grad() @@ -321,11 +292,9 @@ def batch_train(self, data, data_labels, train_generator): if decode_imgs: if not hasattr(self.generator, 'module'): fake_input = fake_data[:,:,:self.generator.channels].reshape(-1, self.generator.seq_len, self.generator.channels) - fake_input = fake_input[:,:-self.padding,:] if self.padding > 0 else fake_input gen_samples = self.generator.decoder.decode(fake_input) else: fake_input = fake_data[:,:,:self.generator.module.channels].reshape(-1, self.generator.module.seq_len, self.generator.module.channels) - fake_input = fake_input[:,:-self.padding,:] if self.padding > 0 else fake_input gen_samples = self.generator.module.decoder.decode(fake_input) # concatenate gen_cond_data_orig with decoded fake_data @@ -348,8 +317,6 @@ def batch_train(self, data, data_labels, train_generator): real_data = self.discriminator.module.encoder.encode(data) if isinstance(self.discriminator.module, EncoderDiscriminator) and not self.discriminator.module.encode else data real_data = self.make_fake_data(real_data, disc_labels) - padding = torch.zeros((real_data.shape[0], self.padding, real_data.shape[-1])) - real_data = torch.cat((real_data, padding.to(self.device)), dim=1) # Loss for real and generated samples real_data.requires_grad = True @@ -378,7 +345,7 @@ def save_checkpoint(self, path_checkpoint=None, samples=None, generator=None, di if update_history: self.configuration['trained_epochs'] = self.trained_epochs - self.configuration['history']['trained_epochs'] = [self.trained_epochs] + self.configuration['history']['trained_epochs'] = self.configuration['history']['trained_epochs'] + [self.trained_epochs] self.configuration['train_time'] = time.strftime('%H:%M:%S', time.gmtime(time.time() - self.start_time)) state_dict = { @@ -386,8 +353,6 @@ def save_checkpoint(self, path_checkpoint=None, samples=None, generator=None, di 'discriminator': discriminator.state_dict(), 'generator_optimizer': self.generator_optimizer.state_dict(), 'discriminator_optimizer': self.discriminator_optimizer.state_dict(), - 'generator_scheduler': None if self.generator_scheduler is None else self.generator_scheduler.state_dict(), - 'discriminator_scheduler': None if self.discriminator_scheduler is None else self.discriminator_scheduler.state_dict(), 'generator_loss': self.g_losses, 'discriminator_loss': self.d_losses, 'samples': samples, @@ -408,33 +373,6 @@ def load_checkpoint(self, path_checkpoint): self.discriminator.load_state_dict(state_dict['discriminator']) self.generator_optimizer.load_state_dict(state_dict['generator_optimizer']) self.discriminator_optimizer.load_state_dict(state_dict['discriminator_optimizer']) - - ''' - self.generator_scheduler = None - self.discriminator_scheduler = None - if self.lr_scheduler.lower() == 'cycliclr': - self.generator_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=self.generator_optimizer, base_lr=self.g_lr*.1, max_lr=self.g_lr, step_size_up=500, mode='exp_range', cycle_momentum=False, verbose=False) - self.discriminator_scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer=self.discriminator_optimizer, base_lr=self.d_lr*.1, max_lr=self.d_lr, step_size_up=500, mode='exp_range', cycle_momentum=False, verbose=False) - elif self.lr_scheduler.lower() == 'reducelronplateau': - self.generator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=self.generator_optimizer, factor=0.1, cooldown=50, verbose=False) - self.discriminator_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=self.discriminator_optimizer, factor=0.1, cooldown=50, verbose=False) - if self.scheduler_target.lower() == 'generator': - self.discriminator_scheduler = None - if self.scheduler_target.lower() == 'discriminator': - self.generator_scheduler = None - ''' - - if self.lr_scheduler != '' and state_dict['configuration']['lr_scheduler'] != '': - if self.scheduler_target == 'generator' or self.scheduler_target == 'both': - self.generator_scheduler.load_state_dict(state_dict['generator_scheduler']) - for i in range(len(self.generator_optimizer.param_groups)): - self.generator_optimizer.param_groups[i]['lr'] = self.generator_scheduler.get_last_lr()[0] - - if self.scheduler_target == 'discriminator' or self.scheduler_target == 'both': - self.discriminator_scheduler.load_state_dict(state_dict['discriminator_scheduler']) - for i in range(len(self.generator_optimizer.param_groups)): - self.discriminator_optimizer.param_groups[i]['lr'] = self.discriminator_scheduler.get_last_lr()[0] - print(f"Device {self.device}:{self.rank}: Using pretrained GAN.") else: Warning("No checkpoint-file found. Using random initialization.") @@ -498,8 +436,7 @@ def make_fake_data(self, gen_imgs, data_labels, condition_data=None): class AETrainer(Trainer): - """Trainer for conditional Wasserstein-GAN with gradient penalty. - Source: https://arxiv.org/pdf/1704.00028.pdf""" + """Trainer for Autoencoder.""" def __init__(self, model, opt): # training configuration @@ -510,6 +447,9 @@ def __init__(self, model, opt): self.sample_interval = opt['sample_interval'] if 'sample_interval' in opt else 100 self.learning_rate = opt['learning_rate'] if 'learning_rate' in opt else 0.0001 self.rank = 0 # Device: cuda:0, cuda:1, ... --> Device: cuda:rank + self.training_levels = opt['training_levels'] + self.training_level = opt['training_level'] + self.start_time = time.time() # model self.model = model @@ -532,47 +472,51 @@ def __init__(self, model, opt): 'sample_interval': self.sample_interval, 'learning_rate': self.learning_rate, 'hidden_dim': opt['hidden_dim'], - 'path_dataset': opt['path_dataset'] if 'path_dataset' in opt else None, - 'path_checkpoint': opt['path_checkpoint'] if 'path_checkpoint' in opt else None, + 'data': opt['data'] if 'data' in opt else None, + 'checkpoint': opt['checkpoint'] if 'checkpoint' in opt else None, 'channels_in': opt['channels_in'], - 'timeseries_in': opt['timeseries_in'], - 'timeseries_out': opt['timeseries_out'] if 'timeseries_out' in opt else None, + 'time_in': opt['time_in'], + 'time_out': opt['time_out'] if 'time_out' in opt else None, 'channels_out': opt['channels_out'] if 'channels_out' in opt else None, 'sequence_length': opt['sequence_length'], 'target': opt['target'] if 'target' in opt else None, - # 'conditions': opt['conditions'] if 'conditions' in opt else None, - 'channel_label': opt['channel_label'] if 'channel_label' in opt else None, 'trained_epochs': self.trained_epochs, 'input_dim': opt['input_dim'], 'output_dim': opt['output_dim'], 'output_dim_2': opt['output_dim_2'], 'num_layers': opt['num_layers'], 'num_heads': opt['num_heads'], - 'activation': opt['activation'], + 'seed': opt['seed'], + 'kw_time': opt['kw_time'] if 'kw_time' in opt else None, + 'kw_channel': opt['kw_channel'] if 'kw_channel' in opt else None, + 'kw_conditions': opt['kw_conditions'] if 'kw_conditions' in opt else None, + 'save_name': opt['save_name'] if 'save_name' in opt else '', 'dataloader': { - 'path_dataset': opt['path_dataset'] if 'path_dataset' in opt else None, - 'col_label': opt['conditions'] if 'conditions' in opt else None, + 'data': opt['data'] if 'data' in opt else None, 'diff_data': opt['diff_data'] if 'diff_data' in opt else None, 'std_data': opt['std_data'] if 'std_data' in opt else None, 'norm_data': opt['norm_data'] if 'norm_data' in opt else None, - 'kw_timestep': opt['kw_timestep'] if 'kw_timestep' in opt else None, - 'channel_label': opt['channel_label'] if 'channel_label' in opt else None, + 'kw_conditions': opt['kw_conditions'] if 'kw_conditions' in opt else None, + 'kw_time': opt['kw_time'] if 'kw_time' in opt else None, + 'kw_channel': opt['kw_channel'] if 'kw_channel' in opt else None, }, 'history': opt['history'] if 'history' in opt else None, } def training(self, train_data, test_data): - try: - path_checkpoint = 'trained_ae' - if not os.path.exists(path_checkpoint): - os.makedirs(path_checkpoint) - trigger_checkpoint_01 = True - checkpoint_01_file = 'checkpoint_01.pt' - checkpoint_02_file = 'checkpoint_02.pt' + path_checkpoint = 'trained_ae' + if not os.path.exists(path_checkpoint): + os.makedirs(path_checkpoint) + trigger_checkpoint_01 = True + checkpoint_01_file = 'checkpoint_01.pt' + checkpoint_02_file = 'checkpoint_02.pt' - samples = [] + samples = [] - loop = tqdm(range(self.epochs)) + loop = tqdm(range(self.epochs)) + + # try/except for KeyboardInterrupt --> Abort training and save model + try: for epoch in loop: train_loss, test_loss, sample = self.batch_train(train_data, test_data) self.train_loss.append(train_loss) @@ -583,7 +527,7 @@ def training(self, train_data, test_data): if len(sample) > 0: samples.append(sample) - # Save a checkpoint of the trained GAN and the generated samples every sample interval + # Save a checkpoint of the trained AE and the generated samples every sample interval if epoch % self.sample_interval == 0: # save models and optimizer states as checkpoints # toggle between checkpoint files to avoid corrupted file during training @@ -595,15 +539,12 @@ def training(self, train_data, test_data): trigger_checkpoint_01 = True self.trained_epochs += 1 - #self.print_log(epoch + 1, train_loss, test_loss) - - self.manage_checkpoints(path_checkpoint, [checkpoint_01_file, checkpoint_02_file], update_history=True, samples=samples) - return samples - except KeyboardInterrupt: # save model at KeyboardInterrupt - print("Keyboard interrupt detected.\nSaving checkpoint...") - self.save_checkpoint(update_history=True, samples=samples) + print("Keyboard interrupt detected.\nCancel training and continue with further operations.") + + self.manage_checkpoints(path_checkpoint, [checkpoint_01_file, checkpoint_02_file], update_history=True, samples=samples) + return samples def batch_train(self, train_data, test_data): train_loss = self.train_model(train_data) @@ -615,8 +556,6 @@ def train_model(self, data): total_loss = 0 for batch in data: self.optimizer.zero_grad() - # inputs = nn.BatchNorm1d(batch.shape[-1])(batch.float().permute(0, 2, 1)).permute(0, 2, 1) - # inputs = filter(inputs.detach().cpu().numpy(), win_len=random.randint(29, 50), dtype=torch.Tensor) inputs = batch.float().to(self.model.device) outputs = self.model(inputs) loss = self.loss(outputs, inputs) @@ -636,9 +575,9 @@ def test_model(self, data): loss = self.loss(outputs, inputs) total_loss += loss.item() if self.trained_epochs % self.sample_interval == 0: - samples.append(np.concatenate([inputs.unsqueeze(1).detach().cpu().numpy(), outputs.unsqueeze(1).detach().cpu().numpy()], axis=1)) + samples.append(np.stack([inputs.cpu().numpy(), outputs.cpu().numpy()], axis=1)) if len(samples) > 0: - samples = np.concatenate(samples, axis=0)[np.random.randint(0, len(samples))] + samples = np.concatenate(samples, axis=0)[np.random.randint(0, len(samples))].reshape(1, *samples[0].shape[1:]) return total_loss / len(data), samples def save_checkpoint(self, path_checkpoint=None, model=None, update_history=False, samples=None): @@ -654,7 +593,8 @@ def save_checkpoint(self, path_checkpoint=None, model=None, update_history=False if update_history: self.configuration['trained_epochs'] = self.trained_epochs self.configuration['history']['trained_epochs'] = self.configuration['history']['trained_epochs'] + [self.trained_epochs] - + self.configuration['train_time'] = time.strftime('%H:%M:%S', time.gmtime(time.time() - self.start_time)) + checkpoint_dict = { 'model': model.state_dict(), 'optimizer': self.optimizer.state_dict(), @@ -665,6 +605,202 @@ def save_checkpoint(self, path_checkpoint=None, model=None, update_history=False 'configuration': self.configuration, } + if self.training_levels == 2 and self.training_level == 2: + checkpoint_dict['model_1'] = self.model1_states['model'] + checkpoint_dict['model_1_optimizer'] = self.model1_states['optimizer'] + + if not (self.training_levels == 2 and self.training_level == 1) or 'checkpoint.pt' not in path_checkpoint: + torch.save(checkpoint_dict, path_checkpoint) + + def load_checkpoint(self, path_checkpoint): + if os.path.isfile(path_checkpoint): + # load state_dicts + state_dict = torch.load(path_checkpoint, map_location=self.device) + consume_prefix_in_state_dict_if_present(state_dict['model'], 'module.') + if self.training_levels == 2 and self.training_level == 1: + self.model.load_state_dict(state_dict['model_1']) + self.optimizer.load_state_dict(state_dict['model_1_optimizer']) + else: + self.model.load_state_dict(state_dict['model']) + self.optimizer.load_state_dict(state_dict['optimizer']) + else: + raise FileNotFoundError(f"Checkpoint-file {path_checkpoint} was not found.") + + def manage_checkpoints(self, path_checkpoint: str, checkpoint_files: list, model=None, update_history=False, samples=None): + """if training was successful delete the sub-checkpoint files and save the most current state as checkpoint, + but without generated samples to keep memory usage low. Checkpoint should be used for further training only. + Therefore, there's no need for the saved samples.""" + + print("Managing checkpoints...") + # save current model as checkpoint.pt + self.save_checkpoint(path_checkpoint=os.path.join(path_checkpoint, 'checkpoint.pt'), model=None, update_history=update_history, samples=samples) + + for f in checkpoint_files: + if os.path.exists(os.path.join(path_checkpoint, f)): + os.remove(os.path.join(path_checkpoint, f)) + + def print_log(self, current_epoch, train_loss, test_loss): + print( + "[Epoch %d/%d] [Train loss: %f] [Test loss: %f]" % (current_epoch, self.epochs, train_loss, test_loss) + ) + + def set_optimizer_state(self, optimizer): + self.optimizer.load_state_dict(optimizer) + print('Optimizer state loaded successfully.') + +class VAETrainer(Trainer): + """Trainer for VAE""" + + def __init__(self, model, opt): + # training configuration + super().__init__() + self.device = opt['device'] if 'device' in opt else 'cuda' if torch.cuda.is_available() else 'cpu' + self.batch_size = opt['batch_size'] if 'batch_size' in opt else 32 + self.epochs = opt['n_epochs'] if 'n_epochs' in opt else 10 + self.sample_interval = opt['sample_interval'] if 'sample_interval' in opt else 100 + self.learning_rate = opt['learning_rate'] if 'learning_rate' in opt else 0.0001 + self.rank = 0 # Device: cuda:0, cuda:1, ... --> Device: cuda:rank + self.kl_alpha = opt['kl_alpha'] if 'kl_alpha' in opt else .00001 + self.n_conditions = len(opt['kw_conditions']) if 'kw_conditions' in opt else 0 + + # model + self.model = model + self.model.to(self.device) + + # optimizer and loss + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) + self.loss = torch.nn.MSELoss() + + # training statistics + self.trained_epochs = 0 + self.train_loss = [] + + self.configuration = { + 'device': self.device, + 'model_class': str(self.model.__class__.__name__), + 'batch_size': self.batch_size, + 'n_epochs': self.epochs, + 'sample_interval': self.sample_interval, + 'learning_rate': self.learning_rate, + 'hidden_dim': opt['hidden_dim'], + 'encoded_dim': opt['encoded_dim'], + 'path_dataset': opt['path_dataset'] if 'path_dataset' in opt else None, + 'path_checkpoint': opt['path_checkpoint'] if 'path_checkpoint' in opt else None, + 'kw_channel': opt['kw_channel'] if 'kw_channel' in opt else None, + 'kw_conditions': opt['kw_conditions'] if 'kw_conditions' in opt else None, + 'kw_time': opt['kw_time'] if 'kw_time' in opt else None, + 'trained_epochs': self.trained_epochs, + 'input_dim': opt['input_dim'], + 'save_name': opt['save_name'] if 'save_name' in opt else '', + + 'dataloader': { + 'data': opt['data'] if 'data' in opt else None, + 'diff_data': opt['diff_data'] if 'diff_data' in opt else None, + 'std_data': opt['std_data'] if 'std_data' in opt else None, + 'norm_data': opt['norm_data'] if 'norm_data' in opt else None, + 'kw_time': opt['kw_time'] if 'kw_time' in opt else None, + 'kw_conditions': opt['kw_conditions'] if 'kw_conditions' in opt else None, + 'kw_channel': opt['kw_channel'] if 'kw_channel' in opt else None, + }, + 'history': opt['history'] if 'history' in opt else None, + } + + def training(self, dataset: DataLoader): + try: + self.recon_losses = [] + self.kl_losses = [] + self.losses = [] + gen_samples = [] + + path_checkpoint = 'trained_vae' + if not os.path.exists(path_checkpoint): + os.makedirs(path_checkpoint) + trigger_checkpoint_01 = True + checkpoint_01_file = 'checkpoint_01.pt' + checkpoint_02_file = 'checkpoint_02.pt' + + loop = tqdm(range(self.epochs)) + for epoch in loop: + self.epoch = epoch + epoch_loss = self.batch_train(dataset) + self.train_loss.append(epoch_loss) + loop.set_postfix(loss=self.batch_loss.item()) + + #Generate samples on interval + if self.epoch % self.sample_interval == 0: + generated_samples = torch.Tensor(self.model.generate_samples(loader=dataset,condition=0,num_samples=1000)).to(self.device) + gen_samples.append(generated_samples[np.random.randint(0, generated_samples.shape[0])].detach().tolist()) #TODO: Not sure if this is the same as the GAN + + # save models and optimizer states as checkpoints + # toggle between checkpoint files to avoid corrupted file during training + if trigger_checkpoint_01: + self.save_checkpoint(os.path.join(path_checkpoint, checkpoint_01_file), samples=gen_samples) + trigger_checkpoint_01 = False + else: + self.save_checkpoint(os.path.join(path_checkpoint, checkpoint_02_file), samples=gen_samples) + trigger_checkpoint_01 = True + + self.trained_epochs += 1 + + self.manage_checkpoints(path_checkpoint, [checkpoint_01_file, checkpoint_02_file], update_history=True, samples=gen_samples) + + return gen_samples + + except KeyboardInterrupt: + # save model at KeyboardInterrupt + print("Keyboard interrupt detected.\nSaving checkpoint...") + self.save_checkpoint(update_history=True, samples=gen_samples) + + def batch_train(self, data): + + self.model.train() + total_loss = 0 + for batch in data: + + #Run data through model + inputs = batch[:,self.n_conditions:,:].to(self.model.device) + x_reconstruction, mu, sigma = self.model(inputs) + + #Loss + reconstruction_loss = self.loss(x_reconstruction, inputs) + kl_div = torch.mean(-0.5 * torch.sum(1 + sigma - mu**2 - torch.exp(sigma), axis=1), dim=0) + self.batch_loss = reconstruction_loss + kl_div*self.kl_alpha + + #Update + self.optimizer.zero_grad() + self.batch_loss.backward() + self.optimizer.step() + total_loss += self.batch_loss.item() + + self.recon_losses.append(reconstruction_loss.detach().tolist()) + self.kl_losses.append(kl_div.detach().tolist()) + self.losses.append(self.batch_loss.detach().tolist()) + + return total_loss / len(data) + + def save_checkpoint(self, path_checkpoint=None, model=None, update_history=False, samples=None): + if path_checkpoint is None: + default_path = 'trained_ae' + if not os.path.exists(default_path): + os.makedirs(default_path) + path_checkpoint = os.path.join(default_path, 'checkpoint.pt') + + if model is None: + model = self.model + + if update_history: + self.configuration['trained_epochs'] = self.trained_epochs + self.configuration['history']['trained_epochs'] = self.configuration['history']['trained_epochs'] + [self.trained_epochs] + + checkpoint_dict = { + 'model': model.state_dict(), + 'optimizer': self.optimizer.state_dict(), + 'train_loss': self.train_loss, + 'trained_epochs': self.trained_epochs, + 'samples': samples, + 'configuration': self.configuration, + } + torch.save(checkpoint_dict, path_checkpoint) def load_checkpoint(self, path_checkpoint): diff --git a/helpers/trainer_3c.py b/helpers/trainer_3c.py deleted file mode 100644 index d7b6535..0000000 --- a/helpers/trainer_3c.py +++ /dev/null @@ -1,206 +0,0 @@ -import os - -import torch -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP - - -class Trainer: - def __init__(self, classifier, critic, opt): - self.device = opt['device'] - self.classifier = classifier.to(self.device) - self.critic = critic - if self.critic is not None: - self.critic = self.critic.to(self.device) - - self.loss_fn = torch.nn.MSELoss() - self.optimizer = torch.optim.Adam(self.classifier.parameters(), lr=opt['learning_rate']) - - self.epochs = opt['n_epochs'] if 'n_epochs' in opt else 1 - self.batch_size = opt['batch_size'] if 'batch_size' in opt else 1 - self.learning_rate = opt['learning_rate'] if 'learning_rate' in opt else 1e-3 - self.sample_interval = opt['sample_interval'] if 'sample_interval' in opt else 50 - - self.opt = opt - - self.rank = 0 - - self.use_checkpoint = opt['load_checkpoint'] if 'load_checkpoint' in opt else False - self.path_checkpoint = os.path.join('trained_3c', 'checkpoint.pt') - - def load_checkpoint(self, path): - state_dict = torch.load(path, map_location=self.device) - self.classifier.load_state_dict(state_dict['classifier']) - self.critic.load_state_dict(state_dict['critic']) - self.optimizer.load_state_dict(state_dict['optimizer']) - - def save_checkpoint(self, path, test_dataset=None, loss=None, classifier=None): - if classifier is None: - classifier = self.classifier - - state_dict = { - 'classifier': classifier.state_dict(), - 'critic': self.critic.state_dict(), - 'optimizer': self.optimizer.state_dict(), - 'configuration': self.opt, - 'train_loss': [l[0] for l in loss] if loss is not None else None, - 'test_loss': [l[1] for l in loss] if loss is not None else None, - 'accuracy': [l[2] for l in loss] if loss is not None else None - } - torch.save(state_dict, path) - - def train(self, train_data, train_labels, test_data, test_labels): - if train_data.shape[0] != train_labels.shape[0]: - raise RuntimeError("Train data and labels must have the same number of samples.") - if test_data.shape[0] != test_labels.shape[0]: - raise RuntimeError("Test data and labels must have the same number of samples.") - - test_data = test_data.to(self.device) - test_labels = test_labels.to(self.device) - - # compute scores on test data - ones_test = torch.ones_like(test_labels) - zeros_test = torch.zeros_like(test_labels) - scores = self.compute_scores(test_data, ones_test, zeros_test) - test_data = test_data.view(-1, 1, 1, test_data.shape[-1]) - scores = scores.view(-1, 2, 1, 1).repeat(1, 1, 1, test_data.shape[-1]).to(self.device) - test_data = torch.concat((test_data, scores), dim=1).to(self.device) - - loss_train = 9e9 - loss_test = 9e9 - accuracy = 0 - loss = [] - - for epoch in range(self.epochs): - # Train - # shuffle train_data and train_labels - idx = torch.randperm(train_data.shape[0]) - train_data = train_data[idx] - train_labels = train_labels[idx] - - for batch in range(0, train_data.shape[0], self.batch_size): - # Check if remaining samples are enough for a batch and adjust if not - if batch + self.batch_size > train_data.shape[0]: - batch_size = train_data.shape[0] - batch - else: - batch_size = self.batch_size - - data = train_data[batch_size:batch_size + self.batch_size].to(self.device) - real_labels = train_labels[batch_size:batch_size + self.batch_size].to(self.device) - - # get scores for all types of conditions from critic and attach to data - ones = torch.ones_like(real_labels) - zeros = torch.zeros_like(real_labels) - scores = self.compute_scores(data, ones, zeros) - data = data.view(-1, 1, 1, data.shape[-1]) - scores = scores.view(-1, 2, 1, 1).repeat(1, 1, 1, data.shape[-1]).to(self.device) - data = torch.concat((data, scores), dim=1) - loss_train = self.batch_train(data, real_labels) - - # Test - loss_test, accuracy = self.test(test_data, test_labels) - - loss.append((loss_train, loss_test, accuracy)) - - # save checkpoint every n epochs - if epoch % self.sample_interval == 0: - self.save_checkpoint(self.path_checkpoint, None, loss) - - print(f"Epoch [{epoch + 1}/{self.epochs}]: " - f"Loss train: {loss_train:.4f}, " - f"Loss test: {loss_test:.4f}, " - f"Accuracy: {accuracy:.4f}") - - if self.rank == 0: - self.save_checkpoint(self.path_checkpoint, None, loss, None) - - return loss - - def batch_train(self, data, labels): - self.classifier.train() - data, labels = data.to(self.device), labels.to(self.device) - - # calc loss - # shape of data: (batch_size, channels, 1, sequence_length) - # shape of labels/output: (batch_size, n_conditions) - output = self.classifier(data) - loss = self.loss_fn(output, labels) - - # optimize - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - return loss.item() - - def test(self, data, labels): - self.classifier.eval() - data, labels = data.to(self.device), labels.to(self.device) - - output = self.classifier(data) - loss = self.loss_fn(output, labels) - - # accuracy - output = output.round() - accuracy = (output == labels).sum() / labels.shape[0] - - return loss.item(), accuracy.item() - - def compute_scores(self, data, real_labels, *args): - """Compute the scores for the given data and all combinations of labels.""" - labels = [real_labels] - labels.extend(args) - score = torch.zeros((real_labels.shape[0], len(labels))) - for j, label in enumerate(labels): - batch_labels = label.view(-1, 1, 1, 1).repeat(1, 1, 1, data.shape[1]) - batch_data = data.view(-1, 1, 1, data.shape[1]) - batch_data = torch.cat((batch_data, batch_labels), dim=1).to(self.device) - validity = self.critic(batch_data) - score[:, j] = validity[:, 0] - - return score - - def print_log(self, current_epoch, train_loss, test_loss, test_accuracy): - print( - "[Epoch %d/%d] [Train loss: %f] [Test loss: %f] [Accuracy: %f]" - % (current_epoch, self.epochs, - train_loss, test_loss, test_accuracy) - ) - - -class DDPTrainer(Trainer): - - def __init__(self, classifier, critic, opt): - super(Trainer, self).__init__() - - # training configuration - super().__init__(classifier, critic, opt) - - self.world_size = opt['world_size'] if 'world_size' in opt else 1 - - def set_ddp_framework(self): - # set ddp generator and discriminator - self.classifier.to(self.rank) - self.classifier = DDP(self.classifier, device_ids=[self.rank]) - - # set ddp optimizer - opt_state = self.optimizer.state_dict() - self.optimizer = torch.optim.Adam(self.classifier.parameters(), lr=self.learning_rate) - self.optimizer.load_state_dict(opt_state) - - def set_device(self, rank): - self.rank = rank - self.device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else f'cpu:{rank}') - - def save_checkpoint(self, path_checkpoint=None, test_dataset=None, loss=None, classifier=None): - if self.rank == 0: - super().save_checkpoint(path_checkpoint, test_dataset=test_dataset, loss=loss, classifier=self.classifier.module) - # dist.barrier() - - def print_log(self, current_epoch, train_loss, test_loss, test_accuracy): - # average the loss across all processes before printing - reduce_tensor = torch.tensor([train_loss, test_loss, test_accuracy], dtype=torch.float32, device=self.device) - dist.all_reduce(reduce_tensor, op=dist.ReduceOp.SUM) - reduce_tensor /= self.world_size - if self.rank == 0: - super().print_log(current_epoch, reduce_tensor[0], reduce_tensor[1], reduce_tensor[2]) diff --git a/helpers/trainer_classifier.py b/helpers/trainer_classifier.py deleted file mode 100644 index da72217..0000000 --- a/helpers/trainer_classifier.py +++ /dev/null @@ -1,172 +0,0 @@ -import os - -import torch -import torch.distributed as dist -from torch.nn.parallel import DistributedDataParallel as DDP - - -class Trainer: - def __init__(self, model, opt): - self.model = model - self.loss_fn = torch.nn.MSELoss() - self.device = opt['device'] - - self.optimizer = torch.optim.Adam(model.parameters(), lr=opt['learning_rate']) - - self.epochs = opt['n_epochs'] if 'n_epochs' in opt else 1 - self.batch_size = opt['batch_size'] if 'batch_size' in opt else 1 - self.learning_rate = opt['learning_rate'] if 'learning_rate' in opt else 1e-3 - self.sample_interval = opt['sample_interval'] if 'sample_interval' in opt else 50 - - self.opt = opt - - self.rank = 0 - - self.use_checkpoint = opt['load_checkpoint'] if 'load_checkpoint' in opt else False - self.path_checkpoint = os.path.join('trained_classifier', 'checkpoint.pt') - # self.path_checkpoint = opt['path_checkpoint'] if 'path_checkpoint' in opt else None - - def load_checkpoint(self, path): - state_dict = torch.load(path, map_location=self.device) - self.model.load_state_dict(state_dict['model']) - self.optimizer.load_state_dict(state_dict['optimizer']) - - def save_checkpoint(self, path, test_dataset=None, loss=None, model=None): - if model is None: - model = self.model - - state_dict = { - 'model': model.state_dict(), - 'optimizer': self.optimizer.state_dict(), - 'configuration': self.opt, - 'test_dataset': test_dataset, - 'train_loss': [l[0] for l in loss] if loss is not None else None, - 'test_loss': [l[1] for l in loss] if loss is not None else None, - 'accuracy': [l[2] for l in loss] if loss is not None else None - } - torch.save(state_dict, path) - - def train(self, train_data, train_labels, test_data, test_labels): - if train_data.shape[0] != train_labels.shape[0]: - raise RuntimeError("Train data and labels must have the same number of samples.") - if test_data.shape[0] != test_labels.shape[0]: - raise RuntimeError("Test data and labels must have the same number of samples.") - - loss_train = 9e9 - loss_test = 9e9 - accuracy = 0 - loss = [] - - test_data = test_data.to(self.device) - test_labels = test_labels.to(self.device) - - for epoch in range(self.epochs): - # Train - # shuffle train_data and train_labels - idx = torch.randperm(train_data.shape[0]) - train_data = train_data[idx] - train_labels = train_labels[idx] - for batch in range(0, train_data.shape[0], self.batch_size): - # Check if remaining samples are enough for a batch and adjust if not - if batch + self.batch_size > train_data.shape[0]: - batch_size = train_data.shape[0] - batch - else: - batch_size = self.batch_size - - data = train_data[batch_size:batch_size + self.batch_size].to(self.device) - labels = train_labels[batch_size:batch_size + self.batch_size].to(self.device) - loss_train = self.batch_train(data, labels) - - # Test - loss_test, accuracy = self.test(test_data, test_labels) - - loss.append((loss_train, loss_test, accuracy)) - - # save checkpoint every n epochs - if epoch % self.sample_interval == 0: - self.save_checkpoint(self.path_checkpoint, torch.concat((test_labels, test_data), dim=1), loss) - - print(f"Epoch [{epoch + 1}/{self.epochs}]: " - f"Loss train: {loss_train:.4f}, " - f"Loss test: {loss_test:.4f}, " - f"Accuracy: {accuracy:.4f}") - - if self.rank == 0: - self.save_checkpoint(self.path_checkpoint, None, loss, None) - - return loss - - def batch_train(self, data, labels): - self.model.train() - data, labels = data.to(self.device), labels.to(self.device) - - # calc loss - # shape of data: (batch_size, channels, 1, sequence_length) - # shape of labels/output: (batch_size, n_conditions) - output = self.model(data.view(data.shape[0], 1, 1, data.shape[-1])) - loss = self.loss_fn(output, labels) - - # optimize - self.optimizer.zero_grad() - loss.backward() - self.optimizer.step() - - return loss.item() - - def test(self, data, labels): - self.model.eval() - data, labels = data.to(self.device), labels.to(self.device) - - output = self.model(data.view(data.shape[0], 1, 1, data.shape[-1])) - loss = self.loss_fn(output, labels) - - # accuracy - output = output.round() - accuracy = (output == labels).sum() / labels.shape[0] - - return loss.item(), accuracy.item() - - def print_log(self, current_epoch, train_loss, test_loss, test_accuracy): - print( - "[Epoch %d/%d] [Train loss: %f] [Test loss: %f] [Accuracy: %f]" - % (current_epoch, self.epochs, - train_loss, test_loss, test_accuracy) - ) - - -class DDPTrainer(Trainer): - - def __init__(self, model, opt): - super(Trainer, self).__init__() - - # training configuration - super().__init__(model, opt) - - self.world_size = opt['world_size'] if 'world_size' in opt else 1 - - def set_ddp_framework(self): - # set ddp generator and discriminator - self.model.to(self.rank) - self.model = DDP(self.model, device_ids=[self.rank]) - - # set ddp optimizer - opt_state = self.optimizer.state_dict() - self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.learning_rate) - self.optimizer.load_state_dict(opt_state) - - def set_device(self, rank): - self.rank = rank - self.device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else f'cpu:{rank}') - - def save_checkpoint(self, path_checkpoint=None, test_dataset=None, loss=None, model=None): - if self.rank == 0: - super().save_checkpoint(path_checkpoint, test_dataset=test_dataset, loss=loss, model=self.model.module) - # dist.barrier() - - def print_log(self, current_epoch, train_loss, test_loss, test_accuracy): - # average the loss across all processes before printing - reduce_tensor = torch.tensor([train_loss, test_loss, test_accuracy], dtype=torch.float32, device=self.device) - dist.all_reduce(reduce_tensor, op=dist.ReduceOp.SUM) - reduce_tensor /= self.world_size - if self.rank == 0: - super().print_log(current_epoch, reduce_tensor[0], reduce_tensor[1], reduce_tensor[2]) diff --git a/helpers/visualize_pca.py b/helpers/visualize_pca.py index 0fdf7e6..7c8e560 100644 --- a/helpers/visualize_pca.py +++ b/helpers/visualize_pca.py @@ -130,7 +130,6 @@ def visualization_dim_reduction(ori_data, generated_data, analysis, save, save_n # Load data dataloader = Dataloader(path=ori_file, norm_data=True) ori_data = dataloader.get_data().unsqueeze(-1).detach().cpu().numpy() - # ori_data = np.load('data/real_data.npy') gen_data = pd.read_csv(gen_file, header=None).to_numpy() gen_data = gen_data.reshape(gen_data.shape[0], gen_data.shape[1], 1) diff --git a/helpers/visualize_spectogram.py b/helpers/visualize_spectogram.py index e417d9f..be15a22 100644 --- a/helpers/visualize_spectogram.py +++ b/helpers/visualize_spectogram.py @@ -43,10 +43,6 @@ def plot_fft_hist(data, save=False, path_save=None): else: plt.show() - # plot nspectrum per frequency, with a semilog scale on nspectrum - # plt.semilogy(freq, nspectrum[10]) - # plt.show() - return xbins, ybins, h.T @@ -54,18 +50,10 @@ def plot_spectogram(x, save=False, path_save=None): """Plot the spectogram of a dataset along the time axis (dim=1).""" fs = 500 - - # interpolate the data to have a length of fs*100 - # x_new = np.zeros((x.shape[0], fs * 100)) - # for i in range(x.shape[0]): - # x_new[i] = np.interp(np.linspace(0, x.shape[1], fs * 100), np.arange(x.shape[1]), x[i]) - # x = x_new - f, t, Sxx = signal.spectrogram(x.T, fs) Sxx = np.sum(Sxx, axis=0) plt.pcolormesh(t, f, Sxx, shading='gouraud') - # plt.yscale('log') plt.ylim(10**-3, 50**1) plt.ylabel('Frequency [Hz]') plt.xlabel('Time [sec]') diff --git a/nn_architecture/ae_networks.py b/nn_architecture/ae_networks.py index 86fca4e..a36a946 100644 --- a/nn_architecture/ae_networks.py +++ b/nn_architecture/ae_networks.py @@ -1,16 +1,6 @@ -import math -import os -import random -import warnings -from typing import Optional - -import pandas as pd import torch import torch.nn as nn -from torch import Tensor - -# from utils.get_filter import moving_average as filter class Autoencoder(nn.Module): @@ -18,7 +8,7 @@ class Autoencoder(nn.Module): TARGET_TIMESERIES = 1 TARGET_BOTH = 2 - def __init__(self, input_dim: int, output_dim: int, output_dim_2: int, hidden_dim: int, target: int, num_layers=3, dropout=0.1, activation='linear', **kwargs): + def __init__(self, input_dim: int, output_dim: int, output_dim_2: int, hidden_dim: int, target: int, num_layers=3, dropout=0.1, activation_decoder='linear', **kwargs): super(Autoencoder, self).__init__() self.input_dim = input_dim self.output_dim = output_dim @@ -27,18 +17,19 @@ def __init__(self, input_dim: int, output_dim: int, output_dim_2: int, hidden_di self.num_layers = num_layers self.dropout = dropout self.target = target - if activation == 'relu': - self.activation = nn.ReLU() - elif activation == 'sigmoid': - self.activation = nn.Sigmoid() - elif activation == 'tanh': - self.activation = nn.Tanh() - elif activation == 'leakyrelu': - self.activation = nn.LeakyReLU() - elif activation == 'linear': - self.activation = nn.Identity() + self.activation_encoder = nn.Tanh() + if activation_decoder == 'relu': + self.activation_decoder = nn.ReLU() + elif activation_decoder == 'sigmoid': + self.activation_decoder = nn.Sigmoid() + elif activation_decoder == 'tanh': + self.activation_decoder = nn.Tanh() + elif activation_decoder == 'leakyrelu': + self.activation_decoder = nn.LeakyReLU() + elif activation_decoder == 'linear': + self.activation_decoder = nn.Identity() else: - raise ValueError(f"Activation function of type '{activation}' was not recognized.") + raise ValueError(f"Activation function of type '{activation_decoder}' was not recognized.") self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # encoder block of linear layers constructed in a loop and passed to a sequential container @@ -54,7 +45,7 @@ def __init__(self, input_dim: int, output_dim: int, output_dim_2: int, hidden_di encoder_block.append(nn.Tanh()) encoder_block.append(nn.Linear(hidden_dim, output_dim)) # encoder_block.append(self.activation) - encoder_block.append(nn.Tanh()) + encoder_block.append(self.activation_encoder) self.encoder = nn.Sequential(*encoder_block) # decoder block of linear layers constructed in a loop and passed to a sequential container @@ -69,7 +60,7 @@ def __init__(self, input_dim: int, output_dim: int, output_dim_2: int, hidden_di # decoder_block.append(self.activation) decoder_block.append(nn.Tanh()) decoder_block.append(nn.Linear(hidden_dim, input_dim)) - decoder_block.append(self.activation) + decoder_block.append(self.activation_decoder) self.decoder = nn.Sequential(*decoder_block) def forward(self, x): @@ -96,12 +87,11 @@ def decode(self, encoded): class TransformerAutoencoder(Autoencoder): - def __init__(self, input_dim: int, output_dim: int, output_dim_2: int, target: int, hidden_dim=256, num_layers=3, num_heads=4, dropout=0.1, activation='linear', **kwargs): - super(TransformerAutoencoder, self).__init__(input_dim, output_dim, output_dim_2, hidden_dim, target, num_layers, dropout, activation) + def __init__(self, input_dim: int, output_dim: int, output_dim_2: int, target: int, hidden_dim=256, num_layers=3, num_heads=4, dropout=0.1, activation_decoder='linear', **kwargs): + super(TransformerAutoencoder, self).__init__(input_dim, output_dim, output_dim_2, hidden_dim, target, num_layers, dropout, activation_decoder) self.num_heads = num_heads self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.tanh = nn.Tanh() # self.pe_enc = PositionalEncoder(batch_first=True, d_model=input_dim) self.linear_enc_in = nn.Linear(input_dim, hidden_dim) @@ -128,7 +118,7 @@ def encode(self, data): x = self.encoder(x) x = self.linear_enc_out(x) # x = self.activation(x) - x = self.tanh(x) + x = self.activation_encoder(x) if self.target == self.TARGET_TIMESERIES: x = x.permute(0, 2, 1) return x @@ -140,20 +130,16 @@ def decode(self, encoded): x = self.linear_dec_in(encoded) x = self.decoder(x) x = self.linear_dec_out(x) - x = self.activation(x) + x = self.activation_decoder(x) if self.target == self.TARGET_TIMESERIES: x = x.permute(0, 2, 1) return x - def save(self, path): - path = '../trained_ae' - file = f'ae_{pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")}.pth' - # torch.save(save, os.path.join(path, file)) class TransformerDoubleAutoencoder(Autoencoder): - def __init__(self, channels_in: int, timeseries_in: int, channels_out: int, timeseries_out: int, hidden_dim=256, num_layers=3, num_heads=8, dropout=0.1, activation='linear', training_level=2, **kwargs): + def __init__(self, channels_in: int, time_in: int, channels_out: int, time_out: int, hidden_dim=256, num_layers=3, num_heads=8, dropout=0.1, activation_decoder='linear', training_level=2, **kwargs): target = Autoencoder.TARGET_BOTH - super(TransformerDoubleAutoencoder, self).__init__(channels_in, channels_out, timeseries_out, hidden_dim, target, num_layers, dropout, activation) + super(TransformerDoubleAutoencoder, self).__init__(channels_in, channels_out, time_out, hidden_dim, target, num_layers, dropout, activation_decoder) ''' Note that this double autoencoder trains two autoencoders - the first is a timeseries autoencoder and the second is a channels autoencoder. @@ -165,17 +151,16 @@ def __init__(self, channels_in: int, timeseries_in: int, channels_out: int, time ''' self.training_level = training_level - self.sequence_length = timeseries_in + self.sequence_length = time_in self.num_heads = num_heads - self.tanh = nn.Tanh() self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Timeseries Encoder - self.linear_enc_in_timeseries = nn.Linear(timeseries_in, hidden_dim) + self.linear_enc_in_timeseries = nn.Linear(time_in, hidden_dim) self.encoder_layer_timeseries = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True) self.encoder_timeseries = nn.TransformerEncoder(self.encoder_layer_timeseries, num_layers=num_layers) - self.linear_enc_out_timeseries = nn.Linear(hidden_dim, timeseries_out) + self.linear_enc_out_timeseries = nn.Linear(hidden_dim, time_out) # Channel Encoder self.linear_enc_in_channels = nn.Linear(channels_in, hidden_dim) @@ -190,19 +175,17 @@ def __init__(self, channels_in: int, timeseries_in: int, channels_out: int, time self.linear_dec_out_channels = nn.Linear(hidden_dim, channels_in) # Timeseries Decoder - self.linear_dec_in_timeseries = nn.Linear(timeseries_out, hidden_dim) + self.linear_dec_in_timeseries = nn.Linear(time_out, hidden_dim) self.decoder_layer_timeseries = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=num_heads, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True) self.decoder_timeseries = nn.TransformerEncoder(self.decoder_layer_timeseries, num_layers=num_layers) - self.linear_dec_out_timeseries = nn.Linear(hidden_dim, timeseries_in) + self.linear_dec_out_timeseries = nn.Linear(hidden_dim, time_in) - def forward(self, data): - x = self.encode(data.to(self.device)) + def forward(self, x): + x = self.encode(x.to(self.device)) x = self.decode(x) return x - def encode(self, data): - x = data - + def encode(self, x): if self.training_level == 1: #Encode timeseries @@ -210,7 +193,7 @@ def encode(self, data): x = self.linear_enc_in_timeseries(x) x = self.encoder_timeseries(x) x = self.linear_enc_out_timeseries(x) - x = self.tanh(x) + x = self.activation_encoder(x) x = x.permute(0, 2, 1) if self.training_level == 2: @@ -222,20 +205,18 @@ def encode(self, data): x = self.linear_enc_in_channels(x) x = self.encoder_channels(x) x = self.linear_enc_out_channels(x) - x = self.tanh(x) + x = self.activation_encoder(x) return x - def decode(self, encoded): - x = encoded - + def decode(self, x): if self.training_level == 1: x = x.permute(0, 2, 1) x = self.linear_dec_in_timeseries(x) x = self.decoder_timeseries(x) x = self.linear_dec_out_timeseries(x) - x = self.activation(x) + x = self.activation_decoder(x) x = x.permute(0, 2, 1) if self.training_level == 2: @@ -247,389 +228,6 @@ def decode(self, encoded): x = self.linear_dec_in_channels(x) x = self.decoder_channels(x) x = self.linear_dec_out_channels(x) - x = self.activation(x) - - return x - - def save(self, path): - path = '../trained_ae' - file = f'ae_{pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")}.pth' - # torch.save(save, os.path.join(path, file)) - - -class TransformerFlattenAutoencoder(Autoencoder): - def __init__(self, input_dim, output_dim, sequence_length, hidden_dim=1024, num_layers=3, dropout=0.1, activation='linear', **kwargs): - super(TransformerFlattenAutoencoder, self).__init__(input_dim, output_dim, hidden_dim, num_layers, dropout, activation) - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.sequence_length = sequence_length - - # self.pe_enc = PositionalEncoder(batch_first=True, d_model=input_dim) - self.linear_enc_in = nn.Linear(input_dim, input_dim) - encoder_layer = nn.TransformerEncoderLayer(d_model=1, nhead=1, dropout=dropout, batch_first=True) - self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) - self.linear_enc_out_1 = nn.Linear(sequence_length*input_dim, hidden_dim) - self.linear_enc_out_2 = nn.Linear(hidden_dim, output_dim) - - # self.pe_dec = PositionalEncoder(batch_first=True, d_model=output_dim) - self.linear_dec_in = nn.Linear(output_dim, output_dim) - decoder_layer = nn.TransformerEncoderLayer(d_model=1, nhead=1, dropout=dropout, batch_first=True) - self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=num_layers) - self.linear_dec_out_1 = nn.Linear(output_dim, hidden_dim) - self.linear_dec_out_2 = nn.Linear(hidden_dim, input_dim*sequence_length) - - self.tanh = nn.Sigmoid() - - def forward(self, data): - x = self.encode(data.to(self.device)) - x = self.decode(x) - return x - - def encode(self, data): - # x = self.pe_enc(data) - x = self.linear_enc_in(data).reshape(data.shape[0], self.sequence_length*self.input_dim, 1) - x = self.encoder(x) - x = self.linear_enc_out_1(x.permute(0, 2, 1)) - x = self.linear_enc_out_2(x) - x = self.tanh(x) - return x - - def decode(self, encoded): - # x = self.pe_dec(encoded) - x = self.linear_dec_in(encoded) - x = self.decoder(x.permute(0, 2, 1)) - x = self.linear_dec_out_1(x.permute(0, 2, 1)) - x = self.linear_dec_out_2(x).reshape(encoded.shape[0], self.sequence_length, self.input_dim) - x = self.tanh(x) - return x - - def save(self, path): - path = '../trained_ae' - file = f'ae_{pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")}.pth' - # torch.save(save, os.path.join(path, file)) - - -class PositionalEncoder(nn.Module): - """ - The authors of the original transformer paper describe very succinctly what - the positional encoding layer does and why it is needed: - - "Since our model contains no recurrence and no convolution, in order for the - model to make use of the order of the sequence, we must inject some - information about the relative or absolute position of the tokens in the - sequence." (Vaswani et al, 2017) - Adapted from: - https://pytorch.org/tutorials/beginner/transformer_tutorial.html - """ - - def __init__( - self, - dropout: float = 0.1, - max_seq_len: int = 5000, - d_model: int = 512, - batch_first: bool = True - ): - """ - Parameters: - dropout: the dropout rate - max_seq_len: the maximum length of the input sequences - d_model: The dimension of the output of sub-layers in the model - (Vaswani et al, 2017) - """ - - super().__init__() - - self.d_model = d_model - - self.dropout = nn.Dropout(p=dropout) - - self.batch_first = batch_first - - self.x_dim = 1 if batch_first else 0 - - # copy pasted from PyTorch tutorial - position = torch.arange(max_seq_len).unsqueeze(1) - # print(f"shape of position is {position.shape}") - div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) - # print(f"shape of div_term is {div_term.shape}") - pe = torch.zeros(1, max_seq_len, d_model) - - pe[0, :, 0::2] = torch.sin(position * div_term) - - pe[0, :, 1::2] = torch.cos(position * div_term) - - self.register_buffer('pe', pe) - - def forward(self, x: Tensor) -> Tensor: - """ - Args: - x: Tensor, shape [batch_size, enc_seq_len, dim_val] or - [enc_seq_len, batch_size, dim_val] - """ - x = x + self.pe[:, :x.size(self.x_dim)] - - return self.dropout(x) - - -class LSTMAutoencoder(Autoencoder): - def __init__(self, input_dim, output_dim, sequence_length, hidden_dim=256, num_layers=3, dropout=0.1, activation=nn.Sigmoid(), **kwargs): - super(LSTMAutoencoder, self).__init__(input_dim, output_dim, hidden_dim, num_layers, dropout) - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.activation = activation - self.sequence_length = sequence_length - - # encoder block - self.enc_lin_in = nn.Linear(self.input_dim, self.input_dim) - self.enc_lstm = nn.LSTM(self.input_dim, self.output_dim, num_layers=self.num_layers, dropout=self.dropout, batch_first=True) - self.enc_lin_out = nn.Linear(self.output_dim, self.output_dim) - self.enc_dropout = nn.Dropout(self.dropout) - - # decoder block - # decoder_block = nn.ModuleList() - # if self.num_layers > 1: - # decoder_block.append(nn.Linear(self.output_dim, hidden_dim)) - # decoder_block.append(self.activation) - # decoder_block.append(nn.Dropout(self.dropout)) - # if self.num_layers > 2: - # for _ in range(self.num_layers-2): - # decoder_block.append(nn.Linear(self.hidden_dim, self.hidden_dim)) - # decoder_block.append(self.activation) - # decoder_block.append(nn.Dropout(self.dropout)) - # if self.num_layers == 1: - # decoder_block.append(nn.Linear(self.output_dim, self.input_dim*self.sequence_length)) - # else: - # decoder_block.append(nn.Linear(self.hidden_dim, self.input_dim*self.sequence_length)) - # decoder_block.append(self.activation) - # self.decoder = nn.Sequential(*decoder_block) - self.dec_lin_in = nn.Linear(self.output_dim, self.output_dim) - self.dec_lstm = nn.LSTM(self.output_dim, self.input_dim, num_layers=self.num_layers, dropout=self.dropout, batch_first=True) - self.dec_lin_out = nn.Linear(self.input_dim, self.input_dim) - - def forward(self, data): - return self.decode(self.encode(data)) - - def encode(self, data): - # flip data along time axis - # data = torch.flip(data, [1]) - x = self.enc_lin_in(data) - x = self.activation(x) - x = self.enc_dropout(x) - x = self.enc_lstm(x)[0]#.reshape(-1, self.hidden_dim//2*self.sequence_length) - # x = self.enc_lin_out(x) - # x = self.activation(x) - return x - - def decode(self, encoded): - x = self.dec_lin_in(encoded) - x = self.activation(x) - x = self.enc_dropout(x) - x = self.dec_lstm(x)[0] # .reshape(-1, self.hidden_dim//2*self.sequence_length) - # x = self.dec_lin_out(x) - # x = self.activation(x) - return x - # return self.decoder(encoded)#.reshape(-1, self.sequence_length, self.input_dim) - - def save(self, path): - path = '../trained_ae' - file = f'ae_{pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")}.pth' - # torch.save(save, os.path.join(path, file)) - - -class LSTMDoubleAutoencoder(Autoencoder): - def __init__(self, input_dim, output_dim, sequence_length, output_dim_2, hidden_dim=256, num_layers=3, dropout=0.1, activation=nn.Sigmoid(), **kwargs): - super(LSTMDoubleAutoencoder, self).__init__(input_dim, output_dim, hidden_dim, num_layers, dropout) - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.activation = activation - self.sequence_length = sequence_length - self.output_dim_2 = output_dim_2 - - # encoder block 1 - self.enc_lin_in = nn.Linear(self.input_dim, self.input_dim) - self.enc_lstm = nn.LSTM(self.input_dim, self.output_dim, num_layers=self.num_layers, dropout=self.dropout, batch_first=True) - self.enc_lin_out = nn.Linear(self.output_dim, self.output_dim) - self.enc_dropout = nn.Dropout(self.dropout) - - # encoder block 2 - self.enc_lin_in2 = nn.Linear(self.sequence_length, self.sequence_length) - self.enc_lstm2 = nn.LSTM(self.sequence_length, self.output_dim_2, num_layers=self.num_layers, dropout=self.dropout, batch_first=True) - self.enc_lin_out2 = nn.Linear(self.output_dim_2, self.output_dim_2) - - # decoder block - # decoder_block = nn.ModuleList() - # if self.num_layers > 1: - # decoder_block.append(nn.Linear(self.output_dim, hidden_dim)) - # decoder_block.append(self.activation) - # decoder_block.append(nn.Dropout(self.dropout)) - # if self.num_layers > 2: - # for _ in range(self.num_layers-2): - # decoder_block.append(nn.Linear(self.hidden_dim, self.hidden_dim)) - # decoder_block.append(self.activation) - # decoder_block.append(nn.Dropout(self.dropout)) - # if self.num_layers == 1: - # decoder_block.append(nn.Linear(self.output_dim, self.input_dim*self.sequence_length)) - # else: - # decoder_block.append(nn.Linear(self.hidden_dim, self.input_dim*self.sequence_length)) - # decoder_block.append(self.activation) - # self.decoder = nn.Sequential(*decoder_block) - - # decoder block 2 - self.dec_lin_in2 = nn.Linear(self.output_dim_2, self.output_dim_2) - self.dec_lstm2 = nn.LSTM(self.output_dim_2, self.sequence_length, num_layers=self.num_layers, dropout=self.dropout, batch_first=True) - self.dec_lin_out2 = nn.Linear(self.sequence_length, self.sequence_length) - - # decoder block 1 - self.dec_lin_in = nn.Linear(self.output_dim, self.output_dim) - self.dec_lstm = nn.LSTM(self.output_dim, self.input_dim, num_layers=self.num_layers, dropout=self.dropout, batch_first=True) - self.dec_lin_out = nn.Linear(self.input_dim, self.input_dim) - - def forward(self, data): - return self.decode(self.encode(data)) - - def encode(self, data): - # encoder block 1 - x = self.enc_lin_in(data) - x = self.activation(x) - x = self.enc_dropout(x) - x = self.enc_lstm(x)[0]#.reshape(-1, self.hidden_dim//2*self.sequence_length) - x = self.enc_lin_out(x) - x = self.activation(x) - - # encoder block 2 - x = self.enc_lin_in2(x.permute(0, 2, 1)) - x = self.activation(x) - x = self.enc_dropout(x) - x = self.enc_lstm2(x)[0]#.reshape(-1, self.hidden_dim//2*self.sequence_length) - x = self.enc_lin_out2(x) - x = self.activation(x) - return x.permute(0, 2, 1) - - def decode(self, encoded): - # decoder block 2 - x = self.dec_lin_in2(encoded.permute(0, 2, 1)) - x = self.activation(x) - x = self.enc_dropout(x) - x = self.dec_lstm2(x)[0]#.reshape(-1, self.hidden_dim//2*self.sequence_length) - x = self.dec_lin_out2(x) - x = self.activation(x) - - # decoder block 1 - x = self.dec_lin_in(x.permute(0, 2, 1)) - x = self.activation(x) - x = self.enc_dropout(x) - x = self.dec_lstm(x)[0] # .reshape(-1, self.hidden_dim//2*self.sequence_length) - x = self.dec_lin_out(x) - x = self.activation(x) - return x - # return self.decoder(encoded)#.reshape(-1, self.sequence_length, self.input_dim) - - def save(self, path): - path = '../trained_ae' - file = f'ae_{pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")}.pth' - # torch.save(save, os.path.join(path, file)) - - -class LSTMTransformerAutoencoder(Autoencoder): - def __init__(self, input_dim, output_dim, hidden_dim=256, num_layers=3, dropout=0.1, **kwargs): - super(LSTMTransformerAutoencoder, self).__init__(input_dim, output_dim, hidden_dim, num_layers, dropout) - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.tanh = nn.Tanh() - - self.pe_enc = PositionalEncoder(batch_first=True, d_model=input_dim) - self.linear_enc_in = nn.Linear(input_dim, input_dim) - self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=5, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True) - self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers) - self.enc_lstm = nn.LSTM(input_dim, output_dim, num_layers=1, dropout=dropout, batch_first=True) - self.linear_enc_out = nn.Linear(output_dim, output_dim) - - self.pe_dec = PositionalEncoder(batch_first=True, d_model=output_dim) - self.linear_dec_in = nn.Linear(output_dim, output_dim) - self.decoder_layer = nn.TransformerEncoderLayer(d_model=output_dim, nhead=5, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True) - self.decoder = nn.TransformerEncoder(self.decoder_layer, num_layers=num_layers) - self.dec_lstm = nn.LSTM(output_dim, input_dim, num_layers=1, dropout=dropout, batch_first=True) - self.linear_dec_out = nn.Linear(input_dim, input_dim) - - - def forward(self, data): - x = self.encode(data.to(self.device)) - x = self.decode(x) - return x - - def encode(self, data): - x = self.pe_enc(data) - x = self.linear_enc_in(x) - x = self.encoder(x) - x = self.enc_lstm(x)[0] - x = self.linear_enc_out(x) - x = self.tanh(x) - return x - - def decode(self, encoded): - x = self.pe_dec(encoded) - x = self.linear_dec_in(x) - x = self.decoder(encoded) - x = self.dec_lstm(x)[0] - x = self.linear_dec_out(x) - x = self.tanh(x) - return x + x = self.activation_decoder(x) - def save(self, path): - path = '../trained_ae' - file = f'ae_{pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")}.pth' - # torch.save(save, os.path.join(path, file)) - - -def train_model(model, data, optimizer, criterion, batch_size=32): - model.train() - total_loss = 0 - for batch in data: - optimizer.zero_grad() - # inputs = nn.BatchNorm1d(batch.shape[-1])(batch.float().permute(0, 2, 1)).permute(0, 2, 1) - # inputs = filter(inputs.detach().cpu().numpy(), win_len=random.randint(29, 50), dtype=torch.Tensor) - inputs = batch.float() - outputs = model(inputs.to(model.device)) - loss = criterion(outputs, inputs) - loss.backward() - optimizer.step() - total_loss += loss.item() - return total_loss / len(data) - - -def test_model(model, dataloader, criterion): - model.eval() - total_loss = 0 - with torch.no_grad(): - for batch in dataloader: - # inputs = nn.BatchNorm1d(batch.shape[-1])(batch.float().permute(0, 2, 1)).permute(0, 2, 1) - inputs = batch.float() - outputs = model(inputs.to(model.device)) - loss = criterion(outputs, inputs) - total_loss += loss.item() - return total_loss / len(dataloader) - - -def train(num_epochs, model, train_data, test_data, optimizer, criterion, configuration: Optional[dict] = None): - try: - train_losses = [] - test_losses = [] - for epoch in range(num_epochs): - train_loss = train_model(model, train_data, optimizer, criterion) - test_loss = test_model(model, test_data, criterion) - train_losses.append(train_loss) - test_losses.append(test_loss) - print(f"Epoch {epoch + 1}/{num_epochs}: train_loss={train_loss:.8f}, test_loss={test_loss:.8f}") - return train_losses, test_losses, model - except KeyboardInterrupt: - # save model at KeyboardInterrupt - print("keyboard interrupt detected.") - if configuration is not None: - print("Configuration found.") - configuration["model"]["state_dict"] = model.state_dict() # update model's state dict - save(configuration, configuration["general"]["default_save_path"]) - - -def save(configuration, path): - torch.save(configuration, path) - print("Saved model and configuration to " + path) + return x \ No newline at end of file diff --git a/nn_architecture/losses.py b/nn_architecture/losses.py index 527b235..3adb298 100644 --- a/nn_architecture/losses.py +++ b/nn_architecture/losses.py @@ -72,10 +72,6 @@ def discriminator(self, *args): def _gradient_penalty(self, discriminator: torch.nn.Module, real_images: torch.Tensor, fake_images: torch.Tensor): """Calculates the gradient penalty for WGAN-GP""" - - # adjust dimensions of real_labels, fake_labels and eta to to match the dimensions of real_images - # if real_labels.shape != fake_labels.shape: - # raise ValueError("real_labels and fake_labels must have the same shape!") if real_images.shape != fake_images.shape: raise ValueError("real_images and fake_images must have the same shape!") @@ -101,13 +97,9 @@ def _gradient_penalty(self, discriminator: torch.nn.Module, real_images: torch.T interpolated = (eta * real_images.detach() + ((1 - eta) * fake_images.detach())) interpolated.requires_grad = True - # deprecated - define it to calculate gradient - # interpolated = autograd.Variable(interpolated, requires_grad=True) - # calculate probability of interpolated examples prob_interpolated = discriminator(interpolated) - # fake = autograd.Variable(torch.ones((real_images.shape[0], 1)).to(real_images.device), requires_grad=False) fake = torch.ones((real_images.shape[0], 1), requires_grad=False).to(real_images.device) # calculate gradients of probabilities with respect to examples @@ -117,47 +109,4 @@ def _gradient_penalty(self, discriminator: torch.nn.Module, real_images: torch.T create_graph=True, retain_graph=True)[0] grad_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.gradient_penalty_weight - return grad_penalty - - # TODO: Check why this gradient penalty is not working - # def _gradient_penalty(self, discriminator, real_samples, fake_samples): - # """Calculates the gradient penalty for WGAN-GP""" - - # batch_size = real_samples.size(0) - # device = real_samples.device - - # # Generate random epsilon - # # TODO: this is old version - # # epsilon = torch.rand(batch_size, 1, 1, device=device, requires_grad=True) - # # epsilon = epsilon.expand_as(real_samples) - # # TODO: this is new version - replace old version with this - # # epsilon = torch.rand(*real_samples.shape, device=device, requires_grad=True) - # # TODO: this is the original pip version; delete if supposed to - # epsilon = torch.FloatTensor(real_samples.shape[0], 1).uniform_(0, 1).repeat((1, real_samples.shape[-1])).to(real_samples.device) - # while epsilon.dim() < real_samples.dim(): - # epsilon = epsilon.unsqueeze(1) - - # # Interpolate between real and fake samples - # interpolated_samples = epsilon * real_samples + (1 - epsilon) * fake_samples - # interpolated_samples = torch.autograd.Variable(interpolated_samples, requires_grad=True) - - # # Calculate critic scores for interpolated samples - # critic_scores = discriminator(interpolated_samples) - - # # TODO: check out if fake works with new version (commented out) or only with old one - # # fake = torch.ones(critic_scores.size(), device=device) - # fake = autograd.Variable(torch.ones((real_samples.shape[0], 1)).to(real_samples.device), requires_grad=False) - # while fake.dim() < critic_scores.dim(): - # fake = fake.unsqueeze(-1) - - # # Compute gradients of critic scores with respect to interpolated samples - # gradients = torch.autograd.grad(outputs=critic_scores, - # inputs=interpolated_samples, - # grad_outputs=fake, - # create_graph=True, - # retain_graph=True)[0] - - # # Calculate gradient penalty - # gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * self.gradient_penalty_weight - - # return gradient_penalty + return grad_penalty \ No newline at end of file diff --git a/nn_architecture/models.py b/nn_architecture/models.py index fcfd340..6dd5cf9 100644 --- a/nn_architecture/models.py +++ b/nn_architecture/models.py @@ -1,8 +1,4 @@ -import math -import warnings - -import torch -from torch import nn, Tensor +from torch import nn from nn_architecture.ae_networks import Autoencoder from nn_architecture.tts_gan_components import Generator as TTSGenerator_Org, Discriminator as TTSDiscriminator_Org @@ -24,328 +20,6 @@ def forward(self, z): raise NotImplementedError -class FFGenerator(Generator): - def __init__(self, latent_dim, channels, seq_len, hidden_dim=256, num_layers=4, dropout=.1, activation='tanh', **kwargs): - """ - :param latent_dim: latent dimension - :param channels: output dimension - :param hidden_dim: hidden dimension - :param num_layers: number of layers - :param dropout: dropout rate - - """ - - super(Generator, self).__init__() - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.latent_dim = latent_dim - self.hidden_dim = hidden_dim - self.channels = channels - self.seq_len = seq_len - self.num_layers = num_layers - self.dropout = dropout - if activation == 'relu': - self.act_out = nn.ReLU() - elif activation == 'sigmoid': - self.act_out = nn.Sigmoid() - elif activation == 'tanh': - self.act_out = nn.Tanh() - elif activation == 'leakyrelu': - self.act_out = nn.LeakyReLU() - elif activation == 'linear': - self.act_out = nn.Identity() - else: - self.act_out = nn.Identity() - warnings.warn( - f"Activation function of type '{activation}' was not recognized. Activation function was set to 'linear'.") - - modulelist = nn.ModuleList() - modulelist.append(nn.Linear(latent_dim, hidden_dim)) - modulelist.append(nn.LeakyReLU(0.1)) - modulelist.append(nn.Dropout(dropout)) - for _ in range(num_layers): - modulelist.append(nn.Linear(hidden_dim, hidden_dim)) - modulelist.append(nn.LeakyReLU(0.1)) - modulelist.append(nn.Dropout(dropout)) - modulelist.append(nn.Linear(hidden_dim, channels * seq_len)) - modulelist.append(self.act_out) - - self.block = nn.Sequential(*modulelist) - - def forward(self, z): - return self.block(z).reshape(-1, self.seq_len, self.channels) - - -class FFDiscriminator(Discriminator): - def __init__(self, channels, seq_len, hidden_dim=256, num_layers=4, dropout=.1, **kwargs): - """ - :param channels: input dimension - :param hidden_dim: hidden dimension - :param num_layers: number of layers - :param dropout: dropout rate - """ - super(Discriminator, self).__init__() - - self.channels = channels - self.hidden_dim = hidden_dim - self.num_layers = num_layers - self.dropout = dropout - self.seq_len = seq_len - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - modulelist = nn.ModuleList() - modulelist.append(nn.Linear(channels * seq_len, hidden_dim)) - modulelist.append(nn.LeakyReLU(0.1)) - modulelist.append(nn.Dropout(dropout)) - for _ in range(num_layers): - modulelist.append(nn.Linear(hidden_dim, hidden_dim)) - modulelist.append(nn.LeakyReLU(0.1)) - modulelist.append(nn.Dropout(dropout)) - modulelist.append(nn.Linear(hidden_dim, 1)) - - self.block = nn.Sequential(*modulelist) - - def forward(self, data): - if data.dim() == 4: - # this is probably a tts-format -> transform - data = data.squeeze(2).permute(0, 2, 1) - - return self.block(data.reshape(-1, self.seq_len * self.channels)) - - -class AutoencoderGenerator(FFGenerator): - """Autoencoder generator""" - - def __init__(self, latent_dim, autoencoder: Autoencoder, **kwargs): - """ - :param autoencoder: Autoencoder model; Decoder takes in array and decodes into multidimensional array of shape (batch, sequence_length, channels) - """ - self.output_dim_1 = autoencoder.output_dim if autoencoder.target in [autoencoder.TARGET_CHANNELS, autoencoder.TARGET_BOTH] else autoencoder.output_dim_2 - self.output_dim_2 = autoencoder.output_dim_2 if autoencoder.target in [autoencoder.TARGET_CHANNELS, autoencoder.TARGET_BOTH] else autoencoder.output_dim - super(AutoencoderGenerator, self).__init__(latent_dim, self.output_dim_1*self.output_dim_2, **kwargs) - self.autoencoder = autoencoder - self.decode = True - - def forward(self, z): - """ - :param z: input array of shape (batch, latent_dim) - :return: output array of shape (batch, sequence_length, channels) - """ - x = super(AutoencoderGenerator, self).forward(z) - if self.decode: - x = self.autoencoder.decode(x.reshape(-1, self.output_dim_2, self.channels // self.output_dim_2)) - return x - - def decode_output(self, mode=True): - self.decode = mode - - -class AutoencoderDiscriminator(FFDiscriminator): - """Autoencoder discriminator""" - - def __init__(self, channels, autoencoder: Autoencoder, **kwargs): - """ - :param autoencoder: Autoencoder model; Encoder takes in multidimensional array of shape (batch, sequence_length, channels) and encodes into array - """ - n_channels = autoencoder.input_dim if autoencoder.target in [autoencoder.TARGET_CHANNELS, autoencoder.TARGET_BOTH] else autoencoder.output_dim_2 - channels = channels - n_channels + autoencoder.output_dim * autoencoder.output_dim_2 - super(AutoencoderDiscriminator, self).__init__(channels, **kwargs) - self.autoencoder = autoencoder - self.encode = True - - def forward(self, z): - """ - :param z: input array of shape (batch, sequence_length, channels + conditions) - :return: output array of shape (batch, 1) - """ - if self.encode: - x = self.autoencoder.encode(z[:, :, :self.autoencoder.input_dim]) - # flatten x - x = x.reshape(-1, 1, x.shape[-2]*x.shape[-1]) - conditions = z[:, 0, self.autoencoder.input_dim:] - if conditions.dim() < x.dim(): - conditions = conditions.unsqueeze(1) - x = self.block(torch.concat((x, conditions), dim=-1)) - else: - x = self.block(z) - return x - - def encode_input(self, mode=True): - self.encode = mode - - -class PositionalEncoder(nn.Module): - """ - The authors of the original transformer paper describe very succinctly what - the positional encoding layer does and why it is needed: - - "Since our model contains no recurrence and no convolution, in order for the - model to make use of the order of the sequence, we must inject some - information about the relative or absolute position of the tokens in the - sequence." (Vaswani et al, 2017) - Adapted from: - https://pytorch.org/tutorials/beginner/transformer_tutorial.html - """ - - def __init__( - self, - dropout: float = 0.1, - max_seq_len: int = 100, - d_model: int = 512, - batch_first: bool = True - ): - """ - Parameters: - dropout: the dropout rate - max_seq_len: the maximum length of the input sequences - d_model: The dimension of the output of sub-layers in the model - (Vaswani et al, 2017) - """ - - super().__init__() - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self.d_model = d_model - - self.dropout = nn.Dropout(p=dropout) - - self.batch_first = batch_first - - self.x_dim = 1 if batch_first else 0 - - # copy pasted from PyTorch tutorial - position = torch.arange(max_seq_len).unsqueeze(1) - # print(f"shape of position is {position.shape}") - div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) - # print(f"shape of div_term is {div_term.shape}") - pe = torch.zeros(1, max_seq_len, d_model) - - pe[0, :, 0::2] = torch.sin(position * div_term) - - pe[0, :, 1::2] = torch.cos(position * div_term) - - self.register_buffer('pe', pe) - - def forward(self, x: Tensor) -> Tensor: - """ - Args: - x: Tensor, shape [batch_size, enc_seq_len, dim_val] or - [enc_seq_len, batch_size, dim_val] - """ - x = x + self.pe[0, :x.size(self.x_dim)] - - return self.dropout(x) - - -class TransformerGenerator(Generator): - def __init__(self, latent_dim, channels, seq_len, hidden_dim=8, num_layers=2, num_heads=4, dropout=.1, **kwargs): - super(TransformerGenerator, self).__init__() - - self.latent_dim = latent_dim - self.channels = channels - self.seq_len = seq_len - self.hidden_dim = hidden_dim - self.num_heads = num_heads - self.num_layers = num_layers - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # self.pe = PositionalEncoder(batch_first=True, d_model=latent_dim) - self.linear_enc_in = nn.Linear(latent_dim, hidden_dim*seq_len) - # self.linear_enc_in = nn.LSTM(latent_dim, hidden_dim, batch_first=True, dropout=dropout, num_layers=2) - self.encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, - nhead=num_heads, - dim_feedforward=hidden_dim, - dropout=dropout, - batch_first=True, - ) - self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers) - self.linear_enc_out = nn.Linear(hidden_dim, channels) - self.act_out = nn.Tanh() - - # self.deconv = nn.Sequential( - # nn.Conv2d(self.embed_dim, self.channels, 1, 1, 0) - # ) - - # TODO: Put in autoencoder - # encoder needs as input dim n_channels - # decoder needs as output dim n_channels - # self.linear_enc_in and self.pe need as input dim embedding_dim of the autoencoder - - # self.encoder = encoder if encoder is not None else nn.Identity() - # for param in self.encoder.parameters(): - # param.requires_grad = False - # self.decoder = decoder if decoder is not None else nn.Identity() - # for param in self.decoder.parameters(): - # param.requires_grad = False - - def forward(self, data): - # x = self.pe(data) - x = self.linear_enc_in(data).reshape(-1, self.seq_len, self.hidden_dim) # [0] --> only for lstm - x = self.encoder(x) - x = self.act_out(self.linear_enc_out(x)) - # x = x.reshape(x.shape[0], 1, x.shape[1], x.shape[2]) - # output = self.deconv(x.permute(0, 3, 1, 2)) - # output = output.view(-1, self.channels, H, W) - # x = self.mask(x, data[:, :, self.latent_dim - self.channels:].diff(dim=1)) - # x = self.tanh(x) - # x = self.decoder(x) - return x - - def mask(self, data, data_ref, mask=0): - # mask predictions if ALL preceding values (axis=sequence) were 'mask' - # return indices to mask - mask_index = (data_ref.sum(dim=1) == mask).unsqueeze(1).repeat(1, data.shape[1], 1) - data[mask_index] = mask - return data - - -class TransformerDiscriminator(Discriminator): - def __init__(self, channels, seq_len, n_classes=1, hidden_dim=8, num_layers=2, num_heads=4, dropout=.1, **kwargs): - super(TransformerDiscriminator, self).__init__() - - self.hidden_dim = hidden_dim - self.channels = channels - self.n_classes = n_classes - self.num_heads = num_heads - self.num_layers = num_layers - self.seq_len = seq_len - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # self.pe = PositionalEncoder(batch_first=True, d_model=channels) - self.linear_enc_in = nn.Linear(channels, hidden_dim) - self.encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, - nhead=num_heads, - dim_feedforward=hidden_dim, - dropout=dropout, - batch_first=True, - ) - self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers) - self.linear_enc_out = nn.Linear(hidden_dim*seq_len, n_classes) - self.tanh = nn.Tanh() - - # self.decoder = decoder if decoder is not None else nn.Identity() - # for param in self.decoder.parameters(): - # param.requires_grad = False - - def forward(self, data): - if data.dim() == 4: - # this is probably a tts-format -> transform - data = data.squeeze(2).permute(0, 2, 1) - - # x = self.pe(data) - x = self.linear_enc_in(data) - x = self.encoder(x).reshape(-1, self.seq_len*self.hidden_dim) - x = self.linear_enc_out(x) # .reshape(-1, self.channels) - # x = self.mask(x, data[:,:,self.latent_dim-self.channels:].diff(dim=1)) - # x = self.tanh(x) - # x = self.decoder(x) - return x - - class TTSGenerator(TTSGenerator_Org): def __init__(self, seq_len=150, patch_size=15, channels=3, num_classes=9, latent_dim=100, embed_dim=10, depth=3, num_heads=5, forward_drop_rate=0.5, attn_drop_rate=0.5): @@ -357,6 +31,7 @@ class TTSDiscriminator(TTSDiscriminator_Org): def __init__(self, in_channels=3, patch_size=15, emb_size=50, seq_length=150, depth=3, n_classes=1, **kwargs): super(TTSDiscriminator, self).__init__(in_channels, patch_size, emb_size, seq_length, depth, n_classes, **kwargs) + class DecoderGenerator(Generator): """ DecoderGenerator serves as a wrapper for a generator. @@ -381,15 +56,14 @@ def __init__(self, generator: Generator, decoder: Autoencoder): def forward(self, data): if self.decode: - data_input = self.generator(data) - data_input = data_input[:,:-self.padding,:] if self.padding > 0 else data_input - return self.decoder.decode(data_input) + return self.decoder.decode(self.generator(data)) else: return self.generator(data) def decode_output(self, mode=True): self.decode = mode + class EncoderDiscriminator(Discriminator): """ EncoderDiscriminator serves as a wrapper for a discriminator. @@ -418,314 +92,4 @@ def forward(self, data): return self.discriminator(data) def encode_input(self, mode=True): - self.encode = mode - - -''' -# ---------------------------------------------------------------------------------------------------------------------- -# Autoencoders -# ---------------------------------------------------------------------------------------------------------------------- - -class TransformerAutoencoder(nn.Module): - def __init__(self, input_dim, output_dim, hidden_dim=256, num_layers=3, dropout=0.1, **kwargs): - super(TransformerAutoencoder, self).__init__() - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.input_dim = input_dim - self.output_dim = output_dim - self.hidden_dim = hidden_dim - self.num_layers = num_layers - self.dropout = dropout - - #self.pe_enc = PositionalEncoder(batch_first=True, d_model=input_dim) - self.linear_enc_in = nn.Linear(input_dim, input_dim) - self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=2, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True) - self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers) - self.linear_enc_out = nn.Linear(input_dim, output_dim) - - #self.pe_dec = PositionalEncoder(batch_first=True, d_model=output_dim) - self.linear_dec_in = nn.Linear(output_dim, output_dim) - self.decoder_layer = nn.TransformerEncoderLayer(d_model=output_dim, nhead=2, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True) - self.decoder = nn.TransformerEncoder(self.decoder_layer, num_layers=num_layers) - self.linear_dec_out = nn.Linear(output_dim, input_dim) - - self.tanh = nn.Tanh() - - def forward(self, data): - x = self.encode(data.to(self.device)) - x = self.decode(x) - return x - - def encode(self, data): - #x = self.pe_enc(data) - #x = self.linear_enc_in(x) - x = self.linear_enc_in(data) - x = self.encoder(x) - x = self.linear_enc_out(x) - x = self.tanh(x) - return x - - def decode(self, encoded): - #x = self.pe_dec(encoded) - #x = self.linear_dec_in(x) - x = self.linear_dec_in(encoded) - x = self.decoder(x) - x = self.linear_dec_out(x) - x = self.tanh(x) - return x - - def save(self, path): - path = '../trained_ae' - file = f'ae_{pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")}.pth' - # torch.save(save, os.path.join(path, file)) - -class Autoencoder(nn.Module): - def __init__(self, input_dim, output_dim, hidden_dim, num_layers=3, dropout=0.1, **kwargs): - super(Autoencoder, self).__init__() - self.input_dim = input_dim - self.output_dim = output_dim - self.hidden_dim = hidden_dim - self.num_layers = num_layers - self.dropout = dropout - self.activation = nn.Sigmoid() - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - # encoder block of linear layers constructed in a loop and passed to a sequential container - encoder_block = [] - encoder_block.append(nn.Linear(input_dim, hidden_dim)) - encoder_block.append(nn.Dropout(dropout)) - encoder_block.append(self.activation) - for i in range(num_layers): - encoder_block.append(nn.Linear(hidden_dim, hidden_dim)) - encoder_block.append(nn.Dropout(dropout)) - encoder_block.append(self.activation) - encoder_block.append(nn.Linear(hidden_dim, output_dim)) - encoder_block.append(self.activation) - self.encoder = nn.Sequential(*encoder_block) - - # decoder block of linear layers constructed in a loop and passed to a sequential container - decoder_block = [] - decoder_block.append(nn.Linear(output_dim, hidden_dim)) - decoder_block.append(nn.Dropout(dropout)) - decoder_block.append(self.activation) - for i in range(num_layers): - decoder_block.append(nn.Linear(hidden_dim, hidden_dim)) - decoder_block.append(nn.Dropout(dropout)) - decoder_block.append(self.activation) - decoder_block.append(nn.Linear(hidden_dim, input_dim)) - decoder_block.append(self.activation) - self.decoder = nn.Sequential(*decoder_block) - - def forward(self, x): - encoded = self.encoder(x.to(self.device)) - decoded = self.decoder(encoded) - return decoded - - def encode(self, data): - return self.encoder(data.to(self.device)) - - def decode(self, encoded): - return self.decoder(encoded) - -class TransformerFlattenAutoencoder(Autoencoder): - def __init__(self, input_dim, output_dim, sequence_length, hidden_dim=1024, num_layers=3, dropout=0.1, **kwargs): - super(TransformerFlattenAutoencoder, self).__init__(input_dim, output_dim, hidden_dim, num_layers, dropout) - - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.sequence_length = sequence_length - - #self.pe_enc = PositionalEncoder(batch_first=True, d_model=input_dim) - self.linear_enc_in = nn.Linear(input_dim, input_dim) - encoder_layer = nn.TransformerEncoderLayer(d_model=1, nhead=1, dropout=dropout, batch_first=True) - self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) - self.linear_enc_out_1 = nn.Linear(sequence_length*input_dim, hidden_dim) - self.linear_enc_out_2 = nn.Linear(hidden_dim, output_dim) - - #self.pe_dec = PositionalEncoder(batch_first=True, d_model=output_dim) - self.linear_dec_in = nn.Linear(output_dim, output_dim) - decoder_layer = nn.TransformerEncoderLayer(d_model=1, nhead=1, dropout=dropout, batch_first=True) - self.decoder = nn.TransformerEncoder(decoder_layer, num_layers=num_layers) - self.linear_dec_out_1 = nn.Linear(output_dim, hidden_dim) - self.linear_dec_out_2 = nn.Linear(hidden_dim, input_dim*sequence_length) - - self.tanh = nn.Sigmoid() - - def forward(self, data): - x = self.encode(data.to(self.device)) - x = self.decode(x) - return x - - def encode(self, data): - #x = self.pe_enc(data) - #x = self.linear_enc_in(x).reshape(data.shape[0], self.sequence_length*self.input_dim, 1) - x = self.linear_enc_in(data).reshape(data.shape[0], self.sequence_length*self.input_dim, 1) - x = self.encoder(x) - x = self.linear_enc_out_1(x.permute(0, 2, 1)) - x = self.linear_enc_out_2(x) - x = self.tanh(x) - return x - - def decode(self, encoded): - #x = self.pe_dec(encoded) - #x = self.linear_dec_in(x) - x = self.linear_dec_in(encoded) - x = self.decoder(x.permute(0, 2, 1)) - x = self.linear_dec_out_1(x.permute(0, 2, 1)) - x = self.linear_dec_out_2(x).reshape(encoded.shape[0], self.sequence_length, self.input_dim) - x = self.tanh(x) - return x - - def save(self, path): - path = '../trained_ae' - file = f'ae_{pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")}.pth' - # torch.save(save, os.path.join(path, file)) - -class TransformerDoubleAutoencoder(nn.Module): - def __init__(self, input_dim, output_dim, sequence_length, output_dim_2, hidden_dim=256, num_layers=3, dropout=0.1, **kwargs): - super(TransformerDoubleAutoencoder, self).__init__() - - # parameters - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.input_dim = input_dim - self.output_dim = output_dim - self.hidden_dim = hidden_dim - self.num_layers = num_layers - self.dropout = dropout - - # encoder block features - #self.pe_enc = PositionalEncoder(batch_first=True, d_model=input_dim) - self.linear_enc_in = nn.Linear(input_dim, input_dim) - self.encoder_layer = nn.TransformerEncoderLayer(d_model=input_dim, nhead=2, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True) - self.encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=num_layers) - self.linear_enc_out = nn.Linear(input_dim, output_dim) - - # encoder block sequence - #self.pe_enc_seq = PositionalEncoder(batch_first=True, d_model=sequence_length) - self.linear_enc_in_seq = nn.Linear(sequence_length, sequence_length) - self.encoder_layer_seq = nn.TransformerEncoderLayer(d_model=sequence_length, nhead=2, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True) - self.encoder_seq = nn.TransformerEncoder(self.encoder_layer_seq, num_layers=num_layers) - self.linear_enc_out_seq = nn.Linear(sequence_length, output_dim_2) - - # decoder block sequence - #self.pe_dec_seq = PositionalEncoder(batch_first=True, d_model=output_dim_2) - self.linear_dec_in_seq = nn.Linear(output_dim_2, output_dim_2) - self.decoder_layer_seq = nn.TransformerEncoderLayer(d_model=output_dim_2, nhead=2, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True) - self.decoder_seq = nn.TransformerEncoder(self.decoder_layer_seq, num_layers=num_layers) - self.linear_dec_out_seq = nn.Linear(output_dim_2, sequence_length) - - # decoder block features - #self.pe_dec = PositionalEncoder(batch_first=True, d_model=output_dim) - self.linear_dec_in = nn.Linear(output_dim, output_dim) - self.decoder_layer = nn.TransformerEncoderLayer(d_model=output_dim, nhead=2, dim_feedforward=hidden_dim, dropout=dropout, batch_first=True) - self.decoder = nn.TransformerEncoder(self.decoder_layer, num_layers=num_layers) - self.linear_dec_out = nn.Linear(output_dim, input_dim) - - self.tanh = nn.Tanh() - - def forward(self, data): - x = self.encode(data.to(self.device)) - x = self.decode(x) - return x - - def encode(self, data): - # encoder features - #x = self.pe_enc(data) - #x = self.linear_enc_in(x) - x = self.linear_enc_in(data) - x = self.encoder(x) - x = self.linear_enc_out(x) - x = self.tanh(x) - - # encoder sequence - #x = self.pe_enc_seq(x.permute(0, 2, 1)) - #x = self.linear_enc_in_seq(x) - x = self.linear_enc_in_seq(x.permute(0, 2, 1)) - x = self.encoder_seq(x) - x = self.linear_enc_out_seq(x) - x = self.tanh(x) - return x.permute(0, 2, 1) - - def decode(self, encoded): - # decoder sequence - #x = self.pe_dec_seq(encoded.permute(0, 2, 1)) - #x = self.linear_dec_in_seq(x) - x = self.linear_dec_in_seq(encoded.permute(0, 2, 1)) - x = self.decoder_seq(x) - x = self.linear_dec_out_seq(x) - x = self.tanh(x) - - # decoder features - #x = self.pe_dec(x.permute(0, 2, 1)) - #x = self.linear_dec_in(x) - x = self.linear_dec_in(x.permute(0, 2, 1)) - x = self.decoder(x) - x = self.linear_dec_out(x) - x = self.tanh(x) - return x - - def save(self, path): - path = '../trained_ae' - file = f'ae_{pd.Timestamp.now().strftime("%Y%m%d_%H%M%S")}.pth' - # torch.save(save, os.path.join(path, file)) - -def train_model(model, dataloader, optimizer, criterion): - model.train() #Sets it into training mode - total_loss = 0 - for batch in dataloader: - optimizer.zero_grad() - inputs = batch.float() - outputs = model(inputs) - loss = criterion(outputs, inputs) - loss.backward() - optimizer.step() - total_loss += loss.item() - return total_loss / len(dataloader) - -def test_model(model, dataloader, criterion): - model.eval() - total_loss = 0 - with torch.no_grad(): - batch = dataloader.dataset[np.random.randint(0, len(dataloader), dataloader.batch_size)] - inputs = batch.float() - outputs = model(inputs) - loss = criterion(outputs, inputs) - total_loss += loss.item() - return total_loss / len(dataloader) - -def train(num_epochs, model, train_dataloader, test_dataloader, optimizer, criterion, configuration: Optional[dict] = None): - try: - train_losses = [] - test_losses = [] - trigger = True - for epoch in range(num_epochs): - train_loss = train_model(model, train_dataloader, optimizer, criterion) - test_loss = test_model(model, test_dataloader, criterion) - train_losses.append(train_loss) - test_losses.append(test_loss) - model.config['trained_epochs'][-1] += 1 - print(f"Epoch {epoch + 1}/{num_epochs} (Model Total: {str(sum(model.config['trained_epochs']))}): train_loss={train_loss:.6f}, test_loss={test_loss:.6f}") - trigger = save_checkpoint(model, epoch, trigger, 100) - return train_losses, test_losses, model - except KeyboardInterrupt: - print("keyboard interrupt detected.") - return train_losses, test_losses, model - -def save_checkpoint(model, epoch, trigger, criterion = 100): - if (epoch+1) % criterion == 0: - model_dict = dict(state_dict = model.state_dict(), config = model.config) - - # toggle between checkpoint files to avoid corrupted file during training - if trigger: - save(model_dict, 'checkpoint_01.pth', verbose=False) - trigger = False - else: - save(model_dict, 'checkpoint_02.pth', verbose=False) - trigger = True - - return trigger - -def save(model, file, path = 'trained_ae', verbose = True): - torch.save(model, os.path.join(path, file)) - if verbose: - print("Saved model and configuration to " + os.path.join(path, file)) -''' \ No newline at end of file + self.encode = mode \ No newline at end of file diff --git a/nn_architecture/vae_networks.py b/nn_architecture/vae_networks.py new file mode 100644 index 0000000..16b270b --- /dev/null +++ b/nn_architecture/vae_networks.py @@ -0,0 +1,179 @@ +import numpy as np + +import torch +import torch.nn.functional as F +from torch import nn + +import matplotlib.pyplot as plt + +class VariationalAutoencoder(nn.Module): + + def __init__(self, input_dim, hidden_dim=256, encoded_dim=25, activation='tanh', device=None, **kwargs): + super().__init__() + + #Variables + self.input_dim = input_dim + self.hidden_dim = hidden_dim + self.encoded_dim = encoded_dim + self.num_electrodes = None + if device: + self.device = device + else: + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + #Activation: Although we have options, tanh and linear are the only ones that works effectively + if activation == 'relu': + self.activation = nn.ReLU() + elif activation == 'sigmoid': + self.activation = nn.Sigmoid() + elif activation == 'tanh': + self.activation = nn.Tanh() + elif activation == 'leakyrelu': + self.activation = nn.LeakyReLU() + elif activation == 'linear': + self.activation = nn.Identity() + else: + raise ValueError(f"Activation function of type '{activation}' was not recognized.") + + #Encoder + encoding_layers = [ + nn.Linear(input_dim, hidden_dim), + self.activation + ] + self._encode = nn.Sequential(*encoding_layers) + + #Distributions + self.mu_refactor = nn.Sequential( + nn.Linear(hidden_dim, encoded_dim), + ) + + self.sigma_refactor = nn.Sequential( + nn.Linear(hidden_dim, encoded_dim), + ) + + #Decoder + decoding_layers = [ + nn.Linear(encoded_dim, hidden_dim), + self.activation, + nn.Linear(hidden_dim, input_dim), + nn.Sigmoid() + ] + self._decode = nn.Sequential(*decoding_layers) + + def encode(self, x): + + x = torch.flatten(x, start_dim=1) + x = self._encode(x) + mu = self.mu_refactor(x) + sigma = self.sigma_refactor(x) + + return mu, sigma + + def sample(self, mu, sigma): + std = torch.exp(0.5 * sigma) + z = torch.randn(std.size(0),std.size(1)) + z = z*std + mu + + return z + + def decode(self, x): + + x = self._decode(x) + x = x.reshape((x.shape[0], int(self.input_dim/self.num_electrodes), self.num_electrodes)) + + return x + + def forward(self, x): + if self.num_electrodes == None: + self.num_electrodes = x.shape[-1] + + mu, sigma = self.encode(x) + z_reparametrized = self.sample(mu, sigma) + x_reconstructed = self.decode(z_reparametrized) + + return x_reconstructed, mu, sigma + + def generate_samples(self, loader, condition=0, num_samples=2500): + + if not type(condition) == list: + condition = [condition] + + if not condition: + raise NotImplementedError('You must specify a condition to generate samples with the VAE') + else: + condition = condition[0] + + self.num_electrodes = next(iter(loader)).shape[-1] + + with torch.no_grad(): + generated_samples = np.empty((0,int(self.input_dim/self.num_electrodes)+1,self.num_electrodes)) + while generated_samples.shape[0] < num_samples: + for i, x in enumerate(loader): + y = x[:,[0],:].to(self.device) + x = x[:,1:,:].to(self.device) + mu, sigma = self.encode(x) + z = mu + sigma * torch.randn_like(sigma) + sample_decoded = self.decode(z) + gen_sample = torch.concat((y, sample_decoded), dim=1) + gen_sample = gen_sample[gen_sample[:,0,0]==condition,:,:] + + generated_samples = np.vstack((generated_samples, gen_sample.detach().numpy())) + + return generated_samples[:num_samples,:] + + def plot_samples(self, loader, epoch): + + empirical_samples = np.empty((0,int(self.input_dim/self.num_electrodes)+1,self.num_electrodes)) + for i, x in enumerate(loader): + empirical_samples = np.vstack((empirical_samples, x.detach().numpy())) + + syn0 = self.generate_samples(loader, condition=0, num_samples=2500)[:,1:,:] + syn1 = self.generate_samples(loader, condition=1, num_samples=2500)[:,1:,:] + emp0 = empirical_samples[empirical_samples[:,0,0]==0,1:,:] + emp1 = empirical_samples[empirical_samples[:,0,0]==1,1:,:] + + if self.num_electrodes == 1: + fig, ax = plt.subplots(1,2) + ax[1].plot(np.mean(syn0, axis=0), alpha=.5) + ax[1].plot(np.mean(syn1,axis=0), alpha=.5) + ax[1].set_title('VAE-Generated') + ax[0].plot(np.mean(emp0,axis=0), alpha=.5) + ax[0].plot(np.mean(emp1,axis=0), alpha=.5) + ax[0].set_title('Empirical') + else: + fig, ax = plt.subplots(2,self.num_electrodes) + for electrode_index in range(self.num_electrodes): + ax[1, electrode_index].plot(np.mean(syn0[:,:,electrode_index], axis=0), alpha=.5) + ax[1, electrode_index].plot(np.mean(syn1[:,:,electrode_index],axis=0), alpha=.5) + + ax[0, electrode_index].plot(np.mean(emp0[:,:,electrode_index],axis=0), alpha=.5) + ax[0, electrode_index].plot(np.mean(emp1[:,:,electrode_index],axis=0), alpha=.5) + + ax[1,0].set_title('VAE-Generated') + ax[0,0].set_title('Empirical') + + plt.savefig(f'generated_images/generated_average_ep{epoch}.png') + plt.close() + + for _ in range(200): + c0_sample = syn0[np.random.randint(0,len(syn0)),:] + c1_sample = syn1[np.random.randint(0,len(syn1)),:] + plt.plot(c0_sample, alpha=.1, label='c0', color='C0') + plt.plot(c1_sample, alpha=.1, label='c1', color='C1') + plt.savefig(f'generated_images/generated_trials_ep{epoch}.png') + plt.close() + + def plot_losses(self, recon_losses, kl_losses, losses): + + fig, ax = plt.subplots(3) + ax[0].plot(recon_losses) + ax[0].set_title('Reconstruction Losses') + + ax[1].plot(kl_losses) + ax[1].set_title('KL Losses') + + ax[2].plot(losses) + ax[2].set_title('Losses') + + plt.savefig(f'generated_images/vae_loss.png') + plt.close() \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 60ba4d7..ab403a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,11 @@ +torch~=1.12.1 +torchvision~=0.13.1 +torchaudio~=0.12.1 +torchsummary~=1.5.1 pandas~=1.3.4 numpy~=1.21.4 matplotlib~=3.5.0 scipy~=1.8.0 -torch~=1.12.1 -torchvision~=0.13.1 einops~=0.4.1 -torchsummary~=1.5.1 -torchaudio~=0.12.1 -scikit-learn~=1.1.2 \ No newline at end of file +scikit-learn~=1.1.2 +tqdm~=4.66.1 \ No newline at end of file diff --git a/testing_suite/ae_visualize_main.py b/testing_suite/ae_visualize_main.py index decfdb8..d6ce858 100644 --- a/testing_suite/ae_visualize_main.py +++ b/testing_suite/ae_visualize_main.py @@ -12,7 +12,7 @@ ae_checkpoint = 'trained_ae/ae_ddp_5000ep_p100_e8_enc50-4.pt' #### Load data #### -dataloader = Dataloader(data_checkpoint, col_label='Condition', channel_label='Electrode') +dataloader = Dataloader(data_checkpoint, kw_conditions='Condition', kw_channel='Electrode') dataset = dataloader.get_data().detach().numpy() norm = lambda data: (data-np.min(data)) / (np.max(data) - np.min(data)) dataset = np.concatenate((dataset[:,[0],:], norm(dataset[:,1:,:])), axis=1) diff --git a/tests/test_ae_training.py b/tests/test_ae_training.py index 5caa4ca..96e04e7 100644 --- a/tests/test_ae_training.py +++ b/tests/test_ae_training.py @@ -1,28 +1,29 @@ import sys import traceback +import os + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir))) from autoencoder_training_main import main if __name__ == '__main__': configurations = { # configurations for normal GAN - # 'basic': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "channels_out=2", "timeseries_out=10"], - 'target_time': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "target=time", "timeseries_out=10"], - 'target_channels': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "target=channels", "channels_out=1"], - 'target_full': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "target=full", "timeseries_out=10", "channels_out=1"], - # 'load_checkpoint': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "load_checkpoint"], - # 'load_checkpoint_specific_file': ["path_dataset=../data/gansMultiCondition.csv", "load_checkpoint", "path_checkpoint=trained_ae/checkpoint.pt"], + # 'target_time': ["data=data/gansMultiCondition_SHORT.csv", "target=time", "time_out=10", "save_name=ae_target_time.pt"], + # 'target_channels': ["data=data/gansMultiCondition_SHORT.csv", "target=channels", "channels_out=1", "save_name=ae_target_channels.pt"], + 'target_full': ["data=data/gansMultiCondition.csv", "target=full", "time_out=10", "channels_out=1", "save_name=ae_target_full.pt"], + # 'load_checkpoint': ["data=data/gansMultiCondition.csv", "checkpoint=x"], } # general parameters n_epochs = 1 batch_size = 32 - channel_label = "Electrode" + kw_channel = "Electrode" sample_interval = 1 for key in configurations.keys(): try: print(f"Running configuration {key}...") - sys.argv = configurations[key] + [f"n_epochs={n_epochs}", f"batch_size={batch_size}", f"channel_label={channel_label}", f"sample_interval={sample_interval}"] + sys.argv = configurations[key] + [f"n_epochs={n_epochs}", f"batch_size={batch_size}", f"kw_channel={kw_channel}", f"sample_interval={sample_interval}"] main() print(f"\nConfiguration {key} finished successfully.\n\n") # if an error occurs, print key and full error message with traceback and exit diff --git a/tests/test_gan_training.py b/tests/test_gan_training.py index b19704a..02497d0 100644 --- a/tests/test_gan_training.py +++ b/tests/test_gan_training.py @@ -1,62 +1,50 @@ import sys import traceback +import os +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir))) from gan_training_main import main if __name__ == '__main__': configurations = { # configurations for normal GAN - 'basic': ["path_dataset=./data/gansMultiCondition_SHORT.csv"], - '1condition': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "conditions=Condition"], - # '2conditions': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "conditions=Trial,Condition"], - # '2channels': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "channel_label=Electrode"], - # '2channels_1condition': ["sample_interval=1", "path_dataset=./data/gansMultiCondition_SHORT.csv", "channel_label=Electrode", "conditions=Condition"], - # '2channels_2conditions': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "channel_label=Electrode", "conditions=Trial,Condition"], - # 'prediction': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "input_sequence_length=70"], - # 'prediction_1condition': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "patch_size=20", "input_sequence_length=70", "conditions=Condition"], - # 'prediction_2conditions': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "patch_size=20", "input_sequence_length=70", "conditions=Trial,Condition"], - # 'prediction_2channels': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "patch_size=20", "input_sequence_length=70", "channel_label=Electrode"], - # 'prediction_2channels_1condition': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "patch_size=20", "input_sequence_length=70", "channel_label=Electrode", "conditions=Condition"], - # 'prediction_2channels_2conditions': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "patch_size=20", "input_sequence_length=70", "channel_label=Electrode", "conditions=Trial,Condition"], - # 'seq2seq': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "input_sequence_length=-1"], - # 'seq2seq_1condition': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "input_sequence_length=-1", "conditions=Condition"], - # 'seq2seq_2conditions': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "input_sequence_length=-1", "conditions=Trial,Condition"], - # 'seq2seq_2channels': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "input_sequence_length=-1", "channel_label=Electrode"], - # 'seq2seq_2channels_1condition': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "input_sequence_length=-1", "channel_label=Electrode", "conditions=Condition"], - # 'seq2seq_2channels_2conditions': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "input_sequence_length=-1", "channel_label=Electrode", "conditions=Trial,Condition"], - + 'basic': ["data=data/gansMultiCondition_SHORT.csv", "save_name=gan_basic.pt"], + '1condition': ["data=data/gansMultiCondition_SHORT.csv", "kw_conditions=Condition", "save_name=gan_1cond.pt"], + 'load_checkpoint': ["data=data/gansMultiCondition_SHORT.csv", "checkpoint=x", "kw_conditions=Condition"], + '2conditions': ["data=data/gansMultiCondition_SHORT.csv", "kw_conditions=Trial,Condition", "save_name=gan_2cond.pt"], + '2channels': ["data=data/gansMultiCondition_SHORT.csv", "kw_channel=Electrode", "save_name=gan_2ch.pt"], + '2channels_1condition': ["sample_interval=1", "data=data/gansMultiCondition_SHORT.csv", "kw_channel=Electrode", "kw_conditions=Condition", "save_name=gan_2ch_1cond.pt"], + '2channels_2conditions': ["data=data/gansMultiCondition_SHORT.csv", "kw_channel=Electrode", "kw_conditions=Trial,Condition", "save_name=gan_2ch_2cond.pt"], + # configurations for autoencoder GAN - # 'autoencoder_basic': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "path_autoencoder=./trained_ae/ae_gansMultiCondition_SHORT_full.pt", "channel_label=Electrode"], - # 'autoencoder_1condition': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "path_autoencoder=./trained_ae/ae_gansMultiCondition_SHORT_full.pt", "channel_label=Electrode", "conditions=Condition"], - # 'autoencoder_2conditions': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "path_autoencoder=./trained_ae/ae_gansMultiCondition_SHORT_full.pt", "channel_label=Electrode", "conditions=Trial,Condition", "hidden_dim=64", "activation=leakyrelu", "num_layers=1",], - # 'autoencoder_2conditions_channels': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "path_autoencoder=./trained_ae/ae_gansMultiCondition_SHORT_channels.pt", "channel_label=Electrode", "conditions=Trial,Condition", "hidden_dim=64", "activation=leakyrelu", "num_layers=1",], - # 'autoencoder_2conditions_time': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "path_autoencoder=./trained_ae/ae_gansMultiCondition_SHORT_time.pt", "channel_label=Electrode", "conditions=Trial,Condition", "hidden_dim=64", "activation=leakyrelu", "num_layers=1",], - # 'autoencoder_2conditions_full': ["path_dataset=./data/gansMultiCondition_SHORT.csv", "path_autoencoder=./trained_ae/ae_gansMultiCondition_SHORT_full.pt", "channel_label=Electrode", "conditions=Trial,Condition", "hidden_dim=64", "activation=leakyrelu", "num_layers=1",], - # 'autoencoder_prediction': ["path_dataset=./data/ganTrialElectrodeERP_p50_e8_len100_SHORT.csv", "path_autoencoder=./trained_ae/ae_gansMultiCondition_SHORT_full.pt", "input_sequence_length=70", "channel_label=Electrode"], - # 'autoencoder_prediction_1condition': ["path_dataset=./data/ganTrialElectrodeERP_p50_e8_len100_SHORT.csv", "path_autoencoder=./trained_ae/ae_gansMultiCondition_SHORT_full.pt", "input_sequence_length=70", "channel_label=Electrode", "conditions=Condition"], - # 'autoencoder_prediction_2conditions': ["path_dataset=./data/ganTrialElectrodeERP_p50_e8_len100_SHORT.csv", "path_autoencoder=./trained_ae/ae_gansMultiCondition_SHORT_full.pt", "input_sequence_length=70", "channel_label=Electrode", "conditions=Trial,Condition"], - # 'autoencoder_seq2seq': ["path_dataset=./data/ganTrialElectrodeERP_p50_e8_len100_SHORT.csv", "path_autoencoder=./trained_ae/ae_gansMultiCondition_SHORT_full.pt", "input_sequence_length=-1", "channel_label=Electrode"], - # 'autoencoder_seq2seq_1condition': ["path_dataset=./data/ganTrialElectrodeERP_p50_e8_len100_SHORT.csv", "path_autoencoder=./trained_ae/ae_gansMultiCondition_SHORT_full.pt", "input_sequence_length=-1", "channel_label=Electrode", "conditions=Condition"], - # 'autoencoder_seq2seq_2conditions': ["path_dataset=./data/ganTrialElectrodeERP_p50_e8_len100_SHORT.csv", "path_autoencoder=./trained_ae/ae_gansMultiCondition_SHORT_full.pt", "input_sequence_length=-1", "channel_label=Electrode", "conditions=Trial,Condition"], + 'autoencoder_basic': ["data=data/gansMultiCondition_SHORT.csv", "autoencoder=trained_ae/ae_gansMultiCondition_SHORT.pt", "kw_channel=Electrode", "save_name=gan_ae.pt"], + 'autoencoder_1condition': ["data=data/gansMultiCondition_SHORT.csv", "autoencoder=trained_ae/ae_gansMultiCondition_SHORT.pt", "kw_channel=Electrode", "kw_conditions=Condition", "save_name=gan_ae_1cond.pt"], + 'autoencoder_2conditions': ["data=data/gansMultiCondition_SHORT.csv", "autoencoder=trained_ae/ae_gansMultiCondition_SHORT.pt", "kw_channel=Electrode", "kw_conditions=Trial,Condition", "save_name=gan_ae_2cond.pt"], + 'autoencoder_2channels': ["data=data/gansMultiCondition_SHORT.csv", "autoencoder=trained_ae/ae_gansMultiCondition_SHORT.pt", "kw_channel=Electrode", "save_name=gan_ae_2ch.pt"], + 'autoencoder_2channels_1conditions': ["data=data/gansMultiCondition_SHORT.csv", "autoencoder=trained_ae/ae_gansMultiCondition_SHORT.pt", "kw_channel=Electrode", "kw_conditions=Condition", "save_name=gan_ae_2ch_1cond.pt"], + 'autoencoder_2channels_2conditions': ["data=data/gansMultiCondition_SHORT.csv", "autoencoder=trained_ae/ae_gansMultiCondition_SHORT.pt", "kw_channel=Electrode", "kw_conditions=Trial,Condition", "save_name=gan_ae_2ch_2cond.pt"], + + # 'autoencoder_2conditions_channels': ["data=data/gansMultiCondition_SHORT.csv", "autoencoder=trained_ae/ae_gansMultiCondition_SHORT_channels.pt", "kw_channel=Electrode", "kw_conditions=Trial,Condition", "hidden_dim=64", "num_layers=1",], + # 'autoencoder_2conditions_time': ["data=data/gansMultiCondition_SHORT.csv", "autoencoder=trained_ae/ae_gansMultiCondition_SHORT_time.pt", "kw_channel=Electrode", "kw_conditions=Trial,Condition", "hidden_dim=64", "num_layers=1",], + # 'autoencoder_2conditions_full': ["data=data/gansMultiCondition_SHORT.csv", "autoencoder=trained_ae/ae_gansMultiCondition_SHORT.pt", "kw_channel=Electrode", "kw_conditions=Trial,Condition", "hidden_dim=64", "num_layers=1",], + # 'load_checkpoint': ["data=data/gansMultiCondition_SHORT.csv", "checkpoint=x", "autoencoder=trained_ae/ae_gansMultiCondition_SHORT.pt", "kw_conditions=Condition", "kw_channel=Electrode"], } # general parameters n_epochs = 1 batch_size = 32 - gan_type = ['tts','ff','tr'] patch_size = 10 - for gan in gan_type: - key = None - try: - for key in configurations.keys(): - print(f"Running configuration {key}...") - sys.argv = configurations[key] + [f"n_epochs={n_epochs}", f"batch_size={batch_size}", f"type={gan}", f"patch_size={patch_size}"] - generator, discriminator, opt, gen_samples = main() - print(f"\nConfiguration {key} finished successfully.\n\n") - # if an error occurs, print key and full error message with traceback and exit - except: - print(f"Configuration {key} failed.") - traceback.print_exc() - exit(1) + key = None + try: + for key in configurations.keys(): + print(f"Running configuration {key}...") + sys.argv = configurations[key] + [f"n_epochs={n_epochs}", f"batch_size={batch_size}", f"patch_size={patch_size}"] + generator, discriminator, opt, gen_samples = main() + print(f"\nConfiguration {key} finished successfully.\n\n") + # if an error occurs, print key and full error message with traceback and exit + except: + print(f"Configuration {key} failed.") + traceback.print_exc() + exit(1) diff --git a/tests/test_generate_samples.py b/tests/test_generate_samples.py index 661c42d..8affd2a 100644 --- a/tests/test_generate_samples.py +++ b/tests/test_generate_samples.py @@ -1,41 +1,27 @@ +import os import sys import traceback + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir))) from generate_samples_main import main if __name__ == '__main__': configurations = { # configurations for normal GAN - # 'basic': ["file=..\\trained_models\\gan_1ep_basic.pt", "path_samples=..\\generated_samples\\gan_1ep_basic.csv"], - # '1condition': ["file=..\\trained_models\\gan_1ep_1cond.pt", "path_samples=..\\generated_samples\\gan_1ep_1cond.csv", "conditions=0"], - # '2conditions': ["file=..\\trained_models\\gan_1ep_2cond.pt", "path_samples=..\\generated_samples\\gan_1ep_2cond.csv", "conditions=0,1"], - # '2channel': ["file=..\\trained_models\\gan_1ep_2chan.pt", "path_samples=..\\generated_samples\\gan_1ep_2chan.csv"], - # '2channel_1condition': ["file=..\\trained_models\\gan_1ep_2chan_1cond.pt", "path_samples=..\\generated_samples\\gan_1ep_2chan_1cond.csv", "conditions=0"], - - # '2channel_2conditions': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "channel_label=Electrode", "conditions=Trial,Condition"], - # 'prediction': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "patch_size=20", "input_sequence_length=70"], - # 'prediction_1condition': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "patch_size=20", "input_sequence_length=70", "conditions=Condition"], - # 'prediction_2conditions': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "patch_size=20", "input_sequence_length=70", "conditions=Trial,Condition"], - # 'prediction_2channel': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "patch_size=20", "input_sequence_length=70", "channel_label=Electrode"], - # 'prediction_2channel_1condition': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "patch_size=20", "input_sequence_length=70", "channel_label=Electrode", "conditions=Condition"], - # 'prediction_2channel_2conditions': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "patch_size=20", "input_sequence_length=70", "channel_label=Electrode", "conditions=Trial,Condition"], - # 'seq2seq': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "input_sequence_length=-1"], - # 'seq2seq_1condition': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "input_sequence_length=-1", "conditions=Condition"], - # 'seq2seq_2conditions': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "input_sequence_length=-1", "conditions=Trial,Condition"], - # 'seq2seq_2channel': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "input_sequence_length=-1", "channel_label=Electrode"], - # 'seq2seq_2channel_1condition': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "input_sequence_length=-1", "channel_label=Electrode", "conditions=Condition"], - # 'seq2seq_2channel_2conditions': ["path_dataset=../data/gansMultiCondition_SHORT.csv", "input_sequence_length=-1", "channel_label=Electrode", "conditions=Trial,Condition"], - + 'basic': ["model=trained_models/gan_basic.pt"], + '1condition': ["model=trained_models/gan_1cond.pt", "conditions=0"], + '2conditions': ["model=trained_models/gan_2cond.pt", "conditions=0,1"], + '2channel': ["model=trained_models/gan_2ch.pt", "save_name=generated_samples/gan_2ch.csv"], + '2channel_1condition': ["model=trained_models/gan_2ch_1cond.pt", "conditions=0"], + '2channel_2conditions': ["model=trained_models/gan_2ch_2cond.pt", "conditions=0,1"], + # configurations for autoencoder GAN - # 'autoencoder_basic': ["path_file=..\\trained_models\\checkpoint.pt", "path_samples=..\\generated_samples\\test.csv"], - 'autoencoder_1condition': ["path_file=..\\trained_models\\gan_ddp_2000ep_20230913_191233.pt", "path_samples=..\\generated_samples\\gan_1ep_ae_1cond.csv", "conditions=0"], - # 'autoencoder_2conditions': ["path_file=..\\trained_models\\checkpoint.pt", "path_samples=..\\generated_samples\\gan_1ep_ae_2cond.csv", "conditions=0,1"], - - # 'autoencoder_prediction': ["path_dataset=../data/ganTrialElectrodeERP_p50_e8_len100_SHORT.csv", "gan_type=autoencoder", "input_sequence_length=70", "channel_label=Electrode"], - # 'autoencoder_prediction_1condition': ["path_dataset=../data/ganTrialElectrodeERP_p50_e8_len100_SHORT.csv", "gan_type=autoencoder", "input_sequence_length=70", "channel_label=Electrode", "conditions=Condition"], - # 'autoencoder_prediction_2conditions': ["path_dataset=../data/ganTrialElectrodeERP_p50_e8_len100_SHORT.csv", "gan_type=autoencoder", "input_sequence_length=70", "channel_label=Electrode", "conditions=Trial,Condition"], - # 'autoencoder_seq2seq': ["path_dataset=../data/ganTrialElectrodeERP_p50_e8_len100_SHORT.csv", "gan_type=autoencoder", "input_sequence_length=-1", "channel_label=Electrode"], - # 'autoencoder_seq2seq_1condition': ["path_dataset=../data/ganTrialElectrodeERP_p50_e8_len100_SHORT.csv", "gan_type=autoencoder", "input_sequence_length=-1", "channel_label=Electrode", "conditions=Condition"], - # 'autoencoder_seq2seq_2conditions': ["path_dataset=../data/ganTrialElectrodeERP_p50_e8_len100_SHORT.csv", "gan_type=autoencoder", "input_sequence_length=-1", "channel_label=Electrode", "conditions=Trial,Condition"], + 'ae_basic': ["model=trained_models/gan_ae.pt"], + 'ae_1condition': ["model=trained_models/gan_ae_1cond.pt", "conditions=0"], + 'ae_2conditions': ["model=trained_models/gan_ae_2cond.pt", "conditions=0,1"], + 'ae_2channel': ["model=trained_models/gan_ae_2ch.pt"], + 'ae_2channel_1condition': ["model=trained_models/gan_ae_2ch_1cond.pt", "conditions=0"], + 'ae_2channel_2conditions': ["model=trained_models/gan_ae_2ch_2cond.pt", "conditions=0,1"], } key = None diff --git a/tests/test_multi-electrode.py b/tests/test_multi-electrode.py index 6b6ed6e..83c875f 100644 --- a/tests/test_multi-electrode.py +++ b/tests/test_multi-electrode.py @@ -30,8 +30,8 @@ def generate_fake_data(n_channels=1, label_channels=False, data_path=None): def run_test_reshaping_data(data_path, n_channels, ): dataloader = Dataloader(data_path, - kw_timestep='Time', - col_label=['Condition', 'Electrode'], + kw_time='Time', + kw_conditions=['Condition', 'Electrode'], n_channels=n_channels) dataset = dataloader.get_data(sequence_length=-1) # channel is in right dimension diff --git a/tests/test_visualization.py b/tests/test_visualization.py new file mode 100644 index 0000000..98a00f4 --- /dev/null +++ b/tests/test_visualization.py @@ -0,0 +1,56 @@ +import sys +import os +import traceback + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir))) +from visualize_main import main + +if __name__ == '__main__': + configurations = { + # configuration for real data + 'dataset': ["data=data/testMultiConditionMultiChannel.csv"], + 'dataset_2condition': ["data=data/testMultiConditionMultiChannel.csv", "kw_conditions=Cond1,Cond2"], + 'dataset_2condition_2channels': ["data=data/testMultiConditionMultiChannel.csv", "kw_conditions=Cond1,Cond2", "kw_channel=Channel"], + 'dataset_2condition_channelplot': ["data=data/testMultiConditionMultiChannel.csv", "kw_conditions=Cond1,Cond2", "channel_plots"], + 'dataset_2condition_2channels_channelindex': ["data=data/testMultiConditionMultiChannel.csv", "kw_conditions=Cond1,Cond2", "kw_channel=Channel", "channel_index=0"], + 'dataset_2condition_2channels_channelplot': ["data=data/testMultiConditionMultiChannel.csv", "kw_conditions=Cond1,Cond2", "kw_channel=Channel", "channel_plots"], + 'dataset_avg': ["data=data/testMultiConditionMultiChannel.csv", "average"], + 'dataset_2condition_avg': ["data=data/testMultiConditionMultiChannel.csv", "kw_conditions=Cond1,Cond2", "average"], + 'dataset_1condition_2channels_avg': ["data=data/testMultiConditionMultiChannel.csv", "kw_conditions=Cond1,Cond2", "kw_channel=Channel", "average"], + 'dataset_pca': ["data=data/testMultiConditionMultiChannel.csv", "pca", "comp_data=data/testCompMultiConditionMultiChannel.csv", "kw_conditions=Cond1,Cond2", "kw_channel=Channel"], + 'dataset_tsne': ["data=data/testMultiConditionMultiChannel.csv", "tsne", "comp_data=data/testCompMultiConditionMultiChannel.csv", "kw_conditions=Cond1,Cond2", "kw_channel=Channel"], + + # configurations for synthetic data + 'synt': ["data=generated_samples/gan_basic.csv"], + 'synt_ae': ["data=generated_samples/gan_ae_2ch.csv", "kw_channel=Electrode"], + 'synt_spectogram': ["data=generated_samples/gan_ae_2ch.csv", "kw_channel=Electrode", "spectogram"], + 'synt_fft': ["data=generated_samples/gan_ae_2ch.csv", "kw_channel=Electrode", "fft"], + + # configurations for normal GAN + 'basic': ["model=trained_models/gan_basic.pt"], + 'basic_2condition': ["model=trained_models/gan_2cond.pt"], + 'basic_2condition_2channels': ["model=trained_models/gan_2ch_2cond.pt"], + 'basic_2condition_2_channels_channelplot': ["model=trained_models/gan_2ch_2cond.pt", "channel_plots"], + + # configurations for autoencoder GAN + 'gan_ae_basic': ["model=trained_models/gan_ae.pt"], + + # configuration for autoencoder + 'ae_basic': ["model=trained_ae/ae_target_full.pt"], + 'ae_pca': ["model=trained_ae/ae_target_full.pt", "pca"], + } + + n_samples = 4 + + key = None + try: + for key in configurations.keys(): + print(f"Running configuration {key}...") + sys.argv = configurations[key] + [f"n_samples={n_samples}"] + main() + print(f"\nConfiguration {key} finished successfully.\n\n") + # if an error occurs, print key and full error message with traceback and exit + except: + print(f"Configuration {key} failed.") + traceback.print_exc() + exit(1) diff --git a/vae_training_main.py b/vae_training_main.py new file mode 100644 index 0000000..7d4e16c --- /dev/null +++ b/vae_training_main.py @@ -0,0 +1,144 @@ +import os +import sys +import multiprocessing as mp +from datetime import datetime + +import numpy as np +import torch +from torch.utils.data import DataLoader + +from helpers import system_inputs +from helpers.dataloader import Dataloader +from helpers.trainer import VAETrainer +from helpers.get_master import find_free_port +from helpers.ddp_training import run#, VAEDDPTrainer +from nn_architecture.vae_networks import VariationalAutoencoder + +def main(): + # ------------------------------------------------------------------------------------------------------------------ + # Configure training parameters + # ------------------------------------------------------------------------------------------------------------------ + + default_args = system_inputs.parse_arguments(sys.argv, file='vae_training_main.py') + print('-----------------------------------------\n') + + if default_args['load_checkpoint']: + print(f'Resuming training from checkpoint {default_args["path_checkpoint"]}.') + + #User input + opt = { + 'data': default_args['data'], + 'path_checkpoint': default_args['path_checkpoint'], + 'save_name': default_args['save_name'], + 'sample_interval': default_args['sample_interval'], + 'kw_channel': default_args['kw_channel'], + 'kw_conditions': default_args['kw_conditions'], + 'n_epochs': default_args['n_epochs'], + 'batch_size': default_args['batch_size'], + 'learning_rate': default_args['learning_rate'], + 'hidden_dim': default_args['hidden_dim'], + 'encoded_dim': default_args['encoded_dim'], + 'activation': default_args['activation'], + 'kl_alpha': default_args['kl_alpha'], + 'norm_data': True, + 'std_data': False, + 'diff_data': False, + 'kw_time': default_args['kw_time'], + 'world_size': torch.cuda.device_count() if torch.cuda.is_available() else mp.cpu_count(), + 'history': None, + 'trained_epochs': 0 + } + + #opt['device'] = torch.device("cuda" if torch.cuda.is_available() and opt['ddp'] else "cpu") + opt['device'] = torch.device("cpu") + + # raise warning if no normalization and standardization is used at the same time + if opt['std_data'] and opt['norm_data']: + raise Warning("Standardization and normalization are used at the same time.") + + # ---------------------------------------------------------------------------------------------------------------------- + # Load, process, and split data + # ---------------------------------------------------------------------------------------------------------------------- + data = Dataloader(path=opt['data'], + kw_channel=opt['kw_channel'], + kw_conditions=opt['kw_conditions'], + kw_time=opt['kw_time'], + norm_data=opt['norm_data'], + std_data=opt['std_data'], + diff_data=opt['diff_data']) + dataset = data.get_data() + + opt['input_dim'] = (dataset.shape[1] - len(opt['kw_conditions'])) * dataset.shape[-1] + + # ------------------------------------------------------------------------------------------------------------------ + # Load VAE checkpoint and populate configuration + # ------------------------------------------------------------------------------------------------------------------ + + # Load VAE + model_dict = None + if default_args['load_checkpoint'] and os.path.isfile(opt['path_checkpoint']): + model_dict = torch.load(opt['path_checkpoint']) + elif default_args['load_checkpoint'] and not os.path.isfile(opt['path_checkpoint']): + raise FileNotFoundError(f"Checkpoint file {opt['path_checkpoint']} not found.") + + # Populate model configuration + history = {} + for key in opt.keys(): + if (not key == 'history') | (not key == 'trained_epochs'): + history[key] = [opt[key]] + history['trained_epochs'] = [] + + if model_dict is not None: + # update history + for key in history.keys(): + history[key] = model_dict['configuration']['history'][key] + history[key] + opt['history'] = history + + # ------------------------------------------------------------------------------------------------------------------ + # Initiate VAE + # ------------------------------------------------------------------------------------------------------------------ + + model = VariationalAutoencoder(input_dim=opt['input_dim'], + hidden_dim=opt['hidden_dim'], + encoded_dim=opt['encoded_dim'], + activation=opt['activation'], + device=opt['device']).to(opt['device']) + + print('Variational autoencoder initialized') + + # ------------------------------------------------------------------------------------------------------------------ + # Train VAE + # ------------------------------------------------------------------------------------------------------------------ + + # VAE-Training + print('\n-----------------------------------------') + print("Training VAE...") + print('-----------------------------------------\n') + + trainer = VAETrainer(model, opt) + if default_args['load_checkpoint']: + trainer.load_checkpoint(default_args['path_checkpoint']) + dataset = DataLoader(dataset, batch_size=trainer.batch_size, shuffle=True) + gen_samples = trainer.training(dataset) + + # save final models, optimizer states, generated samples, losses and configuration as final result + if not opt['save_name']: + path = 'trained_vae' + timestamp = datetime.now().strftime('%Y%m%d_%H%M%S') + filename = f'vae_{trainer.epochs}ep_' + timestamp + '.pt' + save_filename = os.path.join(path, filename) + else: + save_filename = opt['save_name'] + trainer.save_checkpoint(path_checkpoint=save_filename, samples=gen_samples, update_history=True) + + print(f"Checkpoint saved to {default_args['path_checkpoint']}.") + + model = trainer.model + + print("VAE training finished.") + print(f"Model states and generated samples saved to file {save_filename}.") + + return model, opt, gen_samples + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/visualize_main.py b/visualize_main.py index 196d769..d0d59dc 100644 --- a/visualize_main.py +++ b/visualize_main.py @@ -18,45 +18,39 @@ def main(): print('\n-----------------------------------------') print("System output:") print('-----------------------------------------\n') - - if default_args['csv'] + default_args['checkpoint'] != 1: # + default_args['experiment'] - raise ValueError("Please specify only one of the following arguments: csv, checkpoint") + + if default_args['data'] != '' and default_args['model'] != '': + raise ValueError("Please specify only one of the following arguments: data, model") if default_args['channel_index'][0] > -1 and (default_args['pca'] or default_args['tsne']): print("Warning: channel_index is set to a specific value, but PCA or t-SNE is enabled.\n" "PCA and t-SNE are only available for all channels. Ignoring channel_index.") # throw error if checkpoint but csv-file is specified - if default_args['checkpoint'] and default_args['path_dataset'].split('.')[-1] == 'csv': - raise ValueError("Inconsistent parameter specification. 'checkpoint' was specified but a csv-file was given.") + if default_args['model'] != '' and not default_args['model'].endswith('.pt'): + raise ValueError("Inconsistent parameter specification. 'model' was specified but no model-file (.pt) was given.") + if default_args['data'] != '' and not default_args['data'].endswith('.csv'): + raise ValueError("Inconsistent parameter specification. 'data' was specified but no csv-file was given.") # throw warning if checkpoint and conditions are given - if default_args['checkpoint'] and default_args['conditions'][0] != '': - warnings.warn("Conditions are given, but checkpoint is specified. Given conditions will be ignored and taken from the checkpoint file if the checkpoint file contains the conditions parameter.") + if default_args['model'] != '' and default_args['kw_conditions'][0] != '': + warnings.warn("Conditions are given, but model is specified. Given conditions will be ignored and taken from the model file if the model file contains the conditions parameter.") original_data = None - if default_args['csv']: - n_conditions = len(default_args['conditions']) if default_args['conditions'][0] != '' else 0 + if default_args['data'] != '': + n_conditions = len(default_args['kw_conditions']) if default_args['kw_conditions'][0] != '' else 0 # load data with DataLoader - dataloader = Dataloader(path=default_args['path_dataset'], + dataloader = Dataloader(path=default_args['data'], norm_data=True, - kw_timestep=default_args['kw_timestep'], - col_label=default_args['conditions'], - channel_label=default_args['channel_label'], ) + kw_time=default_args['kw_time'], + kw_conditions=default_args['kw_conditions'], + kw_channel=default_args['kw_channel'],) data = dataloader.get_data(shuffle=False)[:, n_conditions:].numpy() conditions = dataloader.get_labels()[:, :, 0].numpy() random = True - elif default_args['checkpoint']: - state_dict = torch.load(default_args['path_dataset'], map_location='cpu') + elif default_args['model'] != '': + state_dict = torch.load(default_args['model'], map_location='cpu') n_conditions = state_dict['configuration']['n_conditions'] if 'n_conditions' in state_dict['configuration'].keys() else 0 - if (n_conditions == 0) & (default_args['conditions'][0] != ''): - dataloader = Dataloader(path=default_args['path_comp_dataset'], - norm_data=True, - kw_timestep=default_args['kw_timestep'], - col_label=default_args['conditions'], - channel_label=default_args['channel_label'], ) - n_conditions = dataloader.get_labels()[:, :, 0].numpy().shape[-1] - sequence_length_generated = state_dict['configuration']['sequence_length_generated'] if 'sequence_length_generated' in state_dict['configuration'].keys() else 0 data = np.concatenate(state_dict['samples']) if len(data.shape) == 2: data = data.reshape((1, data.shape[0], data.shape[1])) @@ -142,7 +136,7 @@ def main(): # ----------------------------- try: - if default_args['loss'] and not default_args['checkpoint']: + if default_args['loss'] and default_args['model'] == '': raise ValueError("Loss plotting only available for checkpoint and not csv") elif default_args['loss']: print("Plotting losses...") @@ -200,8 +194,7 @@ def main(): axs[jcol].plot(averaged_data[i, :, j]) else: axs[i, jcol].plot(averaged_data[i, :, j]) - # axs[i].set_title(f'condition {cond}') - # set legend at the right hand side of the plot; + # set legend at the right hand side of the plot; # legend carries the condition information # make graph and legend visible within the figure if not default_args['channel_plots']: @@ -221,16 +214,16 @@ def main(): # ----------------------------- if default_args['pca'] or default_args['tsne']: - if original_data is None and default_args['path_comp_dataset'] != '': + if original_data is None and default_args['comp_data'] != '': # load comparison data - dataloader_comp = Dataloader(path=default_args['path_comp_dataset'], + dataloader_comp = Dataloader(path=default_args['comp_data'], norm_data=True, - kw_timestep=default_args['kw_timestep'], - col_label=default_args['conditions'], - channel_label=default_args['channel_label'], ) + kw_time=default_args['kw_time'], + kw_conditions=default_args['kw_conditions'], + kw_channel=default_args['kw_channel'], ) original_data = dataloader_comp.get_data(shuffle=False)[:, n_conditions:].numpy() - elif original_data is None and default_args['path_comp_dataset'] == '': - raise ValueError("No comparison data found for PCA or t-SNE. Please specify a comparison dataset with the argument 'path_comp_dataset'.") + elif original_data is None and default_args['comp_data'] == '': + raise ValueError("No comparison data found for PCA or t-SNE. Please specify a comparison dataset with the argument 'comp_data'.") if default_args['pca']: print("Plotting PCA...") @@ -268,23 +261,4 @@ def main(): if __name__ == '__main__': - # sys.argv = [ - # # 'csv', - # # 'path_dataset=generated_samples/gan_1ep_2chan_1cond.csv', - # 'checkpoint', - # 'path_dataset=trained_ae/ae_gansMultiCondition.pt', - # # 'conditions=Condition', - # 'channel_label=Electrode', - # 'n_samples=8', - # # 'channel_plots', - # # 'channel_index=0', - # 'loss', - # # 'average', - # # 'spectogram', - # # 'fft', - # 'pca', - # 'tsne', - # # 'path_comp_dataset=data/gansMultiCondition_SHORT.csv', - # # 'path_comp_dataset=data/gansMultiCondition.csv', - # ] main() \ No newline at end of file