# Running TempoPFN on GIFT-Eval Benchmark

This notebook evaluates the **TempoPFN** model on the GIFT-Eval benchmark. 

Make sure you download the gift-eval benchmark and set the `GIFT_EVAL_DATASET_STORAGE_PATH` environment variable correctly before running this notebook.

## 1. Setup and Dependencies

First, install the required packages. 

**Note:** This notebook assumes that the core `TempoPFN` model code (e.g., `src.models.model`, `src.data.containers`) and dependencies are installed as a Python package or are otherwise available in the `PYTHONPATH`.

## 2. Imports

Import all necessary libraries. 

In [None]:
import csv
import glob
import json
import logging
import math
import os
import warnings
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from enum import Enum
from functools import cached_property
from pathlib import Path

# GluonTS and Data Handling
import datasets

# Plotting and Warnings
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow.compute as pc
import torch
import yaml
from dotenv import load_dotenv
from gluonts.dataset import DataEntry
from gluonts.dataset.common import ProcessDataEntry
from gluonts.dataset.split import TestData, TrainingDataset, split

# GluonTS Evaluation
from gluonts.ev.metrics import (
 MAE,
 MAPE,
 MASE,
 MSE,
 MSIS,
 ND,
 NRMSE,
 RMSE,
 SMAPE,
 MeanWeightedSumQuantileLoss,
)
from gluonts.itertools import Map
from gluonts.model.evaluation import evaluate_model
from gluonts.model.forecast import QuantileForecast
from gluonts.model.predictor import Predictor
from gluonts.time_feature import get_seasonality, norm_freq_str
from gluonts.transform import Transformation
from linear_operator.utils.cholesky import NumericalWarning
from pandas.tseries.frequencies import to_offset

# --- TempoPFN Core Model Imports ---
# These are assumed to be installed or in the PYTHONPATH
from src.data.containers import BatchTimeSeriesContainer
from src.data.frequency import parse_frequency
from src.data.scalers import RobustScaler
from src.models.model import TimeSeriesModel
from src.utils.utils import device
from toolz import compose
from torch.nn.parallel import DistributedDataParallel as DDP

# --- Setup Logging ---
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.getLogger("matplotlib.font_manager").setLevel(logging.WARNING)
logging.getLogger("PIL").setLevel(logging.WARNING)
logger = logging.getLogger("gift_eval_runner")


# Filter out specific gluonts warnings
class WarningFilter(logging.Filter):
 def __init__(self, text_to_filter: str) -> None:
 super().__init__()
 self.text_to_filter = text_to_filter

 def filter(self, record: logging.LogRecord) -> bool:
 return self.text_to_filter not in record.getMessage()


gts_logger = logging.getLogger("gluonts.model.forecast")
gts_logger.addFilter(WarningFilter("The mean prediction is not stored in the forecast data"))

# Filter out numerical warnings
warnings.filterwarnings("ignore", category=NumericalWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)

# Load environment variables (e.g., GIFT_EVAL_DATASET_STORAGE_PATH)
load_dotenv()

## 3. Constants and Configuration

Define dataset lists, metrics, and other constants following GIFT-Eval standards.

### 3.1. Constants 

In [None]:
# Environment setup
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"

# Use absolute path relative to the project root
_MODULE_DIR = Path.cwd().parent.parent # Assumes notebook is in `examples/gift_eval/`
DATASET_PROPERTIES_PATH = _MODULE_DIR / "data" / "dataset_properties.json"

try:
 with open(DATASET_PROPERTIES_PATH) as f:
 DATASET_PROPERTIES = json.load(f)
except Exception as exc: # pragma: no cover - logging path
 DATASET_PROPERTIES = {}
 logger.warning(
 "Could not load dataset properties from %s: %s. Domain and num_variates will fall back to defaults.",
 DATASET_PROPERTIES_PATH,
 exc,
 )

# Datasets
SHORT_DATASETS = (
 "m4_yearly",
 "m4_quarterly",
 "m4_monthly",
 "m4_weekly",
 "m4_daily",
 "m4_hourly",
 "electricity/15T",
 "electricity/H",
 "electricity/D",
 "electricity/W",
 "solar/10T",
 "solar/H",
 "solar/D",
 "solar/W",
 "hospital",
 "covid_deaths",
 "us_births/D",
 "us_births/M",
 "us_births/W",
 "saugeenday/D",
 "saugeenday/M",
 "saugeenday/W",
 "temperature_rain_with_missing",
 "kdd_cup_2018_with_missing/H",
 "kdd_cup_2018_with_missing/D",
 "car_parts_with_missing",
 "restaurant",
 "hierarchical_sales/D",
 "hierarchical_sales/W",
 "LOOP_SEATTLE/5T",
 "LOOP_SEATTLE/H",
 "LOOP_SEATTLE/D",
 "SZ_TAXI/15T",
 "SZ_TAXI/H",
 "M_DENSE/H",
 "M_DENSE/D",
 "ett1/15T",
 "ett1/H",
 "ett1/D",
 "ett1/W",
 "ett2/15T",
 "ett2/H",
 "ett2/D",
 "ett2/W",
 "jena_weather/10T",
 "jena_weather/H",
 "jena_weather/D",
 "bitbrains_fast_storage/5T",
 "bitbrains_fast_storage/H",
 "bitbrains_rnd/5T",
 "bitbrains_rnd/H",
 "bizitobs_application",
 "bizitobs_service",
 "bizitobs_l2c/5T",
 "bizitobs_l2c/H",
)

MED_LONG_DATASETS = (
 "electricity/15T",
 "electricity/H",
 "solar/10T",
 "solar/H",
 "kdd_cup_2018_with_missing/H",
 "LOOP_SEATTLE/5T",
 "LOOP_SEATTLE/H",
 "SZ_TAXI/15T",
 "M_DENSE/H",
 "ett1/15T",
 "ett1/H",
 "ett2/15T",
 "ett2/H",
 "jena_weather/10T",
 "jena_weather/H",
 "bitbrains_fast_storage/5T",
 "bitbrains_rnd/5T",
 "bizitobs_application",
 "bizitobs_service",
 "bizitobs_l2c/5T",
 "bizitobs_l2c/H",
)

# Preserve insertion order
ALL_DATASETS = list(dict.fromkeys(SHORT_DATASETS + MED_LONG_DATASETS))

# Evaluation terms
TERMS = ("short", "medium", "long")

# Pretty names mapping
PRETTY_NAMES = {
 "saugeenday": "saugeen",
 "temperature_rain_with_missing": "temperature_rain",
 "kdd_cup_2018_with_missing": "kdd_cup_2018",
 "car_parts_with_missing": "car_parts",
}

# Metrics
METRICS = (
 MSE(forecast_type="mean"),
 MSE(forecast_type=0.5),
 MAE(),
 MASE(),
 MAPE(),
 SMAPE(),
 MSIS(),
 RMSE(),
 NRMSE(),
 ND(),
 MeanWeightedSumQuantileLoss(quantile_levels=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]),
)

# Standard metric names for CSV header
STANDARD_METRIC_NAMES = (
 "MSE[mean]",
 "MSE[0.5]",
 "MAE[0.5]",
 "MASE[0.5]",
 "MAPE[0.5]",
 "sMAPE[0.5]",
 "MSIS",
 "RMSE[mean]",
 "NRMSE[mean]",
 "ND[0.5]",
 "mean_weighted_sum_quantile_loss",
)

### 3.2. Core Data Structures 

In [None]:
@dataclass
class DatasetMetadata:
 """Structured description of a dataset/term combination."""

 full_name: str
 key: str
 freq: str
 term: str
 season_length: int
 target_dim: int
 to_univariate: bool
 prediction_length: int
 windows: int


@dataclass
class EvaluationItem:
 """Container for evaluation results and optional figures."""

 dataset_metadata: DatasetMetadata
 metrics: dict
 figures: list[tuple[object, str]]


DatasetSelection = list[str] | tuple[str, ...] | str


def expand_datasets_arg(datasets: DatasetSelection) -> list[str]:
 """Normalize dataset selection strings to explicit lists."""

 if isinstance(datasets, str):
 dataset_list = [datasets]
 else:
 dataset_list = list(datasets)

 if not dataset_list:
 return []

 if dataset_list[0] == "all":
 return list(ALL_DATASETS)

 for dataset in dataset_list:
 if dataset not in ALL_DATASETS:
 raise ValueError(f"Invalid dataset: {dataset}. Use one of {ALL_DATASETS}")

 return dataset_list

### 3.3. GIFT-Eval Dataset Class (`data.py`)

The `Dataset` class handles loading and preprocessing GIFT-Eval benchmark datasets. This implementation is adapted from the official GIFT-Eval repository.

In [None]:
TEST_SPLIT = 0.1
MAX_WINDOW = 20

M4_PRED_LENGTH_MAP = {
 "A": 6,
 "Q": 8,
 "M": 18,
 "W": 13,
 "D": 14,
 "H": 48,
 "h": 48,
 "Y": 6,
}

PRED_LENGTH_MAP = {
 "M": 12,
 "W": 8,
 "D": 30,
 "H": 48,
 "h": 48,
 "T": 48,
 "S": 60,
 "s": 60,
 "min": 48,
}

TFB_PRED_LENGTH_MAP = {
 "A": 6,
 "Y": 6,
 "H": 48,
 "h": 48,
 "Q": 8,
 "D": 14,
 "M": 18,
 "W": 13,
 "U": 8,
 "T": 8,
 "min": 8,
 "us": 8,
}


class Term(Enum):
 SHORT = "short"
 MEDIUM = "medium"
 LONG = "long"

 @property
 def multiplier(self) -> int:
 if self == Term.SHORT:
 return 1
 elif self == Term.MEDIUM:
 return 10
 elif self == Term.LONG:
 return 15


def itemize_start(data_entry: DataEntry) -> DataEntry:
 data_entry["start"] = data_entry["start"].item()
 return data_entry


class MultivariateToUnivariate(Transformation):
 def __init__(self, field):
 self.field = field

 def __call__(self, data_it: Iterable[DataEntry], is_train: bool = False) -> Iterator:
 for data_entry in data_it:
 item_id = data_entry["item_id"]
 val_ls = list(data_entry[self.field])
 for id, val in enumerate(val_ls):
 univariate_entry = data_entry.copy()
 univariate_entry[self.field] = val
 univariate_entry["item_id"] = item_id + "_dim" + str(id)
 yield univariate_entry


class Dataset:
 def __init__(
 self,
 name: str,
 term: Term | str = Term.SHORT,
 to_univariate: bool = False,
 storage_path: str = None,
 max_windows: int | None = None,
 ):
 storage_path = Path(storage_path)
 self.hf_dataset = datasets.load_from_disk(str(storage_path / name)).with_format("numpy")
 process = ProcessDataEntry(
 self.freq,
 one_dim_target=self.target_dim == 1,
 )

 self.gluonts_dataset = Map(compose(process, itemize_start), self.hf_dataset)
 if to_univariate:
 self.gluonts_dataset = MultivariateToUnivariate("target").apply(self.gluonts_dataset)

 self.term = Term(term)
 self.name = name
 self.max_windows = max_windows if max_windows is not None else MAX_WINDOW

 @cached_property
 def prediction_length(self) -> int:
 freq = norm_freq_str(to_offset(self.freq).name)
 if freq.endswith("E"):
 freq = freq[:-1]
 pred_len = M4_PRED_LENGTH_MAP[freq] if "m4" in self.name else PRED_LENGTH_MAP[freq]
 return self.term.multiplier * pred_len

 @cached_property
 def freq(self) -> str:
 return self.hf_dataset[0]["freq"]

 @cached_property
 def target_dim(self) -> int:
 return target.shape[0] if len((target := self.hf_dataset[0]["target"]).shape) > 1 else 1

 @cached_property
 def past_feat_dynamic_real_dim(self) -> int:
 if "past_feat_dynamic_real" not in self.hf_dataset[0]:
 return 0
 elif len((past_feat_dynamic_real := self.hf_dataset[0]["past_feat_dynamic_real"]).shape) > 1:
 return past_feat_dynamic_real.shape[0]
 else:
 return 1

 @cached_property
 def windows(self) -> int:
 if "m4" in self.name:
 return 1
 w = math.ceil(TEST_SPLIT * self._min_series_length / self.prediction_length)
 return min(max(1, w), self.max_windows)

 @cached_property
 def _min_series_length(self) -> int:
 if self.hf_dataset[0]["target"].ndim > 1:
 lengths = pc.list_value_length(pc.list_flatten(pc.list_slice(self.hf_dataset.data.column("target"), 0, 1)))
 else:
 lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
 return min(lengths.to_numpy())

 @cached_property
 def sum_series_length(self) -> int:
 if self.hf_dataset[0]["target"].ndim > 1:
 lengths = pc.list_value_length(pc.list_flatten(self.hf_dataset.data.column("target")))
 else:
 lengths = pc.list_value_length(self.hf_dataset.data.column("target"))
 return sum(lengths.to_numpy())

 @property
 def training_dataset(self) -> TrainingDataset:
 training_dataset, _ = split(self.gluonts_dataset, offset=-self.prediction_length * (self.windows + 1))
 return training_dataset

 @property
 def validation_dataset(self) -> TrainingDataset:
 validation_dataset, _ = split(self.gluonts_dataset, offset=-self.prediction_length * self.windows)
 return validation_dataset

 @property
 def test_data(self) -> TestData:
 _, test_template = split(self.gluonts_dataset, offset=-self.prediction_length * self.windows)
 test_data = test_template.generate_instances(
 prediction_length=self.prediction_length,
 windows=self.windows,
 distance=self.prediction_length,
 )
 return test_data

### 3.4. Predictor Wrapper (`predictor.py`)

This is the model-specific `TimeSeriesPredictor` class for `TempoPFN`. It wraps the core `TimeSeriesModel` and adapts it to the `gluonts`-style `Predictor` interface, which expects a `.predict()` method.

In [None]:
class TimeSeriesPredictor(Predictor):
 """Unified predictor for TimeSeriesModel supporting flexible construction."""

 def __init__(
 self,
 model: TimeSeriesModel,
 config: dict,
 ds_prediction_length: int,
 ds_freq: str,
 batch_size: int = 32,
 max_context_length: int | None = None,
 debug: bool = False,
 ) -> None:
 # Dataset-specific context (can be updated per dataset/term)
 self.ds_prediction_length = ds_prediction_length
 self.ds_freq = ds_freq
 self.batch_size = batch_size
 self.max_context_length = max_context_length
 self.debug = debug

 # Persistent model/config (unwrap DDP if needed)
 self.model = model.module if isinstance(model, DDP) else model
 self.model.eval()
 self.config = config

 # Initialize scaler (using same type as model)
 scaler_type = self.config.get("TimeSeriesModel", {}).get("scaler", "custom_robust")
 epsilon = self.config.get("TimeSeriesModel", {}).get("epsilon", 1e-3)
 if scaler_type == "custom_robust":
 self.scaler = RobustScaler(epsilon=epsilon)
 else:
 raise ValueError(f"Unsupported scaler type: {scaler_type}")

 def set_dataset_context(
 self,
 prediction_length: int | None = None,
 freq: str | None = None,
 batch_size: int | None = None,
 max_context_length: int | None = None,
 ) -> None:
 """Update lightweight dataset-specific attributes without reloading the model."""

 if prediction_length is not None:
 self.ds_prediction_length = prediction_length
 if freq is not None:
 self.ds_freq = freq
 if batch_size is not None:
 self.batch_size = batch_size
 if max_context_length is not None:
 self.max_context_length = max_context_length

 @classmethod
 def from_model(
 cls,
 model: TimeSeriesModel,
 config: dict,
 ds_prediction_length: int,
 ds_freq: str,
 batch_size: int = 32,
 max_context_length: int | None = None,
 debug: bool = False,
 ) -> "TimeSeriesPredictor":
 return cls(
 model=model,
 config=config,
 ds_prediction_length=ds_prediction_length,
 ds_freq=ds_freq,
 batch_size=batch_size,
 max_context_length=max_context_length,
 debug=debug,
 )

 @classmethod
 def from_paths(
 cls,
 model_path: str,
 config_path: str,
 ds_prediction_length: int,
 ds_freq: str,
 batch_size: int = 32,
 max_context_length: int | None = None,
 debug: bool = False,
 ) -> "TimeSeriesPredictor":
 with open(config_path) as f:
 config = yaml.safe_load(f)
 model = cls._load_model_from_path(config=config, model_path=model_path)
 return cls(
 model=model,
 config=config,
 ds_prediction_length=ds_prediction_length,
 ds_freq=ds_freq,
 batch_size=batch_size,
 max_context_length=max_context_length,
 debug=debug,
 )

 @staticmethod
 def _load_model_from_path(config: dict, model_path: str) -> TimeSeriesModel:
 try:
 model = TimeSeriesModel(**config["TimeSeriesModel"]).to(device)
 checkpoint = torch.load(model_path, map_location=device)
 model.load_state_dict(checkpoint["model_state_dict"])
 model.eval()
 logger.info(f"Successfully loaded model from {model_path}")
 return model
 except Exception as exc: # pragma: no cover - logging path
 logger.error(f"Failed to load model from {model_path}: {exc}")
 raise

 def predict(self, test_data_input) -> Iterator[QuantileForecast]:
 """Generate forecasts for the test data."""

 if hasattr(test_data_input, "__iter__") and not isinstance(test_data_input, list):
 test_data_input = list(test_data_input)
 logger.debug(f"Processing {len(test_data_input)} time series")

 # Group series by their effective length (after optional truncation),
 # then process each uniform-length group in sub-batches up to batch_size.
 def _effective_length(entry) -> int:
 target = entry["target"]
 if target.ndim == 1:
 seq_len = len(target)
 else:
 # target shape is [num_channels, seq_len]
 seq_len = target.shape[1]
 if self.max_context_length is not None:
 seq_len = min(seq_len, self.max_context_length)
 return seq_len

 length_to_items: dict[int, list[tuple[int, object]]] = {}
 for idx, entry in enumerate(test_data_input):
 seq_len = _effective_length(entry)
 length_to_items.setdefault(seq_len, []).append((idx, entry))

 total = len(test_data_input)
 ordered_results: list[QuantileForecast | None] = [None] * total

 for _, items in length_to_items.items():
 for i in range(0, len(items), self.batch_size):
 chunk = items[i : i + self.batch_size]
 entries = [entry for (_orig_idx, entry) in chunk]
 batch_forecasts = self._predict_batch(entries)
 for forecast_idx, (orig_idx, _entry) in enumerate(chunk):
 ordered_results[orig_idx] = batch_forecasts[forecast_idx]

 return ordered_results # type: ignore[return-value]

 def _predict_batch(self, test_data_batch: list) -> list[QuantileForecast]:
 """Generate predictions for a batch of time series."""

 logger.debug(f"Processing batch of size: {len(test_data_batch)}")

 try:
 batch_container = self._convert_to_batch_container(test_data_batch)

 if isinstance(device, torch.device):
 device_type = device.type
 else:
 device_type = "cuda" if "cuda" in str(device).lower() else "cpu"
 enable_autocast = device_type == "cuda"

 with torch.autocast(
 device_type=device_type,
 dtype=torch.bfloat16,
 enabled=enable_autocast,
 ):
 with torch.no_grad():
 model_output = self.model(batch_container, drop_enc_allow=False)

 forecasts = self._convert_to_forecasts(model_output, test_data_batch, batch_container)

 logger.debug(f"Generated {len(forecasts)} forecasts")
 return forecasts
 except Exception as exc: # pragma: no cover - logging path
 logger.error(f"Error in batch prediction: {exc}")
 raise

 def _convert_to_batch_container(self, test_data_batch: list) -> BatchTimeSeriesContainer:
 """Convert gluonts test data to BatchTimeSeriesContainer."""

 batch_size = len(test_data_batch)
 history_values_list = []
 start_dates = []
 frequencies = []

 for entry in test_data_batch:
 target = entry["target"]

 if target.ndim == 1:
 target = target.reshape(-1, 1)
 else:
 target = target.T

 if self.max_context_length is not None and len(target) > self.max_context_length:
 target = target[-self.max_context_length :]

 history_values_list.append(target)
 start_dates.append(entry["start"].to_timestamp().to_datetime64())
 frequencies.append(parse_frequency(entry["freq"]))

 history_values_np = np.stack(history_values_list, axis=0)
 num_channels = history_values_np.shape[2]

 history_values = torch.tensor(history_values_np, dtype=torch.float32, device=device)

 future_values = torch.zeros(
 (batch_size, self.ds_prediction_length, num_channels),
 dtype=torch.float32,
 device=device,
 )

 return BatchTimeSeriesContainer(
 history_values=history_values,
 future_values=future_values,
 start=start_dates,
 frequency=frequencies,
 )

 def _convert_to_forecasts(
 self,
 model_output: dict,
 test_data_batch: list,
 batch_container: BatchTimeSeriesContainer,
 ) -> list[QuantileForecast]:
 """Convert model predictions to QuantileForecast objects."""

 predictions = model_output["result"]
 scale_statistics = model_output["scale_statistics"]

 if predictions.ndim == 4:
 predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics)
 is_quantile = True
 quantile_levels = self.model.quantiles
 else:
 predictions_unscaled = self.scaler.inverse_scale(predictions, scale_statistics)
 is_quantile = False
 quantile_levels = [0.5]

 forecasts: list[QuantileForecast] = []
 for idx, entry in enumerate(test_data_batch):
 history_length = int(batch_container.history_values.shape[1])
 start_date = entry["start"]
 forecast_start = start_date + history_length

 if is_quantile:
 pred_array = predictions_unscaled[idx].cpu().numpy()

 if pred_array.shape[1] == 1:
 pred_array = pred_array.squeeze(1)
 forecast_arrays = pred_array.T
 else:
 forecast_arrays = pred_array.transpose(2, 0, 1)

 forecast = QuantileForecast(
 forecast_arrays=forecast_arrays,
 forecast_keys=[str(q) for q in quantile_levels],
 start_date=forecast_start,
 )
 else:
 pred_array = predictions_unscaled[idx].cpu().numpy()

 if pred_array.shape[1] == 1:
 pred_array = pred_array.squeeze(1)
 forecast_arrays = pred_array.reshape(1, -1)
 else:
 forecast_arrays = pred_array.reshape(1, *pred_array.shape)

 forecast = QuantileForecast(
 forecast_arrays=forecast_arrays,
 forecast_keys=["0.5"],
 start_date=forecast_start,
 )

 forecasts.append(forecast)

 return forecasts

### 3.5. Result Handling 

These functions handle writing the per-dataset metrics to CSV files and aggregating all results into a single `all_results.csv` at the end.

In [None]:
def _ensure_results_csv(csv_file_path: Path) -> None:
 if not csv_file_path.exists():
 csv_file_path.parent.mkdir(parents=True, exist_ok=True)
 with open(csv_file_path, "w", newline="") as csvfile:
 writer = csv.writer(csvfile)
 header = (
 ["dataset", "model"]
 + [f"eval_metrics/{name}" for name in STANDARD_METRIC_NAMES]
 + ["domain", "num_variates"]
 )
 writer.writerow(header)


def write_results_to_disk(
 items: list[EvaluationItem],
 dataset_name: str,
 output_dir: Path,
 model_name: str,
 create_plots: bool,
) -> None:
 output_dir = output_dir / dataset_name
 output_dir.mkdir(parents=True, exist_ok=True)
 output_csv_path = output_dir / "results.csv"
 _ensure_results_csv(output_csv_path)

 with open(output_csv_path, "a", newline="") as csvfile:
 writer = csv.writer(csvfile)
 for item in items:
 md: DatasetMetadata = item.dataset_metadata
 metric_values: list[float | None] = []
 for metric_name in STANDARD_METRIC_NAMES:
 value = item.metrics.get(metric_name, None)
 if value is None:
 metric_values.append(None)
 else:
 if hasattr(value, "__len__") and not isinstance(value, (str, bytes)) and len(value) == 1:
 value = value[0]
 elif hasattr(value, "item"):
 value = value.item()
 metric_values.append(value)

 ds_key = md.key.lower()
 props = DATASET_PROPERTIES.get(ds_key, {})
 domain = props.get("domain", "unknown")
 num_variates = props.get("num_variates", 1 if md.to_univariate else md.target_dim)

 row = [md.full_name, model_name] + metric_values + [domain, num_variates]
 writer.writerow(row)

 if create_plots and item.figures and plt is not None:
 plots_dir = output_dir / "plots" / md.key / md.term
 plots_dir.mkdir(parents=True, exist_ok=True)
 for fig, filename in item.figures:
 filepath = plots_dir / filename
 fig.savefig(filepath, dpi=300, bbox_inches="tight")
 plt.close(fig)

 logger.info(
 "Evaluation complete for dataset '%s'. Results saved to %s",
 dataset_name,
 output_csv_path,
 )
 if create_plots:
 logger.info("Plots saved under %s", output_dir / "plots")


def get_all_datasets_full_name() -> list[str]:
 """Get all possible dataset full names for validation."""

 terms = ["short", "medium", "long"]
 datasets_full_names: list[str] = []

 for name in ALL_DATASETS:
 for term in terms:
 if term in ["medium", "long"] and name not in MED_LONG_DATASETS:
 continue

 if "/" in name:
 ds_key, ds_freq = name.split("/")
 ds_key = ds_key.lower()
 ds_key = PRETTY_NAMES.get(ds_key, ds_key)
 else:
 ds_key = name.lower()
 ds_key = PRETTY_NAMES.get(ds_key, ds_key)
 ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get("frequency")

 datasets_full_names.append(f"{ds_key}/{ds_freq if ds_freq else 'unknown'}/{term}")

 return datasets_full_names


def aggregate_results(result_root_dir: str | Path) -> pd.DataFrame | None:
 """Aggregate results from multiple CSV files into a single dataframe."""

 result_root = Path(result_root_dir)

 logger.info("Aggregating results in: %s", result_root)

 result_files = glob.glob(f"{result_root}/**/results.csv", recursive=True)

 if not result_files:
 logger.error("No result files found!")
 return None

 dataframes: list[pd.DataFrame] = []
 for file in result_files:
 try:
 df = pd.read_csv(file)
 if len(df) > 0:
 dataframes.append(df)
 else:
 logger.warning("Empty file: %s", file)
 except pd.errors.EmptyDataError:
 logger.warning("Skipping empty file: %s", file)
 except Exception as exc:
 logger.error("Error reading %s: %s", file, exc)

 if not dataframes:
 logger.warning("No valid CSV files found to combine")
 return None

 combined_df = pd.concat(dataframes, ignore_index=True).sort_values("dataset")

 if len(combined_df) != len(set(combined_df.dataset)):
 duplicate_datasets = combined_df.dataset[combined_df.dataset.duplicated()].tolist()
 logger.warning("Warning: Duplicate datasets found: %s", duplicate_datasets)
 combined_df = combined_df.drop_duplicates(subset=["dataset"], keep="first")
 logger.info("Removed duplicates, %s unique datasets remaining", len(combined_df))

 logger.info("Combined results: %s datasets", len(combined_df))

 all_datasets_full_name = get_all_datasets_full_name()
 completed_experiments = combined_df.dataset.tolist()

 completed_experiments_clean = [exp for exp in completed_experiments if exp in all_datasets_full_name]
 missing_or_failed_experiments = [exp for exp in all_datasets_full_name if exp not in completed_experiments_clean]

 logger.info("=== EXPERIMENT SUMMARY ===")
 logger.info("Total expected datasets: %s", len(all_datasets_full_name))
 logger.info("Completed experiments: %s", len(completed_experiments_clean))
 logger.info("Missing/failed experiments: %s", len(missing_or_failed_experiments))

 output_file = result_root / "all_results.csv"
 combined_df.to_csv(output_file, index=False)
 logger.info("Combined results saved to: %s", output_file)

 return combined_df

### 3.6. Evaluation Harness (`evaluate.py`)

This is the main evaluation logic that iterates over dataset terms, prepares the data, calls the predictor, and gathers metrics.

In [None]:
def construct_evaluation_data(
 dataset_name: str,
 dataset_storage_path: str,
 terms: list[str] | None = None,
 max_windows: int | None = None,
) -> list[tuple[Dataset, DatasetMetadata]]:
 """Build datasets and rich metadata per term for a dataset name."""
 # Avoid mutable default argument
 if terms is None:
 terms = ["short", "medium", "long"]

 sub_datasets: list[tuple[Dataset, DatasetMetadata]] = []

 if "/" in dataset_name:
 ds_key, ds_freq = dataset_name.split("/")
 ds_key = ds_key.lower()
 ds_key = PRETTY_NAMES.get(ds_key, ds_key)
 else:
 ds_key = dataset_name.lower()
 ds_key = PRETTY_NAMES.get(ds_key, ds_key)
 ds_freq = DATASET_PROPERTIES.get(ds_key, {}).get("frequency")

 for term in terms:
 # Skip medium/long terms for datasets that don't support them
 if (term == "medium" or term == "long") and dataset_name not in MED_LONG_DATASETS:
 continue

 # Probe once to determine dimensionality
 probe_dataset = Dataset(
 name=dataset_name,
 term=term,
 to_univariate=False,
 storage_path=dataset_storage_path,
 max_windows=max_windows,
 )

 to_univariate = probe_dataset.target_dim > 1

 dataset = Dataset(
 name=dataset_name,
 term=term,
 to_univariate=to_univariate,
 storage_path=dataset_storage_path,
 max_windows=max_windows,
 )

 # Compute metadata
 season_length = get_seasonality(dataset.freq)
 actual_freq = ds_freq if ds_freq else dataset.freq

 metadata = DatasetMetadata(
 full_name=f"{ds_key}/{actual_freq}/{term}",
 key=ds_key,
 freq=actual_freq,
 term=term,
 season_length=season_length,
 target_dim=probe_dataset.target_dim,
 to_univariate=to_univariate,
 prediction_length=dataset.prediction_length,
 windows=dataset.windows,
 )

 sub_datasets.append((dataset, metadata))

 return sub_datasets


def evaluate_datasets(
 predictor: TimeSeriesPredictor,
 dataset: str,
 dataset_storage_path: str,
 terms: list[str] | None = None,
 max_windows: int | None = None,
 batch_size: int = 48,
 max_context_length: int | None = 1024,
 create_plots: bool = False,
 max_plots_per_dataset: int = 10,
) -> list[EvaluationItem]:
 """Evaluate predictor on one dataset across the requested terms."""
 # Avoid mutable default argument
 if terms is None:
 terms = ["short", "medium", "long"]

 sub_datasets = construct_evaluation_data(
 dataset_name=dataset,
 dataset_storage_path=dataset_storage_path,
 terms=terms,
 max_windows=max_windows,
 )

 results: list[EvaluationItem] = []
 for i, (sub_dataset, metadata) in enumerate(sub_datasets):
 logger.info(f"Evaluating {i + 1}/{len(sub_datasets)}: {metadata.full_name}")
 logger.info(f" Dataset size: {len(sub_dataset.test_data)}")
 logger.info(f" Frequency: {sub_dataset.freq}")
 logger.info(f" Term: {metadata.term}")
 logger.info(f" Prediction length: {sub_dataset.prediction_length}")
 logger.info(f" Target dimensions: {sub_dataset.target_dim}")
 logger.info(f" Windows: {sub_dataset.windows}")

 # Update context on the reusable predictor
 predictor.set_dataset_context(
 prediction_length=sub_dataset.prediction_length,
 freq=sub_dataset.freq,
 batch_size=batch_size,
 max_context_length=max_context_length,
 )

 res = evaluate_model(
 model=predictor,
 test_data=sub_dataset.test_data,
 metrics=METRICS,
 axis=None,
 mask_invalid_label=True,
 allow_nan_forecast=False,
 seasonality=metadata.season_length,
 )

 figs: list[tuple[object, str]] = []
 if create_plots:
 # We are missing `src.plotting.gift_eval_utils.create_plots_for_dataset`
 # As this was not provided, plotting will be skipped.
 logger.warning(
 "Plotting is enabled but `create_plots_for_dataset` is not defined. Skipping plot generation."
 )
 pass

 results.append(EvaluationItem(dataset_metadata=metadata, metrics=res, figures=figs))

 return results

## 4. Configuration

Set the parameters for the evaluation run. The script will load the model from the local `models/` directory by default.

In [None]:
# --- Parameters ---
# Assumes the notebook is run from the root of the repo
model_path = Path.cwd() / "models/checkpoint_38M.pth"
config_path = Path.cwd() / "configs/example.yaml"

# --- Datasets and evaluation controls ---
# Use a small subset for testing, e.g., ["m4_weekly"]
datasets_arg = ["all"] # list of dataset names or ["all"].
terms = ["short", "medium", "long"]
dataset_storage_path = os.getenv("GIFT_EVAL_DATASET_STORAGE_PATH")
max_windows = 20
batch_size = 64
max_context_length = 3072

# --- Output ---
after_each_dataset_flush = True # write CSV as each dataset completes
model_name = "TempoPFN"
output_dir = Path.cwd() / "gift_eval_results" / model_name


# --- Helper Functions ---
def _load_yaml(path: str) -> dict:
 with open(path) as f:
 return yaml.safe_load(f)

## 5. Main Evaluation Loop

This cell sets up the predictor and runs the main evaluation loop over all specified datasets.

In [None]:
logger.info("Starting evaluation for model: %s", model_name)

# 1. Build predictor from a checkpoint
resolved_model_path = Path(model_path)

if not resolved_model_path.exists():
 logger.error(f"Model checkpoint not found at: {resolved_model_path}")
 logger.error("Please ensure the file exists and you've cloned the repo using Git LFS.")
 raise FileNotFoundError(f"No model checkpoint found. Set `model_path` correctly. Tried: {resolved_model_path}")

assert Path(config_path).exists(), f"Config not found: {config_path}"
logger.info("Loading predictor from checkpoint: %s", resolved_model_path)

predictor = TimeSeriesPredictor.from_paths(
 model_path=str(resolved_model_path),
 config_path=str(config_path),
 ds_prediction_length=1, # placeholder; set per dataset
 ds_freq="D", # placeholder; set per dataset
 batch_size=batch_size,
 max_context_length=max_context_length,
)

# 2. Run evaluation loop
datasets_to_run = expand_datasets_arg(datasets_arg)
results_root = Path(output_dir)

for ds_name in datasets_to_run:
 try:
 items = evaluate_datasets(
 predictor=predictor,
 dataset=ds_name,
 dataset_storage_path=dataset_storage_path,
 terms=terms,
 max_windows=max_windows,
 batch_size=batch_size,
 max_context_length=max_context_length,
 create_plots=False, # Set to True if you implement plotting
 max_plots_per_dataset=0,
 )
 write_results_to_disk(
 items=items,
 dataset_name=ds_name,
 output_dir=results_root,
 model_name=model_name,
 create_plots=False,
 )
 if after_each_dataset_flush:
 logger.info("Flushed results for %s", ds_name)
 except Exception as e:
 logger.error(f"FAILED evaluation for dataset: {ds_name}. Error: {e} !!!")
 logger.exception(e)
 continue # Continue to the next dataset

print(f"\nEvaluation complete. See results under: {output_dir}")

## 6. Aggregate Results

Finally, we'll aggregate the individual CSV files into a single `all_results.csv` file for easy analysis, following the `gift-eval` convention.

In [None]:
logger.info("Aggregating results from all datasets...")
combined_df = aggregate_results(result_root_dir=output_dir)

if combined_df is not None:
 agg_path = Path(output_dir) / "all_results.csv"
 logger.info("Successfully created aggregated results file: %s", agg_path)
 print(f"\n✅ Aggregated results saved to: {agg_path}")
 print(combined_df.head())
else:
 logger.warning("No results to aggregate. Check that evaluation completed successfully.")