From 03d1ca6ba3d4a102d49777f59cd5ed8d43c48204 Mon Sep 17 00:00:00 2001 From: beckynevin Date: Sat, 13 Jul 2024 16:10:02 -0600 Subject: [PATCH] also prints out chk on 99, data_dim added to check filename --- src/train/train.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/train/train.py b/src/train/train.py index da62b1b..0e3a712 100644 --- a/src/train/train.py +++ b/src/train/train.py @@ -27,6 +27,7 @@ def train_DER( path_to_model="models/", data_prescription="linear_homoskedastic", inject_type="predictive", + data_dim="0D", noise_level="low", save_all_checkpoints=False, save_final_checkpoint=False, @@ -95,7 +96,8 @@ def train_DER( best_loss = np.inf # init to infinity model, lossFn = models.model_setup_DER(loss_type, DEVICE, - n_hidden=n_hidden) + n_hidden=n_hidden, + data_type=data_dim) if verbose: print("model is", model, "lossfn", lossFn) opt = torch.optim.Adam(model.parameters(), lr=INIT_LR) @@ -127,7 +129,6 @@ def train_DER( # send the input to the device # (x, y) = (x.to(device), y.to(device)) # perform a forward pass and calculate the training loss - pred = model(x) loss = lossFn(pred, y, COEFF) if plot or savefig: @@ -259,6 +260,8 @@ def train_DER( + str(data_prescription) + "_" + str(inject_type) + + "_" + + str(data_dim) + "_loss_" + str(loss_type) + "_COEFF_" @@ -280,6 +283,8 @@ def train_DER( + str(data_prescription) + "_" + str(inject_type) + + "_" + + str(data_dim) + "_noise_" + str(noise_level) + "_loss_" @@ -313,6 +318,8 @@ def train_DER( }, filename ) + if epoch == 99: + print('checkpoint saved here', filename) if save_final_checkpoint and (e % (EPOCHS - 1) == 0) and (e != 0): filename = ( @@ -323,6 +330,8 @@ def train_DER( + str(data_prescription) + "_" + str(inject_type) + + "_" + + str(data_dim) + "_noise_" + str(noise_level) + "_loss_"