forked from KasperGroesLudvigsen/influenza_transformer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
139 lines (97 loc) · 4.76 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import os
import torch
from torch.utils.data import Dataset
import pandas as pd
from typing import Tuple
class TransformerDataset(Dataset):
"""
Dataset class used for transformer models.
"""
def __init__(self,
data: torch.tensor,
indices: list,
enc_seq_len: int,
dec_seq_len: int,
target_seq_len: int
) -> None:
"""
Args:
data: tensor, the entire train, validation or test data sequence
before any slicing. If univariate, data.size() will be
[number of samples, number of variables]
where the number of variables will be equal to 1 + the number of
exogenous variables. Number of exogenous variables would be 0
if univariate.
indices: a list of tuples. Each tuple has two elements:
1) the start index of a sub-sequence
2) the end index of a sub-sequence.
The sub-sequence is split into src, trg and trg_y later.
enc_seq_len: int, the desired length of the input sequence given to the
the first layer of the transformer model.
target_seq_len: int, the desired length of the target sequence (the output of the model)
target_idx: The index position of the target variable in data. Data
is a 2D tensor
"""
super().__init__()
self.indices = indices
self.data = data
print("From get_src_trg: data size = {}".format(data.size()))
self.enc_seq_len = enc_seq_len
self.dec_seq_len = dec_seq_len
self.target_seq_len = target_seq_len
def __len__(self):
return len(self.indices)
def __getitem__(self, index):
"""
Returns a tuple with 3 elements:
1) src (the encoder input)
2) trg (the decoder input)
3) trg_y (the target)
"""
# Get the first element of the i'th tuple in the list self.indicesasdfas
start_idx = self.indices[index][0]
# Get the second (and last) element of the i'th tuple in the list self.indices
end_idx = self.indices[index][1]
sequence = self.data[start_idx:end_idx]
#print("From __getitem__: sequence length = {}".format(len(sequence)))
src, trg, trg_y = self.get_src_trg(
sequence=sequence,
enc_seq_len=self.enc_seq_len,
dec_seq_len=self.dec_seq_len,
target_seq_len=self.target_seq_len
)
return src, trg, trg_y
def get_src_trg(
self,
sequence: torch.Tensor,
enc_seq_len: int,
dec_seq_len: int,
target_seq_len: int
) -> Tuple[torch.tensor, torch.tensor, torch.tensor]:
"""
Generate the src (encoder input), trg (decoder input) and trg_y (the target)
sequences from a sequence.
Args:
sequence: tensor, a 1D tensor of length n where
n = encoder input length + target sequence length
enc_seq_len: int, the desired length of the input to the transformer encoder
target_seq_len: int, the desired length of the target sequence (the
one against which the model output is compared)
Return:
src: tensor, 1D, used as input to the transformer model
trg: tensor, 1D, used as input to the transformer model
trg_y: tensor, 1D, the target sequence against which the model output
is compared when computing loss.
"""
assert len(sequence) == enc_seq_len + target_seq_len, "Sequence length does not equal (input length + target length)"
# encoder input
src = sequence[:enc_seq_len]
# decoder input. As per the paper, it must have the same dimension as the
# target sequence, and it must contain the last value of src, and all
# values of trg_y except the last (i.e. it must be shifted right by 1)
trg = sequence[enc_seq_len-1:len(sequence)-1]
assert len(trg) == target_seq_len, "Length of trg does not match target sequence length"
# The target sequence against which the model output will be compared to compute loss
trg_y = sequence[-target_seq_len:]
assert len(trg_y) == target_seq_len, "Length of trg_y does not match target sequence length"
return src, trg, trg_y.squeeze(-1) # change size from [batch_size, target_seq_len, num_features] to [batch_size, target_seq_len]