This repository has been archived by the owner on Oct 31, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 67
/
main.py
93 lines (74 loc) · 2.5 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
import argparse
import torch.multiprocessing as mp
import pprint
import yaml
from src.paws_train import main as paws
from src.suncet_train import main as suncet
from src.fine_tune import main as fine_tune
from src.snn_fine_tune import main as snn_fine_tune
from src.utils import init_distributed
parser = argparse.ArgumentParser()
parser.add_argument(
'--fname', type=str,
help='name of config file to load',
default='configs.yaml')
parser.add_argument(
'--devices', type=str, nargs='+', default=['cuda:0'],
help='which devices to use on local machine')
parser.add_argument(
'--sel', type=str,
help='which script to run',
choices=[
'paws_train',
'suncet_train',
'fine_tune',
'snn_fine_tune'
])
def process_main(rank, sel, fname, world_size, devices):
import os
os.environ['CUDA_VISIBLE_DEVICES'] = str(devices[rank].split(':')[-1])
import logging
logging.basicConfig()
logger = logging.getLogger()
logger.info(f'called-params {sel} {fname}')
# -- load script params
params = None
with open(fname, 'r') as y_file:
params = yaml.load(y_file, Loader=yaml.FullLoader)
logger.info('loaded params...')
if rank == 0:
pp = pprint.PrettyPrinter(indent=4)
pp.pprint(params)
if rank == 0:
dump = os.path.join(params['logging']['folder'], f'params-{sel}.yaml')
with open(dump, 'w') as f:
yaml.dump(params, f)
world_size, rank = init_distributed(rank_and_world_size=(rank, world_size))
# -- make sure all processes correctly initialized torch-distributed
logger.info(f'Running {sel} (rank: {rank}/{world_size})')
# -- turn off info-logging for ranks > 0, otherwise too much std output
if rank == 0:
logger.setLevel(logging.INFO)
else:
logger.setLevel(logging.ERROR)
if sel == 'paws_train':
return paws(params)
elif sel == 'suncet_train':
return suncet(params)
elif sel == 'fine_tune':
return fine_tune(params)
elif sel == 'snn_fine_tune':
return snn_fine_tune(params)
if __name__ == '__main__':
args = parser.parse_args()
num_gpus = len(args.devices)
mp.spawn(
process_main,
nprocs=num_gpus,
args=(args.sel, args.fname, num_gpus, args.devices))