Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/time net model #2538

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open

Conversation

gdevos010
Copy link
Contributor

@gdevos010 gdevos010 commented Sep 23, 2024

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

TODO

  • update doc strings
  • add tests
  • add proper attribution
  • add futures covariate support (this will probably need to be seperate MR)

Fixes #2537.

Summary

Adds the TimesNet mode based on this code
https://github.com/thuml/Time-Series-Library/blob/main/models/TimesNet.py

Other Information

Forecast_0_0 264_TimesNetModel_batch_size=16,hidden_size=128,num_layers=2,num_kernels=2,top_k=5,use_reversible_instance_norm=True

Forecast_0_1 786_TimesNetModel_batch_size=32,hidden_size=16,num_layers=3,num_kernels=4,top_k=5,use_reversible_instance_norm=True

@gdevos010 gdevos010 marked this pull request as draft September 23, 2024 03:45
@gdevos010 gdevos010 marked this pull request as ready for review September 25, 2024 18:35
Copy link

codecov bot commented Sep 25, 2024

Codecov Report

Attention: Patch coverage is 95.19651% with 11 lines in your changes missing coverage. Please review.

Project coverage is 94.11%. Comparing base (d909589) to head (c985d54).

Files with missing lines Patch % Lines
darts/models/forecasting/times_net_model.py 93.20% 7 Missing ⚠️
darts/models/components/embed.py 97.56% 3 Missing ⚠️
darts/models/__init__.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2538      +/-   ##
==========================================
- Coverage   94.14%   94.11%   -0.04%     
==========================================
  Files         139      141       +2     
  Lines       14884    15113     +229     
==========================================
+ Hits        14013    14223     +210     
- Misses        871      890      +19     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@gdevos010
Copy link
Contributor Author

gdevos010 commented Sep 29, 2024

@dennisbader This model is ready for review.

Copy link
Collaborator

@madtoinou madtoinou left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this great PR @gdevos010 and sorry for the delay, we have been quite busy with others responsibilities.

Some comments to reduce the code redundancy in the implementation of this quite complex model.

Also, could you please also include an example notebook to compare it to models such as TiDEModel and/or TSMixerModel on a toy example, just to make sure that the default configuration yield acceptable results?

from darts.utils.torch import MonteCarloDropout


class PositionalEmbedding(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the PositionalEncoding class already implements this logic in darts/models/forecasting/transformer_model.py, let's maybe move it here and keep the old name

Comment on lines +88 to +97
w = torch.zeros(c_in, d_model).float()
w.require_grad = False

position = torch.arange(0, c_in).float().unsqueeze(1)
div_term = (
torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
).exp()

w[:, 0::2] = torch.sin(position * div_term)
w[:, 1::2] = torch.cos(position * div_term)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this code snippet can also be found in PositionalEncoding, let's abstract it to reduce redundancy

return self.dropout(x)


class DataEmbedding_inverted(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class DataEmbedding_inverted(nn.Module):
class DataEmbeddingInverted(nn.Module):

return self.dropout(x)


class DataEmbedding_wo_pos(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic is so similar to DataEmbedding and DataEmbedding_inverted that they should probably be combined into a single class and the logic difference should be implemented in the forward() method instead (by adding a parameter/attribute type that could take the values "normal", "inverted" or "wopos".

from darts.models.forecasting.torch_forecasting_model import PastCovariatesTorchModel


class Inception_Block_V1(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just of the sake of homogeneity.

Suggested change
class Inception_Block_V1(nn.Module):
class InceptionBlock(nn.Module):

@@ -480,7 +480,7 @@ def encode_year(idx):
}
..
random_state
Control the randomness of the weights initialization. Check this
Control the randomness of the weight's initialization. Check this
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's revert this change, since there are several weights, the apostrophe does not seem appropriate.

Or we could eventually rephrase it into "Control the randomness in the initialization of the weights" if you have the impression that the original sentence was not clear.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename the file "timesnet_model.py" so that it's homogeneous with the others

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's rename the file "test_timesnet_model.py" to keep the number of underscores to the minimum

@@ -619,13 +620,13 @@ Patch release
- Added support for past and future covariates to `residuals()` function. [#1223](https://github.com/unit8co/darts/pull/1223) by [Eliane Maalouf](https://github.com/eliane-maalouf).
- Added support for retraining model(s) every `n` iteration and on custom conditions in `historical_forecasts` method of `ForecastingModel`s. [#1139](https://github.com/unit8co/darts/pull/1139) by [Francesco Bruzzesi](https://github.com/fbruzzesi).
- Added support for beta-NLL in `GaussianLikelihood`s, as proposed in [this paper](https://arxiv.org/abs/2203.09168). [#1162](https://github.com/unit8co/darts/pull/1162) by [Julien Herzen](https://github.com/hrzn).
- New LayerNorm alternatives, RMSNorm and LayerNormNoBias [#1113](https://github.com/unit8co/darts/pull/1113) by [Greg DeVos](https://github.com/gdevos010).
- New LayerNorm alternatives, RMSNorm and LayerNormNoBias [#1113](https://github.com/unit8co/darts/issues/1113) by [Greg DeVosNouri](https://github.com/gdevos010).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect link

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add TimesNet model
2 participants