Skip to content

Add params_to_tune for RNNModel and MLPModel #1218

Merged
merged 3 commits into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add default `params_to_tune` for `SeasonalMovingAverageModel`, `MovingAverageModel`, `NaiveModel` and `DeadlineMovingAverageModel` ([#1208](https://github.com/tinkoff-ai/etna/pull/1208))
- Add default `params_to_tune` for `DeepARModel` and `TFTModel` ([#1210](https://github.com/tinkoff-ai/etna/pull/1210))
- Add default `params_to_tune` for `HoltWintersModel`, `HoltModel` and `SimpleExpSmoothingModel` ([#1209](https://github.com/tinkoff-ai/etna/pull/1209))
- Add default `params_to_tune` for `RNNModel` and `MLPModel` ([#1218](https://github.com/tinkoff-ai/etna/pull/1218))
### Fixed
- Fix bug in `GaleShapleyFeatureSelectionTransform` with wrong number of remaining features ([#1110](https://github.com/tinkoff-ai/etna/pull/1110))
- `ProphetModel` fails with additional seasonality set ([#1157](https://github.com/tinkoff-ai/etna/pull/1157))
Expand Down
32 changes: 28 additions & 4 deletions etna/models/nn/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@
from typing import List
from typing import Optional

import numpy as np
import pandas as pd
from typing_extensions import TypedDict

from etna import SETTINGS
from etna.models.base import DeepBaseModel
from etna.models.base import DeepBaseNet

if SETTINGS.torch_required:
import torch
import torch.nn as nn

import numpy as np

from etna.models.base import DeepBaseModel
from etna.models.base import DeepBaseNet
if SETTINGS.auto_required:
from optuna.distributions import BaseDistribution
from optuna.distributions import IntUniformDistribution
from optuna.distributions import LogUniformDistribution


class MLPBatch(TypedDict):
Expand Down Expand Up @@ -231,3 +234,24 @@ def __init__(
trainer_params=trainer_params,
split_params=split_params,
)

def params_to_tune(self) -> Dict[str, "BaseDistribution"]:
"""Get default grid for tuning hyperparameters.

This grid doesn't tune number of layers, that is determined by the length of ``hidden_size`` parameter.
Length of ``hidden_size`` is expected to be set by the user.

Returns
-------
:
Grid to tune.
"""
grid: Dict[str, BaseDistribution] = {}

for i in range(len(self.hidden_size)):
key = f"hidden_size.{i}"
value = IntUniformDistribution(low=4, high=64, step=4)
grid[key] = value

grid["lr"] = LogUniformDistribution(low=1e-5, high=1e-2)
return grid
23 changes: 21 additions & 2 deletions etna/models/nn/rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,17 @@
from typing_extensions import TypedDict

from etna import SETTINGS
from etna.models.base import DeepBaseModel
from etna.models.base import DeepBaseNet

if SETTINGS.torch_required:
import torch
import torch.nn as nn

from etna.models.base import DeepBaseModel
from etna.models.base import DeepBaseNet
if SETTINGS.auto_required:
from optuna.distributions import BaseDistribution
from optuna.distributions import IntUniformDistribution
from optuna.distributions import LogUniformDistribution


class RNNBatch(TypedDict):
Expand Down Expand Up @@ -278,3 +282,18 @@ def __init__(
trainer_params=trainer_params,
split_params=split_params,
)

def params_to_tune(self) -> Dict[str, "BaseDistribution"]:
"""Get default grid for tuning hyperparameters.

Returns
-------
:
Grid to tune.
"""
return {
"num_layers": IntUniformDistribution(low=1, high=3),
Mr-Geekman marked this conversation as resolved.
Show resolved Hide resolved
"hidden_size": IntUniformDistribution(low=4, high=64, step=4),
"lr": LogUniformDistribution(low=1e-5, high=1e-2),
"encoder_length": IntUniformDistribution(low=1, high=20),
}
41 changes: 40 additions & 1 deletion tests/test_models/nn/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from etna.transforms import LagTransform
from etna.transforms import StandardScalerTransform
from tests.test_models.utils import assert_model_equals_loaded_original
from tests.test_models.utils import assert_sampling_is_valid


@pytest.mark.parametrize(
Expand Down Expand Up @@ -134,10 +135,48 @@ def test_save_load(example_tsds):
hidden_size=[10],
lr=1e-1,
decoder_length=14,
trainer_params=dict(max_epochs=2),
trainer_params=dict(max_epochs=1),
)
lag = LagTransform(in_column="target", lags=list(range(horizon, horizon + 3)))
fourier = FourierTransform(period=7, order=3)
std = StandardScalerTransform(in_column="target")
transforms = [lag, fourier, std]
assert_model_equals_loaded_original(model=model, ts=example_tsds, transforms=transforms, horizon=horizon)


@pytest.mark.parametrize(
"model",
[
MLPModel(
input_size=9,
hidden_size=[5],
lr=1e-1,
decoder_length=14,
trainer_params=dict(max_epochs=1),
),
MLPModel(
input_size=9,
hidden_size=[5, 5],
lr=1e-1,
decoder_length=14,
trainer_params=dict(max_epochs=1),
),
MLPModel(
input_size=9,
hidden_size=[5, 5, 5],
lr=1e-1,
decoder_length=14,
trainer_params=dict(max_epochs=1),
),
],
)
def test_params_to_tune(model, example_tsds):
ts = example_tsds
horizon = 3
lag = LagTransform(in_column="target", lags=list(range(horizon, horizon + 3)))
fourier = FourierTransform(period=7, order=3)
std = StandardScalerTransform(in_column="target")
transforms = [lag, fourier, std]
ts.fit_transform(transforms)
assert len(model.params_to_tune()) > 0
assert_sampling_is_valid(model=model, ts=ts)
10 changes: 9 additions & 1 deletion tests/test_models/nn/test_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from etna.models.nn.rnn import RNNNet
from etna.transforms import StandardScalerTransform
from tests.test_models.utils import assert_model_equals_loaded_original
from tests.test_models.utils import assert_sampling_is_valid


@pytest.mark.long_2
Expand Down Expand Up @@ -77,5 +78,12 @@ def test_context_size(encoder_length):


def test_save_load(example_tsds):
model = RNNModel(input_size=1, encoder_length=14, decoder_length=14, trainer_params=dict(max_epochs=2))
model = RNNModel(input_size=1, encoder_length=14, decoder_length=14, trainer_params=dict(max_epochs=1))
assert_model_equals_loaded_original(model=model, ts=example_tsds, transforms=[], horizon=3)


def test_params_to_tune(example_tsds):
ts = example_tsds
model = RNNModel(input_size=1, encoder_length=14, decoder_length=14, trainer_params=dict(max_epochs=1))
assert len(model.params_to_tune()) > 0
assert_sampling_is_valid(model=model, ts=ts)