Skip to content

Commit

Permalink
fixed checkpoints (NVIDIA#42)
Browse files Browse the repository at this point in the history
* fixed checkpoints

* fixed nested darcy as well

* fixed checkpoint

---------

Co-authored-by: oliver <[email protected]>
  • Loading branch information
loliverhennigh and loliverhennigh authored Aug 8, 2023
1 parent 3adcab4 commit b3dc226
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 25 deletions.
10 changes: 3 additions & 7 deletions examples/cfd/darcy_fno/train_fno_darcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,11 @@ def darcy_trainer(cfg: DictConfig) -> None:
LaunchLogger.initialize(use_mlflow=True) # Modulus launch logger

# define model, loss, optimiser, scheduler, data loader
decoder = FullyConnected(
in_features=cfg.arch.fno.latent_channels,
out_features=cfg.arch.decoder.out_features,
num_layers=cfg.arch.decoder.layers,
layer_size=cfg.arch.decoder.layer_size,
)
model = FNO(
decoder_net=decoder,
in_channels=cfg.arch.fno.in_channels,
out_channels=cfg.arch.decoder.out_features,
decoder_layers=cfg.arch.decoder.layers,
decoder_layer_size=cfg.arch.decoder.layer_size,
dimension=cfg.arch.fno.dimension,
latent_channels=cfg.arch.fno.latent_channels,
num_fno_layers=cfg.arch.fno.fno_layers,
Expand Down
10 changes: 3 additions & 7 deletions examples/cfd/darcy_nested_fnos/evaluate_nested_darcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,11 @@ def EvaluateModel(
dist = DistributedManager()
log.info(f"evaluating model {model_name}")
model_cfg = cfg.arch[model_name]
decoder = FullyConnected(
in_features=model_cfg.fno.latent_channels,
out_features=model_cfg.decoder.out_features,
num_layers=model_cfg.decoder.layers,
layer_size=model_cfg.decoder.layer_size,
)
model = FNO(
decoder_net=decoder,
in_channels=model_cfg.fno.in_channels,
out_channels=model_cfg.decoder.out_features,
decoder_layers=model_cfg.decoder.layers,
decoder_layer_size=model_cfg.decoder.layer_size,
dimension=model_cfg.fno.dimension,
latent_channels=model_cfg.fno.latent_channels,
num_fno_layers=model_cfg.fno.fno_layers,
Expand Down
4 changes: 3 additions & 1 deletion examples/cfd/darcy_nested_fnos/train_nested_darcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ def __init__(
layer_size=model_cfg.decoder.layer_size,
)
self.model = FNO(
decoder_net=decoder,
in_channels=model_cfg.fno.in_channels,
out_channels=model_cfg.decoder.out_features,
decoder_layers=model_cfg.decoder.layers,
decoder_layer_size=model_cfg.decoder.layer_size,
dimension=model_cfg.fno.dimension,
latent_channels=model_cfg.fno.latent_channels,
num_fno_layers=model_cfg.fno.fno_layers,
Expand Down
49 changes: 39 additions & 10 deletions modulus/launch/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def _get_checkpoint_filename(
base_name: str = "checkpoint",
index: Union[int, None] = None,
saving: bool = False,
model_type: str = "mdlus",
) -> str:
"""Gets the file name /path of checkpoint
Expand All @@ -58,6 +59,9 @@ def _get_checkpoint_filename(
Checkpoint index, by default None
saving : bool, optional
Get filename for saving a new checkpoint, by default False
model_type : str
Model type, by default "mdlus" for Modulus models and "pt" for PyTorch models
Returns
-------
Expand All @@ -77,22 +81,34 @@ def _get_checkpoint_filename(
checkpoint_filename = str(
Path(path).resolve() / f"{base_name}.{model_parallel_rank}"
)

# File extension for Modulus models or PyTorch models
file_extension = ".mdlus" if model_type == "mdlus" else ".pt"

# If epoch is provided load that file
if index is not None:
checkpoint_filename = checkpoint_filename + f".{index}"
checkpoint_filename += ".pt"
checkpoint_filename += file_extension
# Otherwise try loading the latest epoch or rolling checkpoint
else:
file_names = []
for fname in glob.glob(checkpoint_filename + "*.pt", recursive=False):
for fname in glob.glob(
checkpoint_filename + "*" + file_extension, recursive=False
):
file_names.append(Path(fname).name)

if len(file_names) > 0:
# If checkpoint from a null index save exists load that
# This is the most likely line to error since it will fail with
# invalid checkpoint names
file_idx = [
int(re.sub(f"^{base_name}.{model_parallel_rank}.|.pt", "", fname))
int(
re.sub(
f"^{base_name}.{model_parallel_rank}.|" + file_extension,
"",
fname,
)
)
for fname in file_names
]
file_idx.sort()
Expand All @@ -101,9 +117,9 @@ def _get_checkpoint_filename(
checkpoint_filename = checkpoint_filename + f".{file_idx[-1]+1}"
else:
checkpoint_filename = checkpoint_filename + f".{file_idx[-1]}"
checkpoint_filename += ".pt"
checkpoint_filename += file_extension
else:
checkpoint_filename += ".0.pt"
checkpoint_filename += ".0" + file_extension

return checkpoint_filename

Expand Down Expand Up @@ -163,7 +179,7 @@ def save_checkpoint(
"""Training checkpoint saving utility
This will save a training checkpoint in the provided path following the file naming
convention "checkpoint.{model parallel id}.{epoch/index}.pt". The load checkpoint
convention "checkpoint.{model parallel id}.{epoch/index}.mdlus". The load checkpoint
method in Modulus core can then be used to read this file.
Parameters
Expand Down Expand Up @@ -196,8 +212,14 @@ def save_checkpoint(
models = [models]
models = _unique_model_names(models)
for name, model in models.items():
# Get model type
model_type = "mdlus" if isinstance(model, modulus.models.Module) else "pt"

# Get full file path / name
file_name = _get_checkpoint_filename(path, name, index=epoch, saving=True)
file_name = _get_checkpoint_filename(
path, name, index=epoch, saving=True, model_type=model_type
)

# Save state dictionary
if isinstance(model, modulus.models.Module):
model.save(file_name)
Expand All @@ -223,7 +245,9 @@ def save_checkpoint(
checkpoint_dict["static_capture_state_dict"] = _StaticCapture.state_dict()

# Output file name
output_filename = _get_checkpoint_filename(path, index=epoch, saving=True)
output_filename = _get_checkpoint_filename(
path, index=epoch, saving=True, model_type="pt"
)
if epoch:
checkpoint_dict["epoch"] = epoch

Expand Down Expand Up @@ -287,8 +311,13 @@ def load_checkpoint(
models = [models]
models = _unique_model_names(models)
for name, model in models.items():
# Get model type
model_type = "mdlus" if isinstance(model, modulus.models.Module) else "pt"

# Get full file path / name
file_name = _get_checkpoint_filename(path, name, index=epoch)
file_name = _get_checkpoint_filename(
path, name, index=epoch, model_type=model_type
)
if not Path(file_name).exists():
checkpoint_logging.error(
f"Could not find valid model file {file_name}, skipping load"
Expand All @@ -305,7 +334,7 @@ def load_checkpoint(
)

# == Loading training checkpoint ==
checkpoint_filename = _get_checkpoint_filename(path, index=epoch)
checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt")
if not Path(checkpoint_filename).is_file():
checkpoint_logging.warning(
"Could not find valid checkpoint file, skipping load"
Expand Down

0 comments on commit b3dc226

Please sign in to comment.