-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpreprocess.jl
185 lines (158 loc) · 6.1 KB
/
preprocess.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Sequential preprocessor for large sized text files
# Gather the sequences of the same length s.t. when it is called it brings batches of similar sequences with the help of iterables
import Base: start, next, done
const SOS = "<s>"
const EOS = "</s>"
const UNK = "<unk>"
const LLIMIT = 2
type Data
word_to_index::Dict{AbstractString, Int}
index_to_word::Vector{AbstractString}
batchsize::Int
serve_type::AbstractString
sequences::Dict
end
function Data(datafile; word_to_index=nothing, vocabfile=nothing, serve_type="bitarray", batchsize=20)
# TODO: right now Data type can only be a bit array because of Knet support, no need to following line
# (serve_type == nothing) && error("Please specify the data serve type: onehot, bitarray or sequence")
existing_vocab = (word_to_index != nothing)
if !existing_vocab
word_to_index = Dict{AbstractString, Int}(SOS=>1, EOS=>2, UNK=>3)
vocabfile != nothing && info("Working with provided vocabfile : $vocabfile")
vocabfile != nothing && (V = vocab_from_file(vocabfile))
end
stream = open(datafile)
sequences = Dict{Int64, Array{Any, 1}}()
for line in eachline(stream)
words = Int32[]
push!(words, word_to_index[SOS])
for word in split(line)
if !existing_vocab && vocabfile != nothing && !(word in V)
word = UNK
end
if existing_vocab
index = get(word_to_index, word, word_to_index[UNK])
else
index = get!(word_to_index, word, 1+length(word_to_index))
end
push!(words, index)
end
push!(words, word_to_index[EOS])
skey = length(words)
(skey == LLIMIT) && println("LLMIT needs to be checked") #'ceause we already put <s> and </s> tokens?
(!haskey(sequences, skey)) && (sequences[skey] = Any[])
push!(sequences[skey], words)
end
close(stream)
vocabsize = length(word_to_index)
index_to_word = Array(AbstractString, vocabsize)
for (word, index) in word_to_index
index_to_word[index] = word
end
Data(word_to_index, index_to_word, batchsize, serve_type, sequences)
end
function sentenbatch(nom::Array{Any,1}, from::Int, batchsize::Int, vocabsize::Int; serve_type="bitarr")
total = length(nom)
to = (from + batchsize - 1 < total) ? (from + batchsize - 1) : total
# not to work with surplus sentences
if (to-from + 1 < batchsize)
warning("Surplus does not being cleaned correctly!")
return (nothing, 1)
end
new_from = (to == total) ? 1 : (to + 1)
seqlen = length(nom[1]) # TODO: get rid of length computation give it as an extra argument!
sentences = nom[from:to]
# If Knet supports lookup that part will work correctly, it is commented for speed purposes
#if serve_type == "lookup"
# return (sentences, new_from)
#end
scount = batchsize # modified future code
data = [ falses(scount, vocabsize) for i=1:seqlen ]
for cursor=1:seqlen
for row=1:scount
index = sentences[row][cursor]
data[cursor][row, index] = 1
end
end
return (data, new_from)
end
"""Removes the surplus sentences randomly"""
function clean_seqdict!(seqdict::Dict{Int64,Array{Any,1}}, batchsize::Int)
for seqlen in keys(seqdict)
remain = rem(length(seqdict[seqlen]), batchsize)
while remain != 0
index = rand(1:length(seqdict[seqlen]))
deleteat!(seqdict[seqlen], index)
remain -= 1
end
if isempty(seqdict[seqlen])
delete!(seqdict, seqlen)
end
end
end
function start(s::Data)
sdict = deepcopy(s.sequences)
clean_seqdict!(sdict, s.batchsize)
@assert (!isempty(sdict)) "There is not enough data with that batchsize $(s.batchsize)"
slens = shuffle!(collect(keys(sdict)))
seqlen = pop!(slens)
from = nothing
vocabsize = length(s.word_to_index)
state = (sdict, seqlen, slens,from, vocabsize)
return state
end
function next(s::Data, state)
(sdict, seqlen, slens, from, vocabsize) = state
if from == nothing
(item, new_from) = sentenbatch(sdict[seqlen], 1, s.batchsize, vocabsize)
elseif from == 1
seqlen = pop!(slens)
(item, new_from) = sentenbatch(sdict[seqlen], from, s.batchsize, vocabsize)
else
(item, new_from) = sentenbatch(sdict[seqlen], from, s.batchsize, vocabsize)
end
from = new_from
state = (sdict, seqlen, slens, from, vocabsize)
return (item, state)
end
function done(s::Data, state)
(sdict, seqlen, slens, from, vocabsize) = state
return isempty(slens) && (from == 1)
end
""" Creates a set that contains all the words in that file, vocab file given as each vocab in a single line
sorted_counted represents pure create_vocab.sh output
"""
function vocab_from_file(vocabfile; sorted_counted=false)
V = Set{AbstractString}()
open(vocabfile) do file
for line in eachline(file)
line = split(line)
if sorted_counted
(length(line)>1) && push!(V, line[2])
else
!isempty(line) && push!(V, line[1])
end
end
end
return V
end
""" Builds the kth sentence from a given sequence """
function ibuild_sentence(tdata::Data, sequence::Array{BitArray{2},1}, kth::Int)
sentence = Array{Any, 1}()
for i=1:length(sequence)
z = find(x->x==true, sequence[i][kth,:])
append!(sentence, z)
end
ret = map(x->tdata.index_to_word[x], sentence)
return ret
end
# FUTURE CODE: if one day knet8 allows us to change batchsize on the fly, following lines will implement surplus batch implementation, this code snippet would be put on sentenbatch
# (length(sentences) != batchsize) && (println("I am using the surplus sentences:) $from : $to"))
# scount = length(sentences) # it can be either batchsize or the surplus sentences
# data = [ falses(scount, vocabsize) for i=1:seqlen ]
# for cursor=1:seqlen
# for row=1:scount
# index = sentences[row][cursor]
# data[cursor][row, index] = 1
# end
# end