-
Notifications
You must be signed in to change notification settings - Fork 130
/
demo-lua-torch-layer.config
80 lines (72 loc) · 2.78 KB
/
demo-lua-torch-layer.config
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
#!returnn.py
# kate: syntax python;
import os
demo_name, _ = os.path.splitext(__file__)
print("Hello, experiment: %s" % demo_name)
task = "train"
train = {"class": "TaskXmlModelingDataset", "num_seqs": 1000}
dev = {"class": "TaskXmlModelingDataset", "num_seqs": 100, "fixed_random_seed": 1}
num_inputs = 12
num_outputs = 12
batching = "random"
batch_size = 5000
max_seqs = 10
chunking = "200:200"
# Ultimately, I want to have a bidir LSTM in Torch.
# Some refs:
# https://github.com/Element-Research/rnn
# https://apaszke.github.io/lstm-explained.html
# https://github.com/jcjohnson/torch-rnn
# https://github.com/karpathy/char-rnn
# https://github.com/wojzaremba/lstm/blob/master/main.lua
# https://gist.github.com/karpathy/7bae8033dcf5ca2630ba
network = {
"fw0": {"class": "torch", "n_out": 100,
"lua_fw_func": """
function (x, W, b, index)
-- x is (time,batch,feature)
n_time = x:size()[1]
n_batch = x:size()[2]
x = x:permute(2,1,3) -- Torch expects: (batch,time,feature)
-- For bmm, Torch expects: (batch,feature,ydim) so we get (batch,time,ydim)
-- W is (feature,ydim).
W = W:contiguous():view(1, W:size()[1], W:size()[2]):expand(n_batch, W:size()[1], W:size()[2])
b = b:contiguous():view(1, 1, b:size()[1]):expand(n_batch, n_time, b:size()[1])
return torch.bmm(x, W):add(b):permute(2,1,3)
end
""",
"lua_bw_func": """
function (x, W, b, index, Dy)
-- We are supposed to return Dx, DW, Db.
n_time = x:size()[1]
n_batch = x:size()[2]
n_feat = x:size()[3]
ydim = Dy:size()[3]
x = x:permute(2,1,3) -- Torch expects: (batch,time,feature)
Dy = Dy:permute(2,1,3) -- (batch,time,ydim)
x_ = x:contiguous():view(n_time*n_batch, n_feat, 1)
Dy_ = Dy:contiguous():view(n_time*n_batch, 1, ydim)
DW = torch.sum(torch.bmm(x_, Dy_), 1):squeeze(1) -- (feature,ydim)
Db = torch.sum(Dy:contiguous():view(n_time*n_batch, ydim), 1):squeeze(1) -- (ydim,)
-- W is (feature,ydim).
W_ = W:contiguous():view(1, W:size()[1], W:size()[2]):expand(n_batch, W:size()[1], W:size()[2]) -- (batch,feature,ydim)
WT = W_:permute(1,3,2) -- (batch,ydim,feature)
Dx = torch.bmm(Dy, WT) -- (batch,time,feature)
Dx = Dx:permute(2,1,3) -- (time,batch,feature)
return Dx, DW, Db
end
""",
"params": [
{"name": "W", "n": 12, "m": 100},
{"name": "b", "class": "create_bias", "n": 100}]},
"output": {"class": "softmax", "loss": "ce", "from": ["fw0"]}
}
# training
adam = True
learning_rate = 0.01
model = "/tmp/returnn.%s.network" % demo_name
num_epochs = 100
save_interval = 20
gradient_clip = 0
# log
log_verbosity = 4