-
Notifications
You must be signed in to change notification settings - Fork 4
/
interactive_predict.py
55 lines (48 loc) · 2.3 KB
/
interactive_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import traceback
from common import common
from extractor import Extractor
SHOW_TOP_CONTEXTS = 10
MAX_PATH_LENGTH = 8
MAX_PATH_WIDTH = 2
JAR_PATH = 'JavaExtractor/JPredict/target/JavaExtractor-0.0.1-SNAPSHOT.jar'
class InteractivePredictor:
exit_keywords = ['exit', 'quit', 'q']
def __init__(self, config, model):
model.predict([])
self.model = model
self.config = config
self.path_extractor = Extractor(config,
jar_path=JAR_PATH,
max_path_length=MAX_PATH_LENGTH,
max_path_width=MAX_PATH_WIDTH)
def read_file(self, input_filename):
with open(input_filename, 'r') as file:
return file.readlines()
def predict(self):
input_filename = 'input.ts'
print('Starting interactive prediction...')
while True:
print(
'Modify the file: "%s" and press any key when ready, or "q" / "quit" / "exit" to exit' % input_filename)
user_input = input()
if user_input.lower() in self.exit_keywords:
print('Exiting...')
return
try:
predict_lines, hash_to_string_dict = self.path_extractor.extract_paths(input_filename)
except ValueError as e:
print(e)
continue
results, code_vectors = self.model.predict(predict_lines)
prediction_results = common.parse_results(results, hash_to_string_dict, topk=SHOW_TOP_CONTEXTS)
for i, method_prediction in enumerate(prediction_results):
print('Original name:\t' + method_prediction.original_name)
for name_prob_pair in method_prediction.predictions:
print('\t(%f) predicted: %s' % (name_prob_pair['probability'], name_prob_pair['name']))
print('Attention:')
for attention_obj in method_prediction.attention_paths:
print('%f\tcontext: %s,%s,%s' % (
attention_obj['score'], attention_obj['token1'], attention_obj['path'], attention_obj['token2']))
if self.config.EXPORT_CODE_VECTORS:
print('Code vector:')
print(' '.join(map(str, code_vectors[i])))