-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathexample.py
100 lines (77 loc) · 4.08 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
# modified example from https://pytorch.org/docs/1.2.0/nn.html#lstm
import torch
from torch.nn import RNNCell, LSTMCell
import rnnlib
seq_len = 5
batch_size = 3
input_size = 8
hidden_size = 10
num_layers = 2
dropout = 0.3
r_dropout = 0.25
bidirectional = True
num_directions = 2 if bidirectional else 1
# ------------------------ Examples of RNNFrame ------------------------
rnn_cells = [
[RNNCell(input_size, hidden_size),
RNNCell(input_size, hidden_size)], # 1st bidirectional RNN layer
[RNNCell(hidden_size * num_directions, hidden_size),
RNNCell(hidden_size * num_directions, hidden_size)] # 2nd bidirectional RNN layer
]
assert len(rnn_cells) == num_layers
assert all(len(rnn_layer_cells) == num_directions for rnn_layer_cells in rnn_cells)
# with batch_first=False ------------------------------------------------
rnn = rnnlib.RNNFrame(rnn_cells, dropout=dropout, bidirectional=bidirectional)
# rnn = torch.nn.RNN(input_size, hidden_size, num_layers, bidirectional=bidirectional)
input = torch.randn(seq_len, batch_size, input_size)
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
output, hn = rnn(input, h0)
print(output.size())
# with batch_first=True ------------------------------------------------
rnn = rnnlib.RNNFrame(rnn_cells, dropout=dropout, bidirectional=bidirectional, batch_first=True)
# rnn = torch.nn.RNN(input_size, hidden_size, num_layers, bidirectional=bidirectional, batch_first=True)
input = torch.randn(batch_size, seq_len, input_size)
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
output, hn = rnn(input, h0)
print(output.size())
# ------------------------ Examples of LSTMFrame ------------------------
rnn_cells = [
[LSTMCell(input_size, hidden_size),
LSTMCell(input_size, hidden_size)], # 1st bidirectional LSTM layer
[LSTMCell(hidden_size * num_directions, hidden_size),
LSTMCell(hidden_size * num_directions, hidden_size)] # 2nd bidirectional LSTM layer
]
# 'rnn_cells' is a list of forward/backward LSTM cell pairs.
# Each pair corresponds to a layer of bidirectional LSTM.
# You can replace 'LSTMCell' with your custom LSTM cell class.
# Also you can compose 'rnn_cells' with heterogeneous LSTM cells.
#
# Caution: Non-LSTM cells, which don't distinguish hidden states and cell states,
# such as 'RNNCell' or 'GRUCell', are not allowed to be included in 'rnn_cells'
assert len(rnn_cells) == num_layers
assert all(len(rnn_layer_cells) == num_directions for rnn_layer_cells in rnn_cells)
# with batch_first=False ------------------------------------------------
rnn = rnnlib.LSTMFrame(rnn_cells, dropout=dropout, bidirectional=bidirectional)
# rnn = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
input = torch.randn(seq_len, batch_size, input_size)
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
c0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
output, (hn, cn) = rnn(input, (h0, c0))
print(output.size())
# with batch_first=True ------------------------------------------------
rnn = rnnlib.LSTMFrame(rnn_cells, dropout=dropout, bidirectional=bidirectional, batch_first=True)
# rnn = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional, batch_first=True)
input = torch.randn(batch_size, seq_len, input_size)
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
c0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
output, (hn, cn) = rnn(input, (h0, c0))
print(output.size())
# ------------------------ Examples of LayerNormLSTM ------------------------
rnn = rnnlib.LayerNormLSTM(input_size, hidden_size, num_layers, dropout=dropout, r_dropout=r_dropout,
bidirectional=bidirectional, layer_norm_enabled=True)
# rnn = torch.nn.LSTM(input_size, hidden_size, num_layers, bidirectional=bidirectional)
input = torch.randn(seq_len, batch_size, input_size)
h0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
c0 = torch.randn(num_layers * num_directions, batch_size, hidden_size)
output, (hn, cn) = rnn(input, (h0, c0))
print(output.size())