-
Notifications
You must be signed in to change notification settings - Fork 70
/
data.lua
68 lines (59 loc) · 1.92 KB
/
data.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
--
---- Copyright (c) 2014, Facebook, Inc.
---- All rights reserved.
----
---- This source code is licensed under the Apache 2 license found in the
---- LICENSE file in the root directory of this source tree.
----
local stringx = require('pl.stringx')
local file = require('pl.file')
local ptb_path = "./data/"
local vocab_idx = 0
local vocab_map = {}
-- Stacks replicated, shifted versions of x_inp
-- into a single matrix of size (x_inp:size(1) / batch_size) x batch_size.
local function replicate(x_inp, batch_size)
local s = x_inp:size(1)
local x = torch.zeros(torch.floor(s / batch_size), batch_size)
for i = 1, batch_size do
local start = torch.round((i - 1) * s / batch_size) + 1
local finish = start + x:size(1) - 1
x:sub(1, x:size(1), i, i):copy(x_inp:sub(start, finish))
end
return x
end
local function load_data(fname)
local data = file.read(fname)
data = stringx.replace(data, '\n', '<eos>')
data = stringx.split(data)
print(string.format("Loading %s, size of data = %d", fname, #data))
local x = torch.zeros(#data)
for i = 1, #data do
if vocab_map[data[i]] == nil then
vocab_idx = vocab_idx + 1
vocab_map[data[i]] = vocab_idx
end
x[i] = vocab_map[data[i]]
end
return x
end
local function traindataset(batch_size)
local x = load_data(ptb_path .. "ptb.train.txt")
x = replicate(x, batch_size)
return x
end
-- Intentionally we repeat dimensions without offseting.
-- Pass over this batch corresponds to the fully sequential processing.
local function testdataset(batch_size)
local x = load_data(ptb_path .. "ptb.test.txt")
x = x:resize(x:size(1), 1):expand(x:size(1), batch_size)
return x
end
local function validdataset(batch_size)
local x = load_data(ptb_path .. "ptb.valid.txt")
x = replicate(x, batch_size)
return x
end
return {traindataset=traindataset,
testdataset=testdataset,
validdataset=validdataset}