Source code for torchgeo_bench.main

"""Benchmark script for torchgeo-bench."""

import fcntl
import io
import logging
import os
import time
from collections.abc import Sequence
from dataclasses import dataclass
from statistics import median

import hydra
import numpy as np
import pandas as pd
import torch
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from rich.progress import track
from sklearn.metrics import accuracy_score, average_precision_score
from torch.utils.data import ConcatDataset, DataLoader
from torchgeo.datasets import DatasetNotFoundError

from torchgeo_bench.calibration import (
    apply_temperature,
    compute_calibration_metrics,
    fit_temperature,
)
from torchgeo_bench.datasets import (
    get_bench_dataset_class,
    get_datasets,
    list_datasets,
)
from torchgeo_bench.intrinsic_dim import DegenerateManifoldError, compute_intrinsic_dim
from torchgeo_bench.knn import KNNClassifier
from torchgeo_bench.linear import LogisticRegression
from torchgeo_bench.model_profile import measure_profile
from torchgeo_bench.models.interface import BenchModel
from torchgeo_bench.segmentation_probe import (
    SegmentationProbe,
)
from torchgeo_bench.segmentation_task import SegmentationSolver, SegMetrics
from torchgeo_bench.segmentation_viz import save_segmentation_viz
from torchgeo_bench.utils import extract_features

logger = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Utilities
# ---------------------------------------------------------------------------


def _expand_dataset_list(names: str | Sequence[str]) -> list[str]:
    """Expand dataset names to a flat list.

    Args:
        names: Dataset name(s) — ``"all"``, comma-separated string, or sequence.

    Returns:
        List of individual dataset name strings.
    """
    if isinstance(names, str):
        if names == "all":
            return list_datasets()
        return [n.strip() for n in names.split(",") if n.strip()]
    return list(names)


def _normalize_bands_value(bands: object) -> str:
    """Canonicalize the ``cfg.dataset.bands`` value for logging/CSV/resume.

    Hydra hands us either ``"rgb"``/``"all"``, an explicit list (``ListConfig``
    or ``list[str]``), or ``None``.  Reduce all of those to a stable string so
    that the resume key and the CSV column are comparable across runs.

    Args:
        bands: The raw ``cfg.dataset.bands`` value.

    Returns:
        A stable string representation: ``"rgb"``, ``"all"``, or a
        comma-joined explicit band list (e.g. ``"red,green,blue,nir"``).
    """
    if bands is None:
        return "all"
    if isinstance(bands, str):
        return bands
    try:
        items = [str(b) for b in bands]
    except TypeError:
        return str(bands)
    return ",".join(items)


def _completed_run_keys(
    existing_df: pd.DataFrame,
    key_cols: Sequence[str],
    metric_name: str | None = None,
) -> set[tuple[str, ...]]:
    """Build resume keys from existing rows, optionally requiring a metric."""
    df = existing_df
    if metric_name is not None:
        if "metric_name" not in df.columns:
            return set()
        df = df[df["metric_name"].fillna("").astype(str) == metric_name]
    return set(map(tuple, df[list(key_cols)].fillna("").astype(str).to_numpy()))


def _row_key(row: dict, key_cols: Sequence[str]) -> tuple[str, ...]:
    """Build a normalized resume key tuple from a result row dict."""
    return tuple(str(row.get(col, "")) for col in key_cols)


def _filter_completed_metric_rows(
    rows: list[dict],
    completed_metrics: dict[str, set[tuple[str, ...]]],
    key_cols: Sequence[str],
) -> list[dict]:
    """Drop rows whose (metric_name, resume-key) already exists in the output CSV."""
    filtered: list[dict] = []
    for row in rows:
        metric_name = str(row.get("metric_name", ""))
        key = _row_key(row, key_cols)
        if key in completed_metrics.get(metric_name, set()):
            continue
        filtered.append(row)
    return filtered


def _profile_metric_names(profile_cfg: DictConfig | None) -> list[str]:
    """Return the required profile metrics for resume completeness checks."""
    names = [
        "throughput_samples_per_sec",
        "latency_ms_per_batch_p50",
        "params_m",
    ]
    cpu_cfg = profile_cfg.get("cpu_throughput", {}) if profile_cfg else {}
    if bool(cpu_cfg.get("enabled", False)):
        names.extend(["throughput_samples_per_sec_cpu", "latency_ms_per_batch_p50_cpu"])
    return names


[docs] def bootstrap_accuracy( y_true: np.ndarray, y_pred: np.ndarray, n_boot: int = 1000, ci: float = 95.0, seed: int | None = None, ) -> tuple[float, float, float]: """Bootstrapped accuracy with confidence interval. Returns (mean, ci_lower, ci_upper).""" rng = np.random.default_rng(seed) n = len(y_true) idx = rng.integers(0, n, size=(n_boot, n)) accs = (y_true[idx] == y_pred[idx]).mean(axis=1).astype(np.float32) acc_mean = float((y_true == y_pred).mean()) lo = (100 - ci) / 2 hi = 100 - lo lower = float(np.percentile(accs, lo)) upper = float(np.percentile(accs, hi)) return acc_mean, lower, upper
[docs] def bootstrap_map( y_true: np.ndarray, y_scores: np.ndarray, n_boot: int = 1000, ci: float = 95.0, seed: int | None = None, ) -> tuple[float, float, float]: """Bootstrap micro-averaged mean Average Precision.""" rng = np.random.default_rng(seed) n = len(y_true) map_mean = float(average_precision_score(y_true, y_scores, average="micro")) valid_maps: list[float] = [] for _ in range(n_boot): idx = rng.integers(0, n, size=n) yt = y_true[idx] # Skip degenerate resamples with no positive labels if yt.sum() == 0: continue valid_maps.append(average_precision_score(yt, y_scores[idx], average="micro")) if not valid_maps: return map_mean, map_mean, map_mean maps = np.array(valid_maps, dtype=np.float32) lo = (100 - ci) / 2 hi = 100 - lo lower = float(np.percentile(maps, lo)) upper = float(np.percentile(maps, hi)) return map_mean, lower, upper
[docs] @dataclass class EvaluationResult: """Container for a single evaluation result row.""" dataset: str method: str # 'knn5', 'linear', or seg head type metric_name: str # 'accuracy', 'micro_mAP', or 'mIoU' (primary metric) metric_value: float ci_lower: float ci_upper: float feature_dim: int best_c: float | None best_lr: float | None best_batch_size: int | None n_train: int n_val: int n_test: int seed: int model: str name: str normalization: str image_size: int | None interpolation: str partition: str bands: str c_range_start: float c_range_stop: float c_range_num: int merge_val: bool bootstrap: int # Segmentation-only metrics (None for classification rows) fw_iou: float | None = None precision: float | None = None recall: float | None = None f1: float | None = None # Calibration metrics for KNN / Linear Probing (None for segmentation rows) ece: float | None = None rms_ce: float | None = None mce: float | None = None # Post temperature-scaling calibration (Linear Probing only; None for KNN/seg) ece_ts: float | None = None rms_ce_ts: float | None = None mce_ts: float | None = None temperature: float | None = None calibration_n_bins: int | None = None
[docs] def to_row(self) -> dict: """Convert to a flat dictionary suitable for CSV/DataFrame export.""" return self.__dict__.copy()
[docs] def embed_split( model: BenchModel, dataloader: DataLoader, device: torch.device, verbose: bool ) -> tuple[np.ndarray, np.ndarray]: """Extract feature embeddings and labels from a data split.""" return extract_features(model, dataloader, device, transforms=None, verbose=verbose)
[docs] def evaluate_knn( x_train: np.ndarray, y_train: np.ndarray, x_test: np.ndarray, y_test: np.ndarray, seed: int, n_bootstrap: int, verbose: bool = False, device: str = "cpu", n_neighbors: int = 5, calibration_n_bins: int | None = None, ) -> tuple[float, float, float, dict[str, float], int]: """Evaluate KNN classifier. Auto-detects single-label vs multi-label from y shape. Returns the primary metric with bootstrap CI, a calibration dict (``ece``/``rms_ce``/``mce``) computed from ``predict_proba``, and the ``n_bins`` actually used (defaults to ``n_neighbors + 1``). """ n_bins = calibration_n_bins if calibration_n_bins is not None else n_neighbors + 1 multi_label = y_train.ndim == 2 clf = KNNClassifier(n_neighbors=n_neighbors, device=device, use_fp16=False) clf.fit(x_train, y_train) if multi_label: if verbose: logger.info(f"[KNN] Fit KNN5 multilabel (train={len(x_train)}, test={len(x_test)})") y_scores = clf.predict_proba(x_test) metric, lo, hi = bootstrap_map(y_test, y_scores, n_boot=n_bootstrap, seed=seed) if verbose: logger.info(f"[KNN] Test micro_mAP={metric:.4f} (CI {lo:.4f}-{hi:.4f})") else: if verbose: logger.info( f"[KNN] Fit KNN5 (train={len(x_train)}, test={len(x_test)}, boot={n_bootstrap})" ) preds = clf.predict(x_test) y_scores = clf.predict_proba(x_test) metric, lo, hi = bootstrap_accuracy(y_test, preds, n_boot=n_bootstrap, seed=seed) if verbose: logger.info(f"[KNN] Test accuracy={metric:.4f} (CI {lo:.4f}-{hi:.4f})") calibration = compute_calibration_metrics( y_test, y_scores, multi_label=multi_label, n_bins=n_bins ) if verbose: logger.info( f"[KNN] Calibration (n_bins={n_bins}) ECE={calibration['ece']:.4f} " f"RMS-CE={calibration['rms_ce']:.4f} MCE={calibration['mce']:.4f}" ) return metric, lo, hi, calibration, n_bins
[docs] def evaluate_logistic( x_train: np.ndarray, y_train: np.ndarray, x_val: np.ndarray, y_val: np.ndarray, x_test: np.ndarray, y_test: np.ndarray, c_values: Sequence[float], seed: int, n_bootstrap: int, merge_val: bool, device: str, verbose: bool = False, calibration_n_bins: int = 15, temp_scale: bool = True, ) -> tuple[float, float, float, float, dict[str, float], dict[str, float | None]]: """Sweep C values, retrain, and evaluate. Auto-detects single/multi-label from y shape. Returns the primary metric with bootstrap CI, the selected ``C``, a calibration dict from raw ``predict_proba`` on the test split, and a second dict with temperature-scaled calibration plus the fitted ``temperature`` (all ``None`` when ``temp_scale=False``). """ multi_label = y_train.ndim == 2 best_c: float | None = None best_val_score = -1.0 x_train_tensor = torch.from_numpy(x_train) x_val_tensor = torch.from_numpy(x_val) x_test_tensor = torch.from_numpy(x_test) if multi_label: y_train_tensor = torch.from_numpy(y_train).float() label_tag = "LogReg-ML" else: y_train_tensor = torch.from_numpy(y_train).long() label_tag = "LogReg" if verbose: logger.info( f"[{label_tag}] C sweep start over {len(c_values)} values " f"(train={len(x_train)}, val={len(x_val)})" ) c_value_iterator = track(c_values, description="C values") else: c_value_iterator = c_values for idx, c in enumerate(c_value_iterator): model = LogisticRegression( C=c, max_iter=2000, tol=1e-6, random_state=seed, device=device, multi_label=multi_label, ) model.fit(x_train_tensor, y_train_tensor) if multi_label: val_scores = model.predict_proba(x_val_tensor) val_metric = float(average_precision_score(y_val, val_scores, average="micro")) else: val_pred = model.predict(x_val_tensor) val_metric = accuracy_score(y_val, val_pred) if verbose and (idx < 10 or idx % 50 == 0): logger.info(f"[{label_tag}] C={c:.4g} val_score={val_metric:.4f}") if val_metric > best_val_score: best_val_score = val_metric best_c = c assert best_c is not None, "C sweep failed to select a value" if verbose: logger.info(f"[{label_tag}] Best C={best_c:.4g} val_score={best_val_score:.4f}") if merge_val: x_final_np = np.concatenate([x_train, x_val], axis=0) y_final_np = np.concatenate([y_train, y_val], axis=0) x_final = torch.from_numpy(x_final_np) y_final = ( torch.from_numpy(y_final_np).float() if multi_label else torch.from_numpy(y_final_np).long() ) else: x_final = x_train_tensor y_final = y_train_tensor final_model = LogisticRegression( C=best_c, max_iter=4000, tol=1e-6, random_state=seed, device=device, multi_label=multi_label, ) final_model.fit(x_final, y_final) if multi_label: test_scores = final_model.predict_proba(x_test_tensor) metric, lo, hi = bootstrap_map(y_test, test_scores, n_boot=n_bootstrap, seed=seed) else: test_preds = final_model.predict(x_test_tensor) test_scores = final_model.predict_proba(x_test_tensor) metric, lo, hi = bootstrap_accuracy(y_test, test_preds, n_boot=n_bootstrap, seed=seed) calibration = compute_calibration_metrics( y_test, test_scores, multi_label=multi_label, n_bins=calibration_n_bins ) calibration_ts: dict[str, float | None] = { "ece_ts": None, "rms_ce_ts": None, "mce_ts": None, "temperature": None, } if temp_scale: # Fit T on val logits, apply to test logits, recompute calibration. # When merge_val=True the final model has seen val during training, but # T is a single scalar so the resulting leakage is minimal. val_logits = final_model.decision_function(x_val_tensor) test_logits = final_model.decision_function(x_test_tensor) temperature = fit_temperature(val_logits, y_val, multi_label=multi_label) test_scores_ts = apply_temperature(test_logits, temperature, multi_label=multi_label) cal_ts = compute_calibration_metrics( y_test, test_scores_ts, multi_label=multi_label, n_bins=calibration_n_bins ) calibration_ts = { "ece_ts": cal_ts["ece"], "rms_ce_ts": cal_ts["rms_ce"], "mce_ts": cal_ts["mce"], "temperature": temperature, } if verbose: logger.info( f"[{label_tag}] Test score={metric:.4f} (CI {lo:.4f}-{hi:.4f}) " f"using C={best_c:.4g}; train_final={len(x_final)} test={len(x_test)}" ) logger.info( f"[{label_tag}] Calibration (n_bins={calibration_n_bins}) " f"ECE={calibration['ece']:.4f} " f"RMS-CE={calibration['rms_ce']:.4f} MCE={calibration['mce']:.4f}" ) if temp_scale: logger.info( f"[{label_tag}] Post-TS T={calibration_ts['temperature']:.3f} " f"ECE={calibration_ts['ece_ts']:.4f} " f"RMS-CE={calibration_ts['rms_ce_ts']:.4f} " f"MCE={calibration_ts['mce_ts']:.4f}" ) return metric, lo, hi, float(best_c), calibration, calibration_ts
def _make_seg_dataloaders( train_dataset: torch.utils.data.Dataset, val_dataset: torch.utils.data.Dataset, test_loader: DataLoader, batch_size: int, ) -> tuple[DataLoader, DataLoader, DataLoader]: loader_kwargs = { "batch_size": batch_size, "num_workers": test_loader.num_workers, "pin_memory": test_loader.pin_memory, } train_loader = DataLoader(train_dataset, shuffle=True, **loader_kwargs) val_loader = DataLoader(val_dataset, shuffle=False, **loader_kwargs) train_val_loader = DataLoader( ConcatDataset([train_dataset, val_dataset]), shuffle=True, **loader_kwargs ) return train_loader, val_loader, train_val_loader def _build_seg_probe_and_solver( model: torch.nn.Module, num_classes: int, eval_cfg: DictConfig, device: torch.device, lr: float, ) -> tuple[SegmentationProbe, SegmentationSolver]: layer_names = list(eval_cfg.segmentation.layers) if not layer_names: raise ValueError( "Segmentation evaluation requires eval.segmentation.layers to name " "spatial backbone layers. Refusing to probe the global backbone output." ) probe = SegmentationProbe( backbone=model, layer_names=layer_names, num_classes=num_classes, head_type=eval_cfg.segmentation.head_type, freeze_backbone=True, ) criterion = instantiate(eval_cfg.segmentation.criterion) ignore_index = _resolve_segmentation_ignore_index(eval_cfg.segmentation, criterion) solver = SegmentationSolver( model=probe, num_classes=num_classes, lr=lr, device=str(device), criterion=criterion, lr_scheduler=eval_cfg.segmentation.get("lr_scheduler", "cosine"), ignore_index=ignore_index, ) return probe, solver def _resolve_segmentation_ignore_index(seg_cfg: DictConfig, criterion: torch.nn.Module) -> int: """Resolve the ignore index shared by segmentation loss and metrics.""" explicit = seg_cfg.get("ignore_index", None) criterion_value = getattr(criterion, "ignore_index", None) if explicit is None: return int(criterion_value) if criterion_value is not None else 255 if criterion_value is not None and int(criterion_value) != int(explicit): raise ValueError( "Segmentation ignore_index mismatch: " f"eval.segmentation.ignore_index={explicit} but " f"criterion.ignore_index={criterion_value}." ) return int(explicit)
[docs] def evaluate_intrinsic_dim( splits: dict[str, np.ndarray], estimators: Sequence[str], selected_splits: Sequence[str], device: str | None, max_samples: int | None, seed: int, common_meta: dict, feature_dim: int, n_counts: dict[str, int], verbose: bool = False, ) -> list[dict]: """Compute intrinsic-dimension metrics over selected splits and return CSV rows. Each (split, estimator) yields one row with ``method="intrinsic_dim"`` and ``metric_name=f"id_{estimator}_{split}"``. """ rows: list[dict] = [] for split_name in selected_splits: if split_name not in splits: logger.warning(f"[intrinsic-dim] unknown split '{split_name}', skipping") continue X = splits[split_name] if verbose: logger.info( f"[intrinsic-dim] split={split_name} X{X.shape} " f"estimators={list(estimators)} device={device}" ) # Per-estimator isolation: compute_intrinsic_dim raises on the # *first* non-finite dimension (by design — surfaces fp32 bugs). # During a long sweep that aborts the whole task and we lose KNN # /linear/profile rows too. Run each estimator separately so a # genuinely-degenerate feature manifold (e.g. terramind features # with d1==d2 collapsing TwoNN's log-ratio) only loses *that* # estimator's row, not the rest of the task. dims: dict[str, float] = {} for est_name in estimators: try: dims.update( compute_intrinsic_dim( X, estimators=[est_name], device=device, max_samples=max_samples, seed=seed, ) ) except DegenerateManifoldError as exc: logger.warning( f"[intrinsic-dim] {est_name} split={split_name} model={common_meta.get('model')} " f"dataset={common_meta.get('dataset')} bands={common_meta.get('bands')} " f"norm={common_meta.get('normalization')}: degenerate features, writing NaN. " f"Diagnostic: {exc}" ) dims[est_name] = float("nan") for est_name, dim in dims.items(): rows.append( EvaluationResult( **common_meta, method="intrinsic_dim", metric_name=f"id_{est_name}_{split_name}", metric_value=float(dim), ci_lower=0.0, ci_upper=0.0, feature_dim=feature_dim, best_c=None, best_lr=None, best_batch_size=None, n_train=n_counts.get("train", 0), n_val=n_counts.get("val", 0), n_test=n_counts.get("test", 0), ).to_row() ) return rows
def evaluate_profile( model: BenchModel, sample_loader: DataLoader, device: torch.device, n_warmup: int, n_measure: int, common_meta: dict, feature_dim: int, n_counts: dict[str, int], cpu_throughput_enabled: bool = False, cpu_batch_size: int = 8, cpu_n_warmup: int = 1, cpu_n_measure: int = 5, cpu_time_budget_s: float = 300.0, ) -> list[dict]: """Measure backbone throughput / memory / GMACs and return CSV rows. One row per metric, with ``method="profile"``. When ``cpu_throughput_enabled`` is set, *additionally* runs a short CPU measurement (smaller batch / fewer iters) and emits the throughput / latency / energy / params with a ``_cpu`` suffix. The CPU pass is wall-clock-budgeted via ``cpu_time_budget_s`` so the heavyweight ViT-L backbones don't burn an hour on the login node. """ # If the loader is broken there's nothing meaningful to profile; let the # error propagate so the failure surfaces in SLURM logs instead of # silently appending zero rows and "succeeding" the task. sample = next(iter(sample_loader))["image"].to(device) metrics = measure_profile(model, sample, device, n_warmup=n_warmup, n_measure=n_measure) if cpu_throughput_enabled: cpu_metrics = _measure_cpu_throughput( model, sample, cpu_batch_size=cpu_batch_size, n_warmup=cpu_n_warmup, n_measure=cpu_n_measure, time_budget_s=cpu_time_budget_s, ) for k, v in cpu_metrics.items(): metrics[k + "_cpu"] = v rows: list[dict] = [] for name, value in metrics.items(): if value is None: # value is None only when the underlying probe is structurally # unavailable (e.g. CPU device → no peak_gpu_mem, or the CPU # pass aborted via the wall-clock budget). Logged inside the # measurement helpers; skip the row. continue rows.append( EvaluationResult( **common_meta, method="profile", metric_name=name, metric_value=float(value), ci_lower=0.0, ci_upper=0.0, feature_dim=feature_dim, best_c=None, best_lr=None, best_batch_size=None, n_train=n_counts.get("train", 0), n_val=n_counts.get("val", 0), n_test=n_counts.get("test", 0), ).to_row() ) return rows def _measure_cpu_throughput( model: BenchModel, sample: torch.Tensor, *, cpu_batch_size: int, n_warmup: int, n_measure: int, time_budget_s: float, ) -> dict[str, float | None]: """Run a wall-clock-budgeted CPU pass and return the off-GPU metrics. Reports the subset that makes sense on CPU: throughput and latency. The model and a fresh batch are moved to CPU for the duration, then moved back so the rest of the pipeline can keep using CUDA. If even the first warmup pass exceeds ``time_budget_s`` we return None values with a warning rather than waste cluster hours — that's a documented soft-fail keyed on a specific named condition, not a generic swallow. """ cpu_dev = torch.device("cpu") # rcf/imagestats baselines have no parameters; use the input sample's # device as the restoration target since model.to() is a no-op anyway. params_iter = iter(model.parameters()) first_param = next(params_iter, None) orig_dev = first_param.device if first_param is not None else sample.device cpu_sample = sample[:cpu_batch_size].detach().to(cpu_dev) model.to(cpu_dev) try: t0 = time.perf_counter() with torch.inference_mode(): for _ in range(n_warmup): model(cpu_sample) if time.perf_counter() - t0 > time_budget_s: logger.warning( f"[profile] CPU warmup exceeded {time_budget_s}s budget on " f"{type(model).__name__}; skipping CPU throughput." ) return { "throughput_samples_per_sec": None, "latency_ms_per_batch_p50": None, } per_batch_ms: list[float] = [] t_loop = time.perf_counter() for _ in range(n_measure): tb = time.perf_counter() model(cpu_sample) per_batch_ms.append((time.perf_counter() - tb) * 1000.0) if time.perf_counter() - t0 > time_budget_s: break elapsed = time.perf_counter() - t_loop seen = len(per_batch_ms) if seen == 0: return {"throughput_samples_per_sec": None, "latency_ms_per_batch_p50": None} return { "throughput_samples_per_sec": (cpu_batch_size * seen) / elapsed, "latency_ms_per_batch_p50": median(per_batch_ms), } finally: model.to(orig_dev)
[docs] def evaluate_segmentation( model: torch.nn.Module, train_loader: DataLoader, val_loader: DataLoader, test_loader: DataLoader, cfg: DictConfig, num_classes: int, device: torch.device, collect_preds: bool = False, ) -> "tuple[SegMetrics, int, float | None, int | None, torch.Tensor | None]": """Evaluate segmentation performance using a frozen-backbone segmentation probe. Trains a lightweight segmentation head on top of the frozen backbone and evaluates mIoU on the test split. Optionally pre-caches backbone features for faster training across epochs. Args: model: Frozen backbone model. train_loader: Training DataLoader. val_loader: Validation DataLoader. test_loader: Test DataLoader. cfg: Full Hydra config. num_classes: Number of segmentation classes. device: Torch device. collect_preds: If True, collect and return test predictions as (N, H, W) tensor. Returns: Tuple of (metrics_dict, feature_dim, None, None, preds_or_None). ``preds_or_None`` is None when collect_preds is False. """ # Merge model-specific eval config if present eval_cfg = cfg.eval if "eval" in cfg.model and cfg.model.eval is not None: eval_cfg = OmegaConf.merge(eval_cfg, cfg.model.eval) if "segmentation" not in eval_cfg: raise ValueError("Segmentation evaluation config missing for the model.") seg_cfg = eval_cfg.segmentation epochs = seg_cfg.epochs use_cache = seg_cfg.get("cache_features", True) cache_dtype_str = seg_cfg.get("cache_dtype", "float16") cache_dtype = torch.float16 if cache_dtype_str == "float16" else torch.float32 probe, solver = _build_seg_probe_and_solver(model, num_classes, eval_cfg, device, seg_cfg.lr) if use_cache and probe.freeze_backbone: logger.info("Caching backbone features for train and val splits...") train_cache = probe.extract_segmentation_features(train_loader, cache_dtype=cache_dtype) val_cache = probe.extract_segmentation_features(val_loader, cache_dtype=cache_dtype) test_cache = probe.extract_segmentation_features(test_loader, cache_dtype=cache_dtype) solver.fit_cached( train_cache=train_cache, val_cache=val_cache, batch_size=seg_cfg.get("batch_size", 64), epochs=epochs, verbose=cfg.verbose, ) eval_result = solver.evaluate_cached( test_cache, batch_size=seg_cfg.get("batch_size", 64), collect_preds=collect_preds, ) else: solver.fit( train_loader=train_loader, val_loader=val_loader, epochs=epochs, verbose=cfg.verbose ) eval_result = solver.evaluate(test_loader, collect_preds=collect_preds) if collect_preds: metrics, preds = eval_result else: metrics, preds = eval_result, None return metrics, sum(probe.channels_list), None, None, preds
# --------------------------------------------------------------------------- # Main # ---------------------------------------------------------------------------
[docs] def append_rows_atomic(path: str, rows: list[dict]) -> None: """Append rows to a CSV atomically, with advisory file lock and schema healing. Behavior: - Empty/missing file: writes the header derived from ``rows`` and the rows. - Existing file whose header matches ``rows[0]`` keys exactly: appends rows without rewriting the header (fast path). - Existing file with a different schema (e.g. ``EvaluationResult`` gained a field since the file was first written): the file is rewritten with the unioned schema so every value lives under a named column instead of being silently stuffed into an unnamed position. Args: path: Output CSV path; created if missing. rows: List of dicts to append. All dicts should share the same keys. """ if not rows: return df_local = pd.DataFrame(rows) fd = os.open(path, os.O_RDWR | os.O_CREAT) with os.fdopen(fd, "r+", closefd=True) as f: fcntl.flock(f.fileno(), fcntl.LOCK_EX) try: f.seek(0, os.SEEK_END) empty = f.tell() == 0 buf = io.StringIO() if empty: df_local.to_csv(buf, header=True, index=False) f.write(buf.getvalue()) else: f.seek(0) existing_df = pd.read_csv(f) if list(existing_df.columns) == list(df_local.columns): df_local.to_csv(buf, header=False, index=False) f.seek(0, os.SEEK_END) f.write(buf.getvalue()) else: extra = [c for c in existing_df.columns if c not in df_local.columns] ordered = list(df_local.columns) + extra combined = pd.concat( [existing_df, df_local], ignore_index=True, sort=False ).reindex(columns=ordered) logger.warning( "CSV schema drift detected at %s: existing columns %s, " "new columns %s. Rewriting with unioned schema %s.", path, list(existing_df.columns), list(df_local.columns), ordered, ) f.seek(0) f.truncate() combined.to_csv(buf, header=True, index=False) f.write(buf.getvalue()) f.flush() os.fsync(f.fileno()) finally: fcntl.flock(f.fileno(), fcntl.LOCK_UN)
[docs] @hydra.main(config_path="conf", config_name="config", version_base=None) def main(cfg: DictConfig) -> None: """Run the benchmark pipeline for all configured datasets and models.""" torch.manual_seed(cfg.seed) np.random.seed(cfg.seed) dataset_names = _expand_dataset_list(cfg.dataset.names) device = torch.device(cfg.device) output_path = cfg.output os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) all_rows: list[dict] = [] model_eval = cfg.model.get("eval", None) if "eval" in cfg.model else None if model_eval is not None and model_eval.get("c_range", None) is not None: c_start, c_stop, c_num = model_eval.c_range else: c_start, c_stop, c_num = cfg.eval.c_range c_values = 10 ** np.linspace(float(c_start), float(c_stop), int(c_num)) c_values_list = [float(v) for v in c_values.tolist()] key_cols = ( "dataset", "method", "model", "name", "normalization", "image_size", "interpolation", "partition", "bands", ) completed_runs: set[tuple[str, ...]] = set() completed_metrics: dict[str, set[tuple[str, ...]]] = {} if cfg.resume and os.path.exists(output_path): existing_df = pd.read_csv(cfg.output) for col in key_cols: if col not in existing_df.columns: existing_df[col] = "" completed_runs = _completed_run_keys(existing_df, key_cols) if "metric_name" in existing_df.columns: completed_metrics = { str(metric): _completed_run_keys(existing_df, key_cols, str(metric)) for metric in existing_df["metric_name"].dropna().unique() } logger.info(f"Resume mode: Found {len(completed_runs)} existing results in {cfg.output}") logger.info("Will skip already-computed (dataset, method, model, config) combinations.") # Selectable input-normalisation strategy; recorded in the CSV so # ablations across strategies are distinguishable. normalization = str(getattr(cfg.dataset, "normalization", "bandspec_zscore")) bands_value = _normalize_bands_value(getattr(cfg.dataset, "bands", "rgb")) for ds_name in track(dataset_names, description="Datasets"): try: ds_cls = get_bench_dataset_class(ds_name) except KeyError: logger.warning(f"Skipping dataset {ds_name} (not in registry)") continue config_tuple = ( normalization, str(getattr(cfg.dataset, "image_size", None)), getattr(cfg.dataset, "interpolation", "bilinear"), cfg.dataset.partition, bands_value, ) # Merge model-specific eval config early so resume key and result rows # reflect the actual head_type used, not the global default. eval_cfg_merged = OmegaConf.merge( cfg.eval, cfg.model.eval if "eval" in cfg.model and cfg.model.eval is not None else {}, ) knn_k = int(getattr(eval_cfg_merged, "knn_k", 5)) knn_key = (ds_name, f"knn{knn_k}", cfg.model._target_, cfg.model.name, *config_tuple) linear_key = (ds_name, "linear", cfg.model._target_, cfg.model.name, *config_tuple) seg_method = f"seg-{eval_cfg_merged.segmentation.head_type}" seg_key = (ds_name, seg_method, cfg.model._target_, cfg.model.name, *config_tuple) id_key = (ds_name, "intrinsic_dim", cfg.model._target_, cfg.model.name, *config_tuple) profile_key = (ds_name, "profile", cfg.model._target_, cfg.model.name, *config_tuple) try: result = get_datasets( dataset_name=ds_name, partition_name=cfg.dataset.partition, batch_size=cfg.dataset.batch_size, num_workers=int(cfg.dataset.get("num_workers", 8)), return_val=True, image_size=getattr(cfg.dataset, "image_size", None), interpolation=getattr(cfg.dataset, "interpolation", "bilinear"), bands=getattr(cfg.dataset, "bands", "rgb"), ) except (FileNotFoundError, DatasetNotFoundError) as exc: logger.warning(f"Skipping dataset {ds_name} (data not found: {exc})") continue if result is None or not isinstance(result, tuple) or len(result) != 4: logger.warning(f"Skipping dataset {ds_name} (unexpected return)") continue train_dataset, train_loader, val_loader, test_loader = result num_channels = train_dataset[0]["image"].shape[0] is_segmentation = ds_cls.task == "segmentation" is_multilabel = ds_cls.multilabel num_classes = ds_cls.num_classes # Build the BandSpec list that matches the actual loaded channels. bench_for_bands = ds_cls() bands_resolved = ( tuple(bench_for_bands.rgb_bands) if cfg.dataset.bands == "rgb" else None if cfg.dataset.bands in ("all", None) else tuple(cfg.dataset.bands) ) bands_list = bench_for_bands.select_band_specs(bands_resolved) assert len(bands_list) == num_channels, ( f"BandSpec count {len(bands_list)} != tensor channel count {num_channels} " f"for dataset {ds_name}; sample-level canonicalization may have changed shape." ) # Resume check for segmentation if is_segmentation and cfg.resume and seg_key in completed_runs: if cfg.verbose: logger.info(f"[{ds_name}] Skipping segmentation (already computed)") continue # Instantiate Backbone — pass `bands` post-hoc so Hydra never tries # to OmegaConf-ify the BandSpec list. `_convert_="object"` keeps # the rest of the model config as plain Python primitives. is_rcf_empirical = ( hasattr(cfg.model, "mode") and str(cfg.model._target_).endswith("RCFBench") and str(cfg.model.mode) == "empirical" ) instantiate_kwargs: dict = { "bands": bands_list, "normalization": normalization, "_convert_": "object", } if is_rcf_empirical: instantiate_kwargs["dataset"] = train_dataset model: BenchModel = instantiate(cfg.model, **instantiate_kwargs) model.to(device).eval() common_meta = { "dataset": ds_name, "seed": cfg.seed, "model": cfg.model._target_, "name": cfg.model.name, "normalization": normalization, "image_size": getattr(cfg.dataset, "image_size", None), "interpolation": getattr(cfg.dataset, "interpolation", "bilinear"), "partition": cfg.dataset.partition, "bands": bands_value, "c_range_start": c_start, "c_range_stop": c_stop, "c_range_num": c_num, "merge_val": cfg.eval.merge_val, "bootstrap": cfg.eval.bootstrap, } if is_segmentation: seg_cfg_merged = OmegaConf.merge( cfg.eval, cfg.model.eval if "eval" in cfg.model and cfg.model.eval is not None else {}, ).segmentation save_viz = seg_cfg_merged.get("save_viz", False) metrics, feat_dim, best_lr, best_bs, preds = evaluate_segmentation( model, train_loader, val_loader, test_loader, cfg, num_classes, device, collect_preds=save_viz, ) all_rows.append( EvaluationResult( **common_meta, method=seg_method, metric_name="mIoU", metric_value=metrics.get("mIoU", float("nan")), ci_lower=0.0, ci_upper=0.0, feature_dim=feat_dim, best_c=None, best_lr=best_lr, best_batch_size=best_bs, n_train=len(train_dataset), n_val=len(val_loader.dataset), n_test=len(test_loader.dataset), fw_iou=metrics.get("fw_IoU"), precision=metrics.get("precision"), recall=metrics.get("recall"), f1=metrics.get("f1"), ).to_row() ) if save_viz and preds is not None: rgb_indices = ds_cls().rgb_indices or [0, 1, 2] # Collect images and GT masks from test_loader (cheap pass, no backbone) test_imgs, test_gts = [], [] for _batch in test_loader: if isinstance(_batch, dict): test_imgs.append(_batch["image"]) _m = _batch["mask"] else: test_imgs.append(_batch[0]) _m = _batch[1] if _m.ndim == 4: _m = _m.squeeze(1) test_gts.append(_m.long()) test_imgs_t = torch.cat(test_imgs, dim=0) test_gts_t = torch.cat(test_gts, dim=0) ignore_idx = seg_cfg_merged.get("ignore_index", 255) n_viz = seg_cfg_merged.get("n_viz_samples", 8) viz_dir = seg_cfg_merged.get("viz_dir", "viz") _class_names = list(getattr(train_dataset, "classes", None) or []) or None save_segmentation_viz( out_dir=viz_dir, model_name=cfg.model.name, dataset_name=ds_name, images=test_imgs_t, gt_masks=test_gts_t, pred_masks=preds, num_classes=num_classes, rgb_indices=rgb_indices, ignore_index=ignore_idx, n_samples=n_viz, class_names=_class_names, ) else: # Classification (single-label or multi-label) metric_name = "micro_mAP" if is_multilabel else "accuracy" skip_knn = cfg.resume and knn_key in completed_runs skip_linear = (cfg.resume and linear_key in completed_runs) or getattr( cfg.eval, "skip_linear", False ) id_cfg = getattr(cfg.eval, "intrinsic_dim", None) id_enabled = bool(id_cfg and id_cfg.get("enabled", False)) id_metric_names = ( [f"id_{est}_{split}" for split in id_cfg.splits for est in id_cfg.estimators] if id_enabled else [] ) skip_id = (not id_enabled) or ( cfg.resume and id_metric_names and all( id_key in completed_metrics.get(metric, set()) for metric in id_metric_names ) ) profile_cfg = getattr(cfg.eval, "profile", None) profile_enabled = bool(profile_cfg and profile_cfg.get("enabled", False)) profile_metric_names = _profile_metric_names(profile_cfg) if profile_enabled else [] skip_profile = (not profile_enabled) or ( cfg.resume and profile_metric_names and all( profile_key in completed_metrics.get(metric, set()) for metric in profile_metric_names ) ) if skip_knn and skip_linear and skip_id and skip_profile: continue x_train, y_train = embed_split(model, train_loader, device, verbose=cfg.verbose) x_val, y_val = embed_split(model, val_loader, device, verbose=cfg.verbose) x_test, y_test = embed_split(model, test_loader, device, verbose=cfg.verbose) feature_dim = x_train.shape[1] cal_cfg = cfg.eval.get("calibration", {}) or {} cal_n_bins_knn = cal_cfg.get("n_bins_knn", None) cal_n_bins_linear = int(cal_cfg.get("n_bins_linear", 15)) cal_temp_scale = bool(cal_cfg.get("temp_scale", True)) if not skip_knn: knn_device = cfg.eval.get("knn_device") or cfg.device knn_score, knn_lo, knn_hi, knn_cal, knn_n_bins = evaluate_knn( x_train, y_train, x_test, y_test, cfg.seed, cfg.eval.bootstrap, verbose=cfg.verbose, device=knn_device, n_neighbors=knn_k, calibration_n_bins=cal_n_bins_knn, ) all_rows.append( EvaluationResult( **common_meta, method=f"knn{knn_k}", metric_name=metric_name, metric_value=knn_score, ci_lower=knn_lo, ci_upper=knn_hi, feature_dim=feature_dim, best_c=None, best_lr=None, best_batch_size=None, n_train=len(x_train), n_val=len(x_val), n_test=len(x_test), ece=knn_cal["ece"], rms_ce=knn_cal["rms_ce"], mce=knn_cal["mce"], calibration_n_bins=knn_n_bins, ).to_row() ) if not skip_linear: lin_score, lin_lo, lin_hi, best_c, lin_cal, lin_cal_ts = evaluate_logistic( x_train, y_train, x_val, y_val, x_test, y_test, c_values_list, cfg.seed, cfg.eval.bootstrap, cfg.eval.merge_val, cfg.device, cfg.verbose, calibration_n_bins=cal_n_bins_linear, temp_scale=cal_temp_scale, ) all_rows.append( EvaluationResult( **common_meta, method="linear", metric_name=metric_name, metric_value=lin_score, ci_lower=lin_lo, ci_upper=lin_hi, feature_dim=feature_dim, best_c=best_c, best_lr=None, best_batch_size=None, n_train=len(x_train), n_val=len(x_val), n_test=len(x_test), ece=lin_cal["ece"], rms_ce=lin_cal["rms_ce"], mce=lin_cal["mce"], ece_ts=lin_cal_ts["ece_ts"], rms_ce_ts=lin_cal_ts["rms_ce_ts"], mce_ts=lin_cal_ts["mce_ts"], temperature=lin_cal_ts["temperature"], calibration_n_bins=cal_n_bins_linear, ).to_row() ) if not skip_id: id_rows = evaluate_intrinsic_dim( splits={"train": x_train, "val": x_val, "test": x_test}, estimators=list(id_cfg.estimators), selected_splits=list(id_cfg.splits), device=id_cfg.get("device", None) or cfg.device, max_samples=id_cfg.get("max_samples", None), seed=cfg.seed, common_meta=common_meta, feature_dim=feature_dim, n_counts={ "train": len(x_train), "val": len(x_val), "test": len(x_test), }, verbose=cfg.verbose, ) if cfg.resume: id_rows = _filter_completed_metric_rows(id_rows, completed_metrics, key_cols) all_rows.extend(id_rows) if not skip_profile: cpu_cfg = profile_cfg.get("cpu_throughput", {}) if profile_cfg else {} profile_rows = evaluate_profile( model=model, sample_loader=train_loader, device=torch.device(cfg.device), n_warmup=int(profile_cfg.get("n_warmup", 3)), n_measure=int(profile_cfg.get("n_measure", 20)), common_meta=common_meta, feature_dim=feature_dim, n_counts={ "train": len(x_train), "val": len(x_val), "test": len(x_test), }, cpu_throughput_enabled=bool(cpu_cfg.get("enabled", False)), cpu_batch_size=int(cpu_cfg.get("batch_size", 8)), cpu_n_warmup=int(cpu_cfg.get("n_warmup", 1)), cpu_n_measure=int(cpu_cfg.get("n_measure", 5)), cpu_time_budget_s=float(cpu_cfg.get("time_budget_s", 300.0)), ) if cfg.resume: profile_rows = _filter_completed_metric_rows( profile_rows, completed_metrics, key_cols ) all_rows.extend(profile_rows) append_rows_atomic(output_path, all_rows) all_rows.clear() logger.info(f"Benchmark complete. Results appended to {output_path}")
if __name__ == "__main__": # pragma: no cover # Hydra provides cfg automatically; this call signature is correct. main() # type: ignore[misc]