Skip to content

Commit

Permalink
also prints out chk on 99, data_dim added to check filename
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Jul 13, 2024
1 parent 235a570 commit 03d1ca6
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions src/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def train_DER(
path_to_model="models/",
data_prescription="linear_homoskedastic",
inject_type="predictive",
data_dim="0D",
noise_level="low",
save_all_checkpoints=False,
save_final_checkpoint=False,
Expand Down Expand Up @@ -95,7 +96,8 @@ def train_DER(
best_loss = np.inf # init to infinity
model, lossFn = models.model_setup_DER(loss_type,
DEVICE,
n_hidden=n_hidden)
n_hidden=n_hidden,
data_type=data_dim)
if verbose:
print("model is", model, "lossfn", lossFn)
opt = torch.optim.Adam(model.parameters(), lr=INIT_LR)
Expand Down Expand Up @@ -127,7 +129,6 @@ def train_DER(
# send the input to the device
# (x, y) = (x.to(device), y.to(device))
# perform a forward pass and calculate the training loss

pred = model(x)
loss = lossFn(pred, y, COEFF)
if plot or savefig:
Expand Down Expand Up @@ -259,6 +260,8 @@ def train_DER(
+ str(data_prescription)
+ "_"
+ str(inject_type)
+ "_"
+ str(data_dim)
+ "_loss_"
+ str(loss_type)
+ "_COEFF_"
Expand All @@ -280,6 +283,8 @@ def train_DER(
+ str(data_prescription)
+ "_"
+ str(inject_type)
+ "_"
+ str(data_dim)
+ "_noise_"
+ str(noise_level)
+ "_loss_"
Expand Down Expand Up @@ -313,6 +318,8 @@ def train_DER(
},
filename
)
if epoch == 99:
print('checkpoint saved here', filename)

if save_final_checkpoint and (e % (EPOCHS - 1) == 0) and (e != 0):
filename = (
Expand All @@ -323,6 +330,8 @@ def train_DER(
+ str(data_prescription)
+ "_"
+ str(inject_type)
+ "_"
+ str(data_dim)
+ "_noise_"
+ str(noise_level)
+ "_loss_"
Expand Down

0 comments on commit 03d1ca6

Please sign in to comment.