forked from tech-srl/code2seq
-
Notifications
You must be signed in to change notification settings - Fork 17
/
code2seq.py
39 lines (33 loc) · 1.26 KB
/
code2seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import numpy as np
import tensorflow as tf
from config import Config
from interactive_predict import InteractivePredictor
from modelrunner import ModelRunner
from args import read_args
if __name__ == '__main__':
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices):
tf.config.experimental.set_memory_growth(physical_devices[0], True)
# tf.config.set_visible_devices([], 'GPU')
args = read_args()
np.random.seed(args.seed)
tf.random.set_seed(args.seed)
if args.debug:
config = Config.get_debug_config(args)
tf.config.experimental_run_functions_eagerly(True)
else:
config = Config.get_default_config(args)
print('Created model')
if config.TRAIN_PATH:
model = ModelRunner(config)
model.train()
if config.TEST_PATH and not args.data_path:
model = ModelRunner(config)
results, precision, recall, f1, rouge = model.evaluate()
print('Accuracy: ' + str(results))
print('Precision: ' + str(precision) + ', recall: ' + str(recall) + ', F1: ' + str(f1))
print('Rouge: ', rouge)
if args.predict:
model = ModelRunner(config)
predictor = InteractivePredictor(config, model, args.predict)
predictor.predict()