diff --git a/README.md b/README.md
index dd48dffd..9bb1f76f 100644
--- a/README.md
+++ b/README.md
@@ -30,8 +30,8 @@ pip install ".[notebooks]"
- Getting started with `TinyTimeMixer (TTM)` [[Try it out]](https://github.com/ibm-granite/granite-tsfm/blob/main/notebooks/hfdemo/ttm_getting_started.ipynb)
## 📗 Google Colab Tutorials
-Run the TTM tutorial in Google Colab, and quickly build a forecasting application with pre-trained TSFM models.
-- [TTM Colab Tutorial](https://colab.research.google.com/github/IBM/tsfm/blob/main/notebooks/tutorial/ttm_tutorial.ipynb)
+Run the TTM tutorial in Google Colab, and quickly build a forecasting application with the pre-trained TSFM models.
+- [TTM Colab Tutorial](https://colab.research.google.com/github/ibm-granite/granite-tsfm/blob/main/notebooks/hfdemo/ttm_getting_started.ipynb)
## 💻 Demos Installation
The demo presented at NeurIPS 2023 is available in `tsfmhfdemos`. This demo requires you to have pre-trained and finetuned models in place (we plan to release these at a later date). To install the requirements use `pip`:
diff --git a/notebooks/hfdemo/patch_tsmixer_getting_started.ipynb b/notebooks/hfdemo/patch_tsmixer_getting_started.ipynb
index 82d2be65..eb28f908 100644
--- a/notebooks/hfdemo/patch_tsmixer_getting_started.ipynb
+++ b/notebooks/hfdemo/patch_tsmixer_getting_started.ipynb
@@ -883,7 +883,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.13"
+ "version": "3.9.12"
}
},
"nbformat": 4,
diff --git a/notebooks/hfdemo/patch_tsmixer_transfer.ipynb b/notebooks/hfdemo/patch_tsmixer_transfer.ipynb
index f76beba0..ae3a0f89 100644
--- a/notebooks/hfdemo/patch_tsmixer_transfer.ipynb
+++ b/notebooks/hfdemo/patch_tsmixer_transfer.ipynb
@@ -1465,7 +1465,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.4"
+ "version": "3.9.12"
}
},
"nbformat": 4,
diff --git a/notebooks/hfdemo/patch_tst_getting_started.ipynb b/notebooks/hfdemo/patch_tst_getting_started.ipynb
index 56496e74..09015217 100644
--- a/notebooks/hfdemo/patch_tst_getting_started.ipynb
+++ b/notebooks/hfdemo/patch_tst_getting_started.ipynb
@@ -1,6 +1,7 @@
{
"cells": [
{
+ "attachments": {},
"cell_type": "markdown",
"id": "7478e0e2-b7af-4fd4-b44e-ca58e0c31b71",
"metadata": {},
@@ -55,6 +56,7 @@
]
},
{
+ "attachments": {},
"cell_type": "markdown",
"id": "9e4eb9be-c19f-448f-a4bd-c600e068633f",
"metadata": {},
@@ -283,6 +285,7 @@
]
},
{
+ "attachments": {},
"cell_type": "markdown",
"id": "ae939491-8813-44c9-bc3d-d1c6a5764cd4",
"metadata": {},
@@ -379,6 +382,7 @@
]
},
{
+ "attachments": {},
"cell_type": "markdown",
"id": "19456329-1293-45bf-99c7-e5ccf0534846",
"metadata": {},
@@ -606,6 +610,7 @@
]
},
{
+ "attachments": {},
"cell_type": "markdown",
"id": "5b2e8f0a-8367-4c10-84bd-a72e8f21ccc4",
"metadata": {},
@@ -671,7 +676,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
- "version": "3.10.13"
+ "version": "3.9.12"
}
},
"nbformat": 4,
diff --git a/notebooks/hfdemo/patchtsmixer_HF_blog.ipynb b/notebooks/hfdemo/patchtsmixer_HF_blog.ipynb
new file mode 100644
index 00000000..712537c5
--- /dev/null
+++ b/notebooks/hfdemo/patchtsmixer_HF_blog.ipynb
@@ -0,0 +1,1582 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# PatchTSMixer in HuggingFace - Getting Started\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "`PatchTSMixer` is a lightweight time-series modeling approach based on the MLP-Mixer architecture. It is proposed in [TSMixer: Lightweight MLP-Mixer Model for Multivariate Time Series Forecasting](https://arxiv.org/pdf/2306.09364.pdf) by IBM Research authors `Vijay Ekambaram`, `Arindam Jati`, `Nam Nguyen`, `Phanwadee Sinthong` and `Jayant Kalagnanam`.\n",
+ "\n",
+ "For effective mindshare and to promote opensourcing - IBM Research join hands with the HuggingFace team to opensource this model in HF.\n",
+ "\n",
+ "In this [HuggingFace implementation](https://huggingface.co/docs/transformers/main/en/model_doc/patchtsmixer), we provide PatchTSMixer’s capabilities to effortlessly facilitate lightweight mixing across patches, channels, and hidden features for effective multivariate time-series modeling. It also supports various attention mechanisms starting from simple gated attention to more complex self-attention blocks that can be customized accordingly. The model can be pretrained and subsequently used for various downstream tasks such as forecasting, classification, and regression.\n",
+ "\n",
+ "`PatchTSMixer` outperforms state-of-the-art MLP and Transformer models in forecasting by a considerable margin of 8-60%. It also outperforms the latest strong benchmarks of Patch-Transformer models (by 1-2%) with a significant reduction in memory and runtime (2-3X). For more details, refer to the [paper](https://arxiv.org/pdf/2306.09364.pdf)\n",
+ "\n",
+ "In this blog, we will demonstrate examples of getting started with PatchTSMixer. We will first demonstrate the forecasting capability of `PatchTSMixer` on the Electricity data. We will then demonstrate the transfer learning capability of PatchTSMixer by using the model trained on the Electricity to do zero-shot forecasting on the ETTH2 dataset.\n",
+ "\n",
+ "\n",
+ "`Blog authors`: Arindam Jati, Vijay Ekambaram, Nam Ngugen, Wesley Gifford and Kashif Rasul\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Installation\n",
+ "This demo needs Huggingface [`transformers`](https://github.com/huggingface/transformers) for main modeling tasks, and IBM `tsfm` for auxiliary data pre-processing.\n",
+ "We can install both by cloning the `tsfm` repository and following the below steps.\n",
+ "\n",
+ "1. Clone IBM Time Series Foundation Model Repository [`tsfm`](https://github.com/ibm/tsfm).\n",
+ " ```\n",
+ " git clone git@github.com:IBM/tsfm.git\n",
+ " cd tsfm\n",
+ " ```\n",
+ "2. Install `tsfm`. This will also install Huggingface `transformers`.\n",
+ " ```\n",
+ " pip install .\n",
+ " ```\n",
+ "3. Test it with the following commands in a `python` terminal.\n",
+ " ```\n",
+ " from transformers import PatchTSMixerConfig\n",
+ " from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
+ " ```"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Part 1: Forecasting on Electricity dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
+ " from .autonotebook import tqdm as notebook_tqdm\n",
+ "2023-12-11 01:25:50.313015: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
+ "2023-12-11 01:25:50.313102: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
+ "2023-12-11 01:25:50.313132: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n",
+ "2023-12-11 01:25:51.234452: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Standard\n",
+ "import os\n",
+ "import random\n",
+ "\n",
+ "import numpy as np\n",
+ "import pandas as pd\n",
+ "import torch\n",
+ "\n",
+ "# Third Party\n",
+ "from transformers import (\n",
+ " EarlyStoppingCallback,\n",
+ " PatchTSMixerConfig,\n",
+ " PatchTSMixerForPrediction,\n",
+ " Trainer,\n",
+ " TrainingArguments,\n",
+ ")\n",
+ "\n",
+ "# First Party\n",
+ "from tsfm_public.toolkit.dataset import ForecastDFDataset\n",
+ "from tsfm_public.toolkit.time_series_preprocessor import TimeSeriesPreprocessor\n",
+ "from tsfm_public.toolkit.util import select_by_index"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " ### Set seed"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "SEED = 42\n",
+ "torch.manual_seed(SEED)\n",
+ "random.seed(SEED)\n",
+ "np.random.seed(SEED)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load and prepare datasets\n",
+ "\n",
+ "In the next cell, please adjust the following parameters to suit your application:\n",
+ "- `PRETRAIN_AGAIN`: Set this to `True` if you want to perform pretraining again. Note that this might take some time depending on the GPU availability. Otherwise, the already pretrained model will be used.\n",
+ "- `dataset_path`: path to local .csv file, or web address to a csv file for the data of interest. Data is loaded with pandas, so anything supported by\n",
+ "`pd.read_csv` is supported: (https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html).\n",
+ "- `timestamp_column`: column name containing timestamp information, use None if there is no such column\n",
+ "- `id_columns`: List of column names specifying the IDs of different time series. If no ID column exists, use []\n",
+ "- `forecast_columns`: List of columns to be modeled\n",
+ "- `context_length`: The amount of historical data used as input to the model. Windows of the input time series data with length equal to\n",
+ "`context_length` will be extracted from the input dataframe. In the case of a multi-time series dataset, the context windows will be created\n",
+ "so that they are contained within a single time series (i.e., a single ID).\n",
+ "- `forecast_horizon`: Number of timestamps to forecast in future.\n",
+ "- `train_start_index`, `train_end_index`: the start and end indices in the loaded data which delineate the training data.\n",
+ "- `valid_start_index`, `valid_end_index`: the start and end indices in the loaded data which delineate the validation data.\n",
+ "- `test_start_index`, `test_end_index`: the start and end indices in the loaded data which delineate the test data.\n",
+ "- `patch_length`: The patch length for the `PatchTSMixer` model. It is recommended to choose a value that evenly divides `context_length`.\n",
+ "- `num_workers`: Number of dataloder workers in pytorch dataloader.\n",
+ "- `batch_size`: Batch size.\n",
+ "The data is first loaded into a Pandas dataframe and split into training, validation, and test parts. Then the pandas dataframes are converted\n",
+ "to the appropriate torch dataset needed for training."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "PRETRAIN_AGAIN = True\n",
+ "# Download ECL data from https://github.com/zhouhaoyi/Informer2020\n",
+ "dataset_path = \"~/Downloads/ECL.csv\"\n",
+ "timestamp_column = \"date\"\n",
+ "id_columns = []\n",
+ "\n",
+ "context_length = 512\n",
+ "forecast_horizon = 96\n",
+ "patch_length = 8\n",
+ "num_workers = 16 # Reduce this if you have low number of CPU cores\n",
+ "batch_size = 64 # Adjust according to GPU memory"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if PRETRAIN_AGAIN:\n",
+ " data = pd.read_csv(\n",
+ " dataset_path,\n",
+ " parse_dates=[timestamp_column],\n",
+ " )\n",
+ " forecast_columns = list(data.columns[1:])\n",
+ "\n",
+ " # get split\n",
+ " num_train = int(len(data) * 0.7)\n",
+ " num_test = int(len(data) * 0.2)\n",
+ " num_valid = len(data) - num_train - num_test\n",
+ " border1s = [\n",
+ " 0,\n",
+ " num_train - context_length,\n",
+ " len(data) - num_test - context_length,\n",
+ " ]\n",
+ " border2s = [num_train, num_train + num_valid, len(data)]\n",
+ "\n",
+ " train_start_index = border1s[0] # None indicates beginning of dataset\n",
+ " train_end_index = border2s[0]\n",
+ "\n",
+ " # we shift the start of the evaluation period back by context length so that\n",
+ " # the first evaluation timestamp is immediately following the training data\n",
+ " valid_start_index = border1s[1]\n",
+ " valid_end_index = border2s[1]\n",
+ "\n",
+ " test_start_index = border1s[2]\n",
+ " test_end_index = border2s[2]\n",
+ "\n",
+ " train_data = select_by_index(\n",
+ " data,\n",
+ " id_columns=id_columns,\n",
+ " start_index=train_start_index,\n",
+ " end_index=train_end_index,\n",
+ " )\n",
+ " valid_data = select_by_index(\n",
+ " data,\n",
+ " id_columns=id_columns,\n",
+ " start_index=valid_start_index,\n",
+ " end_index=valid_end_index,\n",
+ " )\n",
+ " test_data = select_by_index(\n",
+ " data,\n",
+ " id_columns=id_columns,\n",
+ " start_index=test_start_index,\n",
+ " end_index=test_end_index,\n",
+ " )\n",
+ "\n",
+ " tsp = TimeSeriesPreprocessor(\n",
+ " timestamp_column=timestamp_column,\n",
+ " id_columns=id_columns,\n",
+ " input_columns=forecast_columns,\n",
+ " output_columns=forecast_columns,\n",
+ " scaling=True,\n",
+ " )\n",
+ " tsp.train(train_data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if PRETRAIN_AGAIN:\n",
+ " train_dataset = ForecastDFDataset(\n",
+ " tsp.preprocess(train_data),\n",
+ " id_columns=id_columns,\n",
+ " timestamp_column=\"date\",\n",
+ " input_columns=forecast_columns,\n",
+ " output_columns=forecast_columns,\n",
+ " context_length=context_length,\n",
+ " prediction_length=forecast_horizon,\n",
+ " )\n",
+ " valid_dataset = ForecastDFDataset(\n",
+ " tsp.preprocess(valid_data),\n",
+ " id_columns=id_columns,\n",
+ " timestamp_column=\"date\",\n",
+ " input_columns=forecast_columns,\n",
+ " output_columns=forecast_columns,\n",
+ " context_length=context_length,\n",
+ " prediction_length=forecast_horizon,\n",
+ " )\n",
+ " test_dataset = ForecastDFDataset(\n",
+ " tsp.preprocess(test_data),\n",
+ " id_columns=id_columns,\n",
+ " timestamp_column=\"date\",\n",
+ " input_columns=forecast_columns,\n",
+ " output_columns=forecast_columns,\n",
+ " context_length=context_length,\n",
+ " prediction_length=forecast_horizon,\n",
+ " )"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " ## Configure the PatchTSMixer model\n",
+ "\n",
+ " The settings below control the different components in the PatchTSMixer model.\n",
+ " - `num_input_channels`: the number of input channels (or dimensions) in the time series data. This is\n",
+ " automatically set to the number for forecast columns.\n",
+ " - `context_length`: As described above, the amount of historical data used as input to the model.\n",
+ " - `prediction_length`: This is same as the forecast horizon as decribed above.\n",
+ " - `patch_length`: The length of the patches extracted from the context window (of length `context_length``).\n",
+ " - `patch_stride`: The stride used when extracting patches from the context window.\n",
+ " - `d_model`: Hidden feature dimension of the model.\n",
+ " - `num_layers`: The number of model layers.\n",
+ " - `dropout`: Dropout probability for all fully connected layers in the encoder.\n",
+ " - `head_dropout`: Dropout probability used in the head of the model.\n",
+ " - `mode`: PatchTSMixer operating mode. \"common_channel\"/\"mix_channel\". Common-channel works in channel-independent mode. For pretraining, use \"common_channel\".\n",
+ " - `scaling`: Per-widow standard scaling. Recommended value: \"std\".\n",
+ "\n",
+ "For full details on the parameters - refer [here](https://huggingface.co/docs/transformers/main/en/model_doc/patchtsmixer)\n",
+ "\n",
+ "We recommend that you only adjust the values in the next cell."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if PRETRAIN_AGAIN:\n",
+ " config = PatchTSMixerConfig(\n",
+ " context_length=context_length,\n",
+ " prediction_length=forecast_horizon,\n",
+ " patch_length=patch_length,\n",
+ " num_input_channels=len(forecast_columns),\n",
+ " patch_stride=patch_length,\n",
+ " d_model=16,\n",
+ " num_layers=8,\n",
+ " expansion_factor=2,\n",
+ " dropout=0.2,\n",
+ " head_dropout=0.2,\n",
+ " mode=\"common_channel\",\n",
+ " scaling=\"std\",\n",
+ " )\n",
+ " model = PatchTSMixerForPrediction(config)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " ## Train model\n",
+ "\n",
+ " Trains the PatchTSMixer model based on the direct forecasting strategy."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "
\n",
+ " \n",
+ "
\n",
+ " [2450/7000 21:35 < 40:08, 1.89 it/s, Epoch 35/100]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.247100 | \n",
+ " 0.141067 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.168600 | \n",
+ " 0.127757 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.156500 | \n",
+ " 0.122327 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.150300 | \n",
+ " 0.118918 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.146000 | \n",
+ " 0.116496 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.143100 | \n",
+ " 0.114968 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.140800 | \n",
+ " 0.113678 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.139200 | \n",
+ " 0.113057 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.137900 | \n",
+ " 0.112405 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.136900 | \n",
+ " 0.112225 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.136100 | \n",
+ " 0.112087 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.135400 | \n",
+ " 0.112330 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.134700 | \n",
+ " 0.111778 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.134100 | \n",
+ " 0.111702 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.133700 | \n",
+ " 0.110964 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.133100 | \n",
+ " 0.111164 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.132800 | \n",
+ " 0.111063 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.132400 | \n",
+ " 0.111088 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.132100 | \n",
+ " 0.110905 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.131800 | \n",
+ " 0.110844 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.131300 | \n",
+ " 0.110831 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.131100 | \n",
+ " 0.110278 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.130700 | \n",
+ " 0.110591 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.130600 | \n",
+ " 0.110319 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.130300 | \n",
+ " 0.109900 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.130000 | \n",
+ " 0.109982 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.129900 | \n",
+ " 0.109975 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.129600 | \n",
+ " 0.110128 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 0.129300 | \n",
+ " 0.109995 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 0.129100 | \n",
+ " 0.109868 | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " 0.129000 | \n",
+ " 0.109928 | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " 0.128700 | \n",
+ " 0.109823 | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " 0.128500 | \n",
+ " 0.109863 | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " 0.128400 | \n",
+ " 0.109794 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 0.128100 | \n",
+ " 0.109945 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n"
+ ]
+ }
+ ],
+ "source": [
+ "if PRETRAIN_AGAIN:\n",
+ " training_args = TrainingArguments(\n",
+ " output_dir=\"./checkpoint/patchtsmixer/electricity/pretrain/output/\",\n",
+ " overwrite_output_dir=True,\n",
+ " learning_rate=0.001,\n",
+ " num_train_epochs=100, # For a quick test of this notebook, set it to 1\n",
+ " do_eval=True,\n",
+ " evaluation_strategy=\"epoch\",\n",
+ " per_device_train_batch_size=batch_size,\n",
+ " per_device_eval_batch_size=batch_size,\n",
+ " dataloader_num_workers=num_workers,\n",
+ " report_to=\"tensorboard\",\n",
+ " save_strategy=\"epoch\",\n",
+ " logging_strategy=\"epoch\",\n",
+ " save_total_limit=3,\n",
+ " logging_dir=\"./checkpoint/patchtsmixer/electricity/pretrain/logs/\", # Make sure to specify a logging directory\n",
+ " load_best_model_at_end=True, # Load the best model when training ends\n",
+ " metric_for_best_model=\"eval_loss\", # Metric to monitor for early stopping\n",
+ " greater_is_better=False, # For loss\n",
+ " label_names=[\"future_values\"],\n",
+ " # max_steps=20,\n",
+ " )\n",
+ "\n",
+ " # Create the early stopping callback\n",
+ " early_stopping_callback = EarlyStoppingCallback(\n",
+ " early_stopping_patience=10, # Number of epochs with no improvement after which to stop\n",
+ " early_stopping_threshold=0.0001, # Minimum improvement required to consider as improvement\n",
+ " )\n",
+ "\n",
+ " # define trainer\n",
+ " trainer = Trainer(\n",
+ " model=model,\n",
+ " args=training_args,\n",
+ " train_dataset=train_dataset,\n",
+ " eval_dataset=valid_dataset,\n",
+ " callbacks=[early_stopping_callback],\n",
+ " )\n",
+ "\n",
+ " # pretrain\n",
+ " trainer.train()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " ## Evaluate model on the test set.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [21/21 00:03]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Test result:\n",
+ "{'eval_loss': 0.12884521484375, 'eval_runtime': 5.7532, 'eval_samples_per_second': 897.763, 'eval_steps_per_second': 3.65, 'epoch': 35.0}\n"
+ ]
+ }
+ ],
+ "source": [
+ "if PRETRAIN_AGAIN:\n",
+ " results = trainer.evaluate(test_dataset)\n",
+ " print(\"Test result:\")\n",
+ " print(results)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "We get MSE score of 0.128 which is the SOTA result on the Electricity data."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " ## Save model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "if PRETRAIN_AGAIN:\n",
+ " save_dir = \"patchtsmixer/electricity/model/pretrain/\"\n",
+ " os.makedirs(save_dir, exist_ok=True)\n",
+ " trainer.save_model(save_dir)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Part 2: Transfer Learning from Electicity to ETTH2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "In this section, we will demonstrate the transfer learning capability of the `PatchTSMixer` model.\n",
+ "We use the model pretrained on Electricity dataset to do zeroshot testing on ETTH2 dataset.\n",
+ "\n",
+ "\n",
+ "In Transfer Learning, we will pretrain the model for a forecasting task on a `source` dataset. Then, we will use the\n",
+ " pretrained model for zero-shot forecasting on a `target` dataset. The zero-shot forecasting\n",
+ " performance will denote the `test` performance of the model in the `target` domain, without any\n",
+ " training on the target domain. Subsequently, we will do linear probing and (then) finetuning of\n",
+ " the pretrained model on the `train` part of the target data, and will validate the forecasting\n",
+ " performance on the `test` part of the target data. In this example, the source dataset is the Electricity dataset and the target dataset is ETTH2"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Transfer Learing on `ETTh2` data. All evaluations are on the `test` part of the `ETTh2` data.\n",
+ "Step 1: Directly evaluate the electricity-pretrained model. This is the zero-shot performance. \n",
+ "Step 2: Evalute after doing linear probing. \n",
+ "Step 3: Evaluate after doing full finetuning. "
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Load ETTh2 data"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "dataset = \"ETTh2\""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loading target dataset: ETTh2\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(f\"Loading target dataset: {dataset}\")\n",
+ "dataset_path = f\"https://raw.githubusercontent.com/zhouhaoyi/ETDataset/main/ETT-small/{dataset}.csv\"\n",
+ "timestamp_column = \"date\"\n",
+ "id_columns = []\n",
+ "forecast_columns = [\"HUFL\", \"HULL\", \"MUFL\", \"MULL\", \"LUFL\", \"LULL\", \"OT\"]\n",
+ "train_start_index = None # None indicates beginning of dataset\n",
+ "train_end_index = 12 * 30 * 24\n",
+ "\n",
+ "# we shift the start of the evaluation period back by context length so that\n",
+ "# the first evaluation timestamp is immediately following the training data\n",
+ "valid_start_index = 12 * 30 * 24 - context_length\n",
+ "valid_end_index = 12 * 30 * 24 + 4 * 30 * 24\n",
+ "\n",
+ "test_start_index = 12 * 30 * 24 + 4 * 30 * 24 - context_length\n",
+ "test_end_index = 12 * 30 * 24 + 8 * 30 * 24"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "TimeSeriesPreprocessor {\n",
+ " \"context_length\": 64,\n",
+ " \"feature_extractor_type\": \"TimeSeriesPreprocessor\",\n",
+ " \"id_columns\": [],\n",
+ " \"input_columns\": [\n",
+ " \"HUFL\",\n",
+ " \"HULL\",\n",
+ " \"MUFL\",\n",
+ " \"MULL\",\n",
+ " \"LUFL\",\n",
+ " \"LULL\",\n",
+ " \"OT\"\n",
+ " ],\n",
+ " \"output_columns\": [\n",
+ " \"HUFL\",\n",
+ " \"HULL\",\n",
+ " \"MUFL\",\n",
+ " \"MULL\",\n",
+ " \"LUFL\",\n",
+ " \"LULL\",\n",
+ " \"OT\"\n",
+ " ],\n",
+ " \"prediction_length\": null,\n",
+ " \"processor_class\": \"TimeSeriesPreprocessor\",\n",
+ " \"scaler_dict\": {\n",
+ " \"0\": {\n",
+ " \"copy\": true,\n",
+ " \"feature_names_in_\": [\n",
+ " \"HUFL\",\n",
+ " \"HULL\",\n",
+ " \"MUFL\",\n",
+ " \"MULL\",\n",
+ " \"LUFL\",\n",
+ " \"LULL\",\n",
+ " \"OT\"\n",
+ " ],\n",
+ " \"mean_\": [\n",
+ " 41.53683496078959,\n",
+ " 12.273452896210882,\n",
+ " 46.60977329964991,\n",
+ " 10.526153112865156,\n",
+ " 1.1869920139097505,\n",
+ " -2.373217913729173,\n",
+ " 26.872023494265697\n",
+ " ],\n",
+ " \"n_features_in_\": 7,\n",
+ " \"n_samples_seen_\": 8640,\n",
+ " \"scale_\": [\n",
+ " 10.448841072588488,\n",
+ " 4.587112566531959,\n",
+ " 16.858190332598408,\n",
+ " 3.018605566682919,\n",
+ " 4.641011217319063,\n",
+ " 8.460910779279644,\n",
+ " 11.584718923414682\n",
+ " ],\n",
+ " \"var_\": [\n",
+ " 109.17827976021215,\n",
+ " 21.04160169803542,\n",
+ " 284.19858129011436,\n",
+ " 9.111979567209104,\n",
+ " 21.538985119281367,\n",
+ " 71.58701121493046,\n",
+ " 134.20571253452223\n",
+ " ],\n",
+ " \"with_mean\": true,\n",
+ " \"with_std\": true\n",
+ " }\n",
+ " },\n",
+ " \"scaling\": true,\n",
+ " \"time_series_task\": \"forecasting\",\n",
+ " \"timestamp_column\": \"date\"\n",
+ "}"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "data = pd.read_csv(\n",
+ " dataset_path,\n",
+ " parse_dates=[timestamp_column],\n",
+ ")\n",
+ "\n",
+ "train_data = select_by_index(\n",
+ " data,\n",
+ " id_columns=id_columns,\n",
+ " start_index=train_start_index,\n",
+ " end_index=train_end_index,\n",
+ ")\n",
+ "valid_data = select_by_index(\n",
+ " data,\n",
+ " id_columns=id_columns,\n",
+ " start_index=valid_start_index,\n",
+ " end_index=valid_end_index,\n",
+ ")\n",
+ "test_data = select_by_index(\n",
+ " data,\n",
+ " id_columns=id_columns,\n",
+ " start_index=test_start_index,\n",
+ " end_index=test_end_index,\n",
+ ")\n",
+ "\n",
+ "tsp = TimeSeriesPreprocessor(\n",
+ " timestamp_column=timestamp_column,\n",
+ " id_columns=id_columns,\n",
+ " input_columns=forecast_columns,\n",
+ " output_columns=forecast_columns,\n",
+ " scaling=True,\n",
+ ")\n",
+ "tsp.train(train_data)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_dataset = ForecastDFDataset(\n",
+ " tsp.preprocess(train_data),\n",
+ " id_columns=id_columns,\n",
+ " input_columns=forecast_columns,\n",
+ " output_columns=forecast_columns,\n",
+ " context_length=context_length,\n",
+ " prediction_length=forecast_horizon,\n",
+ ")\n",
+ "valid_dataset = ForecastDFDataset(\n",
+ " tsp.preprocess(valid_data),\n",
+ " id_columns=id_columns,\n",
+ " input_columns=forecast_columns,\n",
+ " output_columns=forecast_columns,\n",
+ " context_length=context_length,\n",
+ " prediction_length=forecast_horizon,\n",
+ ")\n",
+ "test_dataset = ForecastDFDataset(\n",
+ " tsp.preprocess(test_data),\n",
+ " id_columns=id_columns,\n",
+ " input_columns=forecast_columns,\n",
+ " output_columns=forecast_columns,\n",
+ " context_length=context_length,\n",
+ " prediction_length=forecast_horizon,\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Zero-shot forecasting on `ETTh2`"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Loading pretrained model\n",
+ "Done\n"
+ ]
+ }
+ ],
+ "source": [
+ "print(\"Loading pretrained model\")\n",
+ "finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"patchtsmixer/electricity/model/pretrain/\")\n",
+ "print(\"Done\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n",
+ "Doing zero-shot forecasting on target data\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [11/11 02:52]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Target data zero-shot forecasting result:\n",
+ "{'eval_loss': 0.3038313388824463, 'eval_runtime': 1.8364, 'eval_samples_per_second': 1516.562, 'eval_steps_per_second': 5.99}\n"
+ ]
+ }
+ ],
+ "source": [
+ "finetune_forecast_args = TrainingArguments(\n",
+ " output_dir=\"./checkpoint/patchtsmixer/transfer/finetune/output/\",\n",
+ " overwrite_output_dir=True,\n",
+ " learning_rate=0.0001,\n",
+ " num_train_epochs=100,\n",
+ " do_eval=True,\n",
+ " evaluation_strategy=\"epoch\",\n",
+ " per_device_train_batch_size=batch_size,\n",
+ " per_device_eval_batch_size=batch_size,\n",
+ " dataloader_num_workers=num_workers,\n",
+ " report_to=\"tensorboard\",\n",
+ " save_strategy=\"epoch\",\n",
+ " logging_strategy=\"epoch\",\n",
+ " save_total_limit=3,\n",
+ " logging_dir=\"./checkpoint/patchtsmixer/transfer/finetune/logs/\", # Make sure to specify a logging directory\n",
+ " load_best_model_at_end=True, # Load the best model when training ends\n",
+ " metric_for_best_model=\"eval_loss\", # Metric to monitor for early stopping\n",
+ " greater_is_better=False, # For loss\n",
+ ")\n",
+ "\n",
+ "# Create a new early stopping callback with faster convergence properties\n",
+ "early_stopping_callback = EarlyStoppingCallback(\n",
+ " early_stopping_patience=5, # Number of epochs with no improvement after which to stop\n",
+ " early_stopping_threshold=0.001, # Minimum improvement required to consider as improvement\n",
+ ")\n",
+ "\n",
+ "finetune_forecast_trainer = Trainer(\n",
+ " model=finetune_forecast_model,\n",
+ " args=finetune_forecast_args,\n",
+ " train_dataset=train_dataset,\n",
+ " eval_dataset=valid_dataset,\n",
+ " callbacks=[early_stopping_callback],\n",
+ ")\n",
+ "\n",
+ "print(\"\\n\\nDoing zero-shot forecasting on target data\")\n",
+ "result = finetune_forecast_trainer.evaluate(test_dataset)\n",
+ "print(\"Target data zero-shot forecasting result:\")\n",
+ "print(result)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "By a direct zeroshot, we get MSE of 0.3 which is near to the SOTA result. Lets see, how we can do a simple linear probing to match the SOTA results."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Target data `ETTh2` linear probing\n",
+ "We can do a quick linear probing on the `train` part of the target data to see any possible `test` performance improvement. "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n",
+ "Linear probing on the target data\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 416/3200 01:01 < 06:53, 6.73 it/s, Epoch 13/100]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.447000 | \n",
+ " 0.216436 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.438600 | \n",
+ " 0.215667 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.429400 | \n",
+ " 0.215104 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.422500 | \n",
+ " 0.213820 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.418500 | \n",
+ " 0.213585 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.415000 | \n",
+ " 0.213016 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.412000 | \n",
+ " 0.213067 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.412400 | \n",
+ " 0.211993 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.405900 | \n",
+ " 0.212460 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.405300 | \n",
+ " 0.211772 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.406200 | \n",
+ " 0.212154 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.400600 | \n",
+ " 0.212082 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.405300 | \n",
+ " 0.211458 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [11/11 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Target data head/linear probing result:\n",
+ "{'eval_loss': 0.27119266986846924, 'eval_runtime': 1.7621, 'eval_samples_per_second': 1580.478, 'eval_steps_per_second': 6.242, 'epoch': 13.0}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Freeze the backbone of the model\n",
+ "for param in finetune_forecast_trainer.model.model.parameters():\n",
+ " param.requires_grad = False\n",
+ "\n",
+ "print(\"\\n\\nLinear probing on the target data\")\n",
+ "finetune_forecast_trainer.train()\n",
+ "print(\"Evaluating\")\n",
+ "result = finetune_forecast_trainer.evaluate(test_dataset)\n",
+ "print(\"Target data head/linear probing result:\")\n",
+ "print(result)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "vscode": {
+ "languageId": "plaintext"
+ }
+ },
+ "source": [
+ "By doing a simple linear probing, MSE decreased from 0.3 to 0.271 achiving the SOTA results."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "['patchtsmixer/electricity/model/transfer/ETTh2/preprocessor/preprocessor_config.json']"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "save_dir = f\"patchtsmixer/electricity/model/transfer/{dataset}/model/linear_probe/\"\n",
+ "os.makedirs(save_dir, exist_ok=True)\n",
+ "finetune_forecast_trainer.save_model(save_dir)\n",
+ "\n",
+ "save_dir = f\"patchtsmixer/electricity/model/transfer/{dataset}/preprocessor/\"\n",
+ "os.makedirs(save_dir, exist_ok=True)\n",
+ "tsp.save_pretrained(save_dir)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Lets now see, if we get any more improvements by doing a full finetune."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Target data `ETTh2` full finetune\n",
+ "\n",
+ "We can do a full model finetune (instead of probing the last linear layer as shown above) on the `train` part of the target data to see a possible `test` performance improvement."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "\n",
+ "Finetuning on the target data\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 288/3200 00:44 < 07:34, 6.40 it/s, Epoch 9/100]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.432900 | \n",
+ " 0.215200 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.416700 | \n",
+ " 0.210919 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.401400 | \n",
+ " 0.209932 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.392900 | \n",
+ " 0.208808 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.388100 | \n",
+ " 0.209692 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.375900 | \n",
+ " 0.209546 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.370000 | \n",
+ " 0.210207 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.367000 | \n",
+ " 0.211601 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.359400 | \n",
+ " 0.211405 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Evaluating\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/hf/lib/python3.10/site-packages/torch/nn/parallel/_functions.py:68: UserWarning: Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
+ " warnings.warn('Was asked to gather along dimension 0, but all '\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [11/11 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Target data full finetune result:\n",
+ "{'eval_loss': 0.2734043300151825, 'eval_runtime': 1.5853, 'eval_samples_per_second': 1756.725, 'eval_steps_per_second': 6.939, 'epoch': 9.0}\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Reload the model\n",
+ "finetune_forecast_model = PatchTSMixerForPrediction.from_pretrained(\"patchtsmixer/electricity/model/pretrain/\")\n",
+ "finetune_forecast_trainer = Trainer(\n",
+ " model=finetune_forecast_model,\n",
+ " args=finetune_forecast_args,\n",
+ " train_dataset=train_dataset,\n",
+ " eval_dataset=valid_dataset,\n",
+ " callbacks=[early_stopping_callback],\n",
+ ")\n",
+ "print(\"\\n\\nFinetuning on the target data\")\n",
+ "finetune_forecast_trainer.train()\n",
+ "print(\"Evaluating\")\n",
+ "result = finetune_forecast_trainer.evaluate(test_dataset)\n",
+ "print(\"Target data full finetune result:\")\n",
+ "print(result)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "There is not much improvement with ETTH2 dataset with full finetuning. Lets save the model anyway."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "save_dir = f\"patchtsmixer/electricity/model/transfer/{dataset}/model/fine_tuning/\"\n",
+ "os.makedirs(save_dir, exist_ok=True)\n",
+ "finetune_forecast_trainer.save_model(save_dir)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "\n",
+ "Summary: In this blog, we presented a step-by-step guide on leveraging PatchTSMixer for tasks related to forecasting and transfer learning. We intend to facilitate the seamless integration of the PatchTSMixer HF model for your forecasting use cases. We trust that this content serves as a useful resource to expedite your adoption of PatchTSMixer. Thank you for tuning in to our blog, and we hope you find this information beneficial for your projects.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.12.6"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/hfdemo/tinytimemixer/research_use/ttm-r2_freq_benchmarking_1024_96.ipynb b/notebooks/hfdemo/tinytimemixer/research_use/ttm-r2_freq_benchmarking_1024_96.ipynb
new file mode 100644
index 00000000..2430f7d4
--- /dev/null
+++ b/notebooks/hfdemo/tinytimemixer/research_use/ttm-r2_freq_benchmarking_1024_96.ipynb
@@ -0,0 +1,2988 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " # TTM zero-shot and few-shot benchmarking on multiple datasets\n",
+ "\n",
+ " **Using TTM-1024-96 model with Frequency Tuning.**"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2024-10-04 09:09:31.338304: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
+ "2024-10-04 09:09:31.388332: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2024-10-04 09:09:33.367705: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory\n",
+ " warn(f\"Failed to load image Python extension: {e}\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "import logging\n",
+ "import math\n",
+ "import warnings\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd\n",
+ "from torch.optim import AdamW\n",
+ "from torch.optim.lr_scheduler import OneCycleLR\n",
+ "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n",
+ "\n",
+ "from tsfm_public import TinyTimeMixerForPrediction, TrackingCallback, count_parameters, load_dataset\n",
+ "from tsfm_public.toolkit.lr_finder import optimal_lr_finder\n",
+ "from tsfm_public.toolkit.visualization import plot_predictions\n",
+ "\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "\n",
+ "logging.basicConfig(level=logging.ERROR)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Important arguments"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set seed\n",
+ "SEED = 42\n",
+ "set_seed(SEED)\n",
+ "\n",
+ "# Specify model parameters\n",
+ "context_length = 1024\n",
+ "forecast_length = 96\n",
+ "freeze_backbone = True\n",
+ "enable_prefix_tuning = True\n",
+ "\n",
+ "# Other args\n",
+ "EPOCHS = 50\n",
+ "NUM_WORKERS = 16\n",
+ "\n",
+ "# Make sure all the datasets in the following `list_datasets` are\n",
+ "# saved in the `DATA_ROOT_PATH` folder. Or, change it accordingly.\n",
+ "# Refer to the load_dataset() function\n",
+ "# in notebooks/hfdemo/tinytimemixer/utils/ttm_utils.py\n",
+ "# to see how it is used.\n",
+ "DATA_ROOT_PATH = \"/dccstor/tsfm23/datasets/\"\n",
+ "\n",
+ "# This is where results will be saved\n",
+ "OUT_DIR = f\"ttm_v2_freq_results_benchmark_{context_length}_{forecast_length}/\""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## List of benchmark datasets (TTM was not pre-trained on any of these)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "list_datasets = [\n",
+ " \"etth1\",\n",
+ " \"etth2\",\n",
+ " \"ettm1\",\n",
+ " \"ettm2\",\n",
+ " \"weather\",\n",
+ " \"electricity\",\n",
+ " \"traffic\",\n",
+ "]"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Get model path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# TTM models for Only Research and Academic (Non-Commercial) Use are here: https://huggingface.co/ibm/ttm-research-r2\n",
+ "# Please provide the branch name properly based on context_len and forecast_len\n",
+ "\n",
+ "hf_model_path = \"ibm/ttm-research-r2\"\n",
+ "if context_length == 512:\n",
+ " hf_model_branch = \"main\"\n",
+ "elif context_length == 1024 or context_length == 1536:\n",
+ " hf_model_branch = f\"{context_length}_{forecast_length}_ft_r2\"\n",
+ "else:\n",
+ " raise ValueError(\"Valid context lengths are: 512, 1024, and 1536 for now. Stay tuned for more TTM models.\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Main benchmarking loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Dataset name: etth1, context length: 1024, prediction length 96\n",
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Data lengths: train = 7521, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = etth1, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/1024_96_ft_r2\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "05912e0244824f5082e37db1d808b62b",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "config.json: 0%| | 0.00/1.51k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "811d9832d129491d8cf0f88d75b1499a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model.safetensors: 0%| | 0.00/12.5M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-854016:t-23192246899456:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-854016:t-23192246899456:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:02]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36299267411231995, 'eval_model_preparation_time': 0.0029, 'eval_runtime': 4.4622, 'eval_samples_per_second': 624.135, 'eval_steps_per_second': 9.861}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Dataset name: etth1, context length: 1024, prediction length 96\n",
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Data lengths: train = 285, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 3126076\n",
+ "Number of params after freezing the backbone 980178\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-854016:t-23192246899456:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-854016:t-23192246899456:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00020565123083486514\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00020565123083486514\n",
+ "Using learning rate = 0.00020565123083486514\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-854016:t-23192246899456:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 60/250 00:33 < 01:48, 1.75 it/s, Epoch 12/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 1.452500 | \n",
+ " 0.679759 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1.376100 | \n",
+ " 0.679738 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 1.264900 | \n",
+ " 0.679831 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1.193600 | \n",
+ " 0.680443 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.955400 | \n",
+ " 0.682166 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.765700 | \n",
+ " 0.685723 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.667100 | \n",
+ " 0.690811 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.505500 | \n",
+ " 0.694354 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.423700 | \n",
+ " 0.695601 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.365500 | \n",
+ " 0.694273 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.331300 | \n",
+ " 0.691512 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.307800 | \n",
+ " 0.688851 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23177196345088:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:10:04 EDT)\" (scheduled at 2024-10-04 09:10:04.201164-04:00)\n",
+ "INFO:p-854016:t-23177196345088:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:10:19 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177196345088:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:10:34 EDT)\" (scheduled at 2024-10-04 09:10:19.201164-04:00)\n",
+ "INFO:p-854016:t-23177196345088:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:10:34 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23192246899456:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-854016:t-23192246899456:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-854016:t-23192246899456:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.0863359371821086 seconds, Total Train Time = 34.31828022003174\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36278820037841797, 'eval_runtime': 1.4201, 'eval_samples_per_second': 1961.099, 'eval_steps_per_second': 30.983, 'epoch': 12.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Dataset name: etth2, context length: 1024, prediction length 96\n",
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Data lengths: train = 7521, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.363 0.363\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = etth2, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/1024_96_ft_r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-854016:t-23192246899456:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-854016:t-23192246899456:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.2709115445613861, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 1.4158, 'eval_samples_per_second': 1967.019, 'eval_steps_per_second': 31.077}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Dataset name: etth2, context length: 1024, prediction length 96\n",
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Data lengths: train = 285, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 3126076\n",
+ "Number of params after freezing the backbone 980178\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-854016:t-23192246899456:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-854016:t-23192246899456:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.000298364724028334\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.000298364724028334\n",
+ "Using learning rate = 0.000298364724028334\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-854016:t-23192246899456:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 55/250 00:31 < 01:56, 1.68 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 1.075600 | \n",
+ " 0.228444 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1.039400 | \n",
+ " 0.229198 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.916700 | \n",
+ " 0.230436 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.834900 | \n",
+ " 0.232362 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.668400 | \n",
+ " 0.235177 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.608800 | \n",
+ " 0.238780 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.487000 | \n",
+ " 0.243153 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.423200 | \n",
+ " 0.249117 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.374200 | \n",
+ " 0.259001 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.344200 | \n",
+ " 0.276216 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.317600 | \n",
+ " 0.295603 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23185549661952:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:10:46 EDT)\" (scheduled at 2024-10-04 09:10:46.693758-04:00)\n",
+ "INFO:p-854016:t-23185549661952:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:11:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23192246899456:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-854016:t-23192246899456:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-854016:t-23192246899456:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.1077192696658047 seconds, Total Train Time = 32.47017168998718\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.2712791860103607, 'eval_runtime': 1.5242, 'eval_samples_per_second': 1827.234, 'eval_steps_per_second': 28.868, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Dataset name: ettm1, context length: 1024, prediction length 96\n",
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Data lengths: train = 33441, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.363 0.363\n",
+ "1 etth2 0.271 0.271\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = ettm1, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/1024_96_ft_r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-854016:t-23192246899456:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-854016:t-23192246899456:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:04]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.32694563269615173, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 4.6754, 'eval_samples_per_second': 2443.632, 'eval_steps_per_second': 38.285}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Dataset name: ettm1, context length: 1024, prediction length 96\n",
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Data lengths: train = 1581, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 3126076\n",
+ "Number of params after freezing the backbone 980178\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-854016:t-23192246899456:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-854016:t-23192246899456:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00043287612810830566\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00043287612810830566\n",
+ "Using learning rate = 0.00043287612810830566\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-854016:t-23192246899456:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 275/1250 00:47 < 02:48, 5.78 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.840300 | \n",
+ " 0.408443 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.622600 | \n",
+ " 0.413902 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.449800 | \n",
+ " 0.420073 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.327900 | \n",
+ " 0.420382 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.286800 | \n",
+ " 0.412459 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.261300 | \n",
+ " 0.427272 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.243400 | \n",
+ " 0.439357 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.228300 | \n",
+ " 0.436092 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.214300 | \n",
+ " 0.454617 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.201800 | \n",
+ " 0.466312 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.195100 | \n",
+ " 0.472506 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23177226843904:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:11:30 EDT)\" (scheduled at 2024-10-04 09:11:30.859561-04:00)\n",
+ "INFO:p-854016:t-23177226843904:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:11:45 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177226843904:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:12:00 EDT)\" (scheduled at 2024-10-04 09:11:45.859561-04:00)\n",
+ "INFO:p-854016:t-23177226843904:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:12:00 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177226843904:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:12:15 EDT)\" (scheduled at 2024-10-04 09:12:00.859561-04:00)\n",
+ "INFO:p-854016:t-23177226843904:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:12:15 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23192246899456:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-854016:t-23192246899456:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-854016:t-23192246899456:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.4592871015722102 seconds, Total Train Time = 48.18804407119751\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3278079628944397, 'eval_runtime': 2.4712, 'eval_samples_per_second': 4623.251, 'eval_steps_per_second': 72.434, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Dataset name: ettm2, context length: 1024, prediction length 96\n",
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Data lengths: train = 33441, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.363 0.363\n",
+ "1 etth2 0.271 0.271\n",
+ "2 ettm1 0.327 0.328\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = ettm2, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/1024_96_ft_r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-854016:t-23192246899456:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-854016:t-23192246899456:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:04]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1779371052980423, 'eval_model_preparation_time': 0.0026, 'eval_runtime': 4.4828, 'eval_samples_per_second': 2548.632, 'eval_steps_per_second': 39.93}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Dataset name: ettm2, context length: 1024, prediction length 96\n",
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Data lengths: train = 1581, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 3126076\n",
+ "Number of params after freezing the backbone 980178\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-854016:t-23192246899456:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-854016:t-23192246899456:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.0002477076355991711\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.0002477076355991711\n",
+ "Using learning rate = 0.0002477076355991711\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-854016:t-23192246899456:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 275/1250 00:46 < 02:47, 5.82 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.524500 | \n",
+ " 0.122229 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.414800 | \n",
+ " 0.123283 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.310200 | \n",
+ " 0.125309 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.212000 | \n",
+ " 0.128540 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.159700 | \n",
+ " 0.133482 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.142000 | \n",
+ " 0.138576 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.133700 | \n",
+ " 0.137888 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.128100 | \n",
+ " 0.140013 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.121600 | \n",
+ " 0.141608 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.116800 | \n",
+ " 0.148989 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.111300 | \n",
+ " 0.153411 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23185409140480:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:12:31 EDT)\" (scheduled at 2024-10-04 09:12:31.605868-04:00)\n",
+ "INFO:p-854016:t-23185409140480:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:12:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23185409140480:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:01 EDT)\" (scheduled at 2024-10-04 09:12:46.605868-04:00)\n",
+ "INFO:p-854016:t-23185409140480:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23185409140480:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:16 EDT)\" (scheduled at 2024-10-04 09:13:01.605868-04:00)\n",
+ "INFO:p-854016:t-23185409140480:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23192246899456:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-854016:t-23192246899456:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-854016:t-23192246899456:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.4911205118352717 seconds, Total Train Time = 47.828335762023926\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1781630963087082, 'eval_runtime': 2.6238, 'eval_samples_per_second': 4354.33, 'eval_steps_per_second': 68.221, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Dataset name: weather, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.363 0.363\n",
+ "1 etth2 0.271 0.271\n",
+ "2 ettm1 0.327 0.328\n",
+ "3 ettm2 0.178 0.178\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = weather, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/1024_96_ft_r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Data lengths: train = 35768, val = 5175, test = 10444\n",
+ "WARNING:p-854016:t-23192246899456:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-854016:t-23192246899456:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:07]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.16557331383228302, 'eval_model_preparation_time': 0.0025, 'eval_runtime': 7.5397, 'eval_samples_per_second': 1385.208, 'eval_steps_per_second': 21.752}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Dataset name: weather, context length: 1024, prediction length 96\n",
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Data lengths: train = 1698, val = 5175, test = 10444\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 3126076\n",
+ "Number of params after freezing the backbone 980178\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-854016:t-23192246899456:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-854016:t-23192246899456:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00020565123083486514\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00020565123083486514\n",
+ "Using learning rate = 0.00020565123083486514\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-854016:t-23192246899456:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 459/1350 01:25 < 02:47, 5.33 it/s, Epoch 17/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.161200 | \n",
+ " 0.385808 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.153800 | \n",
+ " 0.383190 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.145900 | \n",
+ " 0.382595 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.135000 | \n",
+ " 0.382253 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.123100 | \n",
+ " 0.385421 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.110500 | \n",
+ " 0.384698 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.101900 | \n",
+ " 0.380126 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.095600 | \n",
+ " 0.385159 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.090800 | \n",
+ " 0.389009 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.086400 | \n",
+ " 0.386302 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.083900 | \n",
+ " 0.386835 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.079600 | \n",
+ " 0.387808 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.076700 | \n",
+ " 0.390683 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.075000 | \n",
+ " 0.390224 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.072600 | \n",
+ " 0.390617 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.070900 | \n",
+ " 0.391976 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.068500 | \n",
+ " 0.394016 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23177466210048:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:36 EDT)\" (scheduled at 2024-10-04 09:13:36.934582-04:00)\n",
+ "INFO:p-854016:t-23177466210048:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:51 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177466210048:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:51 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177466210048:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:06 EDT)\" (scheduled at 2024-10-04 09:13:51.934582-04:00)\n",
+ "INFO:p-854016:t-23177466210048:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:06 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177466210048:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:21 EDT)\" (scheduled at 2024-10-04 09:14:06.934582-04:00)\n",
+ "INFO:p-854016:t-23177466210048:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:21 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177466210048:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:36 EDT)\" (scheduled at 2024-10-04 09:14:21.934582-04:00)\n",
+ "INFO:p-854016:t-23177466210048:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:36 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177466210048:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:51 EDT)\" (scheduled at 2024-10-04 09:14:36.934582-04:00)\n",
+ "INFO:p-854016:t-23177466210048:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:51 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23192246899456:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-854016:t-23192246899456:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-854016:t-23192246899456:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 2.039622292799108 seconds, Total Train Time = 86.71302151679993\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:03]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.16545218229293823, 'eval_runtime': 4.3782, 'eval_samples_per_second': 2385.457, 'eval_steps_per_second': 37.458, 'epoch': 17.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Dataset name: electricity, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.363 0.363\n",
+ "1 etth2 0.271 0.271\n",
+ "2 ettm1 0.327 0.328\n",
+ "3 ettm2 0.178 0.178\n",
+ "4 weather 0.166 0.165\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = electricity, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/1024_96_ft_r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Data lengths: train = 17293, val = 2537, test = 5165\n",
+ "WARNING:p-854016:t-23192246899456:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-854016:t-23192246899456:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:30]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.15667933225631714, 'eval_model_preparation_time': 0.0026, 'eval_runtime': 32.0456, 'eval_samples_per_second': 161.177, 'eval_steps_per_second': 5.055}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Dataset name: electricity, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Data lengths: train = 774, val = 2537, test = 5165\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 3126076\n",
+ "Number of params after freezing the backbone 980178\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-854016:t-23192246899456:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-854016:t-23192246899456:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 5.590810182512223e-05\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 5.590810182512223e-05\n",
+ "Using learning rate = 5.590810182512223e-05\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-854016:t-23192246899456:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [1250/1250 14:00, Epoch 50/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.154000 | \n",
+ " 0.132241 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.150400 | \n",
+ " 0.132691 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.148000 | \n",
+ " 0.132196 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.146900 | \n",
+ " 0.130720 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.144700 | \n",
+ " 0.130376 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.143600 | \n",
+ " 0.129266 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.141100 | \n",
+ " 0.128518 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.140600 | \n",
+ " 0.127543 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.138700 | \n",
+ " 0.126815 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.136300 | \n",
+ " 0.125934 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.135100 | \n",
+ " 0.125684 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.134500 | \n",
+ " 0.124737 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.132800 | \n",
+ " 0.123635 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.131600 | \n",
+ " 0.123778 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.130500 | \n",
+ " 0.122154 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.130000 | \n",
+ " 0.122326 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.128400 | \n",
+ " 0.122099 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.127500 | \n",
+ " 0.121983 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.126400 | \n",
+ " 0.121733 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.125600 | \n",
+ " 0.120991 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.125600 | \n",
+ " 0.120885 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.124900 | \n",
+ " 0.120421 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.123700 | \n",
+ " 0.119788 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.123200 | \n",
+ " 0.120046 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.122400 | \n",
+ " 0.119741 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.122700 | \n",
+ " 0.119550 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.121800 | \n",
+ " 0.120000 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.121700 | \n",
+ " 0.119300 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 0.121000 | \n",
+ " 0.119297 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 0.120900 | \n",
+ " 0.119253 | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " 0.120700 | \n",
+ " 0.118943 | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " 0.120700 | \n",
+ " 0.119013 | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " 0.120000 | \n",
+ " 0.119013 | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " 0.119900 | \n",
+ " 0.118858 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 0.119900 | \n",
+ " 0.118655 | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " 0.119500 | \n",
+ " 0.118519 | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " 0.119300 | \n",
+ " 0.118727 | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " 0.119700 | \n",
+ " 0.118654 | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " 0.119000 | \n",
+ " 0.118608 | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " 0.119400 | \n",
+ " 0.118434 | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " 0.119100 | \n",
+ " 0.118522 | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " 0.119600 | \n",
+ " 0.118490 | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " 0.119000 | \n",
+ " 0.118394 | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " 0.119200 | \n",
+ " 0.118425 | \n",
+ "
\n",
+ " \n",
+ " 45 | \n",
+ " 0.118600 | \n",
+ " 0.118437 | \n",
+ "
\n",
+ " \n",
+ " 46 | \n",
+ " 0.119100 | \n",
+ " 0.118353 | \n",
+ "
\n",
+ " \n",
+ " 47 | \n",
+ " 0.119000 | \n",
+ " 0.118378 | \n",
+ "
\n",
+ " \n",
+ " 48 | \n",
+ " 0.119100 | \n",
+ " 0.118386 | \n",
+ "
\n",
+ " \n",
+ " 49 | \n",
+ " 0.119000 | \n",
+ " 0.118385 | \n",
+ "
\n",
+ " \n",
+ " 50 | \n",
+ " 0.118500 | \n",
+ " 0.118385 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:01 EDT)\" (scheduled at 2024-10-04 09:16:01.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:31 EDT)\" (scheduled at 2024-10-04 09:16:16.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:31 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:46 EDT)\" (scheduled at 2024-10-04 09:16:31.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:01 EDT)\" (scheduled at 2024-10-04 09:16:46.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:16 EDT)\" (scheduled at 2024-10-04 09:17:01.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:31 EDT)\" (scheduled at 2024-10-04 09:17:16.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:31 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:46 EDT)\" (scheduled at 2024-10-04 09:17:31.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:01 EDT)\" (scheduled at 2024-10-04 09:17:46.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:16 EDT)\" (scheduled at 2024-10-04 09:18:01.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:31 EDT)\" (scheduled at 2024-10-04 09:18:16.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:31 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:46 EDT)\" (scheduled at 2024-10-04 09:18:31.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:01 EDT)\" (scheduled at 2024-10-04 09:18:46.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:16 EDT)\" (scheduled at 2024-10-04 09:19:01.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:31 EDT)\" (scheduled at 2024-10-04 09:19:16.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:31 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:46 EDT)\" (scheduled at 2024-10-04 09:19:31.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:01 EDT)\" (scheduled at 2024-10-04 09:19:46.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:16 EDT)\" (scheduled at 2024-10-04 09:20:01.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:31 EDT)\" (scheduled at 2024-10-04 09:20:16.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:31 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:46 EDT)\" (scheduled at 2024-10-04 09:20:31.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:01 EDT)\" (scheduled at 2024-10-04 09:20:46.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:16 EDT)\" (scheduled at 2024-10-04 09:21:01.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:31 EDT)\" (scheduled at 2024-10-04 09:21:16.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:31 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:46 EDT)\" (scheduled at 2024-10-04 09:21:31.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:01 EDT)\" (scheduled at 2024-10-04 09:21:46.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:16 EDT)\" (scheduled at 2024-10-04 09:22:01.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:31 EDT)\" (scheduled at 2024-10-04 09:22:16.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:31 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:46 EDT)\" (scheduled at 2024-10-04 09:22:31.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:01 EDT)\" (scheduled at 2024-10-04 09:22:46.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:01 EDT)\" (scheduled at 2024-10-04 09:22:46.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:16 EDT)\" (scheduled at 2024-10-04 09:23:01.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:31 EDT)\" (scheduled at 2024-10-04 09:23:16.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:31 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:46 EDT)\" (scheduled at 2024-10-04 09:23:31.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:01 EDT)\" (scheduled at 2024-10-04 09:23:46.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:16 EDT)\" (scheduled at 2024-10-04 09:24:01.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:31 EDT)\" (scheduled at 2024-10-04 09:24:16.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:31 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:46 EDT)\" (scheduled at 2024-10-04 09:24:31.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:01 EDT)\" (scheduled at 2024-10-04 09:24:46.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:16 EDT)\" (scheduled at 2024-10-04 09:25:01.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:31 EDT)\" (scheduled at 2024-10-04 09:25:16.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:31 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:46 EDT)\" (scheduled at 2024-10-04 09:25:31.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:01 EDT)\" (scheduled at 2024-10-04 09:25:46.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:16 EDT)\" (scheduled at 2024-10-04 09:26:01.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:31 EDT)\" (scheduled at 2024-10-04 09:26:16.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:31 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:46 EDT)\" (scheduled at 2024-10-04 09:26:31.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:01 EDT)\" (scheduled at 2024-10-04 09:26:46.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:16 EDT)\" (scheduled at 2024-10-04 09:27:01.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:31 EDT)\" (scheduled at 2024-10-04 09:27:16.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:31 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:46 EDT)\" (scheduled at 2024-10-04 09:27:31.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:01 EDT)\" (scheduled at 2024-10-04 09:27:46.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:16 EDT)\" (scheduled at 2024-10-04 09:28:01.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:31 EDT)\" (scheduled at 2024-10-04 09:28:16.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:31 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:46 EDT)\" (scheduled at 2024-10-04 09:28:31.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:01 EDT)\" (scheduled at 2024-10-04 09:28:46.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:01 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:16 EDT)\" (scheduled at 2024-10-04 09:29:01.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:16 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:31 EDT)\" (scheduled at 2024-10-04 09:29:16.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:31 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:46 EDT)\" (scheduled at 2024-10-04 09:29:31.931082-04:00)\n",
+ "INFO:p-854016:t-23177213200128:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:46 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23192246899456:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-854016:t-23192246899456:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-854016:t-23192246899456:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 5.812682814598084 seconds, Total Train Time = 842.4421577453613\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:18]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.14788587391376495, 'eval_runtime': 20.3053, 'eval_samples_per_second': 254.367, 'eval_steps_per_second': 7.978, 'epoch': 50.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Dataset name: traffic, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.363 0.363\n",
+ "1 etth2 0.271 0.271\n",
+ "2 ettm1 0.327 0.328\n",
+ "3 ettm2 0.178 0.178\n",
+ "4 weather 0.166 0.165\n",
+ "5 electricity 0.157 0.148\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = traffic, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/1024_96_ft_r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Data lengths: train = 11161, val = 1661, test = 3413\n",
+ "WARNING:p-854016:t-23192246899456:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-854016:t-23192246899456:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 01:04]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.4755541682243347, 'eval_model_preparation_time': 0.0025, 'eval_runtime': 64.835, 'eval_samples_per_second': 52.641, 'eval_steps_per_second': 6.586}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Dataset name: traffic, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:data_handling.py:load_dataset:Data lengths: train = 467, val = 1661, test = 3413\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 3126076\n",
+ "Number of params after freezing the backbone 980178\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-854016:t-23192246899456:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-854016:t-23192246899456:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00017073526474706903\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00017073526474706903\n",
+ "Using learning rate = 0.00017073526474706903\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23192246899456:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-854016:t-23192246899456:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [2891/2950 22:51 < 00:28, 2.11 it/s, Epoch 49/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.311200 | \n",
+ " 0.385815 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.296600 | \n",
+ " 0.384570 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.289100 | \n",
+ " 0.383350 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.279400 | \n",
+ " 0.380428 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.270800 | \n",
+ " 0.378043 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.264300 | \n",
+ " 0.376897 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.259800 | \n",
+ " 0.371791 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.255100 | \n",
+ " 0.367791 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.249100 | \n",
+ " 0.366351 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.247200 | \n",
+ " 0.360846 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.243400 | \n",
+ " 0.364740 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.240400 | \n",
+ " 0.361001 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.237600 | \n",
+ " 0.359136 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.235600 | \n",
+ " 0.359346 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.231600 | \n",
+ " 0.358251 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.229000 | \n",
+ " 0.353927 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.225800 | \n",
+ " 0.360125 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.224300 | \n",
+ " 0.353490 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.224400 | \n",
+ " 0.355529 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.220100 | \n",
+ " 0.357273 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.219300 | \n",
+ " 0.356624 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.216300 | \n",
+ " 0.352336 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.215600 | \n",
+ " 0.356624 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.214500 | \n",
+ " 0.350067 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.213500 | \n",
+ " 0.348890 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.210700 | \n",
+ " 0.356339 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.209100 | \n",
+ " 0.348824 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.209500 | \n",
+ " 0.350834 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 0.207700 | \n",
+ " 0.350435 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 0.207400 | \n",
+ " 0.348997 | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " 0.206500 | \n",
+ " 0.351628 | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " 0.205000 | \n",
+ " 0.349470 | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " 0.204500 | \n",
+ " 0.348043 | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " 0.202500 | \n",
+ " 0.345874 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 0.202700 | \n",
+ " 0.350634 | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " 0.201500 | \n",
+ " 0.347463 | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " 0.201000 | \n",
+ " 0.348373 | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " 0.200000 | \n",
+ " 0.349413 | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " 0.199800 | \n",
+ " 0.345183 | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " 0.199500 | \n",
+ " 0.348336 | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " 0.199100 | \n",
+ " 0.348488 | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " 0.198500 | \n",
+ " 0.346594 | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " 0.197900 | \n",
+ " 0.347064 | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " 0.197900 | \n",
+ " 0.346063 | \n",
+ "
\n",
+ " \n",
+ " 45 | \n",
+ " 0.197500 | \n",
+ " 0.347178 | \n",
+ "
\n",
+ " \n",
+ " 46 | \n",
+ " 0.197100 | \n",
+ " 0.346488 | \n",
+ "
\n",
+ " \n",
+ " 47 | \n",
+ " 0.197100 | \n",
+ " 0.347005 | \n",
+ "
\n",
+ " \n",
+ " 48 | \n",
+ " 0.196900 | \n",
+ " 0.346596 | \n",
+ "
\n",
+ " \n",
+ " 49 | \n",
+ " 0.197000 | \n",
+ " 0.346689 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:50 EDT)\" (scheduled at 2024-10-04 09:31:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:20 EDT)\" (scheduled at 2024-10-04 09:32:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:35 EDT)\" (scheduled at 2024-10-04 09:32:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:50 EDT)\" (scheduled at 2024-10-04 09:32:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:05 EDT)\" (scheduled at 2024-10-04 09:32:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:20 EDT)\" (scheduled at 2024-10-04 09:33:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:35 EDT)\" (scheduled at 2024-10-04 09:33:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:50 EDT)\" (scheduled at 2024-10-04 09:33:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:05 EDT)\" (scheduled at 2024-10-04 09:33:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:20 EDT)\" (scheduled at 2024-10-04 09:34:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:35 EDT)\" (scheduled at 2024-10-04 09:34:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:50 EDT)\" (scheduled at 2024-10-04 09:34:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:05 EDT)\" (scheduled at 2024-10-04 09:34:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:20 EDT)\" (scheduled at 2024-10-04 09:35:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:35 EDT)\" (scheduled at 2024-10-04 09:35:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:50 EDT)\" (scheduled at 2024-10-04 09:35:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:36:05 EDT)\" (scheduled at 2024-10-04 09:35:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:36:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:36:20 EDT)\" (scheduled at 2024-10-04 09:36:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:36:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:36:35 EDT)\" (scheduled at 2024-10-04 09:36:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:36:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:36:50 EDT)\" (scheduled at 2024-10-04 09:36:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:36:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:37:05 EDT)\" (scheduled at 2024-10-04 09:36:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:37:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:37:20 EDT)\" (scheduled at 2024-10-04 09:37:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:37:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:37:35 EDT)\" (scheduled at 2024-10-04 09:37:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:37:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:37:50 EDT)\" (scheduled at 2024-10-04 09:37:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:37:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:05 EDT)\" (scheduled at 2024-10-04 09:37:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:20 EDT)\" (scheduled at 2024-10-04 09:38:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:35 EDT)\" (scheduled at 2024-10-04 09:38:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:50 EDT)\" (scheduled at 2024-10-04 09:38:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:05 EDT)\" (scheduled at 2024-10-04 09:38:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:20 EDT)\" (scheduled at 2024-10-04 09:39:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:35 EDT)\" (scheduled at 2024-10-04 09:39:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:50 EDT)\" (scheduled at 2024-10-04 09:39:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:05 EDT)\" (scheduled at 2024-10-04 09:39:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:20 EDT)\" (scheduled at 2024-10-04 09:40:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:35 EDT)\" (scheduled at 2024-10-04 09:40:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:50 EDT)\" (scheduled at 2024-10-04 09:40:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:05 EDT)\" (scheduled at 2024-10-04 09:40:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:20 EDT)\" (scheduled at 2024-10-04 09:41:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:35 EDT)\" (scheduled at 2024-10-04 09:41:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:50 EDT)\" (scheduled at 2024-10-04 09:41:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:05 EDT)\" (scheduled at 2024-10-04 09:41:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:20 EDT)\" (scheduled at 2024-10-04 09:42:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:35 EDT)\" (scheduled at 2024-10-04 09:42:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:50 EDT)\" (scheduled at 2024-10-04 09:42:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:43:05 EDT)\" (scheduled at 2024-10-04 09:42:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:43:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:43:20 EDT)\" (scheduled at 2024-10-04 09:43:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:43:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:43:35 EDT)\" (scheduled at 2024-10-04 09:43:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:43:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:43:50 EDT)\" (scheduled at 2024-10-04 09:43:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:43:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:44:05 EDT)\" (scheduled at 2024-10-04 09:43:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:44:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:44:20 EDT)\" (scheduled at 2024-10-04 09:44:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:44:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:44:35 EDT)\" (scheduled at 2024-10-04 09:44:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:44:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:44:50 EDT)\" (scheduled at 2024-10-04 09:44:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:44:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:45:05 EDT)\" (scheduled at 2024-10-04 09:44:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:45:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:45:20 EDT)\" (scheduled at 2024-10-04 09:45:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:45:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:45:35 EDT)\" (scheduled at 2024-10-04 09:45:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:45:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:45:50 EDT)\" (scheduled at 2024-10-04 09:45:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:45:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:46:05 EDT)\" (scheduled at 2024-10-04 09:45:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:46:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:46:20 EDT)\" (scheduled at 2024-10-04 09:46:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:46:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:46:35 EDT)\" (scheduled at 2024-10-04 09:46:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:46:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:46:50 EDT)\" (scheduled at 2024-10-04 09:46:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:46:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:47:05 EDT)\" (scheduled at 2024-10-04 09:46:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:47:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:47:20 EDT)\" (scheduled at 2024-10-04 09:47:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:47:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:47:35 EDT)\" (scheduled at 2024-10-04 09:47:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:47:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:47:50 EDT)\" (scheduled at 2024-10-04 09:47:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:47:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:48:05 EDT)\" (scheduled at 2024-10-04 09:47:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:48:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:48:20 EDT)\" (scheduled at 2024-10-04 09:48:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:48:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:48:35 EDT)\" (scheduled at 2024-10-04 09:48:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:48:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:48:50 EDT)\" (scheduled at 2024-10-04 09:48:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:48:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:49:05 EDT)\" (scheduled at 2024-10-04 09:48:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:49:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:49:20 EDT)\" (scheduled at 2024-10-04 09:49:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:49:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:49:35 EDT)\" (scheduled at 2024-10-04 09:49:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:49:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:49:50 EDT)\" (scheduled at 2024-10-04 09:49:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:49:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:50:05 EDT)\" (scheduled at 2024-10-04 09:49:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:50:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:50:20 EDT)\" (scheduled at 2024-10-04 09:50:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:50:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:50:35 EDT)\" (scheduled at 2024-10-04 09:50:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:50:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:50:50 EDT)\" (scheduled at 2024-10-04 09:50:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:50:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:51:05 EDT)\" (scheduled at 2024-10-04 09:50:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:51:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:51:20 EDT)\" (scheduled at 2024-10-04 09:51:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:51:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:51:35 EDT)\" (scheduled at 2024-10-04 09:51:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:51:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:51:50 EDT)\" (scheduled at 2024-10-04 09:51:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:51:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:52:05 EDT)\" (scheduled at 2024-10-04 09:51:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:52:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:52:20 EDT)\" (scheduled at 2024-10-04 09:52:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:52:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:52:35 EDT)\" (scheduled at 2024-10-04 09:52:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:52:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:52:50 EDT)\" (scheduled at 2024-10-04 09:52:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:52:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:53:05 EDT)\" (scheduled at 2024-10-04 09:52:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:53:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:53:20 EDT)\" (scheduled at 2024-10-04 09:53:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:53:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:53:35 EDT)\" (scheduled at 2024-10-04 09:53:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:53:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:53:50 EDT)\" (scheduled at 2024-10-04 09:53:35.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:53:50 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:54:05 EDT)\" (scheduled at 2024-10-04 09:53:50.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:54:05 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:54:20 EDT)\" (scheduled at 2024-10-04 09:54:05.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:54:20 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:54:35 EDT)\" (scheduled at 2024-10-04 09:54:20.993797-04:00)\n",
+ "INFO:p-854016:t-23177200080640:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:54:35 EDT)\" executed successfully\n",
+ "INFO:p-854016:t-23192246899456:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-854016:t-23192246899456:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-854016:t-23192246899456:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 8.869987984092868 seconds, Total Train Time = 1373.7280249595642\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 00:35]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.41427138447761536, 'eval_runtime': 37.1621, 'eval_samples_per_second': 91.841, 'eval_steps_per_second': 11.49, 'epoch': 49.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.363 0.363\n",
+ "1 etth2 0.271 0.271\n",
+ "2 ettm1 0.327 0.328\n",
+ "3 ettm2 0.178 0.178\n",
+ "4 weather 0.166 0.165\n",
+ "5 electricity 0.157 0.148\n",
+ "6 traffic 0.476 0.414\n"
+ ]
+ }
+ ],
+ "source": [
+ "all_results = {\n",
+ " \"dataset\": [],\n",
+ " \"zs_mse\": [],\n",
+ " \"fs5_mse\": [],\n",
+ " \"zs_eval_time\": [],\n",
+ " \"fs5_mean_epoch_time\": [],\n",
+ " \"fs5_total_train_time\": [],\n",
+ " \"fs5_best_val_metric\": [],\n",
+ "}\n",
+ "# Loop over data\n",
+ "for DATASET in list_datasets:\n",
+ " print()\n",
+ " print(\"=\" * 100)\n",
+ " print(\n",
+ " f\"Running zero-shot/few-shot for TTM-{context_length} on dataset = {DATASET}, forecast_len = {forecast_length}\"\n",
+ " )\n",
+ " print(f\"Model will be loaded from {hf_model_path}/{hf_model_branch}\")\n",
+ " SUBDIR = f\"{OUT_DIR}/{DATASET}\"\n",
+ "\n",
+ " # Set batch size\n",
+ " if DATASET == \"traffic\":\n",
+ " BATCH_SIZE = 8\n",
+ " elif DATASET == \"electricity\":\n",
+ " BATCH_SIZE = 32\n",
+ " else:\n",
+ " BATCH_SIZE = 64\n",
+ "\n",
+ " # Data prep: Get dataset\n",
+ " _, _, dset_test = load_dataset(\n",
+ " DATASET,\n",
+ " context_length,\n",
+ " forecast_length,\n",
+ " dataset_root_path=DATA_ROOT_PATH,\n",
+ " use_frequency_token=enable_prefix_tuning,\n",
+ " )\n",
+ "\n",
+ " #############################################################\n",
+ " ##### Use the pretrained model in zero-shot forecasting #####\n",
+ " #############################################################\n",
+ " # Load model\n",
+ " zeroshot_model = TinyTimeMixerForPrediction.from_pretrained(hf_model_path, revision=hf_model_branch)\n",
+ "\n",
+ " # zeroshot_trainer\n",
+ " zeroshot_trainer = Trainer(\n",
+ " model=zeroshot_model,\n",
+ " args=TrainingArguments(\n",
+ " output_dir=f\"{SUBDIR}/zeroshot\",\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " seed=SEED,\n",
+ " ),\n",
+ " eval_dataset=dset_test,\n",
+ " )\n",
+ "\n",
+ " # evaluate = zero-shot performance\n",
+ " print(\"+\" * 20, \"Test MSE zero-shot\", \"+\" * 20)\n",
+ " zeroshot_output = zeroshot_trainer.evaluate(dset_test)\n",
+ " print(zeroshot_output)\n",
+ " print(\"+\" * 60)\n",
+ " all_results[\"zs_eval_time\"].append(zeroshot_output[\"eval_runtime\"])\n",
+ "\n",
+ " # Plot\n",
+ " plot_predictions(\n",
+ " model=zeroshot_trainer.model,\n",
+ " dset=dset_test,\n",
+ " plot_dir=SUBDIR,\n",
+ " num_plots=10,\n",
+ " plot_prefix=\"test_zeroshot\",\n",
+ " channel=0,\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " # write results\n",
+ " all_results[\"dataset\"].append(DATASET)\n",
+ " all_results[\"zs_mse\"].append(zeroshot_output[\"eval_loss\"])\n",
+ "\n",
+ " ################################################################\n",
+ " ## Use the pretrained model in few-shot 5% and 10% forecasting #\n",
+ " ################################################################\n",
+ " for fewshot_percent in [5]:\n",
+ " # Set learning rate\n",
+ " learning_rate = None # `None` value indicates that the optimal_lr_finder() will be used\n",
+ "\n",
+ " print(\"-\" * 20, f\"Running few-shot {fewshot_percent}%\", \"-\" * 20)\n",
+ " # Data prep: Get dataset\n",
+ " dset_train, dset_val, dset_test = load_dataset(\n",
+ " DATASET,\n",
+ " context_length,\n",
+ " forecast_length,\n",
+ " fewshot_fraction=fewshot_percent / 100,\n",
+ " dataset_root_path=DATA_ROOT_PATH,\n",
+ " use_frequency_token=enable_prefix_tuning,\n",
+ " )\n",
+ "\n",
+ " # change head dropout to 0.7 for ett datasets\n",
+ " if \"ett\" in DATASET:\n",
+ " finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(\n",
+ " hf_model_path, revision=hf_model_branch, head_dropout=0.7\n",
+ " )\n",
+ " else:\n",
+ " finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(\n",
+ " hf_model_path, revision=hf_model_branch\n",
+ " )\n",
+ "\n",
+ " if freeze_backbone:\n",
+ " print(\n",
+ " \"Number of params before freezing backbone\",\n",
+ " count_parameters(finetune_forecast_model),\n",
+ " )\n",
+ "\n",
+ " # Freeze the backbone of the model\n",
+ " for param in finetune_forecast_model.backbone.parameters():\n",
+ " param.requires_grad = False\n",
+ "\n",
+ " # Count params\n",
+ " print(\n",
+ " \"Number of params after freezing the backbone\",\n",
+ " count_parameters(finetune_forecast_model),\n",
+ " )\n",
+ "\n",
+ " if learning_rate is None:\n",
+ " learning_rate, finetune_forecast_model = optimal_lr_finder(\n",
+ " finetune_forecast_model,\n",
+ " dset_train,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " enable_prefix_tuning=enable_prefix_tuning,\n",
+ " )\n",
+ " print(\"OPTIMAL SUGGESTED LEARNING RATE =\", learning_rate)\n",
+ "\n",
+ " print(f\"Using learning rate = {learning_rate}\")\n",
+ " finetune_forecast_args = TrainingArguments(\n",
+ " output_dir=f\"{SUBDIR}/fewshot_{fewshot_percent}\",\n",
+ " overwrite_output_dir=True,\n",
+ " learning_rate=learning_rate,\n",
+ " num_train_epochs=EPOCHS,\n",
+ " do_eval=True,\n",
+ " evaluation_strategy=\"epoch\",\n",
+ " per_device_train_batch_size=BATCH_SIZE,\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " dataloader_num_workers=NUM_WORKERS,\n",
+ " report_to=None,\n",
+ " save_strategy=\"epoch\",\n",
+ " logging_strategy=\"epoch\",\n",
+ " save_total_limit=1,\n",
+ " logging_dir=f\"{SUBDIR}/fewshot_{fewshot_percent}\", # Make sure to specify a logging directory\n",
+ " load_best_model_at_end=True, # Load the best model when training ends\n",
+ " metric_for_best_model=\"eval_loss\", # Metric to monitor for early stopping\n",
+ " greater_is_better=False, # For loss\n",
+ " seed=SEED,\n",
+ " )\n",
+ "\n",
+ " # Create the early stopping callback\n",
+ " early_stopping_callback = EarlyStoppingCallback(\n",
+ " early_stopping_patience=10, # Number of epochs with no improvement after which to stop\n",
+ " early_stopping_threshold=0.0, # Minimum improvement required to consider as improvement\n",
+ " )\n",
+ " tracking_callback = TrackingCallback()\n",
+ "\n",
+ " # Optimizer and scheduler\n",
+ " optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate)\n",
+ " scheduler = OneCycleLR(\n",
+ " optimizer,\n",
+ " learning_rate,\n",
+ " epochs=EPOCHS,\n",
+ " steps_per_epoch=math.ceil(len(dset_train) / (BATCH_SIZE)),\n",
+ " )\n",
+ "\n",
+ " finetune_forecast_trainer = Trainer(\n",
+ " model=finetune_forecast_model,\n",
+ " args=finetune_forecast_args,\n",
+ " train_dataset=dset_train,\n",
+ " eval_dataset=dset_val,\n",
+ " callbacks=[early_stopping_callback, tracking_callback],\n",
+ " optimizers=(optimizer, scheduler),\n",
+ " )\n",
+ "\n",
+ " # Fine tune\n",
+ " finetune_forecast_trainer.train()\n",
+ "\n",
+ " # Evaluation\n",
+ " print(\n",
+ " \"+\" * 20,\n",
+ " f\"Test MSE after few-shot {fewshot_percent}% fine-tuning\",\n",
+ " \"+\" * 20,\n",
+ " )\n",
+ " fewshot_output = finetune_forecast_trainer.evaluate(dset_test)\n",
+ " print(fewshot_output)\n",
+ " print(\"+\" * 60)\n",
+ "\n",
+ " # Plot\n",
+ " plot_predictions(\n",
+ " model=finetune_forecast_trainer.model,\n",
+ " dset=dset_test,\n",
+ " plot_dir=SUBDIR,\n",
+ " num_plots=10,\n",
+ " plot_prefix=f\"test_fewshot_{fewshot_percent}\",\n",
+ " channel=0,\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " # write results\n",
+ " all_results[f\"fs{fewshot_percent}_mse\"].append(fewshot_output[\"eval_loss\"])\n",
+ " all_results[f\"fs{fewshot_percent}_mean_epoch_time\"].append(tracking_callback.mean_epoch_time)\n",
+ " all_results[f\"fs{fewshot_percent}_total_train_time\"].append(tracking_callback.total_train_time)\n",
+ " all_results[f\"fs{fewshot_percent}_best_val_metric\"].append(tracking_callback.best_eval_metric)\n",
+ "\n",
+ " df_out = pd.DataFrame(all_results).round(3)\n",
+ " print(df_out[[\"dataset\", \"zs_mse\", \"fs5_mse\"]])\n",
+ " df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")\n",
+ " df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Benchmarking results*\n",
+ "\n",
+ "*Some slight differences in the results as compared to the TTM paper results is possible due to different training environments."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " dataset | \n",
+ " zs_mse | \n",
+ " fs5_mse | \n",
+ " zs_eval_time | \n",
+ " fs5_mean_epoch_time | \n",
+ " fs5_total_train_time | \n",
+ " fs5_best_val_metric | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " etth1 | \n",
+ " 0.363 | \n",
+ " 0.363 | \n",
+ " 4.462 | \n",
+ " 1.086 | \n",
+ " 34.318 | \n",
+ " 0.680 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " etth2 | \n",
+ " 0.271 | \n",
+ " 0.271 | \n",
+ " 1.416 | \n",
+ " 1.108 | \n",
+ " 32.470 | \n",
+ " 0.228 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " ettm1 | \n",
+ " 0.327 | \n",
+ " 0.328 | \n",
+ " 4.675 | \n",
+ " 1.459 | \n",
+ " 48.188 | \n",
+ " 0.408 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " ettm2 | \n",
+ " 0.178 | \n",
+ " 0.178 | \n",
+ " 4.483 | \n",
+ " 1.491 | \n",
+ " 47.828 | \n",
+ " 0.122 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " weather | \n",
+ " 0.166 | \n",
+ " 0.165 | \n",
+ " 7.540 | \n",
+ " 2.040 | \n",
+ " 86.713 | \n",
+ " 0.380 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " electricity | \n",
+ " 0.157 | \n",
+ " 0.148 | \n",
+ " 32.046 | \n",
+ " 5.813 | \n",
+ " 842.442 | \n",
+ " 0.118 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " traffic | \n",
+ " 0.476 | \n",
+ " 0.414 | \n",
+ " 64.835 | \n",
+ " 8.870 | \n",
+ " 1373.728 | \n",
+ " 0.345 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " dataset zs_mse fs5_mse zs_eval_time fs5_mean_epoch_time \\\n",
+ "0 etth1 0.363 0.363 4.462 1.086 \n",
+ "1 etth2 0.271 0.271 1.416 1.108 \n",
+ "2 ettm1 0.327 0.328 4.675 1.459 \n",
+ "3 ettm2 0.178 0.178 4.483 1.491 \n",
+ "4 weather 0.166 0.165 7.540 2.040 \n",
+ "5 electricity 0.157 0.148 32.046 5.813 \n",
+ "6 traffic 0.476 0.414 64.835 8.870 \n",
+ "\n",
+ " fs5_total_train_time fs5_best_val_metric \n",
+ "0 34.318 0.680 \n",
+ "1 32.470 0.228 \n",
+ "2 48.188 0.408 \n",
+ "3 47.828 0.122 \n",
+ "4 86.713 0.380 \n",
+ "5 842.442 0.118 \n",
+ "6 1373.728 0.345 "
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_out"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/hfdemo/tinytimemixer/research_use/ttm-r2_freq_benchmarking_1536_96.ipynb b/notebooks/hfdemo/tinytimemixer/research_use/ttm-r2_freq_benchmarking_1536_96.ipynb
new file mode 100644
index 00000000..d7bf2cc5
--- /dev/null
+++ b/notebooks/hfdemo/tinytimemixer/research_use/ttm-r2_freq_benchmarking_1536_96.ipynb
@@ -0,0 +1,2688 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " # TTM zero-shot and few-shot benchmarking on multiple datasets\n",
+ "\n",
+ " **Using TTM-1536-96 model with Frequency Tuning.**"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2024-10-04 09:11:04.868399: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
+ "2024-10-04 09:11:04.917034: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2024-10-04 09:11:05.640376: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory\n",
+ " warn(f\"Failed to load image Python extension: {e}\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "import logging\n",
+ "import math\n",
+ "import warnings\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd\n",
+ "from torch.optim import AdamW\n",
+ "from torch.optim.lr_scheduler import OneCycleLR\n",
+ "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n",
+ "\n",
+ "from tsfm_public import TinyTimeMixerForPrediction, TrackingCallback, count_parameters, load_dataset\n",
+ "from tsfm_public.toolkit.lr_finder import optimal_lr_finder\n",
+ "from tsfm_public.toolkit.visualization import plot_predictions\n",
+ "\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "\n",
+ "logging.basicConfig(level=logging.ERROR)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Important arguments"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set seed\n",
+ "SEED = 42\n",
+ "set_seed(SEED)\n",
+ "\n",
+ "# Specify model parameters\n",
+ "context_length = 1536\n",
+ "forecast_length = 96\n",
+ "freeze_backbone = True\n",
+ "enable_prefix_tuning = True\n",
+ "\n",
+ "# Other args\n",
+ "EPOCHS = 50\n",
+ "NUM_WORKERS = 16\n",
+ "\n",
+ "# Make sure all the datasets in the following `list_datasets` are\n",
+ "# saved in the `DATA_ROOT_PATH` folder. Or, change it accordingly.\n",
+ "# Refer to the load_dataset() function\n",
+ "# in notebooks/hfdemo/tinytimemixer/utils/ttm_utils.py\n",
+ "# to see how it is used.\n",
+ "DATA_ROOT_PATH = \"/dccstor/tsfm23/datasets/\"\n",
+ "\n",
+ "# This is where results will be saved\n",
+ "OUT_DIR = f\"ttm_v2_freq_results_benchmark_{context_length}_{forecast_length}/\""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## List of benchmark datasets (TTM was not pre-trained on any of these)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "list_datasets = [\n",
+ " \"etth1\",\n",
+ " \"etth2\",\n",
+ " \"ettm1\",\n",
+ " \"ettm2\",\n",
+ " \"weather\",\n",
+ " \"electricity\",\n",
+ " \"traffic\",\n",
+ "]"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Get model path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# TTM models for Only Research and Academic (Non-Commercial) Use are here: https://huggingface.co/ibm/ttm-research-r2\n",
+ "# Please provide the branch name properly based on context_len and forecast_len\n",
+ "\n",
+ "hf_model_path = \"ibm/ttm-research-r2\"\n",
+ "if context_length == 512:\n",
+ " hf_model_branch = \"main\"\n",
+ "elif context_length == 1024 or context_length == 1536:\n",
+ " hf_model_branch = f\"{context_length}_{forecast_length}_ft_r2\"\n",
+ "else:\n",
+ " raise ValueError(\"Valid context lengths are: 512, 1024, and 1536 for now. Stay tuned for more TTM models.\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Main benchmarking loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Dataset name: etth1, context length: 1536, prediction length 96\n",
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Data lengths: train = 7009, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1536 on dataset = etth1, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/1536_96_ft_r2\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "5dd422436cfd43dd8cf226ddf999c210",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "config.json: 0%| | 0.00/1.51k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "0d8b6eaa0b024d6297b46248a189418a",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model.safetensors: 0%| | 0.00/13.0M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-861228:t-23229240083200:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-861228:t-23229240083200:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3588193356990814, 'eval_model_preparation_time': 0.0029, 'eval_runtime': 2.5439, 'eval_samples_per_second': 1094.782, 'eval_steps_per_second': 17.296}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Dataset name: etth1, context length: 1536, prediction length 96\n",
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Data lengths: train = 260, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 3243996\n",
+ "Number of params after freezing the backbone 1079394\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-861228:t-23229240083200:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-861228:t-23229240083200:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.0005214008287999684\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.0005214008287999684\n",
+ "Using learning rate = 0.0005214008287999684\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-861228:t-23229240083200:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 55/250 00:28 < 01:43, 1.88 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.706600 | \n",
+ " 0.676132 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.590100 | \n",
+ " 0.676201 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.489000 | \n",
+ " 0.676943 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.389600 | \n",
+ " 0.680147 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.344500 | \n",
+ " 0.689842 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.310300 | \n",
+ " 0.714360 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.294400 | \n",
+ " 0.751174 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.258300 | \n",
+ " 0.766458 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.248700 | \n",
+ " 0.809178 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.222200 | \n",
+ " 0.848143 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.215800 | \n",
+ " 0.868455 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23215233554176:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:11:31 EDT)\" (scheduled at 2024-10-04 09:11:31.009870-04:00)\n",
+ "INFO:p-861228:t-23215233554176:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:11:46 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23229240083200:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-861228:t-23229240083200:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-861228:t-23229240083200:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.0010949481617322 seconds, Total Train Time = 29.304269552230835\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.35881510376930237, 'eval_runtime': 1.3471, 'eval_samples_per_second': 2067.478, 'eval_steps_per_second': 32.664, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Dataset name: etth2, context length: 1536, prediction length 96\n",
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Data lengths: train = 7009, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.359 0.359\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1536 on dataset = etth2, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/1536_96_ft_r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-861228:t-23229240083200:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-861228:t-23229240083200:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.2640553414821625, 'eval_model_preparation_time': 0.0025, 'eval_runtime': 1.2594, 'eval_samples_per_second': 2211.296, 'eval_steps_per_second': 34.936}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Dataset name: etth2, context length: 1536, prediction length 96\n",
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Data lengths: train = 260, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 3243996\n",
+ "Number of params after freezing the backbone 1079394\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-861228:t-23229240083200:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-861228:t-23229240083200:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00020565123083486514\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00020565123083486514\n",
+ "Using learning rate = 0.00020565123083486514\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-861228:t-23229240083200:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 90/250 00:46 < 01:25, 1.88 it/s, Epoch 18/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.436100 | \n",
+ " 0.226110 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.402200 | \n",
+ " 0.226250 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.344900 | \n",
+ " 0.226580 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.311700 | \n",
+ " 0.226758 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.302500 | \n",
+ " 0.226712 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.235200 | \n",
+ " 0.226074 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.199100 | \n",
+ " 0.224492 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.176200 | \n",
+ " 0.224491 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.148300 | \n",
+ " 0.229641 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.133200 | \n",
+ " 0.238513 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.126300 | \n",
+ " 0.249538 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.127900 | \n",
+ " 0.257025 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.114400 | \n",
+ " 0.268801 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.103800 | \n",
+ " 0.276413 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.108000 | \n",
+ " 0.277767 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.101900 | \n",
+ " 0.283584 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.097500 | \n",
+ " 0.285750 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.097900 | \n",
+ " 0.283838 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23221903103744:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:12:08 EDT)\" (scheduled at 2024-10-04 09:12:08.603368-04:00)\n",
+ "INFO:p-861228:t-23221903103744:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:12:23 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23221903103744:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:12:38 EDT)\" (scheduled at 2024-10-04 09:12:23.603368-04:00)\n",
+ "INFO:p-861228:t-23221903103744:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:12:38 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23229240083200:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-861228:t-23229240083200:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-861228:t-23229240083200:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 0.9902433421876695 seconds, Total Train Time = 47.71788811683655\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.2693004012107849, 'eval_runtime': 1.6154, 'eval_samples_per_second': 1724.008, 'eval_steps_per_second': 27.237, 'epoch': 18.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Dataset name: ettm1, context length: 1536, prediction length 96\n",
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Data lengths: train = 32929, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.359 0.359\n",
+ "1 etth2 0.264 0.269\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1536 on dataset = ettm1, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/1536_96_ft_r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-861228:t-23229240083200:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-861228:t-23229240083200:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:04]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3177188038825989, 'eval_model_preparation_time': 0.0025, 'eval_runtime': 4.7801, 'eval_samples_per_second': 2390.124, 'eval_steps_per_second': 37.447}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Dataset name: ettm1, context length: 1536, prediction length 96\n",
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Data lengths: train = 1556, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 3243996\n",
+ "Number of params after freezing the backbone 1079394\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-861228:t-23229240083200:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-861228:t-23229240083200:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 9.770099572992256e-05\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 9.770099572992256e-05\n",
+ "Using learning rate = 9.770099572992256e-05\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-861228:t-23229240083200:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 275/1250 00:49 < 02:55, 5.55 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.946500 | \n",
+ " 0.383166 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.807700 | \n",
+ " 0.385413 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.642300 | \n",
+ " 0.389591 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.530900 | \n",
+ " 0.397725 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.428000 | \n",
+ " 0.410418 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.365300 | \n",
+ " 0.428604 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.329900 | \n",
+ " 0.452348 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.303900 | \n",
+ " 0.461395 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.286200 | \n",
+ " 0.454643 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.272300 | \n",
+ " 0.454069 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.264200 | \n",
+ " 0.455587 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23217620121344:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:08 EDT)\" (scheduled at 2024-10-04 09:13:08.581729-04:00)\n",
+ "INFO:p-861228:t-23217620121344:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:23 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23217620121344:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:38 EDT)\" (scheduled at 2024-10-04 09:13:23.581729-04:00)\n",
+ "INFO:p-861228:t-23217620121344:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:38 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23217620121344:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:53 EDT)\" (scheduled at 2024-10-04 09:13:38.581729-04:00)\n",
+ "INFO:p-861228:t-23217620121344:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:53 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23229240083200:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-861228:t-23229240083200:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-861228:t-23229240083200:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.433471766385165 seconds, Total Train Time = 50.059786319732666\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.31683409214019775, 'eval_runtime': 2.8235, 'eval_samples_per_second': 4046.381, 'eval_steps_per_second': 63.396, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Dataset name: ettm2, context length: 1536, prediction length 96\n",
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Data lengths: train = 32929, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.359 0.359\n",
+ "1 etth2 0.264 0.269\n",
+ "2 ettm1 0.318 0.317\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1536 on dataset = ettm2, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/1536_96_ft_r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-861228:t-23229240083200:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-861228:t-23229240083200:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:04]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.16930672526359558, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 4.8089, 'eval_samples_per_second': 2375.807, 'eval_steps_per_second': 37.223}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Dataset name: ettm2, context length: 1536, prediction length 96\n",
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Data lengths: train = 1556, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 3243996\n",
+ "Number of params after freezing the backbone 1079394\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-861228:t-23229240083200:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-861228:t-23229240083200:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 9.770099572992256e-05\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 9.770099572992256e-05\n",
+ "Using learning rate = 9.770099572992256e-05\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-861228:t-23229240083200:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 275/1250 00:48 < 02:52, 5.65 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.408900 | \n",
+ " 0.121544 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.352800 | \n",
+ " 0.121782 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.276300 | \n",
+ " 0.122294 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.195100 | \n",
+ " 0.123225 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.152900 | \n",
+ " 0.125274 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.123600 | \n",
+ " 0.132202 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.107600 | \n",
+ " 0.143401 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.100600 | \n",
+ " 0.152680 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.096500 | \n",
+ " 0.161098 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.093800 | \n",
+ " 0.165741 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.091600 | \n",
+ " 0.170623 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23217620121344:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:12 EDT)\" (scheduled at 2024-10-04 09:14:12.055758-04:00)\n",
+ "INFO:p-861228:t-23217620121344:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:27 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23217620121344:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:42 EDT)\" (scheduled at 2024-10-04 09:14:27.055758-04:00)\n",
+ "INFO:p-861228:t-23217620121344:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:42 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23217620121344:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:42 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23217620121344:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:57 EDT)\" (scheduled at 2024-10-04 09:14:42.055758-04:00)\n",
+ "INFO:p-861228:t-23217620121344:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:57 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23229240083200:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-861228:t-23229240083200:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-861228:t-23229240083200:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.3992452404715798 seconds, Total Train Time = 49.18882489204407\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1694120466709137, 'eval_runtime': 2.8815, 'eval_samples_per_second': 3964.902, 'eval_steps_per_second': 62.12, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Dataset name: weather, context length: 1536, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.359 0.359\n",
+ "1 etth2 0.264 0.269\n",
+ "2 ettm1 0.318 0.317\n",
+ "3 ettm2 0.169 0.169\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1536 on dataset = weather, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/1536_96_ft_r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Data lengths: train = 35256, val = 5175, test = 10444\n",
+ "WARNING:p-861228:t-23229240083200:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-861228:t-23229240083200:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:09]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.15942735970020294, 'eval_model_preparation_time': 0.0025, 'eval_runtime': 9.6351, 'eval_samples_per_second': 1083.958, 'eval_steps_per_second': 17.021}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Dataset name: weather, context length: 1536, prediction length 96\n",
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Data lengths: train = 1672, val = 5175, test = 10444\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 3243996\n",
+ "Number of params after freezing the backbone 1079394\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-861228:t-23229240083200:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-861228:t-23229240083200:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.000298364724028334\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.000298364724028334\n",
+ "Using learning rate = 0.000298364724028334\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-861228:t-23229240083200:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 297/1350 00:57 < 03:24, 5.15 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.106200 | \n",
+ " 0.392984 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.100900 | \n",
+ " 0.393534 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.095400 | \n",
+ " 0.394598 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.088200 | \n",
+ " 0.397990 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.080900 | \n",
+ " 0.400138 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.074700 | \n",
+ " 0.409341 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.068600 | \n",
+ " 0.411072 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.064000 | \n",
+ " 0.413273 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.059300 | \n",
+ " 0.425014 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.055800 | \n",
+ " 0.413878 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.053400 | \n",
+ " 0.417557 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23215315343104:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:15:21 EDT)\" (scheduled at 2024-10-04 09:15:21.383105-04:00)\n",
+ "INFO:p-861228:t-23215315343104:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:15:36 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215315343104:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:15:51 EDT)\" (scheduled at 2024-10-04 09:15:36.383105-04:00)\n",
+ "INFO:p-861228:t-23215315343104:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:15:51 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215315343104:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:06 EDT)\" (scheduled at 2024-10-04 09:15:51.383105-04:00)\n",
+ "INFO:p-861228:t-23215315343104:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:06 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23229240083200:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-861228:t-23229240083200:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-861228:t-23229240083200:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.972560167312622 seconds, Total Train Time = 58.1689395904541\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:04]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.15510673820972443, 'eval_runtime': 5.2632, 'eval_samples_per_second': 1984.335, 'eval_steps_per_second': 31.16, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Dataset name: electricity, context length: 1536, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.359 0.359\n",
+ "1 etth2 0.264 0.269\n",
+ "2 ettm1 0.318 0.317\n",
+ "3 ettm2 0.169 0.169\n",
+ "4 weather 0.159 0.155\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1536 on dataset = electricity, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/1536_96_ft_r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Data lengths: train = 16781, val = 2537, test = 5165\n",
+ "WARNING:p-861228:t-23229240083200:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-861228:t-23229240083200:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:39]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.15161459147930145, 'eval_model_preparation_time': 0.0025, 'eval_runtime': 39.4884, 'eval_samples_per_second': 130.798, 'eval_steps_per_second': 4.102}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Dataset name: electricity, context length: 1536, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Data lengths: train = 748, val = 2537, test = 5165\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 3243996\n",
+ "Number of params after freezing the backbone 1079394\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-861228:t-23229240083200:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-861228:t-23229240083200:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.0002477076355991711\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.0002477076355991711\n",
+ "Using learning rate = 0.0002477076355991711\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-861228:t-23229240083200:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [1200/1200 17:23, Epoch 50/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.140300 | \n",
+ " 0.126572 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.136400 | \n",
+ " 0.124738 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.134500 | \n",
+ " 0.123417 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.132500 | \n",
+ " 0.121974 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.130200 | \n",
+ " 0.120977 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.128100 | \n",
+ " 0.119957 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.126300 | \n",
+ " 0.119561 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.124500 | \n",
+ " 0.118178 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.123200 | \n",
+ " 0.117826 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.121200 | \n",
+ " 0.117279 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.119100 | \n",
+ " 0.116886 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.118700 | \n",
+ " 0.116810 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.116900 | \n",
+ " 0.116884 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.116500 | \n",
+ " 0.116189 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.115200 | \n",
+ " 0.116396 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.114900 | \n",
+ " 0.116871 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.113600 | \n",
+ " 0.115763 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.113100 | \n",
+ " 0.115339 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.112200 | \n",
+ " 0.115565 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.111100 | \n",
+ " 0.115062 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.110700 | \n",
+ " 0.114427 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.110100 | \n",
+ " 0.114737 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.109700 | \n",
+ " 0.114358 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.109300 | \n",
+ " 0.114306 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.109100 | \n",
+ " 0.114210 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.108600 | \n",
+ " 0.114578 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.108100 | \n",
+ " 0.114918 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.107800 | \n",
+ " 0.114290 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 0.107400 | \n",
+ " 0.113889 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 0.106800 | \n",
+ " 0.114318 | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " 0.106600 | \n",
+ " 0.114354 | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " 0.106400 | \n",
+ " 0.113839 | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " 0.106500 | \n",
+ " 0.113933 | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " 0.106100 | \n",
+ " 0.113754 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 0.106100 | \n",
+ " 0.113705 | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " 0.105500 | \n",
+ " 0.113696 | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " 0.105200 | \n",
+ " 0.113619 | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " 0.104900 | \n",
+ " 0.113821 | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " 0.105100 | \n",
+ " 0.113573 | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " 0.105400 | \n",
+ " 0.113611 | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " 0.105100 | \n",
+ " 0.113580 | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " 0.104700 | \n",
+ " 0.113523 | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " 0.105000 | \n",
+ " 0.113410 | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " 0.104500 | \n",
+ " 0.113375 | \n",
+ "
\n",
+ " \n",
+ " 45 | \n",
+ " 0.105000 | \n",
+ " 0.113353 | \n",
+ "
\n",
+ " \n",
+ " 46 | \n",
+ " 0.104700 | \n",
+ " 0.113399 | \n",
+ "
\n",
+ " \n",
+ " 47 | \n",
+ " 0.104700 | \n",
+ " 0.113472 | \n",
+ "
\n",
+ " \n",
+ " 48 | \n",
+ " 0.104600 | \n",
+ " 0.113407 | \n",
+ "
\n",
+ " \n",
+ " 49 | \n",
+ " 0.104700 | \n",
+ " 0.113441 | \n",
+ "
\n",
+ " \n",
+ " 50 | \n",
+ " 0.104700 | \n",
+ " 0.113444 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:30 EDT)\" (scheduled at 2024-10-04 09:17:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:00 EDT)\" (scheduled at 2024-10-04 09:17:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:15 EDT)\" (scheduled at 2024-10-04 09:18:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:30 EDT)\" (scheduled at 2024-10-04 09:18:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:45 EDT)\" (scheduled at 2024-10-04 09:18:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:00 EDT)\" (scheduled at 2024-10-04 09:18:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:15 EDT)\" (scheduled at 2024-10-04 09:19:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:30 EDT)\" (scheduled at 2024-10-04 09:19:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:45 EDT)\" (scheduled at 2024-10-04 09:19:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:00 EDT)\" (scheduled at 2024-10-04 09:19:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:15 EDT)\" (scheduled at 2024-10-04 09:20:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:30 EDT)\" (scheduled at 2024-10-04 09:20:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:45 EDT)\" (scheduled at 2024-10-04 09:20:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:00 EDT)\" (scheduled at 2024-10-04 09:20:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:15 EDT)\" (scheduled at 2024-10-04 09:21:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:30 EDT)\" (scheduled at 2024-10-04 09:21:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:45 EDT)\" (scheduled at 2024-10-04 09:21:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:00 EDT)\" (scheduled at 2024-10-04 09:21:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:15 EDT)\" (scheduled at 2024-10-04 09:22:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:30 EDT)\" (scheduled at 2024-10-04 09:22:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:45 EDT)\" (scheduled at 2024-10-04 09:22:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:00 EDT)\" (scheduled at 2024-10-04 09:22:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:15 EDT)\" (scheduled at 2024-10-04 09:23:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:30 EDT)\" (scheduled at 2024-10-04 09:23:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:45 EDT)\" (scheduled at 2024-10-04 09:23:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:00 EDT)\" (scheduled at 2024-10-04 09:23:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:15 EDT)\" (scheduled at 2024-10-04 09:24:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:30 EDT)\" (scheduled at 2024-10-04 09:24:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:45 EDT)\" (scheduled at 2024-10-04 09:24:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:00 EDT)\" (scheduled at 2024-10-04 09:24:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:15 EDT)\" (scheduled at 2024-10-04 09:25:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:30 EDT)\" (scheduled at 2024-10-04 09:25:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:45 EDT)\" (scheduled at 2024-10-04 09:25:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:00 EDT)\" (scheduled at 2024-10-04 09:25:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:15 EDT)\" (scheduled at 2024-10-04 09:26:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:30 EDT)\" (scheduled at 2024-10-04 09:26:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:45 EDT)\" (scheduled at 2024-10-04 09:26:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:00 EDT)\" (scheduled at 2024-10-04 09:26:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:15 EDT)\" (scheduled at 2024-10-04 09:27:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:30 EDT)\" (scheduled at 2024-10-04 09:27:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:45 EDT)\" (scheduled at 2024-10-04 09:27:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:00 EDT)\" (scheduled at 2024-10-04 09:27:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:15 EDT)\" (scheduled at 2024-10-04 09:28:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:30 EDT)\" (scheduled at 2024-10-04 09:28:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:45 EDT)\" (scheduled at 2024-10-04 09:28:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:00 EDT)\" (scheduled at 2024-10-04 09:28:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:15 EDT)\" (scheduled at 2024-10-04 09:29:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:30 EDT)\" (scheduled at 2024-10-04 09:29:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:45 EDT)\" (scheduled at 2024-10-04 09:29:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:00 EDT)\" (scheduled at 2024-10-04 09:29:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:15 EDT)\" (scheduled at 2024-10-04 09:30:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:30 EDT)\" (scheduled at 2024-10-04 09:30:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:45 EDT)\" (scheduled at 2024-10-04 09:30:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:00 EDT)\" (scheduled at 2024-10-04 09:30:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:15 EDT)\" (scheduled at 2024-10-04 09:31:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:30 EDT)\" (scheduled at 2024-10-04 09:31:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:45 EDT)\" (scheduled at 2024-10-04 09:31:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:00 EDT)\" (scheduled at 2024-10-04 09:31:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:15 EDT)\" (scheduled at 2024-10-04 09:32:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:30 EDT)\" (scheduled at 2024-10-04 09:32:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:45 EDT)\" (scheduled at 2024-10-04 09:32:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:00 EDT)\" (scheduled at 2024-10-04 09:32:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:15 EDT)\" (scheduled at 2024-10-04 09:33:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:30 EDT)\" (scheduled at 2024-10-04 09:33:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:45 EDT)\" (scheduled at 2024-10-04 09:33:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:00 EDT)\" (scheduled at 2024-10-04 09:33:45.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:00 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:15 EDT)\" (scheduled at 2024-10-04 09:34:00.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:15 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:30 EDT)\" (scheduled at 2024-10-04 09:34:15.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:30 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:45 EDT)\" (scheduled at 2024-10-04 09:34:30.668008-04:00)\n",
+ "INFO:p-861228:t-23215225161472:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:45 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23229240083200:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-861228:t-23229240083200:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-861228:t-23229240083200:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 6.7941923379898075 seconds, Total Train Time = 1045.3023135662079\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:25]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.13990454375743866, 'eval_runtime': 26.7672, 'eval_samples_per_second': 192.96, 'eval_steps_per_second': 6.052, 'epoch': 50.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Dataset name: traffic, context length: 1536, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.359 0.359\n",
+ "1 etth2 0.264 0.269\n",
+ "2 ettm1 0.318 0.317\n",
+ "3 ettm2 0.169 0.169\n",
+ "4 weather 0.159 0.155\n",
+ "5 electricity 0.152 0.140\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1536 on dataset = traffic, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/1536_96_ft_r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Data lengths: train = 10649, val = 1661, test = 3413\n",
+ "WARNING:p-861228:t-23229240083200:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-861228:t-23229240083200:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 01:18]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.4620630145072937, 'eval_model_preparation_time': 0.0025, 'eval_runtime': 78.9449, 'eval_samples_per_second': 43.233, 'eval_steps_per_second': 5.409}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Dataset name: traffic, context length: 1536, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:data_handling.py:load_dataset:Data lengths: train = 442, val = 1661, test = 3413\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 3243996\n",
+ "Number of params after freezing the backbone 1079394\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-861228:t-23229240083200:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-861228:t-23229240083200:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 8.111308307896872e-05\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 8.111308307896872e-05\n",
+ "Using learning rate = 8.111308307896872e-05\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23229240083200:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-861228:t-23229240083200:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 616/2800 06:27 < 22:57, 1.59 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.295900 | \n",
+ " 0.390884 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.282800 | \n",
+ " 0.397275 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.272100 | \n",
+ " 0.398559 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.261400 | \n",
+ " 0.400883 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.251000 | \n",
+ " 0.404521 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.242200 | \n",
+ " 0.408792 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.233900 | \n",
+ " 0.412845 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.227200 | \n",
+ " 0.413990 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.221400 | \n",
+ " 0.415893 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.217000 | \n",
+ " 0.420565 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.213100 | \n",
+ " 0.419890 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:37:05 EDT)\" (scheduled at 2024-10-04 09:37:05.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:37:20 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:37:35 EDT)\" (scheduled at 2024-10-04 09:37:20.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:37:35 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:37:50 EDT)\" (scheduled at 2024-10-04 09:37:35.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:37:50 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:05 EDT)\" (scheduled at 2024-10-04 09:37:50.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:05 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:20 EDT)\" (scheduled at 2024-10-04 09:38:05.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:20 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:35 EDT)\" (scheduled at 2024-10-04 09:38:20.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:35 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:50 EDT)\" (scheduled at 2024-10-04 09:38:35.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:38:50 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:05 EDT)\" (scheduled at 2024-10-04 09:38:50.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:05 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:20 EDT)\" (scheduled at 2024-10-04 09:39:05.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:20 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:35 EDT)\" (scheduled at 2024-10-04 09:39:20.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:35 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:50 EDT)\" (scheduled at 2024-10-04 09:39:35.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:39:50 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:05 EDT)\" (scheduled at 2024-10-04 09:39:50.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:05 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:20 EDT)\" (scheduled at 2024-10-04 09:40:05.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:20 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:35 EDT)\" (scheduled at 2024-10-04 09:40:20.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:35 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:50 EDT)\" (scheduled at 2024-10-04 09:40:35.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:40:50 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:05 EDT)\" (scheduled at 2024-10-04 09:40:50.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:05 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:20 EDT)\" (scheduled at 2024-10-04 09:41:05.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:20 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:35 EDT)\" (scheduled at 2024-10-04 09:41:20.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:35 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:50 EDT)\" (scheduled at 2024-10-04 09:41:35.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:41:50 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:05 EDT)\" (scheduled at 2024-10-04 09:41:50.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:05 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:20 EDT)\" (scheduled at 2024-10-04 09:42:05.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:20 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:35 EDT)\" (scheduled at 2024-10-04 09:42:20.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:35 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:50 EDT)\" (scheduled at 2024-10-04 09:42:35.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:42:50 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:43:05 EDT)\" (scheduled at 2024-10-04 09:42:50.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:43:05 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:43:20 EDT)\" (scheduled at 2024-10-04 09:43:05.631304-04:00)\n",
+ "INFO:p-861228:t-23215218599680:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:43:20 EDT)\" executed successfully\n",
+ "INFO:p-861228:t-23229240083200:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-861228:t-23229240083200:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-861228:t-23229240083200:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 10.461192239414562 seconds, Total Train Time = 388.9561378955841\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 00:47]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.4685199558734894, 'eval_runtime': 48.9684, 'eval_samples_per_second': 69.698, 'eval_steps_per_second': 8.72, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.359 0.359\n",
+ "1 etth2 0.264 0.269\n",
+ "2 ettm1 0.318 0.317\n",
+ "3 ettm2 0.169 0.169\n",
+ "4 weather 0.159 0.155\n",
+ "5 electricity 0.152 0.140\n",
+ "6 traffic 0.462 0.469\n"
+ ]
+ }
+ ],
+ "source": [
+ "all_results = {\n",
+ " \"dataset\": [],\n",
+ " \"zs_mse\": [],\n",
+ " \"fs5_mse\": [],\n",
+ " \"zs_eval_time\": [],\n",
+ " \"fs5_mean_epoch_time\": [],\n",
+ " \"fs5_total_train_time\": [],\n",
+ " \"fs5_best_val_metric\": [],\n",
+ "}\n",
+ "# Loop over data\n",
+ "for DATASET in list_datasets:\n",
+ " print()\n",
+ " print(\"=\" * 100)\n",
+ " print(\n",
+ " f\"Running zero-shot/few-shot for TTM-{context_length} on dataset = {DATASET}, forecast_len = {forecast_length}\"\n",
+ " )\n",
+ " print(f\"Model will be loaded from {hf_model_path}/{hf_model_branch}\")\n",
+ " SUBDIR = f\"{OUT_DIR}/{DATASET}\"\n",
+ "\n",
+ " # Set batch size\n",
+ " if DATASET == \"traffic\":\n",
+ " BATCH_SIZE = 8\n",
+ " elif DATASET == \"electricity\":\n",
+ " BATCH_SIZE = 32\n",
+ " else:\n",
+ " BATCH_SIZE = 64\n",
+ "\n",
+ " # Data prep: Get dataset\n",
+ " _, _, dset_test = load_dataset(\n",
+ " DATASET,\n",
+ " context_length,\n",
+ " forecast_length,\n",
+ " dataset_root_path=DATA_ROOT_PATH,\n",
+ " use_frequency_token=enable_prefix_tuning,\n",
+ " )\n",
+ "\n",
+ " #############################################################\n",
+ " ##### Use the pretrained model in zero-shot forecasting #####\n",
+ " #############################################################\n",
+ " # Load model\n",
+ " zeroshot_model = TinyTimeMixerForPrediction.from_pretrained(hf_model_path, revision=hf_model_branch)\n",
+ "\n",
+ " # zeroshot_trainer\n",
+ " zeroshot_trainer = Trainer(\n",
+ " model=zeroshot_model,\n",
+ " args=TrainingArguments(\n",
+ " output_dir=f\"{SUBDIR}/zeroshot\",\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " seed=SEED,\n",
+ " ),\n",
+ " eval_dataset=dset_test,\n",
+ " )\n",
+ "\n",
+ " # evaluate = zero-shot performance\n",
+ " print(\"+\" * 20, \"Test MSE zero-shot\", \"+\" * 20)\n",
+ " zeroshot_output = zeroshot_trainer.evaluate(dset_test)\n",
+ " print(zeroshot_output)\n",
+ " print(\"+\" * 60)\n",
+ " all_results[\"zs_eval_time\"].append(zeroshot_output[\"eval_runtime\"])\n",
+ "\n",
+ " # Plot\n",
+ " plot_predictions(\n",
+ " model=zeroshot_trainer.model,\n",
+ " dset=dset_test,\n",
+ " plot_dir=SUBDIR,\n",
+ " num_plots=10,\n",
+ " plot_prefix=\"test_zeroshot\",\n",
+ " channel=0,\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " # write results\n",
+ " all_results[\"dataset\"].append(DATASET)\n",
+ " all_results[\"zs_mse\"].append(zeroshot_output[\"eval_loss\"])\n",
+ "\n",
+ " ################################################################\n",
+ " ## Use the pretrained model in few-shot 5% and 10% forecasting #\n",
+ " ################################################################\n",
+ " for fewshot_percent in [5]:\n",
+ " # Set learning rate\n",
+ " learning_rate = None # `None` value indicates that the optimal_lr_finder() will be used\n",
+ "\n",
+ " print(\"-\" * 20, f\"Running few-shot {fewshot_percent}%\", \"-\" * 20)\n",
+ " # Data prep: Get dataset\n",
+ " dset_train, dset_val, dset_test = load_dataset(\n",
+ " DATASET,\n",
+ " context_length,\n",
+ " forecast_length,\n",
+ " fewshot_fraction=fewshot_percent / 100,\n",
+ " dataset_root_path=DATA_ROOT_PATH,\n",
+ " use_frequency_token=enable_prefix_tuning,\n",
+ " )\n",
+ "\n",
+ " # change head dropout to 0.7 for ett datasets\n",
+ " if \"ett\" in DATASET:\n",
+ " finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(\n",
+ " hf_model_path, revision=hf_model_branch, head_dropout=0.7\n",
+ " )\n",
+ " else:\n",
+ " finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(\n",
+ " hf_model_path, revision=hf_model_branch\n",
+ " )\n",
+ "\n",
+ " if freeze_backbone:\n",
+ " print(\n",
+ " \"Number of params before freezing backbone\",\n",
+ " count_parameters(finetune_forecast_model),\n",
+ " )\n",
+ "\n",
+ " # Freeze the backbone of the model\n",
+ " for param in finetune_forecast_model.backbone.parameters():\n",
+ " param.requires_grad = False\n",
+ "\n",
+ " # Count params\n",
+ " print(\n",
+ " \"Number of params after freezing the backbone\",\n",
+ " count_parameters(finetune_forecast_model),\n",
+ " )\n",
+ "\n",
+ " if learning_rate is None:\n",
+ " learning_rate, finetune_forecast_model = optimal_lr_finder(\n",
+ " finetune_forecast_model,\n",
+ " dset_train,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " enable_prefix_tuning=enable_prefix_tuning,\n",
+ " )\n",
+ " print(\"OPTIMAL SUGGESTED LEARNING RATE =\", learning_rate)\n",
+ "\n",
+ " print(f\"Using learning rate = {learning_rate}\")\n",
+ " finetune_forecast_args = TrainingArguments(\n",
+ " output_dir=f\"{SUBDIR}/fewshot_{fewshot_percent}\",\n",
+ " overwrite_output_dir=True,\n",
+ " learning_rate=learning_rate,\n",
+ " num_train_epochs=EPOCHS,\n",
+ " do_eval=True,\n",
+ " evaluation_strategy=\"epoch\",\n",
+ " per_device_train_batch_size=BATCH_SIZE,\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " dataloader_num_workers=NUM_WORKERS,\n",
+ " report_to=None,\n",
+ " save_strategy=\"epoch\",\n",
+ " logging_strategy=\"epoch\",\n",
+ " save_total_limit=1,\n",
+ " logging_dir=f\"{SUBDIR}/fewshot_{fewshot_percent}\", # Make sure to specify a logging directory\n",
+ " load_best_model_at_end=True, # Load the best model when training ends\n",
+ " metric_for_best_model=\"eval_loss\", # Metric to monitor for early stopping\n",
+ " greater_is_better=False, # For loss\n",
+ " seed=SEED,\n",
+ " )\n",
+ "\n",
+ " # Create the early stopping callback\n",
+ " early_stopping_callback = EarlyStoppingCallback(\n",
+ " early_stopping_patience=10, # Number of epochs with no improvement after which to stop\n",
+ " early_stopping_threshold=0.0, # Minimum improvement required to consider as improvement\n",
+ " )\n",
+ " tracking_callback = TrackingCallback()\n",
+ "\n",
+ " # Optimizer and scheduler\n",
+ " optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate)\n",
+ " scheduler = OneCycleLR(\n",
+ " optimizer,\n",
+ " learning_rate,\n",
+ " epochs=EPOCHS,\n",
+ " steps_per_epoch=math.ceil(len(dset_train) / (BATCH_SIZE)),\n",
+ " )\n",
+ "\n",
+ " finetune_forecast_trainer = Trainer(\n",
+ " model=finetune_forecast_model,\n",
+ " args=finetune_forecast_args,\n",
+ " train_dataset=dset_train,\n",
+ " eval_dataset=dset_val,\n",
+ " callbacks=[early_stopping_callback, tracking_callback],\n",
+ " optimizers=(optimizer, scheduler),\n",
+ " )\n",
+ "\n",
+ " # Fine tune\n",
+ " finetune_forecast_trainer.train()\n",
+ "\n",
+ " # Evaluation\n",
+ " print(\n",
+ " \"+\" * 20,\n",
+ " f\"Test MSE after few-shot {fewshot_percent}% fine-tuning\",\n",
+ " \"+\" * 20,\n",
+ " )\n",
+ " fewshot_output = finetune_forecast_trainer.evaluate(dset_test)\n",
+ " print(fewshot_output)\n",
+ " print(\"+\" * 60)\n",
+ "\n",
+ " # Plot\n",
+ " plot_predictions(\n",
+ " model=finetune_forecast_trainer.model,\n",
+ " dset=dset_test,\n",
+ " plot_dir=SUBDIR,\n",
+ " num_plots=10,\n",
+ " plot_prefix=f\"test_fewshot_{fewshot_percent}\",\n",
+ " channel=0,\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " # write results\n",
+ " all_results[f\"fs{fewshot_percent}_mse\"].append(fewshot_output[\"eval_loss\"])\n",
+ " all_results[f\"fs{fewshot_percent}_mean_epoch_time\"].append(tracking_callback.mean_epoch_time)\n",
+ " all_results[f\"fs{fewshot_percent}_total_train_time\"].append(tracking_callback.total_train_time)\n",
+ " all_results[f\"fs{fewshot_percent}_best_val_metric\"].append(tracking_callback.best_eval_metric)\n",
+ "\n",
+ " df_out = pd.DataFrame(all_results).round(3)\n",
+ " print(df_out[[\"dataset\", \"zs_mse\", \"fs5_mse\"]])\n",
+ " df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")\n",
+ " df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Benchmarking results*\n",
+ "\n",
+ "*Some slight differences in the results as compared to the TTM paper results is possible due to different training environments."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " dataset | \n",
+ " zs_mse | \n",
+ " fs5_mse | \n",
+ " zs_eval_time | \n",
+ " fs5_mean_epoch_time | \n",
+ " fs5_total_train_time | \n",
+ " fs5_best_val_metric | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " etth1 | \n",
+ " 0.359 | \n",
+ " 0.359 | \n",
+ " 2.544 | \n",
+ " 1.001 | \n",
+ " 29.304 | \n",
+ " 0.676 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " etth2 | \n",
+ " 0.264 | \n",
+ " 0.269 | \n",
+ " 1.259 | \n",
+ " 0.990 | \n",
+ " 47.718 | \n",
+ " 0.224 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " ettm1 | \n",
+ " 0.318 | \n",
+ " 0.317 | \n",
+ " 4.780 | \n",
+ " 1.433 | \n",
+ " 50.060 | \n",
+ " 0.383 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " ettm2 | \n",
+ " 0.169 | \n",
+ " 0.169 | \n",
+ " 4.809 | \n",
+ " 1.399 | \n",
+ " 49.189 | \n",
+ " 0.122 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " weather | \n",
+ " 0.159 | \n",
+ " 0.155 | \n",
+ " 9.635 | \n",
+ " 1.973 | \n",
+ " 58.169 | \n",
+ " 0.393 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " electricity | \n",
+ " 0.152 | \n",
+ " 0.140 | \n",
+ " 39.488 | \n",
+ " 6.794 | \n",
+ " 1045.302 | \n",
+ " 0.113 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " traffic | \n",
+ " 0.462 | \n",
+ " 0.469 | \n",
+ " 78.945 | \n",
+ " 10.461 | \n",
+ " 388.956 | \n",
+ " 0.391 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " dataset zs_mse fs5_mse zs_eval_time fs5_mean_epoch_time \\\n",
+ "0 etth1 0.359 0.359 2.544 1.001 \n",
+ "1 etth2 0.264 0.269 1.259 0.990 \n",
+ "2 ettm1 0.318 0.317 4.780 1.433 \n",
+ "3 ettm2 0.169 0.169 4.809 1.399 \n",
+ "4 weather 0.159 0.155 9.635 1.973 \n",
+ "5 electricity 0.152 0.140 39.488 6.794 \n",
+ "6 traffic 0.462 0.469 78.945 10.461 \n",
+ "\n",
+ " fs5_total_train_time fs5_best_val_metric \n",
+ "0 29.304 0.676 \n",
+ "1 47.718 0.224 \n",
+ "2 50.060 0.383 \n",
+ "3 49.189 0.122 \n",
+ "4 58.169 0.393 \n",
+ "5 1045.302 0.113 \n",
+ "6 388.956 0.391 "
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_out"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/hfdemo/tinytimemixer/research_use/ttm-r2_freq_benchmarking_512_96.ipynb b/notebooks/hfdemo/tinytimemixer/research_use/ttm-r2_freq_benchmarking_512_96.ipynb
new file mode 100644
index 00000000..99133365
--- /dev/null
+++ b/notebooks/hfdemo/tinytimemixer/research_use/ttm-r2_freq_benchmarking_512_96.ipynb
@@ -0,0 +1,2875 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ " # TTM zero-shot and few-shot benchmarking on multiple datasets\n",
+ "\n",
+ " **Using TTM-512-96 model with Frequency Tuning.**"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2024-10-04 09:07:43.310254: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
+ "2024-10-04 09:07:44.089127: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2024-10-04 09:07:47.466446: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory\n",
+ " warn(f\"Failed to load image Python extension: {e}\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "import logging\n",
+ "import math\n",
+ "import warnings\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd\n",
+ "from torch.optim import AdamW\n",
+ "from torch.optim.lr_scheduler import OneCycleLR\n",
+ "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n",
+ "\n",
+ "from tsfm_public import TinyTimeMixerForPrediction, TrackingCallback, count_parameters, load_dataset\n",
+ "from tsfm_public.toolkit.lr_finder import optimal_lr_finder\n",
+ "from tsfm_public.toolkit.visualization import plot_predictions\n",
+ "\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\")\n",
+ "\n",
+ "\n",
+ "logging.basicConfig(level=logging.ERROR)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Important arguments"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set seed\n",
+ "SEED = 42\n",
+ "set_seed(SEED)\n",
+ "\n",
+ "# Specify model parameters\n",
+ "context_length = 512\n",
+ "forecast_length = 96\n",
+ "freeze_backbone = True\n",
+ "enable_prefix_tuning = True\n",
+ "\n",
+ "# Other args\n",
+ "EPOCHS = 50\n",
+ "NUM_WORKERS = 16\n",
+ "\n",
+ "# Make sure all the datasets in the following `list_datasets` are\n",
+ "# saved in the `DATA_ROOT_PATH` folder. Or, change it accordingly.\n",
+ "# Refer to the load_dataset() function\n",
+ "# in notebooks/hfdemo/tinytimemixer/utils/ttm_utils.py\n",
+ "# to see how it is used.\n",
+ "DATA_ROOT_PATH = \"/dccstor/tsfm23/datasets/\"\n",
+ "\n",
+ "# This is where results will be saved\n",
+ "OUT_DIR = f\"ttm_v2_freq_results_benchmark_{context_length}_{forecast_length}/\""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## List of benchmark datasets (TTM was not pre-trained on any of these)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "list_datasets = [\n",
+ " \"etth1\",\n",
+ " \"etth2\",\n",
+ " \"ettm1\",\n",
+ " \"ettm2\",\n",
+ " \"weather\",\n",
+ " \"electricity\",\n",
+ " \"traffic\",\n",
+ "]"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Get model path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# TTM models for Only Research and Academic (Non-Commercial) Use are here: https://huggingface.co/ibm/ttm-research-r2\n",
+ "# Please provide the branch name properly based on context_len and forecast_len\n",
+ "\n",
+ "hf_model_path = \"ibm/ttm-research-r2\"\n",
+ "if context_length == 512:\n",
+ " hf_model_branch = \"main\"\n",
+ "elif context_length == 1024 or context_length == 1536:\n",
+ " hf_model_branch = f\"{context_length}_{forecast_length}_ft_r2\"\n",
+ "else:\n",
+ " raise ValueError(\"Valid context lengths are: 512, 1024, and 1536 for now. Stay tuned for more TTM models.\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Main benchmarking loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Dataset name: etth1, context length: 512, prediction length 96\n",
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Data lengths: train = 8033, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = etth1, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/main\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "a7c75b8186484a0bb2613087e9263f36",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "config.json: 0%| | 0.00/1.51k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "68391ef558864bceb3b151f58d42c775",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model.safetensors: 0%| | 0.00/3.44M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "2d09a76631f14e649aa97380cd5486a8",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "generation_config.json: 0%| | 0.00/69.0 [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-849332:t-23083607622400:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-849332:t-23083607622400:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3641158640384674, 'eval_model_preparation_time': 0.0029, 'eval_runtime': 3.4329, 'eval_samples_per_second': 811.279, 'eval_steps_per_second': 12.817}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Dataset name: etth1, context length: 512, prediction length 96\n",
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Data lengths: train = 311, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 854972\n",
+ "Number of params after freezing the backbone 302162\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-849332:t-23083607622400:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-849332:t-23083607622400:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.0006280291441834253\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.0006280291441834253\n",
+ "Using learning rate = 0.0006280291441834253\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-849332:t-23083607622400:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 70/250 00:28 < 01:14, 2.43 it/s, Epoch 14/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.772100 | \n",
+ " 0.677957 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.735000 | \n",
+ " 0.677009 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.712900 | \n",
+ " 0.676227 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.708100 | \n",
+ " 0.675994 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.667200 | \n",
+ " 0.677063 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.660600 | \n",
+ " 0.679934 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.642500 | \n",
+ " 0.685532 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.626100 | \n",
+ " 0.695030 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.607000 | \n",
+ " 0.696889 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.591100 | \n",
+ " 0.710723 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.571500 | \n",
+ " 0.725537 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.552800 | \n",
+ " 0.723338 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.536200 | \n",
+ " 0.747540 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.515100 | \n",
+ " 0.759912 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23069104002816:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:08:21 EDT)\" (scheduled at 2024-10-04 09:08:21.237899-04:00)\n",
+ "INFO:p-849332:t-23069104002816:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:08:36 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23083607622400:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-849332:t-23083607622400:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-849332:t-23083607622400:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 0.8126076459884644 seconds, Total Train Time = 29.085506439208984\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36364585161209106, 'eval_runtime': 1.0267, 'eval_samples_per_second': 2712.446, 'eval_steps_per_second': 42.854, 'epoch': 14.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Dataset name: etth2, context length: 512, prediction length 96\n",
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Data lengths: train = 8033, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.364 0.364\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = etth2, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-849332:t-23083607622400:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-849332:t-23083607622400:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.27660802006721497, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 0.7318, 'eval_samples_per_second': 3805.928, 'eval_steps_per_second': 60.13}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Dataset name: etth2, context length: 512, prediction length 96\n",
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Data lengths: train = 311, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 854972\n",
+ "Number of params after freezing the backbone 302162\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-849332:t-23083607622400:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-849332:t-23083607622400:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.0002477076355991711\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.0002477076355991711\n",
+ "Using learning rate = 0.0002477076355991711\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-849332:t-23083607622400:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 70/250 00:28 < 01:15, 2.39 it/s, Epoch 14/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.295200 | \n",
+ " 0.212050 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.274500 | \n",
+ " 0.211893 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.245700 | \n",
+ " 0.211690 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.236100 | \n",
+ " 0.211553 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.232500 | \n",
+ " 0.211572 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.223500 | \n",
+ " 0.211775 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.205000 | \n",
+ " 0.212336 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.198200 | \n",
+ " 0.213646 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.193500 | \n",
+ " 0.215267 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.185800 | \n",
+ " 0.216570 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.181800 | \n",
+ " 0.215838 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.174100 | \n",
+ " 0.214250 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.170900 | \n",
+ " 0.213147 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.168300 | \n",
+ " 0.213569 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23070958139136:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:08:57 EDT)\" (scheduled at 2024-10-04 09:08:57.116463-04:00)\n",
+ "INFO:p-849332:t-23070958139136:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:09:12 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23083607622400:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-849332:t-23083607622400:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-849332:t-23083607622400:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 0.7992050477436611 seconds, Total Train Time = 29.094688653945923\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.27722853422164917, 'eval_runtime': 1.0587, 'eval_samples_per_second': 2630.492, 'eval_steps_per_second': 41.559, 'epoch': 14.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Dataset name: ettm1, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.364 0.364\n",
+ "1 etth2 0.277 0.277\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = ettm1, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Data lengths: train = 33953, val = 11425, test = 11425\n",
+ "WARNING:p-849332:t-23083607622400:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-849332:t-23083607622400:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:02]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3216254413127899, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 3.0306, 'eval_samples_per_second': 3769.85, 'eval_steps_per_second': 59.064}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Dataset name: ettm1, context length: 512, prediction length 96\n",
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Data lengths: train = 1607, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 854972\n",
+ "Number of params after freezing the backbone 302162\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-849332:t-23083607622400:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-849332:t-23083607622400:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.0010974987654930567\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.0010974987654930567\n",
+ "Using learning rate = 0.0010974987654930567\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-849332:t-23083607622400:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 286/1300 00:35 < 02:05, 8.05 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.288200 | \n",
+ " 0.383428 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.268400 | \n",
+ " 0.392796 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.256200 | \n",
+ " 0.402290 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.248000 | \n",
+ " 0.402057 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.243600 | \n",
+ " 0.400347 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.233100 | \n",
+ " 0.401490 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.224400 | \n",
+ " 0.399640 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.218200 | \n",
+ " 0.406397 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.214000 | \n",
+ " 0.416378 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.208100 | \n",
+ " 0.420510 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.204500 | \n",
+ " 0.431284 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23070958139136:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:09:35 EDT)\" (scheduled at 2024-10-04 09:09:35.591012-04:00)\n",
+ "INFO:p-849332:t-23070958139136:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:09:50 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070958139136:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:10:05 EDT)\" (scheduled at 2024-10-04 09:09:50.591012-04:00)\n",
+ "INFO:p-849332:t-23070958139136:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:10:05 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23083607622400:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-849332:t-23083607622400:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-849332:t-23083607622400:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.1328961849212646 seconds, Total Train Time = 35.978880167007446\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3108648955821991, 'eval_runtime': 1.9303, 'eval_samples_per_second': 5918.836, 'eval_steps_per_second': 92.733, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Dataset name: ettm2, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.364 0.364\n",
+ "1 etth2 0.277 0.277\n",
+ "2 ettm1 0.322 0.311\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = ettm2, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Data lengths: train = 33953, val = 11425, test = 11425\n",
+ "WARNING:p-849332:t-23083607622400:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-849332:t-23083607622400:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:02]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.17139409482479095, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 3.0318, 'eval_samples_per_second': 3768.374, 'eval_steps_per_second': 59.041}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Dataset name: ettm2, context length: 512, prediction length 96\n",
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Data lengths: train = 1607, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 854972\n",
+ "Number of params after freezing the backbone 302162\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-849332:t-23083607622400:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-849332:t-23083607622400:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.0013219411484660286\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.0013219411484660286\n",
+ "Using learning rate = 0.0013219411484660286\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-849332:t-23083607622400:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 286/1300 00:39 < 02:19, 7.24 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.183500 | \n",
+ " 0.121340 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.163500 | \n",
+ " 0.128155 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.155700 | \n",
+ " 0.128910 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.151000 | \n",
+ " 0.127212 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.145700 | \n",
+ " 0.126899 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.146000 | \n",
+ " 0.126764 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.138500 | \n",
+ " 0.135505 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.134200 | \n",
+ " 0.136716 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.130400 | \n",
+ " 0.138214 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.128600 | \n",
+ " 0.133596 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.126300 | \n",
+ " 0.137331 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23069102954240:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:10:21 EDT)\" (scheduled at 2024-10-04 09:10:21.934245-04:00)\n",
+ "INFO:p-849332:t-23069102954240:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:10:36 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23069102954240:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:10:51 EDT)\" (scheduled at 2024-10-04 09:10:36.934245-04:00)\n",
+ "INFO:p-849332:t-23069102954240:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:10:51 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23083607622400:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-849332:t-23083607622400:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-849332:t-23083607622400:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.2798414880579168 seconds, Total Train Time = 40.0526020526886\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.17053768038749695, 'eval_runtime': 2.0308, 'eval_samples_per_second': 5625.762, 'eval_steps_per_second': 88.141, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Dataset name: weather, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.364 0.364\n",
+ "1 etth2 0.277 0.277\n",
+ "2 ettm1 0.322 0.311\n",
+ "3 ettm2 0.171 0.171\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = weather, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Data lengths: train = 36280, val = 5175, test = 10444\n",
+ "WARNING:p-849332:t-23083607622400:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-849332:t-23083607622400:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:04]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1578841358423233, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 4.7155, 'eval_samples_per_second': 2214.823, 'eval_steps_per_second': 34.779}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Dataset name: weather, context length: 512, prediction length 96\n",
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Data lengths: train = 1723, val = 5175, test = 10444\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 854972\n",
+ "Number of params after freezing the backbone 302162\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-849332:t-23083607622400:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-849332:t-23083607622400:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.0013219411484660286\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.0013219411484660286\n",
+ "Using learning rate = 0.0013219411484660286\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-849332:t-23083607622400:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 297/1350 00:39 < 02:20, 7.50 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.164400 | \n",
+ " 0.400965 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.156600 | \n",
+ " 0.411972 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.151400 | \n",
+ " 0.424562 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.146700 | \n",
+ " 0.434680 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.141100 | \n",
+ " 0.443567 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.135800 | \n",
+ " 0.462677 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.132800 | \n",
+ " 0.453966 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.130600 | \n",
+ " 0.484906 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.129300 | \n",
+ " 0.488840 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.124200 | \n",
+ " 0.493310 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.119500 | \n",
+ " 0.528248 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23071093761792:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:11:14 EDT)\" (scheduled at 2024-10-04 09:11:14.654216-04:00)\n",
+ "INFO:p-849332:t-23071093761792:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:11:29 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23071093761792:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:11:44 EDT)\" (scheduled at 2024-10-04 09:11:29.654216-04:00)\n",
+ "INFO:p-849332:t-23071093761792:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:11:44 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23083607622400:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-849332:t-23083607622400:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-849332:t-23083607622400:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.4791956381364302 seconds, Total Train Time = 40.24339771270752\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1534864753484726, 'eval_runtime': 2.9993, 'eval_samples_per_second': 3482.165, 'eval_steps_per_second': 54.68, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Dataset name: electricity, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.364 0.364\n",
+ "1 etth2 0.277 0.277\n",
+ "2 ettm1 0.322 0.311\n",
+ "3 ettm2 0.171 0.171\n",
+ "4 weather 0.158 0.153\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = electricity, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Data lengths: train = 17805, val = 2537, test = 5165\n",
+ "WARNING:p-849332:t-23083607622400:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-849332:t-23083607622400:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:17]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.16642886400222778, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 17.5817, 'eval_samples_per_second': 293.771, 'eval_steps_per_second': 9.214}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Dataset name: electricity, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Data lengths: train = 800, val = 2537, test = 5165\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 854972\n",
+ "Number of params after freezing the backbone 302162\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-849332:t-23083607622400:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-849332:t-23083607622400:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00014174741629268049\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00014174741629268049\n",
+ "Using learning rate = 0.00014174741629268049\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-849332:t-23083607622400:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [1250/1250 08:34, Epoch 50/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.200000 | \n",
+ " 0.143117 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.198000 | \n",
+ " 0.141135 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.195900 | \n",
+ " 0.138918 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.193700 | \n",
+ " 0.136844 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.190800 | \n",
+ " 0.134736 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.188100 | \n",
+ " 0.132839 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.185200 | \n",
+ " 0.131205 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.182900 | \n",
+ " 0.129960 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.180600 | \n",
+ " 0.129008 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.178600 | \n",
+ " 0.128143 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.176600 | \n",
+ " 0.127417 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.175200 | \n",
+ " 0.127078 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.173600 | \n",
+ " 0.126531 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.172100 | \n",
+ " 0.125496 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.170800 | \n",
+ " 0.125140 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.169900 | \n",
+ " 0.124904 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.168800 | \n",
+ " 0.124422 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.168000 | \n",
+ " 0.124320 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.166900 | \n",
+ " 0.124035 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.166300 | \n",
+ " 0.124053 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.165700 | \n",
+ " 0.123601 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.164900 | \n",
+ " 0.123492 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.164400 | \n",
+ " 0.123469 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.163800 | \n",
+ " 0.122978 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.163300 | \n",
+ " 0.122880 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.162800 | \n",
+ " 0.123042 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.162500 | \n",
+ " 0.123237 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.161900 | \n",
+ " 0.122722 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 0.161600 | \n",
+ " 0.122502 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 0.161100 | \n",
+ " 0.122304 | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " 0.161000 | \n",
+ " 0.122278 | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " 0.160600 | \n",
+ " 0.122204 | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " 0.160300 | \n",
+ " 0.122023 | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " 0.160200 | \n",
+ " 0.122520 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 0.160000 | \n",
+ " 0.121933 | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " 0.159700 | \n",
+ " 0.121884 | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " 0.159400 | \n",
+ " 0.121899 | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " 0.159200 | \n",
+ " 0.121840 | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " 0.159100 | \n",
+ " 0.121886 | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " 0.159100 | \n",
+ " 0.121761 | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " 0.158900 | \n",
+ " 0.121728 | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " 0.159000 | \n",
+ " 0.121783 | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " 0.158700 | \n",
+ " 0.121812 | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " 0.158700 | \n",
+ " 0.121707 | \n",
+ "
\n",
+ " \n",
+ " 45 | \n",
+ " 0.158800 | \n",
+ " 0.121747 | \n",
+ "
\n",
+ " \n",
+ " 46 | \n",
+ " 0.158600 | \n",
+ " 0.121713 | \n",
+ "
\n",
+ " \n",
+ " 47 | \n",
+ " 0.158500 | \n",
+ " 0.121719 | \n",
+ "
\n",
+ " \n",
+ " 48 | \n",
+ " 0.158700 | \n",
+ " 0.121724 | \n",
+ "
\n",
+ " \n",
+ " 49 | \n",
+ " 0.158600 | \n",
+ " 0.121715 | \n",
+ "
\n",
+ " \n",
+ " 50 | \n",
+ " 0.158800 | \n",
+ " 0.121715 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:12:30 EDT)\" (scheduled at 2024-10-04 09:12:30.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:12:45 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:00 EDT)\" (scheduled at 2024-10-04 09:12:45.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:00 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:15 EDT)\" (scheduled at 2024-10-04 09:13:00.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:15 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:30 EDT)\" (scheduled at 2024-10-04 09:13:15.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:30 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:45 EDT)\" (scheduled at 2024-10-04 09:13:30.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:13:45 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:00 EDT)\" (scheduled at 2024-10-04 09:13:45.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:00 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:15 EDT)\" (scheduled at 2024-10-04 09:14:00.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:15 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:30 EDT)\" (scheduled at 2024-10-04 09:14:15.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:30 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:45 EDT)\" (scheduled at 2024-10-04 09:14:30.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:14:45 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:15:00 EDT)\" (scheduled at 2024-10-04 09:14:45.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:15:00 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:15:15 EDT)\" (scheduled at 2024-10-04 09:15:00.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:15:15 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:15:30 EDT)\" (scheduled at 2024-10-04 09:15:15.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:15:30 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:15:30 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:15:45 EDT)\" (scheduled at 2024-10-04 09:15:30.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:15:45 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:00 EDT)\" (scheduled at 2024-10-04 09:15:45.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:00 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:15 EDT)\" (scheduled at 2024-10-04 09:16:00.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:15 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:30 EDT)\" (scheduled at 2024-10-04 09:16:15.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:30 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:45 EDT)\" (scheduled at 2024-10-04 09:16:30.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:16:45 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:00 EDT)\" (scheduled at 2024-10-04 09:16:45.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:00 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:15 EDT)\" (scheduled at 2024-10-04 09:17:00.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:15 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:30 EDT)\" (scheduled at 2024-10-04 09:17:15.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:30 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:45 EDT)\" (scheduled at 2024-10-04 09:17:30.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:17:45 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:00 EDT)\" (scheduled at 2024-10-04 09:17:45.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:00 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:15 EDT)\" (scheduled at 2024-10-04 09:18:00.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:15 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:30 EDT)\" (scheduled at 2024-10-04 09:18:15.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:30 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:45 EDT)\" (scheduled at 2024-10-04 09:18:30.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:18:45 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:00 EDT)\" (scheduled at 2024-10-04 09:18:45.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:00 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:15 EDT)\" (scheduled at 2024-10-04 09:19:00.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:15 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:15 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:30 EDT)\" (scheduled at 2024-10-04 09:19:15.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:30 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:45 EDT)\" (scheduled at 2024-10-04 09:19:30.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:19:45 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:00 EDT)\" (scheduled at 2024-10-04 09:19:45.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:00 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:15 EDT)\" (scheduled at 2024-10-04 09:20:00.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:15 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:30 EDT)\" (scheduled at 2024-10-04 09:20:15.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:30 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:45 EDT)\" (scheduled at 2024-10-04 09:20:30.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:20:45 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:00 EDT)\" (scheduled at 2024-10-04 09:20:45.751854-04:00)\n",
+ "INFO:p-849332:t-23070964442880:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:21:00 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23083607622400:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-849332:t-23083607622400:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-849332:t-23083607622400:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 3.542292618751526 seconds, Total Train Time = 516.1418724060059\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:11]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1458154022693634, 'eval_runtime': 12.4771, 'eval_samples_per_second': 413.958, 'eval_steps_per_second': 12.984, 'epoch': 50.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Dataset name: traffic, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.364 0.364\n",
+ "1 etth2 0.277 0.277\n",
+ "2 ettm1 0.322 0.311\n",
+ "3 ettm2 0.171 0.171\n",
+ "4 weather 0.158 0.153\n",
+ "5 electricity 0.166 0.146\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = traffic, forecast_len = 96\n",
+ "Model will be loaded from ibm/ttm-research-r2/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Data lengths: train = 11673, val = 1661, test = 3413\n",
+ "WARNING:p-849332:t-23083607622400:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-849332:t-23083607622400:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 00:31]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.5142123699188232, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 31.6561, 'eval_samples_per_second': 107.815, 'eval_steps_per_second': 13.489}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Dataset name: traffic, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:data_handling.py:load_dataset:Data lengths: train = 493, val = 1661, test = 3413\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 854972\n",
+ "Number of params after freezing the backbone 302162\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-849332:t-23083607622400:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-849332:t-23083607622400:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00011768119524349978\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00011768119524349978\n",
+ "Using learning rate = 0.00011768119524349978\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23083607622400:base.py:_real_add_job:Added job \"EmissionsTracker._measure_power\" to job store \"default\"\n",
+ "INFO:p-849332:t-23083607622400:base.py:start:Scheduler started\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [3100/3100 13:56, Epoch 50/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.300400 | \n",
+ " 0.419105 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.293100 | \n",
+ " 0.412652 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.288100 | \n",
+ " 0.407203 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.283400 | \n",
+ " 0.403235 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.279000 | \n",
+ " 0.400708 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.274600 | \n",
+ " 0.394619 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.270100 | \n",
+ " 0.390383 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.265900 | \n",
+ " 0.388255 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.261400 | \n",
+ " 0.382568 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.257300 | \n",
+ " 0.376852 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.252700 | \n",
+ " 0.370713 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.248900 | \n",
+ " 0.369995 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.245000 | \n",
+ " 0.365612 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.241800 | \n",
+ " 0.360373 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.238700 | \n",
+ " 0.359754 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.235900 | \n",
+ " 0.357963 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.233700 | \n",
+ " 0.354857 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.231700 | \n",
+ " 0.353336 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.229900 | \n",
+ " 0.353722 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.228100 | \n",
+ " 0.348975 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.226400 | \n",
+ " 0.348115 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.225300 | \n",
+ " 0.347397 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.224200 | \n",
+ " 0.346221 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.222900 | \n",
+ " 0.345252 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.221900 | \n",
+ " 0.346918 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.220900 | \n",
+ " 0.344589 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.220100 | \n",
+ " 0.344082 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.219600 | \n",
+ " 0.344436 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 0.218500 | \n",
+ " 0.343222 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 0.217700 | \n",
+ " 0.343030 | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " 0.217300 | \n",
+ " 0.343605 | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " 0.216700 | \n",
+ " 0.341811 | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " 0.216200 | \n",
+ " 0.341363 | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " 0.215800 | \n",
+ " 0.341179 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 0.215400 | \n",
+ " 0.341348 | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " 0.215000 | \n",
+ " 0.340437 | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " 0.214800 | \n",
+ " 0.340590 | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " 0.214300 | \n",
+ " 0.340316 | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " 0.214200 | \n",
+ " 0.340021 | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " 0.214000 | \n",
+ " 0.340593 | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " 0.213800 | \n",
+ " 0.340408 | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " 0.213600 | \n",
+ " 0.339833 | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " 0.213500 | \n",
+ " 0.340246 | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " 0.213400 | \n",
+ " 0.339990 | \n",
+ "
\n",
+ " \n",
+ " 45 | \n",
+ " 0.213400 | \n",
+ " 0.339910 | \n",
+ "
\n",
+ " \n",
+ " 46 | \n",
+ " 0.213200 | \n",
+ " 0.340009 | \n",
+ "
\n",
+ " \n",
+ " 47 | \n",
+ " 0.213100 | \n",
+ " 0.339965 | \n",
+ "
\n",
+ " \n",
+ " 48 | \n",
+ " 0.213200 | \n",
+ " 0.339881 | \n",
+ "
\n",
+ " \n",
+ " 49 | \n",
+ " 0.213200 | \n",
+ " 0.339873 | \n",
+ "
\n",
+ " \n",
+ " 50 | \n",
+ " 0.213100 | \n",
+ " 0.339869 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:07 EDT)\" (scheduled at 2024-10-04 09:22:07.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:22 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:37 EDT)\" (scheduled at 2024-10-04 09:22:22.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:37 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:52 EDT)\" (scheduled at 2024-10-04 09:22:37.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:22:52 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:07 EDT)\" (scheduled at 2024-10-04 09:22:52.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:07 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:22 EDT)\" (scheduled at 2024-10-04 09:23:07.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:22 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:37 EDT)\" (scheduled at 2024-10-04 09:23:22.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:37 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:52 EDT)\" (scheduled at 2024-10-04 09:23:37.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:23:52 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:07 EDT)\" (scheduled at 2024-10-04 09:23:52.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:07 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:22 EDT)\" (scheduled at 2024-10-04 09:24:07.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:22 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:37 EDT)\" (scheduled at 2024-10-04 09:24:22.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:37 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:52 EDT)\" (scheduled at 2024-10-04 09:24:37.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:24:52 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:07 EDT)\" (scheduled at 2024-10-04 09:24:52.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:07 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:22 EDT)\" (scheduled at 2024-10-04 09:25:07.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:22 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:37 EDT)\" (scheduled at 2024-10-04 09:25:22.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:37 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:52 EDT)\" (scheduled at 2024-10-04 09:25:37.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:25:52 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:07 EDT)\" (scheduled at 2024-10-04 09:25:52.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:07 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:22 EDT)\" (scheduled at 2024-10-04 09:26:07.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:22 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:37 EDT)\" (scheduled at 2024-10-04 09:26:22.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:37 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:52 EDT)\" (scheduled at 2024-10-04 09:26:37.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:26:52 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:07 EDT)\" (scheduled at 2024-10-04 09:26:52.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:07 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:22 EDT)\" (scheduled at 2024-10-04 09:27:07.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:22 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:37 EDT)\" (scheduled at 2024-10-04 09:27:22.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:37 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:52 EDT)\" (scheduled at 2024-10-04 09:27:37.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:27:52 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:07 EDT)\" (scheduled at 2024-10-04 09:27:52.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:07 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:22 EDT)\" (scheduled at 2024-10-04 09:28:07.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:22 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:37 EDT)\" (scheduled at 2024-10-04 09:28:22.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:37 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:52 EDT)\" (scheduled at 2024-10-04 09:28:37.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:28:52 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:07 EDT)\" (scheduled at 2024-10-04 09:28:52.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:07 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:22 EDT)\" (scheduled at 2024-10-04 09:29:07.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:22 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:37 EDT)\" (scheduled at 2024-10-04 09:29:22.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:37 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:52 EDT)\" (scheduled at 2024-10-04 09:29:37.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:29:52 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:07 EDT)\" (scheduled at 2024-10-04 09:29:52.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:07 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:22 EDT)\" (scheduled at 2024-10-04 09:30:07.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:22 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:37 EDT)\" (scheduled at 2024-10-04 09:30:22.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:37 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:52 EDT)\" (scheduled at 2024-10-04 09:30:37.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:30:52 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:07 EDT)\" (scheduled at 2024-10-04 09:30:52.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:07 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:22 EDT)\" (scheduled at 2024-10-04 09:31:07.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:22 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:37 EDT)\" (scheduled at 2024-10-04 09:31:22.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:37 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:52 EDT)\" (scheduled at 2024-10-04 09:31:37.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:31:52 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:07 EDT)\" (scheduled at 2024-10-04 09:31:52.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:07 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:22 EDT)\" (scheduled at 2024-10-04 09:32:07.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:22 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:37 EDT)\" (scheduled at 2024-10-04 09:32:22.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:37 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:52 EDT)\" (scheduled at 2024-10-04 09:32:37.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:32:52 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:07 EDT)\" (scheduled at 2024-10-04 09:32:52.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:07 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:22 EDT)\" (scheduled at 2024-10-04 09:33:07.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:22 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:37 EDT)\" (scheduled at 2024-10-04 09:33:22.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:37 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:52 EDT)\" (scheduled at 2024-10-04 09:33:37.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:33:52 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:07 EDT)\" (scheduled at 2024-10-04 09:33:52.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:07 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:22 EDT)\" (scheduled at 2024-10-04 09:34:07.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:22 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:37 EDT)\" (scheduled at 2024-10-04 09:34:22.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:37 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:52 EDT)\" (scheduled at 2024-10-04 09:34:37.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:34:52 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:07 EDT)\" (scheduled at 2024-10-04 09:34:52.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:07 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:22 EDT)\" (scheduled at 2024-10-04 09:35:07.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:22 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:37 EDT)\" (scheduled at 2024-10-04 09:35:22.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:37 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Running job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:52 EDT)\" (scheduled at 2024-10-04 09:35:37.068300-04:00)\n",
+ "INFO:p-849332:t-23068871223040:base.py:run_job:Job \"EmissionsTracker._measure_power (trigger: interval[0:00:15], next run at: 2024-10-04 09:35:52 EDT)\" executed successfully\n",
+ "INFO:p-849332:t-23083607622400:base.py:shutdown:Scheduler has been shut down\n",
+ "ERROR:p-849332:t-23083607622400:emissions.py:get_private_infra_emissions:Region: not found for Country with ISO CODE : USA\n",
+ "WARNING:p-849332:t-23083607622400:emissions.py:get_private_infra_emissions:CODECARBON : Regional emissions retrieval failed. Falling back on country emissions.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 5.444428825378418 seconds, Total Train Time = 837.8258655071259\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 00:20]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.41814321279525757, 'eval_runtime': 21.6958, 'eval_samples_per_second': 157.311, 'eval_steps_per_second': 19.681, 'epoch': 50.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.364 0.364\n",
+ "1 etth2 0.277 0.277\n",
+ "2 ettm1 0.322 0.311\n",
+ "3 ettm2 0.171 0.171\n",
+ "4 weather 0.158 0.153\n",
+ "5 electricity 0.166 0.146\n",
+ "6 traffic 0.514 0.418\n"
+ ]
+ }
+ ],
+ "source": [
+ "all_results = {\n",
+ " \"dataset\": [],\n",
+ " \"zs_mse\": [],\n",
+ " \"fs5_mse\": [],\n",
+ " \"zs_eval_time\": [],\n",
+ " \"fs5_mean_epoch_time\": [],\n",
+ " \"fs5_total_train_time\": [],\n",
+ " \"fs5_best_val_metric\": [],\n",
+ "}\n",
+ "# Loop over data\n",
+ "for DATASET in list_datasets:\n",
+ " print()\n",
+ " print(\"=\" * 100)\n",
+ " print(\n",
+ " f\"Running zero-shot/few-shot for TTM-{context_length} on dataset = {DATASET}, forecast_len = {forecast_length}\"\n",
+ " )\n",
+ " print(f\"Model will be loaded from {hf_model_path}/{hf_model_branch}\")\n",
+ " SUBDIR = f\"{OUT_DIR}/{DATASET}\"\n",
+ "\n",
+ " # Set batch size\n",
+ " if DATASET == \"traffic\":\n",
+ " BATCH_SIZE = 8\n",
+ " elif DATASET == \"electricity\":\n",
+ " BATCH_SIZE = 32\n",
+ " else:\n",
+ " BATCH_SIZE = 64\n",
+ "\n",
+ " # Data prep: Get dataset\n",
+ " _, _, dset_test = load_dataset(\n",
+ " DATASET,\n",
+ " context_length,\n",
+ " forecast_length,\n",
+ " dataset_root_path=DATA_ROOT_PATH,\n",
+ " use_frequency_token=enable_prefix_tuning,\n",
+ " )\n",
+ "\n",
+ " #############################################################\n",
+ " ##### Use the pretrained model in zero-shot forecasting #####\n",
+ " #############################################################\n",
+ " # Load model\n",
+ " zeroshot_model = TinyTimeMixerForPrediction.from_pretrained(hf_model_path, revision=hf_model_branch)\n",
+ "\n",
+ " # zeroshot_trainer\n",
+ " zeroshot_trainer = Trainer(\n",
+ " model=zeroshot_model,\n",
+ " args=TrainingArguments(\n",
+ " output_dir=f\"{SUBDIR}/zeroshot\",\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " seed=SEED,\n",
+ " ),\n",
+ " eval_dataset=dset_test,\n",
+ " )\n",
+ "\n",
+ " # evaluate = zero-shot performance\n",
+ " print(\"+\" * 20, \"Test MSE zero-shot\", \"+\" * 20)\n",
+ " zeroshot_output = zeroshot_trainer.evaluate(dset_test)\n",
+ " print(zeroshot_output)\n",
+ " print(\"+\" * 60)\n",
+ " all_results[\"zs_eval_time\"].append(zeroshot_output[\"eval_runtime\"])\n",
+ "\n",
+ " # Plot\n",
+ " plot_predictions(\n",
+ " model=zeroshot_trainer.model,\n",
+ " dset=dset_test,\n",
+ " plot_dir=SUBDIR,\n",
+ " num_plots=10,\n",
+ " plot_prefix=\"test_zeroshot\",\n",
+ " channel=0,\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " # write results\n",
+ " all_results[\"dataset\"].append(DATASET)\n",
+ " all_results[\"zs_mse\"].append(zeroshot_output[\"eval_loss\"])\n",
+ "\n",
+ " ################################################################\n",
+ " ## Use the pretrained model in few-shot 5% and 10% forecasting #\n",
+ " ################################################################\n",
+ " for fewshot_percent in [5]:\n",
+ " # Set learning rate\n",
+ " learning_rate = None # `None` value indicates that the optimal_lr_finder() will be used\n",
+ "\n",
+ " print(\"-\" * 20, f\"Running few-shot {fewshot_percent}%\", \"-\" * 20)\n",
+ " # Data prep: Get dataset\n",
+ " dset_train, dset_val, dset_test = load_dataset(\n",
+ " DATASET,\n",
+ " context_length,\n",
+ " forecast_length,\n",
+ " fewshot_fraction=fewshot_percent / 100,\n",
+ " dataset_root_path=DATA_ROOT_PATH,\n",
+ " use_frequency_token=enable_prefix_tuning,\n",
+ " )\n",
+ "\n",
+ " # change head dropout to 0.7 for ett datasets\n",
+ " if \"ett\" in DATASET:\n",
+ " finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(\n",
+ " hf_model_path, revision=hf_model_branch, head_dropout=0.7\n",
+ " )\n",
+ " else:\n",
+ " finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(\n",
+ " hf_model_path, revision=hf_model_branch\n",
+ " )\n",
+ "\n",
+ " if freeze_backbone:\n",
+ " print(\n",
+ " \"Number of params before freezing backbone\",\n",
+ " count_parameters(finetune_forecast_model),\n",
+ " )\n",
+ "\n",
+ " # Freeze the backbone of the model\n",
+ " for param in finetune_forecast_model.backbone.parameters():\n",
+ " param.requires_grad = False\n",
+ "\n",
+ " # Count params\n",
+ " print(\n",
+ " \"Number of params after freezing the backbone\",\n",
+ " count_parameters(finetune_forecast_model),\n",
+ " )\n",
+ "\n",
+ " if learning_rate is None:\n",
+ " learning_rate, finetune_forecast_model = optimal_lr_finder(\n",
+ " finetune_forecast_model,\n",
+ " dset_train,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " enable_prefix_tuning=enable_prefix_tuning,\n",
+ " )\n",
+ " print(\"OPTIMAL SUGGESTED LEARNING RATE =\", learning_rate)\n",
+ "\n",
+ " print(f\"Using learning rate = {learning_rate}\")\n",
+ " finetune_forecast_args = TrainingArguments(\n",
+ " output_dir=f\"{SUBDIR}/fewshot_{fewshot_percent}\",\n",
+ " overwrite_output_dir=True,\n",
+ " learning_rate=learning_rate,\n",
+ " num_train_epochs=EPOCHS,\n",
+ " do_eval=True,\n",
+ " evaluation_strategy=\"epoch\",\n",
+ " per_device_train_batch_size=BATCH_SIZE,\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " dataloader_num_workers=NUM_WORKERS,\n",
+ " report_to=None,\n",
+ " save_strategy=\"epoch\",\n",
+ " logging_strategy=\"epoch\",\n",
+ " save_total_limit=1,\n",
+ " logging_dir=f\"{SUBDIR}/fewshot_{fewshot_percent}\", # Make sure to specify a logging directory\n",
+ " load_best_model_at_end=True, # Load the best model when training ends\n",
+ " metric_for_best_model=\"eval_loss\", # Metric to monitor for early stopping\n",
+ " greater_is_better=False, # For loss\n",
+ " seed=SEED,\n",
+ " )\n",
+ "\n",
+ " # Create the early stopping callback\n",
+ " early_stopping_callback = EarlyStoppingCallback(\n",
+ " early_stopping_patience=10, # Number of epochs with no improvement after which to stop\n",
+ " early_stopping_threshold=0.0, # Minimum improvement required to consider as improvement\n",
+ " )\n",
+ " tracking_callback = TrackingCallback()\n",
+ "\n",
+ " # Optimizer and scheduler\n",
+ " optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate)\n",
+ " scheduler = OneCycleLR(\n",
+ " optimizer,\n",
+ " learning_rate,\n",
+ " epochs=EPOCHS,\n",
+ " steps_per_epoch=math.ceil(len(dset_train) / (BATCH_SIZE)),\n",
+ " )\n",
+ "\n",
+ " finetune_forecast_trainer = Trainer(\n",
+ " model=finetune_forecast_model,\n",
+ " args=finetune_forecast_args,\n",
+ " train_dataset=dset_train,\n",
+ " eval_dataset=dset_val,\n",
+ " callbacks=[early_stopping_callback, tracking_callback],\n",
+ " optimizers=(optimizer, scheduler),\n",
+ " )\n",
+ "\n",
+ " # Fine tune\n",
+ " finetune_forecast_trainer.train()\n",
+ "\n",
+ " # Evaluation\n",
+ " print(\n",
+ " \"+\" * 20,\n",
+ " f\"Test MSE after few-shot {fewshot_percent}% fine-tuning\",\n",
+ " \"+\" * 20,\n",
+ " )\n",
+ " fewshot_output = finetune_forecast_trainer.evaluate(dset_test)\n",
+ " print(fewshot_output)\n",
+ " print(\"+\" * 60)\n",
+ "\n",
+ " # Plot\n",
+ " plot_predictions(\n",
+ " model=finetune_forecast_trainer.model,\n",
+ " dset=dset_test,\n",
+ " plot_dir=SUBDIR,\n",
+ " num_plots=10,\n",
+ " plot_prefix=f\"test_fewshot_{fewshot_percent}\",\n",
+ " channel=0,\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " # write results\n",
+ " all_results[f\"fs{fewshot_percent}_mse\"].append(fewshot_output[\"eval_loss\"])\n",
+ " all_results[f\"fs{fewshot_percent}_mean_epoch_time\"].append(tracking_callback.mean_epoch_time)\n",
+ " all_results[f\"fs{fewshot_percent}_total_train_time\"].append(tracking_callback.total_train_time)\n",
+ " all_results[f\"fs{fewshot_percent}_best_val_metric\"].append(tracking_callback.best_eval_metric)\n",
+ "\n",
+ " df_out = pd.DataFrame(all_results).round(3)\n",
+ " print(df_out[[\"dataset\", \"zs_mse\", \"fs5_mse\"]])\n",
+ " df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")\n",
+ " df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Benchmarking results*\n",
+ "\n",
+ "*Some slight differences in the results as compared to the TTM paper results is possible due to different training environments."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " dataset | \n",
+ " zs_mse | \n",
+ " fs5_mse | \n",
+ " zs_eval_time | \n",
+ " fs5_mean_epoch_time | \n",
+ " fs5_total_train_time | \n",
+ " fs5_best_val_metric | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " etth1 | \n",
+ " 0.364 | \n",
+ " 0.364 | \n",
+ " 3.433 | \n",
+ " 0.813 | \n",
+ " 29.086 | \n",
+ " 0.676 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " etth2 | \n",
+ " 0.277 | \n",
+ " 0.277 | \n",
+ " 0.732 | \n",
+ " 0.799 | \n",
+ " 29.095 | \n",
+ " 0.212 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " ettm1 | \n",
+ " 0.322 | \n",
+ " 0.311 | \n",
+ " 3.031 | \n",
+ " 1.133 | \n",
+ " 35.979 | \n",
+ " 0.383 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " ettm2 | \n",
+ " 0.171 | \n",
+ " 0.171 | \n",
+ " 3.032 | \n",
+ " 1.280 | \n",
+ " 40.053 | \n",
+ " 0.121 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " weather | \n",
+ " 0.158 | \n",
+ " 0.153 | \n",
+ " 4.716 | \n",
+ " 1.479 | \n",
+ " 40.243 | \n",
+ " 0.401 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " electricity | \n",
+ " 0.166 | \n",
+ " 0.146 | \n",
+ " 17.582 | \n",
+ " 3.542 | \n",
+ " 516.142 | \n",
+ " 0.122 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " traffic | \n",
+ " 0.514 | \n",
+ " 0.418 | \n",
+ " 31.656 | \n",
+ " 5.444 | \n",
+ " 837.826 | \n",
+ " 0.340 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " dataset zs_mse fs5_mse zs_eval_time fs5_mean_epoch_time \\\n",
+ "0 etth1 0.364 0.364 3.433 0.813 \n",
+ "1 etth2 0.277 0.277 0.732 0.799 \n",
+ "2 ettm1 0.322 0.311 3.031 1.133 \n",
+ "3 ettm2 0.171 0.171 3.032 1.280 \n",
+ "4 weather 0.158 0.153 4.716 1.479 \n",
+ "5 electricity 0.166 0.146 17.582 3.542 \n",
+ "6 traffic 0.514 0.418 31.656 5.444 \n",
+ "\n",
+ " fs5_total_train_time fs5_best_val_metric \n",
+ "0 29.086 0.676 \n",
+ "1 29.095 0.212 \n",
+ "2 35.979 0.383 \n",
+ "3 40.053 0.121 \n",
+ "4 40.243 0.401 \n",
+ "5 516.142 0.122 \n",
+ "6 837.826 0.340 "
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_out"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/hfdemo/tinytimemixer/ttm-r1_benchmarking_1024_96.ipynb b/notebooks/hfdemo/tinytimemixer/ttm-r1_benchmarking_1024_96.ipynb
new file mode 100644
index 00000000..808e4352
--- /dev/null
+++ b/notebooks/hfdemo/tinytimemixer/ttm-r1_benchmarking_1024_96.ipynb
@@ -0,0 +1,5175 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# TTM zero-shot and few-shot benchmarking on multiple datasets\n",
+ "\n",
+ "**Using TTM-1024-96 model.**\n",
+ "\n",
+ "Pre-trained TTM models will be fetched from the [Hugging Face TTM Model Repository](ibm-granite/granite-timeseries-ttm-r1).\n",
+ "\n",
+ "1. TTM-R1 pre-trained models can be found here: [TTM-R1 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r1)\n",
+ " 1. For 512-96 model set `TTM_MODEL_REVISION=\"main\"`\n",
+ " 2. For 1024-96 model set `TTM_MODEL_REVISION=\"1024_96_v1\"`\n",
+ "2. TTM-R2 pre-trained models can be found here: [TTM-R2 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2)\n",
+ " 1. For 512-96 model set `TTM_MODEL_REVISION=\"main\"`\n",
+ " 2. For 1024-96 model set `TTM_MODEL_REVISION=\"1024-96-r2\"`\n",
+ " 3. For 1536-96 model set `TTM_MODEL_REVISION=\"1536-96-r2\"`\n",
+ "\n",
+ "Details about the revisions (R1 and R2) can be found [here](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2)."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2024-10-10 07:33:33.458902: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
+ "2024-10-10 07:33:33.499290: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2024-10-10 07:33:34.206046: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory\n",
+ " warn(f\"Failed to load image Python extension: {e}\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "import math\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd\n",
+ "from torch.optim import AdamW\n",
+ "from torch.optim.lr_scheduler import OneCycleLR\n",
+ "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n",
+ "from transformers.integrations import INTEGRATION_TO_CALLBACK\n",
+ "\n",
+ "from tsfm_public import TinyTimeMixerForPrediction, TrackingCallback, count_parameters, load_dataset\n",
+ "from tsfm_public.toolkit.visualization import plot_predictions"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Important arguments"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set seed\n",
+ "set_seed(42)\n",
+ "\n",
+ "# Specify model parameters\n",
+ "context_length = 1024\n",
+ "forecast_length = 96\n",
+ "freeze_backbone = True\n",
+ "learning_rate = 0.001\n",
+ "\n",
+ "# Other args\n",
+ "EPOCHS = 50\n",
+ "NUM_WORKERS = 16\n",
+ "\n",
+ "# Make sure all the datasets in the following `list_datasets` are\n",
+ "# saved in the `DATA_ROOT_PATH` folder. Or, change it accordingly.\n",
+ "# Refer to the load_datasets() function to see how it is used.\n",
+ "DATA_ROOT_PATH = \"/dccstor/tsfm23/datasets/\"\n",
+ "\n",
+ "# This is where results will be saved\n",
+ "OUT_DIR = f\"ttm-r1_results_benchmark_{context_length}_{forecast_length}/\""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## List of benchmark datasets (TTM was not pre-trained on any of these)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "list_datasets = [\n",
+ " \"etth1\",\n",
+ " \"etth2\",\n",
+ " \"ettm1\",\n",
+ " \"ettm2\",\n",
+ " \"weather\",\n",
+ " \"electricity\",\n",
+ " \"traffic\",\n",
+ "]"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Get model path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "hf_model_path = \"ibm-granite/granite-timeseries-ttm-r1\"\n",
+ "if context_length == 512:\n",
+ " hf_model_branch = \"main\"\n",
+ "elif context_length == 1024:\n",
+ " hf_model_branch = \"1024_96_v1\"\n",
+ "else:\n",
+ " raise ValueError(\"Current supported context lengths are 512 and 1024. Stay tuned for more TTMs!\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Main benchmarking loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: etth1, context length: 1024, prediction length 96\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 7521, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = etth1, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r1/1024_96_v1\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "078dee5f93884e67b7b6e11fda57e305",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "config.json: 0%| | 0.00/1.19k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "7015b43db14848d1928f3217b78b62f0",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model.safetensors: 0%| | 0.00/3.80M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3621068000793457, 'eval_model_preparation_time': 0.0025, 'eval_runtime': 1.7776, 'eval_samples_per_second': 1566.691, 'eval_steps_per_second': 24.752}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: etth1, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 285, val = 2785, test = 2785\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 946336\n",
+ "Number of params after freezing the backbone 389984\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 85/250 00:36 < 01:12, 2.27 it/s, Epoch 17/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.584400 | \n",
+ " 0.663804 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.601700 | \n",
+ " 0.663111 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.585400 | \n",
+ " 0.662300 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.564100 | \n",
+ " 0.661117 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.527600 | \n",
+ " 0.659781 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.497200 | \n",
+ " 0.658587 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.455000 | \n",
+ " 0.658336 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.427900 | \n",
+ " 0.660301 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.399900 | \n",
+ " 0.663791 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.353700 | \n",
+ " 0.670050 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.312400 | \n",
+ " 0.682261 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.301000 | \n",
+ " 0.700288 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.266300 | \n",
+ " 0.724918 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.249500 | \n",
+ " 0.753728 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.233800 | \n",
+ " 0.775392 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.216400 | \n",
+ " 0.788408 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.214900 | \n",
+ " 0.795417 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 0.9880044320050407 seconds, Total Train Time = 38.51875972747803\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3614313006401062, 'eval_runtime': 1.2179, 'eval_samples_per_second': 2286.638, 'eval_steps_per_second': 36.126, 'epoch': 17.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: etth1, context length: 1024, prediction length 96\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 666, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 10% --------------------\n",
+ "Number of params before freezing backbone 946336\n",
+ "Number of params after freezing the backbone 389984\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [143/550 00:28 < 01:22, 4.91 it/s, Epoch 13/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.642100 | \n",
+ " 0.663602 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.594700 | \n",
+ " 0.662923 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.545200 | \n",
+ " 0.662525 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.518200 | \n",
+ " 0.663749 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.464700 | \n",
+ " 0.667679 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.414400 | \n",
+ " 0.674993 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.371700 | \n",
+ " 0.687040 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.334700 | \n",
+ " 0.708806 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.308600 | \n",
+ " 0.737532 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.286700 | \n",
+ " 0.751357 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.268600 | \n",
+ " 0.776843 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.256200 | \n",
+ " 0.794736 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.245500 | \n",
+ " 0.800167 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.0025265216827393 seconds, Total Train Time = 29.51686191558838\n",
+ "++++++++++++++++++++ Test MSE after few-shot 10% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36271077394485474, 'eval_runtime': 1.2248, 'eval_samples_per_second': 2273.762, 'eval_steps_per_second': 35.923, 'epoch': 13.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: etth2, context length: 1024, prediction length 96\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 7521, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse fs10_mse\n",
+ "0 etth1 0.362 0.361 0.363\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = etth2, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r1/1024_96_v1\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.280693918466568, 'eval_model_preparation_time': 0.0019, 'eval_runtime': 0.7537, 'eval_samples_per_second': 3695.306, 'eval_steps_per_second': 58.382}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: etth2, context length: 1024, prediction length 96\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 285, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 946336\n",
+ "Number of params after freezing the backbone 389984\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 55/250 00:23 < 01:27, 2.23 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.565900 | \n",
+ " 0.224047 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.524400 | \n",
+ " 0.224874 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.560600 | \n",
+ " 0.226318 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.518900 | \n",
+ " 0.229082 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.532800 | \n",
+ " 0.234455 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.487500 | \n",
+ " 0.244063 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.437800 | \n",
+ " 0.257846 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.404400 | \n",
+ " 0.275406 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.368500 | \n",
+ " 0.299849 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.353900 | \n",
+ " 0.332489 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.320900 | \n",
+ " 0.374024 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 0.941808743910356 seconds, Total Train Time = 24.5705406665802\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.2801705598831177, 'eval_runtime': 1.1876, 'eval_samples_per_second': 2345.043, 'eval_steps_per_second': 37.049, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: etth2, context length: 1024, prediction length 96\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 666, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 10% --------------------\n",
+ "Number of params before freezing backbone 946336\n",
+ "Number of params after freezing the backbone 389984\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [132/550 00:27 < 01:28, 4.72 it/s, Epoch 12/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.414400 | \n",
+ " 0.223319 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.420600 | \n",
+ " 0.223185 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.378100 | \n",
+ " 0.223428 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.359700 | \n",
+ " 0.224102 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.330100 | \n",
+ " 0.225113 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.298500 | \n",
+ " 0.227125 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.264500 | \n",
+ " 0.229628 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.242900 | \n",
+ " 0.238282 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.219100 | \n",
+ " 0.251995 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.201800 | \n",
+ " 0.277940 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.180300 | \n",
+ " 0.317722 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.171200 | \n",
+ " 0.340438 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.0485320885976155 seconds, Total Train Time = 28.39580202102661\n",
+ "++++++++++++++++++++ Test MSE after few-shot 10% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.28046759963035583, 'eval_runtime': 1.1773, 'eval_samples_per_second': 2365.604, 'eval_steps_per_second': 37.374, 'epoch': 12.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: ettm1, context length: 1024, prediction length 96\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 33441, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse fs10_mse\n",
+ "0 etth1 0.362 0.361 0.363\n",
+ "1 etth2 0.281 0.280 0.280\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = ettm1, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r1/1024_96_v1\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:02]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.38726314902305603, 'eval_model_preparation_time': 0.0018, 'eval_runtime': 3.0352, 'eval_samples_per_second': 3764.195, 'eval_steps_per_second': 58.975}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: ettm1, context length: 1024, prediction length 96\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 1581, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 946336\n",
+ "Number of params after freezing the backbone 389984\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 375/1250 00:52 < 02:02, 7.16 it/s, Epoch 15/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.504500 | \n",
+ " 0.422623 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.471700 | \n",
+ " 0.417156 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.424700 | \n",
+ " 0.412834 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.385000 | \n",
+ " 0.409597 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.340900 | \n",
+ " 0.409013 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.300900 | \n",
+ " 0.417046 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.273300 | \n",
+ " 0.429183 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.251200 | \n",
+ " 0.439041 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.232500 | \n",
+ " 0.448727 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.223100 | \n",
+ " 0.456104 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.214200 | \n",
+ " 0.460536 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.207600 | \n",
+ " 0.466538 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.203800 | \n",
+ " 0.476997 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.198100 | \n",
+ " 0.480505 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.194100 | \n",
+ " 0.488817 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.2938549836476645 seconds, Total Train Time = 53.02472805976868\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3715095520019531, 'eval_runtime': 2.1521, 'eval_samples_per_second': 5308.838, 'eval_steps_per_second': 83.176, 'epoch': 15.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: ettm1, context length: 1024, prediction length 96\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 3258, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 10% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 946336\n",
+ "Number of params after freezing the backbone 389984\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 714/2550 00:52 < 02:16, 13.45 it/s, Epoch 14/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.542500 | \n",
+ " 0.419067 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.470700 | \n",
+ " 0.414835 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.418600 | \n",
+ " 0.413820 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.366600 | \n",
+ " 0.407678 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.327400 | \n",
+ " 0.407747 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.301200 | \n",
+ " 0.413063 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.282900 | \n",
+ " 0.419441 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.271800 | \n",
+ " 0.437654 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.262200 | \n",
+ " 0.438877 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.254400 | \n",
+ " 0.450964 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.250900 | \n",
+ " 0.463638 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.246800 | \n",
+ " 0.465234 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.241700 | \n",
+ " 0.471252 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.239800 | \n",
+ " 0.479022 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.6105220488139562 seconds, Total Train Time = 53.82016968727112\n",
+ "++++++++++++++++++++ Test MSE after few-shot 10% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.37059730291366577, 'eval_runtime': 2.1347, 'eval_samples_per_second': 5351.977, 'eval_steps_per_second': 83.852, 'epoch': 14.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: ettm2, context length: 1024, prediction length 96\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 33441, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse fs10_mse\n",
+ "0 etth1 0.362 0.361 0.363\n",
+ "1 etth2 0.281 0.280 0.280\n",
+ "2 ettm1 0.387 0.372 0.371\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = ettm2, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r1/1024_96_v1\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:02]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.17503736913204193, 'eval_model_preparation_time': 0.0018, 'eval_runtime': 2.9526, 'eval_samples_per_second': 3869.42, 'eval_steps_per_second': 60.624}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: ettm2, context length: 1024, prediction length 96\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 1581, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 946336\n",
+ "Number of params after freezing the backbone 389984\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 275/1250 00:37 < 02:13, 7.32 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.280700 | \n",
+ " 0.121009 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.264600 | \n",
+ " 0.121268 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.226500 | \n",
+ " 0.123216 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.199600 | \n",
+ " 0.129200 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.169000 | \n",
+ " 0.141758 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.152200 | \n",
+ " 0.155555 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.138100 | \n",
+ " 0.163517 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.129300 | \n",
+ " 0.172492 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.122000 | \n",
+ " 0.183950 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.116400 | \n",
+ " 0.191413 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.112600 | \n",
+ " 0.197757 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.262766101143577 seconds, Total Train Time = 38.181320667266846\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.17288224399089813, 'eval_runtime': 2.1253, 'eval_samples_per_second': 5375.828, 'eval_steps_per_second': 84.225, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: ettm2, context length: 1024, prediction length 96\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 3258, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 10% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 946336\n",
+ "Number of params after freezing the backbone 389984\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 561/2550 00:41 < 02:28, 13.40 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.314100 | \n",
+ " 0.121323 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.274900 | \n",
+ " 0.122920 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.243900 | \n",
+ " 0.126737 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.213000 | \n",
+ " 0.131092 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.194700 | \n",
+ " 0.134649 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.184000 | \n",
+ " 0.137388 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.175500 | \n",
+ " 0.139926 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.169700 | \n",
+ " 0.142911 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.164200 | \n",
+ " 0.151129 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.160800 | \n",
+ " 0.147594 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.157500 | \n",
+ " 0.151603 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.63915508443659 seconds, Total Train Time = 42.55753254890442\n",
+ "++++++++++++++++++++ Test MSE after few-shot 10% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1721040904521942, 'eval_runtime': 2.1601, 'eval_samples_per_second': 5289.139, 'eval_steps_per_second': 82.867, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: weather, context length: 1024, prediction length 96\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 35768, val = 5175, test = 10444\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse fs10_mse\n",
+ "0 etth1 0.362 0.361 0.363\n",
+ "1 etth2 0.281 0.280 0.280\n",
+ "2 ettm1 0.387 0.372 0.371\n",
+ "3 ettm2 0.175 0.173 0.172\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = weather, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r1/1024_96_v1\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:05]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.15184031426906586, 'eval_model_preparation_time': 0.0019, 'eval_runtime': 5.133, 'eval_samples_per_second': 2034.658, 'eval_steps_per_second': 31.95}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: weather, context length: 1024, prediction length 96\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 1698, val = 5175, test = 10444\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 946336\n",
+ "Number of params after freezing the backbone 389984\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 297/1350 00:43 < 02:34, 6.83 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.152400 | \n",
+ " 0.419179 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.147600 | \n",
+ " 0.424661 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.142100 | \n",
+ " 0.439751 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.136300 | \n",
+ " 0.458828 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.127800 | \n",
+ " 0.483952 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.119200 | \n",
+ " 0.519423 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.110600 | \n",
+ " 0.522068 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.103600 | \n",
+ " 0.505524 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.097000 | \n",
+ " 0.515911 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.091300 | \n",
+ " 0.517793 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.086200 | \n",
+ " 0.485350 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.6853359612551602 seconds, Total Train Time = 44.3294575214386\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:02]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1506919413805008, 'eval_runtime': 3.4903, 'eval_samples_per_second': 2992.315, 'eval_steps_per_second': 46.988, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: weather, context length: 1024, prediction length 96\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 3491, val = 5175, test = 10444\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 10% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 946336\n",
+ "Number of params after freezing the backbone 389984\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 605/2750 00:51 < 03:04, 11.64 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.137700 | \n",
+ " 0.424454 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.133600 | \n",
+ " 0.436503 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.128500 | \n",
+ " 0.445798 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.123400 | \n",
+ " 0.456645 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.116700 | \n",
+ " 0.477022 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.111500 | \n",
+ " 0.483446 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.105500 | \n",
+ " 0.470728 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.099900 | \n",
+ " 0.470292 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.095400 | \n",
+ " 0.476556 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.090700 | \n",
+ " 0.458923 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.087200 | \n",
+ " 0.471447 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 2.416881713000211 seconds, Total Train Time = 52.82250738143921\n",
+ "++++++++++++++++++++ Test MSE after few-shot 10% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:02]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.15016287565231323, 'eval_runtime': 3.5776, 'eval_samples_per_second': 2919.276, 'eval_steps_per_second': 45.841, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: electricity, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse fs10_mse\n",
+ "0 etth1 0.362 0.361 0.363\n",
+ "1 etth2 0.281 0.280 0.280\n",
+ "2 ettm1 0.387 0.372 0.371\n",
+ "3 ettm2 0.175 0.173 0.172\n",
+ "4 weather 0.152 0.151 0.150\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = electricity, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r1/1024_96_v1\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 17293, val = 2537, test = 5165\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:24]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1555725336074829, 'eval_model_preparation_time': 0.0018, 'eval_runtime': 24.5596, 'eval_samples_per_second': 210.305, 'eval_steps_per_second': 6.596}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: electricity, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 774, val = 2537, test = 5165\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 946336\n",
+ "Number of params after freezing the backbone 389984\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [1250/1250 12:32, Epoch 50/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.145900 | \n",
+ " 0.127765 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.140000 | \n",
+ " 0.124079 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.135400 | \n",
+ " 0.121057 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.131600 | \n",
+ " 0.118233 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.127300 | \n",
+ " 0.116960 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.124700 | \n",
+ " 0.115788 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.121300 | \n",
+ " 0.114265 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.119800 | \n",
+ " 0.113604 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.117900 | \n",
+ " 0.113588 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.115900 | \n",
+ " 0.114102 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.114200 | \n",
+ " 0.114330 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.114100 | \n",
+ " 0.114430 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.112600 | \n",
+ " 0.113078 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.111900 | \n",
+ " 0.114254 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.110900 | \n",
+ " 0.113203 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.110300 | \n",
+ " 0.116154 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.108800 | \n",
+ " 0.114116 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.108300 | \n",
+ " 0.114400 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.107600 | \n",
+ " 0.113790 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.107100 | \n",
+ " 0.112894 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.107000 | \n",
+ " 0.114230 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.107100 | \n",
+ " 0.113750 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.106600 | \n",
+ " 0.112724 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.105200 | \n",
+ " 0.112615 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.104600 | \n",
+ " 0.112540 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.104400 | \n",
+ " 0.114088 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.104300 | \n",
+ " 0.113155 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.104000 | \n",
+ " 0.113183 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 0.103300 | \n",
+ " 0.113108 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 0.103300 | \n",
+ " 0.112891 | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " 0.103000 | \n",
+ " 0.112966 | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " 0.102800 | \n",
+ " 0.112305 | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " 0.102100 | \n",
+ " 0.112232 | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " 0.101800 | \n",
+ " 0.112428 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 0.101800 | \n",
+ " 0.112281 | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " 0.101500 | \n",
+ " 0.112245 | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " 0.101300 | \n",
+ " 0.112165 | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " 0.101600 | \n",
+ " 0.112430 | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " 0.100800 | \n",
+ " 0.112168 | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " 0.101100 | \n",
+ " 0.112292 | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " 0.100700 | \n",
+ " 0.112243 | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " 0.101100 | \n",
+ " 0.112085 | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " 0.100500 | \n",
+ " 0.112192 | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " 0.100500 | \n",
+ " 0.112366 | \n",
+ "
\n",
+ " \n",
+ " 45 | \n",
+ " 0.100200 | \n",
+ " 0.112117 | \n",
+ "
\n",
+ " \n",
+ " 46 | \n",
+ " 0.100500 | \n",
+ " 0.112201 | \n",
+ "
\n",
+ " \n",
+ " 47 | \n",
+ " 0.100400 | \n",
+ " 0.112215 | \n",
+ "
\n",
+ " \n",
+ " 48 | \n",
+ " 0.100400 | \n",
+ " 0.112120 | \n",
+ "
\n",
+ " \n",
+ " 49 | \n",
+ " 0.100400 | \n",
+ " 0.112149 | \n",
+ "
\n",
+ " \n",
+ " 50 | \n",
+ " 0.100000 | \n",
+ " 0.112153 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 5.032508630752563 seconds, Total Train Time = 755.1325304508209\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:16]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.14543357491493225, 'eval_runtime': 18.4848, 'eval_samples_per_second': 279.418, 'eval_steps_per_second': 8.764, 'epoch': 50.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: electricity, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 10% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 1643, val = 2537, test = 5165\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 946336\n",
+ "Number of params after freezing the backbone 389984\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [2288/2600 13:55 < 01:54, 2.73 it/s, Epoch 44/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.142600 | \n",
+ " 0.123726 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.136800 | \n",
+ " 0.119504 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.132600 | \n",
+ " 0.116580 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.129200 | \n",
+ " 0.114476 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.126500 | \n",
+ " 0.112773 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.124100 | \n",
+ " 0.111989 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.122400 | \n",
+ " 0.111565 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.121000 | \n",
+ " 0.111765 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.119900 | \n",
+ " 0.112019 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.119100 | \n",
+ " 0.110442 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.118400 | \n",
+ " 0.111159 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.117700 | \n",
+ " 0.111936 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.116800 | \n",
+ " 0.110754 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.116500 | \n",
+ " 0.111346 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.115700 | \n",
+ " 0.111159 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.115600 | \n",
+ " 0.112541 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.115400 | \n",
+ " 0.110576 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.114600 | \n",
+ " 0.110490 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.114300 | \n",
+ " 0.110820 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.113700 | \n",
+ " 0.110099 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.113600 | \n",
+ " 0.111032 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.113100 | \n",
+ " 0.110572 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.112500 | \n",
+ " 0.110152 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.112200 | \n",
+ " 0.110462 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.112200 | \n",
+ " 0.110486 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.111500 | \n",
+ " 0.109890 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.111200 | \n",
+ " 0.109659 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.110900 | \n",
+ " 0.110145 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 0.111000 | \n",
+ " 0.110042 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 0.110500 | \n",
+ " 0.109693 | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " 0.110400 | \n",
+ " 0.109685 | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " 0.110300 | \n",
+ " 0.109534 | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " 0.109900 | \n",
+ " 0.109661 | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " 0.109800 | \n",
+ " 0.109107 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 0.109500 | \n",
+ " 0.109508 | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " 0.109200 | \n",
+ " 0.109286 | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " 0.109000 | \n",
+ " 0.109707 | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " 0.108600 | \n",
+ " 0.109372 | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " 0.108600 | \n",
+ " 0.109286 | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " 0.108500 | \n",
+ " 0.109232 | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " 0.108400 | \n",
+ " 0.109145 | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " 0.108300 | \n",
+ " 0.109114 | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " 0.108300 | \n",
+ " 0.109157 | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " 0.108300 | \n",
+ " 0.109180 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 8.928198153322393 seconds, Total Train Time = 837.7255702018738\n",
+ "++++++++++++++++++++ Test MSE after few-shot 10% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:17]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.13808377087116241, 'eval_runtime': 18.4352, 'eval_samples_per_second': 280.17, 'eval_steps_per_second': 8.788, 'epoch': 44.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: traffic, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse fs10_mse\n",
+ "0 etth1 0.362 0.361 0.363\n",
+ "1 etth2 0.281 0.280 0.280\n",
+ "2 ettm1 0.387 0.372 0.371\n",
+ "3 ettm2 0.175 0.173 0.172\n",
+ "4 weather 0.152 0.151 0.150\n",
+ "5 electricity 0.156 0.145 0.138\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = traffic, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r1/1024_96_v1\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 11161, val = 1661, test = 3413\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 00:43]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.4576044976711273, 'eval_model_preparation_time': 0.0018, 'eval_runtime': 43.9397, 'eval_samples_per_second': 77.675, 'eval_steps_per_second': 9.718}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: traffic, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 467, val = 1661, test = 3413\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 946336\n",
+ "Number of params after freezing the backbone 389984\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [1652/2950 11:12 < 08:48, 2.45 it/s, Epoch 28/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.286700 | \n",
+ " 0.364595 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.261100 | \n",
+ " 0.354531 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.248300 | \n",
+ " 0.348675 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.237800 | \n",
+ " 0.347864 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.229100 | \n",
+ " 0.346862 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.223000 | \n",
+ " 0.355750 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.217900 | \n",
+ " 0.347379 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.215000 | \n",
+ " 0.351959 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.210800 | \n",
+ " 0.345832 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.208000 | \n",
+ " 0.353231 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.207100 | \n",
+ " 0.352474 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.204400 | \n",
+ " 0.359140 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.204100 | \n",
+ " 0.350371 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.202100 | \n",
+ " 0.366590 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.201200 | \n",
+ " 0.361391 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.198500 | \n",
+ " 0.349136 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.196100 | \n",
+ " 0.369769 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.195400 | \n",
+ " 0.345229 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.195100 | \n",
+ " 0.357952 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.192400 | \n",
+ " 0.357397 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.192000 | \n",
+ " 0.363344 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.190500 | \n",
+ " 0.355350 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.189100 | \n",
+ " 0.363749 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.188000 | \n",
+ " 0.354421 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.188700 | \n",
+ " 0.353886 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.186300 | \n",
+ " 0.361907 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.184700 | \n",
+ " 0.352133 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.184500 | \n",
+ " 0.356455 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 7.41969324861254 seconds, Total Train Time = 674.0370872020721\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 00:30]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.4156947731971741, 'eval_runtime': 32.3569, 'eval_samples_per_second': 105.48, 'eval_steps_per_second': 13.197, 'epoch': 28.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Dataset name: traffic, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 10% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "INFO:p-3144901:t-23206823609088:data_handling.py:load_dataset:Data lengths: train = 1030, val = 1661, test = 3413\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3144901:t-23206823609088:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3144901:t-23206823609088:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 946336\n",
+ "Number of params after freezing the backbone 389984\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [2064/6450 08:16 < 17:35, 4.16 it/s, Epoch 16/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.270300 | \n",
+ " 0.362664 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.250300 | \n",
+ " 0.351346 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.241200 | \n",
+ " 0.348280 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.235000 | \n",
+ " 0.348637 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.230200 | \n",
+ " 0.346396 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.226200 | \n",
+ " 0.343058 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.223100 | \n",
+ " 0.351453 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.220700 | \n",
+ " 0.347172 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.218700 | \n",
+ " 0.349513 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.216700 | \n",
+ " 0.351570 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.215000 | \n",
+ " 0.353760 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.214200 | \n",
+ " 0.348613 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.213000 | \n",
+ " 0.353903 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.211300 | \n",
+ " 0.347297 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.210300 | \n",
+ " 0.354026 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.208300 | \n",
+ " 0.350533 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 14.462455168366432 seconds, Total Train Time = 497.8177742958069\n",
+ "++++++++++++++++++++ Test MSE after few-shot 10% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 00:30]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.41844481229782104, 'eval_runtime': 32.5373, 'eval_samples_per_second': 104.895, 'eval_steps_per_second': 13.123, 'epoch': 16.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
+ " dataset zs_mse fs5_mse fs10_mse\n",
+ "0 etth1 0.362 0.361 0.363\n",
+ "1 etth2 0.281 0.280 0.280\n",
+ "2 ettm1 0.387 0.372 0.371\n",
+ "3 ettm2 0.175 0.173 0.172\n",
+ "4 weather 0.152 0.151 0.150\n",
+ "5 electricity 0.156 0.145 0.138\n",
+ "6 traffic 0.458 0.416 0.418\n"
+ ]
+ }
+ ],
+ "source": [
+ "all_results = {\n",
+ " \"dataset\": [],\n",
+ " \"zs_mse\": [],\n",
+ " \"fs5_mse\": [],\n",
+ " \"fs10_mse\": [],\n",
+ " \"zs_eval_time\": [],\n",
+ " \"fs5_mean_epoch_time\": [],\n",
+ " \"fs5_total_train_time\": [],\n",
+ " \"fs10_mean_epoch_time\": [],\n",
+ " \"fs10_total_train_time\": [],\n",
+ " \"fs5_best_val_metric\": [],\n",
+ " \"fs10_best_val_metric\": [],\n",
+ "}\n",
+ "# Loop over data\n",
+ "for DATASET in list_datasets:\n",
+ " print()\n",
+ " print(\"=\" * 100)\n",
+ " print(\n",
+ " f\"Running zero-shot/few-shot for TTM-{context_length} on dataset = {DATASET}, forecast_len = {forecast_length}\"\n",
+ " )\n",
+ " print(f\"Model will be loaded from {hf_model_path}/{hf_model_branch}\")\n",
+ " SUBDIR = f\"{OUT_DIR}/{DATASET}\"\n",
+ "\n",
+ " # Set batch size\n",
+ " if DATASET == \"traffic\":\n",
+ " BATCH_SIZE = 8\n",
+ " elif DATASET == \"electricity\":\n",
+ " BATCH_SIZE = 32\n",
+ " else:\n",
+ " BATCH_SIZE = 64\n",
+ "\n",
+ " # Data prep: Get dataset\n",
+ " _, _, dset_test = load_dataset(DATASET, context_length, forecast_length, dataset_root_path=DATA_ROOT_PATH)\n",
+ "\n",
+ " #############################################################\n",
+ " ##### Use the pretrained model in zero-shot forecasting #####\n",
+ " #############################################################\n",
+ " # Load model\n",
+ " zeroshot_model = TinyTimeMixerForPrediction.from_pretrained(hf_model_path, revision=hf_model_branch)\n",
+ "\n",
+ " # zeroshot_trainer\n",
+ " zeroshot_trainer = Trainer(\n",
+ " model=zeroshot_model,\n",
+ " args=TrainingArguments(\n",
+ " output_dir=f\"{SUBDIR}/zeroshot\",\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " ),\n",
+ " eval_dataset=dset_test,\n",
+ " )\n",
+ "\n",
+ " # evaluate = zero-shot performance\n",
+ " print(\"+\" * 20, \"Test MSE zero-shot\", \"+\" * 20)\n",
+ " zeroshot_output = zeroshot_trainer.evaluate(dset_test)\n",
+ " print(zeroshot_output)\n",
+ " print(\"+\" * 60)\n",
+ " all_results[\"zs_eval_time\"].append(zeroshot_output[\"eval_runtime\"])\n",
+ "\n",
+ " # Plot\n",
+ " plot_predictions(\n",
+ " model=zeroshot_trainer.model,\n",
+ " dset=dset_test,\n",
+ " plot_dir=SUBDIR,\n",
+ " num_plots=10,\n",
+ " plot_prefix=\"test_zeroshot\",\n",
+ " channel=0,\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " # write results\n",
+ " all_results[\"dataset\"].append(DATASET)\n",
+ " all_results[\"zs_mse\"].append(zeroshot_output[\"eval_loss\"])\n",
+ "\n",
+ " ################################################################\n",
+ " ## Use the pretrained model in few-shot 5% and 10% forecasting #\n",
+ " ################################################################\n",
+ " for fewshot_percent in [5, 10]:\n",
+ " print(\"-\" * 20, f\"Running few-shot {fewshot_percent}%\", \"-\" * 20)\n",
+ " # Data prep: Get dataset\n",
+ " dset_train, dset_val, dset_test = load_dataset(\n",
+ " DATASET,\n",
+ " context_length,\n",
+ " forecast_length,\n",
+ " fewshot_fraction=fewshot_percent / 100,\n",
+ " dataset_root_path=DATA_ROOT_PATH,\n",
+ " )\n",
+ "\n",
+ " # change head dropout to 0.7 for ett datasets\n",
+ " if \"ett\" in DATASET:\n",
+ " finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(\n",
+ " hf_model_path, revision=hf_model_branch, head_dropout=0.7\n",
+ " )\n",
+ " else:\n",
+ " finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(\n",
+ " hf_model_path, revision=hf_model_branch\n",
+ " )\n",
+ "\n",
+ " if freeze_backbone:\n",
+ " print(\n",
+ " \"Number of params before freezing backbone\",\n",
+ " count_parameters(finetune_forecast_model),\n",
+ " )\n",
+ "\n",
+ " # Freeze the backbone of the model\n",
+ " for param in finetune_forecast_model.backbone.parameters():\n",
+ " param.requires_grad = False\n",
+ "\n",
+ " # Count params\n",
+ " print(\n",
+ " \"Number of params after freezing the backbone\",\n",
+ " count_parameters(finetune_forecast_model),\n",
+ " )\n",
+ "\n",
+ " print(f\"Using learning rate = {learning_rate}\")\n",
+ " finetune_forecast_args = TrainingArguments(\n",
+ " output_dir=f\"{SUBDIR}/fewshot_{fewshot_percent}\",\n",
+ " overwrite_output_dir=True,\n",
+ " learning_rate=learning_rate,\n",
+ " num_train_epochs=EPOCHS,\n",
+ " do_eval=True,\n",
+ " evaluation_strategy=\"epoch\",\n",
+ " per_device_train_batch_size=BATCH_SIZE,\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " dataloader_num_workers=NUM_WORKERS,\n",
+ " report_to=None,\n",
+ " save_strategy=\"epoch\",\n",
+ " logging_strategy=\"epoch\",\n",
+ " save_total_limit=1,\n",
+ " logging_dir=f\"{SUBDIR}/fewshot_{fewshot_percent}\", # Make sure to specify a logging directory\n",
+ " load_best_model_at_end=True, # Load the best model when training ends\n",
+ " metric_for_best_model=\"eval_loss\", # Metric to monitor for early stopping\n",
+ " greater_is_better=False, # For loss\n",
+ " )\n",
+ "\n",
+ " # Create the early stopping callback\n",
+ " early_stopping_callback = EarlyStoppingCallback(\n",
+ " early_stopping_patience=10, # Number of epochs with no improvement after which to stop\n",
+ " early_stopping_threshold=0.0, # Minimum improvement required to consider as improvement\n",
+ " )\n",
+ " tracking_callback = TrackingCallback()\n",
+ "\n",
+ " # Optimizer and scheduler\n",
+ " optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate)\n",
+ " scheduler = OneCycleLR(\n",
+ " optimizer,\n",
+ " learning_rate,\n",
+ " epochs=EPOCHS,\n",
+ " steps_per_epoch=math.ceil(len(dset_train) / (BATCH_SIZE)),\n",
+ " )\n",
+ "\n",
+ " finetune_forecast_trainer = Trainer(\n",
+ " model=finetune_forecast_model,\n",
+ " args=finetune_forecast_args,\n",
+ " train_dataset=dset_train,\n",
+ " eval_dataset=dset_val,\n",
+ " callbacks=[early_stopping_callback, tracking_callback],\n",
+ " optimizers=(optimizer, scheduler),\n",
+ " )\n",
+ " finetune_forecast_trainer.remove_callback(INTEGRATION_TO_CALLBACK[\"codecarbon\"])\n",
+ "\n",
+ " # Fine tune\n",
+ " finetune_forecast_trainer.train()\n",
+ "\n",
+ " # Evaluation\n",
+ " print(\n",
+ " \"+\" * 20,\n",
+ " f\"Test MSE after few-shot {fewshot_percent}% fine-tuning\",\n",
+ " \"+\" * 20,\n",
+ " )\n",
+ " fewshot_output = finetune_forecast_trainer.evaluate(dset_test)\n",
+ " print(fewshot_output)\n",
+ " print(\"+\" * 60)\n",
+ "\n",
+ " # Plot\n",
+ " plot_predictions(\n",
+ " model=finetune_forecast_trainer.model,\n",
+ " dset=dset_test,\n",
+ " plot_dir=SUBDIR,\n",
+ " num_plots=10,\n",
+ " plot_prefix=f\"test_fewshot_{fewshot_percent}\",\n",
+ " channel=0,\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " # write results\n",
+ " all_results[f\"fs{fewshot_percent}_mse\"].append(fewshot_output[\"eval_loss\"])\n",
+ " all_results[f\"fs{fewshot_percent}_mean_epoch_time\"].append(tracking_callback.mean_epoch_time)\n",
+ " all_results[f\"fs{fewshot_percent}_total_train_time\"].append(tracking_callback.total_train_time)\n",
+ " all_results[f\"fs{fewshot_percent}_best_val_metric\"].append(tracking_callback.best_eval_metric)\n",
+ "\n",
+ " df_out = pd.DataFrame(all_results).round(3)\n",
+ " print(df_out[[\"dataset\", \"zs_mse\", \"fs5_mse\", \"fs10_mse\"]])\n",
+ " df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")\n",
+ " df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Benchmarking results*\n",
+ "\n",
+ "*Some slight differences in the results as compared to the TTM paper results is possible due to different training environments."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " dataset | \n",
+ " zs_mse | \n",
+ " fs5_mse | \n",
+ " fs10_mse | \n",
+ " zs_eval_time | \n",
+ " fs5_mean_epoch_time | \n",
+ " fs5_total_train_time | \n",
+ " fs10_mean_epoch_time | \n",
+ " fs10_total_train_time | \n",
+ " fs5_best_val_metric | \n",
+ " fs10_best_val_metric | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " etth1 | \n",
+ " 0.362 | \n",
+ " 0.361 | \n",
+ " 0.363 | \n",
+ " 1.778 | \n",
+ " 0.988 | \n",
+ " 38.519 | \n",
+ " 1.003 | \n",
+ " 29.517 | \n",
+ " 0.658 | \n",
+ " 0.663 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " etth2 | \n",
+ " 0.281 | \n",
+ " 0.280 | \n",
+ " 0.280 | \n",
+ " 0.754 | \n",
+ " 0.942 | \n",
+ " 24.571 | \n",
+ " 1.049 | \n",
+ " 28.396 | \n",
+ " 0.224 | \n",
+ " 0.223 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " ettm1 | \n",
+ " 0.387 | \n",
+ " 0.372 | \n",
+ " 0.371 | \n",
+ " 3.035 | \n",
+ " 1.294 | \n",
+ " 53.025 | \n",
+ " 1.611 | \n",
+ " 53.820 | \n",
+ " 0.409 | \n",
+ " 0.408 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " ettm2 | \n",
+ " 0.175 | \n",
+ " 0.173 | \n",
+ " 0.172 | \n",
+ " 2.953 | \n",
+ " 1.263 | \n",
+ " 38.181 | \n",
+ " 1.639 | \n",
+ " 42.558 | \n",
+ " 0.121 | \n",
+ " 0.121 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " weather | \n",
+ " 0.152 | \n",
+ " 0.151 | \n",
+ " 0.150 | \n",
+ " 5.133 | \n",
+ " 1.685 | \n",
+ " 44.329 | \n",
+ " 2.417 | \n",
+ " 52.823 | \n",
+ " 0.419 | \n",
+ " 0.424 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " electricity | \n",
+ " 0.156 | \n",
+ " 0.145 | \n",
+ " 0.138 | \n",
+ " 24.560 | \n",
+ " 5.033 | \n",
+ " 755.133 | \n",
+ " 8.928 | \n",
+ " 837.726 | \n",
+ " 0.112 | \n",
+ " 0.109 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " traffic | \n",
+ " 0.458 | \n",
+ " 0.416 | \n",
+ " 0.418 | \n",
+ " 43.940 | \n",
+ " 7.420 | \n",
+ " 674.037 | \n",
+ " 14.462 | \n",
+ " 497.818 | \n",
+ " 0.345 | \n",
+ " 0.343 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " dataset zs_mse fs5_mse fs10_mse zs_eval_time fs5_mean_epoch_time \\\n",
+ "0 etth1 0.362 0.361 0.363 1.778 0.988 \n",
+ "1 etth2 0.281 0.280 0.280 0.754 0.942 \n",
+ "2 ettm1 0.387 0.372 0.371 3.035 1.294 \n",
+ "3 ettm2 0.175 0.173 0.172 2.953 1.263 \n",
+ "4 weather 0.152 0.151 0.150 5.133 1.685 \n",
+ "5 electricity 0.156 0.145 0.138 24.560 5.033 \n",
+ "6 traffic 0.458 0.416 0.418 43.940 7.420 \n",
+ "\n",
+ " fs5_total_train_time fs10_mean_epoch_time fs10_total_train_time \\\n",
+ "0 38.519 1.003 29.517 \n",
+ "1 24.571 1.049 28.396 \n",
+ "2 53.025 1.611 53.820 \n",
+ "3 38.181 1.639 42.558 \n",
+ "4 44.329 2.417 52.823 \n",
+ "5 755.133 8.928 837.726 \n",
+ "6 674.037 14.462 497.818 \n",
+ "\n",
+ " fs5_best_val_metric fs10_best_val_metric \n",
+ "0 0.658 0.663 \n",
+ "1 0.224 0.223 \n",
+ "2 0.409 0.408 \n",
+ "3 0.121 0.121 \n",
+ "4 0.419 0.424 \n",
+ "5 0.112 0.109 \n",
+ "6 0.345 0.343 "
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_out"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/hfdemo/tinytimemixer/ttm-r1_benchmarking_512_96.ipynb b/notebooks/hfdemo/tinytimemixer/ttm-r1_benchmarking_512_96.ipynb
new file mode 100644
index 00000000..661972d5
--- /dev/null
+++ b/notebooks/hfdemo/tinytimemixer/ttm-r1_benchmarking_512_96.ipynb
@@ -0,0 +1,5409 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# TTM zero-shot and few-shot benchmarking on multiple datasets\n",
+ "\n",
+ "**Using TTM-512-96 model.**\n",
+ "\n",
+ "Pre-trained TTM models will be fetched from the [Hugging Face TTM Model Repository](ibm-granite/granite-timeseries-ttm-r1).\n",
+ "\n",
+ "1. TTM-R1 pre-trained models can be found here: [TTM-R1 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r1)\n",
+ " 1. For 512-96 model set `TTM_MODEL_REVISION=\"main\"`\n",
+ " 2. For 1024-96 model set `TTM_MODEL_REVISION=\"1024_96_v1\"`\n",
+ "2. TTM-R2 pre-trained models can be found here: [TTM-R2 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2)\n",
+ " 1. For 512-96 model set `TTM_MODEL_REVISION=\"main\"`\n",
+ " 2. For 1024-96 model set `TTM_MODEL_REVISION=\"1024-96-r2\"`\n",
+ " 3. For 1536-96 model set `TTM_MODEL_REVISION=\"1536-96-r2\"`\n",
+ "\n",
+ "Details about the revisions (R1 and R2) can be found [here](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2)."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2024-10-10 07:30:29.873090: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
+ "2024-10-10 07:30:29.910301: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2024-10-10 07:30:30.926289: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory\n",
+ " warn(f\"Failed to load image Python extension: {e}\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "import math\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd\n",
+ "from torch.optim import AdamW\n",
+ "from torch.optim.lr_scheduler import OneCycleLR\n",
+ "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n",
+ "from transformers.integrations import INTEGRATION_TO_CALLBACK\n",
+ "\n",
+ "from tsfm_public import TinyTimeMixerForPrediction, TrackingCallback, count_parameters, load_dataset\n",
+ "from tsfm_public.toolkit.visualization import plot_predictions"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Important arguments"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set seed\n",
+ "set_seed(42)\n",
+ "\n",
+ "# Specify model parameters\n",
+ "context_length = 512\n",
+ "forecast_length = 96\n",
+ "freeze_backbone = True\n",
+ "learning_rate = 0.001\n",
+ "\n",
+ "# Other args\n",
+ "EPOCHS = 50\n",
+ "NUM_WORKERS = 16\n",
+ "\n",
+ "# Make sure all the datasets in the following `list_datasets` are\n",
+ "# saved in the `DATA_ROOT_PATH` folder. Or, change it accordingly.\n",
+ "# Refer to the load_datasets() function to see how it is used.\n",
+ "DATA_ROOT_PATH = \"/dccstor/tsfm23/datasets/\"\n",
+ "\n",
+ "# This is where results will be saved\n",
+ "OUT_DIR = f\"ttm-r1_results_benchmark_{context_length}_{forecast_length}/\""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## List of benchmark datasets (TTM was not pre-trained on any of these)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "list_datasets = [\n",
+ " \"etth1\",\n",
+ " \"etth2\",\n",
+ " \"ettm1\",\n",
+ " \"ettm2\",\n",
+ " \"weather\",\n",
+ " \"electricity\",\n",
+ " \"traffic\",\n",
+ "]"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Get model path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "hf_model_path = \"ibm-granite/granite-timeseries-ttm-r1\"\n",
+ "if context_length == 512:\n",
+ " hf_model_branch = \"main\"\n",
+ "elif context_length == 1024:\n",
+ " hf_model_branch = \"1024_96_v1\"\n",
+ "else:\n",
+ " raise ValueError(\"Current supported context lengths are 512 and 1024. Stay tuned for more TTMs!\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Main benchmarking loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: etth1, context length: 512, prediction length 96\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 8033, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = etth1, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r1/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36317431926727295, 'eval_model_preparation_time': 0.0026, 'eval_runtime': 1.7392, 'eval_samples_per_second': 1601.334, 'eval_steps_per_second': 25.299}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: etth1, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 311, val = 2785, test = 2785\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 55/250 00:18 < 01:09, 2.80 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 1.085700 | \n",
+ " 0.656020 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 1.086200 | \n",
+ " 0.656616 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 1.070400 | \n",
+ " 0.657144 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 1.093300 | \n",
+ " 0.658152 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.937400 | \n",
+ " 0.659537 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.865200 | \n",
+ " 0.661400 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.803600 | \n",
+ " 0.662929 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.748000 | \n",
+ " 0.664672 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.698600 | \n",
+ " 0.667598 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.666000 | \n",
+ " 0.674859 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.616600 | \n",
+ " 0.685595 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 0.876734278418801 seconds, Total Train Time = 20.82619071006775\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.363126665353775, 'eval_runtime': 0.9014, 'eval_samples_per_second': 3089.797, 'eval_steps_per_second': 48.815, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: etth1, context length: 512, prediction length 96\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 717, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 10% --------------------\n",
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [204/600 00:31 < 01:01, 6.43 it/s, Epoch 17/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 1.043800 | \n",
+ " 0.655415 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.990900 | \n",
+ " 0.655896 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.884400 | \n",
+ " 0.657076 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.792000 | \n",
+ " 0.657461 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.665900 | \n",
+ " 0.657554 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.621100 | \n",
+ " 0.655823 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.527600 | \n",
+ " 0.655078 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.474300 | \n",
+ " 0.657213 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.443800 | \n",
+ " 0.662531 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.433600 | \n",
+ " 0.670480 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.410300 | \n",
+ " 0.681129 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.407100 | \n",
+ " 0.680766 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.393800 | \n",
+ " 0.694353 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.392600 | \n",
+ " 0.692552 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.375600 | \n",
+ " 0.702562 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.373900 | \n",
+ " 0.702306 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.369000 | \n",
+ " 0.706614 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 0.8558026341830983 seconds, Total Train Time = 32.004756927490234\n",
+ "++++++++++++++++++++ Test MSE after few-shot 10% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36420342326164246, 'eval_runtime': 0.9572, 'eval_samples_per_second': 2909.495, 'eval_steps_per_second': 45.967, 'epoch': 17.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: etth2, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse fs10_mse\n",
+ "0 etth1 0.363 0.363 0.364\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = etth2, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r1/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 8033, val = 2785, test = 2785\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.28556713461875916, 'eval_model_preparation_time': 0.0018, 'eval_runtime': 0.8802, 'eval_samples_per_second': 3163.949, 'eval_steps_per_second': 49.987}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: etth2, context length: 512, prediction length 96\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 311, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 60/250 00:21 < 01:10, 2.70 it/s, Epoch 12/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.497100 | \n",
+ " 0.208019 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.439300 | \n",
+ " 0.207998 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.450200 | \n",
+ " 0.208099 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.424500 | \n",
+ " 0.208681 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.399700 | \n",
+ " 0.209764 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.336700 | \n",
+ " 0.211253 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.267400 | \n",
+ " 0.213202 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.247000 | \n",
+ " 0.215709 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.223300 | \n",
+ " 0.218617 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.187200 | \n",
+ " 0.222340 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.170400 | \n",
+ " 0.225701 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.159400 | \n",
+ " 0.230151 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 0.8081231911977133 seconds, Total Train Time = 22.086035013198853\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.2842233180999756, 'eval_runtime': 0.9961, 'eval_samples_per_second': 2795.765, 'eval_steps_per_second': 44.17, 'epoch': 12.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: etth2, context length: 512, prediction length 96\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 717, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 10% --------------------\n",
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [132/600 00:20 < 01:13, 6.39 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.694300 | \n",
+ " 0.208229 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.667200 | \n",
+ " 0.208902 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.684900 | \n",
+ " 0.210279 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.530900 | \n",
+ " 0.212758 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.471600 | \n",
+ " 0.216474 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.407100 | \n",
+ " 0.222424 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.366300 | \n",
+ " 0.230155 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.335900 | \n",
+ " 0.234342 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.310300 | \n",
+ " 0.233168 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.305700 | \n",
+ " 0.231881 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.290800 | \n",
+ " 0.239227 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 0.8696485215967352 seconds, Total Train Time = 20.990825176239014\n",
+ "++++++++++++++++++++ Test MSE after few-shot 10% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.2839512526988983, 'eval_runtime': 1.0239, 'eval_samples_per_second': 2720.009, 'eval_steps_per_second': 42.973, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: ettm1, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse fs10_mse\n",
+ "0 etth1 0.363 0.363 0.364\n",
+ "1 etth2 0.286 0.284 0.284\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = ettm1, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r1/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 33953, val = 11425, test = 11425\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:02]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.41525664925575256, 'eval_model_preparation_time': 0.0018, 'eval_runtime': 2.4875, 'eval_samples_per_second': 4592.904, 'eval_steps_per_second': 71.959}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: ettm1, context length: 512, prediction length 96\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 1607, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 520/1300 00:57 < 01:27, 8.96 it/s, Epoch 20/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.550900 | \n",
+ " 0.463731 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.479900 | \n",
+ " 0.465929 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.454400 | \n",
+ " 0.473586 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.367000 | \n",
+ " 0.475486 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.315800 | \n",
+ " 0.475515 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.269300 | \n",
+ " 0.468186 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.253400 | \n",
+ " 0.460052 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.239900 | \n",
+ " 0.458469 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.233500 | \n",
+ " 0.453531 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.225800 | \n",
+ " 0.453469 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.222700 | \n",
+ " 0.455705 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.217800 | \n",
+ " 0.453836 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.213800 | \n",
+ " 0.456086 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.212700 | \n",
+ " 0.458392 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.208400 | \n",
+ " 0.456380 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.207400 | \n",
+ " 0.462406 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.204400 | \n",
+ " 0.465798 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.201600 | \n",
+ " 0.465260 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.199300 | \n",
+ " 0.473123 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.200500 | \n",
+ " 0.470573 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.0739147901535033 seconds, Total Train Time = 58.48771643638611\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3644302189350128, 'eval_runtime': 1.8346, 'eval_samples_per_second': 6227.482, 'eval_steps_per_second': 97.568, 'epoch': 20.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: ettm1, context length: 512, prediction length 96\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 3309, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 10% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 936/2600 01:05 < 01:56, 14.31 it/s, Epoch 18/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.653900 | \n",
+ " 0.460911 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.553200 | \n",
+ " 0.463849 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.452500 | \n",
+ " 0.466370 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.364200 | \n",
+ " 0.445985 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.320800 | \n",
+ " 0.436441 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.302200 | \n",
+ " 0.430455 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.293700 | \n",
+ " 0.430863 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.284700 | \n",
+ " 0.427922 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.279800 | \n",
+ " 0.434429 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.275000 | \n",
+ " 0.431091 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.270600 | \n",
+ " 0.431898 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.268600 | \n",
+ " 0.429764 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.265100 | \n",
+ " 0.439841 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.264000 | \n",
+ " 0.432602 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.261000 | \n",
+ " 0.434874 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.260600 | \n",
+ " 0.439803 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.256600 | \n",
+ " 0.444250 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.255100 | \n",
+ " 0.443020 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.6178780794143677 seconds, Total Train Time = 65.98268413543701\n",
+ "++++++++++++++++++++ Test MSE after few-shot 10% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.37092921137809753, 'eval_runtime': 2.0838, 'eval_samples_per_second': 5482.726, 'eval_steps_per_second': 85.9, 'epoch': 18.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: ettm2, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse fs10_mse\n",
+ "0 etth1 0.363 0.363 0.364\n",
+ "1 etth2 0.286 0.284 0.284\n",
+ "2 ettm1 0.415 0.364 0.371\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = ettm2, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r1/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 33953, val = 11425, test = 11425\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:02]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1860235333442688, 'eval_model_preparation_time': 0.0018, 'eval_runtime': 2.6069, 'eval_samples_per_second': 4382.517, 'eval_steps_per_second': 68.663}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: ettm2, context length: 512, prediction length 96\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 1607, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 338/1300 00:42 < 02:02, 7.87 it/s, Epoch 13/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.403100 | \n",
+ " 0.130643 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.340000 | \n",
+ " 0.129244 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.283400 | \n",
+ " 0.128597 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.238700 | \n",
+ " 0.130647 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.197600 | \n",
+ " 0.135873 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.178500 | \n",
+ " 0.141251 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.160400 | \n",
+ " 0.143489 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.151500 | \n",
+ " 0.143133 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.144200 | \n",
+ " 0.145625 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.141300 | \n",
+ " 0.146513 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.138700 | \n",
+ " 0.148491 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.135700 | \n",
+ " 0.151306 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.132300 | \n",
+ " 0.146737 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.3028539510873647 seconds, Total Train Time = 43.605464220047\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.17499123513698578, 'eval_runtime': 2.0779, 'eval_samples_per_second': 5498.384, 'eval_steps_per_second': 86.145, 'epoch': 13.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: ettm2, context length: 512, prediction length 96\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 3309, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 10% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 624/2600 00:42 < 02:16, 14.51 it/s, Epoch 12/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.366700 | \n",
+ " 0.129779 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.267700 | \n",
+ " 0.128715 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.215200 | \n",
+ " 0.129231 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.169600 | \n",
+ " 0.130869 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.150000 | \n",
+ " 0.131003 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.139700 | \n",
+ " 0.131113 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.134100 | \n",
+ " 0.130966 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.129800 | \n",
+ " 0.134528 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.127100 | \n",
+ " 0.132286 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.124300 | \n",
+ " 0.136354 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.122800 | \n",
+ " 0.130616 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.120800 | \n",
+ " 0.137425 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.621331552664439 seconds, Total Train Time = 43.78216910362244\n",
+ "++++++++++++++++++++ Test MSE after few-shot 10% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.17638568580150604, 'eval_runtime': 2.1272, 'eval_samples_per_second': 5370.977, 'eval_steps_per_second': 84.149, 'epoch': 12.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: weather, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse fs10_mse\n",
+ "0 etth1 0.363 0.363 0.364\n",
+ "1 etth2 0.286 0.284 0.284\n",
+ "2 ettm1 0.415 0.364 0.371\n",
+ "3 ettm2 0.186 0.175 0.176\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = weather, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r1/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 36280, val = 5175, test = 10444\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:03]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1524711698293686, 'eval_model_preparation_time': 0.0018, 'eval_runtime': 3.3764, 'eval_samples_per_second': 3093.197, 'eval_steps_per_second': 48.572}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: weather, context length: 512, prediction length 96\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 1723, val = 5175, test = 10444\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 351/1350 00:40 < 01:56, 8.55 it/s, Epoch 13/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.160100 | \n",
+ " 0.425349 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.155800 | \n",
+ " 0.422991 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.151400 | \n",
+ " 0.422865 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.146100 | \n",
+ " 0.427230 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.140200 | \n",
+ " 0.434825 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.133500 | \n",
+ " 0.442507 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.127200 | \n",
+ " 0.453159 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.120200 | \n",
+ " 0.465943 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.114300 | \n",
+ " 0.465322 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.109000 | \n",
+ " 0.464073 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.103900 | \n",
+ " 0.479937 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.098800 | \n",
+ " 0.485888 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.095600 | \n",
+ " 0.479965 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.393344255594107 seconds, Total Train Time = 41.77857685089111\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.15006662905216217, 'eval_runtime': 2.5414, 'eval_samples_per_second': 4109.619, 'eval_steps_per_second': 64.533, 'epoch': 13.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: weather, context length: 512, prediction length 96\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 3542, val = 5175, test = 10444\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 10% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 672/2800 00:43 < 02:18, 15.38 it/s, Epoch 12/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.134900 | \n",
+ " 0.422834 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.131000 | \n",
+ " 0.421728 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.128000 | \n",
+ " 0.422719 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.123700 | \n",
+ " 0.425492 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.120500 | \n",
+ " 0.428487 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.116000 | \n",
+ " 0.436083 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.112200 | \n",
+ " 0.438655 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.106800 | \n",
+ " 0.437371 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.103000 | \n",
+ " 0.436040 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.100000 | \n",
+ " 0.427018 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.096600 | \n",
+ " 0.435761 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.093300 | \n",
+ " 0.433628 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.9348897337913513 seconds, Total Train Time = 44.50661301612854\n",
+ "++++++++++++++++++++ Test MSE after few-shot 10% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.14866013824939728, 'eval_runtime': 2.442, 'eval_samples_per_second': 4276.86, 'eval_steps_per_second': 67.159, 'epoch': 12.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: electricity, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse fs10_mse\n",
+ "0 etth1 0.363 0.363 0.364\n",
+ "1 etth2 0.286 0.284 0.284\n",
+ "2 ettm1 0.415 0.364 0.371\n",
+ "3 ettm2 0.186 0.175 0.176\n",
+ "4 weather 0.152 0.150 0.149\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = electricity, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r1/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 17805, val = 2537, test = 5165\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:13]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.17006558179855347, 'eval_model_preparation_time': 0.0019, 'eval_runtime': 14.0713, 'eval_samples_per_second': 367.059, 'eval_steps_per_second': 11.513}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: electricity, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 800, val = 2537, test = 5165\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [1250/1250 07:29, Epoch 50/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.186000 | \n",
+ " 0.136702 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.179200 | \n",
+ " 0.132026 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.173600 | \n",
+ " 0.128869 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.167900 | \n",
+ " 0.125446 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.163200 | \n",
+ " 0.123641 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.159400 | \n",
+ " 0.122560 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.156500 | \n",
+ " 0.121135 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.153700 | \n",
+ " 0.120255 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.151700 | \n",
+ " 0.119879 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.149700 | \n",
+ " 0.118841 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.147900 | \n",
+ " 0.119294 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.146600 | \n",
+ " 0.118377 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.145100 | \n",
+ " 0.119855 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.144400 | \n",
+ " 0.118071 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.142800 | \n",
+ " 0.118609 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.142100 | \n",
+ " 0.118664 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.141100 | \n",
+ " 0.118297 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.140100 | \n",
+ " 0.118825 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.139000 | \n",
+ " 0.117799 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.138800 | \n",
+ " 0.118162 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.138300 | \n",
+ " 0.118339 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.137700 | \n",
+ " 0.117534 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.137200 | \n",
+ " 0.117699 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.136200 | \n",
+ " 0.117654 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.135300 | \n",
+ " 0.117274 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.135100 | \n",
+ " 0.117221 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.134600 | \n",
+ " 0.117807 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.134200 | \n",
+ " 0.117367 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 0.133800 | \n",
+ " 0.117252 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 0.133400 | \n",
+ " 0.117081 | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " 0.133000 | \n",
+ " 0.117083 | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " 0.132900 | \n",
+ " 0.116850 | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " 0.132400 | \n",
+ " 0.116892 | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " 0.132200 | \n",
+ " 0.116912 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 0.131800 | \n",
+ " 0.117315 | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " 0.131500 | \n",
+ " 0.116783 | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " 0.130900 | \n",
+ " 0.116776 | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " 0.130800 | \n",
+ " 0.116731 | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " 0.130600 | \n",
+ " 0.116967 | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " 0.130600 | \n",
+ " 0.116730 | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " 0.130300 | \n",
+ " 0.116513 | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " 0.130300 | \n",
+ " 0.116554 | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " 0.130000 | \n",
+ " 0.116673 | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " 0.130000 | \n",
+ " 0.116653 | \n",
+ "
\n",
+ " \n",
+ " 45 | \n",
+ " 0.130000 | \n",
+ " 0.116706 | \n",
+ "
\n",
+ " \n",
+ " 46 | \n",
+ " 0.130000 | \n",
+ " 0.116553 | \n",
+ "
\n",
+ " \n",
+ " 47 | \n",
+ " 0.130000 | \n",
+ " 0.116500 | \n",
+ "
\n",
+ " \n",
+ " 48 | \n",
+ " 0.129900 | \n",
+ " 0.116503 | \n",
+ "
\n",
+ " \n",
+ " 49 | \n",
+ " 0.129700 | \n",
+ " 0.116548 | \n",
+ "
\n",
+ " \n",
+ " 50 | \n",
+ " 0.129800 | \n",
+ " 0.116546 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 3.216276993751526 seconds, Total Train Time = 450.9349133968353\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:09]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1425967961549759, 'eval_runtime': 10.3735, 'eval_samples_per_second': 497.905, 'eval_steps_per_second': 15.617, 'epoch': 50.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: electricity, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 10% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 1695, val = 2537, test = 5165\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [1325/2650 04:31 < 04:32, 4.86 it/s, Epoch 25/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.166400 | \n",
+ " 0.131775 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.156800 | \n",
+ " 0.126925 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.150300 | \n",
+ " 0.123428 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.145700 | \n",
+ " 0.121103 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.141900 | \n",
+ " 0.119786 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.138900 | \n",
+ " 0.118132 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.136400 | \n",
+ " 0.117050 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.134200 | \n",
+ " 0.116493 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.132600 | \n",
+ " 0.116092 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.131200 | \n",
+ " 0.115692 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.130500 | \n",
+ " 0.115982 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.129500 | \n",
+ " 0.115369 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.128500 | \n",
+ " 0.115938 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.128100 | \n",
+ " 0.115339 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.127300 | \n",
+ " 0.114844 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.126600 | \n",
+ " 0.115098 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.126100 | \n",
+ " 0.115571 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.125900 | \n",
+ " 0.115323 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.125100 | \n",
+ " 0.115411 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.124700 | \n",
+ " 0.114962 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.124200 | \n",
+ " 0.114975 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.123700 | \n",
+ " 0.114859 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.123400 | \n",
+ " 0.114951 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.122900 | \n",
+ " 0.115152 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.122600 | \n",
+ " 0.115177 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 5.291070413589478 seconds, Total Train Time = 272.9690537452698\n",
+ "++++++++++++++++++++ Test MSE after few-shot 10% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:08]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.13970844447612762, 'eval_runtime': 10.0501, 'eval_samples_per_second': 513.925, 'eval_steps_per_second': 16.119, 'epoch': 25.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: traffic, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse fs10_mse\n",
+ "0 etth1 0.363 0.363 0.364\n",
+ "1 etth2 0.286 0.284 0.284\n",
+ "2 ettm1 0.415 0.364 0.371\n",
+ "3 ettm2 0.186 0.175 0.176\n",
+ "4 weather 0.152 0.150 0.149\n",
+ "5 electricity 0.170 0.143 0.140\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = traffic, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r1/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 11673, val = 1661, test = 3413\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 00:23]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.5094045996665955, 'eval_model_preparation_time': 0.0019, 'eval_runtime': 23.857, 'eval_samples_per_second': 143.061, 'eval_steps_per_second': 17.898}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: traffic, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 493, val = 1661, test = 3413\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [3100/3100 11:42, Epoch 50/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.272700 | \n",
+ " 0.393278 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.253400 | \n",
+ " 0.375481 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.241100 | \n",
+ " 0.360526 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.230500 | \n",
+ " 0.351872 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.222200 | \n",
+ " 0.344429 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.214800 | \n",
+ " 0.339461 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.209600 | \n",
+ " 0.338062 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.205600 | \n",
+ " 0.336990 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.202900 | \n",
+ " 0.336078 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.200000 | \n",
+ " 0.334375 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.198000 | \n",
+ " 0.333791 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.197000 | \n",
+ " 0.333844 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.195100 | \n",
+ " 0.333792 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.193600 | \n",
+ " 0.333915 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.192700 | \n",
+ " 0.334478 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.191500 | \n",
+ " 0.333000 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.190800 | \n",
+ " 0.332865 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.189500 | \n",
+ " 0.334100 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.188800 | \n",
+ " 0.332967 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.188100 | \n",
+ " 0.331086 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.186900 | \n",
+ " 0.332582 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.186700 | \n",
+ " 0.331533 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.185800 | \n",
+ " 0.330423 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.185100 | \n",
+ " 0.331567 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.184600 | \n",
+ " 0.331676 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.184500 | \n",
+ " 0.330323 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.183900 | \n",
+ " 0.330532 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.183300 | \n",
+ " 0.329897 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 0.182400 | \n",
+ " 0.330098 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 0.181900 | \n",
+ " 0.330095 | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " 0.181800 | \n",
+ " 0.329849 | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " 0.181300 | \n",
+ " 0.329267 | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " 0.180700 | \n",
+ " 0.329384 | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " 0.180200 | \n",
+ " 0.329585 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 0.180200 | \n",
+ " 0.328754 | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " 0.179500 | \n",
+ " 0.328836 | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " 0.179400 | \n",
+ " 0.328085 | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " 0.178800 | \n",
+ " 0.328287 | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " 0.178700 | \n",
+ " 0.328173 | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " 0.178400 | \n",
+ " 0.328408 | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " 0.178100 | \n",
+ " 0.328306 | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " 0.177900 | \n",
+ " 0.327732 | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " 0.177700 | \n",
+ " 0.328101 | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " 0.177600 | \n",
+ " 0.327719 | \n",
+ "
\n",
+ " \n",
+ " 45 | \n",
+ " 0.177700 | \n",
+ " 0.327562 | \n",
+ "
\n",
+ " \n",
+ " 46 | \n",
+ " 0.177300 | \n",
+ " 0.327719 | \n",
+ "
\n",
+ " \n",
+ " 47 | \n",
+ " 0.177100 | \n",
+ " 0.327573 | \n",
+ "
\n",
+ " \n",
+ " 48 | \n",
+ " 0.177200 | \n",
+ " 0.327571 | \n",
+ "
\n",
+ " \n",
+ " 49 | \n",
+ " 0.177200 | \n",
+ " 0.327563 | \n",
+ "
\n",
+ " \n",
+ " 50 | \n",
+ " 0.177200 | \n",
+ " 0.327564 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 4.735059623718262 seconds, Total Train Time = 703.8364264965057\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 00:16]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3968665301799774, 'eval_runtime': 17.7251, 'eval_samples_per_second': 192.552, 'eval_steps_per_second': 24.09, 'epoch': 50.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Dataset name: traffic, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 10% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "/dccstor/dnn_forecasting/arindam/FM/HF/public_tsfm/tsfm/tsfm_public/toolkit/dataset.py:186: PerformanceWarning: DataFrame is highly fragmented. This is usually the result of calling `frame.insert` many times, which has poor performance. Consider joining all columns at once using pd.concat(axis=1) instead. To get a de-fragmented frame, use `newframe = frame.copy()`\n",
+ " data_df[\"group\"] = 0 # create a artificial group\n",
+ "INFO:p-3134008:t-23177103872768:data_handling.py:load_dataset:Data lengths: train = 1081, val = 1661, test = 3413\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/transformers/training_args.py:1545: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead\n",
+ " warnings.warn(\n",
+ "WARNING:p-3134008:t-23177103872768:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3134008:t-23177103872768:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "Using learning rate = 0.001\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [4080/6800 09:01 < 06:01, 7.53 it/s, Epoch 30/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.294200 | \n",
+ " 0.380806 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.271200 | \n",
+ " 0.362084 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.258500 | \n",
+ " 0.351829 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.248300 | \n",
+ " 0.345643 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.240500 | \n",
+ " 0.340656 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.234500 | \n",
+ " 0.339494 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.229600 | \n",
+ " 0.335847 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.226500 | \n",
+ " 0.335783 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.225500 | \n",
+ " 0.338349 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.223200 | \n",
+ " 0.336193 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.222600 | \n",
+ " 0.343954 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.222100 | \n",
+ " 0.340995 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.219800 | \n",
+ " 0.339137 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.219100 | \n",
+ " 0.335982 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.217700 | \n",
+ " 0.344850 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.218600 | \n",
+ " 0.342654 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.218100 | \n",
+ " 0.333909 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.215300 | \n",
+ " 0.338186 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.213600 | \n",
+ " 0.342740 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.213100 | \n",
+ " 0.332170 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.213200 | \n",
+ " 0.335310 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.213000 | \n",
+ " 0.334148 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.211000 | \n",
+ " 0.337650 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.211000 | \n",
+ " 0.340426 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.210200 | \n",
+ " 0.339711 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.210900 | \n",
+ " 0.342000 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.209600 | \n",
+ " 0.339016 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.208400 | \n",
+ " 0.335918 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 0.207600 | \n",
+ " 0.332504 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 0.206700 | \n",
+ " 0.337658 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 8.79585785071055 seconds, Total Train Time = 542.8885765075684\n",
+ "++++++++++++++++++++ Test MSE after few-shot 10% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 00:16]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
+ " self.pid = os.fork()\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.4039205312728882, 'eval_runtime': 17.9277, 'eval_samples_per_second': 190.376, 'eval_steps_per_second': 23.818, 'epoch': 30.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
+ " dataset zs_mse fs5_mse fs10_mse\n",
+ "0 etth1 0.363 0.363 0.364\n",
+ "1 etth2 0.286 0.284 0.284\n",
+ "2 ettm1 0.415 0.364 0.371\n",
+ "3 ettm2 0.186 0.175 0.176\n",
+ "4 weather 0.152 0.150 0.149\n",
+ "5 electricity 0.170 0.143 0.140\n",
+ "6 traffic 0.509 0.397 0.404\n"
+ ]
+ }
+ ],
+ "source": [
+ "all_results = {\n",
+ " \"dataset\": [],\n",
+ " \"zs_mse\": [],\n",
+ " \"fs5_mse\": [],\n",
+ " \"fs10_mse\": [],\n",
+ " \"zs_eval_time\": [],\n",
+ " \"fs5_mean_epoch_time\": [],\n",
+ " \"fs5_total_train_time\": [],\n",
+ " \"fs10_mean_epoch_time\": [],\n",
+ " \"fs10_total_train_time\": [],\n",
+ " \"fs5_best_val_metric\": [],\n",
+ " \"fs10_best_val_metric\": [],\n",
+ "}\n",
+ "# Loop over data\n",
+ "for DATASET in list_datasets:\n",
+ " print()\n",
+ " print(\"=\" * 100)\n",
+ " print(\n",
+ " f\"Running zero-shot/few-shot for TTM-{context_length} on dataset = {DATASET}, forecast_len = {forecast_length}\"\n",
+ " )\n",
+ " print(f\"Model will be loaded from {hf_model_path}/{hf_model_branch}\")\n",
+ " SUBDIR = f\"{OUT_DIR}/{DATASET}\"\n",
+ "\n",
+ " # Set batch size\n",
+ " if DATASET == \"traffic\":\n",
+ " BATCH_SIZE = 8\n",
+ " elif DATASET == \"electricity\":\n",
+ " BATCH_SIZE = 32\n",
+ " else:\n",
+ " BATCH_SIZE = 64\n",
+ "\n",
+ " # Data prep: Get dataset\n",
+ " _, _, dset_test = load_dataset(DATASET, context_length, forecast_length, dataset_root_path=DATA_ROOT_PATH)\n",
+ "\n",
+ " #############################################################\n",
+ " ##### Use the pretrained model in zero-shot forecasting #####\n",
+ " #############################################################\n",
+ " # Load model\n",
+ " zeroshot_model = TinyTimeMixerForPrediction.from_pretrained(hf_model_path, revision=hf_model_branch)\n",
+ "\n",
+ " # zeroshot_trainer\n",
+ " zeroshot_trainer = Trainer(\n",
+ " model=zeroshot_model,\n",
+ " args=TrainingArguments(\n",
+ " output_dir=f\"{SUBDIR}/zeroshot\",\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " ),\n",
+ " eval_dataset=dset_test,\n",
+ " )\n",
+ "\n",
+ " # evaluate = zero-shot performance\n",
+ " print(\"+\" * 20, \"Test MSE zero-shot\", \"+\" * 20)\n",
+ " zeroshot_output = zeroshot_trainer.evaluate(dset_test)\n",
+ " print(zeroshot_output)\n",
+ " print(\"+\" * 60)\n",
+ " all_results[\"zs_eval_time\"].append(zeroshot_output[\"eval_runtime\"])\n",
+ "\n",
+ " # Plot\n",
+ " plot_predictions(\n",
+ " model=zeroshot_trainer.model,\n",
+ " dset=dset_test,\n",
+ " plot_dir=SUBDIR,\n",
+ " num_plots=10,\n",
+ " plot_prefix=\"test_zeroshot\",\n",
+ " channel=0,\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " # write results\n",
+ " all_results[\"dataset\"].append(DATASET)\n",
+ " all_results[\"zs_mse\"].append(zeroshot_output[\"eval_loss\"])\n",
+ "\n",
+ " ################################################################\n",
+ " ## Use the pretrained model in few-shot 5% and 10% forecasting #\n",
+ " ################################################################\n",
+ " for fewshot_percent in [5, 10]:\n",
+ " print(\"-\" * 20, f\"Running few-shot {fewshot_percent}%\", \"-\" * 20)\n",
+ " # Data prep: Get dataset\n",
+ " dset_train, dset_val, dset_test = load_dataset(\n",
+ " DATASET,\n",
+ " context_length,\n",
+ " forecast_length,\n",
+ " fewshot_fraction=fewshot_percent / 100,\n",
+ " dataset_root_path=DATA_ROOT_PATH,\n",
+ " )\n",
+ "\n",
+ " # change head dropout to 0.7 for ett datasets\n",
+ " if \"ett\" in DATASET:\n",
+ " finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(\n",
+ " hf_model_path, revision=hf_model_branch, head_dropout=0.7\n",
+ " )\n",
+ " else:\n",
+ " finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(\n",
+ " hf_model_path, revision=hf_model_branch\n",
+ " )\n",
+ "\n",
+ " if freeze_backbone:\n",
+ " print(\n",
+ " \"Number of params before freezing backbone\",\n",
+ " count_parameters(finetune_forecast_model),\n",
+ " )\n",
+ "\n",
+ " # Freeze the backbone of the model\n",
+ " for param in finetune_forecast_model.backbone.parameters():\n",
+ " param.requires_grad = False\n",
+ "\n",
+ " # Count params\n",
+ " print(\n",
+ " \"Number of params after freezing the backbone\",\n",
+ " count_parameters(finetune_forecast_model),\n",
+ " )\n",
+ "\n",
+ " print(f\"Using learning rate = {learning_rate}\")\n",
+ " finetune_forecast_args = TrainingArguments(\n",
+ " output_dir=f\"{SUBDIR}/fewshot_{fewshot_percent}\",\n",
+ " overwrite_output_dir=True,\n",
+ " learning_rate=learning_rate,\n",
+ " num_train_epochs=EPOCHS,\n",
+ " do_eval=True,\n",
+ " evaluation_strategy=\"epoch\",\n",
+ " per_device_train_batch_size=BATCH_SIZE,\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " dataloader_num_workers=NUM_WORKERS,\n",
+ " report_to=None,\n",
+ " save_strategy=\"epoch\",\n",
+ " logging_strategy=\"epoch\",\n",
+ " save_total_limit=1,\n",
+ " logging_dir=f\"{SUBDIR}/fewshot_{fewshot_percent}\", # Make sure to specify a logging directory\n",
+ " load_best_model_at_end=True, # Load the best model when training ends\n",
+ " metric_for_best_model=\"eval_loss\", # Metric to monitor for early stopping\n",
+ " greater_is_better=False, # For loss\n",
+ " )\n",
+ "\n",
+ " # Create the early stopping callback\n",
+ " early_stopping_callback = EarlyStoppingCallback(\n",
+ " early_stopping_patience=10, # Number of epochs with no improvement after which to stop\n",
+ " early_stopping_threshold=0.0, # Minimum improvement required to consider as improvement\n",
+ " )\n",
+ " tracking_callback = TrackingCallback()\n",
+ "\n",
+ " # Optimizer and scheduler\n",
+ " optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate)\n",
+ " scheduler = OneCycleLR(\n",
+ " optimizer,\n",
+ " learning_rate,\n",
+ " epochs=EPOCHS,\n",
+ " steps_per_epoch=math.ceil(len(dset_train) / (BATCH_SIZE)),\n",
+ " )\n",
+ "\n",
+ " finetune_forecast_trainer = Trainer(\n",
+ " model=finetune_forecast_model,\n",
+ " args=finetune_forecast_args,\n",
+ " train_dataset=dset_train,\n",
+ " eval_dataset=dset_val,\n",
+ " callbacks=[early_stopping_callback, tracking_callback],\n",
+ " optimizers=(optimizer, scheduler),\n",
+ " )\n",
+ " finetune_forecast_trainer.remove_callback(INTEGRATION_TO_CALLBACK[\"codecarbon\"])\n",
+ "\n",
+ " # Fine tune\n",
+ " finetune_forecast_trainer.train()\n",
+ "\n",
+ " # Evaluation\n",
+ " print(\n",
+ " \"+\" * 20,\n",
+ " f\"Test MSE after few-shot {fewshot_percent}% fine-tuning\",\n",
+ " \"+\" * 20,\n",
+ " )\n",
+ " fewshot_output = finetune_forecast_trainer.evaluate(dset_test)\n",
+ " print(fewshot_output)\n",
+ " print(\"+\" * 60)\n",
+ "\n",
+ " # Plot\n",
+ " plot_predictions(\n",
+ " model=finetune_forecast_trainer.model,\n",
+ " dset=dset_test,\n",
+ " plot_dir=SUBDIR,\n",
+ " num_plots=10,\n",
+ " plot_prefix=f\"test_fewshot_{fewshot_percent}\",\n",
+ " channel=0,\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " # write results\n",
+ " all_results[f\"fs{fewshot_percent}_mse\"].append(fewshot_output[\"eval_loss\"])\n",
+ " all_results[f\"fs{fewshot_percent}_mean_epoch_time\"].append(tracking_callback.mean_epoch_time)\n",
+ " all_results[f\"fs{fewshot_percent}_total_train_time\"].append(tracking_callback.total_train_time)\n",
+ " all_results[f\"fs{fewshot_percent}_best_val_metric\"].append(tracking_callback.best_eval_metric)\n",
+ "\n",
+ " df_out = pd.DataFrame(all_results).round(3)\n",
+ " print(df_out[[\"dataset\", \"zs_mse\", \"fs5_mse\", \"fs10_mse\"]])\n",
+ " df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")\n",
+ " df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Benchmarking results*\n",
+ "\n",
+ "*Some slight differences in the results as compared to the TTM paper results is possible due to different training environments."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " dataset | \n",
+ " zs_mse | \n",
+ " fs5_mse | \n",
+ " fs10_mse | \n",
+ " zs_eval_time | \n",
+ " fs5_mean_epoch_time | \n",
+ " fs5_total_train_time | \n",
+ " fs10_mean_epoch_time | \n",
+ " fs10_total_train_time | \n",
+ " fs5_best_val_metric | \n",
+ " fs10_best_val_metric | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " etth1 | \n",
+ " 0.363 | \n",
+ " 0.363 | \n",
+ " 0.364 | \n",
+ " 1.739 | \n",
+ " 0.877 | \n",
+ " 20.826 | \n",
+ " 0.856 | \n",
+ " 32.005 | \n",
+ " 0.656 | \n",
+ " 0.655 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " etth2 | \n",
+ " 0.286 | \n",
+ " 0.284 | \n",
+ " 0.284 | \n",
+ " 0.880 | \n",
+ " 0.808 | \n",
+ " 22.086 | \n",
+ " 0.870 | \n",
+ " 20.991 | \n",
+ " 0.208 | \n",
+ " 0.208 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " ettm1 | \n",
+ " 0.415 | \n",
+ " 0.364 | \n",
+ " 0.371 | \n",
+ " 2.488 | \n",
+ " 1.074 | \n",
+ " 58.488 | \n",
+ " 1.618 | \n",
+ " 65.983 | \n",
+ " 0.453 | \n",
+ " 0.428 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " ettm2 | \n",
+ " 0.186 | \n",
+ " 0.175 | \n",
+ " 0.176 | \n",
+ " 2.607 | \n",
+ " 1.303 | \n",
+ " 43.605 | \n",
+ " 1.621 | \n",
+ " 43.782 | \n",
+ " 0.129 | \n",
+ " 0.129 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " weather | \n",
+ " 0.152 | \n",
+ " 0.150 | \n",
+ " 0.149 | \n",
+ " 3.376 | \n",
+ " 1.393 | \n",
+ " 41.779 | \n",
+ " 1.935 | \n",
+ " 44.507 | \n",
+ " 0.423 | \n",
+ " 0.422 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " electricity | \n",
+ " 0.170 | \n",
+ " 0.143 | \n",
+ " 0.140 | \n",
+ " 14.071 | \n",
+ " 3.216 | \n",
+ " 450.935 | \n",
+ " 5.291 | \n",
+ " 272.969 | \n",
+ " 0.116 | \n",
+ " 0.115 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " traffic | \n",
+ " 0.509 | \n",
+ " 0.397 | \n",
+ " 0.404 | \n",
+ " 23.857 | \n",
+ " 4.735 | \n",
+ " 703.836 | \n",
+ " 8.796 | \n",
+ " 542.889 | \n",
+ " 0.328 | \n",
+ " 0.332 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " dataset zs_mse fs5_mse fs10_mse zs_eval_time fs5_mean_epoch_time \\\n",
+ "0 etth1 0.363 0.363 0.364 1.739 0.877 \n",
+ "1 etth2 0.286 0.284 0.284 0.880 0.808 \n",
+ "2 ettm1 0.415 0.364 0.371 2.488 1.074 \n",
+ "3 ettm2 0.186 0.175 0.176 2.607 1.303 \n",
+ "4 weather 0.152 0.150 0.149 3.376 1.393 \n",
+ "5 electricity 0.170 0.143 0.140 14.071 3.216 \n",
+ "6 traffic 0.509 0.397 0.404 23.857 4.735 \n",
+ "\n",
+ " fs5_total_train_time fs10_mean_epoch_time fs10_total_train_time \\\n",
+ "0 20.826 0.856 32.005 \n",
+ "1 22.086 0.870 20.991 \n",
+ "2 58.488 1.618 65.983 \n",
+ "3 43.605 1.621 43.782 \n",
+ "4 41.779 1.935 44.507 \n",
+ "5 450.935 5.291 272.969 \n",
+ "6 703.836 8.796 542.889 \n",
+ "\n",
+ " fs5_best_val_metric fs10_best_val_metric \n",
+ "0 0.656 0.655 \n",
+ "1 0.208 0.208 \n",
+ "2 0.453 0.428 \n",
+ "3 0.129 0.129 \n",
+ "4 0.423 0.422 \n",
+ "5 0.116 0.115 \n",
+ "6 0.328 0.332 "
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_out"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/hfdemo/tinytimemixer/ttm-r2_benchmarking_1024_96.ipynb b/notebooks/hfdemo/tinytimemixer/ttm-r2_benchmarking_1024_96.ipynb
new file mode 100644
index 00000000..99b4f495
--- /dev/null
+++ b/notebooks/hfdemo/tinytimemixer/ttm-r2_benchmarking_1024_96.ipynb
@@ -0,0 +1,2384 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# TTM zero-shot and few-shot benchmarking on multiple datasets\n",
+ "\n",
+ "**Using TTM-1024-96 model.**\n",
+ "\n",
+ "Pre-trained TTM models will be fetched from the [Hugging Face TTM Model Repository](ibm-granite/granite-timeseries-ttm-r2).\n",
+ "\n",
+ "1. TTM-R1 pre-trained models can be found here: [TTM-R1 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r1)\n",
+ " 1. For 512-96 model set `TTM_MODEL_REVISION=\"main\"`\n",
+ " 2. For 1024-96 model set `TTM_MODEL_REVISION=\"1024_96_v1\"`\n",
+ "2. TTM-R2 pre-trained models can be found here: [TTM-R2 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2)\n",
+ " 1. For 512-96 model set `TTM_MODEL_REVISION=\"main\"`\n",
+ " 2. For 1024-96 model set `TTM_MODEL_REVISION=\"1024-96-r2\"`\n",
+ " 3. For 1536-96 model set `TTM_MODEL_REVISION=\"1536-96-r2\"`\n",
+ "\n",
+ "Details about the revisions (R1 and R2) can be found [here](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2)."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2024-10-10 07:15:39.622201: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
+ "2024-10-10 07:15:39.658868: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2024-10-10 07:15:40.389511: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory\n",
+ " warn(f\"Failed to load image Python extension: {e}\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "import math\n",
+ "import warnings\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd\n",
+ "from torch.optim import AdamW\n",
+ "from torch.optim.lr_scheduler import OneCycleLR\n",
+ "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n",
+ "from transformers.integrations import INTEGRATION_TO_CALLBACK\n",
+ "\n",
+ "from tsfm_public import TinyTimeMixerForPrediction, TrackingCallback, count_parameters, load_dataset\n",
+ "from tsfm_public.toolkit.lr_finder import optimal_lr_finder\n",
+ "from tsfm_public.toolkit.visualization import plot_predictions\n",
+ "\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Important arguments"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set seed\n",
+ "SEED = 42\n",
+ "set_seed(SEED)\n",
+ "\n",
+ "# Specify model parameters\n",
+ "context_length = 1024\n",
+ "forecast_length = 96\n",
+ "freeze_backbone = True\n",
+ "\n",
+ "# Other args\n",
+ "EPOCHS = 50\n",
+ "NUM_WORKERS = 16\n",
+ "\n",
+ "# Make sure all the datasets in the following `list_datasets` are\n",
+ "# saved in the `DATA_ROOT_PATH` folder. Or, change it accordingly.\n",
+ "# Refer to the load_datasets() function\n",
+ "# in notebooks/hfdemo/tinytimemixer/utils/ttm_utils.py\n",
+ "# to see how it is used.\n",
+ "DATA_ROOT_PATH = \"/dccstor/tsfm23/datasets/\"\n",
+ "\n",
+ "# This is where results will be saved\n",
+ "OUT_DIR = f\"ttm-r2_results_benchmark_{context_length}_{forecast_length}/\""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## List of benchmark datasets (TTM was not pre-trained on any of these)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "list_datasets = [\n",
+ " \"etth1\",\n",
+ " \"etth2\",\n",
+ " \"ettm1\",\n",
+ " \"ettm2\",\n",
+ " \"weather\",\n",
+ " \"electricity\",\n",
+ " \"traffic\",\n",
+ "]"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Get model path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Please provide the branch name properly based on context_len and forecast_len\n",
+ "hf_model_path = \"ibm-granite/granite-timeseries-ttm-r2\"\n",
+ "hf_model_branch = f\"{context_length}-{forecast_length}-r2\""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Main benchmarking loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Dataset name: etth1, context length: 1024, prediction length 96\n",
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Data lengths: train = 7521, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = etth1, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/1024-96-r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048561:t-22509185540864:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048561:t-22509185540864:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.35859495401382446, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 1.7201, 'eval_samples_per_second': 1619.12, 'eval_steps_per_second': 25.58}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Dataset name: etth1, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Data lengths: train = 285, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 2964960\n",
+ "Number of params after freezing the backbone 955424\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048561:t-22509185540864:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048561:t-22509185540864:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.000298364724028334\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.000298364724028334\n",
+ "Using learning rate = 0.000298364724028334\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 55/250 00:25 < 01:33, 2.08 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.916600 | \n",
+ " 0.665669 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.888700 | \n",
+ " 0.665982 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.824300 | \n",
+ " 0.666453 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.886300 | \n",
+ " 0.667170 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.773700 | \n",
+ " 0.668418 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.695100 | \n",
+ " 0.669920 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.525000 | \n",
+ " 0.671401 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.475700 | \n",
+ " 0.673846 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.404700 | \n",
+ " 0.675814 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.374400 | \n",
+ " 0.677924 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.335900 | \n",
+ " 0.681410 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.0783685120669277 seconds, Total Train Time = 27.66278910636902\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.35856103897094727, 'eval_runtime': 1.2591, 'eval_samples_per_second': 2211.927, 'eval_steps_per_second': 34.946, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Dataset name: etth2, context length: 1024, prediction length 96\n",
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Data lengths: train = 7521, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.359 0.359\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = etth2, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/1024-96-r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048561:t-22509185540864:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048561:t-22509185540864:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.269417405128479, 'eval_model_preparation_time': 0.0021, 'eval_runtime': 0.7466, 'eval_samples_per_second': 3730.016, 'eval_steps_per_second': 58.93}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Dataset name: etth2, context length: 1024, prediction length 96\n",
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Data lengths: train = 285, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 2964960\n",
+ "Number of params after freezing the backbone 955424\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048561:t-22509185540864:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048561:t-22509185540864:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.000298364724028334\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.000298364724028334\n",
+ "Using learning rate = 0.000298364724028334\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 55/250 00:25 < 01:35, 2.04 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.945200 | \n",
+ " 0.239151 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.861000 | \n",
+ " 0.239945 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.805900 | \n",
+ " 0.241062 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.724700 | \n",
+ " 0.242527 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.654900 | \n",
+ " 0.244388 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.576800 | \n",
+ " 0.246938 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.495700 | \n",
+ " 0.250335 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.457700 | \n",
+ " 0.256598 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.392900 | \n",
+ " 0.267042 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.357100 | \n",
+ " 0.283817 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.323500 | \n",
+ " 0.308233 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.018961863084273 seconds, Total Train Time = 26.78283667564392\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.26942315697669983, 'eval_runtime': 1.3916, 'eval_samples_per_second': 2001.34, 'eval_steps_per_second': 31.619, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Dataset name: ettm1, context length: 1024, prediction length 96\n",
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Data lengths: train = 33441, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.359 0.359\n",
+ "1 etth2 0.269 0.269\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = ettm1, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/1024-96-r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048561:t-22509185540864:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048561:t-22509185540864:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:02]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3369019627571106, 'eval_model_preparation_time': 0.0021, 'eval_runtime': 3.0592, 'eval_samples_per_second': 3734.593, 'eval_steps_per_second': 58.511}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Dataset name: ettm1, context length: 1024, prediction length 96\n",
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Data lengths: train = 1581, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 2964960\n",
+ "Number of params after freezing the backbone 955424\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048561:t-22509185540864:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048561:t-22509185540864:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.0005214008287999684\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.0005214008287999684\n",
+ "Using learning rate = 0.0005214008287999684\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 275/1250 00:38 < 02:18, 7.02 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.814100 | \n",
+ " 0.394550 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.607200 | \n",
+ " 0.395544 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.477900 | \n",
+ " 0.397824 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.380700 | \n",
+ " 0.397300 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.311600 | \n",
+ " 0.408491 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.268400 | \n",
+ " 0.428093 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.242600 | \n",
+ " 0.437327 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.223300 | \n",
+ " 0.456643 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.207600 | \n",
+ " 0.463043 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.197200 | \n",
+ " 0.468228 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.188800 | \n",
+ " 0.478250 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.3170426542108709 seconds, Total Train Time = 39.83929896354675\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.33640581369400024, 'eval_runtime': 2.3821, 'eval_samples_per_second': 4796.209, 'eval_steps_per_second': 75.144, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Dataset name: ettm2, context length: 1024, prediction length 96\n",
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Data lengths: train = 33441, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.359 0.359\n",
+ "1 etth2 0.269 0.269\n",
+ "2 ettm1 0.337 0.336\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = ettm2, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/1024-96-r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048561:t-22509185540864:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048561:t-22509185540864:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:02]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1764754354953766, 'eval_model_preparation_time': 0.0018, 'eval_runtime': 3.0247, 'eval_samples_per_second': 3777.253, 'eval_steps_per_second': 59.18}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Dataset name: ettm2, context length: 1024, prediction length 96\n",
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Data lengths: train = 1581, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 2964960\n",
+ "Number of params after freezing the backbone 955424\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048561:t-22509185540864:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048561:t-22509185540864:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.000298364724028334\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.000298364724028334\n",
+ "Using learning rate = 0.000298364724028334\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 275/1250 00:39 < 02:19, 6.99 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.495700 | \n",
+ " 0.122071 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.399600 | \n",
+ " 0.122304 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.328300 | \n",
+ " 0.122963 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.242100 | \n",
+ " 0.124153 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.188300 | \n",
+ " 0.127375 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.150100 | \n",
+ " 0.135246 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.133100 | \n",
+ " 0.143912 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.122500 | \n",
+ " 0.151637 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.117400 | \n",
+ " 0.158312 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.111200 | \n",
+ " 0.164967 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.106800 | \n",
+ " 0.169490 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.3009986010464756 seconds, Total Train Time = 40.168370962142944\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.17645052075386047, 'eval_runtime': 2.2382, 'eval_samples_per_second': 5104.458, 'eval_steps_per_second': 79.974, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Dataset name: weather, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.359 0.359\n",
+ "1 etth2 0.269 0.269\n",
+ "2 ettm1 0.337 0.336\n",
+ "3 ettm2 0.176 0.176\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = weather, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/1024-96-r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Data lengths: train = 35768, val = 5175, test = 10444\n",
+ "WARNING:p-3048561:t-22509185540864:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048561:t-22509185540864:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:05]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.15011762082576752, 'eval_model_preparation_time': 0.0021, 'eval_runtime': 5.1462, 'eval_samples_per_second': 2029.447, 'eval_steps_per_second': 31.868}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Dataset name: weather, context length: 1024, prediction length 96\n",
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Data lengths: train = 1698, val = 5175, test = 10444\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 2964960\n",
+ "Number of params after freezing the backbone 955424\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048561:t-22509185540864:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048561:t-22509185540864:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00035938136638046257\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00035938136638046257\n",
+ "Using learning rate = 0.00035938136638046257\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 297/1350 00:44 < 02:38, 6.64 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.153500 | \n",
+ " 0.393854 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.147700 | \n",
+ " 0.399079 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.140500 | \n",
+ " 0.407770 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.130300 | \n",
+ " 0.410832 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.115500 | \n",
+ " 0.407429 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.102600 | \n",
+ " 0.411830 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.092700 | \n",
+ " 0.409271 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.085700 | \n",
+ " 0.415379 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.080700 | \n",
+ " 0.414570 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.076900 | \n",
+ " 0.414594 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.073200 | \n",
+ " 0.414909 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.7708212462338535 seconds, Total Train Time = 45.57381844520569\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:02]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1500033736228943, 'eval_runtime': 3.8799, 'eval_samples_per_second': 2691.848, 'eval_steps_per_second': 42.27, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Dataset name: electricity, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.359 0.359\n",
+ "1 etth2 0.269 0.269\n",
+ "2 ettm1 0.337 0.336\n",
+ "3 ettm2 0.176 0.176\n",
+ "4 weather 0.150 0.150\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = electricity, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/1024-96-r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Data lengths: train = 17293, val = 2537, test = 5165\n",
+ "WARNING:p-3048561:t-22509185540864:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048561:t-22509185540864:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:24]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.15828542411327362, 'eval_model_preparation_time': 0.0018, 'eval_runtime': 25.1925, 'eval_samples_per_second': 205.021, 'eval_steps_per_second': 6.43}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Dataset name: electricity, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Data lengths: train = 774, val = 2537, test = 5165\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 2964960\n",
+ "Number of params after freezing the backbone 955424\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048561:t-22509185540864:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048561:t-22509185540864:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 8.111308307896872e-05\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 8.111308307896872e-05\n",
+ "Using learning rate = 8.111308307896872e-05\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [1250/1250 12:33, Epoch 50/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.154100 | \n",
+ " 0.133550 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.150400 | \n",
+ " 0.133363 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.148100 | \n",
+ " 0.131550 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.147100 | \n",
+ " 0.129834 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.144600 | \n",
+ " 0.128791 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.143500 | \n",
+ " 0.127429 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.140500 | \n",
+ " 0.126259 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.139600 | \n",
+ " 0.125177 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.137400 | \n",
+ " 0.124556 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.134800 | \n",
+ " 0.123992 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.133200 | \n",
+ " 0.123508 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.132400 | \n",
+ " 0.122755 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.130200 | \n",
+ " 0.121776 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.129000 | \n",
+ " 0.121530 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.127400 | \n",
+ " 0.120715 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.127100 | \n",
+ " 0.120709 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.124800 | \n",
+ " 0.120068 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.124400 | \n",
+ " 0.119745 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.123300 | \n",
+ " 0.119724 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.122700 | \n",
+ " 0.119251 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.122700 | \n",
+ " 0.119369 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.122000 | \n",
+ " 0.118602 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.121000 | \n",
+ " 0.118696 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.120500 | \n",
+ " 0.118636 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.120000 | \n",
+ " 0.118489 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.120100 | \n",
+ " 0.118362 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.119200 | \n",
+ " 0.118232 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.119200 | \n",
+ " 0.117863 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 0.118600 | \n",
+ " 0.117886 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 0.118500 | \n",
+ " 0.118188 | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " 0.118100 | \n",
+ " 0.117668 | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " 0.118200 | \n",
+ " 0.117762 | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " 0.117600 | \n",
+ " 0.117765 | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " 0.117500 | \n",
+ " 0.117664 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 0.117400 | \n",
+ " 0.117472 | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " 0.117000 | \n",
+ " 0.117431 | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " 0.117000 | \n",
+ " 0.117465 | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " 0.117400 | \n",
+ " 0.117517 | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " 0.116500 | \n",
+ " 0.117558 | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " 0.116800 | \n",
+ " 0.117487 | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " 0.116600 | \n",
+ " 0.117426 | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " 0.117200 | \n",
+ " 0.117347 | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " 0.116500 | \n",
+ " 0.117497 | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " 0.116800 | \n",
+ " 0.117334 | \n",
+ "
\n",
+ " \n",
+ " 45 | \n",
+ " 0.116200 | \n",
+ " 0.117415 | \n",
+ "
\n",
+ " \n",
+ " 46 | \n",
+ " 0.116500 | \n",
+ " 0.117320 | \n",
+ "
\n",
+ " \n",
+ " 47 | \n",
+ " 0.116600 | \n",
+ " 0.117332 | \n",
+ "
\n",
+ " \n",
+ " 48 | \n",
+ " 0.116700 | \n",
+ " 0.117344 | \n",
+ "
\n",
+ " \n",
+ " 49 | \n",
+ " 0.116400 | \n",
+ " 0.117354 | \n",
+ "
\n",
+ " \n",
+ " 50 | \n",
+ " 0.116100 | \n",
+ " 0.117354 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 5.0357036304473874 seconds, Total Train Time = 755.3726332187653\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:16]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.14718736708164215, 'eval_runtime': 18.5622, 'eval_samples_per_second': 278.254, 'eval_steps_per_second': 8.727, 'epoch': 50.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Dataset name: traffic, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.359 0.359\n",
+ "1 etth2 0.269 0.269\n",
+ "2 ettm1 0.337 0.336\n",
+ "3 ettm2 0.176 0.176\n",
+ "4 weather 0.150 0.150\n",
+ "5 electricity 0.158 0.147\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1024 on dataset = traffic, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/1024-96-r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Data lengths: train = 11161, val = 1661, test = 3413\n",
+ "WARNING:p-3048561:t-22509185540864:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048561:t-22509185540864:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 00:43]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.4737617075443268, 'eval_model_preparation_time': 0.0021, 'eval_runtime': 43.7457, 'eval_samples_per_second': 78.019, 'eval_steps_per_second': 9.761}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Dataset name: traffic, context length: 1024, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048561:t-22509185540864:data_handling.py:load_dataset:Data lengths: train = 467, val = 1661, test = 3413\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 2964960\n",
+ "Number of params after freezing the backbone 955424\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048561:t-22509185540864:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048561:t-22509185540864:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00020565123083486514\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00020565123083486514\n",
+ "Using learning rate = 0.00020565123083486514\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [1652/2950 11:10 < 08:47, 2.46 it/s, Epoch 28/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.306300 | \n",
+ " 0.384197 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.290600 | \n",
+ " 0.380115 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.283100 | \n",
+ " 0.377606 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.275400 | \n",
+ " 0.375396 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.267800 | \n",
+ " 0.371779 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.262100 | \n",
+ " 0.370619 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.257600 | \n",
+ " 0.364189 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.253500 | \n",
+ " 0.361611 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.247700 | \n",
+ " 0.357288 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.244900 | \n",
+ " 0.354975 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.240500 | \n",
+ " 0.355310 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.236600 | \n",
+ " 0.355367 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.234100 | \n",
+ " 0.349950 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.232100 | \n",
+ " 0.353106 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.229900 | \n",
+ " 0.352745 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.226000 | \n",
+ " 0.347146 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.223000 | \n",
+ " 0.356564 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.221000 | \n",
+ " 0.345382 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.220800 | \n",
+ " 0.349640 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.217800 | \n",
+ " 0.355157 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.217000 | \n",
+ " 0.356424 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.213900 | \n",
+ " 0.349901 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.212700 | \n",
+ " 0.355637 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.212000 | \n",
+ " 0.349804 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.210700 | \n",
+ " 0.348401 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.208500 | \n",
+ " 0.356707 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.207000 | \n",
+ " 0.348334 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.207000 | \n",
+ " 0.350621 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 7.516440740653446 seconds, Total Train Time = 672.2011640071869\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 00:30]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.4179241955280304, 'eval_runtime': 31.9674, 'eval_samples_per_second': 106.765, 'eval_steps_per_second': 13.357, 'epoch': 28.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.359 0.359\n",
+ "1 etth2 0.269 0.269\n",
+ "2 ettm1 0.337 0.336\n",
+ "3 ettm2 0.176 0.176\n",
+ "4 weather 0.150 0.150\n",
+ "5 electricity 0.158 0.147\n",
+ "6 traffic 0.474 0.418\n"
+ ]
+ }
+ ],
+ "source": [
+ "all_results = {\n",
+ " \"dataset\": [],\n",
+ " \"zs_mse\": [],\n",
+ " \"fs5_mse\": [],\n",
+ " \"zs_eval_time\": [],\n",
+ " \"fs5_mean_epoch_time\": [],\n",
+ " \"fs5_total_train_time\": [],\n",
+ " \"fs5_best_val_metric\": [],\n",
+ "}\n",
+ "# Loop over data\n",
+ "for DATASET in list_datasets:\n",
+ " print()\n",
+ " print(\"=\" * 100)\n",
+ " print(\n",
+ " f\"Running zero-shot/few-shot for TTM-{context_length} on dataset = {DATASET}, forecast_len = {forecast_length}\"\n",
+ " )\n",
+ " print(f\"Model will be loaded from {hf_model_path}/{hf_model_branch}\")\n",
+ " SUBDIR = f\"{OUT_DIR}/{DATASET}\"\n",
+ "\n",
+ " # Set batch size\n",
+ " if DATASET == \"traffic\":\n",
+ " BATCH_SIZE = 8\n",
+ " elif DATASET == \"electricity\":\n",
+ " BATCH_SIZE = 32\n",
+ " else:\n",
+ " BATCH_SIZE = 64\n",
+ "\n",
+ " # Data prep: Get dataset\n",
+ " _, _, dset_test = load_dataset(DATASET, context_length, forecast_length, dataset_root_path=DATA_ROOT_PATH)\n",
+ "\n",
+ " #############################################################\n",
+ " ##### Use the pretrained model in zero-shot forecasting #####\n",
+ " #############################################################\n",
+ " # Load model\n",
+ " zeroshot_model = TinyTimeMixerForPrediction.from_pretrained(hf_model_path, revision=hf_model_branch)\n",
+ "\n",
+ " # zeroshot_trainer\n",
+ " zeroshot_trainer = Trainer(\n",
+ " model=zeroshot_model,\n",
+ " args=TrainingArguments(\n",
+ " output_dir=f\"{SUBDIR}/zeroshot\",\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " seed=SEED,\n",
+ " ),\n",
+ " eval_dataset=dset_test,\n",
+ " )\n",
+ "\n",
+ " # evaluate = zero-shot performance\n",
+ " print(\"+\" * 20, \"Test MSE zero-shot\", \"+\" * 20)\n",
+ " zeroshot_output = zeroshot_trainer.evaluate(dset_test)\n",
+ " print(zeroshot_output)\n",
+ " print(\"+\" * 60)\n",
+ " all_results[\"zs_eval_time\"].append(zeroshot_output[\"eval_runtime\"])\n",
+ "\n",
+ " # Plot\n",
+ " plot_predictions(\n",
+ " model=zeroshot_trainer.model,\n",
+ " dset=dset_test,\n",
+ " plot_dir=SUBDIR,\n",
+ " num_plots=10,\n",
+ " plot_prefix=\"test_zeroshot\",\n",
+ " channel=0,\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " # write results\n",
+ " all_results[\"dataset\"].append(DATASET)\n",
+ " all_results[\"zs_mse\"].append(zeroshot_output[\"eval_loss\"])\n",
+ "\n",
+ " ################################################################\n",
+ " ## Use the pretrained model in few-shot 5% and 10% forecasting #\n",
+ " ################################################################\n",
+ " for fewshot_percent in [5]:\n",
+ " # Set learning rate\n",
+ " learning_rate = None # `None` value indicates that the optimal_lr_finder() will be used\n",
+ "\n",
+ " print(\"-\" * 20, f\"Running few-shot {fewshot_percent}%\", \"-\" * 20)\n",
+ " # Data prep: Get dataset\n",
+ " dset_train, dset_val, dset_test = load_dataset(\n",
+ " DATASET,\n",
+ " context_length,\n",
+ " forecast_length,\n",
+ " fewshot_fraction=fewshot_percent / 100,\n",
+ " dataset_root_path=DATA_ROOT_PATH,\n",
+ " )\n",
+ "\n",
+ " # change head dropout to 0.7 for ett datasets\n",
+ " if \"ett\" in DATASET:\n",
+ " finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(\n",
+ " hf_model_path, revision=hf_model_branch, head_dropout=0.7\n",
+ " )\n",
+ " else:\n",
+ " finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(\n",
+ " hf_model_path, revision=hf_model_branch\n",
+ " )\n",
+ "\n",
+ " if freeze_backbone:\n",
+ " print(\n",
+ " \"Number of params before freezing backbone\",\n",
+ " count_parameters(finetune_forecast_model),\n",
+ " )\n",
+ "\n",
+ " # Freeze the backbone of the model\n",
+ " for param in finetune_forecast_model.backbone.parameters():\n",
+ " param.requires_grad = False\n",
+ "\n",
+ " # Count params\n",
+ " print(\n",
+ " \"Number of params after freezing the backbone\",\n",
+ " count_parameters(finetune_forecast_model),\n",
+ " )\n",
+ "\n",
+ " if learning_rate is None:\n",
+ " learning_rate, finetune_forecast_model = optimal_lr_finder(\n",
+ " finetune_forecast_model,\n",
+ " dset_train,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " )\n",
+ " print(\"OPTIMAL SUGGESTED LEARNING RATE =\", learning_rate)\n",
+ "\n",
+ " print(f\"Using learning rate = {learning_rate}\")\n",
+ " finetune_forecast_args = TrainingArguments(\n",
+ " output_dir=f\"{SUBDIR}/fewshot_{fewshot_percent}\",\n",
+ " overwrite_output_dir=True,\n",
+ " learning_rate=learning_rate,\n",
+ " num_train_epochs=EPOCHS,\n",
+ " do_eval=True,\n",
+ " evaluation_strategy=\"epoch\",\n",
+ " per_device_train_batch_size=BATCH_SIZE,\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " dataloader_num_workers=NUM_WORKERS,\n",
+ " report_to=None,\n",
+ " save_strategy=\"epoch\",\n",
+ " logging_strategy=\"epoch\",\n",
+ " save_total_limit=1,\n",
+ " logging_dir=f\"{SUBDIR}/fewshot_{fewshot_percent}\", # Make sure to specify a logging directory\n",
+ " load_best_model_at_end=True, # Load the best model when training ends\n",
+ " metric_for_best_model=\"eval_loss\", # Metric to monitor for early stopping\n",
+ " greater_is_better=False, # For loss\n",
+ " seed=SEED,\n",
+ " )\n",
+ "\n",
+ " # Create the early stopping callback\n",
+ " early_stopping_callback = EarlyStoppingCallback(\n",
+ " early_stopping_patience=10, # Number of epochs with no improvement after which to stop\n",
+ " early_stopping_threshold=0.0, # Minimum improvement required to consider as improvement\n",
+ " )\n",
+ " tracking_callback = TrackingCallback()\n",
+ "\n",
+ " # Optimizer and scheduler\n",
+ " optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate)\n",
+ " scheduler = OneCycleLR(\n",
+ " optimizer,\n",
+ " learning_rate,\n",
+ " epochs=EPOCHS,\n",
+ " steps_per_epoch=math.ceil(len(dset_train) / (BATCH_SIZE)),\n",
+ " )\n",
+ "\n",
+ " finetune_forecast_trainer = Trainer(\n",
+ " model=finetune_forecast_model,\n",
+ " args=finetune_forecast_args,\n",
+ " train_dataset=dset_train,\n",
+ " eval_dataset=dset_val,\n",
+ " callbacks=[early_stopping_callback, tracking_callback],\n",
+ " optimizers=(optimizer, scheduler),\n",
+ " )\n",
+ " finetune_forecast_trainer.remove_callback(INTEGRATION_TO_CALLBACK[\"codecarbon\"])\n",
+ "\n",
+ " # Fine tune\n",
+ " finetune_forecast_trainer.train()\n",
+ "\n",
+ " # Evaluation\n",
+ " print(\n",
+ " \"+\" * 20,\n",
+ " f\"Test MSE after few-shot {fewshot_percent}% fine-tuning\",\n",
+ " \"+\" * 20,\n",
+ " )\n",
+ " fewshot_output = finetune_forecast_trainer.evaluate(dset_test)\n",
+ " print(fewshot_output)\n",
+ " print(\"+\" * 60)\n",
+ "\n",
+ " # Plot\n",
+ " plot_predictions(\n",
+ " model=finetune_forecast_trainer.model,\n",
+ " dset=dset_test,\n",
+ " plot_dir=SUBDIR,\n",
+ " num_plots=10,\n",
+ " plot_prefix=f\"test_fewshot_{fewshot_percent}\",\n",
+ " channel=0,\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " # write results\n",
+ " all_results[f\"fs{fewshot_percent}_mse\"].append(fewshot_output[\"eval_loss\"])\n",
+ " all_results[f\"fs{fewshot_percent}_mean_epoch_time\"].append(tracking_callback.mean_epoch_time)\n",
+ " all_results[f\"fs{fewshot_percent}_total_train_time\"].append(tracking_callback.total_train_time)\n",
+ " all_results[f\"fs{fewshot_percent}_best_val_metric\"].append(tracking_callback.best_eval_metric)\n",
+ "\n",
+ " df_out = pd.DataFrame(all_results).round(3)\n",
+ " print(df_out[[\"dataset\", \"zs_mse\", \"fs5_mse\"]])\n",
+ " df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")\n",
+ " df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Benchmarking results*\n",
+ "\n",
+ "*Some slight differences in the results as compared to the TTM paper results is possible due to different training environments."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " dataset | \n",
+ " zs_mse | \n",
+ " fs5_mse | \n",
+ " zs_eval_time | \n",
+ " fs5_mean_epoch_time | \n",
+ " fs5_total_train_time | \n",
+ " fs5_best_val_metric | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " etth1 | \n",
+ " 0.359 | \n",
+ " 0.359 | \n",
+ " 1.720 | \n",
+ " 1.078 | \n",
+ " 27.663 | \n",
+ " 0.666 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " etth2 | \n",
+ " 0.269 | \n",
+ " 0.269 | \n",
+ " 0.747 | \n",
+ " 1.019 | \n",
+ " 26.783 | \n",
+ " 0.239 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " ettm1 | \n",
+ " 0.337 | \n",
+ " 0.336 | \n",
+ " 3.059 | \n",
+ " 1.317 | \n",
+ " 39.839 | \n",
+ " 0.395 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " ettm2 | \n",
+ " 0.176 | \n",
+ " 0.176 | \n",
+ " 3.025 | \n",
+ " 1.301 | \n",
+ " 40.168 | \n",
+ " 0.122 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " weather | \n",
+ " 0.150 | \n",
+ " 0.150 | \n",
+ " 5.146 | \n",
+ " 1.771 | \n",
+ " 45.574 | \n",
+ " 0.394 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " electricity | \n",
+ " 0.158 | \n",
+ " 0.147 | \n",
+ " 25.192 | \n",
+ " 5.036 | \n",
+ " 755.373 | \n",
+ " 0.117 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " traffic | \n",
+ " 0.474 | \n",
+ " 0.418 | \n",
+ " 43.746 | \n",
+ " 7.516 | \n",
+ " 672.201 | \n",
+ " 0.345 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " dataset zs_mse fs5_mse zs_eval_time fs5_mean_epoch_time \\\n",
+ "0 etth1 0.359 0.359 1.720 1.078 \n",
+ "1 etth2 0.269 0.269 0.747 1.019 \n",
+ "2 ettm1 0.337 0.336 3.059 1.317 \n",
+ "3 ettm2 0.176 0.176 3.025 1.301 \n",
+ "4 weather 0.150 0.150 5.146 1.771 \n",
+ "5 electricity 0.158 0.147 25.192 5.036 \n",
+ "6 traffic 0.474 0.418 43.746 7.516 \n",
+ "\n",
+ " fs5_total_train_time fs5_best_val_metric \n",
+ "0 27.663 0.666 \n",
+ "1 26.783 0.239 \n",
+ "2 39.839 0.395 \n",
+ "3 40.168 0.122 \n",
+ "4 45.574 0.394 \n",
+ "5 755.373 0.117 \n",
+ "6 672.201 0.345 "
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_out"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/hfdemo/tinytimemixer/ttm-r2_benchmarking_1536_96.ipynb b/notebooks/hfdemo/tinytimemixer/ttm-r2_benchmarking_1536_96.ipynb
new file mode 100644
index 00000000..19f40653
--- /dev/null
+++ b/notebooks/hfdemo/tinytimemixer/ttm-r2_benchmarking_1536_96.ipynb
@@ -0,0 +1,2353 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# TTM zero-shot and few-shot benchmarking on multiple datasets\n",
+ "\n",
+ "**Using TTM-1536-96 model.**\n",
+ "\n",
+ "Pre-trained TTM models will be fetched from the [Hugging Face TTM Model Repository](ibm-granite/granite-timeseries-ttm-r2).\n",
+ "\n",
+ "1. TTM-R1 pre-trained models can be found here: [TTM-R1 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r1)\n",
+ " 1. For 512-96 model set `TTM_MODEL_REVISION=\"main\"`\n",
+ " 2. For 1024-96 model set `TTM_MODEL_REVISION=\"1024_96_v1\"`\n",
+ "2. TTM-R2 pre-trained models can be found here: [TTM-R2 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2)\n",
+ " 1. For 512-96 model set `TTM_MODEL_REVISION=\"main\"`\n",
+ " 2. For 1024-96 model set `TTM_MODEL_REVISION=\"1024-96-r2\"`\n",
+ " 3. For 1536-96 model set `TTM_MODEL_REVISION=\"1536-96-r2\"`\n",
+ "\n",
+ "Details about the revisions (R1 and R2) can be found [here](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2)."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2024-10-10 07:15:38.441950: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
+ "2024-10-10 07:15:38.481580: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2024-10-10 07:15:39.205059: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory\n",
+ " warn(f\"Failed to load image Python extension: {e}\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "import math\n",
+ "import warnings\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd\n",
+ "from torch.optim import AdamW\n",
+ "from torch.optim.lr_scheduler import OneCycleLR\n",
+ "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n",
+ "from transformers.integrations import INTEGRATION_TO_CALLBACK\n",
+ "\n",
+ "from tsfm_public import TinyTimeMixerForPrediction, TrackingCallback, count_parameters, load_dataset\n",
+ "from tsfm_public.toolkit.lr_finder import optimal_lr_finder\n",
+ "from tsfm_public.toolkit.visualization import plot_predictions\n",
+ "\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Important arguments"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set seed\n",
+ "SEED = 42\n",
+ "set_seed(SEED)\n",
+ "\n",
+ "# Specify model parameters\n",
+ "context_length = 1536\n",
+ "forecast_length = 96\n",
+ "freeze_backbone = True\n",
+ "\n",
+ "# Other args\n",
+ "EPOCHS = 50\n",
+ "NUM_WORKERS = 16\n",
+ "\n",
+ "# Make sure all the datasets in the following `list_datasets` are\n",
+ "# saved in the `DATA_ROOT_PATH` folder. Or, change it accordingly.\n",
+ "# Refer to the load_datasets() function\n",
+ "# in notebooks/hfdemo/tinytimemixer/utils/ttm_utils.py\n",
+ "# to see how it is used.\n",
+ "DATA_ROOT_PATH = \"/dccstor/tsfm23/datasets/\"\n",
+ "\n",
+ "# This is where results will be saved\n",
+ "OUT_DIR = f\"ttm-r2_results_benchmark_{context_length}_{forecast_length}/\""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## List of benchmark datasets (TTM was not pre-trained on any of these)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "list_datasets = [\n",
+ " \"etth1\",\n",
+ " \"etth2\",\n",
+ " \"ettm1\",\n",
+ " \"ettm2\",\n",
+ " \"weather\",\n",
+ " \"electricity\",\n",
+ " \"traffic\",\n",
+ "]"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Get model path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Please provide the branch name properly based on context_len and forecast_len\n",
+ "hf_model_path = \"ibm-granite/granite-timeseries-ttm-r2\"\n",
+ "hf_model_branch = f\"{context_length}-{forecast_length}-r2\""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Main benchmarking loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Dataset name: etth1, context length: 1536, prediction length 96\n",
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Data lengths: train = 7009, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1536 on dataset = etth1, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/1536-96-r2\n"
+ ]
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "87646c2e40c54efda572d0951f308a82",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "config.json: 0%| | 0.00/1.57k [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "3f9fbbef80c744659259bbf89636c232",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ "model.safetensors: 0%| | 0.00/12.3M [00:00, ?B/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048562:t-22362052518656:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048562:t-22362052518656:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3570095896720886, 'eval_model_preparation_time': 0.0024, 'eval_runtime': 1.9963, 'eval_samples_per_second': 1395.071, 'eval_steps_per_second': 22.041}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Dataset name: etth1, context length: 1536, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Data lengths: train = 260, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 3081120\n",
+ "Number of params after freezing the backbone 1054560\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048562:t-22362052518656:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048562:t-22362052518656:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.000298364724028334\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.000298364724028334\n",
+ "Using learning rate = 0.000298364724028334\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 55/250 00:26 < 01:36, 2.02 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.612300 | \n",
+ " 0.655407 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.593900 | \n",
+ " 0.656050 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.519100 | \n",
+ " 0.656867 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.480800 | \n",
+ " 0.658155 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.431600 | \n",
+ " 0.659995 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.384700 | \n",
+ " 0.662317 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.355300 | \n",
+ " 0.668283 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.308800 | \n",
+ " 0.689046 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.265600 | \n",
+ " 0.715355 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.248000 | \n",
+ " 0.734134 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.228800 | \n",
+ " 0.771885 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.0921787131916394 seconds, Total Train Time = 28.327817678451538\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3571341633796692, 'eval_runtime': 1.4299, 'eval_samples_per_second': 1947.631, 'eval_steps_per_second': 30.77, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Dataset name: etth2, context length: 1536, prediction length 96\n",
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Data lengths: train = 7009, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.357 0.357\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1536 on dataset = etth2, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/1536-96-r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048562:t-22362052518656:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048562:t-22362052518656:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.2743358612060547, 'eval_model_preparation_time': 0.0019, 'eval_runtime': 0.9901, 'eval_samples_per_second': 2812.989, 'eval_steps_per_second': 44.442}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Dataset name: etth2, context length: 1536, prediction length 96\n",
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Data lengths: train = 260, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 3081120\n",
+ "Number of params after freezing the backbone 1054560\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048562:t-22362052518656:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048562:t-22362052518656:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00020565123083486514\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00020565123083486514\n",
+ "Using learning rate = 0.00020565123083486514\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 95/250 00:46 < 01:17, 2.00 it/s, Epoch 19/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.435300 | \n",
+ " 0.229630 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.388800 | \n",
+ " 0.230058 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.323200 | \n",
+ " 0.231052 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.385600 | \n",
+ " 0.232311 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.298400 | \n",
+ " 0.233664 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.243800 | \n",
+ " 0.234015 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.211400 | \n",
+ " 0.232407 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.186800 | \n",
+ " 0.228532 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.180600 | \n",
+ " 0.228105 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.137400 | \n",
+ " 0.232864 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.138800 | \n",
+ " 0.238103 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.124100 | \n",
+ " 0.240933 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.116900 | \n",
+ " 0.248530 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.107300 | \n",
+ " 0.249423 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.113300 | \n",
+ " 0.250719 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.103300 | \n",
+ " 0.255713 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.099300 | \n",
+ " 0.260282 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.097700 | \n",
+ " 0.261335 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.094600 | \n",
+ " 0.260480 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.0002899671855725 seconds, Total Train Time = 47.401732206344604\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.27716049551963806, 'eval_runtime': 1.386, 'eval_samples_per_second': 2009.436, 'eval_steps_per_second': 31.747, 'epoch': 19.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Dataset name: ettm1, context length: 1536, prediction length 96\n",
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Data lengths: train = 32929, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.357 0.357\n",
+ "1 etth2 0.274 0.277\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1536 on dataset = ettm1, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/1536-96-r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048562:t-22362052518656:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048562:t-22362052518656:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:03]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.32653480768203735, 'eval_model_preparation_time': 0.0018, 'eval_runtime': 3.4627, 'eval_samples_per_second': 3299.436, 'eval_steps_per_second': 51.694}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Dataset name: ettm1, context length: 1536, prediction length 96\n",
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Data lengths: train = 1556, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 3081120\n",
+ "Number of params after freezing the backbone 1054560\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048562:t-22362052518656:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048562:t-22362052518656:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00043287612810830566\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00043287612810830566\n",
+ "Using learning rate = 0.00043287612810830566\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 275/1250 00:44 < 02:37, 6.20 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.715300 | \n",
+ " 0.400856 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.474800 | \n",
+ " 0.420347 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.359000 | \n",
+ " 0.452630 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.325600 | \n",
+ " 0.455598 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.297800 | \n",
+ " 0.474598 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.275900 | \n",
+ " 0.478588 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.262200 | \n",
+ " 0.467313 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.248000 | \n",
+ " 0.475465 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.234600 | \n",
+ " 0.459779 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.225500 | \n",
+ " 0.477715 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.215400 | \n",
+ " 0.466766 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.4104089736938477 seconds, Total Train Time = 44.97520208358765\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3312471807003021, 'eval_runtime': 2.5794, 'eval_samples_per_second': 4429.24, 'eval_steps_per_second': 69.395, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Dataset name: ettm2, context length: 1536, prediction length 96\n",
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Data lengths: train = 32929, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.357 0.357\n",
+ "1 etth2 0.274 0.277\n",
+ "2 ettm1 0.327 0.331\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1536 on dataset = ettm2, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/1536-96-r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048562:t-22362052518656:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048562:t-22362052518656:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:03]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.16795998811721802, 'eval_model_preparation_time': 0.0018, 'eval_runtime': 3.518, 'eval_samples_per_second': 3247.549, 'eval_steps_per_second': 50.881}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Dataset name: ettm2, context length: 1536, prediction length 96\n",
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Data lengths: train = 1556, val = 11425, test = 11425\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 3081120\n",
+ "Number of params after freezing the backbone 1054560\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048562:t-22362052518656:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048562:t-22362052518656:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00011768119524349978\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00011768119524349978\n",
+ "Using learning rate = 0.00011768119524349978\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 275/1250 00:44 < 02:37, 6.18 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.471800 | \n",
+ " 0.123267 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.337700 | \n",
+ " 0.124431 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.252800 | \n",
+ " 0.126874 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.176000 | \n",
+ " 0.131680 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.135800 | \n",
+ " 0.141091 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.117000 | \n",
+ " 0.147765 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.108400 | \n",
+ " 0.156903 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.103900 | \n",
+ " 0.162671 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.100000 | \n",
+ " 0.170844 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.097000 | \n",
+ " 0.176793 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.093700 | \n",
+ " 0.181367 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.3997448791157117 seconds, Total Train Time = 45.14118027687073\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [179/179 00:01]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.1680709272623062, 'eval_runtime': 2.6241, 'eval_samples_per_second': 4353.841, 'eval_steps_per_second': 68.213, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Dataset name: weather, context length: 1536, prediction length 96\n",
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Data lengths: train = 35256, val = 5175, test = 10444\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.357 0.357\n",
+ "1 etth2 0.274 0.277\n",
+ "2 ettm1 0.327 0.331\n",
+ "3 ettm2 0.168 0.168\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1536 on dataset = weather, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/1536-96-r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048562:t-22362052518656:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048562:t-22362052518656:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:06]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.14976251125335693, 'eval_model_preparation_time': 0.0021, 'eval_runtime': 6.5327, 'eval_samples_per_second': 1598.717, 'eval_steps_per_second': 25.104}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Dataset name: weather, context length: 1536, prediction length 96\n",
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Data lengths: train = 1672, val = 5175, test = 10444\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n",
+ "Number of params before freezing backbone 3081120\n",
+ "Number of params after freezing the backbone 1054560\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048562:t-22362052518656:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048562:t-22362052518656:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00020565123083486514\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00020565123083486514\n",
+ "Using learning rate = 0.00020565123083486514\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 297/1350 00:53 < 03:10, 5.53 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.097900 | \n",
+ " 0.393768 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.095700 | \n",
+ " 0.397849 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.092900 | \n",
+ " 0.404240 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.088900 | \n",
+ " 0.411644 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.084900 | \n",
+ " 0.410327 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.081700 | \n",
+ " 0.414159 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.077400 | \n",
+ " 0.414830 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.073400 | \n",
+ " 0.416132 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.068600 | \n",
+ " 0.428362 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.065000 | \n",
+ " 0.419456 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.062700 | \n",
+ " 0.418077 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.9716925404288552 seconds, Total Train Time = 54.368701219558716\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [164/164 00:03]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.14924383163452148, 'eval_runtime': 4.6955, 'eval_samples_per_second': 2224.257, 'eval_steps_per_second': 34.927, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Dataset name: electricity, context length: 1536, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.357 0.357\n",
+ "1 etth2 0.274 0.277\n",
+ "2 ettm1 0.327 0.331\n",
+ "3 ettm2 0.168 0.168\n",
+ "4 weather 0.150 0.149\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1536 on dataset = electricity, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/1536-96-r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Data lengths: train = 16781, val = 2537, test = 5165\n",
+ "WARNING:p-3048562:t-22362052518656:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048562:t-22362052518656:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:34]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.15529614686965942, 'eval_model_preparation_time': 0.0019, 'eval_runtime': 34.5318, 'eval_samples_per_second': 149.572, 'eval_steps_per_second': 4.691}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Dataset name: electricity, context length: 1536, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Data lengths: train = 748, val = 2537, test = 5165\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 3081120\n",
+ "Number of params after freezing the backbone 1054560\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048562:t-22362052518656:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048562:t-22362052518656:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00020565123083486514\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00020565123083486514\n",
+ "Using learning rate = 0.00020565123083486514\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [1104/1200 16:02 < 01:23, 1.15 it/s, Epoch 46/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.143700 | \n",
+ " 0.129405 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.140000 | \n",
+ " 0.127710 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.137600 | \n",
+ " 0.126163 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.135500 | \n",
+ " 0.124611 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.133200 | \n",
+ " 0.123532 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.130900 | \n",
+ " 0.122066 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.129100 | \n",
+ " 0.121844 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.127300 | \n",
+ " 0.120507 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.125600 | \n",
+ " 0.119225 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.123300 | \n",
+ " 0.119105 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.121300 | \n",
+ " 0.117542 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.120600 | \n",
+ " 0.117430 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.118600 | \n",
+ " 0.116615 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.118100 | \n",
+ " 0.117184 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.116800 | \n",
+ " 0.115890 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.116400 | \n",
+ " 0.116175 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.115300 | \n",
+ " 0.115326 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.114800 | \n",
+ " 0.114901 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.114000 | \n",
+ " 0.114714 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.112800 | \n",
+ " 0.114350 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.112600 | \n",
+ " 0.114116 | \n",
+ "
\n",
+ " \n",
+ " 22 | \n",
+ " 0.111900 | \n",
+ " 0.113912 | \n",
+ "
\n",
+ " \n",
+ " 23 | \n",
+ " 0.111600 | \n",
+ " 0.113825 | \n",
+ "
\n",
+ " \n",
+ " 24 | \n",
+ " 0.111300 | \n",
+ " 0.113824 | \n",
+ "
\n",
+ " \n",
+ " 25 | \n",
+ " 0.111200 | \n",
+ " 0.113436 | \n",
+ "
\n",
+ " \n",
+ " 26 | \n",
+ " 0.110900 | \n",
+ " 0.113308 | \n",
+ "
\n",
+ " \n",
+ " 27 | \n",
+ " 0.110400 | \n",
+ " 0.114188 | \n",
+ "
\n",
+ " \n",
+ " 28 | \n",
+ " 0.110100 | \n",
+ " 0.113093 | \n",
+ "
\n",
+ " \n",
+ " 29 | \n",
+ " 0.109600 | \n",
+ " 0.113151 | \n",
+ "
\n",
+ " \n",
+ " 30 | \n",
+ " 0.109400 | \n",
+ " 0.113164 | \n",
+ "
\n",
+ " \n",
+ " 31 | \n",
+ " 0.109200 | \n",
+ " 0.113394 | \n",
+ "
\n",
+ " \n",
+ " 32 | \n",
+ " 0.109000 | \n",
+ " 0.113235 | \n",
+ "
\n",
+ " \n",
+ " 33 | \n",
+ " 0.109100 | \n",
+ " 0.113087 | \n",
+ "
\n",
+ " \n",
+ " 34 | \n",
+ " 0.108800 | \n",
+ " 0.113114 | \n",
+ "
\n",
+ " \n",
+ " 35 | \n",
+ " 0.108800 | \n",
+ " 0.112943 | \n",
+ "
\n",
+ " \n",
+ " 36 | \n",
+ " 0.108400 | \n",
+ " 0.112619 | \n",
+ "
\n",
+ " \n",
+ " 37 | \n",
+ " 0.108100 | \n",
+ " 0.113038 | \n",
+ "
\n",
+ " \n",
+ " 38 | \n",
+ " 0.107900 | \n",
+ " 0.113113 | \n",
+ "
\n",
+ " \n",
+ " 39 | \n",
+ " 0.108100 | \n",
+ " 0.112789 | \n",
+ "
\n",
+ " \n",
+ " 40 | \n",
+ " 0.108500 | \n",
+ " 0.112672 | \n",
+ "
\n",
+ " \n",
+ " 41 | \n",
+ " 0.108100 | \n",
+ " 0.112766 | \n",
+ "
\n",
+ " \n",
+ " 42 | \n",
+ " 0.107800 | \n",
+ " 0.112637 | \n",
+ "
\n",
+ " \n",
+ " 43 | \n",
+ " 0.108100 | \n",
+ " 0.112631 | \n",
+ "
\n",
+ " \n",
+ " 44 | \n",
+ " 0.107500 | \n",
+ " 0.112633 | \n",
+ "
\n",
+ " \n",
+ " 45 | \n",
+ " 0.108000 | \n",
+ " 0.112636 | \n",
+ "
\n",
+ " \n",
+ " 46 | \n",
+ " 0.107900 | \n",
+ " 0.112627 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 6.774247827737228 seconds, Total Train Time = 964.9241693019867\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [162/162 00:25]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.13803862035274506, 'eval_runtime': 26.8199, 'eval_samples_per_second': 192.581, 'eval_steps_per_second': 6.04, 'epoch': 46.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Dataset name: traffic, context length: 1536, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.357 0.357\n",
+ "1 etth2 0.274 0.277\n",
+ "2 ettm1 0.327 0.331\n",
+ "3 ettm2 0.168 0.168\n",
+ "4 weather 0.150 0.149\n",
+ "5 electricity 0.155 0.138\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-1536 on dataset = traffic, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/1536-96-r2\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Data lengths: train = 10649, val = 1661, test = 3413\n",
+ "WARNING:p-3048562:t-22362052518656:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048562:t-22362052518656:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 01:02]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.4634234607219696, 'eval_model_preparation_time': 0.0019, 'eval_runtime': 62.6042, 'eval_samples_per_second': 54.517, 'eval_steps_per_second': 6.821}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Dataset name: traffic, context length: 1536, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:data_handling.py:load_dataset:Data lengths: train = 442, val = 1661, test = 3413\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 3081120\n",
+ "Number of params after freezing the backbone 1054560\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048562:t-22362052518656:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 5.590810182512223e-05\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 5.590810182512223e-05\n",
+ "Using learning rate = 5.590810182512223e-05\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048562:t-22362052518656:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [ 616/2800 06:03 < 21:32, 1.69 it/s, Epoch 11/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.298800 | \n",
+ " 0.391451 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.286600 | \n",
+ " 0.393831 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.275300 | \n",
+ " 0.394873 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.264500 | \n",
+ " 0.396028 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.253800 | \n",
+ " 0.400728 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.245400 | \n",
+ " 0.404270 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.237900 | \n",
+ " 0.408866 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.231600 | \n",
+ " 0.409725 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.226200 | \n",
+ " 0.410739 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.221300 | \n",
+ " 0.412317 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.216800 | \n",
+ " 0.414294 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 9.786083113063466 seconds, Total Train Time = 365.34510469436646\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [427/427 00:44]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.46613699197769165, 'eval_runtime': 46.0615, 'eval_samples_per_second': 74.097, 'eval_steps_per_second': 9.27, 'epoch': 11.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n",
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.357 0.357\n",
+ "1 etth2 0.274 0.277\n",
+ "2 ettm1 0.327 0.331\n",
+ "3 ettm2 0.168 0.168\n",
+ "4 weather 0.150 0.149\n",
+ "5 electricity 0.155 0.138\n",
+ "6 traffic 0.463 0.466\n"
+ ]
+ }
+ ],
+ "source": [
+ "all_results = {\n",
+ " \"dataset\": [],\n",
+ " \"zs_mse\": [],\n",
+ " \"fs5_mse\": [],\n",
+ " \"zs_eval_time\": [],\n",
+ " \"fs5_mean_epoch_time\": [],\n",
+ " \"fs5_total_train_time\": [],\n",
+ " \"fs5_best_val_metric\": [],\n",
+ "}\n",
+ "# Loop over data\n",
+ "for DATASET in list_datasets:\n",
+ " print()\n",
+ " print(\"=\" * 100)\n",
+ " print(\n",
+ " f\"Running zero-shot/few-shot for TTM-{context_length} on dataset = {DATASET}, forecast_len = {forecast_length}\"\n",
+ " )\n",
+ " print(f\"Model will be loaded from {hf_model_path}/{hf_model_branch}\")\n",
+ " SUBDIR = f\"{OUT_DIR}/{DATASET}\"\n",
+ "\n",
+ " # Set batch size\n",
+ " if DATASET == \"traffic\":\n",
+ " BATCH_SIZE = 8\n",
+ " elif DATASET == \"electricity\":\n",
+ " BATCH_SIZE = 32\n",
+ " else:\n",
+ " BATCH_SIZE = 64\n",
+ "\n",
+ " # Data prep: Get dataset\n",
+ " _, _, dset_test = load_dataset(DATASET, context_length, forecast_length, dataset_root_path=DATA_ROOT_PATH)\n",
+ "\n",
+ " #############################################################\n",
+ " ##### Use the pretrained model in zero-shot forecasting #####\n",
+ " #############################################################\n",
+ " # Load model\n",
+ " zeroshot_model = TinyTimeMixerForPrediction.from_pretrained(hf_model_path, revision=hf_model_branch)\n",
+ "\n",
+ " # zeroshot_trainer\n",
+ " zeroshot_trainer = Trainer(\n",
+ " model=zeroshot_model,\n",
+ " args=TrainingArguments(\n",
+ " output_dir=f\"{SUBDIR}/zeroshot\",\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " seed=SEED,\n",
+ " ),\n",
+ " eval_dataset=dset_test,\n",
+ " )\n",
+ "\n",
+ " # evaluate = zero-shot performance\n",
+ " print(\"+\" * 20, \"Test MSE zero-shot\", \"+\" * 20)\n",
+ " zeroshot_output = zeroshot_trainer.evaluate(dset_test)\n",
+ " print(zeroshot_output)\n",
+ " print(\"+\" * 60)\n",
+ " all_results[\"zs_eval_time\"].append(zeroshot_output[\"eval_runtime\"])\n",
+ "\n",
+ " # Plot\n",
+ " plot_predictions(\n",
+ " model=zeroshot_trainer.model,\n",
+ " dset=dset_test,\n",
+ " plot_dir=SUBDIR,\n",
+ " num_plots=10,\n",
+ " plot_prefix=\"test_zeroshot\",\n",
+ " channel=0,\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " # write results\n",
+ " all_results[\"dataset\"].append(DATASET)\n",
+ " all_results[\"zs_mse\"].append(zeroshot_output[\"eval_loss\"])\n",
+ "\n",
+ " ################################################################\n",
+ " ## Use the pretrained model in few-shot 5% and 10% forecasting #\n",
+ " ################################################################\n",
+ " for fewshot_percent in [5]:\n",
+ " # Set learning rate\n",
+ " learning_rate = None # `None` value indicates that the optimal_lr_finder() will be used\n",
+ "\n",
+ " print(\"-\" * 20, f\"Running few-shot {fewshot_percent}%\", \"-\" * 20)\n",
+ " # Data prep: Get dataset\n",
+ " dset_train, dset_val, dset_test = load_dataset(\n",
+ " DATASET,\n",
+ " context_length,\n",
+ " forecast_length,\n",
+ " fewshot_fraction=fewshot_percent / 100,\n",
+ " dataset_root_path=DATA_ROOT_PATH,\n",
+ " )\n",
+ "\n",
+ " # change head dropout to 0.7 for ett datasets\n",
+ " if \"ett\" in DATASET:\n",
+ " finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(\n",
+ " hf_model_path, revision=hf_model_branch, head_dropout=0.7\n",
+ " )\n",
+ " else:\n",
+ " finetune_forecast_model = TinyTimeMixerForPrediction.from_pretrained(\n",
+ " hf_model_path, revision=hf_model_branch\n",
+ " )\n",
+ "\n",
+ " if freeze_backbone:\n",
+ " print(\n",
+ " \"Number of params before freezing backbone\",\n",
+ " count_parameters(finetune_forecast_model),\n",
+ " )\n",
+ "\n",
+ " # Freeze the backbone of the model\n",
+ " for param in finetune_forecast_model.backbone.parameters():\n",
+ " param.requires_grad = False\n",
+ "\n",
+ " # Count params\n",
+ " print(\n",
+ " \"Number of params after freezing the backbone\",\n",
+ " count_parameters(finetune_forecast_model),\n",
+ " )\n",
+ "\n",
+ " if learning_rate is None:\n",
+ " learning_rate, finetune_forecast_model = optimal_lr_finder(\n",
+ " finetune_forecast_model,\n",
+ " dset_train,\n",
+ " batch_size=BATCH_SIZE,\n",
+ " )\n",
+ " print(\"OPTIMAL SUGGESTED LEARNING RATE =\", learning_rate)\n",
+ "\n",
+ " print(f\"Using learning rate = {learning_rate}\")\n",
+ " finetune_forecast_args = TrainingArguments(\n",
+ " output_dir=f\"{SUBDIR}/fewshot_{fewshot_percent}\",\n",
+ " overwrite_output_dir=True,\n",
+ " learning_rate=learning_rate,\n",
+ " num_train_epochs=EPOCHS,\n",
+ " do_eval=True,\n",
+ " evaluation_strategy=\"epoch\",\n",
+ " per_device_train_batch_size=BATCH_SIZE,\n",
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
+ " dataloader_num_workers=NUM_WORKERS,\n",
+ " report_to=None,\n",
+ " save_strategy=\"epoch\",\n",
+ " logging_strategy=\"epoch\",\n",
+ " save_total_limit=1,\n",
+ " logging_dir=f\"{SUBDIR}/fewshot_{fewshot_percent}\", # Make sure to specify a logging directory\n",
+ " load_best_model_at_end=True, # Load the best model when training ends\n",
+ " metric_for_best_model=\"eval_loss\", # Metric to monitor for early stopping\n",
+ " greater_is_better=False, # For loss\n",
+ " seed=SEED,\n",
+ " )\n",
+ "\n",
+ " # Create the early stopping callback\n",
+ " early_stopping_callback = EarlyStoppingCallback(\n",
+ " early_stopping_patience=10, # Number of epochs with no improvement after which to stop\n",
+ " early_stopping_threshold=0.0, # Minimum improvement required to consider as improvement\n",
+ " )\n",
+ " tracking_callback = TrackingCallback()\n",
+ "\n",
+ " # Optimizer and scheduler\n",
+ " optimizer = AdamW(finetune_forecast_model.parameters(), lr=learning_rate)\n",
+ " scheduler = OneCycleLR(\n",
+ " optimizer,\n",
+ " learning_rate,\n",
+ " epochs=EPOCHS,\n",
+ " steps_per_epoch=math.ceil(len(dset_train) / (BATCH_SIZE)),\n",
+ " )\n",
+ "\n",
+ " finetune_forecast_trainer = Trainer(\n",
+ " model=finetune_forecast_model,\n",
+ " args=finetune_forecast_args,\n",
+ " train_dataset=dset_train,\n",
+ " eval_dataset=dset_val,\n",
+ " callbacks=[early_stopping_callback, tracking_callback],\n",
+ " optimizers=(optimizer, scheduler),\n",
+ " )\n",
+ " finetune_forecast_trainer.remove_callback(INTEGRATION_TO_CALLBACK[\"codecarbon\"])\n",
+ "\n",
+ " # Fine tune\n",
+ " finetune_forecast_trainer.train()\n",
+ "\n",
+ " # Evaluation\n",
+ " print(\n",
+ " \"+\" * 20,\n",
+ " f\"Test MSE after few-shot {fewshot_percent}% fine-tuning\",\n",
+ " \"+\" * 20,\n",
+ " )\n",
+ " fewshot_output = finetune_forecast_trainer.evaluate(dset_test)\n",
+ " print(fewshot_output)\n",
+ " print(\"+\" * 60)\n",
+ "\n",
+ " # Plot\n",
+ " plot_predictions(\n",
+ " model=finetune_forecast_trainer.model,\n",
+ " dset=dset_test,\n",
+ " plot_dir=SUBDIR,\n",
+ " num_plots=10,\n",
+ " plot_prefix=f\"test_fewshot_{fewshot_percent}\",\n",
+ " channel=0,\n",
+ " )\n",
+ " plt.close()\n",
+ "\n",
+ " # write results\n",
+ " all_results[f\"fs{fewshot_percent}_mse\"].append(fewshot_output[\"eval_loss\"])\n",
+ " all_results[f\"fs{fewshot_percent}_mean_epoch_time\"].append(tracking_callback.mean_epoch_time)\n",
+ " all_results[f\"fs{fewshot_percent}_total_train_time\"].append(tracking_callback.total_train_time)\n",
+ " all_results[f\"fs{fewshot_percent}_best_val_metric\"].append(tracking_callback.best_eval_metric)\n",
+ "\n",
+ " df_out = pd.DataFrame(all_results).round(3)\n",
+ " print(df_out[[\"dataset\", \"zs_mse\", \"fs5_mse\"]])\n",
+ " df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")\n",
+ " df_out.to_csv(f\"{OUT_DIR}/results_zero_few.csv\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Benchmarking results*\n",
+ "\n",
+ "*Some slight differences in the results as compared to the TTM paper results is possible due to different training environments."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "\n",
+ "
\n",
+ " \n",
+ " \n",
+ " | \n",
+ " dataset | \n",
+ " zs_mse | \n",
+ " fs5_mse | \n",
+ " zs_eval_time | \n",
+ " fs5_mean_epoch_time | \n",
+ " fs5_total_train_time | \n",
+ " fs5_best_val_metric | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 0 | \n",
+ " etth1 | \n",
+ " 0.357 | \n",
+ " 0.357 | \n",
+ " 1.996 | \n",
+ " 1.092 | \n",
+ " 28.328 | \n",
+ " 0.655 | \n",
+ "
\n",
+ " \n",
+ " 1 | \n",
+ " etth2 | \n",
+ " 0.274 | \n",
+ " 0.277 | \n",
+ " 0.990 | \n",
+ " 1.000 | \n",
+ " 47.402 | \n",
+ " 0.228 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " ettm1 | \n",
+ " 0.327 | \n",
+ " 0.331 | \n",
+ " 3.463 | \n",
+ " 1.410 | \n",
+ " 44.975 | \n",
+ " 0.401 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " ettm2 | \n",
+ " 0.168 | \n",
+ " 0.168 | \n",
+ " 3.518 | \n",
+ " 1.400 | \n",
+ " 45.141 | \n",
+ " 0.123 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " weather | \n",
+ " 0.150 | \n",
+ " 0.149 | \n",
+ " 6.533 | \n",
+ " 1.972 | \n",
+ " 54.369 | \n",
+ " 0.394 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " electricity | \n",
+ " 0.155 | \n",
+ " 0.138 | \n",
+ " 34.532 | \n",
+ " 6.774 | \n",
+ " 964.924 | \n",
+ " 0.113 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " traffic | \n",
+ " 0.463 | \n",
+ " 0.466 | \n",
+ " 62.604 | \n",
+ " 9.786 | \n",
+ " 365.345 | \n",
+ " 0.391 | \n",
+ "
\n",
+ " \n",
+ "
\n",
+ "
"
+ ],
+ "text/plain": [
+ " dataset zs_mse fs5_mse zs_eval_time fs5_mean_epoch_time \\\n",
+ "0 etth1 0.357 0.357 1.996 1.092 \n",
+ "1 etth2 0.274 0.277 0.990 1.000 \n",
+ "2 ettm1 0.327 0.331 3.463 1.410 \n",
+ "3 ettm2 0.168 0.168 3.518 1.400 \n",
+ "4 weather 0.150 0.149 6.533 1.972 \n",
+ "5 electricity 0.155 0.138 34.532 6.774 \n",
+ "6 traffic 0.463 0.466 62.604 9.786 \n",
+ "\n",
+ " fs5_total_train_time fs5_best_val_metric \n",
+ "0 28.328 0.655 \n",
+ "1 47.402 0.228 \n",
+ "2 44.975 0.401 \n",
+ "3 45.141 0.123 \n",
+ "4 54.369 0.394 \n",
+ "5 964.924 0.113 \n",
+ "6 365.345 0.391 "
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "df_out"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.12"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/notebooks/hfdemo/tinytimemixer/ttm-r2_benchmarking_512_96.ipynb b/notebooks/hfdemo/tinytimemixer/ttm-r2_benchmarking_512_96.ipynb
new file mode 100644
index 00000000..6b2217bd
--- /dev/null
+++ b/notebooks/hfdemo/tinytimemixer/ttm-r2_benchmarking_512_96.ipynb
@@ -0,0 +1,2579 @@
+{
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# TTM zero-shot and few-shot benchmarking on multiple datasets\n",
+ "\n",
+ "**Using TTM-512-96 model.**\n",
+ "\n",
+ "Pre-trained TTM models will be fetched from the [Hugging Face TTM Model Repository](ibm-granite/granite-timeseries-ttm-r2).\n",
+ "\n",
+ "1. TTM-R1 pre-trained models can be found here: [TTM-R1 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r1)\n",
+ " 1. For 512-96 model set `TTM_MODEL_REVISION=\"main\"`\n",
+ " 2. For 1024-96 model set `TTM_MODEL_REVISION=\"1024_96_v1\"`\n",
+ "2. TTM-R2 pre-trained models can be found here: [TTM-R2 Model Card](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2)\n",
+ " 1. For 512-96 model set `TTM_MODEL_REVISION=\"main\"`\n",
+ " 2. For 1024-96 model set `TTM_MODEL_REVISION=\"1024-96-r2\"`\n",
+ " 3. For 1536-96 model set `TTM_MODEL_REVISION=\"1536-96-r2\"`\n",
+ "\n",
+ "Details about the revisions (R1 and R2) can be found [here](https://huggingface.co/ibm-granite/granite-timeseries-ttm-r2)."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Imports"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2024-10-10 07:15:37.180528: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.\n",
+ "2024-10-10 07:15:37.217865: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n",
+ "To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
+ "2024-10-10 07:15:37.936129: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Could not find TensorRT\n",
+ "/dccstor/dnn_forecasting/conda_envs/envs/fm/lib/python3.9/site-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: libtorch_cuda_cu.so: cannot open shared object file: No such file or directory\n",
+ " warn(f\"Failed to load image Python extension: {e}\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "import math\n",
+ "import warnings\n",
+ "\n",
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd\n",
+ "from torch.optim import AdamW\n",
+ "from torch.optim.lr_scheduler import OneCycleLR\n",
+ "from transformers import EarlyStoppingCallback, Trainer, TrainingArguments, set_seed\n",
+ "from transformers.integrations import INTEGRATION_TO_CALLBACK\n",
+ "\n",
+ "from tsfm_public import TinyTimeMixerForPrediction, TrackingCallback, count_parameters, load_dataset\n",
+ "from tsfm_public.toolkit.lr_finder import optimal_lr_finder\n",
+ "from tsfm_public.toolkit.visualization import plot_predictions\n",
+ "\n",
+ "\n",
+ "warnings.filterwarnings(\"ignore\")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Important arguments"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Set seed\n",
+ "SEED = 42\n",
+ "set_seed(SEED)\n",
+ "\n",
+ "# Specify model parameters\n",
+ "context_length = 512\n",
+ "forecast_length = 96\n",
+ "freeze_backbone = True\n",
+ "\n",
+ "# Other args\n",
+ "EPOCHS = 50\n",
+ "NUM_WORKERS = 16\n",
+ "\n",
+ "# Make sure all the datasets in the following `list_datasets` are\n",
+ "# saved in the `DATA_ROOT_PATH` folder. Or, change it accordingly.\n",
+ "# Refer to the load_datasets() function\n",
+ "# in notebooks/hfdemo/tinytimemixer/utils/ttm_utils.py\n",
+ "# to see how it is used.\n",
+ "DATA_ROOT_PATH = \"/dccstor/tsfm23/datasets/\"\n",
+ "\n",
+ "# This is where results will be saved\n",
+ "OUT_DIR = f\"ttm-r2_results_benchmark_{context_length}_{forecast_length}/\""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## List of benchmark datasets (TTM was not pre-trained on any of these)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "list_datasets = [\n",
+ " \"etth1\",\n",
+ " \"etth2\",\n",
+ " \"ettm1\",\n",
+ " \"ettm2\",\n",
+ " \"weather\",\n",
+ " \"electricity\",\n",
+ " \"traffic\",\n",
+ "]"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Get model path"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Please provide the branch name properly based on context_len and forecast_len\n",
+ "hf_model_path = \"ibm-granite/granite-timeseries-ttm-r2\"\n",
+ "hf_model_branch = \"main\""
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Main benchmarking loop"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048548:t-23085973639936:data_handling.py:load_dataset:Dataset name: etth1, context length: 512, prediction length 96\n",
+ "INFO:p-3048548:t-23085973639936:data_handling.py:load_dataset:Data lengths: train = 8033, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = etth1, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048548:t-23085973639936:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048548:t-23085973639936:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.3628121316432953, 'eval_model_preparation_time': 0.0025, 'eval_runtime': 1.5528, 'eval_samples_per_second': 1793.585, 'eval_steps_per_second': 28.337}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048548:t-23085973639936:data_handling.py:load_dataset:Dataset name: etth1, context length: 512, prediction length 96\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "-------------------- Running few-shot 5% --------------------\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048548:t-23085973639936:data_handling.py:load_dataset:Data lengths: train = 311, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Number of params before freezing backbone 805280\n",
+ "Number of params after freezing the backbone 289696\n",
+ "LR Finder: Running learning rate (LR) finder algorithm. If the suggested LR is very low, we suggest setting the LR manually.\n",
+ "LR Finder: Using GPU:0.\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048548:t-23085973639936:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048548:t-23085973639936:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "LR Finder: Suggested learning rate = 0.00017073526474706903\n",
+ "OPTIMAL SUGGESTED LEARNING RATE = 0.00017073526474706903\n",
+ "Using learning rate = 0.00017073526474706903\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [105/250 00:46 < 01:04, 2.23 it/s, Epoch 21/50]\n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " Epoch | \n",
+ " Training Loss | \n",
+ " Validation Loss | \n",
+ "
\n",
+ " \n",
+ " \n",
+ " \n",
+ " 1 | \n",
+ " 0.812700 | \n",
+ " 0.664259 | \n",
+ "
\n",
+ " \n",
+ " 2 | \n",
+ " 0.832200 | \n",
+ " 0.664153 | \n",
+ "
\n",
+ " \n",
+ " 3 | \n",
+ " 0.793800 | \n",
+ " 0.663970 | \n",
+ "
\n",
+ " \n",
+ " 4 | \n",
+ " 0.772200 | \n",
+ " 0.663760 | \n",
+ "
\n",
+ " \n",
+ " 5 | \n",
+ " 0.811200 | \n",
+ " 0.663474 | \n",
+ "
\n",
+ " \n",
+ " 6 | \n",
+ " 0.769200 | \n",
+ " 0.663127 | \n",
+ "
\n",
+ " \n",
+ " 7 | \n",
+ " 0.749600 | \n",
+ " 0.662820 | \n",
+ "
\n",
+ " \n",
+ " 8 | \n",
+ " 0.719200 | \n",
+ " 0.662412 | \n",
+ "
\n",
+ " \n",
+ " 9 | \n",
+ " 0.716600 | \n",
+ " 0.662103 | \n",
+ "
\n",
+ " \n",
+ " 10 | \n",
+ " 0.709800 | \n",
+ " 0.661821 | \n",
+ "
\n",
+ " \n",
+ " 11 | \n",
+ " 0.688200 | \n",
+ " 0.661788 | \n",
+ "
\n",
+ " \n",
+ " 12 | \n",
+ " 0.684000 | \n",
+ " 0.661836 | \n",
+ "
\n",
+ " \n",
+ " 13 | \n",
+ " 0.656500 | \n",
+ " 0.662756 | \n",
+ "
\n",
+ " \n",
+ " 14 | \n",
+ " 0.644400 | \n",
+ " 0.665279 | \n",
+ "
\n",
+ " \n",
+ " 15 | \n",
+ " 0.635400 | \n",
+ " 0.668289 | \n",
+ "
\n",
+ " \n",
+ " 16 | \n",
+ " 0.620200 | \n",
+ " 0.673172 | \n",
+ "
\n",
+ " \n",
+ " 17 | \n",
+ " 0.621300 | \n",
+ " 0.674407 | \n",
+ "
\n",
+ " \n",
+ " 18 | \n",
+ " 0.611000 | \n",
+ " 0.674804 | \n",
+ "
\n",
+ " \n",
+ " 19 | \n",
+ " 0.604800 | \n",
+ " 0.676390 | \n",
+ "
\n",
+ " \n",
+ " 20 | \n",
+ " 0.611800 | \n",
+ " 0.678676 | \n",
+ "
\n",
+ " \n",
+ " 21 | \n",
+ " 0.597600 | \n",
+ " 0.681201 | \n",
+ "
\n",
+ " \n",
+ "
"
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "[TrackingCallback] Mean Epoch Time = 1.0237201736086892 seconds, Total Train Time = 48.076265811920166\n",
+ "++++++++++++++++++++ Test MSE after few-shot 5% fine-tuning ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ " \n",
+ " \n",
+ "
\n",
+ " [44/44 00:00]\n",
+ "
\n",
+ " "
+ ],
+ "text/plain": [
+ ""
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "{'eval_loss': 0.36197009682655334, 'eval_runtime': 1.2626, 'eval_samples_per_second': 2205.783, 'eval_steps_per_second': 34.849, 'epoch': 21.0}\n",
+ "++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "INFO:p-3048548:t-23085973639936:data_handling.py:load_dataset:Dataset name: etth2, context length: 512, prediction length 96\n",
+ "INFO:p-3048548:t-23085973639936:data_handling.py:load_dataset:Data lengths: train = 8033, val = 2785, test = 2785\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ " dataset zs_mse fs5_mse\n",
+ "0 etth1 0.363 0.362\n",
+ "\n",
+ "====================================================================================================\n",
+ "Running zero-shot/few-shot for TTM-512 on dataset = etth2, forecast_len = 96\n",
+ "Model will be loaded from ibm-granite/granite-timeseries-ttm-r2/main\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "WARNING:p-3048548:t-23085973639936:other.py:check_os_kernel:Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n",
+ "INFO:p-3048548:t-23085973639936:base.py:add_job:Adding job tentatively -- it will be properly scheduled when the scheduler starts\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "++++++++++++++++++++ Test MSE zero-shot ++++++++++++++++++++\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "\n",
+ "