{ "cells": [ { "cell_type": "markdown", "id": "e8a9f0b1", "metadata": {}, "source": [ "# Running TempoPFN on GIFT-Eval Benchmark\n", "\n", "This notebook evaluates the **TempoPFN** model on the GIFT-Eval benchmark. \n", "\n", "Make sure you download the gift-eval benchmark and set the `GIFT_EVAL_DATASET_STORAGE_PATH` environment variable correctly before running this notebook." ] }, { "cell_type": "markdown", "id": "f1d2e3c4", "metadata": {}, "source": [ "## 1. Setup and Dependencies\n", "\n", "First, install the required packages. \n", "\n", "**Note:** This notebook assumes that the core `TempoPFN` model code (e.g., `src.models.model`, `src.data.containers`) and dependencies are installed as a Python package or are otherwise available in the `PYTHONPATH`." ] }, { "cell_type": "markdown", "id": "b9c8d7e6", "metadata": {}, "source": [ "## 2. Imports\n", "\n", "Import all necessary libraries. " ] }, { "cell_type": "code", "execution_count": null, "id": "c7d8e9f0", "metadata": {}, "outputs": [], "source": [ "import csv\n", "import glob\n", "import json\n", "import logging\n", "import math\n", "import os\n", "import warnings\n", "from collections.abc import Iterable, Iterator\n", "from dataclasses import dataclass\n", "from enum import Enum\n", "from functools import cached_property\n", "from pathlib import Path\n", "\n", "# GluonTS and Data Handling\n", "import datasets\n", "\n", "# Plotting and Warnings\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import pandas as pd\n", "import pyarrow.compute as pc\n", "import torch\n", "import yaml\n", "from dotenv import load_dotenv\n", "from gluonts.dataset import DataEntry\n", "from gluonts.dataset.common import ProcessDataEntry\n", "from gluonts.dataset.split import TestData, TrainingDataset, split\n", "\n", "# GluonTS Evaluation\n", "from gluonts.ev.metrics import (\n", " MAE,\n", " MAPE,\n", " MASE,\n", " MSE,\n", " MSIS,\n", " ND,\n", " NRMSE,\n", " RMSE,\n", " SMAPE,\n", " MeanWeightedSumQuantileLoss,\n", ")\n", "from gluonts.itertools import Map\n", "from gluonts.model.evaluation import evaluate_model\n", "from gluonts.model.forecast import QuantileForecast\n", "from gluonts.model.predictor import Predictor\n", "from gluonts.time_feature import get_seasonality, norm_freq_str\n", "from gluonts.transform import Transformation\n", "from linear_operator.utils.cholesky import NumericalWarning\n", "from pandas.tseries.frequencies import to_offset\n", "\n", "# --- TempoPFN Core Model Imports ---\n", "# These are assumed to be installed or in the PYTHONPATH\n", "from src.data.containers import BatchTimeSeriesContainer\n", "from src.data.frequency import parse_frequency\n", "from src.data.scalers import RobustScaler\n", "from src.models.model import TimeSeriesModel\n", "from src.utils.utils import device\n", "from toolz import compose\n", "from torch.nn.parallel import DistributedDataParallel as DDP\n", "\n", "# --- Setup Logging ---\n", "logging.basicConfig(level=logging.INFO, format=\"%(asctime)s - %(levelname)s - %(message)s\")\n", "logging.getLogger(\"matplotlib\").setLevel(logging.WARNING)\n", "logging.getLogger(\"matplotlib.font_manager\").setLevel(logging.WARNING)\n", "logging.getLogger(\"PIL\").setLevel(logging.WARNING)\n", "logger = logging.getLogger(\"gift_eval_runner\")\n", "\n", "\n", "# Filter out specific gluonts warnings\n", "class WarningFilter(logging.Filter):\n", " def __init__(self, text_to_filter: str) -> None:\n", " super().__init__()\n", " self.text_to_filter = text_to_filter\n", "\n", " def filter(self, record: logging.LogRecord) -> bool:\n", " return self.text_to_filter not in record.getMessage()\n", "\n", "\n", "gts_logger = logging.getLogger(\"gluonts.model.forecast\")\n", "gts_logger.addFilter(WarningFilter(\"The mean prediction is not stored in the forecast data\"))\n", "\n", "# Filter out numerical warnings\n", "warnings.filterwarnings(\"ignore\", category=NumericalWarning)\n", "warnings.filterwarnings(\"ignore\", category=FutureWarning)\n", "warnings.filterwarnings(\"ignore\", category=DeprecationWarning)\n", "\n", "# Load environment variables (e.g., GIFT_EVAL_DATASET_STORAGE_PATH)\n", "load_dotenv()" ] }, { "cell_type": "markdown", "id": "d6e7f8a1", "metadata": {}, "source": [ "## 3. Constants and Configuration\n", "\n", "Define dataset lists, metrics, and other constants following GIFT-Eval standards." ] }, { "cell_type": "markdown", "id": "g4h5j6k7", "metadata": {}, "source": [ "### 3.1. Constants " ] }, { "cell_type": "code", "execution_count": null, "id": "h5j6k7l8", "metadata": {}, "outputs": [], "source": [ "# Environment setup\n", "os.environ[\"CUBLAS_WORKSPACE_CONFIG\"] = \":4096:8\"\n", "\n", "# Use absolute path relative to the project root\n", "_MODULE_DIR = Path.cwd().parent.parent # Assumes notebook is in `examples/gift_eval/`\n", "DATASET_PROPERTIES_PATH = _MODULE_DIR / \"data\" / \"dataset_properties.json\"\n", "\n", "try:\n", " with open(DATASET_PROPERTIES_PATH) as f:\n", " DATASET_PROPERTIES = json.load(f)\n", "except Exception as exc: # pragma: no cover - logging path\n", " DATASET_PROPERTIES = {}\n", " logger.warning(\n", " \"Could not load dataset properties from %s: %s. Domain and num_variates will fall back to defaults.\",\n", " DATASET_PROPERTIES_PATH,\n", " exc,\n", " )\n", "\n", "# Datasets\n", "SHORT_DATASETS = (\n", " \"m4_yearly\",\n", " \"m4_quarterly\",\n", " \"m4_monthly\",\n", " \"m4_weekly\",\n", " \"m4_daily\",\n", " \"m4_hourly\",\n", " \"electricity/15T\",\n", " \"electricity/H\",\n", " \"electricity/D\",\n", " \"electricity/W\",\n", " \"solar/10T\",\n", " \"solar/H\",\n", " \"solar/D\",\n", " \"solar/W\",\n", " \"hospital\",\n", " \"covid_deaths\",\n", " \"us_births/D\",\n", " \"us_births/M\",\n", " \"us_births/W\",\n", " \"saugeenday/D\",\n", " \"saugeenday/M\",\n", " \"saugeenday/W\",\n", " \"temperature_rain_with_missing\",\n", " \"kdd_cup_2018_with_missing/H\",\n", " \"kdd_cup_2018_with_missing/D\",\n", " \"car_parts_with_missing\",\n", " \"restaurant\",\n", " \"hierarchical_sales/D\",\n", " \"hierarchical_sales/W\",\n", " \"LOOP_SEATTLE/5T\",\n", " \"LOOP_SEATTLE/H\",\n", " \"LOOP_SEATTLE/D\",\n", " \"SZ_TAXI/15T\",\n", " \"SZ_TAXI/H\",\n", " \"M_DENSE/H\",\n", " \"M_DENSE/D\",\n", " \"ett1/15T\",\n", " \"ett1/H\",\n", " \"ett1/D\",\n", " \"ett1/W\",\n", " \"ett2/15T\",\n", " \"ett2/H\",\n", " \"ett2/D\",\n", " \"ett2/W\",\n", " \"jena_weather/10T\",\n", " \"jena_weather/H\",\n", " \"jena_weather/D\",\n", " \"bitbrains_fast_storage/5T\",\n", " \"bitbrains_fast_storage/H\",\n", " \"bitbrains_rnd/5T\",\n", " \"bitbrains_rnd/H\",\n", " \"bizitobs_application\",\n", " \"bizitobs_service\",\n", " \"bizitobs_l2c/5T\",\n", " \"bizitobs_l2c/H\",\n", ")\n", "\n", "MED_LONG_DATASETS = (\n", " \"electricity/15T\",\n", " \"electricity/H\",\n", " \"solar/10T\",\n", " \"solar/H\",\n", " \"kdd_cup_2018_with_missing/H\",\n", " \"LOOP_SEATTLE/5T\",\n", " \"LOOP_SEATTLE/H\",\n", " \"SZ_TAXI/15T\",\n", " \"M_DENSE/H\",\n", " \"ett1/15T\",\n", " \"ett1/H\",\n", " \"ett2/15T\",\n", " \"ett2/H\",\n", " \"jena_weather/10T\",\n", " \"jena_weather/H\",\n", " \"bitbrains_fast_storage/5T\",\n", " \"bitbrains_rnd/5T\",\n", " \"bizitobs_application\",\n", " \"bizitobs_service\",\n", " \"bizitobs_l2c/5T\",\n", " \"bizitobs_l2c/H\",\n", ")\n", "\n", "# Preserve insertion order\n", "ALL_DATASETS = list(dict.fromkeys(SHORT_DATASETS + MED_LONG_DATASETS))\n", "\n", "# Evaluation terms\n", "TERMS = (\"short\", \"medium\", \"long\")\n", "\n", "# Pretty names mapping\n", "PRETTY_NAMES = {\n", " \"saugeenday\": \"saugeen\",\n", " \"temperature_rain_with_missing\": \"temperature_rain\",\n", " \"kdd_cup_2018_with_missing\": \"kdd_cup_2018\",\n", " \"car_parts_with_missing\": \"car_parts\",\n", "}\n", "\n", "# Metrics\n", "METRICS = (\n", " MSE(forecast_type=\"mean\"),\n", " MSE(forecast_type=0.5),\n", " MAE(),\n", " MASE(),\n", " MAPE(),\n", " SMAPE(),\n", " MSIS(),\n", " RMSE(),\n", " NRMSE(),\n", " ND(),\n", " MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),\n", ")\n", "\n", "# Standard metric names for CSV header\n", "STANDARD_METRIC_NAMES = (\n", " \"MSE[mean]\",\n", " \"MSE[0.5]\",\n", " \"MAE[0.5]\",\n", " \"MASE[0.5]\",\n", " \"MAPE[0.5]\",\n", " \"sMAPE[0.5]\",\n", " \"MSIS\",\n", " \"RMSE[mean]\",\n", " \"NRMSE[mean]\",\n", " \"ND[0.5]\",\n", " \"mean_weighted_sum_quantile_loss\",\n", ")" ] }, { "cell_type": "markdown", "id": "i7j8k9l0", "metadata": {}, "source": [ "### 3.2. Core Data Structures " ] }, { "cell_type": "code", "execution_count": null, "id": "j8k9l0m1", "metadata": {}, "outputs": [], "source": [ "@dataclass\n", "class DatasetMetadata:\n", " \"\"\"Structured description of a dataset/term combination.\"\"\"\n", "\n", " full_name: str\n", " key: str\n", " freq: str\n", " term: str\n", " season_length: int\n", " target_dim: int\n", " to_univariate: bool\n", " prediction_length: int\n", " windows: int\n", "\n", "\n", "@dataclass\n", "class EvaluationItem:\n", " \"\"\"Container for evaluation results and optional figures.\"\"\"\n", "\n", " dataset_metadata: DatasetMetadata\n", " metrics: dict\n", " figures: list[tuple[object, str]]\n", "\n", "\n", "DatasetSelection = list[str] | tuple[str, ...] | str\n", "\n", "\n", "def expand_datasets_arg(datasets: DatasetSelection) -> list[str]:\n", " \"\"\"Normalize dataset selection strings to explicit lists.\"\"\"\n", "\n", " if isinstance(datasets, str):\n", " dataset_list = [datasets]\n", " else:\n", " dataset_list = list(datasets)\n", "\n", " if not dataset_list:\n", " return []\n", "\n", " if dataset_list[0] == \"all\":\n", " return list(ALL_DATASETS)\n", "\n", " for dataset in dataset_list:\n", " if dataset not in ALL_DATASETS:\n", " raise ValueError(f\"Invalid dataset: {dataset}. Use one of {ALL_DATASETS}\")\n", "\n", " return dataset_list" ] }, { "cell_type": "markdown", "id": "k9l0m1n2", "metadata": {}, "source": [ "### 3.3. GIFT-Eval Dataset Class (`data.py`)\n", "\n", "The `Dataset` class handles loading and preprocessing GIFT-Eval benchmark datasets. This implementation is adapted from the official GIFT-Eval repository." ] }, { "cell_type": "code", "execution_count": null, "id": "l0m1n2o3", "metadata": {}, "outputs": [], "source": [ "TEST_SPLIT = 0.1\n", "MAX_WINDOW = 20\n", "\n", "M4_PRED_LENGTH_MAP = {\n", " \"A\": 6,\n", " \"Q\": 8,\n", " \"M\": 18,\n", " \"W\": 13,\n", " \"D\": 14,\n", " \"H\": 48,\n", " \"h\": 48,\n", " \"Y\": 6,\n", "}\n", "\n", "PRED_LENGTH_MAP = {\n", " \"M\": 12,\n", " \"W\": 8,\n", " \"D\": 30,\n", " \"H\": 48,\n", " \"h\": 48,\n", " \"T\": 48,\n", " \"S\": 60,\n", " \"s\": 60,\n", " \"min\": 48,\n", "}\n", "\n", "TFB_PRED_LENGTH_MAP = {\n", " \"A\": 6,\n", " \"Y\": 6,\n", " \"H\": 48,\n", " \"h\": 48,\n", " \"Q\": 8,\n", " \"D\": 14,\n", " \"M\": 18,\n", " \"W\": 13,\n", " \"U\": 8,\n", " \"T\": 8,\n", " \"min\": 8,\n", " \"us\": 8,\n", "}\n", "\n", "\n", "class Term(Enum):\n", " SHORT = \"short\"\n", " MEDIUM = \"medium\"\n", " LONG = \"long\"\n", "\n", " @property\n", " def multiplier(self) -> int:\n", " if self == Term.SHORT:\n", " return 1\n", " elif self == Term.MEDIUM:\n", " return 10\n", " elif self == Term.LONG:\n", " return 15\n", "\n", "\n", "def itemize_start(data_entry: DataEntry) -> DataEntry:\n", " data_entry[\"start\"] = data_entry[\"start\"].item()\n", " return data_entry\n", "\n", "\n", "class MultivariateToUnivariate(Transformation):\n", " def __init__(self, field):\n", " self.field = field\n", "\n", " def __call__(self, data_it: Iterable[DataEntry], is_train: bool = False) -> Iterator:\n", " for data_entry in data_it:\n", " item_id = data_entry[\"item_id\"]\n", " val_ls = list(data_entry[self.field])\n", " for id, val in enumerate(val_ls):\n", " univariate_entry = data_entry.copy()\n", " univariate_entry[self.field] = val\n", " univariate_entry[\"item_id\"] = item_id + \"_dim\" + str(id)\n", " yield univariate_entry\n", "\n", "\n", "class Dataset:\n", " def __init__(\n", " self,\n", " name: str,\n", " term: Term | str = Term.SHORT,\n", " to_univariate: bool = False,\n", " storage_path: str = None,\n", " max_windows: int | None = None,\n", " ):\n", " storage_path = Path(storage_path)\n", " self.hf_dataset = datasets.load_from_disk(str(storage_path / name)).with_format(\"numpy\")\n", " process = ProcessDataEntry(\n", " self.freq,\n", " one_dim_target=self.target_dim == 1,\n", " )\n", "\n", " self.gluonts_dataset = Map(compose(process, itemize_start), self.hf_dataset)\n", " if to_univariate:\n", " self.gluonts_dataset = MultivariateToUnivariate(\"target\").apply(self.gluonts_dataset)\n", "\n", " self.term = Term(term)\n", " self.name = name\n", " self.max_windows = max_windows if max_windows is not None else MAX_WINDOW\n", "\n", " @cached_property\n", " def prediction_length(self) -> int:\n", " freq = norm_freq_str(to_offset(self.freq).name)\n", " if freq.endswith(\"E\"):\n", " freq = freq[:-1]\n", " pred_len = M4_PRED_LENGTH_MAP[freq] if \"m4\" in self.name else PRED_LENGTH_MAP[freq]\n", " return self.term.multiplier * pred_len\n", "\n", " @cached_property\n", " def freq(self) -> str:\n", " return self.hf_dataset[0][\"freq\"]\n", "\n", " @cached_property\n", " def target_dim(self) -> int:\n", " return target.shape[0] if len((target := self.hf_dataset[0][\"target\"]).shape) > 1 else 1\n", "\n", " @cached_property\n", " def past_feat_dynamic_real_dim(self) -> int:\n", " if \"past_feat_dynamic_real\" not in self.hf_dataset[0]:\n", " return 0\n", " elif len((past_feat_dynamic_real := self.hf_dataset[0][\"past_feat_dynamic_real\"]).shape) > 1:\n", " return past_feat_dynamic_real.shape[0]\n", " else:\n", " return 1\n", "\n", " @cached_property\n", " def windows(self) -> int:\n", " if \"m4\" in self.name:\n", " return 1\n", " w = math.ceil(TEST_SPLIT * self._min_series_length / self.prediction_length)\n", " return min(max(1, w), self.max_windows)\n", "\n", " @cached_property\n", " def _min_series_length(self) -> int:\n", " if self.hf_dataset[0][\"target\"].ndim > 1:\n", " lengths = pc.list_value_length(pc.list_flatten(pc.list_slice(self.hf_dataset.data.column(\"target\"), 0, 1)))\n", " else:\n", " lengths = pc.list_value_length(self.hf_dataset.data.column(\"target\"))\n", " return min(lengths.to_numpy())\n", "\n", " @cached_property\n", " def sum_series_length(self) -> int:\n", " if self.hf_dataset[0][\"target\"].ndim > 1:\n", " lengths = pc.list_value_length(pc.list_flatten(self.hf_dataset.data.column(\"target\")))\n", " else:\n", " lengths = pc.list_value_length(self.hf_dataset.data.column(\"target\"))\n", " return sum(lengths.to_numpy())\n", "\n", " @property\n", " def training_dataset(self) -> TrainingDataset:\n", " training_dataset, _ = split(self.gluonts_dataset, offset=-self.prediction_length * (self.windows + 1))\n", " return training_dataset\n", "\n", " @property\n", " def validation_dataset(self) -> TrainingDataset:\n", " validation_dataset, _ = split(self.gluonts_dataset, offset=-self.prediction_length * self.windows)\n", " return validation_dataset\n", "\n", " @property\n", " def test_data(self) -> TestData:\n", " _, test_template = split(self.gluonts_dataset, offset=-self.prediction_length * self.windows)\n", " test_data = test_template.generate_instances(\n", " prediction_length=self.prediction_length,\n", " windows=self.windows,\n", " distance=self.prediction_length,\n", " )\n", " return test_data" ] }, { "cell_type": "markdown", "id": "m1n2o3p4", "metadata": {}, "source": [ "### 3.4. Predictor Wrapper (`predictor.py`)\n", "\n", "This is the model-specific `TimeSeriesPredictor` class for `TempoPFN`. It wraps the core `TimeSeriesModel` and adapts it to the `gluonts`-style `Predictor` interface, which expects a `.predict()` method." ] }, { "cell_type": "code", "execution_count": null, "id": "n2o3p4q5", "metadata": {}, "outputs": [], "source": [ "class TimeSeriesPredictor(Predictor):\n", " \"\"\"Unified predictor for TimeSeriesModel supporting flexible construction.\"\"\"\n", "\n", " def __init__(\n", " self,\n", " model: TimeSeriesModel,\n", " config: dict,\n", " ds_prediction_length: int,\n", " ds_freq: str,\n", " batch_size: int = 32,\n", " max_context_length: int | None = None,\n", " debug: bool = False,\n", " ) -> None:\n", " # Dataset-specific context (can be updated per dataset/term)\n", " self.ds_prediction_length = ds_prediction_length\n", " self.ds_freq = ds_freq\n", " self.batch_size = batch_size\n", " self.max_context_length = max_context_length\n", " self.debug = debug\n", "\n", " # Persistent model/config (unwrap DDP if needed)\n", " self.model = model.module if isinstance(model, DDP) else model\n", " self.model.eval()\n", " self.config = config\n", "\n", " # Initialize scaler (using same type as model)\n", " scaler_type = self.config.get(\"TimeSeriesModel\", {}).get(\"scaler\", \"custom_robust\")\n", " epsilon = self.config.get(\"TimeSeriesModel\", {}).get(\"epsilon\", 1e-3)\n", " if scaler_type == \"custom_robust\":\n", " self.scaler = RobustScaler(epsilon=epsilon)\n", " else:\n", " raise ValueError(f\"Unsupported scaler type: {scaler_type}\")\n", "\n", " def set_dataset_context(\n", " self,\n", " prediction_length: int | None = None,\n", " freq: str | None = None,\n", " batch_size: int | None = None,\n", " max_context_length: int | None = None,\n", " ) -> None:\n", " \"\"\"Update lightweight dataset-specific attributes without reloading the model.\"\"\"\n", "\n", " if prediction_length is not None:\n", " self.ds_prediction_length = prediction_length\n", " if freq is not None:\n", " self.ds_freq = freq\n", " if batch_size is not None:\n", " self.batch_size = batch_size\n", " if max_context_length is not None:\n", " self.max_context_length = max_context_length\n", "\n", " @classmethod\n", " def from_model(\n", " cls,\n", " model: TimeSeriesModel,\n", " config: dict,\n", " ds_prediction_length: int,\n", " ds_freq: str,\n", " batch_size: int = 32,\n", " max_context_length: int | None = None,\n", " debug: bool = False,\n", " ) -> \"TimeSeriesPredictor\":\n", " return cls(\n", " model=model,\n", " config=config,\n", " ds_prediction_length=ds_prediction_length,\n", " ds_freq=ds_freq,\n", " batch_size=batch_size,\n", " max_context_length=max_context_length,\n", " debug=debug,\n", " )\n", "\n", " @classmethod\n", " def from_paths(\n", " cls,\n", " model_path: str,\n", " config_path: str,\n", " ds_prediction_length: int,\n", " ds_freq: str,\n", " batch_size: int = 32,\n", " max_context_length: int | None = None,\n", " debug: bool = False,\n", " ) -> \"TimeSeriesPredictor\":\n", " with open(config_path) as f:\n", " config = yaml.safe_load(f)\n", " model = cls._load_model_from_path(config=config, model_path=model_path)\n", " return cls(\n", " model=model,\n", " config=config,\n", " ds_prediction_length=ds_prediction_length,\n", " ds_freq=ds_freq,\n", " batch_size=batch_size,\n", " max_context_length=max_context_length,\n", " debug=debug,\n", " )\n", "\n", " @staticmethod\n", " def _load_model_from_path(config: dict, model_path: str) -> TimeSeriesModel:\n", " try:\n", " model = TimeSeriesModel(**config[\"TimeSeriesModel\"]).to(device)\n", " checkpoint = torch.load(model_path, map_location=device)\n", " model.load_state_dict(checkpoint[\"model_state_dict\"])\n", " model.eval()\n", " logger.info(f\"Successfully loaded model from {model_path}\")\n", " return model\n", " except Exception as exc: # pragma: no cover - logging path\n", " logger.error(f\"Failed to load model from {model_path}: {exc}\")\n", " raise\n", "\n", " def predict(self, test_data_input) -> Iterator[QuantileForecast]:\n", " \"\"\"Generate forecasts for the test data.\"\"\"\n", "\n", " if hasattr(test_data_input, \"__iter__\") and not isinstance(test_data_input, list):\n", " test_data_input = list(test_data_input)\n", " logger.debug(f\"Processing {len(test_data_input)} time series\")\n", "\n", " # Group series by their effective length (after optional truncation),\n", " # then process each uniform-length group in sub-batches up to batch_size.\n", " def _effective_length(entry) -> int:\n", " target = entry[\"target\"]\n", " if target.ndim == 1:\n", " seq_len = len(target)\n", " else:\n", " # target shape is [num_channels, seq_len]\n", " seq_len = target.shape[1]\n", " if self.max_context_length is not None:\n", " seq_len = min(seq_len, self.max_context_length)\n", " return seq_len\n", "\n", " length_to_items: dict[int, list[tuple[int, object]]] = {}\n", " for idx, entry in enumerate(test_data_input):\n", " seq_len = _effective_length(entry)\n", " length_to_items.setdefault(seq_len, []).append((idx, entry))\n", "\n", " total = len(test_data_input)\n", " ordered_results: list[QuantileForecast | None] = [None] * total\n", "\n", " for _, items in length_to_items.items():\n", " for i in range(0, len(items), self.batch_size):\n", " chunk = items[i : i + self.batch_size]\n", " entries = [entry for (_orig_idx, entry) in chunk]\n", " batch_forecasts = self._predict_batch(entries)\n", " for forecast_idx, (orig_idx, _entry) in enumerate(chunk):\n", " ordered_results[orig_idx] = batch_forecasts[forecast_idx]\n", "\n", " return ordered_results # type: ignore[return-value]\n", "\n", " def _predict_batch(self, test_data_batch: list) -> list[QuantileForecast]:\n", " \"\"\"Generate predictions for a batch of time series.\"\"\"\n", "\n", " logger.debug(f\"Processing batch of size: {len(test_data_batch)}\")\n", "\n", " try:\n", " batch_container = self._convert_to_batch_container(test_data_batch)\n", "\n", " if isinstance(device, torch.device):\n", " device_type = device.type\n", " else:\n", " device_type = \"cuda\" if \"cuda\" in str(device).lower() else \"cpu\"\n", " enable_autocast = device_type == \"cuda\"\n", "\n", " with torch.autocast(\n", " device_type=device_type,\n", " dtype=torch.bfloat16,\n", " enabled=enable_autocast,\n", " ):\n", " with torch.no_grad():\n", " model_output = self.model(batch_container, drop_enc_allow=False)\n", "\n", " forecasts = self._convert_to_forecasts(model_output, test_data_batch, batch_container)\n", "\n", " logger.debug(f\"Generated {len(forecasts)} forecasts\")\n", " return forecasts\n", " except Exception as exc: # pragma: no cover - logging path\n", " logger.error(f\"Error in batch prediction: {exc}\")\n", " raise\n", "\n", " def _convert_to_batch_container(self, test_data_batch: list) -> BatchTimeSeriesContainer:\n", " \"\"\"Convert gluonts test data to BatchTimeSeriesContainer.\"\"\"\n", "\n", " batch_size = len(test_data_batch)\n", " history_values_list = []\n", " start_dates = []\n", " frequencies = []\n", "\n", " for entry in test_data_batch:\n", " target = entry[\"target\"]\n", "\n", " if target.ndim == 1:\n", " target = target.reshape(-1, 1)\n", " else:\n", " target = target.T\n", "\n", " if self.max_context_length is not None and len(target) > self.max_context_length:\n", " target = target[-self.max_context_length :]\n", "\n", " history_values_list.append(target)\n", " start_dates.append(entry[\"start\"].to_timestamp().to_datetime64())\n", " frequencies.append(parse_frequency(entry[\"freq\"]))\n", "\n", " history_values_np = np.stack(history_values_list, axis=0)\n", " num_channels = history_values_np.shape[2]\n", "\n", " history_values = torch.tensor(history_values_np, dtype=torch.float32, device=device)\n", "\n", " future_values = torch.zeros(\n", " (batch_size, self.ds_prediction_length, num_channels),\n", " dtype=torch.float32,\n", " device=device,\n", " )\n", "\n", " return BatchTimeSeriesContainer(\n", " history_values=history_values,\n", " future_values=future_values,\n", " start=start_dates,\n", " frequency=frequencies,\n", " )\n", "\n", " def _convert_to_forecasts(\n", " self,\n", " model_output: dict,\n", " test_data_batch: list,\n", " batch_container: BatchTimeSeriesContainer,\n", " ) -> list[QuantileForecast]:\n", " \"\"\"Convert model predictions to QuantileForecast objects.\"\"\"\n", "\n", " predictions = model_output[\"result\"]\n", " scale_statistics = model_output[\"scale_statistics\"]\n", "\n", " if predictions.ndim == 4:\n", " predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics)\n", " is_quantile = True\n", " quantile_levels = self.model.quantiles\n", " else:\n", " predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics)\n", " is_quantile = False\n", " quantile_levels = [0.5]\n", "\n", " forecasts: list[QuantileForecast] = []\n", " for idx, entry in enumerate(test_data_batch):\n", " history_length = int(batch_container.history_values.shape[1])\n", " start_date = entry[\"start\"]\n", " forecast_start = start_date + history_length\n", "\n", " if is_quantile:\n", " pred_array = predictions_unscaled[idx].cpu().numpy()\n", "\n", " if pred_array.shape[1] == 1:\n", " pred_array = pred_array.squeeze(1)\n", " forecast_arrays = pred_array.T\n", " else:\n", " forecast_arrays = pred_array.transpose(2, 0, 1)\n", "\n", " forecast = QuantileForecast(\n", " forecast_arrays=forecast_arrays,\n", " forecast_keys=[str(q) for q in quantile_levels],\n", " start_date=forecast_start,\n", " )\n", " else:\n", " pred_array = predictions_unscaled[idx].cpu().numpy()\n", "\n", " if pred_array.shape[1] == 1:\n", " pred_array = pred_array.squeeze(1)\n", " forecast_arrays = pred_array.reshape(1, -1)\n", " else:\n", " forecast_arrays = pred_array.reshape(1, *pred_array.shape)\n", "\n", " forecast = QuantileForecast(\n", " forecast_arrays=forecast_arrays,\n", " forecast_keys=[\"0.5\"],\n", " start_date=forecast_start,\n", " )\n", "\n", " forecasts.append(forecast)\n", "\n", " return forecasts" ] }, { "cell_type": "markdown", "id": "o3p4q5r6", "metadata": {}, "source": [ "### 3.5. Result Handling \n", "\n", "These functions handle writing the per-dataset metrics to CSV files and aggregating all results into a single `all_results.csv` at the end." ] }, { "cell_type": "code", "execution_count": null, "id": "p4q5r6s7", "metadata": {}, "outputs": [], "source": [ "def _ensure_results_csv(csv_file_path: Path) -> None:\n", " if not csv_file_path.exists():\n", " csv_file_path.parent.mkdir(parents=True, exist_ok=True)\n", " with open(csv_file_path, \"w\", newline=\"\") as csvfile:\n", " writer = csv.writer(csvfile)\n", " header = (\n", " [\"dataset\", \"model\"]\n", " + [f\"eval_metrics/{name}\" for name in STANDARD_METRIC_NAMES]\n", " + [\"domain\", \"num_variates\"]\n", " )\n", " writer.writerow(header)\n", "\n", "\n", "def write_results_to_disk(\n", " items: list[EvaluationItem],\n", " dataset_name: str,\n", " output_dir: Path,\n", " model_name: str,\n", " create_plots: bool,\n", ") -> None:\n", " output_dir = output_dir / dataset_name\n", " output_dir.mkdir(parents=True, exist_ok=True)\n", " output_csv_path = output_dir / \"results.csv\"\n", " _ensure_results_csv(output_csv_path)\n", "\n", " with open(output_csv_path, \"a\", newline=\"\") as csvfile:\n", " writer = csv.writer(csvfile)\n", " for item in items:\n", " md: DatasetMetadata = item.dataset_metadata\n", " metric_values: list[float | None] = []\n", " for metric_name in STANDARD_METRIC_NAMES:\n", " value = item.metrics.get(metric_name, None)\n", " if value is None:\n", " metric_values.append(None)\n", " else:\n", " if hasattr(value, \"__len__\") and not isinstance(value, (str, bytes)) and len(value) == 1:\n", " value = value[0]\n", " elif hasattr(value, \"item\"):\n", " value = value.item()\n", " metric_values.append(value)\n", "\n", " ds_key = md.key.lower()\n", " props = DATASET_PROPERTIES.get(ds_key, {})\n", " domain = props.get(\"domain\", \"unknown\")\n", " num_variates = props.get(\"num_variates\", 1 if md.to_univariate else md.target_dim)\n", "\n", " row = [md.full_name, model_name] + metric_values + [domain, num_variates]\n", " writer.writerow(row)\n", "\n", " if create_plots and item.figures and plt is not None:\n", " plots_dir = output_dir / \"plots\" / md.key / md.term\n", " plots_dir.mkdir(parents=True, exist_ok=True)\n", " for fig, filename in item.figures:\n", " filepath = plots_dir / filename\n", " fig.savefig(filepath, dpi=300, bbox_inches=\"tight\")\n", " plt.close(fig)\n", "\n", " logger.info(\n", " \"Evaluation complete for dataset '%s'. Results saved to %s\",\n", " dataset_name,\n", " output_csv_path,\n", " )\n", " if create_plots:\n", " logger.info(\"Plots saved under %s\", output_dir / \"plots\")\n", "\n", "\n", "def get_all_datasets_full_name() -> list[str]:\n", " \"\"\"Get all possible dataset full names for validation.\"\"\"\n", "\n", " terms = [\"short\", \"medium\", \"long\"]\n", " datasets_full_names: list[str] = []\n", "\n", " for name in ALL_DATASETS:\n", " for term in terms:\n", " if term in [\"medium\", \"long\"] and name not in MED_LONG_DATASETS:\n", " continue\n", "\n", " if \"/\" in name:\n", " ds_key, ds_freq = name.split(\"/\")\n", " ds_key = ds_key.lower()\n", " ds_key = PRETTY_NAMES.get(ds_key, ds_key)\n", " else:\n", " ds_key = name.lower()\n", " ds_key = PRETTY_NAMES.get(ds_key, ds_key)\n", " ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get(\"frequency\")\n", "\n", " datasets_full_names.append(f\"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}\")\n", "\n", " return datasets_full_names\n", "\n", "\n", "def aggregate_results(result_root_dir: str | Path) -> pd.DataFrame | None:\n", " \"\"\"Aggregate results from multiple CSV files into a single dataframe.\"\"\"\n", "\n", " result_root = Path(result_root_dir)\n", "\n", " logger.info(\"Aggregating results in: %s\", result_root)\n", "\n", " result_files = glob.glob(f\"{result_root}/**/results.csv\", recursive=True)\n", "\n", " if not result_files:\n", " logger.error(\"No result files found!\")\n", " return None\n", "\n", " dataframes: list[pd.DataFrame] = []\n", " for file in result_files:\n", " try:\n", " df = pd.read_csv(file)\n", " if len(df) > 0:\n", " dataframes.append(df)\n", " else:\n", " logger.warning(\"Empty file: %s\", file)\n", " except pd.errors.EmptyDataError:\n", " logger.warning(\"Skipping empty file: %s\", file)\n", " except Exception as exc:\n", " logger.error(\"Error reading %s: %s\", file, exc)\n", "\n", " if not dataframes:\n", " logger.warning(\"No valid CSV files found to combine\")\n", " return None\n", "\n", " combined_df = pd.concat(dataframes, ignore_index=True).sort_values(\"dataset\")\n", "\n", " if len(combined_df) != len(set(combined_df.dataset)):\n", " duplicate_datasets = combined_df.dataset[combined_df.dataset.duplicated()].tolist()\n", " logger.warning(\"Warning: Duplicate datasets found: %s\", duplicate_datasets)\n", " combined_df = combined_df.drop_duplicates(subset=[\"dataset\"], keep=\"first\")\n", " logger.info(\"Removed duplicates, %s unique datasets remaining\", len(combined_df))\n", "\n", " logger.info(\"Combined results: %s datasets\", len(combined_df))\n", "\n", " all_datasets_full_name = get_all_datasets_full_name()\n", " completed_experiments = combined_df.dataset.tolist()\n", "\n", " completed_experiments_clean = [exp for exp in completed_experiments if exp in all_datasets_full_name]\n", " missing_or_failed_experiments = [exp for exp in all_datasets_full_name if exp not in completed_experiments_clean]\n", "\n", " logger.info(\"=== EXPERIMENT SUMMARY ===\")\n", " logger.info(\"Total expected datasets: %s\", len(all_datasets_full_name))\n", " logger.info(\"Completed experiments: %s\", len(completed_experiments_clean))\n", " logger.info(\"Missing/failed experiments: %s\", len(missing_or_failed_experiments))\n", "\n", " output_file = result_root / \"all_results.csv\"\n", " combined_df.to_csv(output_file, index=False)\n", " logger.info(\"Combined results saved to: %s\", output_file)\n", "\n", " return combined_df" ] }, { "cell_type": "markdown", "id": "q5r6s7t8", "metadata": {}, "source": [ "### 3.6. Evaluation Harness (`evaluate.py`)\n", "\n", "This is the main evaluation logic that iterates over dataset terms, prepares the data, calls the predictor, and gathers metrics." ] }, { "cell_type": "code", "execution_count": null, "id": "r6s7t8u9", "metadata": {}, "outputs": [], "source": [ "def construct_evaluation_data(\n", " dataset_name: str,\n", " dataset_storage_path: str,\n", " terms: list[str] | None = None,\n", " max_windows: int | None = None,\n", ") -> list[tuple[Dataset, DatasetMetadata]]:\n", " \"\"\"Build datasets and rich metadata per term for a dataset name.\"\"\"\n", " # Avoid mutable default argument\n", " if terms is None:\n", " terms = [\"short\", \"medium\", \"long\"]\n", "\n", " sub_datasets: list[tuple[Dataset, DatasetMetadata]] = []\n", "\n", " if \"/\" in dataset_name:\n", " ds_key, ds_freq = dataset_name.split(\"/\")\n", " ds_key = ds_key.lower()\n", " ds_key = PRETTY_NAMES.get(ds_key, ds_key)\n", " else:\n", " ds_key = dataset_name.lower()\n", " ds_key = PRETTY_NAMES.get(ds_key, ds_key)\n", " ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get(\"frequency\")\n", "\n", " for term in terms:\n", " # Skip medium/long terms for datasets that don't support them\n", " if (term == \"medium\" or term == \"long\") and dataset_name not in MED_LONG_DATASETS:\n", " continue\n", "\n", " # Probe once to determine dimensionality\n", " probe_dataset = Dataset(\n", " name=dataset_name,\n", " term=term,\n", " to_univariate=False,\n", " storage_path=dataset_storage_path,\n", " max_windows=max_windows,\n", " )\n", "\n", " to_univariate = probe_dataset.target_dim > 1\n", "\n", " dataset = Dataset(\n", " name=dataset_name,\n", " term=term,\n", " to_univariate=to_univariate,\n", " storage_path=dataset_storage_path,\n", " max_windows=max_windows,\n", " )\n", "\n", " # Compute metadata\n", " season_length = get_seasonality(dataset.freq)\n", " actual_freq = ds_freq if ds_freq else dataset.freq\n", "\n", " metadata = DatasetMetadata(\n", " full_name=f\"{ds_key}/{actual_freq}/{term}\",\n", " key=ds_key,\n", " freq=actual_freq,\n", " term=term,\n", " season_length=season_length,\n", " target_dim=probe_dataset.target_dim,\n", " to_univariate=to_univariate,\n", " prediction_length=dataset.prediction_length,\n", " windows=dataset.windows,\n", " )\n", "\n", " sub_datasets.append((dataset, metadata))\n", "\n", " return sub_datasets\n", "\n", "\n", "def evaluate_datasets(\n", " predictor: TimeSeriesPredictor,\n", " dataset: str,\n", " dataset_storage_path: str,\n", " terms: list[str] | None = None,\n", " max_windows: int | None = None,\n", " batch_size: int = 48,\n", " max_context_length: int | None = 1024,\n", " create_plots: bool = False,\n", " max_plots_per_dataset: int = 10,\n", ") -> list[EvaluationItem]:\n", " \"\"\"Evaluate predictor on one dataset across the requested terms.\"\"\"\n", " # Avoid mutable default argument\n", " if terms is None:\n", " terms = [\"short\", \"medium\", \"long\"]\n", "\n", " sub_datasets = construct_evaluation_data(\n", " dataset_name=dataset,\n", " dataset_storage_path=dataset_storage_path,\n", " terms=terms,\n", " max_windows=max_windows,\n", " )\n", "\n", " results: list[EvaluationItem] = []\n", " for i, (sub_dataset, metadata) in enumerate(sub_datasets):\n", " logger.info(f\"Evaluating {i + 1}/{len(sub_datasets)}: {metadata.full_name}\")\n", " logger.info(f\" Dataset size: {len(sub_dataset.test_data)}\")\n", " logger.info(f\" Frequency: {sub_dataset.freq}\")\n", " logger.info(f\" Term: {metadata.term}\")\n", " logger.info(f\" Prediction length: {sub_dataset.prediction_length}\")\n", " logger.info(f\" Target dimensions: {sub_dataset.target_dim}\")\n", " logger.info(f\" Windows: {sub_dataset.windows}\")\n", "\n", " # Update context on the reusable predictor\n", " predictor.set_dataset_context(\n", " prediction_length=sub_dataset.prediction_length,\n", " freq=sub_dataset.freq,\n", " batch_size=batch_size,\n", " max_context_length=max_context_length,\n", " )\n", "\n", " res = evaluate_model(\n", " model=predictor,\n", " test_data=sub_dataset.test_data,\n", " metrics=METRICS,\n", " axis=None,\n", " mask_invalid_label=True,\n", " allow_nan_forecast=False,\n", " seasonality=metadata.season_length,\n", " )\n", "\n", " figs: list[tuple[object, str]] = []\n", " if create_plots:\n", " # We are missing `src.plotting.gift_eval_utils.create_plots_for_dataset`\n", " # As this was not provided, plotting will be skipped.\n", " logger.warning(\n", " \"Plotting is enabled but `create_plots_for_dataset` is not defined. Skipping plot generation.\"\n", " )\n", " pass\n", "\n", " results.append(EvaluationItem(dataset_metadata=metadata, metrics=res, figures=figs))\n", "\n", " return results" ] }, { "cell_type": "markdown", "id": "s7t8u9v0", "metadata": {}, "source": [ "## 4. Configuration\n", "\n", "Set the parameters for the evaluation run. The script will load the model from the local `models/` directory by default." ] }, { "cell_type": "code", "execution_count": null, "id": "t8u9v0w1", "metadata": {}, "outputs": [], "source": [ "# --- Parameters ---\n", "# Assumes the notebook is run from the root of the repo\n", "model_path = Path.cwd() / \"models/checkpoint_38M.pth\"\n", "config_path = Path.cwd() / \"configs/example.yaml\"\n", "\n", "# --- Datasets and evaluation controls ---\n", "# Use a small subset for testing, e.g., [\"m4_weekly\"]\n", "datasets_arg = [\"all\"] # list of dataset names or [\"all\"].\n", "terms = [\"short\", \"medium\", \"long\"]\n", "dataset_storage_path = os.getenv(\"GIFT_EVAL_DATASET_STORAGE_PATH\")\n", "max_windows = 20\n", "batch_size = 64\n", "max_context_length = 3072\n", "\n", "# --- Output ---\n", "after_each_dataset_flush = True # write CSV as each dataset completes\n", "model_name = \"TempoPFN\"\n", "output_dir = Path.cwd() / \"gift_eval_results\" / model_name\n", "\n", "\n", "# --- Helper Functions ---\n", "def _load_yaml(path: str) -> dict:\n", " with open(path) as f:\n", " return yaml.safe_load(f)" ] }, { "cell_type": "markdown", "id": "u9v0w1x2", "metadata": {}, "source": [ "## 5. Main Evaluation Loop\n", "\n", "This cell sets up the predictor and runs the main evaluation loop over all specified datasets." ] }, { "cell_type": "code", "execution_count": null, "id": "v0w1x2y3", "metadata": {}, "outputs": [], "source": [ "logger.info(\"Starting evaluation for model: %s\", model_name)\n", "\n", "# 1. Build predictor from a checkpoint\n", "resolved_model_path = Path(model_path)\n", "\n", "if not resolved_model_path.exists():\n", " logger.error(f\"Model checkpoint not found at: {resolved_model_path}\")\n", " logger.error(\"Please ensure the file exists and you've cloned the repo using Git LFS.\")\n", " raise FileNotFoundError(f\"No model checkpoint found. Set `model_path` correctly. Tried: {resolved_model_path}\")\n", "\n", "assert Path(config_path).exists(), f\"Config not found: {config_path}\"\n", "logger.info(\"Loading predictor from checkpoint: %s\", resolved_model_path)\n", "\n", "predictor = TimeSeriesPredictor.from_paths(\n", " model_path=str(resolved_model_path),\n", " config_path=str(config_path),\n", " ds_prediction_length=1, # placeholder; set per dataset\n", " ds_freq=\"D\", # placeholder; set per dataset\n", " batch_size=batch_size,\n", " max_context_length=max_context_length,\n", ")\n", "\n", "# 2. Run evaluation loop\n", "datasets_to_run = expand_datasets_arg(datasets_arg)\n", "results_root = Path(output_dir)\n", "\n", "for ds_name in datasets_to_run:\n", " try:\n", " items = evaluate_datasets(\n", " predictor=predictor,\n", " dataset=ds_name,\n", " dataset_storage_path=dataset_storage_path,\n", " terms=terms,\n", " max_windows=max_windows,\n", " batch_size=batch_size,\n", " max_context_length=max_context_length,\n", " create_plots=False, # Set to True if you implement plotting\n", " max_plots_per_dataset=0,\n", " )\n", " write_results_to_disk(\n", " items=items,\n", " dataset_name=ds_name,\n", " output_dir=results_root,\n", " model_name=model_name,\n", " create_plots=False,\n", " )\n", " if after_each_dataset_flush:\n", " logger.info(\"Flushed results for %s\", ds_name)\n", " except Exception as e:\n", " logger.error(f\"FAILED evaluation for dataset: {ds_name}. Error: {e} !!!\")\n", " logger.exception(e)\n", " continue # Continue to the next dataset\n", "\n", "print(f\"\\nEvaluation complete. See results under: {output_dir}\")" ] }, { "cell_type": "markdown", "id": "w1x2y3z4", "metadata": {}, "source": [ "## 6. Aggregate Results\n", "\n", "Finally, we'll aggregate the individual CSV files into a single `all_results.csv` file for easy analysis, following the `gift-eval` convention." ] }, { "cell_type": "code", "execution_count": null, "id": "x2y3z4a5", "metadata": {}, "outputs": [], "source": [ "logger.info(\"Aggregating results from all datasets...\")\n", "combined_df = aggregate_results(result_root_dir=output_dir)\n", "\n", "if combined_df is not None:\n", " agg_path = Path(output_dir) / \"all_results.csv\"\n", " logger.info(\"Successfully created aggregated results file: %s\", agg_path)\n", " print(f\"\\n✅ Aggregated results saved to: {agg_path}\")\n", " print(combined_df.head())\n", "else:\n", " logger.warning(\"No results to aggregate. Check that evaluation completed successfully.\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" } }, "nbformat": 4, "nbformat_minor": 5 }