forked from lukalabs/cakechat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_predictions.py
174 lines (140 loc) · 7.36 KB
/
generate_predictions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import argparse
import os
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from cakechat.utils.env import init_cuda_env
init_cuda_env()
from cakechat.dialog_model.factory import get_reverse_model
from cakechat.dialog_model.model import CakeChatModel
from cakechat.dialog_model.model_utils import transform_contexts_to_token_ids, lines_to_context
from cakechat.dialog_model.quality import log_predictions, calculate_and_log_val_metrics
from cakechat.utils.files_utils import is_non_empty_file
from cakechat.utils.logger import get_tools_logger
from cakechat.utils.data_types import ModelParam
from cakechat.utils.dataset_loader import get_tokenized_test_lines, load_context_free_val, \
load_context_sensitive_val, get_validation_data_id, get_validation_sets_names
from cakechat.utils.text_processing import get_index_to_token_path, load_index_to_item, get_index_to_condition_path
from cakechat.utils.w2v.model import get_w2v_model_id
from cakechat.config import BASE_CORPUS_NAME, QUESTIONS_CORPUS_NAME, INPUT_SEQUENCE_LENGTH, INPUT_CONTEXT_SIZE, \
PREDICTION_MODES, PREDICTION_MODE_FOR_TESTS, RESULTS_PATH, DEFAULT_TEMPERATURE, TRAIN_CORPUS_NAME, \
USE_PRETRAINED_W2V_EMBEDDINGS_LAYER
_logger = get_tools_logger(__file__)
def _save_test_results(test_dataset, predictions_filename, nn_model, prediction_mode, **kwargs):
context_sensitive_val = load_context_sensitive_val(nn_model.token_to_index, nn_model.condition_to_index)
context_free_val = load_context_free_val(nn_model.token_to_index)
calculate_and_log_val_metrics(nn_model, context_sensitive_val, context_free_val, prediction_mode,
calculate_ngram_distance=False)
test_dataset_ids = transform_contexts_to_token_ids(
list(lines_to_context(test_dataset)), nn_model.token_to_index, INPUT_SEQUENCE_LENGTH, INPUT_CONTEXT_SIZE)
log_predictions(predictions_filename, test_dataset_ids, nn_model, prediction_modes=[prediction_mode], **kwargs)
def predict(model_path,
tokens_index_path=None,
conditions_index_path=None,
default_predictions_path=None,
reverse_model_weights=None,
temperatures=None,
prediction_mode=None):
if not tokens_index_path:
tokens_index_path = get_index_to_token_path(BASE_CORPUS_NAME)
if not conditions_index_path:
conditions_index_path = get_index_to_condition_path(BASE_CORPUS_NAME)
if not temperatures:
temperatures = [DEFAULT_TEMPERATURE]
if not prediction_mode:
prediction_mode = PREDICTION_MODE_FOR_TESTS
# Construct list of parameters values for all possible combinations of passed parameters
prediction_params = [dict()]
if reverse_model_weights:
prediction_params = [
dict(params, mmi_reverse_model_score_weight=w)
for params in prediction_params
for w in reverse_model_weights
]
if temperatures:
prediction_params = [dict(params, temperature=t) for params in prediction_params for t in temperatures]
if not is_non_empty_file(tokens_index_path):
_logger.warning('Couldn\'t find tokens_index file:\n{}. \nExiting...'.format(tokens_index_path))
return
index_to_token = load_index_to_item(tokens_index_path)
index_to_condition = load_index_to_item(conditions_index_path)
w2v_model_id = get_w2v_model_id() if USE_PRETRAINED_W2V_EMBEDDINGS_LAYER else None
nn_model = CakeChatModel(
index_to_token,
index_to_condition,
training_data_param=ModelParam(value=None, id=TRAIN_CORPUS_NAME),
validation_data_param=ModelParam(value=None, id=get_validation_data_id(get_validation_sets_names())),
w2v_model_param=ModelParam(value=None, id=w2v_model_id),
model_init_path=model_path,
reverse_model=get_reverse_model(prediction_mode))
nn_model.init_model()
nn_model.resolve_model()
if not default_predictions_path:
default_predictions_path = os.path.join(RESULTS_PATH, 'results', 'predictions_' + nn_model.model_name)
# Get path for each combination of parameters
predictions_paths = []
# Add suffix to the filename only for parameters that have a specific value passed as an argument
# If no parameters were specified, no suffix is added
if len(prediction_params) > 1:
for cur_params in prediction_params:
cur_path = '{base_path}_{params_str}.tsv'.format(
base_path=default_predictions_path,
params_str='_'.join(['{}_{}'.format(k, v) for k, v in cur_params.items()]))
predictions_paths.append(cur_path)
else:
predictions_paths = [default_predictions_path + '.tsv']
_logger.info('Model for prediction: {}'.format(nn_model.model_path))
_logger.info('Tokens index: {}'.format(tokens_index_path))
_logger.info('File with questions: {}'.format(QUESTIONS_CORPUS_NAME))
_logger.info('Files to dump responses: {}'.format('\n'.join(predictions_paths)))
_logger.info('Prediction parameters {}'.format('\n'.join([str(x) for x in prediction_params])))
processed_test_set = get_tokenized_test_lines(QUESTIONS_CORPUS_NAME, set(index_to_token.values()))
for cur_params, cur_path in zip(prediction_params, predictions_paths):
_logger.info('Predicting with the following params: {}'.format(cur_params))
_save_test_results(processed_test_set, cur_path, nn_model, prediction_mode, **cur_params)
def parse_args():
argparser = argparse.ArgumentParser()
argparser.add_argument(
'-p', '--prediction-mode', action='store', help='Prediction mode', choices=PREDICTION_MODES, default=None)
argparser.add_argument(
'-m',
'--model',
action='store',
help='Path to the file with your model. '
'Be careful, model parameters are inferred from config, not from the filename',
default=None)
argparser.add_argument(
'-i',
'--tokens_index',
action='store',
help='Path to the json file with index_to_token dictionary.',
default=None)
argparser.add_argument(
'-c',
'--conditions_index',
action='store',
help='Path to the json file with index_to_condition dictionary.',
default=None)
argparser.add_argument(
'-o',
'--output',
action='store',
help='Path to the file to dump predictions.'
'Be careful, file extension ".tsv" is appended to the filename automatically',
default=None)
argparser.add_argument(
'-r',
'--reverse-model-weights',
action='append',
type=float,
help='Reverse model score weight for prediction with MMI-reranking objective. Used only in *-reranking modes',
default=None)
argparser.add_argument('-t', '--temperatures', action='append', help='temperature values', default=None, type=float)
args = argparser.parse_args()
# Extra params validation
reranking_modes = [PREDICTION_MODES.beamsearch_reranking, PREDICTION_MODES.sampling_reranking]
if args.reverse_model_weights and args.prediction_mode not in reranking_modes:
raise Exception('--reverse-model-weights param can be specified only for *-reranking prediction modes.')
return args
if __name__ == '__main__':
args = vars(parse_args())
predict(args.pop('model'), args.pop('tokens_index'), args.pop('conditions_index'), args.pop('output'), **args)