Skip to content

Commit

Permalink
Merge pull request #14 to fix inference issue
Browse files Browse the repository at this point in the history
Small hacks to fix #13 due to lack of tags in the inference stage
  • Loading branch information
kylase authored Jan 4, 2019
2 parents 188721d + 9f59121 commit dca0646
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 4 deletions.
8 changes: 7 additions & 1 deletion loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,13 @@ def f(x):
chars = [[char_to_id[c] for c in w if c in char_to_id]
for w in str_words]
caps = [cap_feature(w) for w in str_words]
tags = [tag_to_id[w[-1]] for w in s]

# Hack: This is for an inference stage where tag_to_id is not necessary
if tag_to_id:
tags = [tag_to_id[w[-1]] for w in s]
else:
tags = tag_to_id

data.append({
'str_words': str_words,
'words': words,
Expand Down
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
file.write('\n'.join(string.split()) + '\n')
file.close()
test_sentences = load_sentences(test_file, lower, zeros)
data = prepare_dataset(test_sentences, word_to_id, char_to_id, lower, True)
data = prepare_dataset(test_sentences, word_to_id, char_to_id, {}, lower, True)

for citation in data:
inputs = create_input(citation, model.parameters, False)
Expand Down
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@
# Train network
#
singletons = set([word_to_id[k] for k, v in dico_words_train.items() if v == 1])
n_epochs = 10 # number of epochs over the training set
n_epochs = 1 # number of epochs over the training set
freq_eval = 1000 # evaluate on dev every freq_eval steps
best_dev = -np.inf
best_test = -np.inf
Expand Down
3 changes: 2 additions & 1 deletion utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import codecs
import numpy as np
import theano
from sklearn import metrics

models_path = "./models"
eval_path = "./evaluation"
Expand Down Expand Up @@ -223,6 +222,8 @@ def evaluate(parameters, f_eval, raw_sentences, parsed_sentences,
"""
Evaluate current model using CoNLL script.
"""
# Make sklearn import at runtime only
from sklearn import metrics
results = {'real': [], 'predicted': []}

for _, data in zip(raw_sentences, parsed_sentences):
Expand Down

0 comments on commit dca0646

Please sign in to comment.