forked from FAIR-Chem/fairchem
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
127 lines (111 loc) · 4.28 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
Copyright (c) Facebook, Inc. and its affiliates.
This source code is licensed under the MIT license found in the
LICENSE file in the root directory of this source tree.
"""
import copy
import logging
import os
import sys
import time
from pathlib import Path
import submitit
from ocpmodels.common import distutils
from ocpmodels.common.flags import flags
from ocpmodels.common.registry import registry
from ocpmodels.common.utils import (
build_config,
create_grid,
save_experiment_log,
setup_imports,
setup_logging,
)
class Runner(submitit.helpers.Checkpointable):
def __init__(self):
self.config = None
def __call__(self, config):
setup_logging()
self.config = copy.deepcopy(config)
if args.distributed:
distutils.setup(config)
try:
setup_imports()
self.trainer = registry.get_trainer_class(
config.get("trainer", "energy")
)(
task=config["task"],
model=config["model"],
dataset=config["dataset"],
optimizer=config["optim"],
identifier=config["identifier"],
timestamp_id=config.get("timestamp_id", None),
run_dir=config.get("run_dir", "./"),
is_debug=config.get("is_debug", False),
print_every=config.get("print_every", 10),
seed=config.get("seed", 0),
logger=config.get("logger", "tensorboard"),
local_rank=config["local_rank"],
amp=config.get("amp", False),
cpu=config.get("cpu", False),
slurm=config.get("slurm", {}),
noddp=config.get("noddp", False),
deepspeed_config=config["deepspeed_config"],
profiler=config.get("profiler", {"enabled": False}),
)
self.task = registry.get_task_class(config["mode"])(self.config)
self.task.setup(self.trainer)
start_time = time.time()
self.task.run()
distutils.synchronize()
if distutils.is_master():
logging.info(f"Total time taken: {time.time() - start_time}")
finally:
if args.distributed:
distutils.cleanup()
def checkpoint(self, *args, **kwargs):
new_runner = Runner()
self.trainer.save(checkpoint_file="checkpoint.pt", training_state=True)
self.config["checkpoint"] = self.task.chkpt_path
self.config["timestamp_id"] = self.trainer.timestamp_id
if self.trainer.logger is not None:
self.trainer.logger.mark_preempting()
return submitit.helpers.DelayedSubmission(new_runner, self.config)
if __name__ == "__main__":
setup_logging()
parser = flags.get_parser()
args, override_args = parser.parse_known_args()
config = build_config(args, override_args)
if args.submit: # Run on cluster
slurm_add_params = config.get(
"slurm", None
) # additional slurm arguments
if args.sweep_yml: # Run grid search
configs = create_grid(config, args.sweep_yml)
else:
configs = [config]
logging.info(f"Submitting {len(configs)} jobs")
executor = submitit.AutoExecutor(
folder=args.logdir / "%j", slurm_max_num_timeout=3
)
executor.update_parameters(
name=args.identifier,
mem_gb=args.slurm_mem,
timeout_min=args.slurm_timeout * 60,
slurm_partition=args.slurm_partition,
gpus_per_node=args.num_gpus,
cpus_per_task=(config["optim"]["num_workers"] + 1),
tasks_per_node=(args.num_gpus if args.distributed else 1),
nodes=args.num_nodes,
slurm_additional_parameters=slurm_add_params,
)
for config in configs:
config["slurm"] = copy.deepcopy(executor.parameters)
config["slurm"]["folder"] = str(executor.folder)
jobs = executor.map_array(Runner(), configs)
logging.info(
f"Submitted jobs: {', '.join([job.job_id for job in jobs])}"
)
log_file = save_experiment_log(args, jobs, configs)
logging.info(f"Experiment log saved to: {log_file}")
else: # Run locally
Runner()(config)