-
Notifications
You must be signed in to change notification settings - Fork 1.1k
/
FiLM.py
268 lines (236 loc) · 11.3 KB
/
FiLM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy import signal
from scipy import special as ss
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def transition(N):
Q = np.arange(N, dtype=np.float64)
R = (2 * Q + 1)[:, None] # / theta
j, i = np.meshgrid(Q, Q)
A = np.where(i < j, -1, (-1.) ** (i - j + 1)) * R
B = (-1.) ** Q[:, None] * R
return A, B
class HiPPO_LegT(nn.Module):
def __init__(self, N, dt=1.0, discretization='bilinear'):
"""
N: the order of the HiPPO projection
dt: discretization step size - should be roughly inverse to the length of the sequence
"""
super(HiPPO_LegT, self).__init__()
self.N = N
A, B = transition(N)
C = np.ones((1, N))
D = np.zeros((1,))
A, B, _, _, _ = signal.cont2discrete((A, B, C, D), dt=dt, method=discretization)
B = B.squeeze(-1)
self.register_buffer('A', torch.Tensor(A).to(device))
self.register_buffer('B', torch.Tensor(B).to(device))
vals = np.arange(0.0, 1.0, dt)
self.register_buffer('eval_matrix', torch.Tensor(
ss.eval_legendre(np.arange(N)[:, None], 1 - 2 * vals).T).to(device))
def forward(self, inputs):
"""
inputs : (length, ...)
output : (length, ..., N) where N is the order of the HiPPO projection
"""
c = torch.zeros(inputs.shape[:-1] + tuple([self.N])).to(device)
cs = []
for f in inputs.permute([-1, 0, 1]):
f = f.unsqueeze(-1)
new = f @ self.B.unsqueeze(0)
c = F.linear(c, self.A) + new
cs.append(c)
return torch.stack(cs, dim=0)
def reconstruct(self, c):
return (self.eval_matrix @ c.unsqueeze(-1)).squeeze(-1)
class SpectralConv1d(nn.Module):
def __init__(self, in_channels, out_channels, seq_len, ratio=0.5):
"""
1D Fourier layer. It does FFT, linear transform, and Inverse FFT.
"""
super(SpectralConv1d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.ratio = ratio
self.modes = min(32, seq_len // 2)
self.index = list(range(0, self.modes))
self.scale = (1 / (in_channels * out_channels))
self.weights_real = nn.Parameter(
self.scale * torch.rand(in_channels, out_channels, len(self.index), dtype=torch.float))
self.weights_imag = nn.Parameter(
self.scale * torch.rand(in_channels, out_channels, len(self.index), dtype=torch.float))
def compl_mul1d(self, order, x, weights_real, weights_imag):
return torch.complex(torch.einsum(order, x.real, weights_real) - torch.einsum(order, x.imag, weights_imag),
torch.einsum(order, x.real, weights_imag) + torch.einsum(order, x.imag, weights_real))
def forward(self, x):
B, H, E, N = x.shape
x_ft = torch.fft.rfft(x)
out_ft = torch.zeros(B, H, self.out_channels, x.size(-1) // 2 + 1, device=x.device, dtype=torch.cfloat)
a = x_ft[:, :, :, :self.modes]
out_ft[:, :, :, :self.modes] = self.compl_mul1d("bjix,iox->bjox", a, self.weights_real, self.weights_imag)
x = torch.fft.irfft(out_ft, n=x.size(-1))
return x
class Model(nn.Module):
"""
Paper link: https://arxiv.org/abs/2205.08897
"""
def __init__(self, configs):
super(Model, self).__init__()
self.task_name = configs.task_name
self.configs = configs
self.seq_len = configs.seq_len
self.label_len = configs.label_len
self.pred_len = configs.seq_len if configs.pred_len == 0 else configs.pred_len
self.seq_len_all = self.seq_len + self.label_len
self.layers = configs.e_layers
self.enc_in = configs.enc_in
self.e_layers = configs.e_layers
# b, s, f means b, f
self.affine_weight = nn.Parameter(torch.ones(1, 1, configs.enc_in))
self.affine_bias = nn.Parameter(torch.zeros(1, 1, configs.enc_in))
self.multiscale = [1, 2, 4]
self.window_size = [256]
configs.ratio = 0.5
self.legts = nn.ModuleList(
[HiPPO_LegT(N=n, dt=1. / self.pred_len / i) for n in self.window_size for i in self.multiscale])
self.spec_conv_1 = nn.ModuleList([SpectralConv1d(in_channels=n, out_channels=n,
seq_len=min(self.pred_len, self.seq_len),
ratio=configs.ratio) for n in
self.window_size for _ in range(len(self.multiscale))])
self.mlp = nn.Linear(len(self.multiscale) * len(self.window_size), 1)
if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
self.projection = nn.Linear(
configs.d_model, configs.c_out, bias=True)
if self.task_name == 'classification':
self.act = F.gelu
self.dropout = nn.Dropout(configs.dropout)
self.projection = nn.Linear(
configs.enc_in * configs.seq_len, configs.num_class)
def forecast(self, x_enc, x_mark_enc, x_dec_true, x_mark_dec):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x_enc /= stdev
x_enc = x_enc * self.affine_weight + self.affine_bias
x_decs = []
jump_dist = 0
for i in range(0, len(self.multiscale) * len(self.window_size)):
x_in_len = self.multiscale[i % len(self.multiscale)] * self.pred_len
x_in = x_enc[:, -x_in_len:]
legt = self.legts[i]
x_in_c = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0])[:, :, :, jump_dist:]
out1 = self.spec_conv_1[i](x_in_c)
if self.seq_len >= self.pred_len:
x_dec_c = out1.transpose(2, 3)[:, :, self.pred_len - 1 - jump_dist, :]
else:
x_dec_c = out1.transpose(2, 3)[:, :, -1, :]
x_dec = x_dec_c @ legt.eval_matrix[-self.pred_len:, :].T
x_decs.append(x_dec)
x_dec = torch.stack(x_decs, dim=-1)
x_dec = self.mlp(x_dec).squeeze(-1).permute(0, 2, 1)
# De-Normalization from Non-stationary Transformer
x_dec = x_dec - self.affine_bias
x_dec = x_dec / (self.affine_weight + 1e-10)
x_dec = x_dec * stdev
x_dec = x_dec + means
return x_dec
def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x_enc /= stdev
x_enc = x_enc * self.affine_weight + self.affine_bias
x_decs = []
jump_dist = 0
for i in range(0, len(self.multiscale) * len(self.window_size)):
x_in_len = self.multiscale[i % len(self.multiscale)] * self.pred_len
x_in = x_enc[:, -x_in_len:]
legt = self.legts[i]
x_in_c = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0])[:, :, :, jump_dist:]
out1 = self.spec_conv_1[i](x_in_c)
if self.seq_len >= self.pred_len:
x_dec_c = out1.transpose(2, 3)[:, :, self.pred_len - 1 - jump_dist, :]
else:
x_dec_c = out1.transpose(2, 3)[:, :, -1, :]
x_dec = x_dec_c @ legt.eval_matrix[-self.pred_len:, :].T
x_decs.append(x_dec)
x_dec = torch.stack(x_decs, dim=-1)
x_dec = self.mlp(x_dec).squeeze(-1).permute(0, 2, 1)
# De-Normalization from Non-stationary Transformer
x_dec = x_dec - self.affine_bias
x_dec = x_dec / (self.affine_weight + 1e-10)
x_dec = x_dec * stdev
x_dec = x_dec + means
return x_dec
def anomaly_detection(self, x_enc):
# Normalization from Non-stationary Transformer
means = x_enc.mean(1, keepdim=True).detach()
x_enc = x_enc - means
stdev = torch.sqrt(torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5).detach()
x_enc /= stdev
x_enc = x_enc * self.affine_weight + self.affine_bias
x_decs = []
jump_dist = 0
for i in range(0, len(self.multiscale) * len(self.window_size)):
x_in_len = self.multiscale[i % len(self.multiscale)] * self.pred_len
x_in = x_enc[:, -x_in_len:]
legt = self.legts[i]
x_in_c = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0])[:, :, :, jump_dist:]
out1 = self.spec_conv_1[i](x_in_c)
if self.seq_len >= self.pred_len:
x_dec_c = out1.transpose(2, 3)[:, :, self.pred_len - 1 - jump_dist, :]
else:
x_dec_c = out1.transpose(2, 3)[:, :, -1, :]
x_dec = x_dec_c @ legt.eval_matrix[-self.pred_len:, :].T
x_decs.append(x_dec)
x_dec = torch.stack(x_decs, dim=-1)
x_dec = self.mlp(x_dec).squeeze(-1).permute(0, 2, 1)
# De-Normalization from Non-stationary Transformer
x_dec = x_dec - self.affine_bias
x_dec = x_dec / (self.affine_weight + 1e-10)
x_dec = x_dec * stdev
x_dec = x_dec + means
return x_dec
def classification(self, x_enc, x_mark_enc):
x_enc = x_enc * self.affine_weight + self.affine_bias
x_decs = []
jump_dist = 0
for i in range(0, len(self.multiscale) * len(self.window_size)):
x_in_len = self.multiscale[i % len(self.multiscale)] * self.pred_len
x_in = x_enc[:, -x_in_len:]
legt = self.legts[i]
x_in_c = legt(x_in.transpose(1, 2)).permute([1, 2, 3, 0])[:, :, :, jump_dist:]
out1 = self.spec_conv_1[i](x_in_c)
if self.seq_len >= self.pred_len:
x_dec_c = out1.transpose(2, 3)[:, :, self.pred_len - 1 - jump_dist, :]
else:
x_dec_c = out1.transpose(2, 3)[:, :, -1, :]
x_dec = x_dec_c @ legt.eval_matrix[-self.pred_len:, :].T
x_decs.append(x_dec)
x_dec = torch.stack(x_decs, dim=-1)
x_dec = self.mlp(x_dec).squeeze(-1).permute(0, 2, 1)
# Output from Non-stationary Transformer
output = self.act(x_dec)
output = self.dropout(output)
output = output * x_mark_enc.unsqueeze(-1)
output = output.reshape(output.shape[0], -1)
output = self.projection(output)
return output
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
return dec_out[:, -self.pred_len:, :] # [B, L, D]
if self.task_name == 'imputation':
dec_out = self.imputation(x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
return dec_out # [B, L, D]
if self.task_name == 'anomaly_detection':
dec_out = self.anomaly_detection(x_enc)
return dec_out # [B, L, D]
if self.task_name == 'classification':
dec_out = self.classification(x_enc, x_mark_enc)
return dec_out # [B, N]
return None