-
Notifications
You must be signed in to change notification settings - Fork 6
/
main.py
34 lines (23 loc) · 1.06 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import argparse
def main(work_type_args):
if work_type_args.type == 'classification_TU':
from parsers.classification_TU import Parser
from trainers.classification_TU import Trainer
elif work_type_args.type == 'classification_OGB':
from parsers.classification_OGB import Parser
from trainers.classification_OGB import Trainer
elif work_type_args.type == 'reconstruction_ZINC':
from parsers.reconstruction_ZINC import Parser
from trainers.reconstruction_ZINC import Trainer
elif work_type_args.type == 'classification_node':
from parsers.classification_node import Parser
from trainers.classification_node import Trainer
else:
raise ValueError("Work Type Name <{}> is Unknown".format(work_type_args.type))
args = Parser().parse()
trainer = Trainer(args)
trainer.train()
if __name__ == '__main__':
work_type_parser = argparse.ArgumentParser()
work_type_parser.add_argument('--type', type=str, required=True)
main(work_type_parser.parse_known_args()[0])