Skip to content

Commit

Permalink
add cugraphops args
Browse files Browse the repository at this point in the history
  • Loading branch information
mnabian authored Apr 25, 2023
1 parent 875a795 commit 4d129f0
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
3 changes: 3 additions & 0 deletions recipes/gnn/graphcast/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@ class Constants(BaseModel):
checkpoint_processor_finetune: bool = True
checkpoint_decoder_finetune: bool = True
concat_trick: bool = True
cugraphops_encoder: bool = False
cugraphops_processor: bool = False
cugraphops_decoder: bool = False
recompute_activation: bool = False
wb_mode: str = "disabled"
dataset_path: str = "datasets/ngc_era5_data"
static_dataset_path: str = "datasets/static"
Expand Down
6 changes: 5 additions & 1 deletion recipes/gnn/graphcast/train_graphcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ def __init__(self, wb, dist, rank_zero_logger):
processor_layers=C.processor_layers,
hidden_dim=C.hidden_dim,
do_concat_trick=C.concat_trick,
use_cugraphops_encoder=C.cugraphops_encoder,
use_cugraphops_processor=C.cugraphops_processor,
use_cugraphops_decoder=C.cugraphops_decoder,
recompute_activation=C.recompute_activation,
)

# set gradient checkpointing
Expand Down Expand Up @@ -293,7 +297,7 @@ def __init__(self, wb, dist, rank_zero_logger):
)
update_dataloader = False
rank_zero_logger.info(
dist, f"Switching to {num_rollout_steps}-step rollout!"
f"Switching to {num_rollout_steps}-step rollout!"
)
break

Expand Down

0 comments on commit 4d129f0

Please sign in to comment.