-
Notifications
You must be signed in to change notification settings - Fork 33
/
Copy pathdecoder.py
69 lines (55 loc) · 2.46 KB
/
decoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import traceback
import config
from model import *
def decode_python_dataset(model, dataset, verbose=True):
from lang.py.parse import decode_tree_to_python_ast
if verbose:
logging.info('decoding [%s] set, num. examples: %d', dataset.name, dataset.count)
decode_results = []
cum_num = 0
for example in dataset.examples:
cand_list = model.decode(example, dataset.grammar, dataset.terminal_vocab,
beam_size=config.beam_size, max_time_step=config.decode_max_time_step)
exg_decode_results = []
for cid, cand in enumerate(cand_list[:10]):
try:
ast_tree = decode_tree_to_python_ast(cand.tree)
code = astor.to_source(ast_tree)
exg_decode_results.append((cid, cand, ast_tree, code))
except:
if verbose:
print "Exception in converting tree to code:"
print '-' * 60
print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)
traceback.print_exc(file=sys.stdout)
print '-' * 60
cum_num += 1
if cum_num % 50 == 0 and verbose:
print '%d examples so far ...' % cum_num
decode_results.append(exg_decode_results)
return decode_results
# serialize_to_file(decode_results, '%s.decode_results.profile' % dataset.name)
def decode_ifttt_dataset(model, dataset, verbose=True):
if verbose:
logging.info('decoding [%s] set, num. examples: %d', dataset.name, dataset.count)
decode_results = []
cum_num = 0
for example in dataset.examples:
cand_list = model.decode(example, dataset.grammar, dataset.terminal_vocab,
beam_size=config.beam_size, max_time_step=config.decode_max_time_step)
exg_decode_results = []
for cid, cand in enumerate(cand_list[:10]):
try:
exg_decode_results.append((cid, cand))
except:
if verbose:
print "Exception in converting tree to code:"
print '-' * 60
print 'raw_id: %d, beam pos: %d' % (example.raw_id, cid)
traceback.print_exc(file=sys.stdout)
print '-' * 60
cum_num += 1
if cum_num % 50 == 0 and verbose:
print '%d examples so far ...' % cum_num
decode_results.append(exg_decode_results)
return decode_results