diff --git a/examples/cfd/darcy_fno/train_fno_darcy.py b/examples/cfd/darcy_fno/train_fno_darcy.py index 7a0572f..2851842 100644 --- a/examples/cfd/darcy_fno/train_fno_darcy.py +++ b/examples/cfd/darcy_fno/train_fno_darcy.py @@ -61,15 +61,11 @@ def darcy_trainer(cfg: DictConfig) -> None: LaunchLogger.initialize(use_mlflow=True) # Modulus launch logger # define model, loss, optimiser, scheduler, data loader - decoder = FullyConnected( - in_features=cfg.arch.fno.latent_channels, - out_features=cfg.arch.decoder.out_features, - num_layers=cfg.arch.decoder.layers, - layer_size=cfg.arch.decoder.layer_size, - ) model = FNO( - decoder_net=decoder, in_channels=cfg.arch.fno.in_channels, + out_channels=cfg.arch.decoder.out_features, + decoder_layers=cfg.arch.decoder.layers, + decoder_layer_size=cfg.arch.decoder.layer_size, dimension=cfg.arch.fno.dimension, latent_channels=cfg.arch.fno.latent_channels, num_fno_layers=cfg.arch.fno.fno_layers, diff --git a/examples/cfd/darcy_nested_fnos/evaluate_nested_darcy.py b/examples/cfd/darcy_nested_fnos/evaluate_nested_darcy.py index 0d832ee..eabe685 100644 --- a/examples/cfd/darcy_nested_fnos/evaluate_nested_darcy.py +++ b/examples/cfd/darcy_nested_fnos/evaluate_nested_darcy.py @@ -56,15 +56,11 @@ def EvaluateModel( dist = DistributedManager() log.info(f"evaluating model {model_name}") model_cfg = cfg.arch[model_name] - decoder = FullyConnected( - in_features=model_cfg.fno.latent_channels, - out_features=model_cfg.decoder.out_features, - num_layers=model_cfg.decoder.layers, - layer_size=model_cfg.decoder.layer_size, - ) model = FNO( - decoder_net=decoder, in_channels=model_cfg.fno.in_channels, + out_channels=model_cfg.decoder.out_features, + decoder_layers=model_cfg.decoder.layers, + decoder_layer_size=model_cfg.decoder.layer_size, dimension=model_cfg.fno.dimension, latent_channels=model_cfg.fno.latent_channels, num_fno_layers=model_cfg.fno.fno_layers, diff --git a/examples/cfd/darcy_nested_fnos/train_nested_darcy.py b/examples/cfd/darcy_nested_fnos/train_nested_darcy.py index 7af47c1..0f35999 100644 --- a/examples/cfd/darcy_nested_fnos/train_nested_darcy.py +++ b/examples/cfd/darcy_nested_fnos/train_nested_darcy.py @@ -132,8 +132,10 @@ def __init__( layer_size=model_cfg.decoder.layer_size, ) self.model = FNO( - decoder_net=decoder, in_channels=model_cfg.fno.in_channels, + out_channels=model_cfg.decoder.out_features, + decoder_layers=model_cfg.decoder.layers, + decoder_layer_size=model_cfg.decoder.layer_size, dimension=model_cfg.fno.dimension, latent_channels=model_cfg.fno.latent_channels, num_fno_layers=model_cfg.fno.fno_layers, diff --git a/modulus/launch/utils/checkpoint.py b/modulus/launch/utils/checkpoint.py index 0f57547..ceb6a9c 100644 --- a/modulus/launch/utils/checkpoint.py +++ b/modulus/launch/utils/checkpoint.py @@ -38,6 +38,7 @@ def _get_checkpoint_filename( base_name: str = "checkpoint", index: Union[int, None] = None, saving: bool = False, + model_type: str = "mdlus", ) -> str: """Gets the file name /path of checkpoint @@ -58,6 +59,9 @@ def _get_checkpoint_filename( Checkpoint index, by default None saving : bool, optional Get filename for saving a new checkpoint, by default False + model_type : str + Model type, by default "mdlus" for Modulus models and "pt" for PyTorch models + Returns ------- @@ -77,14 +81,20 @@ def _get_checkpoint_filename( checkpoint_filename = str( Path(path).resolve() / f"{base_name}.{model_parallel_rank}" ) + + # File extension for Modulus models or PyTorch models + file_extension = ".mdlus" if model_type == "mdlus" else ".pt" + # If epoch is provided load that file if index is not None: checkpoint_filename = checkpoint_filename + f".{index}" - checkpoint_filename += ".pt" + checkpoint_filename += file_extension # Otherwise try loading the latest epoch or rolling checkpoint else: file_names = [] - for fname in glob.glob(checkpoint_filename + "*.pt", recursive=False): + for fname in glob.glob( + checkpoint_filename + "*" + file_extension, recursive=False + ): file_names.append(Path(fname).name) if len(file_names) > 0: @@ -92,7 +102,13 @@ def _get_checkpoint_filename( # This is the most likely line to error since it will fail with # invalid checkpoint names file_idx = [ - int(re.sub(f"^{base_name}.{model_parallel_rank}.|.pt", "", fname)) + int( + re.sub( + f"^{base_name}.{model_parallel_rank}.|" + file_extension, + "", + fname, + ) + ) for fname in file_names ] file_idx.sort() @@ -101,9 +117,9 @@ def _get_checkpoint_filename( checkpoint_filename = checkpoint_filename + f".{file_idx[-1]+1}" else: checkpoint_filename = checkpoint_filename + f".{file_idx[-1]}" - checkpoint_filename += ".pt" + checkpoint_filename += file_extension else: - checkpoint_filename += ".0.pt" + checkpoint_filename += ".0" + file_extension return checkpoint_filename @@ -163,7 +179,7 @@ def save_checkpoint( """Training checkpoint saving utility This will save a training checkpoint in the provided path following the file naming - convention "checkpoint.{model parallel id}.{epoch/index}.pt". The load checkpoint + convention "checkpoint.{model parallel id}.{epoch/index}.mdlus". The load checkpoint method in Modulus core can then be used to read this file. Parameters @@ -196,8 +212,14 @@ def save_checkpoint( models = [models] models = _unique_model_names(models) for name, model in models.items(): + # Get model type + model_type = "mdlus" if isinstance(model, modulus.models.Module) else "pt" + # Get full file path / name - file_name = _get_checkpoint_filename(path, name, index=epoch, saving=True) + file_name = _get_checkpoint_filename( + path, name, index=epoch, saving=True, model_type=model_type + ) + # Save state dictionary if isinstance(model, modulus.models.Module): model.save(file_name) @@ -223,7 +245,9 @@ def save_checkpoint( checkpoint_dict["static_capture_state_dict"] = _StaticCapture.state_dict() # Output file name - output_filename = _get_checkpoint_filename(path, index=epoch, saving=True) + output_filename = _get_checkpoint_filename( + path, index=epoch, saving=True, model_type="pt" + ) if epoch: checkpoint_dict["epoch"] = epoch @@ -287,8 +311,13 @@ def load_checkpoint( models = [models] models = _unique_model_names(models) for name, model in models.items(): + # Get model type + model_type = "mdlus" if isinstance(model, modulus.models.Module) else "pt" + # Get full file path / name - file_name = _get_checkpoint_filename(path, name, index=epoch) + file_name = _get_checkpoint_filename( + path, name, index=epoch, model_type=model_type + ) if not Path(file_name).exists(): checkpoint_logging.error( f"Could not find valid model file {file_name}, skipping load" @@ -305,7 +334,7 @@ def load_checkpoint( ) # == Loading training checkpoint == - checkpoint_filename = _get_checkpoint_filename(path, index=epoch) + checkpoint_filename = _get_checkpoint_filename(path, index=epoch, model_type="pt") if not Path(checkpoint_filename).is_file(): checkpoint_logging.warning( "Could not find valid checkpoint file, skipping load"