diff --git a/nbs/docs/capabilities/01_overview.ipynb b/nbs/docs/capabilities/01_overview.ipynb
index 12fe0f29e..11b964a7f 100644
--- a/nbs/docs/capabilities/01_overview.ipynb
+++ b/nbs/docs/capabilities/01_overview.ipynb
@@ -34,6 +34,7 @@
"|`NHITS` | `AutoNHITS` | MLP | Univariate | Direct | F/H/S | \n",
"|`NLinear` | `AutoNLinear` | MLP | Univariate | Direct | - | \n",
"|`PatchTST` | `AutoPatchTST` | Transformer | Univariate | Direct | - | \n",
+ "|`RMoK` | `AutoRMoK` | KAN | Multivariate | Direct | - |\n",
"|`RNN` | `AutoRNN` | RNN | Univariate | Recursive | F/H/S | \n",
"|`SOFTS` | `AutoSOFTS` | MLP | Multivariate | Direct | - | \n",
"|`StemGNN` | `AutoStemGNN` | GNN | Multivariate | Direct | - | \n",
diff --git a/nbs/imgs_models/rmok.png b/nbs/imgs_models/rmok.png
new file mode 100644
index 000000000..11c7cf4db
Binary files /dev/null and b/nbs/imgs_models/rmok.png differ
diff --git a/nbs/models.ipynb b/nbs/models.ipynb
index f693134cb..018525399 100644
--- a/nbs/models.ipynb
+++ b/nbs/models.ipynb
@@ -66,6 +66,7 @@
"from neuralforecast.models.itransformer import iTransformer\n",
"\n",
"from neuralforecast.models.kan import KAN\n",
+ "from neuralforecast.models.rmok import RMoK\n",
"\n",
"from neuralforecast.models.stemgnn import StemGNN\n",
"from neuralforecast.models.hint import HINT\n",
@@ -2300,6 +2301,14 @@
"model.fit(dataset=dataset)"
]
},
+ {
+ "cell_type": "markdown",
+ "id": "8efc1692",
+ "metadata": {},
+ "source": [
+ "## C. KAN-Based"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -2444,7 +2453,7 @@
"id": "fd705a56",
"metadata": {},
"source": [
- "## C. Transformer-Based"
+ "## D. Transformer-Based"
]
},
{
@@ -3427,7 +3436,7 @@
"id": "57d6cb1f",
"metadata": {},
"source": [
- "## D. CNN Based"
+ "## E. CNN Based"
]
},
{
@@ -3573,7 +3582,7 @@
"id": "e6fd22c7",
"metadata": {},
"source": [
- "## E. Multivariate"
+ "## F. Multivariate"
]
},
{
@@ -4680,6 +4689,158 @@
"model.fit(dataset=dataset)"
]
},
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ab15c4b6",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "class AutoRMoK(BaseAuto):\n",
+ "\n",
+ " default_config = {\n",
+ " \"input_size_multiplier\": [1, 2, 3, 4, 5],\n",
+ " \"h\": None,\n",
+ " \"n_series\": None,\n",
+ " \"taylor_order\": tune.choice([3, 4, 5]),\n",
+ " \"jacobi_degree\": tune.choice([4, 5, 6]),\n",
+ " \"wavelet_function\": tune.choice(['mexican_hat', 'morlet', 'dog', 'meyer', 'shannon']),\n",
+ " \"learning_rate\": tune.loguniform(1e-4, 1e-1),\n",
+ " \"scaler_type\": tune.choice([None, 'robust', 'standard', 'identity']),\n",
+ " \"max_steps\": tune.choice([500, 1000, 2000]),\n",
+ " \"batch_size\": tune.choice([32, 64, 128, 256]),\n",
+ " \"loss\": None,\n",
+ " \"random_seed\": tune.randint(1, 20),\n",
+ " }\n",
+ "\n",
+ " def __init__(self,\n",
+ " h,\n",
+ " n_series,\n",
+ " loss=MAE(),\n",
+ " valid_loss=None,\n",
+ " config=None, \n",
+ " search_alg=BasicVariantGenerator(random_state=1),\n",
+ " num_samples=10,\n",
+ " refit_with_val=False,\n",
+ " cpus=cpu_count(),\n",
+ " gpus=torch.cuda.device_count(),\n",
+ " verbose=False,\n",
+ " alias=None,\n",
+ " backend='ray',\n",
+ " callbacks=None):\n",
+ " \n",
+ " # Define search space, input/output sizes\n",
+ " if config is None:\n",
+ " config = self.get_default_config(h=h, backend=backend, n_series=n_series) \n",
+ "\n",
+ " # Always use n_series from parameters, raise exception with Optuna because we can't enforce it\n",
+ " if backend == 'ray':\n",
+ " config['n_series'] = n_series\n",
+ " elif backend == 'optuna':\n",
+ " mock_trial = MockTrial()\n",
+ " if ('n_series' in config(mock_trial) and config(mock_trial)['n_series'] != n_series) or ('n_series' not in config(mock_trial)):\n",
+ " raise Exception(f\"config needs 'n_series': {n_series}\") \n",
+ "\n",
+ " super(AutoRMoK, self).__init__(\n",
+ " cls_model=RMoK, \n",
+ " h=h,\n",
+ " loss=loss,\n",
+ " valid_loss=valid_loss,\n",
+ " config=config,\n",
+ " search_alg=search_alg,\n",
+ " num_samples=num_samples, \n",
+ " refit_with_val=refit_with_val,\n",
+ " cpus=cpus,\n",
+ " gpus=gpus,\n",
+ " verbose=verbose,\n",
+ " alias=alias,\n",
+ " backend=backend,\n",
+ " callbacks=callbacks, \n",
+ " )\n",
+ "\n",
+ " @classmethod\n",
+ " def get_default_config(cls, h, backend, n_series):\n",
+ " config = cls.default_config.copy() \n",
+ " config['input_size'] = tune.choice([h * x \\\n",
+ " for x in config[\"input_size_multiplier\"]])\n",
+ "\n",
+ " # Rolling windows with step_size=1 or step_size=h\n",
+ " # See `BaseWindows` and `BaseRNN`'s create_windows\n",
+ " config['step_size'] = tune.choice([1, h])\n",
+ " del config[\"input_size_multiplier\"]\n",
+ " if backend == 'optuna':\n",
+ " # Always use n_series from parameters\n",
+ " config['n_series'] = n_series\n",
+ " config = cls._ray_config_to_optuna(config) \n",
+ "\n",
+ " return config "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "111d8d3b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "show_doc(AutoRMoK, title_level=3)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2073d4aa",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%capture\n",
+ "# Use your own config or AutoRMoK.default_config\n",
+ "config = dict(max_steps=1, val_check_steps=1, input_size=12, learning_rate=1e-2)\n",
+ "model = AutoRMoK(h=12, n_series=1, config=config, num_samples=1, cpus=1)\n",
+ "\n",
+ "# Fit and predict\n",
+ "model.fit(dataset=dataset)\n",
+ "y_hat = model.predict(dataset=dataset)\n",
+ "\n",
+ "# Optuna\n",
+ "model = AutoRMoK(h=12, n_series=1, config=None, backend='optuna')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ebe2c500",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "# Check Optuna\n",
+ "assert model.config(MockTrial())['h'] == 12\n",
+ "\n",
+ "# Unit test to test that Auto* model contains all required arguments from BaseAuto\n",
+ "test_args(AutoRMoK, exclude_args=['cls_model']) \n",
+ "\n",
+ "# Unit test for situation: Optuna with updated default config\n",
+ "my_config = AutoRMoK.get_default_config(h=12, n_series=1, backend='optuna')\n",
+ "def my_config_new(trial):\n",
+ " config = {**my_config(trial)}\n",
+ " config.update({'max_steps': 1, 'val_check_steps': 1, 'input_size': 12, 'learning_rate': 1e-1})\n",
+ " return config\n",
+ "\n",
+ "model = AutoRMoK(h=12, n_series=1, config=my_config_new, backend='optuna', num_samples=1, cpus=1)\n",
+ "model.fit(dataset=dataset)\n",
+ "\n",
+ "# Unit test for situation: Ray with updated default config\n",
+ "my_config = AutoRMoK.get_default_config(h=12, n_series=1, backend='ray')\n",
+ "my_config['max_steps'] = 1\n",
+ "my_config['val_check_steps'] = 1\n",
+ "my_config['input_size'] = 12\n",
+ "my_config['learning_rate'] = 1e-1\n",
+ "model = AutoRMoK(h=12, n_series=1, config=my_config, backend='ray', num_samples=1, cpus=1)\n",
+ "model.fit(dataset=dataset)"
+ ]
+ },
{
"attachments": {},
"cell_type": "markdown",
diff --git a/nbs/models.rmok.ipynb b/nbs/models.rmok.ipynb
new file mode 100644
index 000000000..08fee40c0
--- /dev/null
+++ b/nbs/models.rmok.ipynb
@@ -0,0 +1,650 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "%set_env PYTORCH_ENABLE_MPS_FALLBACK=1"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| default_exp models.rmok"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "%load_ext autoreload\n",
+ "%autoreload 2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Reversible Mixture of KAN - RMoK\n",
+ "The Reversible Mixture of KAN (RMoK) is a KAN-based model for time series forecasting which uses a mixture-of-experts structure to assign variables to different KAN experts, such as WaveKAN, TaylorKAN and JacobiKAN.\n",
+ "\n",
+ "**Reference**\n",
+ "- [Xiao Han, Xinfeng Zhang, Yiling Wu, Zhenduo Zhang, Zhe Wu.\"KAN4TSF: Are KAN and KAN-based models Effective for Time Series Forecasting?\"](https://arxiv.org/abs/2408.11306)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "![Figure 1. Architecture of RMoK.](imgs_models/rmok.png)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| hide\n",
+ "from fastcore.test import test_eq\n",
+ "from nbdev.showdoc import show_doc"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "import math\n",
+ "\n",
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "from neuralforecast.losses.pytorch import MAE\n",
+ "from neuralforecast.common._base_multivariate import BaseMultivariate"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 1. Auxiliary functions\n",
+ "### 1.1 WaveKAN"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "\n",
+ "class WaveKANLayer(nn.Module):\n",
+ " '''This is a sample code for the simulations of the paper:\n",
+ " Bozorgasl, Zavareh and Chen, Hao, Wav-KAN: Wavelet Kolmogorov-Arnold Networks (May, 2024)\n",
+ "\n",
+ " https://arxiv.org/abs/2405.12832\n",
+ " and also available at:\n",
+ " https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4835325\n",
+ " We used efficient KAN notation and some part of the code:+\n",
+ "\n",
+ " '''\n",
+ "\n",
+ " def __init__(self, in_features, out_features, wavelet_type='mexican_hat', with_bn=True, device=\"cpu\"):\n",
+ " super(WaveKANLayer, self).__init__()\n",
+ " self.in_features = in_features\n",
+ " self.out_features = out_features\n",
+ " self.wavelet_type = wavelet_type\n",
+ " self.with_bn = with_bn\n",
+ "\n",
+ " # Parameters for wavelet transformation\n",
+ " self.scale = nn.Parameter(torch.ones(out_features, in_features))\n",
+ " self.translation = nn.Parameter(torch.zeros(out_features, in_features))\n",
+ "\n",
+ " # self.weight1 is not used; you may use it for weighting base activation and adding it like Spl-KAN paper\n",
+ " self.weight1 = nn.Parameter(torch.Tensor(out_features, in_features))\n",
+ " self.wavelet_weights = nn.Parameter(torch.Tensor(out_features, in_features))\n",
+ "\n",
+ " nn.init.kaiming_uniform_(self.wavelet_weights, a=math.sqrt(5))\n",
+ " nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))\n",
+ "\n",
+ " # Base activation function #not used for this experiment\n",
+ " self.base_activation = nn.SiLU()\n",
+ "\n",
+ " # Batch normalization\n",
+ " if self.with_bn:\n",
+ " self.bn = nn.BatchNorm1d(out_features)\n",
+ "\n",
+ " def wavelet_transform(self, x):\n",
+ " if x.dim() == 2:\n",
+ " x_expanded = x.unsqueeze(1)\n",
+ " else:\n",
+ " x_expanded = x\n",
+ "\n",
+ " translation_expanded = self.translation.unsqueeze(0).expand(x.size(0), -1, -1)\n",
+ " scale_expanded = self.scale.unsqueeze(0).expand(x.size(0), -1, -1)\n",
+ " x_scaled = (x_expanded - translation_expanded) / scale_expanded\n",
+ "\n",
+ " # Implementation of different wavelet types\n",
+ " if self.wavelet_type == 'mexican_hat':\n",
+ " term1 = ((x_scaled ** 2) - 1)\n",
+ " term2 = torch.exp(-0.5 * x_scaled ** 2)\n",
+ " wavelet = (2 / (math.sqrt(3) * math.pi ** 0.25)) * term1 * term2\n",
+ " wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)\n",
+ " wavelet_output = wavelet_weighted.sum(dim=2)\n",
+ " elif self.wavelet_type == 'morlet':\n",
+ " omega0 = 5.0 # Central frequency\n",
+ " real = torch.cos(omega0 * x_scaled)\n",
+ " envelope = torch.exp(-0.5 * x_scaled ** 2)\n",
+ " wavelet = envelope * real\n",
+ " wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)\n",
+ " wavelet_output = wavelet_weighted.sum(dim=2)\n",
+ "\n",
+ " elif self.wavelet_type == 'dog':\n",
+ " # Implementing Derivative of Gaussian Wavelet\n",
+ " dog = -x_scaled * torch.exp(-0.5 * x_scaled ** 2)\n",
+ " wavelet = dog\n",
+ " wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)\n",
+ " wavelet_output = wavelet_weighted.sum(dim=2)\n",
+ " elif self.wavelet_type == 'meyer':\n",
+ " # Implement Meyer Wavelet here\n",
+ " # Constants for the Meyer wavelet transition boundaries\n",
+ " v = torch.abs(x_scaled)\n",
+ " pi = math.pi\n",
+ "\n",
+ " def meyer_aux(v):\n",
+ " return torch.where(v <= 1 / 2, torch.ones_like(v),\n",
+ " torch.where(v >= 1, torch.zeros_like(v), torch.cos(pi / 2 * nu(2 * v - 1))))\n",
+ "\n",
+ " def nu(t):\n",
+ " return t ** 4 * (35 - 84 * t + 70 * t ** 2 - 20 * t ** 3)\n",
+ "\n",
+ " # Meyer wavelet calculation using the auxiliary function\n",
+ " wavelet = torch.sin(pi * v) * meyer_aux(v)\n",
+ " wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)\n",
+ " wavelet_output = wavelet_weighted.sum(dim=2)\n",
+ " elif self.wavelet_type == 'shannon':\n",
+ " # Windowing the sinc function to limit its support\n",
+ " pi = math.pi\n",
+ " sinc = torch.sinc(x_scaled / pi) # sinc(x) = sin(pi*x) / (pi*x)\n",
+ "\n",
+ " # Applying a Hamming window to limit the infinite support of the sinc function\n",
+ " window = torch.hamming_window(x_scaled.size(-1), periodic=False, dtype=x_scaled.dtype,\n",
+ " device=x_scaled.device)\n",
+ " # Shannon wavelet is the product of the sinc function and the window\n",
+ " wavelet = sinc * window\n",
+ " wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)\n",
+ " wavelet_output = wavelet_weighted.sum(dim=2)\n",
+ " # You can try many more wavelet types ...\n",
+ " else:\n",
+ " raise ValueError(\"Unsupported wavelet type\")\n",
+ "\n",
+ " return wavelet_output\n",
+ "\n",
+ " def forward(self, x):\n",
+ " wavelet_output = self.wavelet_transform(x)\n",
+ " # You may like test the cases like Spl-KAN\n",
+ " # wav_output = F.linear(wavelet_output, self.weight)\n",
+ " # base_output = F.linear(self.base_activation(x), self.weight1)\n",
+ "\n",
+ " # base_output = F.linear(x, self.weight1)\n",
+ " combined_output = wavelet_output # + base_output\n",
+ "\n",
+ " # Apply batch normalization\n",
+ " if self.with_bn:\n",
+ " return self.bn(combined_output)\n",
+ " else:\n",
+ " return combined_output"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.2 TaylorKAN"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "\n",
+ "class TaylorKANLayer(nn.Module):\n",
+ " \"\"\"\n",
+ " https://github.com/Muyuzhierchengse/TaylorKAN/\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, input_dim, out_dim, order, addbias=True):\n",
+ " super(TaylorKANLayer, self).__init__()\n",
+ " self.input_dim = input_dim\n",
+ " self.out_dim = out_dim\n",
+ " self.order = order\n",
+ " self.addbias = addbias\n",
+ "\n",
+ " self.coeffs = nn.Parameter(torch.randn(out_dim, input_dim, order) * 0.01)\n",
+ " if self.addbias:\n",
+ " self.bias = nn.Parameter(torch.zeros(1, out_dim))\n",
+ "\n",
+ " def forward(self, x):\n",
+ " shape = x.shape\n",
+ " outshape = shape[0:-1] + (self.out_dim,)\n",
+ " x = torch.reshape(x, (-1, self.input_dim))\n",
+ " x_expanded = x.unsqueeze(1).expand(-1, self.out_dim, -1)\n",
+ "\n",
+ " y = torch.zeros((x.shape[0], self.out_dim), device=x.device)\n",
+ "\n",
+ " for i in range(self.order):\n",
+ " term = (x_expanded ** i) * self.coeffs[:, :, i]\n",
+ " y += term.sum(dim=-1)\n",
+ "\n",
+ " if self.addbias:\n",
+ " y += self.bias\n",
+ "\n",
+ " y = torch.reshape(y, outshape)\n",
+ " return y"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.3. JacobiKAN"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "\n",
+ "class JacobiKANLayer(nn.Module):\n",
+ " \"\"\"\n",
+ " https://github.com/SpaceLearner/JacobiKAN/blob/main/JacobiKANLayer.py\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, input_dim, output_dim, degree, a=1.0, b=1.0):\n",
+ " super(JacobiKANLayer, self).__init__()\n",
+ " self.inputdim = input_dim\n",
+ " self.outdim = output_dim\n",
+ " self.a = a\n",
+ " self.b = b\n",
+ " self.degree = degree\n",
+ "\n",
+ " self.jacobi_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))\n",
+ "\n",
+ " nn.init.normal_(self.jacobi_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)))\n",
+ "\n",
+ " def forward(self, x):\n",
+ " x = torch.reshape(x, (-1, self.inputdim)) # shape = (batch_size, inputdim)\n",
+ " # Since Jacobian polynomial is defined in [-1, 1]\n",
+ " # We need to normalize x to [-1, 1] using tanh\n",
+ " x = torch.tanh(x)\n",
+ " # Initialize Jacobian polynomial tensors\n",
+ " jacobi = torch.ones(x.shape[0], self.inputdim, self.degree + 1, device=x.device)\n",
+ " if self.degree > 0: ## degree = 0: jacobi[:, :, 0] = 1 (already initialized) ; degree = 1: jacobi[:, :, 1] = x ; d\n",
+ " jacobi[:, :, 1] = ((self.a - self.b) + (self.a + self.b + 2) * x) / 2\n",
+ " for i in range(2, self.degree + 1):\n",
+ " theta_k = (2 * i + self.a + self.b) * (2 * i + self.a + self.b - 1) / (2 * i * (i + self.a + self.b))\n",
+ " theta_k1 = (2 * i + self.a + self.b - 1) * (self.a * self.a - self.b * self.b) / (\n",
+ " 2 * i * (i + self.a + self.b) * (2 * i + self.a + self.b - 2))\n",
+ " theta_k2 = (i + self.a - 1) * (i + self.b - 1) * (2 * i + self.a + self.b) / (\n",
+ " i * (i + self.a + self.b) * (2 * i + self.a + self.b - 2))\n",
+ " jacobi[:, :, i] = (theta_k * x + theta_k1) * jacobi[:, :, i - 1].clone() - theta_k2 * jacobi[:, :,\n",
+ " i - 2].clone() # 2 * x * jacobi[:, :, i - 1].clone() - jacobi[:, :, i - 2].clone()\n",
+ " # Compute the Jacobian interpolation\n",
+ " y = torch.einsum('bid,iod->bo', jacobi, self.jacobi_coeffs) # shape = (batch_size, outdim)\n",
+ " y = y.view(-1, self.outdim)\n",
+ " return y"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 1.4 RevIN"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "\n",
+ "class RevIN(nn.Module):\n",
+ " def __init__(self, num_features: int, eps=1e-5, affine=True):\n",
+ " \"\"\"\n",
+ " :param num_features: the number of features or channels\n",
+ " :param eps: a value added for numerical stability\n",
+ " :param affine: if True, RevIN has learnable affine parameters\n",
+ " \"\"\"\n",
+ " super(RevIN, self).__init__()\n",
+ "\n",
+ " self.num_features = num_features\n",
+ " self.eps = eps\n",
+ " self.affine = affine\n",
+ "\n",
+ " if self.affine:\n",
+ " self._init_params()\n",
+ "\n",
+ " def forward(self, x, mode: str):\n",
+ " if mode == 'norm':\n",
+ " self._get_statistics(x)\n",
+ " x = self._normalize(x)\n",
+ "\n",
+ " elif mode == 'denorm':\n",
+ " x = self._denormalize(x)\n",
+ "\n",
+ " else:\n",
+ " raise NotImplementedError\n",
+ "\n",
+ " return x\n",
+ "\n",
+ " def _init_params(self):\n",
+ " # initialize RevIN params: (C,)\n",
+ " self.affine_weight = nn.Parameter(torch.ones(self.num_features))\n",
+ " self.affine_bias = nn.Parameter(torch.zeros(self.num_features))\n",
+ "\n",
+ " def _get_statistics(self, x):\n",
+ " dim2reduce = tuple(range(1, x.ndim - 1))\n",
+ " self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()\n",
+ " self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()\n",
+ "\n",
+ " def _normalize(self, x):\n",
+ " x = x - self.mean\n",
+ " x = x / self.stdev\n",
+ " if self.affine:\n",
+ " x = x * self.affine_weight\n",
+ " x = x + self.affine_bias\n",
+ "\n",
+ " return x\n",
+ "\n",
+ " def _denormalize(self, x):\n",
+ " if self.affine:\n",
+ " x = x - self.affine_bias\n",
+ " x = x / (self.affine_weight + self.eps * self.eps)\n",
+ " x = x * self.stdev\n",
+ " x = x + self.mean\n",
+ "\n",
+ " return x"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 2. Model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| export\n",
+ "\n",
+ "class RMoK(BaseMultivariate):\n",
+ " \"\"\" Reversible Mixture of KAN\n",
+ " **Parameters**
\n",
+ " `h`: int, Forecast horizon.
\n",
+ " `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].
\n",
+ " `n_series`: int, number of time-series.
\n",
+ " `futr_exog_list`: str list, future exogenous columns.
\n",
+ " `hist_exog_list`: str list, historic exogenous columns.
\n",
+ " `stat_exog_list`: str list, static exogenous columns.
\n",
+ " `taylor_order`: int, order of the Taylor polynomial.
\n",
+ " `jacobi_degree`: int, degree of the Jacobi polynomial.
\n",
+ " `wavelet_function`: str, wavelet function to use in the WaveKAN. Choose from [\"mexican_hat\", \"morlet\", \"dog\", \"meyer\", \"shannon\"]
\n",
+ " `dropout`: float, dropout rate.
\n",
+ " `revin_affine`: bool=False, bool to use affine in RevIn.
\n",
+ " `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n",
+ " `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n",
+ " `max_steps`: int=1000, maximum number of training steps.
\n",
+ " `learning_rate`: float=1e-3, Learning rate between (0, 1).
\n",
+ " `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
\n",
+ " `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
\n",
+ " `val_check_steps`: int=100, Number of training steps between every validation loss check.
\n",
+ " `batch_size`: int=32, number of different series in each batch.
\n",
+ " `step_size`: int=1, step size between each window of temporal data.
\n",
+ " `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
\n",
+ " `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
\n",
+ " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n",
+ " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n",
+ " `alias`: str, optional, Custom name of the model.
\n",
+ " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n",
+ " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n",
+ " `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).
\n",
+ " `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.
\n",
+ " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n",
+ "\n",
+ " Reference
\n",
+ " [Xiao Han, Xinfeng Zhang, Yiling Wu, Zhenduo Zhang, Zhe Wu.\"KAN4TSF: Are KAN and KAN-based models Effective for Time Series Forecasting?\"](https://arxiv.org/abs/2408.11306)\n",
+ " \"\"\"\n",
+ "\n",
+ " # Class attributes\n",
+ " SAMPLING_TYPE = 'multivariate'\n",
+ " EXOGENOUS_FUTR = False\n",
+ " EXOGENOUS_HIST = False\n",
+ " EXOGENOUS_STAT = False\n",
+ "\n",
+ " def __init__(self,\n",
+ " h,\n",
+ " input_size,\n",
+ " n_series,\n",
+ " futr_exog_list = None,\n",
+ " hist_exog_list = None,\n",
+ " stat_exog_list = None,\n",
+ " taylor_order: int = 3,\n",
+ " jacobi_degree: int = 6,\n",
+ " wavelet_function: str = 'mexican_hat',\n",
+ " dropout: float = 0.1,\n",
+ " revine_affine: bool = True,\n",
+ " loss = MAE(),\n",
+ " valid_loss = None,\n",
+ " max_steps: int = 1000,\n",
+ " learning_rate: float = 1e-3,\n",
+ " num_lr_decays: int = -1,\n",
+ " early_stop_patience_steps: int =-1,\n",
+ " val_check_steps: int = 100,\n",
+ " batch_size: int = 32,\n",
+ " step_size: int = 1,\n",
+ " scaler_type: str = 'identity',\n",
+ " random_seed: int = 1,\n",
+ " num_workers_loader: int = 0,\n",
+ " drop_last_loader: bool = False,\n",
+ " optimizer = None,\n",
+ " optimizer_kwargs = None,\n",
+ " lr_scheduler = None,\n",
+ " lr_scheduler_kwargs = None, \n",
+ " **trainer_kwargs):\n",
+ " \n",
+ " super(RMoK, self).__init__(h=h,\n",
+ " input_size=input_size,\n",
+ " n_series=n_series,\n",
+ " stat_exog_list = None,\n",
+ " futr_exog_list = None,\n",
+ " hist_exog_list = None,\n",
+ " loss=loss,\n",
+ " valid_loss=valid_loss,\n",
+ " max_steps=max_steps,\n",
+ " learning_rate=learning_rate,\n",
+ " num_lr_decays=num_lr_decays,\n",
+ " early_stop_patience_steps=early_stop_patience_steps,\n",
+ " val_check_steps=val_check_steps,\n",
+ " batch_size=batch_size,\n",
+ " step_size=step_size,\n",
+ " scaler_type=scaler_type,\n",
+ " random_seed=random_seed,\n",
+ " num_workers_loader=num_workers_loader,\n",
+ " drop_last_loader=drop_last_loader,\n",
+ " optimizer=optimizer,\n",
+ " optimizer_kwargs=optimizer_kwargs,\n",
+ " lr_scheduler=lr_scheduler,\n",
+ " lr_scheduler_kwargs=lr_scheduler_kwargs,\n",
+ " **trainer_kwargs)\n",
+ " \n",
+ " self.input_size = input_size\n",
+ " self.h = h\n",
+ " self.n_series = n_series\n",
+ " self.dropout = nn.Dropout(dropout)\n",
+ " self.revin_affine = revine_affine\n",
+ "\n",
+ " self.taylor_order = taylor_order\n",
+ " self.jacobi_degree = jacobi_degree\n",
+ " self.wavelet_function = wavelet_function\n",
+ "\n",
+ " self.experts = nn.ModuleList([\n",
+ " TaylorKANLayer(self.input_size, self.h, order=self.taylor_order, addbias=True),\n",
+ " JacobiKANLayer(self.input_size, self.h, degree=self.jacobi_degree),\n",
+ " WaveKANLayer(self.input_size, self.h, wavelet_type=self.wavelet_function),\n",
+ " nn.Linear(self.input_size, self.h),\n",
+ " ])\n",
+ " \n",
+ " self.num_experts = len(self.experts)\n",
+ " self.gate = nn.Linear(self.input_size, self.num_experts)\n",
+ " self.softmax = nn.Softmax(dim=-1)\n",
+ " self.rev = RevIN(self.n_series, affine=self.revin_affine)\n",
+ "\n",
+ " def forward(self, windows_batch):\n",
+ " insample_y = windows_batch['insample_y']\n",
+ " B, L, N = insample_y.shape\n",
+ " x = self.rev(insample_y, 'norm') if self.rev else insample_y\n",
+ " x = self.dropout(x).transpose(1, 2).reshape(B * N, L)\n",
+ "\n",
+ " score = F.softmax(self.gate(x), dim=-1)\n",
+ " expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=-1)\n",
+ "\n",
+ " y_pred = torch.einsum(\"BLE,BE->BL\", expert_outputs, score).reshape(B, N, -1).permute(0, 2, 1)\n",
+ " y_pred = self.rev(y_pred, 'denorm')\n",
+ " y_pred = self.loss.domain_map(y_pred)\n",
+ "\n",
+ " # domain_map might have squeezed the last dimension in case n_series == 1\n",
+ " if y_pred.ndim == 2:\n",
+ " return y_pred.unsqueeze(-1)\n",
+ " else:\n",
+ " return y_pred"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "show_doc(RMoK)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "show_doc(RMoK.fit, name='RMoK.fit')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "show_doc(RMoK.predict, name='RMoK.predict')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 3. Usage example"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "#| eval: false\n",
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "\n",
+ "from neuralforecast import NeuralForecast\n",
+ "from neuralforecast.models import RMoK\n",
+ "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic\n",
+ "from neuralforecast.losses.pytorch import MSE\n",
+ "\n",
+ "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n",
+ "\n",
+ "model = RMoK(h=12,\n",
+ " input_size=24,\n",
+ " n_series=2,\n",
+ " taylor_order=3,\n",
+ " jacobi_degree=6,\n",
+ " wavelet_function='mexican_hat',\n",
+ " dropout=0.1,\n",
+ " revine_affine=True,\n",
+ " loss=MSE(),\n",
+ " valid_loss=MAE(),\n",
+ " early_stop_patience_steps=3,\n",
+ " batch_size=32)\n",
+ "\n",
+ "fcst = NeuralForecast(models=[model], freq='M')\n",
+ "fcst.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n",
+ "forecasts = fcst.predict(futr_df=Y_test_df)\n",
+ "\n",
+ "# Plot predictions\n",
+ "fig, ax = plt.subplots(1, 1, figsize = (20, 7))\n",
+ "Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n",
+ "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n",
+ "plot_df = pd.concat([Y_train_df, plot_df])\n",
+ "\n",
+ "plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n",
+ "plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n",
+ "plt.plot(plot_df['ds'], plot_df['RMoK'], c='blue', label='Forecast')\n",
+ "ax.set_title('AirPassengers Forecast', fontsize=22)\n",
+ "ax.set_ylabel('Monthly Passengers', fontsize=20)\n",
+ "ax.set_xlabel('Year', fontsize=20)\n",
+ "ax.legend(prop={'size': 15})\n",
+ "ax.grid()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "python3",
+ "language": "python",
+ "name": "python3"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/neuralforecast/_modidx.py b/neuralforecast/_modidx.py
index a58491846..add079b58 100644
--- a/neuralforecast/_modidx.py
+++ b/neuralforecast/_modidx.py
@@ -96,6 +96,10 @@
'neuralforecast/auto.py'),
'neuralforecast.auto.AutoPatchTST.get_default_config': ( 'models.html#autopatchtst.get_default_config',
'neuralforecast/auto.py'),
+ 'neuralforecast.auto.AutoRMoK': ('models.html#autormok', 'neuralforecast/auto.py'),
+ 'neuralforecast.auto.AutoRMoK.__init__': ('models.html#autormok.__init__', 'neuralforecast/auto.py'),
+ 'neuralforecast.auto.AutoRMoK.get_default_config': ( 'models.html#autormok.get_default_config',
+ 'neuralforecast/auto.py'),
'neuralforecast.auto.AutoRNN': ('models.html#autornn', 'neuralforecast/auto.py'),
'neuralforecast.auto.AutoRNN.__init__': ('models.html#autornn.__init__', 'neuralforecast/auto.py'),
'neuralforecast.auto.AutoRNN.get_default_config': ( 'models.html#autornn.get_default_config',
@@ -1037,6 +1041,44 @@
'neuralforecast/models/patchtst.py'),
'neuralforecast.models.patchtst.positional_encoding': ( 'models.patchtst.html#positional_encoding',
'neuralforecast/models/patchtst.py')},
+ 'neuralforecast.models.rmok': { 'neuralforecast.models.rmok.JacobiKANLayer': ( 'models.rmok.html#jacobikanlayer',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.JacobiKANLayer.__init__': ( 'models.rmok.html#jacobikanlayer.__init__',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.JacobiKANLayer.forward': ( 'models.rmok.html#jacobikanlayer.forward',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.RMoK': ('models.rmok.html#rmok', 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.RMoK.__init__': ( 'models.rmok.html#rmok.__init__',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.RMoK.forward': ( 'models.rmok.html#rmok.forward',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.RevIN': ('models.rmok.html#revin', 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.RevIN.__init__': ( 'models.rmok.html#revin.__init__',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.RevIN._denormalize': ( 'models.rmok.html#revin._denormalize',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.RevIN._get_statistics': ( 'models.rmok.html#revin._get_statistics',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.RevIN._init_params': ( 'models.rmok.html#revin._init_params',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.RevIN._normalize': ( 'models.rmok.html#revin._normalize',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.RevIN.forward': ( 'models.rmok.html#revin.forward',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.TaylorKANLayer': ( 'models.rmok.html#taylorkanlayer',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.TaylorKANLayer.__init__': ( 'models.rmok.html#taylorkanlayer.__init__',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.TaylorKANLayer.forward': ( 'models.rmok.html#taylorkanlayer.forward',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.WaveKANLayer': ( 'models.rmok.html#wavekanlayer',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.WaveKANLayer.__init__': ( 'models.rmok.html#wavekanlayer.__init__',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.WaveKANLayer.forward': ( 'models.rmok.html#wavekanlayer.forward',
+ 'neuralforecast/models/rmok.py'),
+ 'neuralforecast.models.rmok.WaveKANLayer.wavelet_transform': ( 'models.rmok.html#wavekanlayer.wavelet_transform',
+ 'neuralforecast/models/rmok.py')},
'neuralforecast.models.rnn': { 'neuralforecast.models.rnn.RNN': ('models.rnn.html#rnn', 'neuralforecast/models/rnn.py'),
'neuralforecast.models.rnn.RNN.__init__': ( 'models.rnn.html#rnn.__init__',
'neuralforecast/models/rnn.py'),
diff --git a/neuralforecast/auto.py b/neuralforecast/auto.py
index c69b406b1..b3c85892a 100644
--- a/neuralforecast/auto.py
+++ b/neuralforecast/auto.py
@@ -5,7 +5,7 @@
'AutoNBEATSx', 'AutoNHITS', 'AutoDLinear', 'AutoNLinear', 'AutoTiDE', 'AutoDeepNPTS', 'AutoKAN', 'AutoTFT',
'AutoVanillaTransformer', 'AutoInformer', 'AutoAutoformer', 'AutoFEDformer', 'AutoPatchTST',
'AutoiTransformer', 'AutoTimesNet', 'AutoStemGNN', 'AutoHINT', 'AutoTSMixer', 'AutoTSMixerx',
- 'AutoMLPMultivariate', 'AutoSOFTS', 'AutoTimeMixer']
+ 'AutoMLPMultivariate', 'AutoSOFTS', 'AutoTimeMixer', 'AutoRMoK']
# %% ../nbs/models.ipynb 2
from os import cpu_count
@@ -44,6 +44,7 @@
from .models.itransformer import iTransformer
from .models.kan import KAN
+from .models.rmok import RMoK
from .models.stemgnn import StemGNN
from .models.hint import HINT
@@ -1108,7 +1109,7 @@ def get_default_config(cls, h, backend, n_series=None):
return config
-# %% ../nbs/models.ipynb 74
+# %% ../nbs/models.ipynb 75
class AutoKAN(BaseAuto):
default_config = {
@@ -1177,7 +1178,7 @@ def get_default_config(cls, h, backend, n_series=None):
return config
-# %% ../nbs/models.ipynb 79
+# %% ../nbs/models.ipynb 80
class AutoTFT(BaseAuto):
default_config = {
@@ -1245,7 +1246,7 @@ def get_default_config(cls, h, backend, n_series=None):
return config
-# %% ../nbs/models.ipynb 83
+# %% ../nbs/models.ipynb 84
class AutoVanillaTransformer(BaseAuto):
default_config = {
@@ -1313,7 +1314,7 @@ def get_default_config(cls, h, backend, n_series=None):
return config
-# %% ../nbs/models.ipynb 87
+# %% ../nbs/models.ipynb 88
class AutoInformer(BaseAuto):
default_config = {
@@ -1381,7 +1382,7 @@ def get_default_config(cls, h, backend, n_series=None):
return config
-# %% ../nbs/models.ipynb 91
+# %% ../nbs/models.ipynb 92
class AutoAutoformer(BaseAuto):
default_config = {
@@ -1449,7 +1450,7 @@ def get_default_config(cls, h, backend, n_series=None):
return config
-# %% ../nbs/models.ipynb 95
+# %% ../nbs/models.ipynb 96
class AutoFEDformer(BaseAuto):
default_config = {
@@ -1516,7 +1517,7 @@ def get_default_config(cls, h, backend, n_series=None):
return config
-# %% ../nbs/models.ipynb 99
+# %% ../nbs/models.ipynb 100
class AutoPatchTST(BaseAuto):
default_config = {
@@ -1586,7 +1587,7 @@ def get_default_config(cls, h, backend, n_series=None):
return config
-# %% ../nbs/models.ipynb 103
+# %% ../nbs/models.ipynb 104
class AutoiTransformer(BaseAuto):
default_config = {
@@ -1671,7 +1672,7 @@ def get_default_config(cls, h, backend, n_series):
return config
-# %% ../nbs/models.ipynb 108
+# %% ../nbs/models.ipynb 109
class AutoTimesNet(BaseAuto):
default_config = {
@@ -1739,7 +1740,7 @@ def get_default_config(cls, h, backend, n_series=None):
return config
-# %% ../nbs/models.ipynb 113
+# %% ../nbs/models.ipynb 114
class AutoStemGNN(BaseAuto):
default_config = {
@@ -1824,7 +1825,7 @@ def get_default_config(cls, h, backend, n_series):
return config
-# %% ../nbs/models.ipynb 117
+# %% ../nbs/models.ipynb 118
class AutoHINT(BaseAuto):
def __init__(
@@ -1896,7 +1897,7 @@ def _fit_model(
def get_default_config(cls, h, backend, n_series=None):
raise Exception("AutoHINT has no default configuration.")
-# %% ../nbs/models.ipynb 122
+# %% ../nbs/models.ipynb 123
class AutoTSMixer(BaseAuto):
default_config = {
@@ -1982,7 +1983,7 @@ def get_default_config(cls, h, backend, n_series):
return config
-# %% ../nbs/models.ipynb 126
+# %% ../nbs/models.ipynb 127
class AutoTSMixerx(BaseAuto):
default_config = {
@@ -2068,7 +2069,7 @@ def get_default_config(cls, h, backend, n_series):
return config
-# %% ../nbs/models.ipynb 130
+# %% ../nbs/models.ipynb 131
class AutoMLPMultivariate(BaseAuto):
default_config = {
@@ -2153,7 +2154,7 @@ def get_default_config(cls, h, backend, n_series):
return config
-# %% ../nbs/models.ipynb 134
+# %% ../nbs/models.ipynb 135
class AutoSOFTS(BaseAuto):
default_config = {
@@ -2238,7 +2239,7 @@ def get_default_config(cls, h, backend, n_series):
return config
-# %% ../nbs/models.ipynb 138
+# %% ../nbs/models.ipynb 139
class AutoTimeMixer(BaseAuto):
default_config = {
@@ -2323,3 +2324,91 @@ def get_default_config(cls, h, backend, n_series):
config = cls._ray_config_to_optuna(config)
return config
+
+# %% ../nbs/models.ipynb 143
+class AutoRMoK(BaseAuto):
+
+ default_config = {
+ "input_size_multiplier": [1, 2, 3, 4, 5],
+ "h": None,
+ "n_series": None,
+ "taylor_order": tune.choice([3, 4, 5]),
+ "jacobi_degree": tune.choice([4, 5, 6]),
+ "wavelet_function": tune.choice(
+ ["mexican_hat", "morlet", "dog", "meyer", "shannon"]
+ ),
+ "learning_rate": tune.loguniform(1e-4, 1e-1),
+ "scaler_type": tune.choice([None, "robust", "standard", "identity"]),
+ "max_steps": tune.choice([500, 1000, 2000]),
+ "batch_size": tune.choice([32, 64, 128, 256]),
+ "loss": None,
+ "random_seed": tune.randint(1, 20),
+ }
+
+ def __init__(
+ self,
+ h,
+ n_series,
+ loss=MAE(),
+ valid_loss=None,
+ config=None,
+ search_alg=BasicVariantGenerator(random_state=1),
+ num_samples=10,
+ refit_with_val=False,
+ cpus=cpu_count(),
+ gpus=torch.cuda.device_count(),
+ verbose=False,
+ alias=None,
+ backend="ray",
+ callbacks=None,
+ ):
+
+ # Define search space, input/output sizes
+ if config is None:
+ config = self.get_default_config(h=h, backend=backend, n_series=n_series)
+
+ # Always use n_series from parameters, raise exception with Optuna because we can't enforce it
+ if backend == "ray":
+ config["n_series"] = n_series
+ elif backend == "optuna":
+ mock_trial = MockTrial()
+ if (
+ "n_series" in config(mock_trial)
+ and config(mock_trial)["n_series"] != n_series
+ ) or ("n_series" not in config(mock_trial)):
+ raise Exception(f"config needs 'n_series': {n_series}")
+
+ super(AutoRMoK, self).__init__(
+ cls_model=RMoK,
+ h=h,
+ loss=loss,
+ valid_loss=valid_loss,
+ config=config,
+ search_alg=search_alg,
+ num_samples=num_samples,
+ refit_with_val=refit_with_val,
+ cpus=cpus,
+ gpus=gpus,
+ verbose=verbose,
+ alias=alias,
+ backend=backend,
+ callbacks=callbacks,
+ )
+
+ @classmethod
+ def get_default_config(cls, h, backend, n_series):
+ config = cls.default_config.copy()
+ config["input_size"] = tune.choice(
+ [h * x for x in config["input_size_multiplier"]]
+ )
+
+ # Rolling windows with step_size=1 or step_size=h
+ # See `BaseWindows` and `BaseRNN`'s create_windows
+ config["step_size"] = tune.choice([1, h])
+ del config["input_size_multiplier"]
+ if backend == "optuna":
+ # Always use n_series from parameters
+ config["n_series"] = n_series
+ config = cls._ray_config_to_optuna(config)
+
+ return config
diff --git a/neuralforecast/models/__init__.py b/neuralforecast/models/__init__.py
index ee404e31c..414689631 100644
--- a/neuralforecast/models/__init__.py
+++ b/neuralforecast/models/__init__.py
@@ -2,7 +2,7 @@
'MLP', 'NHITS', 'NBEATS', 'NBEATSx', 'DLinear', 'NLinear',
'TFT', 'VanillaTransformer', 'Informer', 'Autoformer', 'PatchTST', 'FEDformer',
'StemGNN', 'HINT', 'TimesNet', 'TimeLLM', 'TSMixer', 'TSMixerx', 'MLPMultivariate',
- 'iTransformer', 'BiTCN', 'TiDE', 'DeepNPTS', 'SOFTS', 'TimeMixer', 'KAN'
+ 'iTransformer', 'BiTCN', 'TiDE', 'DeepNPTS', 'SOFTS', 'TimeMixer', 'KAN', 'RMoK',
]
from .rnn import RNN
@@ -37,3 +37,4 @@
from .softs import SOFTS
from .timemixer import TimeMixer
from .kan import KAN
+from .rmok import RMoK
diff --git a/neuralforecast/models/rmok.py b/neuralforecast/models/rmok.py
new file mode 100644
index 000000000..c83585e1b
--- /dev/null
+++ b/neuralforecast/models/rmok.py
@@ -0,0 +1,473 @@
+# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.rmok.ipynb.
+
+# %% auto 0
+__all__ = ['WaveKANLayer', 'TaylorKANLayer', 'JacobiKANLayer', 'RevIN', 'RMoK']
+
+# %% ../../nbs/models.rmok.ipynb 6
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..losses.pytorch import MAE
+from ..common._base_multivariate import BaseMultivariate
+
+# %% ../../nbs/models.rmok.ipynb 8
+class WaveKANLayer(nn.Module):
+ """This is a sample code for the simulations of the paper:
+ Bozorgasl, Zavareh and Chen, Hao, Wav-KAN: Wavelet Kolmogorov-Arnold Networks (May, 2024)
+
+ https://arxiv.org/abs/2405.12832
+ and also available at:
+ https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4835325
+ We used efficient KAN notation and some part of the code:+
+
+ """
+
+ def __init__(
+ self,
+ in_features,
+ out_features,
+ wavelet_type="mexican_hat",
+ with_bn=True,
+ device="cpu",
+ ):
+ super(WaveKANLayer, self).__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.wavelet_type = wavelet_type
+ self.with_bn = with_bn
+
+ # Parameters for wavelet transformation
+ self.scale = nn.Parameter(torch.ones(out_features, in_features))
+ self.translation = nn.Parameter(torch.zeros(out_features, in_features))
+
+ # self.weight1 is not used; you may use it for weighting base activation and adding it like Spl-KAN paper
+ self.weight1 = nn.Parameter(torch.Tensor(out_features, in_features))
+ self.wavelet_weights = nn.Parameter(torch.Tensor(out_features, in_features))
+
+ nn.init.kaiming_uniform_(self.wavelet_weights, a=math.sqrt(5))
+ nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))
+
+ # Base activation function #not used for this experiment
+ self.base_activation = nn.SiLU()
+
+ # Batch normalization
+ if self.with_bn:
+ self.bn = nn.BatchNorm1d(out_features)
+
+ def wavelet_transform(self, x):
+ if x.dim() == 2:
+ x_expanded = x.unsqueeze(1)
+ else:
+ x_expanded = x
+
+ translation_expanded = self.translation.unsqueeze(0).expand(x.size(0), -1, -1)
+ scale_expanded = self.scale.unsqueeze(0).expand(x.size(0), -1, -1)
+ x_scaled = (x_expanded - translation_expanded) / scale_expanded
+
+ # Implementation of different wavelet types
+ if self.wavelet_type == "mexican_hat":
+ term1 = (x_scaled**2) - 1
+ term2 = torch.exp(-0.5 * x_scaled**2)
+ wavelet = (2 / (math.sqrt(3) * math.pi**0.25)) * term1 * term2
+ wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(
+ wavelet
+ )
+ wavelet_output = wavelet_weighted.sum(dim=2)
+ elif self.wavelet_type == "morlet":
+ omega0 = 5.0 # Central frequency
+ real = torch.cos(omega0 * x_scaled)
+ envelope = torch.exp(-0.5 * x_scaled**2)
+ wavelet = envelope * real
+ wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(
+ wavelet
+ )
+ wavelet_output = wavelet_weighted.sum(dim=2)
+
+ elif self.wavelet_type == "dog":
+ # Implementing Derivative of Gaussian Wavelet
+ dog = -x_scaled * torch.exp(-0.5 * x_scaled**2)
+ wavelet = dog
+ wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(
+ wavelet
+ )
+ wavelet_output = wavelet_weighted.sum(dim=2)
+ elif self.wavelet_type == "meyer":
+ # Implement Meyer Wavelet here
+ # Constants for the Meyer wavelet transition boundaries
+ v = torch.abs(x_scaled)
+ pi = math.pi
+
+ def meyer_aux(v):
+ return torch.where(
+ v <= 1 / 2,
+ torch.ones_like(v),
+ torch.where(
+ v >= 1, torch.zeros_like(v), torch.cos(pi / 2 * nu(2 * v - 1))
+ ),
+ )
+
+ def nu(t):
+ return t**4 * (35 - 84 * t + 70 * t**2 - 20 * t**3)
+
+ # Meyer wavelet calculation using the auxiliary function
+ wavelet = torch.sin(pi * v) * meyer_aux(v)
+ wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(
+ wavelet
+ )
+ wavelet_output = wavelet_weighted.sum(dim=2)
+ elif self.wavelet_type == "shannon":
+ # Windowing the sinc function to limit its support
+ pi = math.pi
+ sinc = torch.sinc(x_scaled / pi) # sinc(x) = sin(pi*x) / (pi*x)
+
+ # Applying a Hamming window to limit the infinite support of the sinc function
+ window = torch.hamming_window(
+ x_scaled.size(-1),
+ periodic=False,
+ dtype=x_scaled.dtype,
+ device=x_scaled.device,
+ )
+ # Shannon wavelet is the product of the sinc function and the window
+ wavelet = sinc * window
+ wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(
+ wavelet
+ )
+ wavelet_output = wavelet_weighted.sum(dim=2)
+ # You can try many more wavelet types ...
+ else:
+ raise ValueError("Unsupported wavelet type")
+
+ return wavelet_output
+
+ def forward(self, x):
+ wavelet_output = self.wavelet_transform(x)
+ # You may like test the cases like Spl-KAN
+ # wav_output = F.linear(wavelet_output, self.weight)
+ # base_output = F.linear(self.base_activation(x), self.weight1)
+
+ # base_output = F.linear(x, self.weight1)
+ combined_output = wavelet_output # + base_output
+
+ # Apply batch normalization
+ if self.with_bn:
+ return self.bn(combined_output)
+ else:
+ return combined_output
+
+# %% ../../nbs/models.rmok.ipynb 10
+class TaylorKANLayer(nn.Module):
+ """
+ https://github.com/Muyuzhierchengse/TaylorKAN/
+ """
+
+ def __init__(self, input_dim, out_dim, order, addbias=True):
+ super(TaylorKANLayer, self).__init__()
+ self.input_dim = input_dim
+ self.out_dim = out_dim
+ self.order = order
+ self.addbias = addbias
+
+ self.coeffs = nn.Parameter(torch.randn(out_dim, input_dim, order) * 0.01)
+ if self.addbias:
+ self.bias = nn.Parameter(torch.zeros(1, out_dim))
+
+ def forward(self, x):
+ shape = x.shape
+ outshape = shape[0:-1] + (self.out_dim,)
+ x = torch.reshape(x, (-1, self.input_dim))
+ x_expanded = x.unsqueeze(1).expand(-1, self.out_dim, -1)
+
+ y = torch.zeros((x.shape[0], self.out_dim), device=x.device)
+
+ for i in range(self.order):
+ term = (x_expanded**i) * self.coeffs[:, :, i]
+ y += term.sum(dim=-1)
+
+ if self.addbias:
+ y += self.bias
+
+ y = torch.reshape(y, outshape)
+ return y
+
+# %% ../../nbs/models.rmok.ipynb 12
+class JacobiKANLayer(nn.Module):
+ """
+ https://github.com/SpaceLearner/JacobiKAN/blob/main/JacobiKANLayer.py
+ """
+
+ def __init__(self, input_dim, output_dim, degree, a=1.0, b=1.0):
+ super(JacobiKANLayer, self).__init__()
+ self.inputdim = input_dim
+ self.outdim = output_dim
+ self.a = a
+ self.b = b
+ self.degree = degree
+
+ self.jacobi_coeffs = nn.Parameter(
+ torch.empty(input_dim, output_dim, degree + 1)
+ )
+
+ nn.init.normal_(
+ self.jacobi_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1))
+ )
+
+ def forward(self, x):
+ x = torch.reshape(x, (-1, self.inputdim)) # shape = (batch_size, inputdim)
+ # Since Jacobian polynomial is defined in [-1, 1]
+ # We need to normalize x to [-1, 1] using tanh
+ x = torch.tanh(x)
+ # Initialize Jacobian polynomial tensors
+ jacobi = torch.ones(x.shape[0], self.inputdim, self.degree + 1, device=x.device)
+ if (
+ self.degree > 0
+ ): ## degree = 0: jacobi[:, :, 0] = 1 (already initialized) ; degree = 1: jacobi[:, :, 1] = x ; d
+ jacobi[:, :, 1] = ((self.a - self.b) + (self.a + self.b + 2) * x) / 2
+ for i in range(2, self.degree + 1):
+ theta_k = (
+ (2 * i + self.a + self.b)
+ * (2 * i + self.a + self.b - 1)
+ / (2 * i * (i + self.a + self.b))
+ )
+ theta_k1 = (
+ (2 * i + self.a + self.b - 1)
+ * (self.a * self.a - self.b * self.b)
+ / (2 * i * (i + self.a + self.b) * (2 * i + self.a + self.b - 2))
+ )
+ theta_k2 = (
+ (i + self.a - 1)
+ * (i + self.b - 1)
+ * (2 * i + self.a + self.b)
+ / (i * (i + self.a + self.b) * (2 * i + self.a + self.b - 2))
+ )
+ jacobi[:, :, i] = (theta_k * x + theta_k1) * jacobi[
+ :, :, i - 1
+ ].clone() - theta_k2 * jacobi[
+ :, :, i - 2
+ ].clone() # 2 * x * jacobi[:, :, i - 1].clone() - jacobi[:, :, i - 2].clone()
+ # Compute the Jacobian interpolation
+ y = torch.einsum(
+ "bid,iod->bo", jacobi, self.jacobi_coeffs
+ ) # shape = (batch_size, outdim)
+ y = y.view(-1, self.outdim)
+ return y
+
+# %% ../../nbs/models.rmok.ipynb 14
+class RevIN(nn.Module):
+ def __init__(self, num_features: int, eps=1e-5, affine=True):
+ """
+ :param num_features: the number of features or channels
+ :param eps: a value added for numerical stability
+ :param affine: if True, RevIN has learnable affine parameters
+ """
+ super(RevIN, self).__init__()
+
+ self.num_features = num_features
+ self.eps = eps
+ self.affine = affine
+
+ if self.affine:
+ self._init_params()
+
+ def forward(self, x, mode: str):
+ if mode == "norm":
+ self._get_statistics(x)
+ x = self._normalize(x)
+
+ elif mode == "denorm":
+ x = self._denormalize(x)
+
+ else:
+ raise NotImplementedError
+
+ return x
+
+ def _init_params(self):
+ # initialize RevIN params: (C,)
+ self.affine_weight = nn.Parameter(torch.ones(self.num_features))
+ self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
+
+ def _get_statistics(self, x):
+ dim2reduce = tuple(range(1, x.ndim - 1))
+ self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
+ self.stdev = torch.sqrt(
+ torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps
+ ).detach()
+
+ def _normalize(self, x):
+ x = x - self.mean
+ x = x / self.stdev
+ if self.affine:
+ x = x * self.affine_weight
+ x = x + self.affine_bias
+
+ return x
+
+ def _denormalize(self, x):
+ if self.affine:
+ x = x - self.affine_bias
+ x = x / (self.affine_weight + self.eps * self.eps)
+ x = x * self.stdev
+ x = x + self.mean
+
+ return x
+
+# %% ../../nbs/models.rmok.ipynb 16
+class RMoK(BaseMultivariate):
+ """Reversible Mixture of KAN
+ **Parameters**
+ `h`: int, Forecast horizon.
+ `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].
+ `n_series`: int, number of time-series.
+ `futr_exog_list`: str list, future exogenous columns.
+ `hist_exog_list`: str list, historic exogenous columns.
+ `stat_exog_list`: str list, static exogenous columns.
+ `taylor_order`: int, order of the Taylor polynomial.
+ `jacobi_degree`: int, degree of the Jacobi polynomial.
+ `wavelet_function`: str, wavelet function to use in the WaveKAN. Choose from ["mexican_hat", "morlet", "dog", "meyer", "shannon"]
+ `dropout`: float, dropout rate.
+ `revin_affine`: bool=False, bool to use affine in RevIn.
+ `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
+ `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
+ `max_steps`: int=1000, maximum number of training steps.
+ `learning_rate`: float=1e-3, Learning rate between (0, 1).
+ `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
+ `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
+ `val_check_steps`: int=100, Number of training steps between every validation loss check.
+ `batch_size`: int=32, number of different series in each batch.
+ `step_size`: int=1, step size between each window of temporal data.
+ `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
+ `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
+ `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
+ `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
+ `alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
+ `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).
+ `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.
+ `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
+
+ Reference
+ [Xiao Han, Xinfeng Zhang, Yiling Wu, Zhenduo Zhang, Zhe Wu."KAN4TSF: Are KAN and KAN-based models Effective for Time Series Forecasting?"](https://arxiv.org/abs/2408.11306)
+ """
+
+ # Class attributes
+ SAMPLING_TYPE = "multivariate"
+ EXOGENOUS_FUTR = False
+ EXOGENOUS_HIST = False
+ EXOGENOUS_STAT = False
+
+ def __init__(
+ self,
+ h,
+ input_size,
+ n_series,
+ futr_exog_list=None,
+ hist_exog_list=None,
+ stat_exog_list=None,
+ taylor_order: int = 3,
+ jacobi_degree: int = 6,
+ wavelet_function: str = "mexican_hat",
+ dropout: float = 0.1,
+ revine_affine: bool = True,
+ loss=MAE(),
+ valid_loss=None,
+ max_steps: int = 1000,
+ learning_rate: float = 1e-3,
+ num_lr_decays: int = -1,
+ early_stop_patience_steps: int = -1,
+ val_check_steps: int = 100,
+ batch_size: int = 32,
+ step_size: int = 1,
+ scaler_type: str = "identity",
+ random_seed: int = 1,
+ num_workers_loader: int = 0,
+ drop_last_loader: bool = False,
+ optimizer=None,
+ optimizer_kwargs=None,
+ lr_scheduler=None,
+ lr_scheduler_kwargs=None,
+ **trainer_kwargs
+ ):
+
+ super(RMoK, self).__init__(
+ h=h,
+ input_size=input_size,
+ n_series=n_series,
+ stat_exog_list=None,
+ futr_exog_list=None,
+ hist_exog_list=None,
+ loss=loss,
+ valid_loss=valid_loss,
+ max_steps=max_steps,
+ learning_rate=learning_rate,
+ num_lr_decays=num_lr_decays,
+ early_stop_patience_steps=early_stop_patience_steps,
+ val_check_steps=val_check_steps,
+ batch_size=batch_size,
+ step_size=step_size,
+ scaler_type=scaler_type,
+ random_seed=random_seed,
+ num_workers_loader=num_workers_loader,
+ drop_last_loader=drop_last_loader,
+ optimizer=optimizer,
+ optimizer_kwargs=optimizer_kwargs,
+ lr_scheduler=lr_scheduler,
+ lr_scheduler_kwargs=lr_scheduler_kwargs,
+ **trainer_kwargs
+ )
+
+ self.input_size = input_size
+ self.h = h
+ self.n_series = n_series
+ self.dropout = nn.Dropout(dropout)
+ self.revin_affine = revine_affine
+
+ self.taylor_order = taylor_order
+ self.jacobi_degree = jacobi_degree
+ self.wavelet_function = wavelet_function
+
+ self.experts = nn.ModuleList(
+ [
+ TaylorKANLayer(
+ self.input_size, self.h, order=self.taylor_order, addbias=True
+ ),
+ JacobiKANLayer(self.input_size, self.h, degree=self.jacobi_degree),
+ WaveKANLayer(
+ self.input_size, self.h, wavelet_type=self.wavelet_function
+ ),
+ nn.Linear(self.input_size, self.h),
+ ]
+ )
+
+ self.num_experts = len(self.experts)
+ self.gate = nn.Linear(self.input_size, self.num_experts)
+ self.softmax = nn.Softmax(dim=-1)
+ self.rev = RevIN(self.n_series, affine=self.revin_affine)
+
+ def forward(self, windows_batch):
+ insample_y = windows_batch["insample_y"]
+ B, L, N = insample_y.shape
+ x = self.rev(insample_y, "norm") if self.rev else insample_y
+ x = self.dropout(x).transpose(1, 2).reshape(B * N, L)
+
+ score = F.softmax(self.gate(x), dim=-1)
+ expert_outputs = torch.stack(
+ [self.experts[i](x) for i in range(self.num_experts)], dim=-1
+ )
+
+ y_pred = (
+ torch.einsum("BLE,BE->BL", expert_outputs, score)
+ .reshape(B, N, -1)
+ .permute(0, 2, 1)
+ )
+ y_pred = self.rev(y_pred, "denorm")
+ y_pred = self.loss.domain_map(y_pred)
+
+ # domain_map might have squeezed the last dimension in case n_series == 1
+ if y_pred.ndim == 2:
+ return y_pred.unsqueeze(-1)
+ else:
+ return y_pred