-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpredict.py
51 lines (45 loc) · 2.33 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import os, argparse, json
from models.rnn import RNN
from util import clean_text
dir = os.path.dirname(os.path.realpath(__file__))
parser = argparse.ArgumentParser()
parser.add_argument("--debug", nargs="?", const=True, default=False, type=bool, help="Debug mode (default: %(default)s)")
parser.add_argument("--model_dir", default="rnn", type=str, help="The trained RNN dir (default: %(default)s)")
parser.add_argument("--inputs", default="Un iench", type=str, help="Choose the beginning of the predicted sentence (see this as an initialization)")
parser.add_argument("--random", nargs="?", const=True, default=False, type=bool, help="Add some randomness, predict using temperature (default: %(default)s)")
parser.add_argument("--temperature", default=1., type=float, help="The temperature for predictions (default: %(default)s)")
parser.add_argument("--top_k", default=1, type=int, help="Return the top K prediction (default: %(default)s)")
parser.add_argument("--nb_word", default=-1, type=int, help="How many words should it return (default: %(default)s, -1: no limit)")
parser.add_argument("--nb_sentence", default=-1, type=int, help="How many lines should it return (default: %(default)s, -1: no limit)")
parser.add_argument("--nb_para", default=1, type=int, help="How many paragraph should it return (default: %(default)s, -1: no limit)")
# parser.add_argument('--use_server', nargs="?", const=True, default=False, type=bool, help='Should use the Server architecture')
args = parser.parse_args()
results_dir = dir + '/results'
rnn_dir = dir + '/' + args.model_dir
config = vars(args)
config['log_dir'] = rnn_dir
config['restore_embedding'] = False
config['seq_length'] = None
input_words = clean_text(config['inputs'])
# if args.use_server is True:
# with open('clusterSpec.json') as f:
# clusterSpec = json.load(f)
# config['target'] = 'grpc://' + clusterSpec['server'][0]
# pass
rnn = RNN(config)
y = rnn.predict(input_words, config)
print('__BBB_START__') # Marker for the Regexp used in the App, do not remove
json = json.dumps({
'config': {
'inputs': args.inputs,
'random': args.random,
'temperature': args.temperature,
'top_k': args.top_k,
'nb_word': args.nb_word,
'nb_sentence': args.nb_sentence,
'nb_para': args.nb_para,
},
'output': ' '.join(y)
})
print(json)
print('__BBB_END__')