forked from KasperGroesLudvigsen/influenza_transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpositional_encoder.py
74 lines (56 loc) · 2.3 KB
/
positional_encoder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torch.nn as nn
import math
from torch import nn, Tensor
class PositionalEncoder(nn.Module):
"""
The authors of the original transformer paper describe very succinctly what
the positional encoding layer does and why it is needed:
"Since our model contains no recurrence and no convolution, in order for the
model to make use of the order of the sequence, we must inject some
information about the relative or absolute position of the tokens in the
sequence." (Vaswani et al, 2017)
Adapted from:
https://pytorch.org/tutorials/beginner/transformer_tutorial.html
"""
def __init__(
self,
dropout: float=0.1,
max_seq_len: int=5000,
d_model: int=512,
batch_first: bool=False
):
"""
Parameters:
dropout: the dropout rate
max_seq_len: the maximum length of the input sequences
d_model: The dimension of the output of sub-layers in the model
(Vaswani et al, 2017)
"""
super().__init__()
self.d_model = d_model
self.dropout = nn.Dropout(p=dropout)
self.batch_first = batch_first
# adapted from PyTorch tutorial
position = torch.arange(max_seq_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
if self.batch_first:
pe = torch.zeros(1, max_seq_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
else:
pe = torch.zeros(max_seq_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x: Tensor) -> Tensor:
"""
Args:
x: Tensor, shape [batch_size, enc_seq_len, dim_val] or
[enc_seq_len, batch_size, dim_val]
"""
if self.batch_first:
x = x + self.pe[:,:x.size(1)]
else:
x = x + self.pe[:x.size(0)]
return self.dropout(x)