-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathexample_part1.lua
116 lines (103 loc) · 4.22 KB
/
example_part1.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
require 'rnn'
cmd = torch.CmdLine()
cmd:text()
cmd:text('Simple LSTM example for the RNN library')
cmd:text()
cmd:text('Options')
cmd:option('-use_saved',false,'Use previously saved inputs and trained network instead of new')
cmd:text()
-- parse input params
opt = cmd:parse(arg)
-- Keep the input layer small so the model trains / converges quickly while training
local inputSize = 10
-- Larger numbers here mean more complex problems can be solved, but can also over-fit. 256 works well for now
local hiddenSize = 256
-- We want the network to classify the inputs using a one-hot representation of the outputs
local outputSize = 3
-- the dataset size is the total number of examples we want to present to the LSTM
local dsSize=200
-- We present the dataset to the network in batches where batchSize << dsSize
local batchSize=5
-- seqLength is the length of each sequence, i.e. the number of "events" we want to pass to the LSTM
-- to make up a single example. I'd like this to be dynamic ideally for the YOOCHOOSE dataset..
local seqLength=5
-- number of target classes or labels, needs to be the same as outputSize above
-- or we get the dreaded "ClassNLLCriterion.lua:46: Assertion `cur_target >= 0 && cur_target < n_classes' failed. "
local nClass = 3
function build_data()
local inputs = {}
local targets = {}
--Use previously created and saved data
if opt.use_saved then
inputs = torch.load('training.t7')
targets = torch.load('targets.t7')
rnn = torch.load('trained-model.t7')
else
for i = 1, dsSize do
-- populate both tables to get ready for training
table.insert(inputs, torch.randn(batchSize,inputSize))
table.insert(targets, torch.LongTensor(batchSize):random(1,nClass))
end
end
return inputs, targets
end
function build_network(inputSize, hiddenSize, outputSize)
if opt.use_saved then
rnn = torch.load('trained-model.t7')
else
rnn = nn.Sequential()
:add(nn.Linear(inputSize, hiddenSize))
:add(nn.LSTM(hiddenSize, hiddenSize))
:add(nn.LSTM(hiddenSize, hiddenSize))
:add(nn.Linear(hiddenSize, outputSize))
:add(nn.LogSoftMax())
-- wrap this in a Sequencer such that we can forward/backward
-- entire sequences of length seqLength at once
rnn = nn.Sequencer(rnn)
end
return rnn
end
function save(inputs, targets, rnn)
-- Save out the tensors we created and the model itself so we can load it back in
-- if -use_saved is set to true
torch.save('training.t7', inputs)
torch.save('targets.t7', targets)
torch.save('trained-model.t7', rnn)
end
-- two tables to hold the *full* dataset input and target tensors
local inputs, targets = build_data()
local rnn = build_network(inputSize, hiddenSize, outputSize)
-- Decorate the regular nn Criterion with a SequencerCriterion as this simplifies training quite a bit
-- SequencerCriterion requires tables as input, and this affects the code we have to write inside the training for loop
local seqC = nn.SequencerCriterion(nn.ClassNLLCriterion())
local start = torch.tic()
-- Now let's train our network on the small, fake dataset we generated earlier
rnn:training()
-- Feed our LSTM the dsSize examples in total, broken into batchSize chunks
for numEpochs=0,200,1 do
local start = torch.tic()
local err = 0
for offset=1,dsSize,batchSize+seqLength do
-- We need to get a subset (of size batchSize) of the inputs and targets tables
local batchInputs = {}
local batchTargets = {}
-- start needs to be "2" and end "batchSize-1" to correctly index
-- all of the examples in the "inputs" and "targets" tables
for i = 2, batchSize+seqLength-1,1 do
table.insert(batchInputs, inputs[offset+i])
table.insert(batchTargets, targets[offset+i])
end
-- forward
local out = rnn:forward(batchInputs)
err = err + seqC:forward(out, batchTargets)
-- backward
local gradOut = seqC:backward(out, batchTargets)
rnn:backward(batchInputs, gradOut)
-- We update params at the end of each batch
rnn:updateParameters(0.05)
rnn:zeroGradParameters()
end
local currT = torch.toc(start)
print('loss', err/dsSize .. ' in ', currT .. ' s')
end
save(inputs, targets, rnn)