Skip to content

Commit

Permalink
figured out how to get checkpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed May 7, 2024
1 parent ab55afe commit 21ff521
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 243 deletions.
17 changes: 13 additions & 4 deletions src/analyze/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,14 @@ class AggregateCheckpoints:
# def load_final_checkpoints():
# def load_all_checkpoints():
# functions for loading model checkpoints
def load_DE_checkpoint(
self, model_name, nmodel, epoch, beta, device, path="models/checkpoints/"
def load_checkpoint(
self,
model_name,
noise,
nmodel,
epoch,
beta, device,
path="models/"
):
"""
Load PyTorch model checkpoint from a .pt file.
Expand All @@ -36,8 +42,11 @@ def load_DE_checkpoint(
:param model: PyTorch model to load the checkpoint into
:return: Loaded model
"""
file_name = path + f"{model_name}_beta_{beta}_nmodel_{nmodel}_epoch_{epoch}.pt"
checkpoint = torch.load(file_name, map_location=device)
if model_name[0:2] == "DE":
file_name = str(path) + "checkpoints/" + f"{model_name}_noise_{noise}_beta_{beta}_nmodel_{nmodel}_epoch_{epoch}.pt"
checkpoint = torch.load(file_name, map_location=device)
else:
STOP
return checkpoint

def ep_al_checkpoint_DE(checkpoint):
Expand Down
Loading

0 comments on commit 21ff521

Please sign in to comment.