Skip to content

Commit

Permalink
Improved script for generating context/response pairs.
Browse files Browse the repository at this point in the history
  • Loading branch information
julianser committed Mar 20, 2016
1 parent e66f9e8 commit 9e2ff35
Showing 1 changed file with 67 additions and 31 deletions.
98 changes: 67 additions & 31 deletions create-text-file-for-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
python create-text-file-for-tests.py <model> Test_SplitByDialogues.dialogues.pkl --utterances_to_predict 2 --max_words_in_context 300
NOTE: It's better to use the original dialogues in plain text for building the context/response pairs, since we can then avoid unknown tokens!
@author Iulian Vlad Serban
"""

Expand All @@ -23,7 +25,6 @@
import logging
import time
import sys
import search

import collections
import string
Expand Down Expand Up @@ -52,7 +53,7 @@ def parse_args():
help="Path to the model prefix (without _model.npz or _state.pkl)")

parser.add_argument("test_file",
help="Path to the test file (pickled list, with one dialogue per entry)")
help="Path to the test file (pickled list, with one dialogue per entry; or plain text file with one dialogue per line)")

parser.add_argument("--utterances_to_predict",
type=int, default=1,
Expand All @@ -76,51 +77,86 @@ def main():
# Load dictionary

# Load dictionaries to convert str to idx and vice-versa
#raw_dict = cPickle.load(open(state['dictionary'], 'r'))
raw_dict = cPickle.load(open('../Data/Dataset.dict.pkl', 'r')) # HACK
raw_dict = cPickle.load(open(state['dictionary'], 'r'))

str_to_idx = dict([(tok, tok_id) for tok, tok_id, _, _ in raw_dict])
idx_to_str = dict([(tok_id, tok) for tok, tok_id, freq, _ in raw_dict])



test_dialogues = cPickle.load(open(args.test_file, 'r'))
assert args.utterances_to_predict > 0
utterances_to_predict = args.utterances_to_predict

assert len(args.test_file) > 3
test_contexts = ''
test_responses = ''
utterances_to_predict = args.utterances_to_predict
assert args.utterances_to_predict > 0

# Is it a pickle file? Then process using model dictionaries..
if args.test_file[len(args.test_file)-4:len(args.test_file)] == '.pkl':
test_dialogues = cPickle.load(open(args.test_file, 'r'))
for test_dialogueid,test_dialogue in enumerate(test_dialogues):
if test_dialogueid % 100 == 0:
print 'test_dialogue', test_dialogueid

utterances = []
current_utterance = []
for word in test_dialogue:
current_utterance += [word]
if word == state['eos_sym']:
utterances += [current_utterance]
current_utterance = []



context_utterances = []
prediction_utterances = []
for utteranceid, utterance in enumerate(utterances):
if utteranceid >= len(utterances) - utterances_to_predict:
prediction_utterances += utterance
else:
context_utterances += utterance

if args.max_words_in_context > 0:
while len(context_utterances) > args.max_words_in_context:
del context_utterances[0]


test_contexts += indices_to_words(idx_to_str, context_utterances) + '\n'
test_responses += indices_to_words(idx_to_str, prediction_utterances) + '\n'

else: # Assume it's a text file

for test_dialogueid,test_dialogue in enumerate(test_dialogues):
if test_dialogueid % 100 == 0:
print 'test_dialogue', test_dialogueid
test_dialogues = [[]]
lines = open(args.test_file, "r").readlines()
if len(lines):
test_dialogues = [x.strip() for x in lines]

utterances = []
current_utterance = []
for word in test_dialogue:
current_utterance += [word]
if word == state['eos_sym']:
utterances += [current_utterance]
current_utterance = []
for test_dialogueid,test_dialogue in enumerate(test_dialogues):
if test_dialogueid % 100 == 0:
print 'test_dialogue', test_dialogueid

utterances = []
current_utterance = []
for word in test_dialogue.split():
current_utterance += [word]
if word == state['end_sym_sentence']:
utterances += [current_utterance]
current_utterance = []

context_utterances = []
prediction_utterances = []
for utteranceid, utterance in enumerate(utterances):
if utteranceid >= len(utterances) - utterances_to_predict:
prediction_utterances += utterance
else:
context_utterances += utterance

context_utterances = []
prediction_utterances = []
for utteranceid, utterance in enumerate(utterances):
if utteranceid >= len(utterances) - utterances_to_predict:
prediction_utterances += utterance
else:
context_utterances += utterance
if args.max_words_in_context > 0:
while len(context_utterances) > args.max_words_in_context:
del context_utterances[0]

if args.max_words_in_context > 0:
while len(context_utterances) > args.max_words_in_context:
del context_utterances[0]

test_contexts += ' '.join(context_utterances) + '\n'
test_responses += ' '.join(prediction_utterances) + '\n'

test_contexts += indices_to_words(idx_to_str, context_utterances) + '\n'
test_responses += indices_to_words(idx_to_str, prediction_utterances) + '\n'

print('Writing to files...')
f = open('test_contexts.txt','w')
Expand Down

0 comments on commit 9e2ff35

Please sign in to comment.