Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/time net model #2538

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
19 changes: 10 additions & 9 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Improved**

- Improvements to `ForecastingModel`: Improved `start` handling for historical forecasts, backtest, residuals, and gridsearch. If `start` is not within the trainable / forecastable points, uses the closest valid start point that is a round multiple of `stride` ahead of start. Raises a ValueError, if no valid start point exists. This guarantees that all historical forecasts are `n * stride` points away from start, and will simplify many downstream tasks. [#2560](https://github.com/unit8co/darts/issues/2560) by [Dennis Bader](https://github.com/dennisbader).
- 🚀🚀 New forecasting model: `TimesNetModel` as proposed in [this paper](https://arxiv.org/abs/2210.02186). [#2538](https://github.com/unit8co/darts/pull/2538) by [Greg DeVosNouri](https://github.com/gdevos010).

**Fixed**

Expand Down Expand Up @@ -548,7 +549,7 @@ Patch release
[#1256](https://github.com/unit8co/darts/pull/1256) by [Julien Adda](https://github.com/julien12234)
and [Julien Herzen](https://github.com/hrzn).
- New forecasting models: `DLinearModel` and `NLinearModel` as proposed in [this paper](https://arxiv.org/pdf/2205.13504.pdf).
[#1139](https://github.com/unit8co/darts/pull/1139) by [Julien Herzen](https://github.com/hrzn) and [Greg DeVos](https://github.com/gdevos010).
[#1139](https://github.com/unit8co/darts/pull/1139) by [Julien Herzen](https://github.com/hrzn) and [Greg DeVosNouri](https://github.com/gdevos010).
- New forecasting model: `XGBModel` implementing XGBoost.
[#1405](https://github.com/unit8co/darts/pull/1405) by [Julien Herzen](https://github.com/hrzn).
- New `multi_models` option for all `RegressionModel`s: when set to False, uses only a single underlying
Expand Down Expand Up @@ -619,13 +620,13 @@ Patch release
- Added support for past and future covariates to `residuals()` function. [#1223](https://github.com/unit8co/darts/pull/1223) by [Eliane Maalouf](https://github.com/eliane-maalouf).
- Added support for retraining model(s) every `n` iteration and on custom conditions in `historical_forecasts` method of `ForecastingModel`s. [#1139](https://github.com/unit8co/darts/pull/1139) by [Francesco Bruzzesi](https://github.com/fbruzzesi).
- Added support for beta-NLL in `GaussianLikelihood`s, as proposed in [this paper](https://arxiv.org/abs/2203.09168). [#1162](https://github.com/unit8co/darts/pull/1162) by [Julien Herzen](https://github.com/hrzn).
- New LayerNorm alternatives, RMSNorm and LayerNormNoBias [#1113](https://github.com/unit8co/darts/pull/1113) by [Greg DeVos](https://github.com/gdevos010).
- New LayerNorm alternatives, RMSNorm and LayerNormNoBias [#1113](https://github.com/unit8co/darts/issues/1113) by [Greg DeVosNouri](https://github.com/gdevos010).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect link

- 🔴 Improvements to encoders: improve fitting behavior of encoders' transformers and solve a couple of issues. Remove support for absolute index encoding. [#1257](https://github.com/unit8co/darts/pull/1257) by [Dennis Bader](https://github.com/dennisbader).
- Overwrite min_train_series_length for Catboost and LightGBM [#1214](https://github.com/unit8co/darts/pull/1214) by [Anne de Vries](https://github.com/anne-devries).
- New example notebook showcasing and end-to-end example of hyperparameter optimization with Optuna [#1242](https://github.com/unit8co/darts/pull/1242) by [Julien Herzen](https://github.com/hrzn).
- New user guide section on hyperparameter optimization with Optuna and Ray Tune [#1242](https://github.com/unit8co/darts/pull/1242) by [Julien Herzen](https://github.com/hrzn).
- Documentation on model saving and loading. [#1210](https://github.com/unit8co/darts/pull/1210) by [Amadej Kocbek](https://github.com/amadejkocbek).
- 🔴 `torch_device_str` has been removed from all torch models in favor of Pytorch Lightning's `pl_trainer_kwargs` method [#1244](https://github.com/unit8co/darts/pull/1244) by [Greg DeVos](https://github.com/gdevos010).
- 🔴 `torch_device_str` has been removed from all torch models in favor of Pytorch Lightning's `pl_trainer_kwargs` method [#1244](https://github.com/unit8co/darts/pull/1244) by [Greg DeVosNouri](https://github.com/gdevos010).

**Fixed**

Expand Down Expand Up @@ -682,16 +683,16 @@ Patch release
- New Reconciliation transformers for forecast reconciliation: bottom up, top down and MinT. [#1012](https://github.com/unit8co/darts/pull/1012) by [Julien Herzen](https://github.com/hrzn).
- Added support for Monte Carlo Dropout, as a way to capture model uncertainty with torch models at inference time. [#1013](https://github.com/unit8co/darts/pull/1013) by [Julien Herzen](https://github.com/hrzn).
- New datasets: ETT and Electricity. [#617](https://github.com/unit8co/darts/pull/617)
by [Greg DeVos](https://github.com/gdevos010)
- New dataset, [Uber TLC](https://github.com/fivethirtyeight/uber-tlc-foil-response). [#1003](https://github.com/unit8co/darts/pull/1003) by [Greg DeVos](https://github.com/gdevos010).
- Model Improvements: Option for changing activation function for NHiTs and NBEATS. NBEATS support for dropout. NHiTs Support for AvgPooling1d. [#955](https://github.com/unit8co/darts/pull/955) by [Greg DeVos](https://github.com/gdevos010).
- Implemented ["GLU Variants Improve Transformer"](https://arxiv.org/abs/2002.05202) for transformer based models (transformer and TFT). [#968](https://github.com/unit8co/darts/pull/968) by [Greg DeVos](https://github.com/gdevos010).
- Added support for torch metrics during training and validation. [#996](https://github.com/unit8co/darts/pull/996) by [Greg DeVos](https://github.com/gdevos010).
by [Greg DeVosNouri](https://github.com/gdevos010)
- New dataset, [Uber TLC](https://github.com/fivethirtyeight/uber-tlc-foil-response). [#1003](https://github.com/unit8co/darts/pull/1003) by [Greg DeVosNouri](https://github.com/gdevos010).
- Model Improvements: Option for changing activation function for NHiTs and NBEATS. NBEATS support for dropout. NHiTs Support for AvgPooling1d. [#955](https://github.com/unit8co/darts/pull/955) by [Greg DeVosNouri](https://github.com/gdevos010).
- Implemented ["GLU Variants Improve Transformer"](https://arxiv.org/abs/2002.05202) for transformer based models (transformer and TFT). [#959](https://github.com/unit8co/darts/issues/959) by [Greg DeVosNouri](https://github.com/gdevos010).
- Added support for torch metrics during training and validation. [#996](https://github.com/unit8co/darts/pull/996) by [Greg DeVosNouri](https://github.com/gdevos010).
- Better handling of logging [#1010](https://github.com/unit8co/darts/pull/1010) by [Dustin Brunner](https://github.com/brunnedu).
- Better support for Python 3.10, and dropping `prophet` as a dependency (`Prophet` model still works if `prophet` package is installed separately) [#1023](https://github.com/unit8co/darts/pull/1023) by [Julien Herzen](https://github.com/hrzn).
- Option to avoid global matplotlib configuration changes.
[#924](https://github.com/unit8co/darts/pull/924) by [Mike Richman](https://github.com/zgana).
- 🔴 `HNiTSModel` renamed to `HNiTS` [#1000](https://github.com/unit8co/darts/pull/1000) by [Greg DeVos](https://github.com/gdevos010).
- 🔴 `HNiTSModel` renamed to `HNiTS` [#1000](https://github.com/unit8co/darts/pull/1000) by [Greg DeVosNouri](https://github.com/gdevos010).

**Fixed**

Expand Down
4 changes: 4 additions & 0 deletions darts/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from darts.models.forecasting.regression_model import RegressionModel
from darts.models.forecasting.tbats_model import BATS, TBATS
from darts.models.forecasting.theta import FourTheta, Theta
from darts.models.forecasting.times_net_model import TimesNetModel
from darts.models.forecasting.varima import VARIMA

try:
Expand All @@ -50,6 +51,7 @@
from darts.models.forecasting.tcn_model import TCNModel
from darts.models.forecasting.tft_model import TFTModel
from darts.models.forecasting.tide_model import TiDEModel
from darts.models.forecasting.times_net_model import TimesNetModel
from darts.models.forecasting.transformer_model import TransformerModel
from darts.models.forecasting.tsmixer_model import TSMixerModel
except ModuleNotFoundError:
Expand All @@ -71,6 +73,7 @@
TFTModel = NotImportedModule(module_name="(Py)Torch", warn=False)
TiDEModel = NotImportedModule(module_name="(Py)Torch", warn=False)
TransformerModel = NotImportedModule(module_name="(Py)Torch", warn=False)
TimesNetModel = NotImportedModule(module_name="(Py)Torch", warn=False)
TSMixerModel = NotImportedModule(module_name="(Py)Torch", warn=False)

try:
Expand Down Expand Up @@ -151,6 +154,7 @@
"TFTModel",
"TiDEModel",
"TransformerModel",
"TimesNetModel",
"TSMixerModel",
"Prophet",
"CatBoostModel",
Expand Down
237 changes: 237 additions & 0 deletions darts/models/components/embed.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
"""
TimesNet Model
-------
The implementation is built upon the Time Series Library's TimesNet model
<https://github.com/thuml/Time-Series-Library/blob/main/layers/Embed.py>

-------
MIT License

Copyright (c) 2021 THUML @ Tsinghua University

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import math

import torch
import torch.nn as nn

from darts.utils.torch import MonteCarloDropout


class PositionalEmbedding(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the PositionalEncoding class already implements this logic in darts/models/forecasting/transformer_model.py, let's maybe move it here and keep the old name

def __init__(self, d_model, max_len=5000):
super().__init__()
# Compute the positional encodings once in log space.
pe = torch.zeros(max_len, d_model).float()
pe.require_grad = False

position = torch.arange(0, max_len).float().unsqueeze(1)
div_term = (
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
).exp()

pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)

def forward(self, x):
return self.pe[:, : x.size(1)]


class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super().__init__()
padding = 1 if torch.__version__ >= "1.5.0" else 2
self.tokenConv = nn.Conv1d(
in_channels=c_in,
out_channels=d_model,
kernel_size=3,
padding=padding,
padding_mode="circular",
bias=False,
)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(
m.weight, mode="fan_in", nonlinearity="leaky_relu"
)

def forward(self, x):
x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
return x


class FixedEmbedding(nn.Module):
def __init__(self, c_in, d_model):
super().__init__()

w = torch.zeros(c_in, d_model).float()
w.require_grad = False

position = torch.arange(0, c_in).float().unsqueeze(1)
div_term = (
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
).exp()

w[:, 0::2] = torch.sin(position * div_term)
w[:, 1::2] = torch.cos(position * div_term)
Comment on lines +88 to +97
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code snippet can also be found in PositionalEncoding, let's abstract it to reduce redundancy


self.emb = nn.Embedding(c_in, d_model)
self.emb.weight = nn.Parameter(w, requires_grad=False)

def forward(self, x):
return self.emb(x).detach()


class TemporalEmbedding(nn.Module):
def __init__(self, d_model, embed_type="fixed", freq="h"):
super().__init__()

minute_size = 4
hour_size = 24
weekday_size = 7
day_size = 32
month_size = 13

Embed = FixedEmbedding if embed_type == "fixed" else nn.Embedding
if freq == "t":
self.minute_embed = Embed(minute_size, d_model)
self.hour_embed = Embed(hour_size, d_model)
self.weekday_embed = Embed(weekday_size, d_model)
self.day_embed = Embed(day_size, d_model)
self.month_embed = Embed(month_size, d_model)

def forward(self, x):
x = x.long()
minute_x = (
self.minute_embed(x[:, :, 4]) if hasattr(self, "minute_embed") else 0.0
)
hour_x = self.hour_embed(x[:, :, 3])
weekday_x = self.weekday_embed(x[:, :, 2])
day_x = self.day_embed(x[:, :, 1])
month_x = self.month_embed(x[:, :, 0])

return hour_x + weekday_x + day_x + month_x + minute_x


class TimeFeatureEmbedding(nn.Module):
def __init__(self, d_model, embed_type="timeF", freq="h"):
super().__init__()

freq_map = {"h": 4, "t": 5, "s": 6, "m": 1, "a": 1, "w": 2, "d": 3, "b": 3}
d_inp = freq_map[freq]
self.embed = nn.Linear(d_inp, d_model, bias=False)

def forward(self, x):
return self.embed(x)


class DataEmbedding(nn.Module):
def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
super().__init__()

self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.temporal_embedding = (
TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
if embed_type != "timeF"
else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
)
self.dropout = MonteCarloDropout(p=dropout)

def forward(self, x, x_mark):
if x_mark is None:
x = self.value_embedding(x) + self.position_embedding(x)
else:
x = (
self.value_embedding(x)
+ self.temporal_embedding(x_mark)
+ self.position_embedding(x)
)
return self.dropout(x)


class DataEmbedding_inverted(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class DataEmbedding_inverted(nn.Module):
class DataEmbeddingInverted(nn.Module):

def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
super().__init__()
self.value_embedding = nn.Linear(c_in, d_model)
self.dropout = MonteCarloDropout(p=dropout)

def forward(self, x, x_mark):
x = x.permute(0, 2, 1)
# x: [Batch Variate Time]
if x_mark is None:
x = self.value_embedding(x)
else:
x = self.value_embedding(torch.cat([x, x_mark.permute(0, 2, 1)], 1))
# x: [Batch Variate d_model]
return self.dropout(x)


class DataEmbedding_wo_pos(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is so similar to DataEmbedding and DataEmbedding_inverted that they should probably be combined into a single class and the logic difference should be implemented in the forward() method instead (by adding a parameter/attribute type that could take the values "normal", "inverted" or "wopos".

def __init__(self, c_in, d_model, embed_type="fixed", freq="h", dropout=0.1):
super().__init__()

self.value_embedding = TokenEmbedding(c_in=c_in, d_model=d_model)
self.position_embedding = PositionalEmbedding(d_model=d_model)
self.temporal_embedding = (
TemporalEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
if embed_type != "timeF"
else TimeFeatureEmbedding(d_model=d_model, embed_type=embed_type, freq=freq)
)
self.dropout = MonteCarloDropout(p=dropout)

def forward(self, x, x_mark):
if x_mark is None:
x = self.value_embedding(x)
else:
x = self.value_embedding(x) + self.temporal_embedding(x_mark)
return self.dropout(x)


class PatchEmbedding(nn.Module):
def __init__(self, d_model, patch_len, stride, padding, dropout):
super().__init__()
# Patching
self.patch_len = patch_len
self.stride = stride
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))

# Backbone, Input encoding: projection of feature vectors onto a d-dim vector space
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)

# Positional embedding
self.position_embedding = PositionalEmbedding(d_model)

# Residual dropout
self.dropout = MonteCarloDropout(p=dropout)

def forward(self, x):
# do patching
n_vars = x.shape[1]
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
# Input encoding
x = self.value_embedding(x) + self.position_embedding(x)
return self.dropout(x), n_vars
1 change: 1 addition & 0 deletions darts/models/forecasting/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
- :class:`~darts.models.forecasting.nhits.NHiTSModel`
- :class:`~darts.models.forecasting.tcn_model.TCNModel`
- :class:`~darts.models.forecasting.transformer_model.TransformerModel`
- :class:`~darts.models.forecasting.time_net_model.TimesNetModel`
- :class:`~darts.models.forecasting.tft_model.TFTModel`
- :class:`~darts.models.forecasting.dlinear.DLinearModel`
- :class:`~darts.models.forecasting.nlinear.NLinearModel`
Expand Down
Loading
Loading