diff --git a/nbs/docs/capabilities/01_overview.ipynb b/nbs/docs/capabilities/01_overview.ipynb index 12fe0f29e..11b964a7f 100644 --- a/nbs/docs/capabilities/01_overview.ipynb +++ b/nbs/docs/capabilities/01_overview.ipynb @@ -34,6 +34,7 @@ "|`NHITS` | `AutoNHITS` | MLP | Univariate | Direct | F/H/S | \n", "|`NLinear` | `AutoNLinear` | MLP | Univariate | Direct | - | \n", "|`PatchTST` | `AutoPatchTST` | Transformer | Univariate | Direct | - | \n", + "|`RMoK` | `AutoRMoK` | KAN | Multivariate | Direct | - |\n", "|`RNN` | `AutoRNN` | RNN | Univariate | Recursive | F/H/S | \n", "|`SOFTS` | `AutoSOFTS` | MLP | Multivariate | Direct | - | \n", "|`StemGNN` | `AutoStemGNN` | GNN | Multivariate | Direct | - | \n", diff --git a/nbs/imgs_models/rmok.png b/nbs/imgs_models/rmok.png new file mode 100644 index 000000000..11c7cf4db Binary files /dev/null and b/nbs/imgs_models/rmok.png differ diff --git a/nbs/models.ipynb b/nbs/models.ipynb index f693134cb..018525399 100644 --- a/nbs/models.ipynb +++ b/nbs/models.ipynb @@ -66,6 +66,7 @@ "from neuralforecast.models.itransformer import iTransformer\n", "\n", "from neuralforecast.models.kan import KAN\n", + "from neuralforecast.models.rmok import RMoK\n", "\n", "from neuralforecast.models.stemgnn import StemGNN\n", "from neuralforecast.models.hint import HINT\n", @@ -2300,6 +2301,14 @@ "model.fit(dataset=dataset)" ] }, + { + "cell_type": "markdown", + "id": "8efc1692", + "metadata": {}, + "source": [ + "## C. KAN-Based" + ] + }, { "cell_type": "code", "execution_count": null, @@ -2444,7 +2453,7 @@ "id": "fd705a56", "metadata": {}, "source": [ - "## C. Transformer-Based" + "## D. Transformer-Based" ] }, { @@ -3427,7 +3436,7 @@ "id": "57d6cb1f", "metadata": {}, "source": [ - "## D. CNN Based" + "## E. CNN Based" ] }, { @@ -3573,7 +3582,7 @@ "id": "e6fd22c7", "metadata": {}, "source": [ - "## E. Multivariate" + "## F. Multivariate" ] }, { @@ -4680,6 +4689,158 @@ "model.fit(dataset=dataset)" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "ab15c4b6", + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class AutoRMoK(BaseAuto):\n", + "\n", + " default_config = {\n", + " \"input_size_multiplier\": [1, 2, 3, 4, 5],\n", + " \"h\": None,\n", + " \"n_series\": None,\n", + " \"taylor_order\": tune.choice([3, 4, 5]),\n", + " \"jacobi_degree\": tune.choice([4, 5, 6]),\n", + " \"wavelet_function\": tune.choice(['mexican_hat', 'morlet', 'dog', 'meyer', 'shannon']),\n", + " \"learning_rate\": tune.loguniform(1e-4, 1e-1),\n", + " \"scaler_type\": tune.choice([None, 'robust', 'standard', 'identity']),\n", + " \"max_steps\": tune.choice([500, 1000, 2000]),\n", + " \"batch_size\": tune.choice([32, 64, 128, 256]),\n", + " \"loss\": None,\n", + " \"random_seed\": tune.randint(1, 20),\n", + " }\n", + "\n", + " def __init__(self,\n", + " h,\n", + " n_series,\n", + " loss=MAE(),\n", + " valid_loss=None,\n", + " config=None, \n", + " search_alg=BasicVariantGenerator(random_state=1),\n", + " num_samples=10,\n", + " refit_with_val=False,\n", + " cpus=cpu_count(),\n", + " gpus=torch.cuda.device_count(),\n", + " verbose=False,\n", + " alias=None,\n", + " backend='ray',\n", + " callbacks=None):\n", + " \n", + " # Define search space, input/output sizes\n", + " if config is None:\n", + " config = self.get_default_config(h=h, backend=backend, n_series=n_series) \n", + "\n", + " # Always use n_series from parameters, raise exception with Optuna because we can't enforce it\n", + " if backend == 'ray':\n", + " config['n_series'] = n_series\n", + " elif backend == 'optuna':\n", + " mock_trial = MockTrial()\n", + " if ('n_series' in config(mock_trial) and config(mock_trial)['n_series'] != n_series) or ('n_series' not in config(mock_trial)):\n", + " raise Exception(f\"config needs 'n_series': {n_series}\") \n", + "\n", + " super(AutoRMoK, self).__init__(\n", + " cls_model=RMoK, \n", + " h=h,\n", + " loss=loss,\n", + " valid_loss=valid_loss,\n", + " config=config,\n", + " search_alg=search_alg,\n", + " num_samples=num_samples, \n", + " refit_with_val=refit_with_val,\n", + " cpus=cpus,\n", + " gpus=gpus,\n", + " verbose=verbose,\n", + " alias=alias,\n", + " backend=backend,\n", + " callbacks=callbacks, \n", + " )\n", + "\n", + " @classmethod\n", + " def get_default_config(cls, h, backend, n_series):\n", + " config = cls.default_config.copy() \n", + " config['input_size'] = tune.choice([h * x \\\n", + " for x in config[\"input_size_multiplier\"]])\n", + "\n", + " # Rolling windows with step_size=1 or step_size=h\n", + " # See `BaseWindows` and `BaseRNN`'s create_windows\n", + " config['step_size'] = tune.choice([1, h])\n", + " del config[\"input_size_multiplier\"]\n", + " if backend == 'optuna':\n", + " # Always use n_series from parameters\n", + " config['n_series'] = n_series\n", + " config = cls._ray_config_to_optuna(config) \n", + "\n", + " return config " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "111d8d3b", + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(AutoRMoK, title_level=3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2073d4aa", + "metadata": {}, + "outputs": [], + "source": [ + "%%capture\n", + "# Use your own config or AutoRMoK.default_config\n", + "config = dict(max_steps=1, val_check_steps=1, input_size=12, learning_rate=1e-2)\n", + "model = AutoRMoK(h=12, n_series=1, config=config, num_samples=1, cpus=1)\n", + "\n", + "# Fit and predict\n", + "model.fit(dataset=dataset)\n", + "y_hat = model.predict(dataset=dataset)\n", + "\n", + "# Optuna\n", + "model = AutoRMoK(h=12, n_series=1, config=None, backend='optuna')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ebe2c500", + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# Check Optuna\n", + "assert model.config(MockTrial())['h'] == 12\n", + "\n", + "# Unit test to test that Auto* model contains all required arguments from BaseAuto\n", + "test_args(AutoRMoK, exclude_args=['cls_model']) \n", + "\n", + "# Unit test for situation: Optuna with updated default config\n", + "my_config = AutoRMoK.get_default_config(h=12, n_series=1, backend='optuna')\n", + "def my_config_new(trial):\n", + " config = {**my_config(trial)}\n", + " config.update({'max_steps': 1, 'val_check_steps': 1, 'input_size': 12, 'learning_rate': 1e-1})\n", + " return config\n", + "\n", + "model = AutoRMoK(h=12, n_series=1, config=my_config_new, backend='optuna', num_samples=1, cpus=1)\n", + "model.fit(dataset=dataset)\n", + "\n", + "# Unit test for situation: Ray with updated default config\n", + "my_config = AutoRMoK.get_default_config(h=12, n_series=1, backend='ray')\n", + "my_config['max_steps'] = 1\n", + "my_config['val_check_steps'] = 1\n", + "my_config['input_size'] = 12\n", + "my_config['learning_rate'] = 1e-1\n", + "model = AutoRMoK(h=12, n_series=1, config=my_config, backend='ray', num_samples=1, cpus=1)\n", + "model.fit(dataset=dataset)" + ] + }, { "attachments": {}, "cell_type": "markdown", diff --git a/nbs/models.rmok.ipynb b/nbs/models.rmok.ipynb new file mode 100644 index 000000000..08fee40c0 --- /dev/null +++ b/nbs/models.rmok.ipynb @@ -0,0 +1,650 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "%set_env PYTORCH_ENABLE_MPS_FALLBACK=1" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| default_exp models.rmok" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "%load_ext autoreload\n", + "%autoreload 2" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Reversible Mixture of KAN - RMoK\n", + "The Reversible Mixture of KAN (RMoK) is a KAN-based model for time series forecasting which uses a mixture-of-experts structure to assign variables to different KAN experts, such as WaveKAN, TaylorKAN and JacobiKAN.\n", + "\n", + "**Reference**\n", + "- [Xiao Han, Xinfeng Zhang, Yiling Wu, Zhenduo Zhang, Zhe Wu.\"KAN4TSF: Are KAN and KAN-based models Effective for Time Series Forecasting?\"](https://arxiv.org/abs/2408.11306)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "![Figure 1. Architecture of RMoK.](imgs_models/rmok.png)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "from fastcore.test import test_eq\n", + "from nbdev.showdoc import show_doc" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "import math\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "\n", + "from neuralforecast.losses.pytorch import MAE\n", + "from neuralforecast.common._base_multivariate import BaseMultivariate" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 1. Auxiliary functions\n", + "### 1.1 WaveKAN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class WaveKANLayer(nn.Module):\n", + " '''This is a sample code for the simulations of the paper:\n", + " Bozorgasl, Zavareh and Chen, Hao, Wav-KAN: Wavelet Kolmogorov-Arnold Networks (May, 2024)\n", + "\n", + " https://arxiv.org/abs/2405.12832\n", + " and also available at:\n", + " https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4835325\n", + " We used efficient KAN notation and some part of the code:+\n", + "\n", + " '''\n", + "\n", + " def __init__(self, in_features, out_features, wavelet_type='mexican_hat', with_bn=True, device=\"cpu\"):\n", + " super(WaveKANLayer, self).__init__()\n", + " self.in_features = in_features\n", + " self.out_features = out_features\n", + " self.wavelet_type = wavelet_type\n", + " self.with_bn = with_bn\n", + "\n", + " # Parameters for wavelet transformation\n", + " self.scale = nn.Parameter(torch.ones(out_features, in_features))\n", + " self.translation = nn.Parameter(torch.zeros(out_features, in_features))\n", + "\n", + " # self.weight1 is not used; you may use it for weighting base activation and adding it like Spl-KAN paper\n", + " self.weight1 = nn.Parameter(torch.Tensor(out_features, in_features))\n", + " self.wavelet_weights = nn.Parameter(torch.Tensor(out_features, in_features))\n", + "\n", + " nn.init.kaiming_uniform_(self.wavelet_weights, a=math.sqrt(5))\n", + " nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5))\n", + "\n", + " # Base activation function #not used for this experiment\n", + " self.base_activation = nn.SiLU()\n", + "\n", + " # Batch normalization\n", + " if self.with_bn:\n", + " self.bn = nn.BatchNorm1d(out_features)\n", + "\n", + " def wavelet_transform(self, x):\n", + " if x.dim() == 2:\n", + " x_expanded = x.unsqueeze(1)\n", + " else:\n", + " x_expanded = x\n", + "\n", + " translation_expanded = self.translation.unsqueeze(0).expand(x.size(0), -1, -1)\n", + " scale_expanded = self.scale.unsqueeze(0).expand(x.size(0), -1, -1)\n", + " x_scaled = (x_expanded - translation_expanded) / scale_expanded\n", + "\n", + " # Implementation of different wavelet types\n", + " if self.wavelet_type == 'mexican_hat':\n", + " term1 = ((x_scaled ** 2) - 1)\n", + " term2 = torch.exp(-0.5 * x_scaled ** 2)\n", + " wavelet = (2 / (math.sqrt(3) * math.pi ** 0.25)) * term1 * term2\n", + " wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)\n", + " wavelet_output = wavelet_weighted.sum(dim=2)\n", + " elif self.wavelet_type == 'morlet':\n", + " omega0 = 5.0 # Central frequency\n", + " real = torch.cos(omega0 * x_scaled)\n", + " envelope = torch.exp(-0.5 * x_scaled ** 2)\n", + " wavelet = envelope * real\n", + " wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)\n", + " wavelet_output = wavelet_weighted.sum(dim=2)\n", + "\n", + " elif self.wavelet_type == 'dog':\n", + " # Implementing Derivative of Gaussian Wavelet\n", + " dog = -x_scaled * torch.exp(-0.5 * x_scaled ** 2)\n", + " wavelet = dog\n", + " wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)\n", + " wavelet_output = wavelet_weighted.sum(dim=2)\n", + " elif self.wavelet_type == 'meyer':\n", + " # Implement Meyer Wavelet here\n", + " # Constants for the Meyer wavelet transition boundaries\n", + " v = torch.abs(x_scaled)\n", + " pi = math.pi\n", + "\n", + " def meyer_aux(v):\n", + " return torch.where(v <= 1 / 2, torch.ones_like(v),\n", + " torch.where(v >= 1, torch.zeros_like(v), torch.cos(pi / 2 * nu(2 * v - 1))))\n", + "\n", + " def nu(t):\n", + " return t ** 4 * (35 - 84 * t + 70 * t ** 2 - 20 * t ** 3)\n", + "\n", + " # Meyer wavelet calculation using the auxiliary function\n", + " wavelet = torch.sin(pi * v) * meyer_aux(v)\n", + " wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)\n", + " wavelet_output = wavelet_weighted.sum(dim=2)\n", + " elif self.wavelet_type == 'shannon':\n", + " # Windowing the sinc function to limit its support\n", + " pi = math.pi\n", + " sinc = torch.sinc(x_scaled / pi) # sinc(x) = sin(pi*x) / (pi*x)\n", + "\n", + " # Applying a Hamming window to limit the infinite support of the sinc function\n", + " window = torch.hamming_window(x_scaled.size(-1), periodic=False, dtype=x_scaled.dtype,\n", + " device=x_scaled.device)\n", + " # Shannon wavelet is the product of the sinc function and the window\n", + " wavelet = sinc * window\n", + " wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as(wavelet)\n", + " wavelet_output = wavelet_weighted.sum(dim=2)\n", + " # You can try many more wavelet types ...\n", + " else:\n", + " raise ValueError(\"Unsupported wavelet type\")\n", + "\n", + " return wavelet_output\n", + "\n", + " def forward(self, x):\n", + " wavelet_output = self.wavelet_transform(x)\n", + " # You may like test the cases like Spl-KAN\n", + " # wav_output = F.linear(wavelet_output, self.weight)\n", + " # base_output = F.linear(self.base_activation(x), self.weight1)\n", + "\n", + " # base_output = F.linear(x, self.weight1)\n", + " combined_output = wavelet_output # + base_output\n", + "\n", + " # Apply batch normalization\n", + " if self.with_bn:\n", + " return self.bn(combined_output)\n", + " else:\n", + " return combined_output" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.2 TaylorKAN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class TaylorKANLayer(nn.Module):\n", + " \"\"\"\n", + " https://github.com/Muyuzhierchengse/TaylorKAN/\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_dim, out_dim, order, addbias=True):\n", + " super(TaylorKANLayer, self).__init__()\n", + " self.input_dim = input_dim\n", + " self.out_dim = out_dim\n", + " self.order = order\n", + " self.addbias = addbias\n", + "\n", + " self.coeffs = nn.Parameter(torch.randn(out_dim, input_dim, order) * 0.01)\n", + " if self.addbias:\n", + " self.bias = nn.Parameter(torch.zeros(1, out_dim))\n", + "\n", + " def forward(self, x):\n", + " shape = x.shape\n", + " outshape = shape[0:-1] + (self.out_dim,)\n", + " x = torch.reshape(x, (-1, self.input_dim))\n", + " x_expanded = x.unsqueeze(1).expand(-1, self.out_dim, -1)\n", + "\n", + " y = torch.zeros((x.shape[0], self.out_dim), device=x.device)\n", + "\n", + " for i in range(self.order):\n", + " term = (x_expanded ** i) * self.coeffs[:, :, i]\n", + " y += term.sum(dim=-1)\n", + "\n", + " if self.addbias:\n", + " y += self.bias\n", + "\n", + " y = torch.reshape(y, outshape)\n", + " return y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.3. JacobiKAN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class JacobiKANLayer(nn.Module):\n", + " \"\"\"\n", + " https://github.com/SpaceLearner/JacobiKAN/blob/main/JacobiKANLayer.py\n", + " \"\"\"\n", + "\n", + " def __init__(self, input_dim, output_dim, degree, a=1.0, b=1.0):\n", + " super(JacobiKANLayer, self).__init__()\n", + " self.inputdim = input_dim\n", + " self.outdim = output_dim\n", + " self.a = a\n", + " self.b = b\n", + " self.degree = degree\n", + "\n", + " self.jacobi_coeffs = nn.Parameter(torch.empty(input_dim, output_dim, degree + 1))\n", + "\n", + " nn.init.normal_(self.jacobi_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)))\n", + "\n", + " def forward(self, x):\n", + " x = torch.reshape(x, (-1, self.inputdim)) # shape = (batch_size, inputdim)\n", + " # Since Jacobian polynomial is defined in [-1, 1]\n", + " # We need to normalize x to [-1, 1] using tanh\n", + " x = torch.tanh(x)\n", + " # Initialize Jacobian polynomial tensors\n", + " jacobi = torch.ones(x.shape[0], self.inputdim, self.degree + 1, device=x.device)\n", + " if self.degree > 0: ## degree = 0: jacobi[:, :, 0] = 1 (already initialized) ; degree = 1: jacobi[:, :, 1] = x ; d\n", + " jacobi[:, :, 1] = ((self.a - self.b) + (self.a + self.b + 2) * x) / 2\n", + " for i in range(2, self.degree + 1):\n", + " theta_k = (2 * i + self.a + self.b) * (2 * i + self.a + self.b - 1) / (2 * i * (i + self.a + self.b))\n", + " theta_k1 = (2 * i + self.a + self.b - 1) * (self.a * self.a - self.b * self.b) / (\n", + " 2 * i * (i + self.a + self.b) * (2 * i + self.a + self.b - 2))\n", + " theta_k2 = (i + self.a - 1) * (i + self.b - 1) * (2 * i + self.a + self.b) / (\n", + " i * (i + self.a + self.b) * (2 * i + self.a + self.b - 2))\n", + " jacobi[:, :, i] = (theta_k * x + theta_k1) * jacobi[:, :, i - 1].clone() - theta_k2 * jacobi[:, :,\n", + " i - 2].clone() # 2 * x * jacobi[:, :, i - 1].clone() - jacobi[:, :, i - 2].clone()\n", + " # Compute the Jacobian interpolation\n", + " y = torch.einsum('bid,iod->bo', jacobi, self.jacobi_coeffs) # shape = (batch_size, outdim)\n", + " y = y.view(-1, self.outdim)\n", + " return y" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1.4 RevIN" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class RevIN(nn.Module):\n", + " def __init__(self, num_features: int, eps=1e-5, affine=True):\n", + " \"\"\"\n", + " :param num_features: the number of features or channels\n", + " :param eps: a value added for numerical stability\n", + " :param affine: if True, RevIN has learnable affine parameters\n", + " \"\"\"\n", + " super(RevIN, self).__init__()\n", + "\n", + " self.num_features = num_features\n", + " self.eps = eps\n", + " self.affine = affine\n", + "\n", + " if self.affine:\n", + " self._init_params()\n", + "\n", + " def forward(self, x, mode: str):\n", + " if mode == 'norm':\n", + " self._get_statistics(x)\n", + " x = self._normalize(x)\n", + "\n", + " elif mode == 'denorm':\n", + " x = self._denormalize(x)\n", + "\n", + " else:\n", + " raise NotImplementedError\n", + "\n", + " return x\n", + "\n", + " def _init_params(self):\n", + " # initialize RevIN params: (C,)\n", + " self.affine_weight = nn.Parameter(torch.ones(self.num_features))\n", + " self.affine_bias = nn.Parameter(torch.zeros(self.num_features))\n", + "\n", + " def _get_statistics(self, x):\n", + " dim2reduce = tuple(range(1, x.ndim - 1))\n", + " self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()\n", + " self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()\n", + "\n", + " def _normalize(self, x):\n", + " x = x - self.mean\n", + " x = x / self.stdev\n", + " if self.affine:\n", + " x = x * self.affine_weight\n", + " x = x + self.affine_bias\n", + "\n", + " return x\n", + "\n", + " def _denormalize(self, x):\n", + " if self.affine:\n", + " x = x - self.affine_bias\n", + " x = x / (self.affine_weight + self.eps * self.eps)\n", + " x = x * self.stdev\n", + " x = x + self.mean\n", + "\n", + " return x" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "\n", + "class RMoK(BaseMultivariate):\n", + " \"\"\" Reversible Mixture of KAN\n", + " **Parameters**
\n", + " `h`: int, Forecast horizon.
\n", + " `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].
\n", + " `n_series`: int, number of time-series.
\n", + " `futr_exog_list`: str list, future exogenous columns.
\n", + " `hist_exog_list`: str list, historic exogenous columns.
\n", + " `stat_exog_list`: str list, static exogenous columns.
\n", + " `taylor_order`: int, order of the Taylor polynomial.
\n", + " `jacobi_degree`: int, degree of the Jacobi polynomial.
\n", + " `wavelet_function`: str, wavelet function to use in the WaveKAN. Choose from [\"mexican_hat\", \"morlet\", \"dog\", \"meyer\", \"shannon\"]
\n", + " `dropout`: float, dropout rate.
\n", + " `revin_affine`: bool=False, bool to use affine in RevIn.
\n", + " `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n", + " `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n", + " `max_steps`: int=1000, maximum number of training steps.
\n", + " `learning_rate`: float=1e-3, Learning rate between (0, 1).
\n", + " `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
\n", + " `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
\n", + " `val_check_steps`: int=100, Number of training steps between every validation loss check.
\n", + " `batch_size`: int=32, number of different series in each batch.
\n", + " `step_size`: int=1, step size between each window of temporal data.
\n", + " `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
\n", + " `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
\n", + " `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
\n", + " `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
\n", + " `alias`: str, optional, Custom name of the model.
\n", + " `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
\n", + " `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
\n", + " `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).
\n", + " `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.
\n", + " `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
\n", + "\n", + " Reference
\n", + " [Xiao Han, Xinfeng Zhang, Yiling Wu, Zhenduo Zhang, Zhe Wu.\"KAN4TSF: Are KAN and KAN-based models Effective for Time Series Forecasting?\"](https://arxiv.org/abs/2408.11306)\n", + " \"\"\"\n", + "\n", + " # Class attributes\n", + " SAMPLING_TYPE = 'multivariate'\n", + " EXOGENOUS_FUTR = False\n", + " EXOGENOUS_HIST = False\n", + " EXOGENOUS_STAT = False\n", + "\n", + " def __init__(self,\n", + " h,\n", + " input_size,\n", + " n_series,\n", + " futr_exog_list = None,\n", + " hist_exog_list = None,\n", + " stat_exog_list = None,\n", + " taylor_order: int = 3,\n", + " jacobi_degree: int = 6,\n", + " wavelet_function: str = 'mexican_hat',\n", + " dropout: float = 0.1,\n", + " revine_affine: bool = True,\n", + " loss = MAE(),\n", + " valid_loss = None,\n", + " max_steps: int = 1000,\n", + " learning_rate: float = 1e-3,\n", + " num_lr_decays: int = -1,\n", + " early_stop_patience_steps: int =-1,\n", + " val_check_steps: int = 100,\n", + " batch_size: int = 32,\n", + " step_size: int = 1,\n", + " scaler_type: str = 'identity',\n", + " random_seed: int = 1,\n", + " num_workers_loader: int = 0,\n", + " drop_last_loader: bool = False,\n", + " optimizer = None,\n", + " optimizer_kwargs = None,\n", + " lr_scheduler = None,\n", + " lr_scheduler_kwargs = None, \n", + " **trainer_kwargs):\n", + " \n", + " super(RMoK, self).__init__(h=h,\n", + " input_size=input_size,\n", + " n_series=n_series,\n", + " stat_exog_list = None,\n", + " futr_exog_list = None,\n", + " hist_exog_list = None,\n", + " loss=loss,\n", + " valid_loss=valid_loss,\n", + " max_steps=max_steps,\n", + " learning_rate=learning_rate,\n", + " num_lr_decays=num_lr_decays,\n", + " early_stop_patience_steps=early_stop_patience_steps,\n", + " val_check_steps=val_check_steps,\n", + " batch_size=batch_size,\n", + " step_size=step_size,\n", + " scaler_type=scaler_type,\n", + " random_seed=random_seed,\n", + " num_workers_loader=num_workers_loader,\n", + " drop_last_loader=drop_last_loader,\n", + " optimizer=optimizer,\n", + " optimizer_kwargs=optimizer_kwargs,\n", + " lr_scheduler=lr_scheduler,\n", + " lr_scheduler_kwargs=lr_scheduler_kwargs,\n", + " **trainer_kwargs)\n", + " \n", + " self.input_size = input_size\n", + " self.h = h\n", + " self.n_series = n_series\n", + " self.dropout = nn.Dropout(dropout)\n", + " self.revin_affine = revine_affine\n", + "\n", + " self.taylor_order = taylor_order\n", + " self.jacobi_degree = jacobi_degree\n", + " self.wavelet_function = wavelet_function\n", + "\n", + " self.experts = nn.ModuleList([\n", + " TaylorKANLayer(self.input_size, self.h, order=self.taylor_order, addbias=True),\n", + " JacobiKANLayer(self.input_size, self.h, degree=self.jacobi_degree),\n", + " WaveKANLayer(self.input_size, self.h, wavelet_type=self.wavelet_function),\n", + " nn.Linear(self.input_size, self.h),\n", + " ])\n", + " \n", + " self.num_experts = len(self.experts)\n", + " self.gate = nn.Linear(self.input_size, self.num_experts)\n", + " self.softmax = nn.Softmax(dim=-1)\n", + " self.rev = RevIN(self.n_series, affine=self.revin_affine)\n", + "\n", + " def forward(self, windows_batch):\n", + " insample_y = windows_batch['insample_y']\n", + " B, L, N = insample_y.shape\n", + " x = self.rev(insample_y, 'norm') if self.rev else insample_y\n", + " x = self.dropout(x).transpose(1, 2).reshape(B * N, L)\n", + "\n", + " score = F.softmax(self.gate(x), dim=-1)\n", + " expert_outputs = torch.stack([self.experts[i](x) for i in range(self.num_experts)], dim=-1)\n", + "\n", + " y_pred = torch.einsum(\"BLE,BE->BL\", expert_outputs, score).reshape(B, N, -1).permute(0, 2, 1)\n", + " y_pred = self.rev(y_pred, 'denorm')\n", + " y_pred = self.loss.domain_map(y_pred)\n", + "\n", + " # domain_map might have squeezed the last dimension in case n_series == 1\n", + " if y_pred.ndim == 2:\n", + " return y_pred.unsqueeze(-1)\n", + " else:\n", + " return y_pred" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(RMoK)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(RMoK.fit, name='RMoK.fit')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "show_doc(RMoK.predict, name='RMoK.predict')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. Usage example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| eval: false\n", + "import pandas as pd\n", + "import matplotlib.pyplot as plt\n", + "\n", + "from neuralforecast import NeuralForecast\n", + "from neuralforecast.models import RMoK\n", + "from neuralforecast.utils import AirPassengersPanel, AirPassengersStatic\n", + "from neuralforecast.losses.pytorch import MSE\n", + "\n", + "Y_train_df = AirPassengersPanel[AirPassengersPanel.ds=AirPassengersPanel['ds'].values[-12]].reset_index(drop=True) # 12 test\n", + "\n", + "model = RMoK(h=12,\n", + " input_size=24,\n", + " n_series=2,\n", + " taylor_order=3,\n", + " jacobi_degree=6,\n", + " wavelet_function='mexican_hat',\n", + " dropout=0.1,\n", + " revine_affine=True,\n", + " loss=MSE(),\n", + " valid_loss=MAE(),\n", + " early_stop_patience_steps=3,\n", + " batch_size=32)\n", + "\n", + "fcst = NeuralForecast(models=[model], freq='M')\n", + "fcst.fit(df=Y_train_df, static_df=AirPassengersStatic, val_size=12)\n", + "forecasts = fcst.predict(futr_df=Y_test_df)\n", + "\n", + "# Plot predictions\n", + "fig, ax = plt.subplots(1, 1, figsize = (20, 7))\n", + "Y_hat_df = forecasts.reset_index(drop=False).drop(columns=['unique_id','ds'])\n", + "plot_df = pd.concat([Y_test_df, Y_hat_df], axis=1)\n", + "plot_df = pd.concat([Y_train_df, plot_df])\n", + "\n", + "plot_df = plot_df[plot_df.unique_id=='Airline1'].drop('unique_id', axis=1)\n", + "plt.plot(plot_df['ds'], plot_df['y'], c='black', label='True')\n", + "plt.plot(plot_df['ds'], plot_df['RMoK'], c='blue', label='Forecast')\n", + "ax.set_title('AirPassengers Forecast', fontsize=22)\n", + "ax.set_ylabel('Monthly Passengers', fontsize=20)\n", + "ax.set_xlabel('Year', fontsize=20)\n", + "ax.legend(prop={'size': 15})\n", + "ax.grid()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/neuralforecast/_modidx.py b/neuralforecast/_modidx.py index a58491846..add079b58 100644 --- a/neuralforecast/_modidx.py +++ b/neuralforecast/_modidx.py @@ -96,6 +96,10 @@ 'neuralforecast/auto.py'), 'neuralforecast.auto.AutoPatchTST.get_default_config': ( 'models.html#autopatchtst.get_default_config', 'neuralforecast/auto.py'), + 'neuralforecast.auto.AutoRMoK': ('models.html#autormok', 'neuralforecast/auto.py'), + 'neuralforecast.auto.AutoRMoK.__init__': ('models.html#autormok.__init__', 'neuralforecast/auto.py'), + 'neuralforecast.auto.AutoRMoK.get_default_config': ( 'models.html#autormok.get_default_config', + 'neuralforecast/auto.py'), 'neuralforecast.auto.AutoRNN': ('models.html#autornn', 'neuralforecast/auto.py'), 'neuralforecast.auto.AutoRNN.__init__': ('models.html#autornn.__init__', 'neuralforecast/auto.py'), 'neuralforecast.auto.AutoRNN.get_default_config': ( 'models.html#autornn.get_default_config', @@ -1037,6 +1041,44 @@ 'neuralforecast/models/patchtst.py'), 'neuralforecast.models.patchtst.positional_encoding': ( 'models.patchtst.html#positional_encoding', 'neuralforecast/models/patchtst.py')}, + 'neuralforecast.models.rmok': { 'neuralforecast.models.rmok.JacobiKANLayer': ( 'models.rmok.html#jacobikanlayer', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.JacobiKANLayer.__init__': ( 'models.rmok.html#jacobikanlayer.__init__', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.JacobiKANLayer.forward': ( 'models.rmok.html#jacobikanlayer.forward', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.RMoK': ('models.rmok.html#rmok', 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.RMoK.__init__': ( 'models.rmok.html#rmok.__init__', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.RMoK.forward': ( 'models.rmok.html#rmok.forward', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.RevIN': ('models.rmok.html#revin', 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.RevIN.__init__': ( 'models.rmok.html#revin.__init__', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.RevIN._denormalize': ( 'models.rmok.html#revin._denormalize', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.RevIN._get_statistics': ( 'models.rmok.html#revin._get_statistics', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.RevIN._init_params': ( 'models.rmok.html#revin._init_params', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.RevIN._normalize': ( 'models.rmok.html#revin._normalize', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.RevIN.forward': ( 'models.rmok.html#revin.forward', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.TaylorKANLayer': ( 'models.rmok.html#taylorkanlayer', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.TaylorKANLayer.__init__': ( 'models.rmok.html#taylorkanlayer.__init__', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.TaylorKANLayer.forward': ( 'models.rmok.html#taylorkanlayer.forward', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.WaveKANLayer': ( 'models.rmok.html#wavekanlayer', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.WaveKANLayer.__init__': ( 'models.rmok.html#wavekanlayer.__init__', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.WaveKANLayer.forward': ( 'models.rmok.html#wavekanlayer.forward', + 'neuralforecast/models/rmok.py'), + 'neuralforecast.models.rmok.WaveKANLayer.wavelet_transform': ( 'models.rmok.html#wavekanlayer.wavelet_transform', + 'neuralforecast/models/rmok.py')}, 'neuralforecast.models.rnn': { 'neuralforecast.models.rnn.RNN': ('models.rnn.html#rnn', 'neuralforecast/models/rnn.py'), 'neuralforecast.models.rnn.RNN.__init__': ( 'models.rnn.html#rnn.__init__', 'neuralforecast/models/rnn.py'), diff --git a/neuralforecast/auto.py b/neuralforecast/auto.py index c69b406b1..b3c85892a 100644 --- a/neuralforecast/auto.py +++ b/neuralforecast/auto.py @@ -5,7 +5,7 @@ 'AutoNBEATSx', 'AutoNHITS', 'AutoDLinear', 'AutoNLinear', 'AutoTiDE', 'AutoDeepNPTS', 'AutoKAN', 'AutoTFT', 'AutoVanillaTransformer', 'AutoInformer', 'AutoAutoformer', 'AutoFEDformer', 'AutoPatchTST', 'AutoiTransformer', 'AutoTimesNet', 'AutoStemGNN', 'AutoHINT', 'AutoTSMixer', 'AutoTSMixerx', - 'AutoMLPMultivariate', 'AutoSOFTS', 'AutoTimeMixer'] + 'AutoMLPMultivariate', 'AutoSOFTS', 'AutoTimeMixer', 'AutoRMoK'] # %% ../nbs/models.ipynb 2 from os import cpu_count @@ -44,6 +44,7 @@ from .models.itransformer import iTransformer from .models.kan import KAN +from .models.rmok import RMoK from .models.stemgnn import StemGNN from .models.hint import HINT @@ -1108,7 +1109,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 74 +# %% ../nbs/models.ipynb 75 class AutoKAN(BaseAuto): default_config = { @@ -1177,7 +1178,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 79 +# %% ../nbs/models.ipynb 80 class AutoTFT(BaseAuto): default_config = { @@ -1245,7 +1246,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 83 +# %% ../nbs/models.ipynb 84 class AutoVanillaTransformer(BaseAuto): default_config = { @@ -1313,7 +1314,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 87 +# %% ../nbs/models.ipynb 88 class AutoInformer(BaseAuto): default_config = { @@ -1381,7 +1382,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 91 +# %% ../nbs/models.ipynb 92 class AutoAutoformer(BaseAuto): default_config = { @@ -1449,7 +1450,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 95 +# %% ../nbs/models.ipynb 96 class AutoFEDformer(BaseAuto): default_config = { @@ -1516,7 +1517,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 99 +# %% ../nbs/models.ipynb 100 class AutoPatchTST(BaseAuto): default_config = { @@ -1586,7 +1587,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 103 +# %% ../nbs/models.ipynb 104 class AutoiTransformer(BaseAuto): default_config = { @@ -1671,7 +1672,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 108 +# %% ../nbs/models.ipynb 109 class AutoTimesNet(BaseAuto): default_config = { @@ -1739,7 +1740,7 @@ def get_default_config(cls, h, backend, n_series=None): return config -# %% ../nbs/models.ipynb 113 +# %% ../nbs/models.ipynb 114 class AutoStemGNN(BaseAuto): default_config = { @@ -1824,7 +1825,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 117 +# %% ../nbs/models.ipynb 118 class AutoHINT(BaseAuto): def __init__( @@ -1896,7 +1897,7 @@ def _fit_model( def get_default_config(cls, h, backend, n_series=None): raise Exception("AutoHINT has no default configuration.") -# %% ../nbs/models.ipynb 122 +# %% ../nbs/models.ipynb 123 class AutoTSMixer(BaseAuto): default_config = { @@ -1982,7 +1983,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 126 +# %% ../nbs/models.ipynb 127 class AutoTSMixerx(BaseAuto): default_config = { @@ -2068,7 +2069,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 130 +# %% ../nbs/models.ipynb 131 class AutoMLPMultivariate(BaseAuto): default_config = { @@ -2153,7 +2154,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 134 +# %% ../nbs/models.ipynb 135 class AutoSOFTS(BaseAuto): default_config = { @@ -2238,7 +2239,7 @@ def get_default_config(cls, h, backend, n_series): return config -# %% ../nbs/models.ipynb 138 +# %% ../nbs/models.ipynb 139 class AutoTimeMixer(BaseAuto): default_config = { @@ -2323,3 +2324,91 @@ def get_default_config(cls, h, backend, n_series): config = cls._ray_config_to_optuna(config) return config + +# %% ../nbs/models.ipynb 143 +class AutoRMoK(BaseAuto): + + default_config = { + "input_size_multiplier": [1, 2, 3, 4, 5], + "h": None, + "n_series": None, + "taylor_order": tune.choice([3, 4, 5]), + "jacobi_degree": tune.choice([4, 5, 6]), + "wavelet_function": tune.choice( + ["mexican_hat", "morlet", "dog", "meyer", "shannon"] + ), + "learning_rate": tune.loguniform(1e-4, 1e-1), + "scaler_type": tune.choice([None, "robust", "standard", "identity"]), + "max_steps": tune.choice([500, 1000, 2000]), + "batch_size": tune.choice([32, 64, 128, 256]), + "loss": None, + "random_seed": tune.randint(1, 20), + } + + def __init__( + self, + h, + n_series, + loss=MAE(), + valid_loss=None, + config=None, + search_alg=BasicVariantGenerator(random_state=1), + num_samples=10, + refit_with_val=False, + cpus=cpu_count(), + gpus=torch.cuda.device_count(), + verbose=False, + alias=None, + backend="ray", + callbacks=None, + ): + + # Define search space, input/output sizes + if config is None: + config = self.get_default_config(h=h, backend=backend, n_series=n_series) + + # Always use n_series from parameters, raise exception with Optuna because we can't enforce it + if backend == "ray": + config["n_series"] = n_series + elif backend == "optuna": + mock_trial = MockTrial() + if ( + "n_series" in config(mock_trial) + and config(mock_trial)["n_series"] != n_series + ) or ("n_series" not in config(mock_trial)): + raise Exception(f"config needs 'n_series': {n_series}") + + super(AutoRMoK, self).__init__( + cls_model=RMoK, + h=h, + loss=loss, + valid_loss=valid_loss, + config=config, + search_alg=search_alg, + num_samples=num_samples, + refit_with_val=refit_with_val, + cpus=cpus, + gpus=gpus, + verbose=verbose, + alias=alias, + backend=backend, + callbacks=callbacks, + ) + + @classmethod + def get_default_config(cls, h, backend, n_series): + config = cls.default_config.copy() + config["input_size"] = tune.choice( + [h * x for x in config["input_size_multiplier"]] + ) + + # Rolling windows with step_size=1 or step_size=h + # See `BaseWindows` and `BaseRNN`'s create_windows + config["step_size"] = tune.choice([1, h]) + del config["input_size_multiplier"] + if backend == "optuna": + # Always use n_series from parameters + config["n_series"] = n_series + config = cls._ray_config_to_optuna(config) + + return config diff --git a/neuralforecast/models/__init__.py b/neuralforecast/models/__init__.py index ee404e31c..414689631 100644 --- a/neuralforecast/models/__init__.py +++ b/neuralforecast/models/__init__.py @@ -2,7 +2,7 @@ 'MLP', 'NHITS', 'NBEATS', 'NBEATSx', 'DLinear', 'NLinear', 'TFT', 'VanillaTransformer', 'Informer', 'Autoformer', 'PatchTST', 'FEDformer', 'StemGNN', 'HINT', 'TimesNet', 'TimeLLM', 'TSMixer', 'TSMixerx', 'MLPMultivariate', - 'iTransformer', 'BiTCN', 'TiDE', 'DeepNPTS', 'SOFTS', 'TimeMixer', 'KAN' + 'iTransformer', 'BiTCN', 'TiDE', 'DeepNPTS', 'SOFTS', 'TimeMixer', 'KAN', 'RMoK', ] from .rnn import RNN @@ -37,3 +37,4 @@ from .softs import SOFTS from .timemixer import TimeMixer from .kan import KAN +from .rmok import RMoK diff --git a/neuralforecast/models/rmok.py b/neuralforecast/models/rmok.py new file mode 100644 index 000000000..c83585e1b --- /dev/null +++ b/neuralforecast/models/rmok.py @@ -0,0 +1,473 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/models.rmok.ipynb. + +# %% auto 0 +__all__ = ['WaveKANLayer', 'TaylorKANLayer', 'JacobiKANLayer', 'RevIN', 'RMoK'] + +# %% ../../nbs/models.rmok.ipynb 6 +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..losses.pytorch import MAE +from ..common._base_multivariate import BaseMultivariate + +# %% ../../nbs/models.rmok.ipynb 8 +class WaveKANLayer(nn.Module): + """This is a sample code for the simulations of the paper: + Bozorgasl, Zavareh and Chen, Hao, Wav-KAN: Wavelet Kolmogorov-Arnold Networks (May, 2024) + + https://arxiv.org/abs/2405.12832 + and also available at: + https://papers.ssrn.com/sol3/papers.cfm?abstract_id=4835325 + We used efficient KAN notation and some part of the code:+ + + """ + + def __init__( + self, + in_features, + out_features, + wavelet_type="mexican_hat", + with_bn=True, + device="cpu", + ): + super(WaveKANLayer, self).__init__() + self.in_features = in_features + self.out_features = out_features + self.wavelet_type = wavelet_type + self.with_bn = with_bn + + # Parameters for wavelet transformation + self.scale = nn.Parameter(torch.ones(out_features, in_features)) + self.translation = nn.Parameter(torch.zeros(out_features, in_features)) + + # self.weight1 is not used; you may use it for weighting base activation and adding it like Spl-KAN paper + self.weight1 = nn.Parameter(torch.Tensor(out_features, in_features)) + self.wavelet_weights = nn.Parameter(torch.Tensor(out_features, in_features)) + + nn.init.kaiming_uniform_(self.wavelet_weights, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) + + # Base activation function #not used for this experiment + self.base_activation = nn.SiLU() + + # Batch normalization + if self.with_bn: + self.bn = nn.BatchNorm1d(out_features) + + def wavelet_transform(self, x): + if x.dim() == 2: + x_expanded = x.unsqueeze(1) + else: + x_expanded = x + + translation_expanded = self.translation.unsqueeze(0).expand(x.size(0), -1, -1) + scale_expanded = self.scale.unsqueeze(0).expand(x.size(0), -1, -1) + x_scaled = (x_expanded - translation_expanded) / scale_expanded + + # Implementation of different wavelet types + if self.wavelet_type == "mexican_hat": + term1 = (x_scaled**2) - 1 + term2 = torch.exp(-0.5 * x_scaled**2) + wavelet = (2 / (math.sqrt(3) * math.pi**0.25)) * term1 * term2 + wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as( + wavelet + ) + wavelet_output = wavelet_weighted.sum(dim=2) + elif self.wavelet_type == "morlet": + omega0 = 5.0 # Central frequency + real = torch.cos(omega0 * x_scaled) + envelope = torch.exp(-0.5 * x_scaled**2) + wavelet = envelope * real + wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as( + wavelet + ) + wavelet_output = wavelet_weighted.sum(dim=2) + + elif self.wavelet_type == "dog": + # Implementing Derivative of Gaussian Wavelet + dog = -x_scaled * torch.exp(-0.5 * x_scaled**2) + wavelet = dog + wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as( + wavelet + ) + wavelet_output = wavelet_weighted.sum(dim=2) + elif self.wavelet_type == "meyer": + # Implement Meyer Wavelet here + # Constants for the Meyer wavelet transition boundaries + v = torch.abs(x_scaled) + pi = math.pi + + def meyer_aux(v): + return torch.where( + v <= 1 / 2, + torch.ones_like(v), + torch.where( + v >= 1, torch.zeros_like(v), torch.cos(pi / 2 * nu(2 * v - 1)) + ), + ) + + def nu(t): + return t**4 * (35 - 84 * t + 70 * t**2 - 20 * t**3) + + # Meyer wavelet calculation using the auxiliary function + wavelet = torch.sin(pi * v) * meyer_aux(v) + wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as( + wavelet + ) + wavelet_output = wavelet_weighted.sum(dim=2) + elif self.wavelet_type == "shannon": + # Windowing the sinc function to limit its support + pi = math.pi + sinc = torch.sinc(x_scaled / pi) # sinc(x) = sin(pi*x) / (pi*x) + + # Applying a Hamming window to limit the infinite support of the sinc function + window = torch.hamming_window( + x_scaled.size(-1), + periodic=False, + dtype=x_scaled.dtype, + device=x_scaled.device, + ) + # Shannon wavelet is the product of the sinc function and the window + wavelet = sinc * window + wavelet_weighted = wavelet * self.wavelet_weights.unsqueeze(0).expand_as( + wavelet + ) + wavelet_output = wavelet_weighted.sum(dim=2) + # You can try many more wavelet types ... + else: + raise ValueError("Unsupported wavelet type") + + return wavelet_output + + def forward(self, x): + wavelet_output = self.wavelet_transform(x) + # You may like test the cases like Spl-KAN + # wav_output = F.linear(wavelet_output, self.weight) + # base_output = F.linear(self.base_activation(x), self.weight1) + + # base_output = F.linear(x, self.weight1) + combined_output = wavelet_output # + base_output + + # Apply batch normalization + if self.with_bn: + return self.bn(combined_output) + else: + return combined_output + +# %% ../../nbs/models.rmok.ipynb 10 +class TaylorKANLayer(nn.Module): + """ + https://github.com/Muyuzhierchengse/TaylorKAN/ + """ + + def __init__(self, input_dim, out_dim, order, addbias=True): + super(TaylorKANLayer, self).__init__() + self.input_dim = input_dim + self.out_dim = out_dim + self.order = order + self.addbias = addbias + + self.coeffs = nn.Parameter(torch.randn(out_dim, input_dim, order) * 0.01) + if self.addbias: + self.bias = nn.Parameter(torch.zeros(1, out_dim)) + + def forward(self, x): + shape = x.shape + outshape = shape[0:-1] + (self.out_dim,) + x = torch.reshape(x, (-1, self.input_dim)) + x_expanded = x.unsqueeze(1).expand(-1, self.out_dim, -1) + + y = torch.zeros((x.shape[0], self.out_dim), device=x.device) + + for i in range(self.order): + term = (x_expanded**i) * self.coeffs[:, :, i] + y += term.sum(dim=-1) + + if self.addbias: + y += self.bias + + y = torch.reshape(y, outshape) + return y + +# %% ../../nbs/models.rmok.ipynb 12 +class JacobiKANLayer(nn.Module): + """ + https://github.com/SpaceLearner/JacobiKAN/blob/main/JacobiKANLayer.py + """ + + def __init__(self, input_dim, output_dim, degree, a=1.0, b=1.0): + super(JacobiKANLayer, self).__init__() + self.inputdim = input_dim + self.outdim = output_dim + self.a = a + self.b = b + self.degree = degree + + self.jacobi_coeffs = nn.Parameter( + torch.empty(input_dim, output_dim, degree + 1) + ) + + nn.init.normal_( + self.jacobi_coeffs, mean=0.0, std=1 / (input_dim * (degree + 1)) + ) + + def forward(self, x): + x = torch.reshape(x, (-1, self.inputdim)) # shape = (batch_size, inputdim) + # Since Jacobian polynomial is defined in [-1, 1] + # We need to normalize x to [-1, 1] using tanh + x = torch.tanh(x) + # Initialize Jacobian polynomial tensors + jacobi = torch.ones(x.shape[0], self.inputdim, self.degree + 1, device=x.device) + if ( + self.degree > 0 + ): ## degree = 0: jacobi[:, :, 0] = 1 (already initialized) ; degree = 1: jacobi[:, :, 1] = x ; d + jacobi[:, :, 1] = ((self.a - self.b) + (self.a + self.b + 2) * x) / 2 + for i in range(2, self.degree + 1): + theta_k = ( + (2 * i + self.a + self.b) + * (2 * i + self.a + self.b - 1) + / (2 * i * (i + self.a + self.b)) + ) + theta_k1 = ( + (2 * i + self.a + self.b - 1) + * (self.a * self.a - self.b * self.b) + / (2 * i * (i + self.a + self.b) * (2 * i + self.a + self.b - 2)) + ) + theta_k2 = ( + (i + self.a - 1) + * (i + self.b - 1) + * (2 * i + self.a + self.b) + / (i * (i + self.a + self.b) * (2 * i + self.a + self.b - 2)) + ) + jacobi[:, :, i] = (theta_k * x + theta_k1) * jacobi[ + :, :, i - 1 + ].clone() - theta_k2 * jacobi[ + :, :, i - 2 + ].clone() # 2 * x * jacobi[:, :, i - 1].clone() - jacobi[:, :, i - 2].clone() + # Compute the Jacobian interpolation + y = torch.einsum( + "bid,iod->bo", jacobi, self.jacobi_coeffs + ) # shape = (batch_size, outdim) + y = y.view(-1, self.outdim) + return y + +# %% ../../nbs/models.rmok.ipynb 14 +class RevIN(nn.Module): + def __init__(self, num_features: int, eps=1e-5, affine=True): + """ + :param num_features: the number of features or channels + :param eps: a value added for numerical stability + :param affine: if True, RevIN has learnable affine parameters + """ + super(RevIN, self).__init__() + + self.num_features = num_features + self.eps = eps + self.affine = affine + + if self.affine: + self._init_params() + + def forward(self, x, mode: str): + if mode == "norm": + self._get_statistics(x) + x = self._normalize(x) + + elif mode == "denorm": + x = self._denormalize(x) + + else: + raise NotImplementedError + + return x + + def _init_params(self): + # initialize RevIN params: (C,) + self.affine_weight = nn.Parameter(torch.ones(self.num_features)) + self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) + + def _get_statistics(self, x): + dim2reduce = tuple(range(1, x.ndim - 1)) + self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() + self.stdev = torch.sqrt( + torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps + ).detach() + + def _normalize(self, x): + x = x - self.mean + x = x / self.stdev + if self.affine: + x = x * self.affine_weight + x = x + self.affine_bias + + return x + + def _denormalize(self, x): + if self.affine: + x = x - self.affine_bias + x = x / (self.affine_weight + self.eps * self.eps) + x = x * self.stdev + x = x + self.mean + + return x + +# %% ../../nbs/models.rmok.ipynb 16 +class RMoK(BaseMultivariate): + """Reversible Mixture of KAN + **Parameters**
+ `h`: int, Forecast horizon.
+ `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].
+ `n_series`: int, number of time-series.
+ `futr_exog_list`: str list, future exogenous columns.
+ `hist_exog_list`: str list, historic exogenous columns.
+ `stat_exog_list`: str list, static exogenous columns.
+ `taylor_order`: int, order of the Taylor polynomial.
+ `jacobi_degree`: int, degree of the Jacobi polynomial.
+ `wavelet_function`: str, wavelet function to use in the WaveKAN. Choose from ["mexican_hat", "morlet", "dog", "meyer", "shannon"]
+ `dropout`: float, dropout rate.
+ `revin_affine`: bool=False, bool to use affine in RevIn.
+ `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
+ `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
+ `max_steps`: int=1000, maximum number of training steps.
+ `learning_rate`: float=1e-3, Learning rate between (0, 1).
+ `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
+ `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
+ `val_check_steps`: int=100, Number of training steps between every validation loss check.
+ `batch_size`: int=32, number of different series in each batch.
+ `step_size`: int=1, step size between each window of temporal data.
+ `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
+ `random_seed`: int=1, random_seed for pytorch initializer and numpy generators.
+ `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
+ `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
+ `alias`: str, optional, Custom name of the model.
+ `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
+ `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.
+ `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).
+ `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.
+ `**trainer_kwargs`: int, keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).
+ + Reference
+ [Xiao Han, Xinfeng Zhang, Yiling Wu, Zhenduo Zhang, Zhe Wu."KAN4TSF: Are KAN and KAN-based models Effective for Time Series Forecasting?"](https://arxiv.org/abs/2408.11306) + """ + + # Class attributes + SAMPLING_TYPE = "multivariate" + EXOGENOUS_FUTR = False + EXOGENOUS_HIST = False + EXOGENOUS_STAT = False + + def __init__( + self, + h, + input_size, + n_series, + futr_exog_list=None, + hist_exog_list=None, + stat_exog_list=None, + taylor_order: int = 3, + jacobi_degree: int = 6, + wavelet_function: str = "mexican_hat", + dropout: float = 0.1, + revine_affine: bool = True, + loss=MAE(), + valid_loss=None, + max_steps: int = 1000, + learning_rate: float = 1e-3, + num_lr_decays: int = -1, + early_stop_patience_steps: int = -1, + val_check_steps: int = 100, + batch_size: int = 32, + step_size: int = 1, + scaler_type: str = "identity", + random_seed: int = 1, + num_workers_loader: int = 0, + drop_last_loader: bool = False, + optimizer=None, + optimizer_kwargs=None, + lr_scheduler=None, + lr_scheduler_kwargs=None, + **trainer_kwargs + ): + + super(RMoK, self).__init__( + h=h, + input_size=input_size, + n_series=n_series, + stat_exog_list=None, + futr_exog_list=None, + hist_exog_list=None, + loss=loss, + valid_loss=valid_loss, + max_steps=max_steps, + learning_rate=learning_rate, + num_lr_decays=num_lr_decays, + early_stop_patience_steps=early_stop_patience_steps, + val_check_steps=val_check_steps, + batch_size=batch_size, + step_size=step_size, + scaler_type=scaler_type, + random_seed=random_seed, + num_workers_loader=num_workers_loader, + drop_last_loader=drop_last_loader, + optimizer=optimizer, + optimizer_kwargs=optimizer_kwargs, + lr_scheduler=lr_scheduler, + lr_scheduler_kwargs=lr_scheduler_kwargs, + **trainer_kwargs + ) + + self.input_size = input_size + self.h = h + self.n_series = n_series + self.dropout = nn.Dropout(dropout) + self.revin_affine = revine_affine + + self.taylor_order = taylor_order + self.jacobi_degree = jacobi_degree + self.wavelet_function = wavelet_function + + self.experts = nn.ModuleList( + [ + TaylorKANLayer( + self.input_size, self.h, order=self.taylor_order, addbias=True + ), + JacobiKANLayer(self.input_size, self.h, degree=self.jacobi_degree), + WaveKANLayer( + self.input_size, self.h, wavelet_type=self.wavelet_function + ), + nn.Linear(self.input_size, self.h), + ] + ) + + self.num_experts = len(self.experts) + self.gate = nn.Linear(self.input_size, self.num_experts) + self.softmax = nn.Softmax(dim=-1) + self.rev = RevIN(self.n_series, affine=self.revin_affine) + + def forward(self, windows_batch): + insample_y = windows_batch["insample_y"] + B, L, N = insample_y.shape + x = self.rev(insample_y, "norm") if self.rev else insample_y + x = self.dropout(x).transpose(1, 2).reshape(B * N, L) + + score = F.softmax(self.gate(x), dim=-1) + expert_outputs = torch.stack( + [self.experts[i](x) for i in range(self.num_experts)], dim=-1 + ) + + y_pred = ( + torch.einsum("BLE,BE->BL", expert_outputs, score) + .reshape(B, N, -1) + .permute(0, 2, 1) + ) + y_pred = self.rev(y_pred, "denorm") + y_pred = self.loss.domain_map(y_pred) + + # domain_map might have squeezed the last dimension in case n_series == 1 + if y_pred.ndim == 2: + return y_pred.unsqueeze(-1) + else: + return y_pred