From 9e2ff35fa13aeccd7a9331cd6739b56901cc6bab Mon Sep 17 00:00:00 2001 From: julianser Date: Sun, 20 Mar 2016 10:04:08 -0400 Subject: [PATCH] Improved script for generating context/response pairs. --- create-text-file-for-tests.py | 98 ++++++++++++++++++++++++----------- 1 file changed, 67 insertions(+), 31 deletions(-) diff --git a/create-text-file-for-tests.py b/create-text-file-for-tests.py index c3e8f01..4b87c53 100644 --- a/create-text-file-for-tests.py +++ b/create-text-file-for-tests.py @@ -13,6 +13,8 @@ python create-text-file-for-tests.py Test_SplitByDialogues.dialogues.pkl --utterances_to_predict 2 --max_words_in_context 300 +NOTE: It's better to use the original dialogues in plain text for building the context/response pairs, since we can then avoid unknown tokens! + @author Iulian Vlad Serban """ @@ -23,7 +25,6 @@ import logging import time import sys -import search import collections import string @@ -52,7 +53,7 @@ def parse_args(): help="Path to the model prefix (without _model.npz or _state.pkl)") parser.add_argument("test_file", - help="Path to the test file (pickled list, with one dialogue per entry)") + help="Path to the test file (pickled list, with one dialogue per entry; or plain text file with one dialogue per line)") parser.add_argument("--utterances_to_predict", type=int, default=1, @@ -76,51 +77,86 @@ def main(): # Load dictionary # Load dictionaries to convert str to idx and vice-versa - #raw_dict = cPickle.load(open(state['dictionary'], 'r')) - raw_dict = cPickle.load(open('../Data/Dataset.dict.pkl', 'r')) # HACK + raw_dict = cPickle.load(open(state['dictionary'], 'r')) str_to_idx = dict([(tok, tok_id) for tok, tok_id, _, _ in raw_dict]) idx_to_str = dict([(tok_id, tok) for tok, tok_id, freq, _ in raw_dict]) - - test_dialogues = cPickle.load(open(args.test_file, 'r')) - assert args.utterances_to_predict > 0 - utterances_to_predict = args.utterances_to_predict - + assert len(args.test_file) > 3 test_contexts = '' test_responses = '' + utterances_to_predict = args.utterances_to_predict + assert args.utterances_to_predict > 0 + + # Is it a pickle file? Then process using model dictionaries.. + if args.test_file[len(args.test_file)-4:len(args.test_file)] == '.pkl': + test_dialogues = cPickle.load(open(args.test_file, 'r')) + for test_dialogueid,test_dialogue in enumerate(test_dialogues): + if test_dialogueid % 100 == 0: + print 'test_dialogue', test_dialogueid + + utterances = [] + current_utterance = [] + for word in test_dialogue: + current_utterance += [word] + if word == state['eos_sym']: + utterances += [current_utterance] + current_utterance = [] + + + + context_utterances = [] + prediction_utterances = [] + for utteranceid, utterance in enumerate(utterances): + if utteranceid >= len(utterances) - utterances_to_predict: + prediction_utterances += utterance + else: + context_utterances += utterance + + if args.max_words_in_context > 0: + while len(context_utterances) > args.max_words_in_context: + del context_utterances[0] + + + test_contexts += indices_to_words(idx_to_str, context_utterances) + '\n' + test_responses += indices_to_words(idx_to_str, prediction_utterances) + '\n' + else: # Assume it's a text file - for test_dialogueid,test_dialogue in enumerate(test_dialogues): - if test_dialogueid % 100 == 0: - print 'test_dialogue', test_dialogueid + test_dialogues = [[]] + lines = open(args.test_file, "r").readlines() + if len(lines): + test_dialogues = [x.strip() for x in lines] - utterances = [] - current_utterance = [] - for word in test_dialogue: - current_utterance += [word] - if word == state['eos_sym']: - utterances += [current_utterance] - current_utterance = [] + for test_dialogueid,test_dialogue in enumerate(test_dialogues): + if test_dialogueid % 100 == 0: + print 'test_dialogue', test_dialogueid + utterances = [] + current_utterance = [] + for word in test_dialogue.split(): + current_utterance += [word] + if word == state['end_sym_sentence']: + utterances += [current_utterance] + current_utterance = [] + context_utterances = [] + prediction_utterances = [] + for utteranceid, utterance in enumerate(utterances): + if utteranceid >= len(utterances) - utterances_to_predict: + prediction_utterances += utterance + else: + context_utterances += utterance - context_utterances = [] - prediction_utterances = [] - for utteranceid, utterance in enumerate(utterances): - if utteranceid >= len(utterances) - utterances_to_predict: - prediction_utterances += utterance - else: - context_utterances += utterance + if args.max_words_in_context > 0: + while len(context_utterances) > args.max_words_in_context: + del context_utterances[0] - if args.max_words_in_context > 0: - while len(context_utterances) > args.max_words_in_context: - del context_utterances[0] + test_contexts += ' '.join(context_utterances) + '\n' + test_responses += ' '.join(prediction_utterances) + '\n' - test_contexts += indices_to_words(idx_to_str, context_utterances) + '\n' - test_responses += indices_to_words(idx_to_str, prediction_utterances) + '\n' print('Writing to files...') f = open('test_contexts.txt','w')