From 63da5dcd397c0b132071b8af2d1bbdd50bc11a06 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Tue, 26 Mar 2024 15:57:07 -0600 Subject: [PATCH 01/30] adding parseargs to io --- src/scripts/io.py | 146 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 134 insertions(+), 12 deletions(-) diff --git a/src/scripts/io.py b/src/scripts/io.py index b216d3c..6aebc80 100644 --- a/src/scripts/io.py +++ b/src/scripts/io.py @@ -1,13 +1,71 @@ # Contains modules used to prepare a dataset # with varying noise properties - +import argparse import numpy as np +from sklearn.model_selection import train_test_split import pickle from torch.distributions import Uniform +from torch.utils.data import DataLoader, TensorDataset import torch import h5py +def parse_args(): + parser = argparse.ArgumentParser( + description="data handling module" + ) + parser.add_argument( + "size_df", + type=float, + required=False, + default=1000, + help="Used to load the associated .h5 data file", + ) + parser.add_argument( + "noise_level", + type=str, + required=False, + default='low', + help="low, medium, high or vhigh, used to look up associated sigma value", + ) + parser.add_argument( + "size_df", + type=str, + nargs="?", + default="/repo/embargo", + help="Butler Repository path from which data is transferred. \ + Input str. Default = '/repo/embargo'", + ) + parser.add_argument( + "--normalize", + required=False, + action="store_true", + help="If true theres an option to normalize the dataset", + ) + parser.add_argument( + "--val_proportion", + type=float, + required=False, + default=0.1, + help="Proportion of the dataset to use as validation", + ) + parser.add_argument( + "--randomseed", + type=float, + required=False, + default=42, + help="Random seed used for shuffling the training and validation set", + ) + parser.add_argument( + "--batchsize", + type=float, + required=False, + default=100, + help="Size of batched used in the traindataloader", + ) + return parser.parse_args() + + class ModelLoader: def save_model_pkl(self, path, model_name, posterior): """ @@ -208,17 +266,81 @@ def get_dict(self): def get_data(self): return self.data + def get_sigma(noise): + if noise == 'low': + sigma = 1 + if noise == 'medium': + sigma = 5 + if noise == 'high': + sigma = 10 + if noise == 'vhigh': + sigma = 100 + return sigma + + def normalize(inputs, + ys_array, + norm=False): + if norm: + # normalize everything before it goes into a network + inputmin = np.min(inputs, axis=0) + inputmax = np.max(inputs, axis=0) + outputmin = np.min(ys_array) + outputmax = np.max(ys_array) + model_inputs = (inputs - inputmin) / (inputmax - inputmin) + model_outputs = (ys_array - outputmin) / (outputmax - outputmin) + else: + model_inputs = inputs + model_outputs = ys_array + return model_inputs, model_outputs + + def train_val_split(model_inputs, + model_outputs, + val_proportion=0.1, + random_state=42): + x_train, x_val, y_train, y_val = train_test_split(model_inputs, + model_outputs, + test_size=val_proportion, + random_state=random_state) + return x_train, x_val, y_train, y_val + # Example usage: if __name__ == "__main__": - # Replace 'your_dataset.csv' with your actual dataset file path - dataset_manager = DataPreparation("your_dataset.csv") - dataset_manager.load_data() - dataset_manager.preprocess_data() - - # Simulate linear data - dataset_manager.simulate_data("linear") - - # Access the simulated data - simulated_data = dataset_manager.get_data() - print(simulated_data.head()) + namespace = parse_args() + size_df = namespace.size_df + noise = namespace.noise_level + norm = namespace.normalize + val_prop = namespace.val_proportion + rs = namespace.randomseed + BATCH_SIZE = namespace.batchsize + sigma = DataPreparation.get_sigma(noise) + loader = DataLoader() + data = loader.load_data_h5('linear_sigma_'+str(sigma)+'_size_'+str(size_df)) + len_df = len(data['params'][:, 0].numpy()) + len_x = len(data['inputs'].numpy()) + ms_array = np.repeat(data['params'][:, 0].numpy(), len_x) + bs_array = np.repeat(data['params'][:, 1].numpy(), len_x) + xs_array = np.tile(data['inputs'].numpy(), len_df) + ys_array = np.reshape(data['output'].numpy(), (len_df * len_x)) + inputs = np.array([xs_array, ms_array, bs_array]).T + model_inputs, model_outputs = DataPreparation.normalize(inputs, + ys_array, + norm) + x_train, x_val, y_train, y_val = DataPreparation.train_val_split(model_inputs, + model_outputs, + test_size=val_prop, + random_state=rs) + trainData = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train)) + trainDataLoader = DataLoader(trainData, + batch_size=BATCH_SIZE, + shuffle=True) + ''' + valData = TensorDataset(torch.Tensor(x_val), torch.Tensor(y_val)) + valDataLoader = DataLoader(valData, + batch_size=BATCH_SIZE) + + # calculate steps per epoch for training and validation set + trainSteps = len(trainDataLoader.dataset) // BATCH_SIZE + valSteps = len(valDataLoader.dataset) // BATCH_SIZE + ''' + return trainDataLoader, x_val, y_val From 1a8e07936570442382f3c9a8d62f76f2abaa9017 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Thu, 28 Mar 2024 11:04:52 -0600 Subject: [PATCH 02/30] added all argparse to run dataloader and train --- src/scripts/DeepEnsemble.py | 236 +++++++++++++++++++++++++++++------- 1 file changed, 189 insertions(+), 47 deletions(-) diff --git a/src/scripts/DeepEnsemble.py b/src/scripts/DeepEnsemble.py index 99d6417..e25cf77 100644 --- a/src/scripts/DeepEnsemble.py +++ b/src/scripts/DeepEnsemble.py @@ -1,95 +1,237 @@ import argparse import logging +import numpy as np +import torch +from torch.utils.data import DataLoader, TensorDataset +from scripts import train, models, analysis, io + + +def beta_type(value): + if isinstance(value, float): + return value + elif value.lower() == 'linear_decrease': + return value + elif value.lower() == 'step_decrease_to_0.5': + return value + elif value.lower() == 'step_decrease_to_1.0': + return value + else: + raise argparse.ArgumentTypeError("BETA must be a float or one of 'linear_decrease', 'step_decrease_to_0.5', 'step_decrease_to_1.0'") def parse_args(): parser = argparse.ArgumentParser( - description="Transferring data from embargo butler to another butler" + description="data handling module" + ) + parser.add_argument( + "--size_df", + type=float, + required=False, + default=1000, + help="Used to load the associated .h5 data file", ) - - # at least one arg in dataId needed for 'where' clause. parser.add_argument( - "fromrepo", + "noise_level", + type=str, + default='low', + help="low, medium, high or vhigh, used to look up associated sigma value", + ) + ''' + parser.add_argument( + "size_df", type=str, nargs="?", default="/repo/embargo", help="Butler Repository path from which data is transferred. \ Input str. Default = '/repo/embargo'", ) + ''' parser.add_argument( - "torepo", - type=str, - help="Repository to which data is transferred. Input str", + "--normalize", + required=False, + action="store_true", + help="If true theres an option to normalize the dataset", ) parser.add_argument( - "instrument", - type=str, - nargs="?", - default="LATISS", - help="Instrument. Input str", + "--val_proportion", + type=float, + required=False, + default=0.1, + help="Proportion of the dataset to use as validation", ) parser.add_argument( - "--embargohours", + "--randomseed", type=float, required=False, - default=80.0, - help="Embargo time period in hours. Input float", + default=42, + help="Random seed used for shuffling the training and validation set", ) parser.add_argument( - "--datasettype", + "--batchsize", + type=float, required=False, - nargs="+", - # default=[] - help="Dataset type. Input list or str", + default=100, + help="Size of batched used in the traindataloader", + ) + # now args for model + parser.add_argument( + "n_models", + type=float, + default=100, + help="Number of MVEs in the ensemble", ) parser.add_argument( - "--collections", - # type=str, - nargs="+", + "--init_lr", + type=float, required=False, - default="LATISS/raw/all", - help="Data Collections. Input list or str", + default=0.001, + help="Learning rate", ) parser.add_argument( - "--nowtime", + "--loss_type", type=str, required=False, - default="now", - help="Now time in (ISO, TAI timescale). If left blank it will \ - use astropy.time.Time.now.", + default="bnn_loss", + help="Loss types for MVE, options are no_var_loss, var_loss, and bnn_loss", ) parser.add_argument( - "--move", - type=str, + "--BETA", + type=beta_type, required=False, - default="False", - help="Copies if False, deletes original if True", + default=0.5, + help="If loss_type is bnn_loss, specify a beta as a float or there are string options: linear_decrease, step_decrease_to_0.5, and step_decrease_to_1.0", ) parser.add_argument( - "--log", + "--model_type", type=str, required=False, - default="False", - help="No logging if False, longlog if True", + default="DE", + help="Beginning of name for saved checkpoints and figures", + ) + parser.add_argument( + "--n_epochs", + type=float, + required=False, + default=100, + help="number of epochs for each MVE", ) parser.add_argument( - "--desturiprefix", + "--path_to_models", type=str, required=False, - default="False", - help="Define dest uri if you need to run ingest for raws", + default="models/", + help="path to where the checkpoints are saved", + ) + parser.add_argument( + "--save_all_checkpoints", + type=bool, + required=False, + default=False, + help="option to save all checkpoints", + ) + parser.add_argument( + "--save_final_checkpoints", + type=bool, + required=False, + default=False, + help="option to save the final epoch checkpoint for each ensemble", + ) + parser.add_argument( + "--overwrite_final_checkpoints", + type=bool, + required=False, + default=False, + help="option to overwite already saved checkpoints", + ) + parser.add_argument( + "--plot", + type=bool, + required=False, + default=True, + help="option to plot in notebook", + ) + parser.add_argument( + "--savefig", + type=bool, + required=False, + default=True, + help="option to save a figure of the true and predicted values", + ) + parser.add_argument( + "--verbose", + type=bool, + required=False, + default=False, + help="verbose option for train", ) return parser.parse_args() if __name__ == "__main__": namespace = parse_args() - # Define embargo and destination butler - # If move is true, then you'll need write - # permissions from the fromrepo (embargo) - dest_butler = namespace.torepo - if namespace.log == "True": - # CliLog.initLog(longlog=True) - logger = logging.getLogger("lsst.transfer.embargo") - logger.info("from path: %s", namespace.fromrepo) - logger.info("to path: %s", namespace.torepo) + size_df = namespace.size_df + noise = namespace.noise_level + norm = namespace.normalize + val_prop = namespace.val_proportion + rs = namespace.randomseed + BATCH_SIZE = namespace.batchsize + sigma = io.DataPreparation.get_sigma(noise) + loader = io.DataLoader() + data = loader.load_data_h5('linear_sigma_'+str(sigma)+'_size_'+str(size_df)) + len_df = len(data['params'][:, 0].numpy()) + len_x = len(data['inputs'].numpy()) + ms_array = np.repeat(data['params'][:, 0].numpy(), len_x) + bs_array = np.repeat(data['params'][:, 1].numpy(), len_x) + xs_array = np.tile(data['inputs'].numpy(), len_df) + ys_array = np.reshape(data['output'].numpy(), (len_df * len_x)) + inputs = np.array([xs_array, ms_array, bs_array]).T + model_inputs, model_outputs = io.DataPreparation.normalize(inputs, + ys_array, + norm) + x_train, x_val, y_train, y_val = io.DataPreparation.train_val_split(model_inputs, + model_outputs, + test_size=val_prop, + random_state=rs) + trainData = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train)) + trainDataLoader = DataLoader(trainData, + batch_size=BATCH_SIZE, + shuffle=True) + ''' + valData = TensorDataset(torch.Tensor(x_val), torch.Tensor(y_val)) + valDataLoader = DataLoader(valData, + batch_size=BATCH_SIZE) + + # calculate steps per epoch for training and validation set + trainSteps = len(trainDataLoader.dataset) // BATCH_SIZE + valSteps = len(valDataLoader.dataset) // BATCH_SIZE + + return trainDataLoader, x_val, y_val + ''' + print("[INFO] initializing the gal model...") + # set the device we will be using to train the model + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + model_name = namespace.model_type + '_noise_' + noise + + model, lossFn = models.model_setup_DE(namespace.loss_type, DEVICE) + + model_ensemble = train.train_DE(trainDataLoader, + x_val, + y_val, + namespace.init_lr, + DEVICE, + namespace.loss_type, + namespace.n_models, + model_name, + BETA=namespace.BETA, + EPOCHS=namespace.n_epochs, + path_to_model=namespace.path_to_model, + save_all_checkpoints=namespace.save_all_checkpoints, + save_final_checkpoint=namespace.save_final_checkpoint, + overwrite_final_checkpoint=namespace.overwrite_final_checkpoint, + plot=namespace.plot, + savefig=namespace.savefig, + verbose=namespace.verbose + ) + + From 0412dfcc15f5c7c7a409df5bd8a3d4a072ad7008 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Thu, 28 Mar 2024 11:05:13 -0600 Subject: [PATCH 03/30] option to run io as argparse --- src/scripts/io.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/scripts/io.py b/src/scripts/io.py index 6aebc80..f6e7afa 100644 --- a/src/scripts/io.py +++ b/src/scripts/io.py @@ -343,4 +343,4 @@ def train_val_split(model_inputs, trainSteps = len(trainDataLoader.dataset) // BATCH_SIZE valSteps = len(valDataLoader.dataset) // BATCH_SIZE ''' - return trainDataLoader, x_val, y_val + #return trainDataLoader, x_val, y_val From 644f5850923b12832ec1b90acf43a4315127328c Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 07:28:09 -0600 Subject: [PATCH 04/30] argparse running and saving --- src/scripts/DeepEnsemble.py | 31 ++++++++++++++++++------------- src/scripts/models.py | 2 +- src/scripts/train.py | 20 +++++++++++--------- 3 files changed, 30 insertions(+), 23 deletions(-) diff --git a/src/scripts/DeepEnsemble.py b/src/scripts/DeepEnsemble.py index e25cf77..5bf0f47 100644 --- a/src/scripts/DeepEnsemble.py +++ b/src/scripts/DeepEnsemble.py @@ -61,14 +61,14 @@ def parse_args(): ) parser.add_argument( "--randomseed", - type=float, + type=int, required=False, default=42, help="Random seed used for shuffling the training and validation set", ) parser.add_argument( "--batchsize", - type=float, + type=int, required=False, default=100, help="Size of batched used in the traindataloader", @@ -76,7 +76,7 @@ def parse_args(): # now args for model parser.add_argument( "n_models", - type=float, + type=int, default=100, help="Number of MVEs in the ensemble", ) @@ -91,7 +91,7 @@ def parse_args(): "--loss_type", type=str, required=False, - default="bnn_loss", + default="bnll_loss", help="Loss types for MVE, options are no_var_loss, var_loss, and bnn_loss", ) parser.add_argument( @@ -101,6 +101,11 @@ def parse_args(): default=0.5, help="If loss_type is bnn_loss, specify a beta as a float or there are string options: linear_decrease, step_decrease_to_0.5, and step_decrease_to_1.0", ) + parser.add_argument( + "wd", + type=str, + help="Top level of directory", + ) parser.add_argument( "--model_type", type=str, @@ -110,7 +115,7 @@ def parse_args(): ) parser.add_argument( "--n_epochs", - type=float, + type=int, required=False, default=100, help="number of epochs for each MVE", @@ -147,7 +152,7 @@ def parse_args(): "--plot", type=bool, required=False, - default=True, + default=False, help="option to plot in notebook", ) parser.add_argument( @@ -177,7 +182,8 @@ def parse_args(): BATCH_SIZE = namespace.batchsize sigma = io.DataPreparation.get_sigma(noise) loader = io.DataLoader() - data = loader.load_data_h5('linear_sigma_'+str(sigma)+'_size_'+str(size_df)) + data = loader.load_data_h5('linear_sigma_'+str(sigma)+'_size_'+str(size_df), + path='/Users/rnevin/Documents/DeepUQ/data/') len_df = len(data['params'][:, 0].numpy()) len_x = len(data['inputs'].numpy()) ms_array = np.repeat(data['params'][:, 0].numpy(), len_x) @@ -190,7 +196,7 @@ def parse_args(): norm) x_train, x_val, y_train, y_val = io.DataPreparation.train_val_split(model_inputs, model_outputs, - test_size=val_prop, + val_proportion=val_prop, random_state=rs) trainData = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train)) trainDataLoader = DataLoader(trainData, @@ -212,9 +218,7 @@ def parse_args(): DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_name = namespace.model_type + '_noise_' + noise - model, lossFn = models.model_setup_DE(namespace.loss_type, DEVICE) - model_ensemble = train.train_DE(trainDataLoader, x_val, y_val, @@ -222,13 +226,14 @@ def parse_args(): DEVICE, namespace.loss_type, namespace.n_models, + namespace.wd, model_name, BETA=namespace.BETA, EPOCHS=namespace.n_epochs, - path_to_model=namespace.path_to_model, + path_to_model=namespace.path_to_models, save_all_checkpoints=namespace.save_all_checkpoints, - save_final_checkpoint=namespace.save_final_checkpoint, - overwrite_final_checkpoint=namespace.overwrite_final_checkpoint, + save_final_checkpoint=namespace.save_final_checkpoints, + overwrite_final_checkpoint=namespace.overwrite_final_checkpoints, plot=namespace.plot, savefig=namespace.savefig, verbose=namespace.verbose diff --git a/src/scripts/models.py b/src/scripts/models.py index b65f640..80f25a0 100644 --- a/src/scripts/models.py +++ b/src/scripts/models.py @@ -58,7 +58,7 @@ def forward(self, x): return torch.stack((mu, var), dim=1) -def model_setup_DE(loss_type, DEVICE): # , INIT_LR=0.001): +def model_setup_DE(loss_type, DEVICE): # initialize the model from scratch if loss_type == "no_var_loss": # model = de_no_var().to(DEVICE) diff --git a/src/scripts/train.py b/src/scripts/train.py index fb27131..85986c6 100644 --- a/src/scripts/train.py +++ b/src/scripts/train.py @@ -188,6 +188,7 @@ def train_DE( DEVICE, loss_type, n_models, + wd, model_name="DE", BETA=None, EPOCHS=100, @@ -493,7 +494,8 @@ def train_DE( # ax1.errorbar(200, 600, yerr=5, # color='red', capsize=2) plt.savefig( - "../images/animations/" + str(wd) + + "images/animations/" + str(model_name) + "_nmodel_" + str(m) @@ -522,8 +524,8 @@ def train_DE( "x_val": x_val, "y_val": y_val, }, - path_to_model - + "/" + str(wd) + + "models/" + str(model_name) + "_beta_" + str(BETA) @@ -547,8 +549,8 @@ def train_DE( "x_val": x_val, "y_val": y_val, }, - path_to_model - + "/" + str(wd) + + "models/" + str(model_name) + "_nmodel_" + str(m) @@ -572,8 +574,8 @@ def train_DE( "x_val": x_val, "y_val": y_val, }, - path_to_model - + "/" + str(wd) + + "models/" + str(model_name) + "_beta_" + str(BETA) @@ -597,8 +599,8 @@ def train_DE( "x_val": x_val, "y_val": y_val, }, - path_to_model - + "/" + str(wd) + + "models/" + str(model_name) + "_nmodel_" + str(m) From 8f89deb5a0cfaee96adae99b8b9e827d39bcca2a Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 07:37:26 -0600 Subject: [PATCH 05/30] adding instructions for how to run deepensemble.py module --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index bc6b64f..0fa1082 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,13 @@ Getting a little more specific: ![python module overview](images/workflow_deepUQ.png) +These modules can be accessed via the ipython example notebooks or via the model modules (ie `DeepEnsemble.py`). For example, to ingest data and train a Deep Ensemble: +> cd src/scripts/ +> python DeepEnsemble.py low 10 /Users/rnevin/Documents/DeepUQ/ --save_final_checkpoints=True --savefig=True --n_epochs=10 + +This command will train a 10 network, 10 epoch ensemble on the low noise data and will save figures and final checkpoints to the specified directory. For more information on the arguments: +> python DeepEnsemble.py --help + ## Installation ### Clone this repo From 11d5e62a0f41daaf5eb31c5747837cd9a0406df5 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 07:38:41 -0600 Subject: [PATCH 06/30] adding a space after cd --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 0fa1082..c6ea96f 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ Getting a little more specific: These modules can be accessed via the ipython example notebooks or via the model modules (ie `DeepEnsemble.py`). For example, to ingest data and train a Deep Ensemble: > cd src/scripts/ + > python DeepEnsemble.py low 10 /Users/rnevin/Documents/DeepUQ/ --save_final_checkpoints=True --savefig=True --n_epochs=10 This command will train a 10 network, 10 epoch ensemble on the low noise data and will save figures and final checkpoints to the specified directory. For more information on the arguments: From 9ae98ce2feb2a4b8c0e337f7101c7fd7f012cc51 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 07:39:49 -0600 Subject: [PATCH 07/30] addding info on args --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index c6ea96f..e5653f6 100644 --- a/README.md +++ b/README.md @@ -23,7 +23,9 @@ These modules can be accessed via the ipython example notebooks or via the model > python DeepEnsemble.py low 10 /Users/rnevin/Documents/DeepUQ/ --save_final_checkpoints=True --savefig=True --n_epochs=10 -This command will train a 10 network, 10 epoch ensemble on the low noise data and will save figures and final checkpoints to the specified directory. For more information on the arguments: +This command will train a 10 network, 10 epoch ensemble on the low noise data and will save figures and final checkpoints to the specified directory. Required arguments are the noise setting (low/medium/high), the number of ensembles, and the working directory. + +For more information on the arguments: > python DeepEnsemble.py --help ## Installation From be1c1dfa6bd53a58dcdd3097e124bb956c6a2f70 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 09:33:38 -0600 Subject: [PATCH 08/30] new testing module for deepensembles argparse --- test/test_DeepEnsemble.py | 236 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 236 insertions(+) create mode 100644 test/test_DeepEnsemble.py diff --git a/test/test_DeepEnsemble.py b/test/test_DeepEnsemble.py new file mode 100644 index 0000000..148309b --- /dev/null +++ b/test/test_DeepEnsemble.py @@ -0,0 +1,236 @@ +import sys +import pytest +import torch +import numpy as np +import sbi +import os +import subprocess +import tempfile +import shutil +import unittest + +# flake8: noqa +sys.path.append("..") +#print(sys.path) +#from scripts.evaluate import Diagnose_static, Diagnose_generative +#from scripts.io import ModelLoader +from scripts import evaluate, models, DeepEnsemble + + +@pytest.fixture +def temp_directory(): + # Setup: Create a temporary directory with one folder level + temp_dir = tempfile.mkdtemp() + + # Create subdirectories within the temporary directory + models_dir = os.path.join(temp_dir, "models") + os.makedirs(models_dir) + + animations_dir = os.path.join(temp_dir, "images", "animations") + os.makedirs(animations_dir) + + yield temp_dir # Provide the temporary directory path to the test function + + # Teardown: Remove the temporary directory and its contents + ''' + for dir_path in [models_dir, animations_dir, temp_dir]: + os.rmdir(dir_path) + # Teardown: Remove the temporary directory and its contents + ''' + shutil.rmtree(temp_dir) + + +''' +@pytest.fixture +def temp_directory(tmpdir): + # Setup: Create a temporary directory + #temp_dir = tmpdir.mkdir("temp_test_directory") + + #temp_dir = tmpdir.join("temp_test_directory") + #os.mkdir(temp_dir + '/models/') + #os.mkdir(temp_dir + '/images/animations/') + temp_dir = tmpdir / "temp_test_directory" + temp_dir.mkdir() + + yield temp_dir # Provide the temporary directory to the test function + # Teardown: Remove the temporary directory and its contents + temp_dir.remove(rec=True) +''' + +''' +class TestMoveEmbargoArgs(unittest.TestCase): + def setUp(self): + """ + Performs the setup necessary to run + all tests + """ + temp_dir = tempfile.TemporaryDirectory() + temp_path = os.path.join(temp_dir.name, "temp_test/") + self.temp_dir = temp_dir + self.temp_path = temp_path + + + def tearDown(self): + """ + Removes all test files created by tests + """ + shutil.rmtree(self.temp_dir.name, ignore_errors=True) + +''' + +def test_run_simple_ensemble(temp_directory): + noise_level = 'low' + n_models = '10' + #here = os.getcwd() + #wd = self.temp_path + #os.path.dirname(here) + str(temp_directory) + '/' + wd = str(temp_directory) + '/' + print('wd', wd) + + subprocess_args = [ + "python", + "../src/scripts/DeepEnsemble.py", + noise_level, + n_models, + wd, + "--n_epochs", + '2'] + # now run the subprocess + subprocess.run(subprocess_args, check=True) +''' +@pytest.mark.xfail(strict=True) +def test_missing_req_arg(): + noise_level = 'low' + n_models = 10 + subprocess_args = [ + "python", + "../src/scripts/DeepEnsemble.py", + noise_level, + n_models, + "--n_epochs", + '1'] + # now run the subprocess + subprocess.run(subprocess_args, check=True) +''' + +''' +def run_ensemble(noise_level, + n_models, + wd): + subprocess_args = [ + "python", + "../src/scripts/DeepEnsemble.py", + noise_level, + n_models, + wd, + + temp_to, + "LATISS", + "--embargohours", + str(embargo_hours), + "--datasettype", + *iterable_datasettype, + "--collections", + *iterable_collections, + "--nowtime", + now_time_embargo, + "--log", + log, + "--desturiprefix", + desturiprefix, + ] + # now run the subprocess + subprocess.run(subprocess_args, check=True) +''' + +""" +@pytest.fixture +def diagnose_static_instance(): + return Diagnose_static() + +@pytest.fixture +def diagnose_generative_instance(): + return Diagnose_generative() + + +@pytest.fixture +def posterior_generative_sbi_model(): + # create a temporary directory for the saved model + #dir = "savedmodels/sbi/" + #os.makedirs(dir) + + # now save the model + low_bounds = torch.tensor([0, -10]) + high_bounds = torch.tensor([10, 10]) + + prior = sbi.utils.BoxUniform(low = low_bounds, high = high_bounds) + + posterior = sbi.inference.base.infer(simulator, prior, "SNPE", num_simulations=10000) + + # Provide the posterior to the tests + yield prior, posterior + + # Teardown: Remove the temporary directory and its contents + #shutil.rmtree(dataset_dir) + +@pytest.fixture +def setup_plot_dir(): + # create a temporary directory for the saved model + dir = "tests/plots/" + os.makedirs(dir) + yield dir + +def simulator(thetas): # , percent_errors): + # convert to numpy array (if tensor): + thetas = np.atleast_2d(thetas) + # Check if the input has the correct shape + if thetas.shape[1] != 2: + raise ValueError( + "Input tensor must have shape (n, 2) \ + where n is the number of parameter sets." + ) + + # Unpack the parameters + if thetas.shape[0] == 1: + # If there's only one set of parameters, extract them directly + m, b = thetas[0, 0], thetas[0, 1] + else: + # If there are multiple sets of parameters, extract them for each row + m, b = thetas[:, 0], thetas[:, 1] + x = np.linspace(0, 100, 101) + rs = np.random.RandomState() # 2147483648)# + # I'm thinking sigma could actually be a function of x + # if we want to get fancy down the road + # Generate random noise (epsilon) based + # on a normal distribution with mean 0 and standard deviation sigma + sigma = 5 + ε = rs.normal(loc=0, scale=sigma, size=(len(x), thetas.shape[0])) + + # Initialize an empty array to store the results for each set of parameters + y = np.zeros((len(x), thetas.shape[0])) + for i in range(thetas.shape[0]): + m, b = thetas[i, 0], thetas[i, 1] + y[:, i] = m * x + b + ε[:, i] + return torch.Tensor(y.T) + + +def test_generate_sbc_samples(diagnose_generative_instance, + posterior_generative_sbi_model): + # Mock data + #low_bounds = torch.tensor([0, -10]) + #high_bounds = torch.tensor([10, 10]) + + #prior = sbi.utils.BoxUniform(low=low_bounds, high=high_bounds) + prior, posterior = posterior_generative_sbi_model + #inference_instance # provide a mock posterior object + simulator_test = simulator # provide a mock simulator function + num_sbc_runs = 1000 + num_posterior_samples = 1000 + + # Generate SBC samples + thetas, ys, ranks, dap_samples = diagnose_generative_instance.generate_sbc_samples( + prior, posterior, simulator_test, num_sbc_runs, num_posterior_samples + ) + + # Add assertions based on the expected behavior of the method +""" From fe827f790e34b1e3fcd1ca67f8f808e91273b146 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 10:02:09 -0600 Subject: [PATCH 09/30] removing test_example --- test/test_example.py | 23 ----------------------- 1 file changed, 23 deletions(-) delete mode 100644 test/test_example.py diff --git a/test/test_example.py b/test/test_example.py deleted file mode 100644 index e9dadb1..0000000 --- a/test/test_example.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -Example of pytest functionality - -Goes over some basic assert examples -""" -import pytest - - -def test_example_assert_equal(): - assert 0 == 0 - - -def test_example_assert_no_equal(): - assert 0 != 1 - - -def test_example_assert_almost_equal(): - assert 1.0 == pytest.approx(1.01, .1) - - -""" -To run this suite of tests, run 'pytest' in the main directory -""" From b35635a95cbeae133241af0068591e4064e25e37 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 10:31:13 -0600 Subject: [PATCH 10/30] lint will run on push --- .github/workflows/lint.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index dd4e094..0b9b7ea 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,6 +1,6 @@ name: Lint it -on: [pull_request] +on: push jobs: lint: From 0a74d482a84ce6dfff34e303a3806e7ee3327979 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 10:37:02 -0600 Subject: [PATCH 11/30] changed all bools to store_true --- src/scripts/DeepEnsemble.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/src/scripts/DeepEnsemble.py b/src/scripts/DeepEnsemble.py index 5bf0f47..0c92bce 100644 --- a/src/scripts/DeepEnsemble.py +++ b/src/scripts/DeepEnsemble.py @@ -129,43 +129,37 @@ def parse_args(): ) parser.add_argument( "--save_all_checkpoints", - type=bool, - required=False, + action="store_true", default=False, help="option to save all checkpoints", ) parser.add_argument( - "--save_final_checkpoints", - type=bool, - required=False, - default=False, + "--save_final_checkpoint", + action="store_true", # Set to True if argument is present + default=False, # Set default value to False if argument is not present help="option to save the final epoch checkpoint for each ensemble", ) parser.add_argument( - "--overwrite_final_checkpoints", - type=bool, - required=False, + "--overwrite_final_checkpoint", + action="store_true", default=False, help="option to overwite already saved checkpoints", ) parser.add_argument( "--plot", - type=bool, - required=False, + action="store_true", default=False, help="option to plot in notebook", ) parser.add_argument( "--savefig", - type=bool, - required=False, + action="store_true", default=True, help="option to save a figure of the true and predicted values", ) parser.add_argument( "--verbose", - type=bool, - required=False, + action="store_true", default=False, help="verbose option for train", ) @@ -232,8 +226,8 @@ def parse_args(): EPOCHS=namespace.n_epochs, path_to_model=namespace.path_to_models, save_all_checkpoints=namespace.save_all_checkpoints, - save_final_checkpoint=namespace.save_final_checkpoints, - overwrite_final_checkpoint=namespace.overwrite_final_checkpoints, + save_final_checkpoint=namespace.save_final_checkpoint, + overwrite_final_checkpoint=namespace.overwrite_final_checkpoint, plot=namespace.plot, savefig=namespace.savefig, verbose=namespace.verbose From 737d4c9a1017d87f6e4a97e62563991b9b8b27ad Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 10:38:54 -0600 Subject: [PATCH 12/30] changed default to False for all store_true args --- src/scripts/DeepEnsemble.py | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/src/scripts/DeepEnsemble.py b/src/scripts/DeepEnsemble.py index 0c92bce..97fad07 100644 --- a/src/scripts/DeepEnsemble.py +++ b/src/scripts/DeepEnsemble.py @@ -36,16 +36,6 @@ def parse_args(): default='low', help="low, medium, high or vhigh, used to look up associated sigma value", ) - ''' - parser.add_argument( - "size_df", - type=str, - nargs="?", - default="/repo/embargo", - help="Butler Repository path from which data is transferred. \ - Input str. Default = '/repo/embargo'", - ) - ''' parser.add_argument( "--normalize", required=False, @@ -104,7 +94,7 @@ def parse_args(): parser.add_argument( "wd", type=str, - help="Top level of directory", + help="Top level of directory, required arg", ) parser.add_argument( "--model_type", @@ -154,7 +144,7 @@ def parse_args(): parser.add_argument( "--savefig", action="store_true", - default=True, + default=False, help="option to save a figure of the true and predicted values", ) parser.add_argument( From 994387757bf9df09a8de2cc20db5c4cb03586b2a Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 10:39:26 -0600 Subject: [PATCH 13/30] black and all tests back --- test/test_DeepEnsemble.py | 316 +++++++++++++++----------------------- 1 file changed, 128 insertions(+), 188 deletions(-) diff --git a/test/test_DeepEnsemble.py b/test/test_DeepEnsemble.py index 148309b..7a300c4 100644 --- a/test/test_DeepEnsemble.py +++ b/test/test_DeepEnsemble.py @@ -11,9 +11,9 @@ # flake8: noqa sys.path.append("..") -#print(sys.path) -#from scripts.evaluate import Diagnose_static, Diagnose_generative -#from scripts.io import ModelLoader +# print(sys.path) +# from scripts.evaluate import Diagnose_static, Diagnose_generative +# from scripts.io import ModelLoader from scripts import evaluate, models, DeepEnsemble @@ -21,216 +21,156 @@ def temp_directory(): # Setup: Create a temporary directory with one folder level temp_dir = tempfile.mkdtemp() - + # Create subdirectories within the temporary directory models_dir = os.path.join(temp_dir, "models") os.makedirs(models_dir) - + animations_dir = os.path.join(temp_dir, "images", "animations") os.makedirs(animations_dir) - + yield temp_dir # Provide the temporary directory path to the test function - + # Teardown: Remove the temporary directory and its contents - ''' + """ for dir_path in [models_dir, animations_dir, temp_dir]: os.rmdir(dir_path) # Teardown: Remove the temporary directory and its contents - ''' + """ shutil.rmtree(temp_dir) -''' -@pytest.fixture -def temp_directory(tmpdir): - # Setup: Create a temporary directory - #temp_dir = tmpdir.mkdir("temp_test_directory") - - #temp_dir = tmpdir.join("temp_test_directory") - #os.mkdir(temp_dir + '/models/') - #os.mkdir(temp_dir + '/images/animations/') - temp_dir = tmpdir / "temp_test_directory" - temp_dir.mkdir() - - yield temp_dir # Provide the temporary directory to the test function - # Teardown: Remove the temporary directory and its contents - temp_dir.remove(rec=True) -''' - -''' -class TestMoveEmbargoArgs(unittest.TestCase): - def setUp(self): - """ - Performs the setup necessary to run - all tests - """ - temp_dir = tempfile.TemporaryDirectory() - temp_path = os.path.join(temp_dir.name, "temp_test/") - self.temp_dir = temp_dir - self.temp_path = temp_path - - - def tearDown(self): - """ - Removes all test files created by tests - """ - shutil.rmtree(self.temp_dir.name, ignore_errors=True) - -''' - -def test_run_simple_ensemble(temp_directory): - noise_level = 'low' - n_models = '10' - #here = os.getcwd() - #wd = self.temp_path - #os.path.dirname(here) + str(temp_directory) + '/' - wd = str(temp_directory) + '/' - print('wd', wd) - +@pytest.mark.xfail(strict=True) +def test_no_chkpt_saved_xfail(temp_directory): + noise_level = "low" + n_models = 10 + wd = str(temp_directory) + "/" + n_epochs = 2 subprocess_args = [ - "python", - "../src/scripts/DeepEnsemble.py", - noise_level, - n_models, - wd, - "--n_epochs", - '2'] + "python", + "../src/scripts/DeepEnsemble.py", + noise_level, + str(n_models), + wd, + "--n_epochs", + str(n_epochs), + ] # now run the subprocess subprocess.run(subprocess_args, check=True) -''' -@pytest.mark.xfail(strict=True) -def test_missing_req_arg(): - noise_level = 'low' + # check if the right number of checkpoints are saved + models_folder = os.path.join(temp_directory, "models") + # list all files in the "models" folder + files_in_models_folder = os.listdir(models_folder) + # assert that the number of files is equal to 10 + assert ( + len(files_in_models_folder) == n_models + ), "Expected 10 files in the 'models' folder" + + +def test_no_chkpt_saved(temp_directory): + noise_level = "low" n_models = 10 + wd = str(temp_directory) + "/" + n_epochs = 2 subprocess_args = [ - "python", - "../src/scripts/DeepEnsemble.py", - noise_level, - n_models, - "--n_epochs", - '1'] + "python", + "../src/scripts/DeepEnsemble.py", + noise_level, + str(n_models), + wd, + "--n_epochs", + str(n_epochs), + ] # now run the subprocess subprocess.run(subprocess_args, check=True) -''' + # check if the right number of checkpoints are saved + models_folder = os.path.join(temp_directory, "models") + # list all files in the "models" folder + files_in_models_folder = os.listdir(models_folder) + # assert that the number of files is equal to 10 + assert len(files_in_models_folder) == 0, "Expect 0 files in the 'models' folder" + -''' -def run_ensemble(noise_level, - n_models, - wd): +def test_chkpt_saved(temp_directory): + noise_level = "low" + n_models = 10 + wd = str(temp_directory) + "/" + n_epochs = 2 subprocess_args = [ - "python", - "../src/scripts/DeepEnsemble.py", - noise_level, - n_models, - wd, - - temp_to, - "LATISS", - "--embargohours", - str(embargo_hours), - "--datasettype", - *iterable_datasettype, - "--collections", - *iterable_collections, - "--nowtime", - now_time_embargo, - "--log", - log, - "--desturiprefix", - desturiprefix, - ] + "python", + "../src/scripts/DeepEnsemble.py", + noise_level, + str(n_models), + wd, + "--n_epochs", + str(n_epochs), + "--save_final_checkpoints", + "True", + ] # now run the subprocess subprocess.run(subprocess_args, check=True) -''' - -""" -@pytest.fixture -def diagnose_static_instance(): - return Diagnose_static() + # check if the right number of checkpoints are saved + models_folder = os.path.join(temp_directory, "models") + # list all files in the "models" folder + files_in_models_folder = os.listdir(models_folder) + # assert that the number of files is equal to 10 + assert ( + len(files_in_models_folder) == n_models + ), "Expected 10 files in the 'models' folder" + + # check if the right number of images were saved + animations_folder = os.path.join(temp_directory, "images/animations") + files_in_animations_folder = os.listdir(animations_folder) + # assert that the number of files is equal to 10 + assert ( + len(files_in_animations_folder) == n_models + ), "Expected 10 files in the 'images/animations' folder" + + # also check that all files in here have the same name elements + expected_substring = "epoch_" + str(n_epochs - 1) + for file_name in files_in_models_folder: + assert ( + expected_substring in file_name + ), f"File '{file_name}' does not contain the expected substring" + + # also check that all files in here have the same name elements + for file_name in files_in_animations_folder: + assert ( + expected_substring in file_name + ), f"File '{file_name}' does not contain the expected substring" -@pytest.fixture -def diagnose_generative_instance(): - return Diagnose_generative() +def test_run_simple_ensemble(temp_directory): + noise_level = "low" + n_models = "10" + # here = os.getcwd() + # wd = self.temp_path + # os.path.dirname(here) + str(temp_directory) + '/' + wd = str(temp_directory) + "/" + subprocess_args = [ + "python", + "../src/scripts/DeepEnsemble.py", + noise_level, + n_models, + wd, + "--n_epochs", + "2", + ] + # now run the subprocess + subprocess.run(subprocess_args, check=True) -@pytest.fixture -def posterior_generative_sbi_model(): - # create a temporary directory for the saved model - #dir = "savedmodels/sbi/" - #os.makedirs(dir) - - # now save the model - low_bounds = torch.tensor([0, -10]) - high_bounds = torch.tensor([10, 10]) - - prior = sbi.utils.BoxUniform(low = low_bounds, high = high_bounds) - - posterior = sbi.inference.base.infer(simulator, prior, "SNPE", num_simulations=10000) - - # Provide the posterior to the tests - yield prior, posterior - - # Teardown: Remove the temporary directory and its contents - #shutil.rmtree(dataset_dir) -@pytest.fixture -def setup_plot_dir(): - # create a temporary directory for the saved model - dir = "tests/plots/" - os.makedirs(dir) - yield dir - -def simulator(thetas): # , percent_errors): - # convert to numpy array (if tensor): - thetas = np.atleast_2d(thetas) - # Check if the input has the correct shape - if thetas.shape[1] != 2: - raise ValueError( - "Input tensor must have shape (n, 2) \ - where n is the number of parameter sets." - ) - - # Unpack the parameters - if thetas.shape[0] == 1: - # If there's only one set of parameters, extract them directly - m, b = thetas[0, 0], thetas[0, 1] - else: - # If there are multiple sets of parameters, extract them for each row - m, b = thetas[:, 0], thetas[:, 1] - x = np.linspace(0, 100, 101) - rs = np.random.RandomState() # 2147483648)# - # I'm thinking sigma could actually be a function of x - # if we want to get fancy down the road - # Generate random noise (epsilon) based - # on a normal distribution with mean 0 and standard deviation sigma - sigma = 5 - ε = rs.normal(loc=0, scale=sigma, size=(len(x), thetas.shape[0])) - - # Initialize an empty array to store the results for each set of parameters - y = np.zeros((len(x), thetas.shape[0])) - for i in range(thetas.shape[0]): - m, b = thetas[i, 0], thetas[i, 1] - y[:, i] = m * x + b + ε[:, i] - return torch.Tensor(y.T) - - -def test_generate_sbc_samples(diagnose_generative_instance, - posterior_generative_sbi_model): - # Mock data - #low_bounds = torch.tensor([0, -10]) - #high_bounds = torch.tensor([10, 10]) - - #prior = sbi.utils.BoxUniform(low=low_bounds, high=high_bounds) - prior, posterior = posterior_generative_sbi_model - #inference_instance # provide a mock posterior object - simulator_test = simulator # provide a mock simulator function - num_sbc_runs = 1000 - num_posterior_samples = 1000 - - # Generate SBC samples - thetas, ys, ranks, dap_samples = diagnose_generative_instance.generate_sbc_samples( - prior, posterior, simulator_test, num_sbc_runs, num_posterior_samples - ) - - # Add assertions based on the expected behavior of the method -""" +@pytest.mark.xfail(strict=True) +def test_missing_req_arg(temp_directory): + noise_level = "low" + n_models = "10" + subprocess_args = [ + "python", + "../src/scripts/DeepEnsemble.py", + noise_level, + n_models, + "--n_epochs", + "2", + ] + # now run the subprocess + subprocess.run(subprocess_args, check=True) From 0ce803c35fa5e8ba70d6e9e11f31dd4df52cd578 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 10:50:35 -0600 Subject: [PATCH 14/30] modifying savefig to default to false --- src/scripts/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/scripts/train.py b/src/scripts/train.py index 85986c6..0aff0c4 100644 --- a/src/scripts/train.py +++ b/src/scripts/train.py @@ -214,6 +214,9 @@ def train_DE( model_ensemble = [] + print('this is the value of save_final_checkpoint', + save_final_checkpoint) + for m in range(n_models): print("model", m) if not save_all_checkpoints and save_final_checkpoint: From 2cf50a17a742e7983b86bba45cd4089eb65030cf Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 10:52:05 -0600 Subject: [PATCH 15/30] changing readme to proper store_true flags --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e5653f6..51128e7 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Getting a little more specific: These modules can be accessed via the ipython example notebooks or via the model modules (ie `DeepEnsemble.py`). For example, to ingest data and train a Deep Ensemble: > cd src/scripts/ -> python DeepEnsemble.py low 10 /Users/rnevin/Documents/DeepUQ/ --save_final_checkpoints=True --savefig=True --n_epochs=10 +> python DeepEnsemble.py low 10 /Users/rnevin/Documents/DeepUQ/ --save_final_checkpoint --savefig --n_epochs=10 This command will train a 10 network, 10 epoch ensemble on the low noise data and will save figures and final checkpoints to the specified directory. Required arguments are the noise setting (low/medium/high), the number of ensembles, and the working directory. From 5b9507c6f053b1c465a2b56e8dd48566f559ec6e Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 10:58:43 -0600 Subject: [PATCH 16/30] flake8 on io.py --- src/scripts/DeepEnsemble.py | 109 +++++++++++++++++------------------- src/scripts/io.py | 80 +++++++++++--------------- 2 files changed, 85 insertions(+), 104 deletions(-) diff --git a/src/scripts/DeepEnsemble.py b/src/scripts/DeepEnsemble.py index 97fad07..a800a67 100644 --- a/src/scripts/DeepEnsemble.py +++ b/src/scripts/DeepEnsemble.py @@ -1,28 +1,28 @@ import argparse -import logging import numpy as np import torch from torch.utils.data import DataLoader, TensorDataset -from scripts import train, models, analysis, io +from scripts import train, models, io def beta_type(value): if isinstance(value, float): return value - elif value.lower() == 'linear_decrease': + elif value.lower() == "linear_decrease": return value - elif value.lower() == 'step_decrease_to_0.5': + elif value.lower() == "step_decrease_to_0.5": return value - elif value.lower() == 'step_decrease_to_1.0': + elif value.lower() == "step_decrease_to_1.0": return value else: - raise argparse.ArgumentTypeError("BETA must be a float or one of 'linear_decrease', 'step_decrease_to_0.5', 'step_decrease_to_1.0'") + raise argparse.ArgumentTypeError( + "BETA must be a float or one of 'linear_decrease', \ + 'step_decrease_to_0.5', 'step_decrease_to_1.0'" + ) def parse_args(): - parser = argparse.ArgumentParser( - description="data handling module" - ) + parser = argparse.ArgumentParser(description="data handling module") parser.add_argument( "--size_df", type=float, @@ -33,8 +33,9 @@ def parse_args(): parser.add_argument( "noise_level", type=str, - default='low', - help="low, medium, high or vhigh, used to look up associated sigma value", + default="low", + help="low, medium, high or vhigh, \ + used to look up associated sigma value", ) parser.add_argument( "--normalize", @@ -82,14 +83,17 @@ def parse_args(): type=str, required=False, default="bnll_loss", - help="Loss types for MVE, options are no_var_loss, var_loss, and bnn_loss", + help="Loss types for MVE, options are no_var_loss, var_loss, \ + and bnn_loss", ) parser.add_argument( "--BETA", type=beta_type, required=False, default=0.5, - help="If loss_type is bnn_loss, specify a beta as a float or there are string options: linear_decrease, step_decrease_to_0.5, and step_decrease_to_1.0", + help="If loss_type is bnn_loss, specify a beta as a float or \ + there are string options: linear_decrease, \ + step_decrease_to_0.5, and step_decrease_to_1.0", ) parser.add_argument( "wd", @@ -166,61 +170,50 @@ def parse_args(): BATCH_SIZE = namespace.batchsize sigma = io.DataPreparation.get_sigma(noise) loader = io.DataLoader() - data = loader.load_data_h5('linear_sigma_'+str(sigma)+'_size_'+str(size_df), - path='/Users/rnevin/Documents/DeepUQ/data/') - len_df = len(data['params'][:, 0].numpy()) - len_x = len(data['inputs'].numpy()) - ms_array = np.repeat(data['params'][:, 0].numpy(), len_x) - bs_array = np.repeat(data['params'][:, 1].numpy(), len_x) - xs_array = np.tile(data['inputs'].numpy(), len_df) - ys_array = np.reshape(data['output'].numpy(), (len_df * len_x)) + data = loader.load_data_h5( + "linear_sigma_" + str(sigma) + "_size_" + str(size_df), + path="/Users/rnevin/Documents/DeepUQ/data/", + ) + len_df = len(data["params"][:, 0].numpy()) + len_x = len(data["inputs"].numpy()) + ms_array = np.repeat(data["params"][:, 0].numpy(), len_x) + bs_array = np.repeat(data["params"][:, 1].numpy(), len_x) + xs_array = np.tile(data["inputs"].numpy(), len_df) + ys_array = np.reshape(data["output"].numpy(), (len_df * len_x)) inputs = np.array([xs_array, ms_array, bs_array]).T model_inputs, model_outputs = io.DataPreparation.normalize(inputs, ys_array, norm) - x_train, x_val, y_train, y_val = io.DataPreparation.train_val_split(model_inputs, - model_outputs, - val_proportion=val_prop, - random_state=rs) + x_train, x_val, y_train, y_val = io.DataPreparation.train_val_split( + model_inputs, model_outputs, val_proportion=val_prop, random_state=rs + ) trainData = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train)) trainDataLoader = DataLoader(trainData, batch_size=BATCH_SIZE, shuffle=True) - ''' - valData = TensorDataset(torch.Tensor(x_val), torch.Tensor(y_val)) - valDataLoader = DataLoader(valData, - batch_size=BATCH_SIZE) - - # calculate steps per epoch for training and validation set - trainSteps = len(trainDataLoader.dataset) // BATCH_SIZE - valSteps = len(valDataLoader.dataset) // BATCH_SIZE - - return trainDataLoader, x_val, y_val - ''' print("[INFO] initializing the gal model...") # set the device we will be using to train the model DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") - model_name = namespace.model_type + '_noise_' + noise + model_name = namespace.model_type + "_noise_" + noise model, lossFn = models.model_setup_DE(namespace.loss_type, DEVICE) - model_ensemble = train.train_DE(trainDataLoader, - x_val, - y_val, - namespace.init_lr, - DEVICE, - namespace.loss_type, - namespace.n_models, - namespace.wd, - model_name, - BETA=namespace.BETA, - EPOCHS=namespace.n_epochs, - path_to_model=namespace.path_to_models, - save_all_checkpoints=namespace.save_all_checkpoints, - save_final_checkpoint=namespace.save_final_checkpoint, - overwrite_final_checkpoint=namespace.overwrite_final_checkpoint, - plot=namespace.plot, - savefig=namespace.savefig, - verbose=namespace.verbose - ) - - + model_ensemble = train.train_DE( + trainDataLoader, + x_val, + y_val, + namespace.init_lr, + DEVICE, + namespace.loss_type, + namespace.n_models, + namespace.wd, + model_name, + BETA=namespace.BETA, + EPOCHS=namespace.n_epochs, + path_to_model=namespace.path_to_models, + save_all_checkpoints=namespace.save_all_checkpoints, + save_final_checkpoint=namespace.save_final_checkpoint, + overwrite_final_checkpoint=namespace.overwrite_final_checkpoint, + plot=namespace.plot, + savefig=namespace.savefig, + verbose=namespace.verbose, + ) diff --git a/src/scripts/io.py b/src/scripts/io.py index f6e7afa..5f38eda 100644 --- a/src/scripts/io.py +++ b/src/scripts/io.py @@ -5,15 +5,13 @@ from sklearn.model_selection import train_test_split import pickle from torch.distributions import Uniform -from torch.utils.data import DataLoader, TensorDataset +from torch.utils.data import TensorDataset import torch import h5py def parse_args(): - parser = argparse.ArgumentParser( - description="data handling module" - ) + parser = argparse.ArgumentParser(description="data handling module") parser.add_argument( "size_df", type=float, @@ -25,8 +23,9 @@ def parse_args(): "noise_level", type=str, required=False, - default='low', - help="low, medium, high or vhigh, used to look up associated sigma value", + default="low", + help="low, medium, high or vhigh, \ + used to look up associated sigma value", ) parser.add_argument( "size_df", @@ -267,19 +266,17 @@ def get_data(self): return self.data def get_sigma(noise): - if noise == 'low': + if noise == "low": sigma = 1 - if noise == 'medium': + if noise == "medium": sigma = 5 - if noise == 'high': + if noise == "high": sigma = 10 - if noise == 'vhigh': + if noise == "vhigh": sigma = 100 return sigma - - def normalize(inputs, - ys_array, - norm=False): + + def normalize(inputs, ys_array, norm=False): if norm: # normalize everything before it goes into a network inputmin = np.min(inputs, axis=0) @@ -292,15 +289,16 @@ def normalize(inputs, model_inputs = inputs model_outputs = ys_array return model_inputs, model_outputs - - def train_val_split(model_inputs, - model_outputs, - val_proportion=0.1, - random_state=42): - x_train, x_val, y_train, y_val = train_test_split(model_inputs, - model_outputs, - test_size=val_proportion, - random_state=random_state) + + def train_val_split( + model_inputs, model_outputs, val_proportion=0.1, random_state=42 + ): + x_train, x_val, y_train, y_val = train_test_split( + model_inputs, + model_outputs, + test_size=val_proportion, + random_state=random_state, + ) return x_train, x_val, y_train, y_val @@ -315,32 +313,22 @@ def train_val_split(model_inputs, BATCH_SIZE = namespace.batchsize sigma = DataPreparation.get_sigma(noise) loader = DataLoader() - data = loader.load_data_h5('linear_sigma_'+str(sigma)+'_size_'+str(size_df)) - len_df = len(data['params'][:, 0].numpy()) - len_x = len(data['inputs'].numpy()) - ms_array = np.repeat(data['params'][:, 0].numpy(), len_x) - bs_array = np.repeat(data['params'][:, 1].numpy(), len_x) - xs_array = np.tile(data['inputs'].numpy(), len_df) - ys_array = np.reshape(data['output'].numpy(), (len_df * len_x)) + data = loader.load_data_h5("linear_sigma_" + str(sigma) + + "_size_" + str(size_df)) + len_df = len(data["params"][:, 0].numpy()) + len_x = len(data["inputs"].numpy()) + ms_array = np.repeat(data["params"][:, 0].numpy(), len_x) + bs_array = np.repeat(data["params"][:, 1].numpy(), len_x) + xs_array = np.tile(data["inputs"].numpy(), len_df) + ys_array = np.reshape(data["output"].numpy(), (len_df * len_x)) inputs = np.array([xs_array, ms_array, bs_array]).T model_inputs, model_outputs = DataPreparation.normalize(inputs, - ys_array, - norm) - x_train, x_val, y_train, y_val = DataPreparation.train_val_split(model_inputs, - model_outputs, - test_size=val_prop, - random_state=rs) + ys_array, + norm) + x_train, x_val, y_train, y_val = DataPreparation.train_val_split( + model_inputs, model_outputs, test_size=val_prop, random_state=rs + ) trainData = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train)) trainDataLoader = DataLoader(trainData, batch_size=BATCH_SIZE, shuffle=True) - ''' - valData = TensorDataset(torch.Tensor(x_val), torch.Tensor(y_val)) - valDataLoader = DataLoader(valData, - batch_size=BATCH_SIZE) - - # calculate steps per epoch for training and validation set - trainSteps = len(trainDataLoader.dataset) // BATCH_SIZE - valSteps = len(valDataLoader.dataset) // BATCH_SIZE - ''' - #return trainDataLoader, x_val, y_val From 48480d3834f7d65984c43801cc11de86ad77e836 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 10:58:51 -0600 Subject: [PATCH 17/30] all tests passing locally --- test/test_DeepEnsemble.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_DeepEnsemble.py b/test/test_DeepEnsemble.py index 7a300c4..f29698c 100644 --- a/test/test_DeepEnsemble.py +++ b/test/test_DeepEnsemble.py @@ -104,9 +104,9 @@ def test_chkpt_saved(temp_directory): wd, "--n_epochs", str(n_epochs), - "--save_final_checkpoints", - "True", - ] + "--save_final_checkpoint", + "--savefig" + ] # now run the subprocess subprocess.run(subprocess_args, check=True) # check if the right number of checkpoints are saved From a8373baafd0badbaacf81714b6b88bf614ea225e Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 11:05:52 -0600 Subject: [PATCH 18/30] chaging from poetry to just pytest --- .github/workflows/test.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 99bc88f..7337c7e 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -29,4 +29,5 @@ jobs: - name: Test with pytest run: | - python3 -m poetry run pytest --cov + cd ./tests + pytest --cov \ No newline at end of file From 553a4463811f17b1e8226a08d03470a23c7061fe Mon Sep 17 00:00:00 2001 From: beckynevin Date: Mon, 1 Apr 2024 11:56:45 -0600 Subject: [PATCH 19/30] adding pwd --- .github/workflows/test.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 7337c7e..d933917 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -29,5 +29,6 @@ jobs: - name: Test with pytest run: | + pwd cd ./tests pytest --cov \ No newline at end of file From 072dba9458e420b502d70e1162641b836424cd41 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Wed, 3 Apr 2024 11:45:58 -0600 Subject: [PATCH 20/30] changed from tests to test/ folder --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index d933917..ae630f5 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -30,5 +30,5 @@ jobs: - name: Test with pytest run: | pwd - cd ./tests + cd ./test pytest --cov \ No newline at end of file From 45531206b011e55a433f5a27f7f0fcdd3f06a187 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Wed, 3 Apr 2024 11:59:52 -0600 Subject: [PATCH 21/30] installing pytest --- .github/workflows/test.yaml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index ae630f5..b5de3b6 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -26,6 +26,9 @@ jobs: - name: Install dependencies shell: bash run: python -m poetry install + + - name: Install pytest + run: python -m pip install pytest - name: Test with pytest run: | From 9ac424078cc08ff5d17f4459f953c81cd81c4632 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Wed, 3 Apr 2024 12:09:23 -0600 Subject: [PATCH 22/30] removing cov --- .github/workflows/test.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b5de3b6..dabdde3 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -34,4 +34,4 @@ jobs: run: | pwd cd ./test - pytest --cov \ No newline at end of file + pytest test_DeepEnsemble.py \ No newline at end of file From c78fa17902957b2c4f1c1ac08c8b8786ac700a03 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Tue, 9 Apr 2024 14:25:12 -0600 Subject: [PATCH 23/30] updating to be same as deepdiagnostics --- .github/workflows/test.yaml | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index dabdde3..315e748 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -8,6 +8,15 @@ jobs: runs-on: ubuntu-latest steps: + - name: Cache Poetry dependencies + uses: actions/cache@v2 + with: + path: | + ~/.cache + ~/.local/share/virtualenvs + key: ${{ runner.os }}-poetry-${{ hashFiles('**/pyproject.toml') }} + restore-keys: | + ${{ runner.os }}-poetry- - uses: actions/checkout@v2 - name: Set up Python 3.10 uses: actions/setup-python@v2 @@ -26,12 +35,12 @@ jobs: - name: Install dependencies shell: bash run: python -m poetry install - - - name: Install pytest - run: python -m pip install pytest + - name: Create Environment File + run: echo "PYTHONPATH=$(pwd):$(pwd)/src" >> ${{ runner.workspace }}/.env + - name: Test with pytest - run: | - pwd - cd ./test - pytest test_DeepEnsemble.py \ No newline at end of file + run: python -m poetry run pytest --cov + env: + PYTHONPATH: ${{ env.PYTHONPATH }} + ENV_FILE: ${{ runner.workspace }}/.env \ No newline at end of file From 35fb574132493d6d30c46973dd18ecc36392da81 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Tue, 9 Apr 2024 14:51:51 -0600 Subject: [PATCH 24/30] pyproject now has a dev part --- poetry.lock | 1062 +++++------------------------------------------- pyproject.toml | 16 +- 2 files changed, 112 insertions(+), 966 deletions(-) diff --git a/poetry.lock b/poetry.lock index 090224f..1230fc8 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,16 +1,5 @@ # This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. -[[package]] -name = "absl-py" -version = "2.1.0" -description = "Abseil Python Common Libraries, see https://github.com/abseil/abseil-py." -optional = false -python-versions = ">=3.7" -files = [ - {file = "absl-py-2.1.0.tar.gz", hash = "sha256:7820790efbb316739cde8b4e19357243fc3608a152024288513dd968d7d959ff"}, - {file = "absl_py-2.1.0-py3-none-any.whl", hash = "sha256:526a04eadab8b4ee719ce68f204172ead1027549089702d99b9059f129ff1308"}, -] - [[package]] name = "anyio" version = "4.0.0" @@ -119,83 +108,6 @@ types-python-dateutil = ">=2.8.10" doc = ["doc8", "sphinx (>=7.0.0)", "sphinx-autobuild", "sphinx-autodoc-typehints", "sphinx_rtd_theme (>=1.3.0)"] test = ["dateparser (==1.*)", "pre-commit", "pytest", "pytest-cov", "pytest-mock", "pytz (==2021.1)", "simplejson (==3.*)"] -[[package]] -name = "arviz" -version = "0.15.1" -description = "Exploratory analysis of Bayesian models" -optional = false -python-versions = ">=3.8" -files = [ - {file = "arviz-0.15.1-py3-none-any.whl", hash = "sha256:120695738fb81cc39e8da98b8b751f8f08c618267efda2a6dcb3f1511b599311"}, - {file = "arviz-0.15.1.tar.gz", hash = "sha256:981cce0282bdf6f3b379255b95a440979f9a0ef0ae9dd88a54f763cf5b31484c"}, -] - -[package.dependencies] -h5netcdf = ">=1.0.2" -matplotlib = ">=3.2" -numpy = ">=1.20.0" -packaging = "*" -pandas = ">=1.3.0" -scipy = ">=1.8.0" -setuptools = ">=60.0.0" -typing-extensions = ">=4.1.0" -xarray = ">=0.21.0" -xarray-einstats = ">=0.3" - -[package.extras] -all = ["bokeh (>=1.4.0,<3.0)", "contourpy", "dask[distributed]", "netcdf4", "numba", "ujson", "zarr (>=2.5.0)"] - -[[package]] -name = "astropy" -version = "5.3.4" -description = "Astronomy and astrophysics core library" -optional = false -python-versions = ">=3.9" -files = [ - {file = "astropy-5.3.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6c63abc95d094cd3062e32c1ebf80c07502e4f3094b1e276458db5ce6b6a2"}, - {file = "astropy-5.3.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:e85871ec762fc7eab2f7e716c97dad1b3c546bb75941ea7fae6c8eadd51f0bf8"}, - {file = "astropy-5.3.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e82fdad3417b70af381945aa42fdae0f11bc9aaf94b95027b1e24379bf847d6"}, - {file = "astropy-5.3.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bbce56f46ec1051fd67a5e2244e5f2e08599a176fe524c0bee2294c62be317b3"}, - {file = "astropy-5.3.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a489c2322136b76a43208e3e9b5a7947a7fd624a10e49d2909b94f12b624da06"}, - {file = "astropy-5.3.4-cp310-cp310-win32.whl", hash = "sha256:c713695e39f5a874705bc3bd262c5d218890e3e7c43f0b6c0b5e7d46bdff527c"}, - {file = "astropy-5.3.4-cp310-cp310-win_amd64.whl", hash = "sha256:2576579befb0674cdfd18f5cc138c919a109c6886a25aa3d8ed8ab4e4607c581"}, - {file = "astropy-5.3.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4ce096dde6b86a87aa84aec4198732ec379fbb7649af66a96f85b96d17214c2a"}, - {file = "astropy-5.3.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:830fb4b19c36bf8092fdd74ecf9df5b78c6435bf571c5e09b7f644875148a058"}, - {file = "astropy-5.3.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9a707c534408d26d90014a1938af883f6cbf43a3dd78df8bb9a191d275c09f8d"}, - {file = "astropy-5.3.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0bb2b9b93bc879bcd032931e7fc07c3a3de6f9546fed17f0f12974e0ffc83e0"}, - {file = "astropy-5.3.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:1fa4437fe8d1e103f14cb1cb4e8449c93ae4190b5e9fd97e9c61a5155de9af0d"}, - {file = "astropy-5.3.4-cp311-cp311-win32.whl", hash = "sha256:c656c7fd3d862bcb9d3c4a87b8e9488d0c351b4edf348410c09a26641b9d4731"}, - {file = "astropy-5.3.4-cp311-cp311-win_amd64.whl", hash = "sha256:4c4971abae8e3ddfb8f40447d78aaf24e6ce44b976b3874770ff533609050366"}, - {file = "astropy-5.3.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:887db411555692fb1858ae305f87fd2ff42a021b68c78abbf3fa1fc64641e895"}, - {file = "astropy-5.3.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:e4033d7a6bd2da38b83ec65f7282dfeb2641f2b2d41b1cd392cdbe3d6f8abfff"}, - {file = "astropy-5.3.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2cc6503b79d4fb61ca80e1d37dd609fabca6d2e0124e17f831cc08c2e6ff75e"}, - {file = "astropy-5.3.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f3f9fe1d76d151428a8d2bc7d50f4a47ae6e7141c11880a3ad259ac7b906b03"}, - {file = "astropy-5.3.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:6e0f7ecbb2a8acb3eace99bcaca30dd1ce001e6f4750a009fd9cc3b8d1b49c58"}, - {file = "astropy-5.3.4-cp312-cp312-win32.whl", hash = "sha256:d915e6370315a1a6a40c2576e77d0063f48cc3b5f8873087cad8ad19dd429d19"}, - {file = "astropy-5.3.4-cp312-cp312-win_amd64.whl", hash = "sha256:69f5a3789a8a4cb00815630b63f950be629a983896dc1aba92566ccc7937a77d"}, - {file = "astropy-5.3.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d5d1a1be788344f11a94a5356c1a25b4d45f1736b740edb4d8e3a272b872a8fa"}, - {file = "astropy-5.3.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ae59e4d41461ad96a2573bc51408000a7b4f90dce2bad07646fa6409a12a5a74"}, - {file = "astropy-5.3.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b4c4d3a14e8e3a33208683331b16a721ab9f9493ed998d34533532fdaeaa3642"}, - {file = "astropy-5.3.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d8f58f53294f07cd3f9173bb113ad60d2cd823501c99251891936202fed76681"}, - {file = "astropy-5.3.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:f79400dc6641bb0202a8998cfb08ad1afe197818e27c946491a292e2ffd16a1b"}, - {file = "astropy-5.3.4-cp39-cp39-win32.whl", hash = "sha256:fd0baa7621d03aa74bb8ba673d7955381d15aed4f30dc2a56654560401fc3aca"}, - {file = "astropy-5.3.4-cp39-cp39-win_amd64.whl", hash = "sha256:9ed6116d07de02183d966e9a5dabc86f6fd3d86cc3e1e8b9feef89fd757be8a6"}, - {file = "astropy-5.3.4.tar.gz", hash = "sha256:d490f7e2faac2ccc01c9244202d629154259af8a979104ced89dc4ace4e6f1d8"}, -] - -[package.dependencies] -numpy = ">=1.21,<2" -packaging = ">=19.0" -pyerfa = ">=2.0" -PyYAML = ">=3.13" - -[package.extras] -all = ["asdf (>=2.10.0)", "beautifulsoup4", "bleach", "bottleneck", "certifi", "dask[array]", "fsspec[http] (>=2022.8.2)", "h5py", "html5lib", "ipython (>=4.2)", "jplephem", "matplotlib (>=3.3,!=3.4.0,!=3.5.2)", "mpmath", "pandas", "pre-commit", "pyarrow (>=5.0.0)", "pytest (>=7.0,<8)", "pytz", "s3fs (>=2022.8.2)", "scipy (>=1.5)", "sortedcontainers", "typing-extensions (>=3.10.0.1)"] -docs = ["Jinja2 (>=3.0)", "matplotlib (>=3.3,!=3.4.0,!=3.5.2)", "pytest (>=7.0,<8)", "scipy (>=1.3)", "sphinx", "sphinx-astropy (>=1.6)", "sphinx-changelog (>=1.2.0)"] -recommended = ["matplotlib (>=3.3,!=3.4.0,!=3.5.2)", "scipy (>=1.5)"] -test = ["pytest (>=7.0,<8)", "pytest-astropy (>=0.10)", "pytest-astropy-header (>=0.2.1)", "pytest-doctestplus (>=0.12)", "pytest-xdist"] -test-all = ["coverage[toml]", "ipython (>=4.2)", "objgraph", "pytest (>=7.0,<8)", "pytest-astropy (>=0.10)", "pytest-astropy-header (>=0.2.1)", "pytest-doctestplus (>=0.12)", "pytest-xdist", "sgp4 (>=2.3)", "skyfield (>=1.20)"] - [[package]] name = "asttokens" version = "2.4.0" @@ -245,21 +157,6 @@ docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib- tests = ["attrs[tests-no-zope]", "zope-interface"] tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -[[package]] -name = "autograd" -version = "1.6.2" -description = "Efficiently computes derivatives of numpy code." -optional = false -python-versions = "*" -files = [ - {file = "autograd-1.6.2-py3-none-any.whl", hash = "sha256:208dde2a938e63b4f8f5049b1985505139e529068b0d26f8cd7771fd3eb145d5"}, - {file = "autograd-1.6.2.tar.gz", hash = "sha256:8731e08a0c4e389d8695a40072ada4512641c113b6cace8f4cfbe8eb7e9aedeb"}, -] - -[package.dependencies] -future = ">=0.15.2" -numpy = ">=1.12" - [[package]] name = "babel" version = "2.13.0" @@ -305,33 +202,33 @@ lxml = ["lxml"] [[package]] name = "black" -version = "24.2.0" +version = "24.3.0" description = "The uncompromising code formatter." optional = false python-versions = ">=3.8" files = [ - {file = "black-24.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:6981eae48b3b33399c8757036c7f5d48a535b962a7c2310d19361edeef64ce29"}, - {file = "black-24.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d533d5e3259720fdbc1b37444491b024003e012c5173f7d06825a77508085430"}, - {file = "black-24.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61a0391772490ddfb8a693c067df1ef5227257e72b0e4108482b8d41b5aee13f"}, - {file = "black-24.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:992e451b04667116680cb88f63449267c13e1ad134f30087dec8527242e9862a"}, - {file = "black-24.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:163baf4ef40e6897a2a9b83890e59141cc8c2a98f2dda5080dc15c00ee1e62cd"}, - {file = "black-24.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e37c99f89929af50ffaf912454b3e3b47fd64109659026b678c091a4cd450fb2"}, - {file = "black-24.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9de21bafcba9683853f6c96c2d515e364aee631b178eaa5145fc1c61a3cc92"}, - {file = "black-24.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:9db528bccb9e8e20c08e716b3b09c6bdd64da0dd129b11e160bf082d4642ac23"}, - {file = "black-24.2.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:d84f29eb3ee44859052073b7636533ec995bd0f64e2fb43aeceefc70090e752b"}, - {file = "black-24.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1e08fb9a15c914b81dd734ddd7fb10513016e5ce7e6704bdd5e1251ceee51ac9"}, - {file = "black-24.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:810d445ae6069ce64030c78ff6127cd9cd178a9ac3361435708b907d8a04c693"}, - {file = "black-24.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:ba15742a13de85e9b8f3239c8f807723991fbfae24bad92d34a2b12e81904982"}, - {file = "black-24.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7e53a8c630f71db01b28cd9602a1ada68c937cbf2c333e6ed041390d6968faf4"}, - {file = "black-24.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:93601c2deb321b4bad8f95df408e3fb3943d85012dddb6121336b8e24a0d1218"}, - {file = "black-24.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0057f800de6acc4407fe75bb147b0c2b5cbb7c3ed110d3e5999cd01184d53b0"}, - {file = "black-24.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:faf2ee02e6612577ba0181f4347bcbcf591eb122f7841ae5ba233d12c39dcb4d"}, - {file = "black-24.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:057c3dc602eaa6fdc451069bd027a1b2635028b575a6c3acfd63193ced20d9c8"}, - {file = "black-24.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:08654d0797e65f2423f850fc8e16a0ce50925f9337fb4a4a176a7aa4026e63f8"}, - {file = "black-24.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ca610d29415ee1a30a3f30fab7a8f4144e9d34c89a235d81292a1edb2b55f540"}, - {file = "black-24.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:4dd76e9468d5536abd40ffbc7a247f83b2324f0c050556d9c371c2b9a9a95e31"}, - {file = "black-24.2.0-py3-none-any.whl", hash = "sha256:e8a6ae970537e67830776488bca52000eaa37fa63b9988e8c487458d9cd5ace6"}, - {file = "black-24.2.0.tar.gz", hash = "sha256:bce4f25c27c3435e4dace4815bcb2008b87e167e3bf4ee47ccdc5ce906eb4894"}, + {file = "black-24.3.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7d5e026f8da0322b5662fa7a8e752b3fa2dac1c1cbc213c3d7ff9bdd0ab12395"}, + {file = "black-24.3.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9f50ea1132e2189d8dff0115ab75b65590a3e97de1e143795adb4ce317934995"}, + {file = "black-24.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e2af80566f43c85f5797365077fb64a393861a3730bd110971ab7a0c94e873e7"}, + {file = "black-24.3.0-cp310-cp310-win_amd64.whl", hash = "sha256:4be5bb28e090456adfc1255e03967fb67ca846a03be7aadf6249096100ee32d0"}, + {file = "black-24.3.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4f1373a7808a8f135b774039f61d59e4be7eb56b2513d3d2f02a8b9365b8a8a9"}, + {file = "black-24.3.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:aadf7a02d947936ee418777e0247ea114f78aff0d0959461057cae8a04f20597"}, + {file = "black-24.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65c02e4ea2ae09d16314d30912a58ada9a5c4fdfedf9512d23326128ac08ac3d"}, + {file = "black-24.3.0-cp311-cp311-win_amd64.whl", hash = "sha256:bf21b7b230718a5f08bd32d5e4f1db7fc8788345c8aea1d155fc17852b3410f5"}, + {file = "black-24.3.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:2818cf72dfd5d289e48f37ccfa08b460bf469e67fb7c4abb07edc2e9f16fb63f"}, + {file = "black-24.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4acf672def7eb1725f41f38bf6bf425c8237248bb0804faa3965c036f7672d11"}, + {file = "black-24.3.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7ed6668cbbfcd231fa0dc1b137d3e40c04c7f786e626b405c62bcd5db5857e4"}, + {file = "black-24.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:56f52cfbd3dabe2798d76dbdd299faa046a901041faf2cf33288bc4e6dae57b5"}, + {file = "black-24.3.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:79dcf34b33e38ed1b17434693763301d7ccbd1c5860674a8f871bd15139e7837"}, + {file = "black-24.3.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:e19cb1c6365fd6dc38a6eae2dcb691d7d83935c10215aef8e6c38edee3f77abd"}, + {file = "black-24.3.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:65b76c275e4c1c5ce6e9870911384bff5ca31ab63d19c76811cb1fb162678213"}, + {file = "black-24.3.0-cp38-cp38-win_amd64.whl", hash = "sha256:b5991d523eee14756f3c8d5df5231550ae8993e2286b8014e2fdea7156ed0959"}, + {file = "black-24.3.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c45f8dff244b3c431b36e3224b6be4a127c6aca780853574c00faf99258041eb"}, + {file = "black-24.3.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6905238a754ceb7788a73f02b45637d820b2f5478b20fec82ea865e4f5d4d9f7"}, + {file = "black-24.3.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7de8d330763c66663661a1ffd432274a2f92f07feeddd89ffd085b5744f85e7"}, + {file = "black-24.3.0-cp39-cp39-win_amd64.whl", hash = "sha256:7bb041dca0d784697af4646d3b62ba4a6b028276ae878e53f6b4f74ddd6db99f"}, + {file = "black-24.3.0-py3-none-any.whl", hash = "sha256:41622020d7120e01d377f74249e677039d20e6344ff5851de8a10f11f513bf93"}, + {file = "black-24.3.0.tar.gz", hash = "sha256:a0c9c4a0771afc6919578cec71ce82a3e31e054904e7197deacbc9382671c41f"}, ] [package.dependencies] @@ -367,17 +264,6 @@ webencodings = "*" [package.extras] css = ["tinycss2 (>=1.1.0,<1.2)"] -[[package]] -name = "cachetools" -version = "5.3.2" -description = "Extensible memoizing collections and decorators" -optional = false -python-versions = ">=3.7" -files = [ - {file = "cachetools-5.3.2-py3-none-any.whl", hash = "sha256:861f35a13a451f94e301ce2bec7cac63e881232ccce7ed67fab9b5df4d3beaa1"}, - {file = "cachetools-5.3.2.tar.gz", hash = "sha256:086ee420196f7b2ab9ca2db2520aca326318b68fe5ba8bc4d49cca91add450f2"}, -] - [[package]] name = "certifi" version = "2023.7.22" @@ -453,6 +339,17 @@ files = [ [package.dependencies] pycparser = "*" +[[package]] +name = "cfgv" +version = "3.4.0" +description = "Validate configuration and produce human readable error messages." +optional = false +python-versions = ">=3.8" +files = [ + {file = "cfgv-3.4.0-py2.py3-none-any.whl", hash = "sha256:b7265b1f29fd3316bfcd2b330d63d024f2bfd8bcb8b0272f8e19a504856c48f9"}, + {file = "cfgv-3.4.0.tar.gz", hash = "sha256:e52591d4c5f5dead8e0f673fb16db7949d2cfb3f7da4582893288f0ded8fe560"}, +] + [[package]] name = "charset-normalizer" version = "3.3.0" @@ -667,25 +564,6 @@ mypy = ["contourpy[bokeh,docs]", "docutils-stubs", "mypy (==1.4.1)", "types-Pill test = ["Pillow", "contourpy[test-no-images]", "matplotlib"] test-no-images = ["pytest", "pytest-cov", "wurlitzer"] -[[package]] -name = "corner" -version = "2.2.2" -description = "Make some beautiful corner plots" -optional = false -python-versions = ">=3.9" -files = [ - {file = "corner-2.2.2-py3-none-any.whl", hash = "sha256:e7577cdb59cfa304effa243b0c7ac0e3777030d3dc2f2e217a387e87a47074bb"}, - {file = "corner-2.2.2.tar.gz", hash = "sha256:4bc79f3b6778c270103f0926e64ef2606c48c3b6f92daf5382fc4babf5d608d1"}, -] - -[package.dependencies] -matplotlib = ">=2.1" - -[package.extras] -arviz = ["arviz (>=0.9)"] -docs = ["arviz (>=0.9)", "ipython", "myst-nb", "pandoc", "sphinx (>=1.7.5)", "sphinx-book-theme"] -test = ["arviz (>=0.9)", "pytest", "scipy"] - [[package]] name = "coverage" version = "7.3.2" @@ -806,26 +684,6 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] -[[package]] -name = "deepbench" -version = "0.2.2" -description = "Physics Benchmark Dataset Generator" -optional = false -python-versions = ">=3.8,<4.0" -files = [ - {file = "deepbench-0.2.2-py3-none-any.whl", hash = "sha256:309f269a5a65e681f4ce20425430449dad7f677659344b4860927eefa308c335"}, - {file = "deepbench-0.2.2.tar.gz", hash = "sha256:bab3ef5a048f9c8e9b3e162cbfb75dcb4cdbef08689d479ca29d32829f5dd887"}, -] - -[package.dependencies] -astropy = ">=5.2.2,<6.0.0" -autograd = ">=1.5,<2.0" -h5py = ">=3.9.0" -matplotlib = ">=3.7.1,<4.0.0" -numpy = ">=1.24.3,<2.0.0" -pyyaml = ">=6.0,<7.0" -scikit-image = ">=0.20.0,<0.21.0" - [[package]] name = "defusedxml" version = "0.7.1" @@ -837,6 +695,17 @@ files = [ {file = "defusedxml-0.7.1.tar.gz", hash = "sha256:1bb3032db185915b62d7c6209c5a8792be6a32ab2fedacc84e01b52c51aa3e69"}, ] +[[package]] +name = "distlib" +version = "0.3.8" +description = "Distribution utilities" +optional = false +python-versions = "*" +files = [ + {file = "distlib-0.3.8-py2.py3-none-any.whl", hash = "sha256:034db59a0b96f8ca18035f36290806a9a6e6bd9d1ff91e45a7f172eb17e51784"}, + {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, +] + [[package]] name = "exceptiongroup" version = "1.1.3" @@ -1022,157 +891,6 @@ smb = ["smbprotocol"] ssh = ["paramiko"] tqdm = ["tqdm"] -[[package]] -name = "future" -version = "0.18.3" -description = "Clean single-source support for Python 3 and 2" -optional = false -python-versions = ">=2.6, !=3.0.*, !=3.1.*, !=3.2.*" -files = [ - {file = "future-0.18.3.tar.gz", hash = "sha256:34a17436ed1e96697a86f9de3d15a3b0be01d8bc8de9c1dffd59fb8234ed5307"}, -] - -[[package]] -name = "google-auth" -version = "2.27.0" -description = "Google Authentication Library" -optional = false -python-versions = ">=3.7" -files = [ - {file = "google-auth-2.27.0.tar.gz", hash = "sha256:e863a56ccc2d8efa83df7a80272601e43487fa9a728a376205c86c26aaefa821"}, - {file = "google_auth-2.27.0-py2.py3-none-any.whl", hash = "sha256:8e4bad367015430ff253fe49d500fdc3396c1a434db5740828c728e45bcce245"}, -] - -[package.dependencies] -cachetools = ">=2.0.0,<6.0" -pyasn1-modules = ">=0.2.1" -rsa = ">=3.1.4,<5" - -[package.extras] -aiohttp = ["aiohttp (>=3.6.2,<4.0.0.dev0)", "requests (>=2.20.0,<3.0.0.dev0)"] -enterprise-cert = ["cryptography (==36.0.2)", "pyopenssl (==22.0.0)"] -pyopenssl = ["cryptography (>=38.0.3)", "pyopenssl (>=20.0.0)"] -reauth = ["pyu2f (>=0.1.5)"] -requests = ["requests (>=2.20.0,<3.0.0.dev0)"] - -[[package]] -name = "google-auth-oauthlib" -version = "1.2.0" -description = "Google Authentication Library" -optional = false -python-versions = ">=3.6" -files = [ - {file = "google-auth-oauthlib-1.2.0.tar.gz", hash = "sha256:292d2d3783349f2b0734a0a0207b1e1e322ac193c2c09d8f7c613fb7cc501ea8"}, - {file = "google_auth_oauthlib-1.2.0-py2.py3-none-any.whl", hash = "sha256:297c1ce4cb13a99b5834c74a1fe03252e1e499716718b190f56bcb9c4abc4faf"}, -] - -[package.dependencies] -google-auth = ">=2.15.0" -requests-oauthlib = ">=0.7.0" - -[package.extras] -tool = ["click (>=6.0.0)"] - -[[package]] -name = "graphviz" -version = "0.20.1" -description = "Simple Python interface for Graphviz" -optional = false -python-versions = ">=3.7" -files = [ - {file = "graphviz-0.20.1-py3-none-any.whl", hash = "sha256:587c58a223b51611c0cf461132da386edd896a029524ca61a1462b880bf97977"}, - {file = "graphviz-0.20.1.zip", hash = "sha256:8c58f14adaa3b947daf26c19bc1e98c4e0702cdc31cf99153e6f06904d492bf8"}, -] - -[package.extras] -dev = ["flake8", "pep8-naming", "tox (>=3)", "twine", "wheel"] -docs = ["sphinx (>=5)", "sphinx-autodoc-typehints", "sphinx-rtd-theme"] -test = ["coverage", "mock (>=4)", "pytest (>=7)", "pytest-cov", "pytest-mock (>=3)"] - -[[package]] -name = "grpcio" -version = "1.60.1" -description = "HTTP/2-based RPC framework" -optional = false -python-versions = ">=3.7" -files = [ - {file = "grpcio-1.60.1-cp310-cp310-linux_armv7l.whl", hash = "sha256:14e8f2c84c0832773fb3958240c69def72357bc11392571f87b2d7b91e0bb092"}, - {file = "grpcio-1.60.1-cp310-cp310-macosx_12_0_universal2.whl", hash = "sha256:33aed0a431f5befeffd9d346b0fa44b2c01aa4aeae5ea5b2c03d3e25e0071216"}, - {file = "grpcio-1.60.1-cp310-cp310-manylinux_2_17_aarch64.whl", hash = "sha256:fead980fbc68512dfd4e0c7b1f5754c2a8e5015a04dea454b9cada54a8423525"}, - {file = "grpcio-1.60.1-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:082081e6a36b6eb5cf0fd9a897fe777dbb3802176ffd08e3ec6567edd85bc104"}, - {file = "grpcio-1.60.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:55ccb7db5a665079d68b5c7c86359ebd5ebf31a19bc1a91c982fd622f1e31ff2"}, - {file = "grpcio-1.60.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:9b54577032d4f235452f77a83169b6527bf4b77d73aeada97d45b2aaf1bf5ce0"}, - {file = "grpcio-1.60.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7d142bcd604166417929b071cd396aa13c565749a4c840d6c702727a59d835eb"}, - {file = "grpcio-1.60.1-cp310-cp310-win32.whl", hash = "sha256:2a6087f234cb570008a6041c8ffd1b7d657b397fdd6d26e83d72283dae3527b1"}, - {file = "grpcio-1.60.1-cp310-cp310-win_amd64.whl", hash = "sha256:f2212796593ad1d0235068c79836861f2201fc7137a99aa2fea7beeb3b101177"}, - {file = "grpcio-1.60.1-cp311-cp311-linux_armv7l.whl", hash = "sha256:79ae0dc785504cb1e1788758c588c711f4e4a0195d70dff53db203c95a0bd303"}, - {file = "grpcio-1.60.1-cp311-cp311-macosx_10_10_universal2.whl", hash = "sha256:4eec8b8c1c2c9b7125508ff7c89d5701bf933c99d3910e446ed531cd16ad5d87"}, - {file = "grpcio-1.60.1-cp311-cp311-manylinux_2_17_aarch64.whl", hash = "sha256:8c9554ca8e26241dabe7951aa1fa03a1ba0856688ecd7e7bdbdd286ebc272e4c"}, - {file = "grpcio-1.60.1-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:91422ba785a8e7a18725b1dc40fbd88f08a5bb4c7f1b3e8739cab24b04fa8a03"}, - {file = "grpcio-1.60.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cba6209c96828711cb7c8fcb45ecef8c8859238baf15119daa1bef0f6c84bfe7"}, - {file = "grpcio-1.60.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:c71be3f86d67d8d1311c6076a4ba3b75ba5703c0b856b4e691c9097f9b1e8bd2"}, - {file = "grpcio-1.60.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:af5ef6cfaf0d023c00002ba25d0751e5995fa0e4c9eec6cd263c30352662cbce"}, - {file = "grpcio-1.60.1-cp311-cp311-win32.whl", hash = "sha256:a09506eb48fa5493c58f946c46754ef22f3ec0df64f2b5149373ff31fb67f3dd"}, - {file = "grpcio-1.60.1-cp311-cp311-win_amd64.whl", hash = "sha256:49c9b6a510e3ed8df5f6f4f3c34d7fbf2d2cae048ee90a45cd7415abab72912c"}, - {file = "grpcio-1.60.1-cp312-cp312-linux_armv7l.whl", hash = "sha256:b58b855d0071575ea9c7bc0d84a06d2edfbfccec52e9657864386381a7ce1ae9"}, - {file = "grpcio-1.60.1-cp312-cp312-macosx_10_10_universal2.whl", hash = "sha256:a731ac5cffc34dac62053e0da90f0c0b8560396a19f69d9703e88240c8f05858"}, - {file = "grpcio-1.60.1-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:cf77f8cf2a651fbd869fbdcb4a1931464189cd210abc4cfad357f1cacc8642a6"}, - {file = "grpcio-1.60.1-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c557e94e91a983e5b1e9c60076a8fd79fea1e7e06848eb2e48d0ccfb30f6e073"}, - {file = "grpcio-1.60.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:069fe2aeee02dfd2135d562d0663fe70fbb69d5eed6eb3389042a7e963b54de8"}, - {file = "grpcio-1.60.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb0af13433dbbd1c806e671d81ec75bd324af6ef75171fd7815ca3074fe32bfe"}, - {file = "grpcio-1.60.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:2f44c32aef186bbba254129cea1df08a20be414144ac3bdf0e84b24e3f3b2e05"}, - {file = "grpcio-1.60.1-cp312-cp312-win32.whl", hash = "sha256:a212e5dea1a4182e40cd3e4067ee46be9d10418092ce3627475e995cca95de21"}, - {file = "grpcio-1.60.1-cp312-cp312-win_amd64.whl", hash = "sha256:6e490fa5f7f5326222cb9f0b78f207a2b218a14edf39602e083d5f617354306f"}, - {file = "grpcio-1.60.1-cp37-cp37m-linux_armv7l.whl", hash = "sha256:4216e67ad9a4769117433814956031cb300f85edc855252a645a9a724b3b6594"}, - {file = "grpcio-1.60.1-cp37-cp37m-macosx_10_10_universal2.whl", hash = "sha256:73e14acd3d4247169955fae8fb103a2b900cfad21d0c35f0dcd0fdd54cd60367"}, - {file = "grpcio-1.60.1-cp37-cp37m-manylinux_2_17_aarch64.whl", hash = "sha256:6ecf21d20d02d1733e9c820fb5c114c749d888704a7ec824b545c12e78734d1c"}, - {file = "grpcio-1.60.1-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:33bdea30dcfd4f87b045d404388469eb48a48c33a6195a043d116ed1b9a0196c"}, - {file = "grpcio-1.60.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:53b69e79d00f78c81eecfb38f4516080dc7f36a198b6b37b928f1c13b3c063e9"}, - {file = "grpcio-1.60.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:39aa848794b887120b1d35b1b994e445cc028ff602ef267f87c38122c1add50d"}, - {file = "grpcio-1.60.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:72153a0d2e425f45b884540a61c6639436ddafa1829a42056aa5764b84108b8e"}, - {file = "grpcio-1.60.1-cp37-cp37m-win_amd64.whl", hash = "sha256:50d56280b482875d1f9128ce596e59031a226a8b84bec88cb2bf76c289f5d0de"}, - {file = "grpcio-1.60.1-cp38-cp38-linux_armv7l.whl", hash = "sha256:6d140bdeb26cad8b93c1455fa00573c05592793c32053d6e0016ce05ba267549"}, - {file = "grpcio-1.60.1-cp38-cp38-macosx_10_10_universal2.whl", hash = "sha256:bc808924470643b82b14fe121923c30ec211d8c693e747eba8a7414bc4351a23"}, - {file = "grpcio-1.60.1-cp38-cp38-manylinux_2_17_aarch64.whl", hash = "sha256:70c83bb530572917be20c21f3b6be92cd86b9aecb44b0c18b1d3b2cc3ae47df0"}, - {file = "grpcio-1.60.1-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9b106bc52e7f28170e624ba61cc7dc6829566e535a6ec68528f8e1afbed1c41f"}, - {file = "grpcio-1.60.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:30e980cd6db1088c144b92fe376747328d5554bc7960ce583ec7b7d81cd47287"}, - {file = "grpcio-1.60.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:0c5807e9152eff15f1d48f6b9ad3749196f79a4a050469d99eecb679be592acc"}, - {file = "grpcio-1.60.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:f1c3dc536b3ee124e8b24feb7533e5c70b9f2ef833e3b2e5513b2897fd46763a"}, - {file = "grpcio-1.60.1-cp38-cp38-win32.whl", hash = "sha256:d7404cebcdb11bb5bd40bf94131faf7e9a7c10a6c60358580fe83913f360f929"}, - {file = "grpcio-1.60.1-cp38-cp38-win_amd64.whl", hash = "sha256:c8754c75f55781515a3005063d9a05878b2cfb3cb7e41d5401ad0cf19de14872"}, - {file = "grpcio-1.60.1-cp39-cp39-linux_armv7l.whl", hash = "sha256:0250a7a70b14000fa311de04b169cc7480be6c1a769b190769d347939d3232a8"}, - {file = "grpcio-1.60.1-cp39-cp39-macosx_10_10_universal2.whl", hash = "sha256:660fc6b9c2a9ea3bb2a7e64ba878c98339abaf1811edca904ac85e9e662f1d73"}, - {file = "grpcio-1.60.1-cp39-cp39-manylinux_2_17_aarch64.whl", hash = "sha256:76eaaba891083fcbe167aa0f03363311a9f12da975b025d30e94b93ac7a765fc"}, - {file = "grpcio-1.60.1-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5d97c65ea7e097056f3d1ead77040ebc236feaf7f71489383d20f3b4c28412a"}, - {file = "grpcio-1.60.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2bb2a2911b028f01c8c64d126f6b632fcd8a9ac975aa1b3855766c94e4107180"}, - {file = "grpcio-1.60.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:5a1ebbae7e2214f51b1f23b57bf98eeed2cf1ba84e4d523c48c36d5b2f8829ff"}, - {file = "grpcio-1.60.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:9a66f4d2a005bc78e61d805ed95dedfcb35efa84b7bba0403c6d60d13a3de2d6"}, - {file = "grpcio-1.60.1-cp39-cp39-win32.whl", hash = "sha256:8d488fbdbf04283f0d20742b64968d44825617aa6717b07c006168ed16488804"}, - {file = "grpcio-1.60.1-cp39-cp39-win_amd64.whl", hash = "sha256:61b7199cd2a55e62e45bfb629a35b71fc2c0cb88f686a047f25b1112d3810904"}, - {file = "grpcio-1.60.1.tar.gz", hash = "sha256:dd1d3a8d1d2e50ad9b59e10aa7f07c7d1be2b367f3f2d33c5fade96ed5460962"}, -] - -[package.extras] -protobuf = ["grpcio-tools (>=1.60.1)"] - -[[package]] -name = "h5netcdf" -version = "1.2.0" -description = "netCDF4 via h5py" -optional = false -python-versions = ">=3.9" -files = [ - {file = "h5netcdf-1.2.0-py3-none-any.whl", hash = "sha256:aa53c39b94bcd4595a2e5a2f62f3fb4fb8a723b5ca0a05f2db352f014bcfe72c"}, - {file = "h5netcdf-1.2.0.tar.gz", hash = "sha256:7f6b2733bde06ea2575b79a6450d9bd5c38918ff4cb2a355bf22bbe8c86c6bcf"}, -] - -[package.dependencies] -h5py = "*" -packaging = "*" - -[package.extras] -test = ["netCDF4", "pytest"] - [[package]] name = "h5py" version = "3.10.0" @@ -1210,6 +928,20 @@ files = [ [package.dependencies] numpy = ">=1.17.3" +[[package]] +name = "identify" +version = "2.5.35" +description = "File identification library for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "identify-2.5.35-py2.py3-none-any.whl", hash = "sha256:c4de0081837b211594f8e877a6b4fad7ca32bbfc1a9307fdd61c28bfe923f13e"}, + {file = "identify-2.5.35.tar.gz", hash = "sha256:10a7ca245cfcd756a554a7288159f72ff105ad233c7c4b9c6f0f4d108f5f6791"}, +] + +[package.extras] +license = ["ukkonen"] + [[package]] name = "idna" version = "3.4" @@ -1221,37 +953,6 @@ files = [ {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, ] -[[package]] -name = "imageio" -version = "2.31.5" -description = "Library for reading and writing a wide range of image, video, scientific, and volumetric data formats." -optional = false -python-versions = ">=3.8" -files = [ - {file = "imageio-2.31.5-py3-none-any.whl", hash = "sha256:97f68e12ba676f2f4b541684ed81f7f3370dc347e8321bc68ee34d37b2dbac9f"}, - {file = "imageio-2.31.5.tar.gz", hash = "sha256:d8e53f9cd4054880276a3dac0a28c85ba7874084856a55a0294a8ae6ed7f3a8e"}, -] - -[package.dependencies] -numpy = "*" -pillow = ">=8.3.2" - -[package.extras] -all-plugins = ["astropy", "av", "imageio-ffmpeg", "psutil", "tifffile"] -all-plugins-pypy = ["av", "imageio-ffmpeg", "psutil", "tifffile"] -build = ["wheel"] -dev = ["black", "flake8", "fsspec[github]", "pytest", "pytest-cov"] -docs = ["numpydoc", "pydata-sphinx-theme", "sphinx (<6)"] -ffmpeg = ["imageio-ffmpeg", "psutil"] -fits = ["astropy"] -full = ["astropy", "av", "black", "flake8", "fsspec[github]", "gdal", "imageio-ffmpeg", "itk", "numpydoc", "psutil", "pydata-sphinx-theme", "pytest", "pytest-cov", "sphinx (<6)", "tifffile", "wheel"] -gdal = ["gdal"] -itk = ["itk"] -linting = ["black", "flake8"] -pyav = ["av"] -test = ["fsspec[github]", "pytest", "pytest-cov"] -tifffile = ["tifffile"] - [[package]] name = "importlib-metadata" version = "6.8.0" @@ -1908,39 +1609,6 @@ files = [ {file = "kiwisolver-1.4.5.tar.gz", hash = "sha256:e57e563a57fb22a142da34f38acc2fc1a5c864bc29ca1517a88abc963e60d6ec"}, ] -[[package]] -name = "lazy-loader" -version = "0.3" -description = "lazy_loader" -optional = false -python-versions = ">=3.7" -files = [ - {file = "lazy_loader-0.3-py3-none-any.whl", hash = "sha256:1e9e76ee8631e264c62ce10006718e80b2cfc74340d17d1031e0f84af7478554"}, - {file = "lazy_loader-0.3.tar.gz", hash = "sha256:3b68898e34f5b2a29daaaac172c6555512d0f32074f147e2254e4a6d9d838f37"}, -] - -[package.extras] -lint = ["pre-commit (>=3.3)"] -test = ["pytest (>=7.4)", "pytest-cov (>=4.1)"] - -[[package]] -name = "markdown" -version = "3.5.2" -description = "Python implementation of John Gruber's Markdown." -optional = false -python-versions = ">=3.8" -files = [ - {file = "Markdown-3.5.2-py3-none-any.whl", hash = "sha256:d43323865d89fc0cb9b20c75fc8ad313af307cc087e84b657d9eec768eddeadd"}, - {file = "Markdown-3.5.2.tar.gz", hash = "sha256:e1ac7b3dc550ee80e602e71c1d168002f062e49f1b11e26a36264dafd4df2ef8"}, -] - -[package.dependencies] -importlib-metadata = {version = ">=4.4", markers = "python_version < \"3.10\""} - -[package.extras] -docs = ["mdx-gh-links (>=0.2)", "mkdocs (>=1.5)", "mkdocs-gen-files", "mkdocs-literate-nav", "mkdocs-nature (>=0.6)", "mkdocs-section-index", "mkdocstrings[python]"] -testing = ["coverage", "pyyaml"] - [[package]] name = "markupsafe" version = "2.1.3" @@ -2225,21 +1893,18 @@ extra = ["lxml (>=4.6)", "pydot (>=1.4.2)", "pygraphviz (>=1.10)", "sympy (>=1.1 test = ["codecov (>=2.1)", "pytest (>=7.2)", "pytest-cov (>=4.0)"] [[package]] -name = "nflows" -version = "0.14" -description = "Normalizing flows in PyTorch." +name = "nodeenv" +version = "1.8.0" +description = "Node.js virtual environment builder" optional = false -python-versions = "*" +python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*" files = [ - {file = "nflows-0.14.tar.gz", hash = "sha256:6299844a62f9999fcdf2d95cb2d01c091a50136bd17826e303aba646b2d11b55"}, + {file = "nodeenv-1.8.0-py2.py3-none-any.whl", hash = "sha256:df865724bb3c3adc86b3876fa209771517b0cfe596beff01a92700e0e8be4cec"}, + {file = "nodeenv-1.8.0.tar.gz", hash = "sha256:d51e0c37e64fbf47d017feac3145cdbb58836d7eee8c6f6d3b6880c5456227d2"}, ] [package.dependencies] -matplotlib = "*" -numpy = "*" -tensorboard = "*" -torch = "*" -tqdm = "*" +setuptools = "*" [[package]] name = "notebook" @@ -2281,43 +1946,6 @@ jupyter-server = ">=1.8,<3" [package.extras] test = ["pytest", "pytest-console-scripts", "pytest-jupyter", "pytest-tornasync"] -[[package]] -name = "numpy" -version = "1.24.4" -description = "Fundamental package for array computing in Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "numpy-1.24.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0bfb52d2169d58c1cdb8cc1f16989101639b34c7d3ce60ed70b19c63eba0b64"}, - {file = "numpy-1.24.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ed094d4f0c177b1b8e7aa9cba7d6ceed51c0e569a5318ac0ca9a090680a6a1b1"}, - {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:79fc682a374c4a8ed08b331bef9c5f582585d1048fa6d80bc6c35bc384eee9b4"}, - {file = "numpy-1.24.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ffe43c74893dbf38c2b0a1f5428760a1a9c98285553c89e12d70a96a7f3a4d6"}, - {file = "numpy-1.24.4-cp310-cp310-win32.whl", hash = "sha256:4c21decb6ea94057331e111a5bed9a79d335658c27ce2adb580fb4d54f2ad9bc"}, - {file = "numpy-1.24.4-cp310-cp310-win_amd64.whl", hash = "sha256:b4bea75e47d9586d31e892a7401f76e909712a0fd510f58f5337bea9572c571e"}, - {file = "numpy-1.24.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f136bab9c2cfd8da131132c2cf6cc27331dd6fae65f95f69dcd4ae3c3639c810"}, - {file = "numpy-1.24.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e2926dac25b313635e4d6cf4dc4e51c8c0ebfed60b801c799ffc4c32bf3d1254"}, - {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:222e40d0e2548690405b0b3c7b21d1169117391c2e82c378467ef9ab4c8f0da7"}, - {file = "numpy-1.24.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7215847ce88a85ce39baf9e89070cb860c98fdddacbaa6c0da3ffb31b3350bd5"}, - {file = "numpy-1.24.4-cp311-cp311-win32.whl", hash = "sha256:4979217d7de511a8d57f4b4b5b2b965f707768440c17cb70fbf254c4b225238d"}, - {file = "numpy-1.24.4-cp311-cp311-win_amd64.whl", hash = "sha256:b7b1fc9864d7d39e28f41d089bfd6353cb5f27ecd9905348c24187a768c79694"}, - {file = "numpy-1.24.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1452241c290f3e2a312c137a9999cdbf63f78864d63c79039bda65ee86943f61"}, - {file = "numpy-1.24.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:04640dab83f7c6c85abf9cd729c5b65f1ebd0ccf9de90b270cd61935eef0197f"}, - {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a5425b114831d1e77e4b5d812b69d11d962e104095a5b9c3b641a218abcc050e"}, - {file = "numpy-1.24.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd80e219fd4c71fc3699fc1dadac5dcf4fd882bfc6f7ec53d30fa197b8ee22dc"}, - {file = "numpy-1.24.4-cp38-cp38-win32.whl", hash = "sha256:4602244f345453db537be5314d3983dbf5834a9701b7723ec28923e2889e0bb2"}, - {file = "numpy-1.24.4-cp38-cp38-win_amd64.whl", hash = "sha256:692f2e0f55794943c5bfff12b3f56f99af76f902fc47487bdfe97856de51a706"}, - {file = "numpy-1.24.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2541312fbf09977f3b3ad449c4e5f4bb55d0dbf79226d7724211acc905049400"}, - {file = "numpy-1.24.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9667575fb6d13c95f1b36aca12c5ee3356bf001b714fc354eb5465ce1609e62f"}, - {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3a86ed21e4f87050382c7bc96571755193c4c1392490744ac73d660e8f564a9"}, - {file = "numpy-1.24.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d11efb4dbecbdf22508d55e48d9c8384db795e1b7b51ea735289ff96613ff74d"}, - {file = "numpy-1.24.4-cp39-cp39-win32.whl", hash = "sha256:6620c0acd41dbcb368610bb2f4d83145674040025e5536954782467100aa8835"}, - {file = "numpy-1.24.4-cp39-cp39-win_amd64.whl", hash = "sha256:befe2bf740fd8373cf56149a5c23a0f601e82869598d41f8e188a0e9869926f8"}, - {file = "numpy-1.24.4-pp38-pypy38_pp73-macosx_10_9_x86_64.whl", hash = "sha256:31f13e25b4e304632a4619d0e0777662c2ffea99fcae2029556b17d8ff958aef"}, - {file = "numpy-1.24.4-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:95f7ac6540e95bc440ad77f56e520da5bf877f87dca58bd095288dce8940532a"}, - {file = "numpy-1.24.4-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e98f220aa76ca2a977fe435f5b04d7b3470c0a2e6312907b37ba6068f26787f2"}, - {file = "numpy-1.24.4.tar.gz", hash = "sha256:80f5e3a4e498641401868df4208b74581206afbee7cf7b8329daae82676d9463"}, -] - [[package]] name = "numpy" version = "1.26.0" @@ -2359,40 +1987,6 @@ files = [ {file = "numpy-1.26.0.tar.gz", hash = "sha256:f93fc78fe8bf15afe2b8d6b6499f1c73953169fad1e9a8dd086cdff3190e7fdf"}, ] -[[package]] -name = "oauthlib" -version = "3.2.2" -description = "A generic, spec-compliant, thorough implementation of the OAuth request-signing logic" -optional = false -python-versions = ">=3.6" -files = [ - {file = "oauthlib-3.2.2-py3-none-any.whl", hash = "sha256:8139f29aac13e25d502680e9e19963e83f16838d48a0d71c287fe40e7067fbca"}, - {file = "oauthlib-3.2.2.tar.gz", hash = "sha256:9859c40929662bec5d64f34d01c99e093149682a3f38915dc0655d5a633dd918"}, -] - -[package.extras] -rsa = ["cryptography (>=3.0.0)"] -signals = ["blinker (>=1.4.0)"] -signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"] - -[[package]] -name = "opt-einsum" -version = "3.3.0" -description = "Optimizing numpys einsum function" -optional = false -python-versions = ">=3.5" -files = [ - {file = "opt_einsum-3.3.0-py3-none-any.whl", hash = "sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147"}, - {file = "opt_einsum-3.3.0.tar.gz", hash = "sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549"}, -] - -[package.dependencies] -numpy = ">=1.7" - -[package.extras] -docs = ["numpydoc", "sphinx (==1.2.3)", "sphinx-rtd-theme", "sphinxcontrib-napoleon"] -tests = ["pytest", "pytest-cov", "pytest-pep8"] - [[package]] name = "overrides" version = "7.4.0" @@ -2638,6 +2232,24 @@ files = [ dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] +[[package]] +name = "pre-commit" +version = "3.7.0" +description = "A framework for managing and maintaining multi-language pre-commit hooks." +optional = false +python-versions = ">=3.9" +files = [ + {file = "pre_commit-3.7.0-py2.py3-none-any.whl", hash = "sha256:5eae9e10c2b5ac51577c3452ec0a490455c45a0533f7960f993a0d01e59decab"}, + {file = "pre_commit-3.7.0.tar.gz", hash = "sha256:e209d61b8acdcf742404408531f0c37d49d2c734fd7cff2d6076083d191cb060"}, +] + +[package.dependencies] +cfgv = ">=2.0.0" +identify = ">=1.0.0" +nodeenv = ">=0.11.1" +pyyaml = ">=5.1" +virtualenv = ">=20.10.0" + [[package]] name = "prometheus-client" version = "0.17.1" @@ -2666,28 +2278,6 @@ files = [ [package.dependencies] wcwidth = "*" -[[package]] -name = "protobuf" -version = "4.23.4" -description = "" -optional = false -python-versions = ">=3.7" -files = [ - {file = "protobuf-4.23.4-cp310-abi3-win32.whl", hash = "sha256:5fea3c64d41ea5ecf5697b83e41d09b9589e6f20b677ab3c48e5f242d9b7897b"}, - {file = "protobuf-4.23.4-cp310-abi3-win_amd64.whl", hash = "sha256:7b19b6266d92ca6a2a87effa88ecc4af73ebc5cfde194dc737cf8ef23a9a3b12"}, - {file = "protobuf-4.23.4-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:8547bf44fe8cec3c69e3042f5c4fb3e36eb2a7a013bb0a44c018fc1e427aafbd"}, - {file = "protobuf-4.23.4-cp37-abi3-manylinux2014_aarch64.whl", hash = "sha256:fee88269a090ada09ca63551bf2f573eb2424035bcf2cb1b121895b01a46594a"}, - {file = "protobuf-4.23.4-cp37-abi3-manylinux2014_x86_64.whl", hash = "sha256:effeac51ab79332d44fba74660d40ae79985901ac21bca408f8dc335a81aa597"}, - {file = "protobuf-4.23.4-cp37-cp37m-win32.whl", hash = "sha256:c3e0939433c40796ca4cfc0fac08af50b00eb66a40bbbc5dee711998fb0bbc1e"}, - {file = "protobuf-4.23.4-cp37-cp37m-win_amd64.whl", hash = "sha256:9053df6df8e5a76c84339ee4a9f5a2661ceee4a0dab019e8663c50ba324208b0"}, - {file = "protobuf-4.23.4-cp38-cp38-win32.whl", hash = "sha256:e1c915778d8ced71e26fcf43c0866d7499891bca14c4368448a82edc61fdbc70"}, - {file = "protobuf-4.23.4-cp38-cp38-win_amd64.whl", hash = "sha256:351cc90f7d10839c480aeb9b870a211e322bf05f6ab3f55fcb2f51331f80a7d2"}, - {file = "protobuf-4.23.4-cp39-cp39-win32.whl", hash = "sha256:6dd9b9940e3f17077e820b75851126615ee38643c2c5332aa7a359988820c720"}, - {file = "protobuf-4.23.4-cp39-cp39-win_amd64.whl", hash = "sha256:0a5759f5696895de8cc913f084e27fd4125e8fb0914bb729a17816a33819f474"}, - {file = "protobuf-4.23.4-py3-none-any.whl", hash = "sha256:e9d0be5bf34b275b9f87ba7407796556abeeba635455d036c7351f7c183ef8ff"}, - {file = "protobuf-4.23.4.tar.gz", hash = "sha256:ccd9430c0719dce806b93f89c91de7977304729e55377f872a92465d548329a9"}, -] - [[package]] name = "psutil" version = "5.9.5" @@ -2739,31 +2329,6 @@ files = [ [package.extras] tests = ["pytest"] -[[package]] -name = "pyasn1" -version = "0.5.1" -description = "Pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208)" -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" -files = [ - {file = "pyasn1-0.5.1-py2.py3-none-any.whl", hash = "sha256:4439847c58d40b1d0a573d07e3856e95333f1976294494c325775aeca506eb58"}, - {file = "pyasn1-0.5.1.tar.gz", hash = "sha256:6d391a96e59b23130a5cfa74d6fd7f388dbbe26cc8f1edf39fdddf08d9d6676c"}, -] - -[[package]] -name = "pyasn1-modules" -version = "0.3.0" -description = "A collection of ASN.1-based protocols modules" -optional = false -python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,>=2.7" -files = [ - {file = "pyasn1_modules-0.3.0-py2.py3-none-any.whl", hash = "sha256:d3ccd6ed470d9ffbc716be08bd90efbd44d0734bc9303818f7336070984a162d"}, - {file = "pyasn1_modules-0.3.0.tar.gz", hash = "sha256:5bd01446b736eb9d31512a30d46c1ac3395d676c6f3cafa4c03eb54b9925631c"}, -] - -[package.dependencies] -pyasn1 = ">=0.4.6,<0.6.0" - [[package]] name = "pycodestyle" version = "2.11.1" @@ -2786,62 +2351,6 @@ files = [ {file = "pycparser-2.21.tar.gz", hash = "sha256:e644fdec12f7872f86c58ff790da456218b10f863970249516d60a5eaca77206"}, ] -[[package]] -name = "pyerfa" -version = "2.0.0.3" -description = "Python bindings for ERFA" -optional = false -python-versions = ">=3.7" -files = [ - {file = "pyerfa-2.0.0.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:676515861ca3f0cb9d7e693389233e7126413a5ba93a0cc4d36b8ca933951e8d"}, - {file = "pyerfa-2.0.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a438865894d226247dcfcb60d683ae075a52716504537052371b2b73458fe4fc"}, - {file = "pyerfa-2.0.0.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:73bf7d23f069d47632a2feeb1e73454b10392c4f3c16116017a6983f1f0e9b2b"}, - {file = "pyerfa-2.0.0.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:780b0f90adf500b8ba24e9d509a690576a7e8287e354cfb90227c5963690d3fc"}, - {file = "pyerfa-2.0.0.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5447bb45ddedde3052693c86b941a4908f5dbeb4a697bda45b5b89de92cfb74a"}, - {file = "pyerfa-2.0.0.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7c24e7960c6cdd3fa3f4dba5f3444a106ad48c94ff0b19eebaee06a142c18c52"}, - {file = "pyerfa-2.0.0.3-cp310-cp310-win32.whl", hash = "sha256:170a83bd0243da518119b846f296cf33fa03f1f884a88578c1a38560182cf64e"}, - {file = "pyerfa-2.0.0.3-cp310-cp310-win_amd64.whl", hash = "sha256:51aa6e0faa4aa9ad8f0eef1c47fec76c5bebc0da7023a436089bdd6e5cfd625f"}, - {file = "pyerfa-2.0.0.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4fa9fceeb78057bfff7ae3aa6cdad3f1b193722de22bdbb75319256f4a9e2f76"}, - {file = "pyerfa-2.0.0.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a8a2029fc62ff2369d01219f66a5ce6aed35ef33eddb06118b6c27e8573a9ed8"}, - {file = "pyerfa-2.0.0.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:da888da2c8db5a78273fbf0af4e74f04e2d312d371c3c021cf6c3b14fa60fe3b"}, - {file = "pyerfa-2.0.0.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7354753addba5261ec1cbf1ba45784ed3a5c42da565ecc6e0aa36b7a17fa4689"}, - {file = "pyerfa-2.0.0.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3b55f7278c1dd362648d7956e1a5365ade5fed2fe5541b721b3ceb5271128892"}, - {file = "pyerfa-2.0.0.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:23e5efcf96ed7161d74f79ca261d255e1f36988843d22cd97d8f60fe9c868d44"}, - {file = "pyerfa-2.0.0.3-cp311-cp311-win32.whl", hash = "sha256:f0e9d0b122c454bcad5dbd0c3283b200783031d3f99ca9c550f49a7a7d4c41ea"}, - {file = "pyerfa-2.0.0.3-cp311-cp311-win_amd64.whl", hash = "sha256:09af83540e23a7d61a8368b0514b3daa4ed967e1e52d0add4f501f58c500dd7f"}, - {file = "pyerfa-2.0.0.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6a07444fd53a5dd18d7955f86f8d9b1be9a68ceb143e1145c0019a310c913c04"}, - {file = "pyerfa-2.0.0.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:daf7364e475cff1f973e2fcf6962de9df9642c8802b010e29b2c592ae337e3c5"}, - {file = "pyerfa-2.0.0.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8458421166f6ffe2e259aaf4aaa6e802d6539649a40e3194a81d30dccdc167a"}, - {file = "pyerfa-2.0.0.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96ea688341176ae6220cc4743cda655549d71e3e3b60c5a99d02d5912d0ddf55"}, - {file = "pyerfa-2.0.0.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d56f6b5a0a3ed7b80d630041829463a872946df277259b5453298842d42a54a4"}, - {file = "pyerfa-2.0.0.3-cp37-cp37m-win32.whl", hash = "sha256:3ecb598924ddb4ea2b06efc6f1e55ca70897ed178a690e2eaa1e290448466c7c"}, - {file = "pyerfa-2.0.0.3-cp37-cp37m-win_amd64.whl", hash = "sha256:1033fdb890ec70d3a511e20a464afc8abbea2180108f27b14d8f1d1addc38cbe"}, - {file = "pyerfa-2.0.0.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:2d8c0dbb17119e52def33f9d6dbf2deaf2113ed3e657b6ff692df9b6a3598397"}, - {file = "pyerfa-2.0.0.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8a1edd2cbe4ead3bf9a51e578d5d83bdd7ab3b3ccb69e09b89a4c42aa5b35ffb"}, - {file = "pyerfa-2.0.0.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a04c3b715c924b6f972dd440a94a701a16a07700bc8ba9e88b1df765bdc36ad0"}, - {file = "pyerfa-2.0.0.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d01c341c45b860ee5c7585ef003118c8015e9d65c30668d2f5bf657e1dcdd68"}, - {file = "pyerfa-2.0.0.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:24d89ead30edc6038408336ad9b696683e74c4eef550708fca6afef3ecd5b010"}, - {file = "pyerfa-2.0.0.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0b8c5e74d48a505a014e855cd4c7be11604901d94fd6f34b685f6720b7b20ed8"}, - {file = "pyerfa-2.0.0.3-cp38-cp38-win32.whl", hash = "sha256:2ccba04de166d81bdd3adcf10428d908ce2f3a56ed1c2767d740fec12680edbd"}, - {file = "pyerfa-2.0.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:3df87743e27588c5bd5e1f3a886629b3277fdd418059ca048420d33169376775"}, - {file = "pyerfa-2.0.0.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:88aa1acedf298d255cc4b0740ee11a3b303b71763dba2f039d48abf0a95cf9df"}, - {file = "pyerfa-2.0.0.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:06d4f08e96867b1fc3ae9a9e4b38693ed0806463288efc41473ad16e14774504"}, - {file = "pyerfa-2.0.0.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f1819e0d95ff8dead80614f8063919d82b2dbb55437b6c0109d3393c1ab55954"}, - {file = "pyerfa-2.0.0.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61f1097ac2ee8c15a2a636cdfb99340d708574d66f4610456bd457d1e6b852f4"}, - {file = "pyerfa-2.0.0.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:36f42ee01a62c6cbba58103e6f8e600b21ad3a71262dccf03d476efb4a20ea71"}, - {file = "pyerfa-2.0.0.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3ecd6167b48bb8f1922fae7b49554616f2e7382748a4320ad46ebd7e2cc62f3d"}, - {file = "pyerfa-2.0.0.3-cp39-cp39-win32.whl", hash = "sha256:7f9eabfefa5317ce58fe22480102902f10f270fc64a5636c010f7c0b7e0fb032"}, - {file = "pyerfa-2.0.0.3-cp39-cp39-win_amd64.whl", hash = "sha256:4ea7ca03ecc440224c2bed8fb136fadf6cf8aea8ba67d717f635116f30c8cc8c"}, - {file = "pyerfa-2.0.0.3.tar.gz", hash = "sha256:d77fbbfa58350c194ccb99e5d93aa05d3c2b14d5aad8b662d93c6ad9fff41f39"}, -] - -[package.dependencies] -numpy = ">=1.17" - -[package.extras] -docs = ["sphinx-astropy (>=1.3)"] -test = ["pytest", "pytest-doctestplus (>=0.7)"] - [[package]] name = "pyflakes" version = "3.2.0" @@ -2867,25 +2376,6 @@ files = [ [package.extras] plugins = ["importlib-metadata"] -[[package]] -name = "pyknos" -version = "0.15.2" -description = "Conditional density estimation." -optional = false -python-versions = ">=3.6.0" -files = [ - {file = "pyknos-0.15.2-py2.py3-none-any.whl", hash = "sha256:cdd9d9c45706fe9fe5ed9991edbc6c728ad6bbfb928d6394fb858864dd7a8158"}, - {file = "pyknos-0.15.2.tar.gz", hash = "sha256:ee09ea841858e79ed9f3e104b6654aea676f3403a2f228ef76151e108968caf3"}, -] - -[package.dependencies] -matplotlib = "*" -nflows = "0.14" -numpy = "*" -tensorboard = "*" -torch = "*" -tqdm = "*" - [[package]] name = "pyparsing" version = "3.1.1" @@ -2900,48 +2390,6 @@ files = [ [package.extras] diagrams = ["jinja2", "railroad-diagrams"] -[[package]] -name = "pyro-api" -version = "0.1.2" -description = "Generic API for dispatch to Pyro backends." -optional = false -python-versions = "*" -files = [ - {file = "pyro-api-0.1.2.tar.gz", hash = "sha256:a1b900d9580aa1c2fab3b123ab7ff33413744da7c5f440bd4aadc4d40d14d920"}, - {file = "pyro_api-0.1.2-py3-none-any.whl", hash = "sha256:10e0e42e9e4401ce464dab79c870e50dfb4f413d326fa777f3582928ef9caf8f"}, -] - -[package.extras] -dev = ["ipython", "sphinx (>=2.0)", "sphinx-rtd-theme"] -test = ["flake8", "pytest (>=5.0)"] - -[[package]] -name = "pyro-ppl" -version = "1.8.6" -description = "A Python library for probabilistic modeling and inference" -optional = false -python-versions = ">=3.7" -files = [ - {file = "pyro-ppl-1.8.6.tar.gz", hash = "sha256:00d2f4dda8a53e66d955124dc6e49e92dcf570cd3bd706825091db764d93cd07"}, - {file = "pyro_ppl-1.8.6-py3-none-any.whl", hash = "sha256:18a28febe1be9c42af94a684c2971a798f2acb51f77b07e8430f146eafe11fed"}, -] - -[package.dependencies] -numpy = ">=1.7" -opt-einsum = ">=2.3.2" -pyro-api = ">=0.1.1" -torch = ">=1.11.0" -tqdm = ">=4.36" - -[package.extras] -dev = ["black (>=21.4b0)", "graphviz (>=0.8)", "jupyter (>=1.0.0)", "lap", "matplotlib (>=1.3)", "mypy (>=0.812)", "nbformat", "nbsphinx (>=0.3.2)", "nbstripout", "nbval", "ninja", "pandas", "pillow (==8.2.0)", "pypandoc", "pytest (>=5.0)", "pytest-xdist", "ruff", "scikit-learn", "scipy (>=1.1)", "seaborn (>=0.11.0)", "sphinx", "sphinx-rtd-theme", "torchvision (>=0.12.0)", "visdom (>=0.1.4,<0.2.2)", "wget", "yapf"] -extras = ["graphviz (>=0.8)", "jupyter (>=1.0.0)", "lap", "matplotlib (>=1.3)", "pandas", "pillow (==8.2.0)", "scikit-learn", "seaborn (>=0.11.0)", "torchvision (>=0.12.0)", "visdom (>=0.1.4,<0.2.2)", "wget"] -funsor = ["funsor[torch] (==0.4.4)"] -horovod = ["horovod[pytorch] (>=0.19)"] -lightning = ["pytorch-lightning"] -profile = ["prettytable", "pytest-benchmark", "snakeviz"] -test = ["black (>=21.4b0)", "graphviz (>=0.8)", "jupyter (>=1.0.0)", "lap", "matplotlib (>=1.3)", "nbval", "pandas", "pillow (==8.2.0)", "pytest (>=5.0)", "pytest-cov", "pytest-xdist", "ruff", "scikit-learn", "scipy (>=1.1)", "seaborn (>=0.11.0)", "torchvision (>=0.12.0)", "visdom (>=0.1.4,<0.2.2)", "wget"] - [[package]] name = "pytest" version = "7.4.2" @@ -3018,43 +2466,6 @@ files = [ {file = "pytz-2023.3.post1.tar.gz", hash = "sha256:7b4fddbeb94a1eba4b557da24f19fdf9db575192544270a9101d8509f9f43d7b"}, ] -[[package]] -name = "pywavelets" -version = "1.4.1" -description = "PyWavelets, wavelet transform module" -optional = false -python-versions = ">=3.8" -files = [ - {file = "PyWavelets-1.4.1-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:d854411eb5ee9cb4bc5d0e66e3634aeb8f594210f6a1bed96dbed57ec70f181c"}, - {file = "PyWavelets-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:231b0e0b1cdc1112f4af3c24eea7bf181c418d37922a67670e9bf6cfa2d544d4"}, - {file = "PyWavelets-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:754fa5085768227c4f4a26c1e0c78bc509a266d9ebd0eb69a278be7e3ece943c"}, - {file = "PyWavelets-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da7b9c006171be1f9ddb12cc6e0d3d703b95f7f43cb5e2c6f5f15d3233fcf202"}, - {file = "PyWavelets-1.4.1-cp310-cp310-win32.whl", hash = "sha256:67a0d28a08909f21400cb09ff62ba94c064882ffd9e3a6b27880a111211d59bd"}, - {file = "PyWavelets-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:91d3d393cffa634f0e550d88c0e3f217c96cfb9e32781f2960876f1808d9b45b"}, - {file = "PyWavelets-1.4.1-cp311-cp311-macosx_10_13_x86_64.whl", hash = "sha256:64c6bac6204327321db30b775060fbe8e8642316e6bff17f06b9f34936f88875"}, - {file = "PyWavelets-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3f19327f2129fb7977bc59b966b4974dfd72879c093e44a7287500a7032695de"}, - {file = "PyWavelets-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad987748f60418d5f4138db89d82ba0cb49b086e0cbb8fd5c3ed4a814cfb705e"}, - {file = "PyWavelets-1.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:875d4d620eee655346e3589a16a73790cf9f8917abba062234439b594e706784"}, - {file = "PyWavelets-1.4.1-cp311-cp311-win32.whl", hash = "sha256:7231461d7a8eb3bdc7aa2d97d9f67ea5a9f8902522818e7e2ead9c2b3408eeb1"}, - {file = "PyWavelets-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:daf0aa79842b571308d7c31a9c43bc99a30b6328e6aea3f50388cd8f69ba7dbc"}, - {file = "PyWavelets-1.4.1-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:ab7da0a17822cd2f6545626946d3b82d1a8e106afc4b50e3387719ba01c7b966"}, - {file = "PyWavelets-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:578af438a02a86b70f1975b546f68aaaf38f28fb082a61ceb799816049ed18aa"}, - {file = "PyWavelets-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9cb5ca8d11d3f98e89e65796a2125be98424d22e5ada360a0dbabff659fca0fc"}, - {file = "PyWavelets-1.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:058b46434eac4c04dd89aeef6fa39e4b6496a951d78c500b6641fd5b2cc2f9f4"}, - {file = "PyWavelets-1.4.1-cp38-cp38-win32.whl", hash = "sha256:de7cd61a88a982edfec01ea755b0740e94766e00a1ceceeafef3ed4c85c605cd"}, - {file = "PyWavelets-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:7ab8d9db0fe549ab2ee0bea61f614e658dd2df419d5b75fba47baa761e95f8f2"}, - {file = "PyWavelets-1.4.1-cp39-cp39-macosx_10_13_x86_64.whl", hash = "sha256:23bafd60350b2b868076d976bdd92f950b3944f119b4754b1d7ff22b7acbf6c6"}, - {file = "PyWavelets-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:d0e56cd7a53aed3cceca91a04d62feb3a0aca6725b1912d29546c26f6ea90426"}, - {file = "PyWavelets-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:030670a213ee8fefa56f6387b0c8e7d970c7f7ad6850dc048bd7c89364771b9b"}, - {file = "PyWavelets-1.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:71ab30f51ee4470741bb55fc6b197b4a2b612232e30f6ac069106f0156342356"}, - {file = "PyWavelets-1.4.1-cp39-cp39-win32.whl", hash = "sha256:47cac4fa25bed76a45bc781a293c26ac63e8eaae9eb8f9be961758d22b58649c"}, - {file = "PyWavelets-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:88aa5449e109d8f5e7f0adef85f7f73b1ab086102865be64421a3a3d02d277f4"}, - {file = "PyWavelets-1.4.1.tar.gz", hash = "sha256:6437af3ddf083118c26d8f97ab43b0724b956c9f958e9ea788659f6a2834ba93"}, -] - -[package.dependencies] -numpy = ">=1.17.3" - [[package]] name = "pywin32" version = "306" @@ -3335,24 +2746,6 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] -[[package]] -name = "requests-oauthlib" -version = "1.3.1" -description = "OAuthlib authentication support for Requests." -optional = false -python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*" -files = [ - {file = "requests-oauthlib-1.3.1.tar.gz", hash = "sha256:75beac4a47881eeb94d5ea5d6ad31ef88856affe2332b9aafb52c6452ccf0d7a"}, - {file = "requests_oauthlib-1.3.1-py2.py3-none-any.whl", hash = "sha256:2577c501a2fb8d05a304c09d090d6e47c306fef15809d102b327cf8364bddab5"}, -] - -[package.dependencies] -oauthlib = ">=3.0.0" -requests = ">=2.0.0" - -[package.extras] -rsa = ["oauthlib[signedtoken] (>=3.0.0)"] - [[package]] name = "rfc3339-validator" version = "0.1.4" @@ -3486,101 +2879,6 @@ files = [ {file = "rpds_py-0.10.4.tar.gz", hash = "sha256:18d5ff7fbd305a1d564273e9eb22de83ae3cd9cd6329fddc8f12f6428a711a6a"}, ] -[[package]] -name = "rsa" -version = "4.9" -description = "Pure-Python RSA implementation" -optional = false -python-versions = ">=3.6,<4" -files = [ - {file = "rsa-4.9-py3-none-any.whl", hash = "sha256:90260d9058e514786967344d0ef75fa8727eed8a7d2e43ce9f4bcf1b536174f7"}, - {file = "rsa-4.9.tar.gz", hash = "sha256:e38464a49c6c85d7f1351b0126661487a7e0a14a50f1675ec50eb34d4f20ef21"}, -] - -[package.dependencies] -pyasn1 = ">=0.1.3" - -[[package]] -name = "sbi" -version = "0.22.0" -description = "Simulation-based inference." -optional = false -python-versions = ">=3.6.0" -files = [ - {file = "sbi-0.22.0-py2.py3-none-any.whl", hash = "sha256:26dd81d3e1220c4ca16a33fc0779e18e35c296f97e7fae588712f17a412c058e"}, - {file = "sbi-0.22.0.tar.gz", hash = "sha256:e632994c0bcfbc63c110d6eb04c6e1bdf6ae4ca42211d67ab4946d9b394dc360"}, -] - -[package.dependencies] -arviz = "*" -joblib = ">=1.0.0" -matplotlib = "*" -numpy = "*" -pillow = "*" -pyknos = ">=0.15.1" -pyro-ppl = ">=1.3.1" -scikit-learn = "*" -scipy = "*" -tensorboard = "*" -torch = ">=1.8.0" -tqdm = "*" - -[package.extras] -dev = ["autoflake", "black", "deepdiff", "flake8", "isort", "jupyter", "markdown-include", "mkdocs", "mkdocs-material", "mkdocs-redirects", "mkdocstrings[python] (>=0.18)", "nbconvert", "pep517", "pre-commit", "pyright (>=1.1.300,<1.1.306)", "pytest", "pyyaml", "torchtestcase", "twine"] - -[[package]] -name = "scikit-image" -version = "0.20.0" -description = "Image processing in Python" -optional = false -python-versions = ">=3.8" -files = [ - {file = "scikit_image-0.20.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3cec8c5e8412ee19642a916648144186eb6b60c39fb6608ab478b4d1a4575e25"}, - {file = "scikit_image-0.20.0-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:0ab378822fadc93db7e917a266d489ea33df3b42edfef197caaebbabbc2e4ecc"}, - {file = "scikit_image-0.20.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6797e3ef5fc53897bde131cfc3ceba6ce247d89cfe194fc8d3aba7f5c12aaf6"}, - {file = "scikit_image-0.20.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f667dcf01737248bc5bd0a99fad58475abeb6b6a8229aecee9fdb96cf988ae85"}, - {file = "scikit_image-0.20.0-cp310-cp310-win_amd64.whl", hash = "sha256:79a400ffe35fc7f64d1d043f3d043e062015689ad5637c35cd5569edae87ae13"}, - {file = "scikit_image-0.20.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:049d955869620453b9e0568c2da62c8fec47bf3714be48b5d46bbaebb91bdc1f"}, - {file = "scikit_image-0.20.0-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:a503ee85b444234ee88f34bf8674872dc37c6124ff60b7eb9242813de012ff4e"}, - {file = "scikit_image-0.20.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3943d7355d02b40c066fd87cd5fe1b4f6637a16448e62333c4191a65ebf40a1c"}, - {file = "scikit_image-0.20.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d719242ea7e7250d49e38d1e33c44c2dd59c3414ae085881d168b98cbb6059a"}, - {file = "scikit_image-0.20.0-cp311-cp311-win_amd64.whl", hash = "sha256:fdd1fd258e78c86e382fd687177431088a40880bd785e0ab40ee5f3794366710"}, - {file = "scikit_image-0.20.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1cd0486cb769d906307a3ec3884630be822d8ec2f41069e197336f904f584a33"}, - {file = "scikit_image-0.20.0-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:2e9026161d0a698f532352dda6455a0bc13b1c9d831ea9279726b59d064df574"}, - {file = "scikit_image-0.20.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c123e6b0677dc1697c04b5bf2efb7110bcca511b4bc6967a38fa395ae5edf44"}, - {file = "scikit_image-0.20.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76f2fd12b537daea806a078df9ea76f5cc5a529d5bd7c41d7d0a101e9c5f91c4"}, - {file = "scikit_image-0.20.0-cp38-cp38-win_amd64.whl", hash = "sha256:2118d610096754bca44b5d37328e1382e5fa7c6493803685100c9238e257d848"}, - {file = "scikit_image-0.20.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:13a5c1c81ee5bcb64ee8ca8f1a2cf371b0c4345ea6fb67c3052e1c6d5edbd936"}, - {file = "scikit_image-0.20.0-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:1794889d2dbb385c7ad5656363371ba0057b7a3335cda093a11415af84bb96e2"}, - {file = "scikit_image-0.20.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:df14f8a55dae511749b081d9402ea215ea7c641bd6f74f06aa7b623e132817df"}, - {file = "scikit_image-0.20.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b856efc75e3051bea6d40a8ffcdaabd5682783ece1aa91c3f6777c3372a98ca1"}, - {file = "scikit_image-0.20.0-cp39-cp39-win_amd64.whl", hash = "sha256:a600374394b76b7fc260cef54e1be21047c4de0ecffb0b7f2f7392cd8ba16ffa"}, - {file = "scikit_image-0.20.0.tar.gz", hash = "sha256:2cd784fce18bd31d71ade62c6221440199ead03acf7544086261ee032264cf61"}, -] - -[package.dependencies] -imageio = ">=2.4.1" -lazy_loader = ">=0.1" -networkx = ">=2.8" -numpy = ">=1.21.1" -packaging = ">=20.0" -pillow = ">=9.0.1" -PyWavelets = ">=1.1.1" -scipy = [ - {version = ">=1.8,<1.9.2", markers = "python_version <= \"3.9\""}, - {version = ">=1.8", markers = "python_version > \"3.9\""}, -] -tifffile = ">=2019.7.26" - -[package.extras] -build = ["Cython (>=0.29.24)", "build", "meson-python (>=0.13.0rc0)", "ninja", "numpy (>=1.21.1)", "packaging (>=20)", "pythran", "setuptools (>=67)", "wheel"] -data = ["pooch (>=1.3.0)"] -default = ["PyWavelets (>=1.1.1)", "imageio (>=2.4.1)", "lazy_loader (>=0.1)", "networkx (>=2.8)", "numpy (>=1.21.1)", "packaging (>=20.0)", "pillow (>=9.0.1)", "scipy (>=1.8)", "scipy (>=1.8,<1.9.2)", "tifffile (>=2019.7.26)"] -developer = ["pre-commit", "rtoml"] -docs = ["dask[array] (>=2022.9.2)", "ipywidgets", "kaleido", "matplotlib (>=3.6)", "myst-parser", "numpydoc (>=1.5)", "pandas (>=1.5)", "plotly (>=5.10)", "pooch (>=1.6)", "pytest-runner", "scikit-learn", "seaborn (>=0.11)", "sphinx (>=5.2)", "sphinx-copybutton", "sphinx-gallery (>=0.11)", "tifffile (>=2022.8.12)"] -optional = ["SimpleITK", "astropy (>=3.1.2)", "cloudpickle (>=0.2.1)", "dask[array] (>=1.0.0,!=2.17.0)", "matplotlib (>=3.3)", "pooch (>=1.3.0)", "pyamg"] -test = ["asv", "codecov", "matplotlib (>=3.3)", "pooch (>=1.3.0)", "pytest (>=5.2.0)", "pytest-cov (>=2.7.0)", "pytest-faulthandler", "pytest-localserver"] - [[package]] name = "scikit-learn" version = "1.3.1" @@ -3623,41 +2921,6 @@ docs = ["Pillow (>=7.1.2)", "matplotlib (>=3.1.3)", "memory-profiler (>=0.57.0)" examples = ["matplotlib (>=3.1.3)", "pandas (>=1.0.5)", "plotly (>=5.14.0)", "pooch (>=1.6.0)", "scikit-image (>=0.16.2)", "seaborn (>=0.9.0)"] tests = ["black (>=23.3.0)", "matplotlib (>=3.1.3)", "mypy (>=1.3)", "numpydoc (>=1.2.0)", "pandas (>=1.0.5)", "pooch (>=1.6.0)", "pyamg (>=4.0.0)", "pytest (>=7.1.2)", "pytest-cov (>=2.9.0)", "ruff (>=0.0.272)", "scikit-image (>=0.16.2)"] -[[package]] -name = "scipy" -version = "1.9.1" -description = "SciPy: Scientific Library for Python" -optional = false -python-versions = ">=3.8,<3.12" -files = [ - {file = "scipy-1.9.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c61b4a91a702e8e04aeb0bfc40460e1f17a640977c04dda8757efb0199c75332"}, - {file = "scipy-1.9.1-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:d79da472015d0120ba9b357b28a99146cd6c17b9609403164b1a8ed149b4dfc8"}, - {file = "scipy-1.9.1-cp310-cp310-macosx_12_0_universal2.macosx_10_9_x86_64.whl", hash = "sha256:825951b88f56765aeb6e5e38ac9d7d47407cfaaeb008d40aa1b45a2d7ea2731e"}, - {file = "scipy-1.9.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f950a04b33e17b38ff561d5a0951caf3f5b47caa841edd772ffb7959f20a6af0"}, - {file = "scipy-1.9.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8cc81ac25659fec73599ccc52c989670e5ccd8974cf34bacd7b54a8d809aff1a"}, - {file = "scipy-1.9.1-cp310-cp310-win_amd64.whl", hash = "sha256:8d3faa40ac16c6357aaf7ea50394ea6f1e8e99d75e927a51102b1943b311b4d9"}, - {file = "scipy-1.9.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:7a412c476a91b080e456229e413792bbb5d6202865dae963d1e6e28c2bb58691"}, - {file = "scipy-1.9.1-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:eb954f5aca4d26f468bbebcdc5448348eb287f7bea536c6306f62ea062f63d9a"}, - {file = "scipy-1.9.1-cp38-cp38-macosx_12_0_universal2.macosx_10_9_x86_64.whl", hash = "sha256:3c6f5d1d4b9a5e4fe5e14f26ffc9444fc59473bbf8d45dc4a9a15283b7063a72"}, - {file = "scipy-1.9.1-cp38-cp38-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:bc4e2c77d4cd015d739e75e74ebbafed59ba8497a7ed0fd400231ed7683497c4"}, - {file = "scipy-1.9.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0419485dbcd0ed78c0d5bf234c5dd63e86065b39b4d669e45810d42199d49521"}, - {file = "scipy-1.9.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34441dfbee5b002f9e15285014fd56e5e3372493c3e64ae297bae2c4b9659f5a"}, - {file = "scipy-1.9.1-cp38-cp38-win32.whl", hash = "sha256:b97b479f39c7e4aaf807efd0424dec74bbb379108f7d22cf09323086afcd312c"}, - {file = "scipy-1.9.1-cp38-cp38-win_amd64.whl", hash = "sha256:e8fe305d9d67a81255e06203454729405706907dccbdfcc330b7b3482a6c371d"}, - {file = "scipy-1.9.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:39ab9240cd215a9349c85ab908dda6d732f7d3b4b192fa05780812495536acc4"}, - {file = "scipy-1.9.1-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:71487c503e036740635f18324f62a11f283a632ace9d35933b2b0a04fd898c98"}, - {file = "scipy-1.9.1-cp39-cp39-macosx_12_0_universal2.macosx_10_9_x86_64.whl", hash = "sha256:3bc1ab68b9a096f368ba06c3a5e1d1d50957a86665fc929c4332d21355e7e8f4"}, - {file = "scipy-1.9.1-cp39-cp39-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:f7c39f7dbb57cce00c108d06d731f3b0e2a4d3a95c66d96bce697684876ce4d4"}, - {file = "scipy-1.9.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:47d1a95bd9d37302afcfe1b84c8011377c4f81e33649c5a5785db9ab827a6ade"}, - {file = "scipy-1.9.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:96d7cf7b25c9f23c59a766385f6370dab0659741699ecc7a451f9b94604938ce"}, - {file = "scipy-1.9.1-cp39-cp39-win32.whl", hash = "sha256:09412eb7fb60b8f00b328037fd814d25d261066ebc43a1e339cdce4f7502877e"}, - {file = "scipy-1.9.1-cp39-cp39-win_amd64.whl", hash = "sha256:90c805f30c46cf60f1e76e947574f02954d25e3bb1e97aa8a07bc53aa31cf7d1"}, - {file = "scipy-1.9.1.tar.gz", hash = "sha256:26d28c468900e6d5fdb37d2812ab46db0ccd22c63baa095057871faa3a498bc9"}, -] - -[package.dependencies] -numpy = ">=1.18.5,<1.25.0" - [[package]] name = "scipy" version = "1.11.3" @@ -3841,42 +3104,6 @@ files = [ [package.dependencies] mpmath = ">=0.19" -[[package]] -name = "tensorboard" -version = "2.15.1" -description = "TensorBoard lets you watch Tensors Flow" -optional = false -python-versions = ">=3.9" -files = [ - {file = "tensorboard-2.15.1-py3-none-any.whl", hash = "sha256:c46c1d1cf13a458c429868a78b2531d8ff5f682058d69ec0840b0bc7a38f1c0f"}, -] - -[package.dependencies] -absl-py = ">=0.4" -google-auth = ">=1.6.3,<3" -google-auth-oauthlib = ">=0.5,<2" -grpcio = ">=1.48.2" -markdown = ">=2.6.8" -numpy = ">=1.12.0" -protobuf = ">=3.19.6,<4.24" -requests = ">=2.21.0,<3" -setuptools = ">=41.0.0" -six = ">1.9" -tensorboard-data-server = ">=0.7.0,<0.8.0" -werkzeug = ">=1.0.1" - -[[package]] -name = "tensorboard-data-server" -version = "0.7.2" -description = "Fast data loading for TensorBoard" -optional = false -python-versions = ">=3.7" -files = [ - {file = "tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb"}, - {file = "tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60"}, - {file = "tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530"}, -] - [[package]] name = "terminado" version = "0.17.1" @@ -3908,23 +3135,6 @@ files = [ {file = "threadpoolctl-3.2.0.tar.gz", hash = "sha256:c96a0ba3bdddeaca37dc4cc7344aafad41cdb8c313f74fdfe387a867bba93355"}, ] -[[package]] -name = "tifffile" -version = "2023.9.26" -description = "Read and write TIFF files" -optional = false -python-versions = ">=3.9" -files = [ - {file = "tifffile-2023.9.26-py3-none-any.whl", hash = "sha256:1de47fa945fddaade256e25ad4f375ae65547f3c1354063aded881c32a64cf89"}, - {file = "tifffile-2023.9.26.tar.gz", hash = "sha256:67e355e4595aab397f8405d04afe1b4ae7c6f62a44e22d933fee1a571a48c7ae"}, -] - -[package.dependencies] -numpy = "*" - -[package.extras] -all = ["defusedxml", "fsspec", "imagecodecs (>=2023.8.12)", "lxml", "matplotlib", "zarr"] - [[package]] name = "tinycss2" version = "1.2.1" @@ -4014,26 +3224,6 @@ files = [ {file = "tornado-6.3.3.tar.gz", hash = "sha256:e7d8db41c0181c80d76c982aacc442c0783a2c54d6400fe028954201a2e032fe"}, ] -[[package]] -name = "tqdm" -version = "4.66.1" -description = "Fast, Extensible Progress Meter" -optional = false -python-versions = ">=3.7" -files = [ - {file = "tqdm-4.66.1-py3-none-any.whl", hash = "sha256:d302b3c5b53d47bce91fea46679d9c3c6508cf6332229aa1e7d8653723793386"}, - {file = "tqdm-4.66.1.tar.gz", hash = "sha256:d88e651f9db8d8551a62556d3cff9e3034274ca5d66e93197cf2490e2dcb69c7"}, -] - -[package.dependencies] -colorama = {version = "*", markers = "platform_system == \"Windows\""} - -[package.extras] -dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] -notebook = ["ipywidgets (>=6)"] -slack = ["slack-sdk"] -telegram = ["requests"] - [[package]] name = "traitlets" version = "5.11.2" @@ -4113,6 +3303,26 @@ secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17. socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "virtualenv" +version = "20.25.1" +description = "Virtual Python Environment builder" +optional = false +python-versions = ">=3.7" +files = [ + {file = "virtualenv-20.25.1-py3-none-any.whl", hash = "sha256:961c026ac520bac5f69acb8ea063e8a4f071bcc9457b9c1f28f6b085c511583a"}, + {file = "virtualenv-20.25.1.tar.gz", hash = "sha256:e08e13ecdca7a0bd53798f356d5831434afa5b07b93f0abdf0797b7a06ffe197"}, +] + +[package.dependencies] +distlib = ">=0.3.7,<1" +filelock = ">=3.12.2,<4" +platformdirs = ">=3.9.1,<5" + +[package.extras] +docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] + [[package]] name = "wcwidth" version = "0.2.8" @@ -4166,23 +3376,6 @@ docs = ["Sphinx (>=6.0)", "sphinx-rtd-theme (>=1.1.0)"] optional = ["python-socks", "wsaccel"] test = ["websockets"] -[[package]] -name = "werkzeug" -version = "3.0.1" -description = "The comprehensive WSGI web application library." -optional = false -python-versions = ">=3.8" -files = [ - {file = "werkzeug-3.0.1-py3-none-any.whl", hash = "sha256:90a285dc0e42ad56b34e696398b8122ee4c681833fb35b8334a095d82c56da10"}, - {file = "werkzeug-3.0.1.tar.gz", hash = "sha256:507e811ecea72b18a404947aded4b3390e1db8f826b494d76550ef45bb3b1dcc"}, -] - -[package.dependencies] -MarkupSafe = ">=2.1.1" - -[package.extras] -watchdog = ["watchdog (>=2.3)"] - [[package]] name = "widgetsnbextension" version = "4.0.9" @@ -4194,51 +3387,6 @@ files = [ {file = "widgetsnbextension-4.0.9.tar.gz", hash = "sha256:3c1f5e46dc1166dfd40a42d685e6a51396fd34ff878742a3e47c6f0cc4a2a385"}, ] -[[package]] -name = "xarray" -version = "2023.9.0" -description = "N-D labeled arrays and datasets in Python" -optional = false -python-versions = ">=3.9" -files = [ - {file = "xarray-2023.9.0-py3-none-any.whl", hash = "sha256:3fc4a558bd70968040a4e1cefc6ddb3f9a7a86ef6a48e67857156ffe655d3a66"}, - {file = "xarray-2023.9.0.tar.gz", hash = "sha256:271955c05dc626dad37791a7807d920aaf9c64cac71d03b45ec7e402cc646603"}, -] - -[package.dependencies] -numpy = ">=1.21" -packaging = ">=21.3" -pandas = ">=1.4" - -[package.extras] -accel = ["bottleneck", "flox", "numbagg", "scipy"] -complete = ["xarray[accel,io,parallel,viz]"] -io = ["cftime", "fsspec", "h5netcdf", "netCDF4", "pooch", "pydap", "scipy", "zarr"] -parallel = ["dask[complete]"] -viz = ["matplotlib", "nc-time-axis", "seaborn"] - -[[package]] -name = "xarray-einstats" -version = "0.6.0" -description = "Stats, linear algebra and einops for xarray" -optional = false -python-versions = ">=3.9" -files = [ - {file = "xarray_einstats-0.6.0-py3-none-any.whl", hash = "sha256:4c6f556a9d8603245545cb88583c04398b10a70c572936a2f48678330545883a"}, - {file = "xarray_einstats-0.6.0.tar.gz", hash = "sha256:ace90601505cfbe2d374762e674557ed14e1725b024823372f7ef9fd237effad"}, -] - -[package.dependencies] -numpy = ">=1.21" -scipy = ">=1.7" -xarray = ">=2022.09.0" - -[package.extras] -doc = ["furo", "jupyter-sphinx", "matplotlib", "myst-nb", "myst-parser[linkify]", "numpydoc", "sphinx (>=4)", "sphinx-copybutton", "sphinx-design", "sphinx-togglebutton", "watermark"] -einops = ["einops"] -numba = ["numba (>=0.55)"] -test = ["hypothesis", "packaging", "pytest", "pytest-cov"] - [[package]] name = "zipp" version = "3.17.0" @@ -4257,4 +3405,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.11" -content-hash = "29632fb37f2d97534d97e2640ed8dfbd82d26ebeabf5828277a40ce54a2412ac" +content-hash = "bbcbce1ea616fc0edb04862f3a5ded188204eed252165daed47414ca299237f1" diff --git a/pyproject.toml b/pyproject.toml index 4ca08fa..2bd02ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,20 +11,18 @@ license = "MIT" python = ">=3.9,<3.11" jupyter = "^1.0.0" matplotlib = "^3.7.1" -arviz = "^0.15.1" -corner = "^2.2.2" scikit-learn = "^1.3.0" -graphviz = "^0.20.1" seaborn = "^0.12.2" torch = "^2.0.1" -pytest-cov = "^4.1.0" -deepbench = "^0.2.2" -sbi = "^0.22.0" h5py = "^3.10.0" -flake8 = "^7.0.0" -black = "^24.2.0" +[tool.poetry.group.dev.dependencies] +pytest-cov = "^4.1.0" +flake8 = "^7.0.0" +pytest = "^7.3.2" +pre-commit = "^3.7.0" +black = "^24.3.0" [build-system] -requires = ["poetry-core"] +requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" From 83703cd7d8adfa0cfcce4a4f711a1bcac0da2270 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Tue, 9 Apr 2024 14:52:18 -0600 Subject: [PATCH 25/30] argparse should work --- src/scripts/models.py | 7 +- src/scripts/train.py | 225 ++++++++++++++++++++++++++++++++------ test/test_DeepEnsemble.py | 4 - 3 files changed, 195 insertions(+), 41 deletions(-) diff --git a/src/scripts/models.py b/src/scripts/models.py index 80f25a0..73e5802 100644 --- a/src/scripts/models.py +++ b/src/scripts/models.py @@ -28,13 +28,14 @@ def forward(self, x): return torch.stack((gamma, nu, alpha, beta), dim=1) -def model_setup_DER(DER_type, DEVICE): +def model_setup_DER(loss_type, DEVICE): + print('loss type', loss_type, type(loss_type)) # initialize the model from scratch - if DER_type == "SDER": + if loss_type == "SDER": Layer = SDERLayer # initialize our loss function lossFn = loss_sder - if DER_type == "DER": + if loss_type == "DER": Layer = DERLayer # initialize our loss function lossFn = loss_der diff --git a/src/scripts/train.py b/src/scripts/train.py index 0aff0c4..56120bd 100644 --- a/src/scripts/train.py +++ b/src/scripts/train.py @@ -14,51 +14,69 @@ def train_DER( INIT_LR, DEVICE, COEFF, - DER_type, - model_name, - EPOCHS=40, - save_checkpoints=False, + loss_type, + wd, + model_name="DER", + EPOCHS=100, path_to_model="models/", - plot=False, - verbose=True, + save_all_checkpoints=False, + save_final_checkpoint=False, + overwrite_final_checkpoint=False, + plot=True, + savefig=True, + verbose=True ): + # first determine if you even need to run anything + if not save_all_checkpoints and save_final_checkpoint: + # option to skip running the model if you don't care about + # saving all checkpoints and only want to save the final + final_chk = ( + path_to_model + + str(model_name) + + "_loss_" + + str(loss_type) + + "_epoch_" + + str(EPOCHS - 1) + + ".pt" + ) + if verbose: + print("final chk", final_chk) + # check if the final epoch checkpoint already exists + print(glob.glob(final_chk)) + if glob.glob(final_chk): + print("final model already exists") + if overwrite_final_checkpoint: + print("going to overwrite final checkpoint") + else: + print("not overwriting, exiting") + return + else: + print("model does not exist yet, going to save") # measure how long training is going to take if verbose: print("[INFO] training the network...") - print("saving checkpoints?") - print(save_checkpoints) + print("saving all checkpoints?") + print(save_all_checkpoints) + print("saving final checkpoint?") + print(save_final_checkpoint) + print("overwriting final checkpoint if its already there?") + print(overwrite_final_checkpoint) print(f"saving here: {path_to_model}") print(f"model name: {model_name}") startTime = time.time() start_epoch = 0 - """ - # Find last epoch saved - if save_checkpoints: - - print(glob.glob(path_to_model + "/" + str(model_name) + "*")) - list_models_run = [] - for file in glob.glob(path_to_model + "/" + str(model_name) + "*"): - list_models_run.append( - float(str.split(str(str.split(file, - model_name + "_")[1]), ".")[0]) - ) - if list_models_run: - start_epoch = max(list_models_run) + 1 - else: - start_epoch = 0 - else: - start_epoch = 0 - print("starting here", start_epoch) - """ + best_loss = np.inf # init to infinity - model, lossFn = models.model_setup_DER(DER_type, DEVICE) + model, lossFn = models.model_setup_DER(loss_type, DEVICE) + if verbose: + print('model is', model, 'lossfn', lossFn) opt = torch.optim.Adam(model.parameters(), lr=INIT_LR) # loop over our epochs for e in range(0, EPOCHS): - if plot: + if plot or savefig: plt.clf() fig, (ax1, ax2) = plt.subplots( 2, 1, figsize=(8, 6), gridspec_kw={"height_ratios": [3, 1]} @@ -84,8 +102,8 @@ def train_DER( pred = model(x) loss = lossFn(pred, y, COEFF) - if plot and (e % 5 == 0): - if i == 0: + if plot or savefig: + if (e % (EPOCHS - 1) == 0) and (e != 0): pred_loader_0 = pred[:, 0].flatten().detach().numpy() y_loader_0 = y.detach().numpy() ax1.scatter( @@ -108,10 +126,12 @@ def train_DER( xycoords="axes fraction", color="black", ) + ''' else: ax1.scatter(y, pred[:, 0].flatten().detach().numpy(), color="grey") + ''' loss_this_epoch.append(loss[0].item()) # zero out the gradients @@ -123,6 +143,112 @@ def train_DER( # optimizer takes a step based on the gradients of the parameters # here, its taking a step for every batch opt.step() + if (plot or savefig) and (e % (EPOCHS - 1) == 0) and (e != 0): + ax1.plot(range(0, 1000), + range(0, 1000), + color="black", + ls="--") + if loss_type == "no_var_loss": + ax1.scatter( + y_val, + y_pred.flatten().detach().numpy(), + color="#F45866", + edgecolor="black", + zorder=100, + label="validation dtata", + ) + else: + ax1.errorbar( + y_val, + y_pred[:, 0].flatten().detach().numpy(), + yerr=np.sqrt(y_pred[:, 1].flatten().detach().numpy()), + linestyle="None", + color="black", + capsize=2, + zorder=100, + ) + ax1.scatter( + y_val, + y_pred[:, 0].flatten().detach().numpy(), + color="#9CD08F", + s=5, + zorder=101, + label="validation data", + ) + + # add residual plot + residuals = y_pred[:, 0].flatten().detach().numpy() - y_val + ax2.errorbar( + y_val, + residuals, + yerr=np.sqrt(y_pred[:, 1].flatten().detach().numpy()), + linestyle="None", + color="black", + capsize=2, + ) + ax2.scatter(y_val, residuals, color="#9B287B", s=5, zorder=100) + ax2.axhline(0, color="black", linestyle="--", linewidth=1) + ax2.set_ylabel("Residuals") + ax2.set_xlabel("True Value") + # add annotion for loss value + if loss_type == "bnll_loss": + ax1.annotate( + r"$\beta = $" + + str(round(beta_epoch, 2)) + + "\n" + + str(loss_type) + + " = " + + str(round(loss, 2)) + + "\n" + + r"MSE = " + + str(round(mse, 2)), + xy=(0.73, 0.1), + xycoords="axes fraction", + bbox=dict( + boxstyle="round,pad=0.5", + facecolor="lightgrey", + alpha=0.5 + ), + ) + + else: + ax1.annotate( + str(loss_type) + + " = " + + str(round(loss, 2)) + + "\n" + + r"MSE = " + + str(round(mse, 2)), + xy=(0.73, 0.1), + xycoords="axes fraction", + bbox=dict( + boxstyle="round,pad=0.5", + facecolor="lightgrey", + alpha=0.5 + ), + ) + ax1.set_ylabel("Prediction") + ax1.set_title("Epoch " + str(e)) + ax1.set_xlim([0, 1000]) + ax1.set_ylim([0, 1000]) + ax1.legend() + if savefig: + # ax1.errorbar(200, 600, yerr=5, + # color='red', capsize=2) + plt.savefig( + str(wd) + + "images/animations/" + + str(model_name) + + "_loss_" + + str(loss_type) + + "_epoch_" + + str(epoch) + + ".png" + ) + if plot: + plt.show() + plt.close() + ''' if plot and (e % 5 == 0): ax1.set_ylabel("prediction") ax1.set_title("Epoch " + str(e)) @@ -136,6 +262,7 @@ def train_DER( plt.show() plt.close() + ''' model.eval() y_pred = model(torch.Tensor(x_val)) loss = lossFn(y_pred, torch.Tensor(y_val), COEFF) @@ -155,7 +282,7 @@ def train_DER( # best_weights = copy.deepcopy(model.state_dict()) # print('validation loss', mse) - if save_checkpoints: + if save_all_checkpoints: torch.save( { @@ -170,9 +297,39 @@ def train_DER( "std_u_al_validation": std_u_al_val, "std_u_ep_validation": std_u_ep_val, }, - path_to_model + "/" + str(model_name) - + "_epoch_" + str(epoch) + ".pt", + str(wd) + + "models/" + + str(model_name) + + "_loss_" + + str(loss_type) + + "_epoch_" + + str(epoch) + + ".pt", ) + if save_final_checkpoint and (e % (EPOCHS - 1) == 0) and (e != 0): + # option to just save final epoch + torch.save( + { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "optimizer_state_dict": opt.state_dict(), + "train_loss": np.mean(loss_this_epoch), + "valid_loss": NIGloss_val, + "valid_mse": mse, + "med_u_al_validation": med_u_al_val, + "med_u_ep_validation": med_u_ep_val, + "std_u_al_validation": std_u_al_val, + "std_u_ep_validation": std_u_ep_val, + }, + str(wd) + + "models/" + + str(model_name) + + "_loss_" + + str(loss_type) + + "_epoch_" + + str(epoch) + + ".pt", + ) endTime = time.time() if verbose: print("start at", startTime, "end at", endTime) diff --git a/test/test_DeepEnsemble.py b/test/test_DeepEnsemble.py index f29698c..e8c4d4f 100644 --- a/test/test_DeepEnsemble.py +++ b/test/test_DeepEnsemble.py @@ -1,13 +1,9 @@ import sys import pytest -import torch -import numpy as np -import sbi import os import subprocess import tempfile import shutil -import unittest # flake8: noqa sys.path.append("..") From c8e2783131a94c85d3942b9fb745723bf48039c8 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Tue, 9 Apr 2024 15:02:06 -0600 Subject: [PATCH 26/30] correcting flake8 making sure everything is defined in DER --- src/scripts/train.py | 128 ++++++++++++++----------------------------- 1 file changed, 42 insertions(+), 86 deletions(-) diff --git a/src/scripts/train.py b/src/scripts/train.py index 56120bd..ce73e68 100644 --- a/src/scripts/train.py +++ b/src/scripts/train.py @@ -24,7 +24,7 @@ def train_DER( overwrite_final_checkpoint=False, plot=True, savefig=True, - verbose=True + verbose=True, ): # first determine if you even need to run anything if not save_all_checkpoints and save_final_checkpoint: @@ -38,7 +38,7 @@ def train_DER( + "_epoch_" + str(EPOCHS - 1) + ".pt" - ) + ) if verbose: print("final chk", final_chk) # check if the final epoch checkpoint already exists @@ -66,11 +66,11 @@ def train_DER( startTime = time.time() start_epoch = 0 - + best_loss = np.inf # init to infinity model, lossFn = models.model_setup_DER(loss_type, DEVICE) if verbose: - print('model is', model, 'lossfn', lossFn) + print("model is", model, "lossfn", lossFn) opt = torch.optim.Adam(model.parameters(), lr=INIT_LR) @@ -104,8 +104,6 @@ def train_DER( loss = lossFn(pred, y, COEFF) if plot or savefig: if (e % (EPOCHS - 1) == 0) and (e != 0): - pred_loader_0 = pred[:, 0].flatten().detach().numpy() - y_loader_0 = y.detach().numpy() ax1.scatter( y, pred[:, 0].flatten().detach().numpy(), @@ -126,12 +124,12 @@ def train_DER( xycoords="axes fraction", color="black", ) - ''' + """ else: ax1.scatter(y, pred[:, 0].flatten().detach().numpy(), color="grey") - ''' + """ loss_this_epoch.append(loss[0].item()) # zero out the gradients @@ -143,11 +141,25 @@ def train_DER( # optimizer takes a step based on the gradients of the parameters # here, its taking a step for every batch opt.step() + model.eval() + y_pred = model(torch.Tensor(x_val)) + loss = lossFn(y_pred, torch.Tensor(y_val), COEFF) + NIGloss_val = loss[0].item() + med_u_al_val = np.median(loss[1]) + med_u_ep_val = np.median(loss[2]) + std_u_al_val = np.std(loss[1]) + std_u_ep_val = np.std(loss[2]) + + # lets also grab mse loss + mse_loss = torch.nn.MSELoss(reduction="mean") + mse = mse_loss(y_pred[:, 0], torch.Tensor(y_val)).item() + if NIGloss_val < best_loss: + best_loss = NIGloss_val + if verbose: + print("new best loss", NIGloss_val, "in epoch", epoch) + # best_weights = copy.deepcopy(model.state_dict()) if (plot or savefig) and (e % (EPOCHS - 1) == 0) and (e != 0): - ax1.plot(range(0, 1000), - range(0, 1000), - color="black", - ls="--") + ax1.plot(range(0, 1000), range(0, 1000), color="black", ls="--") if loss_type == "no_var_loss": ax1.scatter( y_val, @@ -191,42 +203,21 @@ def train_DER( ax2.set_ylabel("Residuals") ax2.set_xlabel("True Value") # add annotion for loss value - if loss_type == "bnll_loss": - ax1.annotate( - r"$\beta = $" - + str(round(beta_epoch, 2)) - + "\n" - + str(loss_type) - + " = " - + str(round(loss, 2)) - + "\n" - + r"MSE = " - + str(round(mse, 2)), - xy=(0.73, 0.1), - xycoords="axes fraction", - bbox=dict( - boxstyle="round,pad=0.5", - facecolor="lightgrey", - alpha=0.5 - ), - ) - - else: - ax1.annotate( - str(loss_type) - + " = " - + str(round(loss, 2)) - + "\n" - + r"MSE = " - + str(round(mse, 2)), - xy=(0.73, 0.1), - xycoords="axes fraction", - bbox=dict( - boxstyle="round,pad=0.5", - facecolor="lightgrey", - alpha=0.5 - ), - ) + ax1.annotate( + str(loss_type) + + " = " + + str(round(loss, 2)) + + "\n" + + r"MSE = " + + str(round(mse, 2)), + xy=(0.73, 0.1), + xycoords="axes fraction", + bbox=dict( + boxstyle="round,pad=0.5", + facecolor="lightgrey", + alpha=0.5 + ), + ) ax1.set_ylabel("Prediction") ax1.set_title("Epoch " + str(e)) ax1.set_xlim([0, 1000]) @@ -248,40 +239,6 @@ def train_DER( if plot: plt.show() plt.close() - ''' - if plot and (e % 5 == 0): - ax1.set_ylabel("prediction") - ax1.set_title("Epoch " + str(e)) - - # Residuals plot - residuals = pred_loader_0 - y_loader_0 - ax2.scatter(y_loader_0, residuals, color="red") - ax2.axhline(0, color="black", linestyle="--", linewidth=1) - ax2.set_ylabel("Residuals") - ax2.set_xlabel("True Value") - - plt.show() - plt.close() - ''' - model.eval() - y_pred = model(torch.Tensor(x_val)) - loss = lossFn(y_pred, torch.Tensor(y_val), COEFF) - NIGloss_val = loss[0].item() - med_u_al_val = np.median(loss[1]) - med_u_ep_val = np.median(loss[2]) - std_u_al_val = np.std(loss[1]) - std_u_ep_val = np.std(loss[2]) - - # lets also grab mse loss - mse_loss = torch.nn.MSELoss(reduction="mean") - mse = mse_loss(y_pred[:, 0], torch.Tensor(y_val)).item() - if NIGloss_val < best_loss: - best_loss = NIGloss_val - if verbose: - print("new best loss", NIGloss_val, "in epoch", epoch) - # best_weights = copy.deepcopy(model.state_dict()) - # print('validation loss', mse) - if save_all_checkpoints: torch.save( @@ -307,7 +264,7 @@ def train_DER( + ".pt", ) if save_final_checkpoint and (e % (EPOCHS - 1) == 0) and (e != 0): - # option to just save final epoch + # option to just save final epoch torch.save( { "epoch": epoch, @@ -329,7 +286,7 @@ def train_DER( + "_epoch_" + str(epoch) + ".pt", - ) + ) endTime = time.time() if verbose: print("start at", startTime, "end at", endTime) @@ -371,8 +328,7 @@ def train_DE( model_ensemble = [] - print('this is the value of save_final_checkpoint', - save_final_checkpoint) + print("this is the value of save_final_checkpoint", save_final_checkpoint) for m in range(n_models): print("model", m) From 993addddaa29b2e2e3e82d9e34b937d33f28d26a Mon Sep 17 00:00:00 2001 From: beckynevin Date: Tue, 9 Apr 2024 15:07:12 -0600 Subject: [PATCH 27/30] updating torch --- poetry.lock | 226 +++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 204 insertions(+), 22 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1230fc8..e4d7529 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1987,6 +1987,147 @@ files = [ {file = "numpy-1.26.0.tar.gz", hash = "sha256:f93fc78fe8bf15afe2b8d6b6499f1c73953169fad1e9a8dd086cdff3190e7fdf"}, ] +[[package]] +name = "nvidia-cublas-cu12" +version = "12.1.3.1" +description = "CUBLAS native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"}, + {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"}, +] + +[[package]] +name = "nvidia-cuda-cupti-cu12" +version = "12.1.105" +description = "CUDA profiling tools runtime libs." +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"}, + {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"}, +] + +[[package]] +name = "nvidia-cuda-nvrtc-cu12" +version = "12.1.105" +description = "NVRTC native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"}, + {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"}, +] + +[[package]] +name = "nvidia-cuda-runtime-cu12" +version = "12.1.105" +description = "CUDA Runtime native Libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"}, + {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"}, +] + +[[package]] +name = "nvidia-cudnn-cu12" +version = "8.9.2.26" +description = "cuDNN runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" + +[[package]] +name = "nvidia-cufft-cu12" +version = "11.0.2.54" +description = "CUFFT native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"}, + {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"}, +] + +[[package]] +name = "nvidia-curand-cu12" +version = "10.3.2.106" +description = "CURAND native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"}, + {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"}, +] + +[[package]] +name = "nvidia-cusolver-cu12" +version = "11.4.5.107" +description = "CUDA solver native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"}, + {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"}, +] + +[package.dependencies] +nvidia-cublas-cu12 = "*" +nvidia-cusparse-cu12 = "*" +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-cusparse-cu12" +version = "12.1.0.106" +description = "CUSPARSE native runtime libraries" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"}, + {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"}, +] + +[package.dependencies] +nvidia-nvjitlink-cu12 = "*" + +[[package]] +name = "nvidia-nccl-cu12" +version = "2.19.3" +description = "NVIDIA Collective Communication Library (NCCL) Runtime" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl", hash = "sha256:a9734707a2c96443331c1e48c717024aa6678a0e2a4cb66b2c364d18cee6b48d"}, +] + +[[package]] +name = "nvidia-nvjitlink-cu12" +version = "12.4.127" +description = "Nvidia JIT LTO Library" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl", hash = "sha256:06b3b9b25bf3f8af351d664978ca26a16d2c5127dbd53c0497e28d1fb9611d57"}, + {file = "nvidia_nvjitlink_cu12-12.4.127-py3-none-win_amd64.whl", hash = "sha256:fd9020c501d27d135f983c6d3e244b197a7ccad769e34df53a42e276b0e25fa1"}, +] + +[[package]] +name = "nvidia-nvtx-cu12" +version = "12.1.105" +description = "NVIDIA Tools Extension" +optional = false +python-versions = ">=3" +files = [ + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"}, + {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"}, +] + [[package]] name = "overrides" version = "7.4.0" @@ -3166,31 +3307,36 @@ files = [ [[package]] name = "torch" -version = "2.1.0" +version = "2.2.2" description = "Tensors and Dynamic neural networks in Python with strong GPU acceleration" optional = false python-versions = ">=3.8.0" files = [ - {file = "torch-2.1.0-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bf57f8184b2c317ef81fb33dc233ce4d850cd98ef3f4a38be59c7c1572d175db"}, - {file = "torch-2.1.0-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:a04a0296d47f28960f51c18c5489a8c3472f624ec3b5bcc8e2096314df8c3342"}, - {file = "torch-2.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0bd691efea319b14ef239ede16d8a45c246916456fa3ed4f217d8af679433cc6"}, - {file = "torch-2.1.0-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:101c139152959cb20ab370fc192672c50093747906ee4ceace44d8dd703f29af"}, - {file = "torch-2.1.0-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:a6b7438a90a870e4cdeb15301519ae6c043c883fcd224d303c5b118082814767"}, - {file = "torch-2.1.0-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:2224622407ca52611cbc5b628106fde22ed8e679031f5a99ce286629fc696128"}, - {file = "torch-2.1.0-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:8132efb782cd181cc2dcca5e58effbe4217cdb2581206ac71466d535bf778867"}, - {file = "torch-2.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:5c3bfa91ce25ba10116c224c59d5b64cdcce07161321d978bd5a1f15e1ebce72"}, - {file = "torch-2.1.0-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:601b0a2a9d9233fb4b81f7d47dca9680d4f3a78ca3f781078b6ad1ced8a90523"}, - {file = "torch-2.1.0-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:3cd1dedff13884d890f18eea620184fb4cd8fd3c68ce3300498f427ae93aa962"}, - {file = "torch-2.1.0-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:fb7bf0cc1a3db484eb5d713942a93172f3bac026fcb377a0cd107093d2eba777"}, - {file = "torch-2.1.0-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:761822761fffaa1c18a62c5deb13abaa780862577d3eadc428f1daa632536905"}, - {file = "torch-2.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:458a6d6d8f7d2ccc348ac4d62ea661b39a3592ad15be385bebd0a31ced7e00f4"}, - {file = "torch-2.1.0-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:c8bf7eaf9514465e5d9101e05195183470a6215bb50295c61b52302a04edb690"}, - {file = "torch-2.1.0-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:05661c32ec14bc3a157193d0f19a7b19d8e61eb787b33353cad30202c295e83b"}, - {file = "torch-2.1.0-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:556d8dd3e0c290ed9d4d7de598a213fb9f7c59135b4fee144364a8a887016a55"}, - {file = "torch-2.1.0-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:de7d63c6ecece118684415a3dbd4805af4a4c1ee1490cccf7405d8c240a481b4"}, - {file = "torch-2.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:2419cf49aaf3b2336c7aa7a54a1b949fa295b1ae36f77e2aecb3a74e3a947255"}, - {file = "torch-2.1.0-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:6ad491e70dbe4288d17fdbfc7fbfa766d66cbe219bc4871c7a8096f4a37c98df"}, - {file = "torch-2.1.0-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:421739685eba5e0beba42cb649740b15d44b0d565c04e6ed667b41148734a75b"}, + {file = "torch-2.2.2-cp310-cp310-manylinux1_x86_64.whl", hash = "sha256:bc889d311a855dd2dfd164daf8cc903a6b7273a747189cebafdd89106e4ad585"}, + {file = "torch-2.2.2-cp310-cp310-manylinux2014_aarch64.whl", hash = "sha256:15dffa4cc3261fa73d02f0ed25f5fa49ecc9e12bf1ae0a4c1e7a88bbfaad9030"}, + {file = "torch-2.2.2-cp310-cp310-win_amd64.whl", hash = "sha256:11e8fe261233aeabd67696d6b993eeb0896faa175c6b41b9a6c9f0334bdad1c5"}, + {file = "torch-2.2.2-cp310-none-macosx_10_9_x86_64.whl", hash = "sha256:b2e2200b245bd9f263a0d41b6a2dab69c4aca635a01b30cca78064b0ef5b109e"}, + {file = "torch-2.2.2-cp310-none-macosx_11_0_arm64.whl", hash = "sha256:877b3e6593b5e00b35bbe111b7057464e76a7dd186a287280d941b564b0563c2"}, + {file = "torch-2.2.2-cp311-cp311-manylinux1_x86_64.whl", hash = "sha256:ad4c03b786e074f46606f4151c0a1e3740268bcf29fbd2fdf6666d66341c1dcb"}, + {file = "torch-2.2.2-cp311-cp311-manylinux2014_aarch64.whl", hash = "sha256:32827fa1fbe5da8851686256b4cd94cc7b11be962862c2293811c94eea9457bf"}, + {file = "torch-2.2.2-cp311-cp311-win_amd64.whl", hash = "sha256:f9ef0a648310435511e76905f9b89612e45ef2c8b023bee294f5e6f7e73a3e7c"}, + {file = "torch-2.2.2-cp311-none-macosx_10_9_x86_64.whl", hash = "sha256:95b9b44f3bcebd8b6cd8d37ec802048c872d9c567ba52c894bba90863a439059"}, + {file = "torch-2.2.2-cp311-none-macosx_11_0_arm64.whl", hash = "sha256:49aa4126ede714c5aeef7ae92969b4b0bbe67f19665106463c39f22e0a1860d1"}, + {file = "torch-2.2.2-cp312-cp312-manylinux1_x86_64.whl", hash = "sha256:cf12cdb66c9c940227ad647bc9cf5dba7e8640772ae10dfe7569a0c1e2a28aca"}, + {file = "torch-2.2.2-cp312-cp312-manylinux2014_aarch64.whl", hash = "sha256:89ddac2a8c1fb6569b90890955de0c34e1724f87431cacff4c1979b5f769203c"}, + {file = "torch-2.2.2-cp312-cp312-win_amd64.whl", hash = "sha256:451331406b760f4b1ab298ddd536486ab3cfb1312614cfe0532133535be60bea"}, + {file = "torch-2.2.2-cp312-none-macosx_10_9_x86_64.whl", hash = "sha256:eb4d6e9d3663e26cd27dc3ad266b34445a16b54908e74725adb241aa56987533"}, + {file = "torch-2.2.2-cp312-none-macosx_11_0_arm64.whl", hash = "sha256:bf9558da7d2bf7463390b3b2a61a6a3dbb0b45b161ee1dd5ec640bf579d479fc"}, + {file = "torch-2.2.2-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:cd2bf7697c9e95fb5d97cc1d525486d8cf11a084c6af1345c2c2c22a6b0029d0"}, + {file = "torch-2.2.2-cp38-cp38-manylinux2014_aarch64.whl", hash = "sha256:b421448d194496e1114d87a8b8d6506bce949544e513742b097e2ab8f7efef32"}, + {file = "torch-2.2.2-cp38-cp38-win_amd64.whl", hash = "sha256:3dbcd563a9b792161640c0cffe17e3270d85e8f4243b1f1ed19cca43d28d235b"}, + {file = "torch-2.2.2-cp38-none-macosx_10_9_x86_64.whl", hash = "sha256:31f4310210e7dda49f1fb52b0ec9e59382cfcb938693f6d5378f25b43d7c1d29"}, + {file = "torch-2.2.2-cp38-none-macosx_11_0_arm64.whl", hash = "sha256:c795feb7e8ce2e0ef63f75f8e1ab52e7fd5e1a4d7d0c31367ade1e3de35c9e95"}, + {file = "torch-2.2.2-cp39-cp39-manylinux1_x86_64.whl", hash = "sha256:a6e5770d68158d07456bfcb5318b173886f579fdfbf747543901ce718ea94782"}, + {file = "torch-2.2.2-cp39-cp39-manylinux2014_aarch64.whl", hash = "sha256:67dcd726edff108e2cd6c51ff0e416fd260c869904de95750e80051358680d24"}, + {file = "torch-2.2.2-cp39-cp39-win_amd64.whl", hash = "sha256:539d5ef6c4ce15bd3bd47a7b4a6e7c10d49d4d21c0baaa87c7d2ef8698632dfb"}, + {file = "torch-2.2.2-cp39-none-macosx_10_9_x86_64.whl", hash = "sha256:dff696de90d6f6d1e8200e9892861fd4677306d0ef604cb18f2134186f719f82"}, + {file = "torch-2.2.2-cp39-none-macosx_11_0_arm64.whl", hash = "sha256:3a4dd910663fd7a124c056c878a52c2b0be4a5a424188058fe97109d4436ee42"}, ] [package.dependencies] @@ -3198,11 +3344,24 @@ filelock = "*" fsspec = "*" jinja2 = "*" networkx = "*" +nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nccl-cu12 = {version = "2.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} +nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""} sympy = "*" -typing-extensions = "*" +triton = {version = "2.2.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""} +typing-extensions = ">=4.8.0" [package.extras] opt-einsum = ["opt-einsum (>=3.3)"] +optree = ["optree (>=0.9.1)"] [[package]] name = "tornado" @@ -3239,6 +3398,29 @@ files = [ docs = ["myst-parser", "pydata-sphinx-theme", "sphinx"] test = ["argcomplete (>=3.0.3)", "mypy (>=1.5.1)", "pre-commit", "pytest (>=7.0,<7.5)", "pytest-mock", "pytest-mypy-testing"] +[[package]] +name = "triton" +version = "2.2.0" +description = "A language and compiler for custom Deep Learning operations" +optional = false +python-versions = "*" +files = [ + {file = "triton-2.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2294514340cfe4e8f4f9e5c66c702744c4a117d25e618bd08469d0bfed1e2e5"}, + {file = "triton-2.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da58a152bddb62cafa9a857dd2bc1f886dbf9f9c90a2b5da82157cd2b34392b0"}, + {file = "triton-2.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0af58716e721460a61886668b205963dc4d1e4ac20508cc3f623aef0d70283d5"}, + {file = "triton-2.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e8fe46d3ab94a8103e291bd44c741cc294b91d1d81c1a2888254cbf7ff846dab"}, + {file = "triton-2.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8ce26093e539d727e7cf6f6f0d932b1ab0574dc02567e684377630d86723ace"}, + {file = "triton-2.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:227cc6f357c5efcb357f3867ac2a8e7ecea2298cd4606a8ba1e931d1d5a947df"}, +] + +[package.dependencies] +filelock = "*" + +[package.extras] +build = ["cmake (>=3.20)", "lit"] +tests = ["autopep8", "flake8", "isort", "numpy", "pytest", "scipy (>=1.7.1)", "torch"] +tutorials = ["matplotlib", "pandas", "tabulate", "torch"] + [[package]] name = "types-python-dateutil" version = "2.8.19.14" From 43d4e79bd63e3fb8064e79466635a8beb9578042 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Tue, 9 Apr 2024 15:27:29 -0600 Subject: [PATCH 28/30] trying to print out cwd --- .github/workflows/test.yaml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 315e748..a39bd0f 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -40,7 +40,9 @@ jobs: run: echo "PYTHONPATH=$(pwd):$(pwd)/src" >> ${{ runner.workspace }}/.env - name: Test with pytest - run: python -m poetry run pytest --cov + run: | + pwd + python -m poetry run pytest --cov env: PYTHONPATH: ${{ env.PYTHONPATH }} ENV_FILE: ${{ runner.workspace }}/.env \ No newline at end of file From 4f0dae08824b6ca64a3ec31939de5d2de86ef72a Mon Sep 17 00:00:00 2001 From: beckynevin Date: Tue, 9 Apr 2024 15:32:30 -0600 Subject: [PATCH 29/30] should be pointing to correct test --- .github/workflows/test.yaml | 4 +- test/test_DeepEnsemble.py | 83 +++++++++++++++++-------------------- 2 files changed, 40 insertions(+), 47 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a39bd0f..315e748 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -40,9 +40,7 @@ jobs: run: echo "PYTHONPATH=$(pwd):$(pwd)/src" >> ${{ runner.workspace }}/.env - name: Test with pytest - run: | - pwd - python -m poetry run pytest --cov + run: python -m poetry run pytest --cov env: PYTHONPATH: ${{ env.PYTHONPATH }} ENV_FILE: ${{ runner.workspace }}/.env \ No newline at end of file diff --git a/test/test_DeepEnsemble.py b/test/test_DeepEnsemble.py index e8c4d4f..8bff976 100644 --- a/test/test_DeepEnsemble.py +++ b/test/test_DeepEnsemble.py @@ -4,12 +4,6 @@ import subprocess import tempfile import shutil - -# flake8: noqa -sys.path.append("..") -# print(sys.path) -# from scripts.evaluate import Diagnose_static, Diagnose_generative -# from scripts.io import ModelLoader from scripts import evaluate, models, DeepEnsemble @@ -35,22 +29,22 @@ def temp_directory(): """ shutil.rmtree(temp_dir) - -@pytest.mark.xfail(strict=True) -def test_no_chkpt_saved_xfail(temp_directory): +def test_chkpt_saved(temp_directory): noise_level = "low" n_models = 10 wd = str(temp_directory) + "/" n_epochs = 2 subprocess_args = [ "python", - "../src/scripts/DeepEnsemble.py", + "src/scripts/DeepEnsemble.py", noise_level, str(n_models), wd, "--n_epochs", str(n_epochs), - ] + "--save_final_checkpoint", + "--savefig" + ] # now run the subprocess subprocess.run(subprocess_args, check=True) # check if the right number of checkpoints are saved @@ -62,15 +56,38 @@ def test_no_chkpt_saved_xfail(temp_directory): len(files_in_models_folder) == n_models ), "Expected 10 files in the 'models' folder" + # check if the right number of images were saved + animations_folder = os.path.join(temp_directory, "images/animations") + files_in_animations_folder = os.listdir(animations_folder) + # assert that the number of files is equal to 10 + assert ( + len(files_in_animations_folder) == n_models + ), "Expected 10 files in the 'images/animations' folder" -def test_no_chkpt_saved(temp_directory): + # also check that all files in here have the same name elements + expected_substring = "epoch_" + str(n_epochs - 1) + for file_name in files_in_models_folder: + assert ( + expected_substring in file_name + ), f"File '{file_name}' does not contain the expected substring" + + # also check that all files in here have the same name elements + for file_name in files_in_animations_folder: + assert ( + expected_substring in file_name + ), f"File '{file_name}' does not contain the expected substring" + + + +@pytest.mark.xfail(strict=True) +def test_no_chkpt_saved_xfail(temp_directory): noise_level = "low" n_models = 10 wd = str(temp_directory) + "/" n_epochs = 2 subprocess_args = [ "python", - "../src/scripts/DeepEnsemble.py", + "src/scripts/DeepEnsemble.py", noise_level, str(n_models), wd, @@ -84,25 +101,25 @@ def test_no_chkpt_saved(temp_directory): # list all files in the "models" folder files_in_models_folder = os.listdir(models_folder) # assert that the number of files is equal to 10 - assert len(files_in_models_folder) == 0, "Expect 0 files in the 'models' folder" + assert ( + len(files_in_models_folder) == n_models + ), "Expected 10 files in the 'models' folder" -def test_chkpt_saved(temp_directory): +def test_no_chkpt_saved(temp_directory): noise_level = "low" n_models = 10 wd = str(temp_directory) + "/" n_epochs = 2 subprocess_args = [ "python", - "../src/scripts/DeepEnsemble.py", + "src/scripts/DeepEnsemble.py", noise_level, str(n_models), wd, "--n_epochs", str(n_epochs), - "--save_final_checkpoint", - "--savefig" - ] + ] # now run the subprocess subprocess.run(subprocess_args, check=True) # check if the right number of checkpoints are saved @@ -110,30 +127,8 @@ def test_chkpt_saved(temp_directory): # list all files in the "models" folder files_in_models_folder = os.listdir(models_folder) # assert that the number of files is equal to 10 - assert ( - len(files_in_models_folder) == n_models - ), "Expected 10 files in the 'models' folder" - - # check if the right number of images were saved - animations_folder = os.path.join(temp_directory, "images/animations") - files_in_animations_folder = os.listdir(animations_folder) - # assert that the number of files is equal to 10 - assert ( - len(files_in_animations_folder) == n_models - ), "Expected 10 files in the 'images/animations' folder" - - # also check that all files in here have the same name elements - expected_substring = "epoch_" + str(n_epochs - 1) - for file_name in files_in_models_folder: - assert ( - expected_substring in file_name - ), f"File '{file_name}' does not contain the expected substring" + assert len(files_in_models_folder) == 0, "Expect 0 files in the 'models' folder" - # also check that all files in here have the same name elements - for file_name in files_in_animations_folder: - assert ( - expected_substring in file_name - ), f"File '{file_name}' does not contain the expected substring" def test_run_simple_ensemble(temp_directory): @@ -145,7 +140,7 @@ def test_run_simple_ensemble(temp_directory): wd = str(temp_directory) + "/" subprocess_args = [ "python", - "../src/scripts/DeepEnsemble.py", + "src/scripts/DeepEnsemble.py", noise_level, n_models, wd, @@ -162,7 +157,7 @@ def test_missing_req_arg(temp_directory): n_models = "10" subprocess_args = [ "python", - "../src/scripts/DeepEnsemble.py", + "src/scripts/DeepEnsemble.py", noise_level, n_models, "--n_epochs", From 0723da7dc04fbf6ad117ba466d4d098cf86e184a Mon Sep 17 00:00:00 2001 From: beckynevin Date: Wed, 10 Apr 2024 08:32:59 -0600 Subject: [PATCH 30/30] do not want to add an .h5 file so deepensemble also give an option to create data dynamically --- src/scripts/DeepEnsemble.py | 42 ++++++++++++++++++++++++++++++------- test/test_DeepEnsemble.py | 12 +++++------ 2 files changed, 40 insertions(+), 14 deletions(-) diff --git a/src/scripts/DeepEnsemble.py b/src/scripts/DeepEnsemble.py index a800a67..759cc9c 100644 --- a/src/scripts/DeepEnsemble.py +++ b/src/scripts/DeepEnsemble.py @@ -169,17 +169,45 @@ def parse_args(): rs = namespace.randomseed BATCH_SIZE = namespace.batchsize sigma = io.DataPreparation.get_sigma(noise) + + # generate the df + data = io.DataPreparation() + data.sample_params_from_prior(size_df) + data.simulate_data(data.params, sigma, "linear_homogeneous") + df_array = data.get_dict() + # Convert non-tensor entries to tensors + df = {} + for key, value in df_array.items(): + + if isinstance(value, TensorDataset): + # Keep tensors as they are + df[key] = value + else: + # Convert lists to tensors + df[key] = torch.tensor(value) + + len_df = len(df["params"][:, 0].numpy()) + len_x = len(df["inputs"].numpy()) + ms_array = np.repeat(df["params"][:, 0].numpy(), len_x) + bs_array = np.repeat(df["params"][:, 1].numpy(), len_x) + xs_array = np.tile(df["inputs"].numpy(), len_df) + ys_array = np.reshape(df["output"].numpy(), (len_df * len_x)) + + """ loader = io.DataLoader() - data = loader.load_data_h5( + df = loader.load_data_h5( "linear_sigma_" + str(sigma) + "_size_" + str(size_df), path="/Users/rnevin/Documents/DeepUQ/data/", ) - len_df = len(data["params"][:, 0].numpy()) - len_x = len(data["inputs"].numpy()) - ms_array = np.repeat(data["params"][:, 0].numpy(), len_x) - bs_array = np.repeat(data["params"][:, 1].numpy(), len_x) - xs_array = np.tile(data["inputs"].numpy(), len_df) - ys_array = np.reshape(data["output"].numpy(), (len_df * len_x)) + len_df = len(df["params"][:, 0].numpy()) + len_x = len(df["inputs"].numpy()) + ms_array = np.repeat(df["params"][:, 0].numpy(), len_x) + bs_array = np.repeat(df["params"][:, 1].numpy(), len_x) + xs_array = np.tile(df["inputs"].numpy(), len_df) + ys_array = np.reshape(df["output"].numpy(), (len_df * len_x)) + print(df) + STOP + """ inputs = np.array([xs_array, ms_array, bs_array]).T model_inputs, model_outputs = io.DataPreparation.normalize(inputs, ys_array, diff --git a/test/test_DeepEnsemble.py b/test/test_DeepEnsemble.py index 8bff976..cfb5b8a 100644 --- a/test/test_DeepEnsemble.py +++ b/test/test_DeepEnsemble.py @@ -1,10 +1,8 @@ -import sys import pytest import os import subprocess import tempfile import shutil -from scripts import evaluate, models, DeepEnsemble @pytest.fixture @@ -29,6 +27,7 @@ def temp_directory(): """ shutil.rmtree(temp_dir) + def test_chkpt_saved(temp_directory): noise_level = "low" n_models = 10 @@ -43,8 +42,8 @@ def test_chkpt_saved(temp_directory): "--n_epochs", str(n_epochs), "--save_final_checkpoint", - "--savefig" - ] + "--savefig", + ] # now run the subprocess subprocess.run(subprocess_args, check=True) # check if the right number of checkpoints are saved @@ -78,7 +77,6 @@ def test_chkpt_saved(temp_directory): ), f"File '{file_name}' does not contain the expected substring" - @pytest.mark.xfail(strict=True) def test_no_chkpt_saved_xfail(temp_directory): noise_level = "low" @@ -127,8 +125,8 @@ def test_no_chkpt_saved(temp_directory): # list all files in the "models" folder files_in_models_folder = os.listdir(models_folder) # assert that the number of files is equal to 10 - assert len(files_in_models_folder) == 0, "Expect 0 files in the 'models' folder" - + assert len(files_in_models_folder) == 0, \ + "Expect 0 files in the 'models' folder" def test_run_simple_ensemble(temp_directory):