From ef94792cb2748b9487c0c97ae9264679dff0ac96 Mon Sep 17 00:00:00 2001 From: Akshay Subramaniam <6964110+akshaysubr@users.noreply.github.com> Date: Fri, 6 Oct 2023 13:48:59 -0700 Subject: [PATCH] Adding per rank mlflow tracking location to fix race condition and updating fcn_afno config for easier benchmarking (#121) Signed-off-by: Akshay Subramaniam <6964110+akshaysubr@users.noreply.github.com> --- examples/weather/fcn_afno/conf/config.yaml | 2 ++ examples/weather/fcn_afno/train_era5.py | 4 ++-- modulus/launch/logging/mlflow.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/examples/weather/fcn_afno/conf/config.yaml b/examples/weather/fcn_afno/conf/config.yaml index 2b7b13d..b3a1e5a 100644 --- a/examples/weather/fcn_afno/conf/config.yaml +++ b/examples/weather/fcn_afno/conf/config.yaml @@ -19,7 +19,9 @@ hydra: dir: ./outputs/ wb_artifacts: False +use_mlflow: True start_epoch: 0 +max_epoch: 80 num_samples_per_year_train: 1456 diff --git a/examples/weather/fcn_afno/train_era5.py b/examples/weather/fcn_afno/train_era5.py index 11fe035..2cb571a 100644 --- a/examples/weather/fcn_afno/train_era5.py +++ b/examples/weather/fcn_afno/train_era5.py @@ -109,7 +109,7 @@ def main(cfg: DictConfig) -> None: user_name="Modulus User", mode="offline", ) - LaunchLogger.initialize(use_mlflow=True) # Modulus launch logger + LaunchLogger.initialize(use_mlflow=cfg.use_mlflow) # Modulus launch logger logger = PythonLogger("main") # General python logger datapipe = ERA5HDF5Datapipe( @@ -199,7 +199,7 @@ def train_step_forward(my_model, invar, outvar): return loss # Main training loop - max_epoch = 80 + max_epoch = cfg.max_epoch for epoch in range(max(1, loaded_epoch + 1), max_epoch + 1): # Wrap epoch in launch logger for console / WandB logs with LaunchLogger( diff --git a/modulus/launch/logging/mlflow.py b/modulus/launch/logging/mlflow.py index 64e3fd3..628893e 100644 --- a/modulus/launch/logging/mlflow.py +++ b/modulus/launch/logging/mlflow.py @@ -90,7 +90,7 @@ def initialize_mlflow( group_name = f"{run_name}_{time_string}" # Set default value here for Hydra if tracking_location is None: - tracking_location = str(Path("./mlruns").absolute()) + tracking_location = str(Path(f"./mlruns_{dist.rank}").absolute()) # Set up URI (remote or local) if mode == "online":