diff --git a/nbs/models.tft.ipynb b/nbs/models.tft.ipynb
index ec38bf09a..2207fc64d 100644
--- a/nbs/models.tft.ipynb
+++ b/nbs/models.tft.ipynb
@@ -1,5 +1,22 @@
{
"cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "env: PYTORCH_ENABLE_MPS_FALLBACK=1\n"
+ ]
+ }
+ ],
+ "source": [
+ "%set_env PYTORCH_ENABLE_MPS_FALLBACK=1"
+ ]
+ },
{
"cell_type": "code",
"execution_count": null,
@@ -44,7 +61,7 @@
"outputs": [],
"source": [
"#| export\n",
- "from typing import Tuple, Optional\n",
+ "from typing import Tuple, Optional, Callable\n",
"\n",
"import torch\n",
"import torch.nn as nn\n",
@@ -124,6 +141,19 @@
"outputs": [],
"source": [
"#| exporti\n",
+ "def get_activation_fn(activation_str: str) -> Callable:\n",
+ " activation_map = {\n",
+ " 'ReLU': F.relu,\n",
+ " 'Softplus': F.softplus,\n",
+ " 'Tanh': F.tanh,\n",
+ " 'SELU': F.selu,\n",
+ " 'LeakyReLU': F.leaky_relu,\n",
+ " 'Sigmoid': F.sigmoid,\n",
+ " 'ELU': F.elu,\n",
+ " 'GLU': F.glu\n",
+ " }\n",
+ " return activation_map.get(activation_str, F.elu)\n",
+ "\n",
"class MaybeLayerNorm(nn.Module):\n",
" def __init__(self, output_size, hidden_size, eps):\n",
" super().__init__()\n",
@@ -149,12 +179,12 @@
"class GRN(nn.Module):\n",
" def __init__(self,\n",
" input_size,\n",
- " hidden_size, \n",
+ " hidden_size,\n",
" output_size=None,\n",
" context_hidden_size=None,\n",
- " dropout=0):\n",
+ " dropout=0,\n",
+ " activation='ELU',):\n",
" super().__init__()\n",
- " \n",
" self.layer_norm = MaybeLayerNorm(output_size, hidden_size, eps=1e-3)\n",
" self.lin_a = nn.Linear(input_size, hidden_size)\n",
" if context_hidden_size is not None:\n",
@@ -163,12 +193,13 @@
" self.glu = GLU(hidden_size, output_size if output_size else hidden_size)\n",
" self.dropout = nn.Dropout(dropout)\n",
" self.out_proj = nn.Linear(input_size, output_size) if output_size else None\n",
- "\n",
+ " self.activation_fn = get_activation_fn(activation)\n",
+ " \n",
" def forward(self, a: Tensor, c: Optional[Tensor] = None):\n",
" x = self.lin_a(a)\n",
" if c is not None:\n",
" x = x + self.lin_c(c).unsqueeze(1)\n",
- " x = F.elu(x)\n",
+ " x = self.activation_fn(x)\n",
" x = self.lin_i(x)\n",
" x = self.dropout(x)\n",
" x = self.glu(x)\n",
@@ -292,15 +323,16 @@
" return s_inp, k_inp, o_inp, target_inp\n",
"\n",
"class VariableSelectionNetwork(nn.Module):\n",
- " def __init__(self, hidden_size, num_inputs, dropout):\n",
+ " def __init__(self, hidden_size, num_inputs, dropout, grn_activation):\n",
" super().__init__()\n",
" self.joint_grn = GRN(input_size=hidden_size*num_inputs, \n",
" hidden_size=hidden_size, \n",
" output_size=num_inputs, \n",
- " context_hidden_size=hidden_size)\n",
+ " context_hidden_size=hidden_size,\n",
+ " activation=grn_activation)\n",
" self.var_grns = nn.ModuleList(\n",
" [GRN(input_size=hidden_size, \n",
- " hidden_size=hidden_size, dropout=dropout)\n",
+ " hidden_size=hidden_size, dropout=dropout, activation=grn_activation)\n",
" for _ in range(num_inputs)])\n",
"\n",
" def forward(self, x: Tensor, context: Optional[Tensor] = None):\n",
@@ -445,10 +477,10 @@
"source": [
"#| exporti\n",
"class StaticCovariateEncoder(nn.Module):\n",
- " def __init__(self, hidden_size, num_static_vars, dropout):\n",
+ " def __init__(self, hidden_size, num_static_vars, dropout, grn_activation):\n",
" super().__init__()\n",
" self.vsn = VariableSelectionNetwork(\n",
- " hidden_size=hidden_size, num_inputs=num_static_vars, dropout=dropout\n",
+ " hidden_size=hidden_size, num_inputs=num_static_vars, dropout=dropout, grn_activation=grn_activation\n",
" )\n",
" self.context_grns = nn.ModuleList(\n",
" [\n",
@@ -502,18 +534,18 @@
"source": [
"#| exporti\n",
"class TemporalCovariateEncoder(nn.Module):\n",
- " def __init__(self, hidden_size, num_historic_vars, num_future_vars, dropout):\n",
+ " def __init__(self, hidden_size, num_historic_vars, num_future_vars, dropout, grn_activation):\n",
" super(TemporalCovariateEncoder, self).__init__()\n",
"\n",
" self.history_vsn = VariableSelectionNetwork(\n",
- " hidden_size=hidden_size, num_inputs=num_historic_vars, dropout=dropout\n",
+ " hidden_size=hidden_size, num_inputs=num_historic_vars, dropout=dropout, grn_activation=grn_activation\n",
" )\n",
" self.history_encoder = nn.LSTM(\n",
" input_size=hidden_size, hidden_size=hidden_size, batch_first=True\n",
" )\n",
"\n",
" self.future_vsn = VariableSelectionNetwork(\n",
- " hidden_size=hidden_size, num_inputs=num_future_vars, dropout=dropout\n",
+ " hidden_size=hidden_size, num_inputs=num_future_vars, dropout=dropout, grn_activation=grn_activation\n",
" )\n",
" self.future_encoder = nn.LSTM(\n",
" input_size=hidden_size, hidden_size=hidden_size, batch_first=True\n",
@@ -567,7 +599,7 @@
"#| exporti\n",
"class TemporalFusionDecoder(nn.Module):\n",
" def __init__(\n",
- " self, n_head, hidden_size, example_length, encoder_length, attn_dropout, dropout\n",
+ " self, n_head, hidden_size, example_length, encoder_length, attn_dropout, dropout, grn_activation\n",
" ):\n",
" super(TemporalFusionDecoder, self).__init__()\n",
" self.encoder_length = encoder_length\n",
@@ -578,6 +610,7 @@
" hidden_size=hidden_size,\n",
" context_hidden_size=hidden_size,\n",
" dropout=dropout,\n",
+ " activation=grn_activation\n",
" )\n",
" self.attention = InterpretableMultiHeadAttention(\n",
" n_head=n_head,\n",
@@ -590,7 +623,7 @@
" self.attention_ln = LayerNorm(normalized_shape=hidden_size, eps=1e-3)\n",
"\n",
" self.positionwise_grn = GRN(\n",
- " input_size=hidden_size, hidden_size=hidden_size, dropout=dropout\n",
+ " input_size=hidden_size, hidden_size=hidden_size, dropout=dropout, activation=grn_activation\n",
" )\n",
"\n",
" # ---------------------- Decoder -----------------------#\n",
@@ -652,8 +685,7 @@
" `dropout`: float (0, 1), dropout of inputs VSNs.
\n",
" `n_head`: int=4, number of attention heads in temporal fusion decoder.
\n",
" `attn_dropout`: float (0, 1), dropout of fusion decoder's attention layer.
\n",
- " `shared_weights`: bool, If True, all blocks within each stack will share parameters.
\n",
- " `activation`: str, activation from ['ReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'PReLU', 'Sigmoid'].
\n",
+ " `grn_activation`: str, activation for the GRN module from ['ReLU', 'Softplus', 'Tanh', 'SELU', 'LeakyReLU', 'Sigmoid', 'ELU', 'GLU'].
\n",
" `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n",
" `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
\n",
" `max_steps`: int=1000, maximum number of training steps.
\n",
@@ -700,6 +732,7 @@
" hidden_size: int = 128,\n",
" n_head: int = 4,\n",
" attn_dropout: float = 0.0,\n",
+ " grn_activation: str = 'ELU',\n",
" dropout: float = 0.1,\n",
" loss=MAE(),\n",
" valid_loss=None,\n",
@@ -758,6 +791,7 @@
" self.example_length = input_size + h\n",
" self.interpretability_params = dict([]) # type: ignore\n",
" self.tgt_size = tgt_size\n",
+ " self.grn_activation = grn_activation\n",
" futr_exog_size = max(self.futr_exog_size, 1)\n",
" num_historic_vars = futr_exog_size + self.hist_exog_size + tgt_size\n",
"\n",
@@ -772,13 +806,15 @@
" self.static_encoder = StaticCovariateEncoder(\n",
" hidden_size=hidden_size,\n",
" num_static_vars=self.stat_exog_size,\n",
- " dropout=dropout)\n",
+ " dropout=dropout,\n",
+ " grn_activation=self.grn_activation)\n",
"\n",
" self.temporal_encoder = TemporalCovariateEncoder(\n",
" hidden_size=hidden_size,\n",
" num_historic_vars=num_historic_vars,\n",
" num_future_vars=futr_exog_size,\n",
" dropout=dropout,\n",
+ " grn_activation=self.grn_activation\n",
" )\n",
"\n",
" # ------------------------------ Decoders -----------------------------#\n",
@@ -789,6 +825,7 @@
" encoder_length=self.input_size,\n",
" attn_dropout=attn_dropout,\n",
" dropout=dropout,\n",
+ " grn_activation=self.grn_activation\n",
" )\n",
"\n",
" # Adapter with Loss dependent dimensions\n",
@@ -992,7 +1029,7 @@
"> TFT.fit (dataset, val_size=0, test_size=0, random_seed=None,\n",
"> distributed_config=None)\n",
"\n",
- "*Fit.\n",
+ "Fit.\n",
"\n",
"The `fit` method, optimizes the neural network's weights using the\n",
"initialization parameters (`learning_rate`, `windows_batch_size`, ...)\n",
@@ -1011,7 +1048,7 @@
"`dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).
\n",
"`val_size`: int, validation size for temporal cross-validation.
\n",
"`random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.
\n",
- "`test_size`: int, test size for temporal cross-validation.
*"
+ "`test_size`: int, test size for temporal cross-validation.
"
],
"text/plain": [
"---\n",
@@ -1021,7 +1058,7 @@
"> TFT.fit (dataset, val_size=0, test_size=0, random_seed=None,\n",
"> distributed_config=None)\n",
"\n",
- "*Fit.\n",
+ "Fit.\n",
"\n",
"The `fit` method, optimizes the neural network's weights using the\n",
"initialization parameters (`learning_rate`, `windows_batch_size`, ...)\n",
@@ -1040,7 +1077,7 @@
"`dataset`: NeuralForecast's `TimeSeriesDataset`, see [documentation](https://nixtla.github.io/neuralforecast/tsdataset.html).
\n",
"`val_size`: int, validation size for temporal cross-validation.
\n",
"`random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.
\n",
- "`test_size`: int, test size for temporal cross-validation.
*"
+ "`test_size`: int, test size for temporal cross-validation.
"
]
},
"execution_count": null,
@@ -1067,7 +1104,7 @@
"> TFT.predict (dataset, test_size=None, step_size=1, random_seed=None,\n",
"> **data_module_kwargs)\n",
"\n",
- "*Predict.\n",
+ "Predict.\n",
"\n",
"Neural network prediction with PL's `Trainer` execution of `predict_step`.\n",
"\n",
@@ -1076,7 +1113,7 @@
"`test_size`: int=None, test size for temporal cross-validation.
\n",
"`step_size`: int=1, Step size between each window.
\n",
"`random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.
\n",
- "`**data_module_kwargs`: PL's TimeSeriesDataModule args, see [documentation](https://pytorch-lightning.readthedocs.io/en/1.6.1/extensions/datamodules.html#using-a-datamodule).*"
+ "`**data_module_kwargs`: PL's TimeSeriesDataModule args, see [documentation](https://pytorch-lightning.readthedocs.io/en/1.6.1/extensions/datamodules.html#using-a-datamodule)."
],
"text/plain": [
"---\n",
@@ -1086,7 +1123,7 @@
"> TFT.predict (dataset, test_size=None, step_size=1, random_seed=None,\n",
"> **data_module_kwargs)\n",
"\n",
- "*Predict.\n",
+ "Predict.\n",
"\n",
"Neural network prediction with PL's `Trainer` execution of `predict_step`.\n",
"\n",
@@ -1095,7 +1132,7 @@
"`test_size`: int=None, test size for temporal cross-validation.
\n",
"`step_size`: int=1, Step size between each window.
\n",
"`random_seed`: int=None, random_seed for pytorch initializer and numpy generators, overwrites model.__init__'s.
\n",
- "`**data_module_kwargs`: PL's TimeSeriesDataModule args, see [documentation](https://pytorch-lightning.readthedocs.io/en/1.6.1/extensions/datamodules.html#using-a-datamodule).*"
+ "`**data_module_kwargs`: PL's TimeSeriesDataModule args, see [documentation](https://pytorch-lightning.readthedocs.io/en/1.6.1/extensions/datamodules.html#using-a-datamodule)."
]
},
"execution_count": null,
@@ -1117,34 +1154,34 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L630){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L678){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TFT.feature_importances,\n",
"\n",
"> TFT.feature_importances, ()\n",
"\n",
- "*Compute the feature importances for historical, future, and static features.\n",
+ "Compute the feature importances for historical, future, and static features.\n",
"\n",
"Returns:\n",
" dict: A dictionary containing the feature importances for each feature type.\n",
" The keys are 'hist_vsn', 'future_vsn', and 'static_vsn', and the values\n",
- " are pandas DataFrames with the corresponding feature importances.*"
+ " are pandas DataFrames with the corresponding feature importances."
],
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L630){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L678){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TFT.feature_importances,\n",
"\n",
"> TFT.feature_importances, ()\n",
"\n",
- "*Compute the feature importances for historical, future, and static features.\n",
+ "Compute the feature importances for historical, future, and static features.\n",
"\n",
"Returns:\n",
" dict: A dictionary containing the feature importances for each feature type.\n",
" The keys are 'hist_vsn', 'future_vsn', and 'static_vsn', and the values\n",
- " are pandas DataFrames with the corresponding feature importances.*"
+ " are pandas DataFrames with the corresponding feature importances."
]
},
"execution_count": null,
@@ -1166,30 +1203,30 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L688){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L736){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TFT.attention_weights\n",
"\n",
"> TFT.attention_weights ()\n",
"\n",
- "*Batch average attention weights\n",
+ "Batch average attention weights\n",
"\n",
"Returns:\n",
- "np.ndarray: A 1D array containing the attention weights for each time step.*"
+ "np.ndarray: A 1D array containing the attention weights for each time step."
],
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L688){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L736){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TFT.attention_weights\n",
"\n",
"> TFT.attention_weights ()\n",
"\n",
- "*Batch average attention weights\n",
+ "Batch average attention weights\n",
"\n",
"Returns:\n",
- "np.ndarray: A 1D array containing the attention weights for each time step.*"
+ "np.ndarray: A 1D array containing the attention weights for each time step."
]
},
"execution_count": null,
@@ -1211,30 +1248,30 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L688){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L736){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TFT.attention_weights\n",
"\n",
"> TFT.attention_weights ()\n",
"\n",
- "*Batch average attention weights\n",
+ "Batch average attention weights\n",
"\n",
"Returns:\n",
- "np.ndarray: A 1D array containing the attention weights for each time step.*"
+ "np.ndarray: A 1D array containing the attention weights for each time step."
],
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L688){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L736){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TFT.attention_weights\n",
"\n",
"> TFT.attention_weights ()\n",
"\n",
- "*Batch average attention weights\n",
+ "Batch average attention weights\n",
"\n",
"Returns:\n",
- "np.ndarray: A 1D array containing the attention weights for each time step.*"
+ "np.ndarray: A 1D array containing the attention weights for each time step."
]
},
"execution_count": null,
@@ -1256,30 +1293,30 @@
"text/markdown": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L706){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L754){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TFT.feature_importance_correlations\n",
"\n",
"> TFT.feature_importance_correlations ()\n",
"\n",
- "*Compute the correlation between the past and future feature importances and the mean attention weights.\n",
+ "Compute the correlation between the past and future feature importances and the mean attention weights.\n",
"\n",
"Returns:\n",
- "pd.DataFrame: A DataFrame containing the correlation coefficients between the past feature importances and the mean attention weights.*"
+ "pd.DataFrame: A DataFrame containing the correlation coefficients between the past feature importances and the mean attention weights."
],
"text/plain": [
"---\n",
"\n",
- "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L706){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
+ "[source](https://github.com/Nixtla/neuralforecast/blob/main/neuralforecast/models/tft.py#L754){target=\"_blank\" style=\"float:right; font-size:smaller\"}\n",
"\n",
"### TFT.feature_importance_correlations\n",
"\n",
"> TFT.feature_importance_correlations ()\n",
"\n",
- "*Compute the correlation between the past and future feature importances and the mean attention weights.\n",
+ "Compute the correlation between the past and future feature importances and the mean attention weights.\n",
"\n",
"Returns:\n",
- "pd.DataFrame: A DataFrame containing the correlation coefficients between the past feature importances and the mean attention weights.*"
+ "pd.DataFrame: A DataFrame containing the correlation coefficients between the past feature importances and the mean attention weights."
]
},
"execution_count": null,
@@ -1304,10 +1341,21 @@
"execution_count": null,
"metadata": {},
"outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Seed set to 1\n",
+ "GPU available: True (mps), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "IPU available: False, using: 0 IPUs\n",
+ "HPU available: False, using: 0 HPUs\n"
+ ]
+ },
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "28d80883ad5c4fbcb09c287416508cff",
+ "model_id": "fb09f042057e45d181a21ef46b8d933d",
"version_major": 2,
"version_minor": 0
},
@@ -1321,7 +1369,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "164f652306bc4443b667b223fd216571",
+ "model_id": "e7853b68bc174f7bb1a022a466da2186",
"version_major": 2,
"version_minor": 0
},
@@ -1335,7 +1383,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "9f5d2857dc80499ab589276dc2ef2ef7",
+ "model_id": "86866ee63d314c69b955af148d6d28c9",
"version_major": 2,
"version_minor": 0
},
@@ -1349,7 +1397,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "a61af16db6fa430e8bad2e56cbb6e204",
+ "model_id": "f1e52f5aba3e40d9beca566cb62ef5b4",
"version_major": 2,
"version_minor": 0
},
@@ -1363,7 +1411,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "3d6a111c2c4a4310a668494267427ee3",
+ "model_id": "4a058b96deae42dabda97307249d15a6",
"version_major": 2,
"version_minor": 0
},
@@ -1377,7 +1425,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "053994a65f284431aabfd6ba7aa46e08",
+ "model_id": "779e686e6a86423ebda2172ea6268786",
"version_major": 2,
"version_minor": 0
},
@@ -1391,7 +1439,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "21573be21eea4e858a75e23f7f090e9c",
+ "model_id": "3866504eb4ac47a398f82b67b4528e68",
"version_major": 2,
"version_minor": 0
},
@@ -1405,7 +1453,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "65b34751133141388175b695fe42d76b",
+ "model_id": "652afb9d18c9401e800372fe47e3bd0c",
"version_major": 2,
"version_minor": 0
},
@@ -1419,7 +1467,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "87a7b895c8b14b41a0365e1e0eebab3f",
+ "model_id": "aaca1e1ab85e4988acfc52b976e06b62",
"version_major": 2,
"version_minor": 0
},
@@ -1433,7 +1481,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "45d33286f3ec4611bf269d991c984283",
+ "model_id": "17af4b56375e4eda9b42038b0ddca23b",
"version_major": 2,
"version_minor": 0
},
@@ -1447,7 +1495,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "a3df1d8c667b467a9c95a674e8e70790",
+ "model_id": "bc0cf020042b480e865786435fe86174",
"version_major": 2,
"version_minor": 0
},
@@ -1461,7 +1509,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "1676e1cec80a43feb6fff5b101aa9e74",
+ "model_id": "406fc612f09c45489dcfd57159646a9a",
"version_major": 2,
"version_minor": 0
},
@@ -1475,7 +1523,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "ce5140456f724787b8c521c71877161f",
+ "model_id": "9cfb437ad5fe4b60bc407ddfd3a54eef",
"version_major": 2,
"version_minor": 0
},
@@ -1489,7 +1537,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "bcf26a086e48436dbec942a16d1068e5",
+ "model_id": "4caea18544d44beca9f03f4ba5432051",
"version_major": 2,
"version_minor": 0
},
@@ -1503,7 +1551,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "82080ec388b64b36a9329573ba7f6c72",
+ "model_id": "c02d707ed0974a1ba562e6b261370815",
"version_major": 2,
"version_minor": 0
},
@@ -1517,7 +1565,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "9cb23c364290468982ab7db4066d41c4",
+ "model_id": "356f7ff7f9654864aa8a95d1516e0e03",
"version_major": 2,
"version_minor": 0
},
@@ -1531,7 +1579,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "b17d7126e58f4765aa5d5bf685fb1c76",
+ "model_id": "ac4a9d8fa3e1494ab986fe59b47c3bd2",
"version_major": 2,
"version_minor": 0
},
@@ -1545,7 +1593,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "b56b7debd7374198bf86a3c4ded57a40",
+ "model_id": "90906beaebd84ac6bd4b996816d0efe8",
"version_major": 2,
"version_minor": 0
},
@@ -1559,7 +1607,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "e25fceeff8aa41c0852b5f40c9a73489",
+ "model_id": "89d445aa28fd43568b55c5674577d986",
"version_major": 2,
"version_minor": 0
},
@@ -1573,7 +1621,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "39a9ebe57449439992be2036b0311762",
+ "model_id": "1e0aa3e904954d89b1f7446da474c971",
"version_major": 2,
"version_minor": 0
},
@@ -1585,23 +1633,20 @@
"output_type": "display_data"
},
{
- "data": {
- "application/vnd.jupyter.widget-view+json": {
- "model_id": "9bdae404519d47769ed1940b9c7e26a0",
- "version_major": 2,
- "version_minor": 0
- },
- "text/plain": [
- "Validation: | | 0/? [00:00, ?it/s]"
- ]
- },
- "metadata": {},
- "output_type": "display_data"
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Trainer already configured with model summary callbacks: []. Skipping setting a default `ModelSummary` callback.\n",
+ "GPU available: True (mps), used: True\n",
+ "TPU available: False, using: 0 TPU cores\n",
+ "IPU available: False, using: 0 IPUs\n",
+ "HPU available: False, using: 0 HPUs\n"
+ ]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "970de95dd7a84c40bf408a433189067a",
+ "model_id": "23e3f9ec08ab4b4697bf2479a6bd0bc4",
"version_major": 2,
"version_minor": 0
},
@@ -1624,7 +1669,7 @@
},
{
"data": {
- "image/png": "",
+ "image/png": "",
"text/plain": [
"