From 4d129f089504239a5d61b81bd08eee86f07c4bb1 Mon Sep 17 00:00:00 2001 From: Mohammad Amin Nabian Date: Tue, 25 Apr 2023 09:54:34 -0700 Subject: [PATCH] add cugraphops args --- recipes/gnn/graphcast/constants.py | 3 +++ recipes/gnn/graphcast/train_graphcast.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/recipes/gnn/graphcast/constants.py b/recipes/gnn/graphcast/constants.py index 89d86c6..fdcb812 100644 --- a/recipes/gnn/graphcast/constants.py +++ b/recipes/gnn/graphcast/constants.py @@ -33,7 +33,10 @@ class Constants(BaseModel): checkpoint_processor_finetune: bool = True checkpoint_decoder_finetune: bool = True concat_trick: bool = True + cugraphops_encoder: bool = False cugraphops_processor: bool = False + cugraphops_decoder: bool = False + recompute_activation: bool = False wb_mode: str = "disabled" dataset_path: str = "datasets/ngc_era5_data" static_dataset_path: str = "datasets/static" diff --git a/recipes/gnn/graphcast/train_graphcast.py b/recipes/gnn/graphcast/train_graphcast.py index 30c216d..6f2f37b 100644 --- a/recipes/gnn/graphcast/train_graphcast.py +++ b/recipes/gnn/graphcast/train_graphcast.py @@ -88,6 +88,10 @@ def __init__(self, wb, dist, rank_zero_logger): processor_layers=C.processor_layers, hidden_dim=C.hidden_dim, do_concat_trick=C.concat_trick, + use_cugraphops_encoder=C.cugraphops_encoder, + use_cugraphops_processor=C.cugraphops_processor, + use_cugraphops_decoder=C.cugraphops_decoder, + recompute_activation=C.recompute_activation, ) # set gradient checkpointing @@ -293,7 +297,7 @@ def __init__(self, wb, dist, rank_zero_logger): ) update_dataloader = False rank_zero_logger.info( - dist, f"Switching to {num_rollout_steps}-step rollout!" + f"Switching to {num_rollout_steps}-step rollout!" ) break