Source code for torchgeo_bench.intrinsic_dim

"""Intrinsic dimension (ID) estimation over feature embeddings.

Thin wrapper around ``torchid`` (https://github.com/isaaccorley/torchid).
Provides a single entry point to compute one or more global ID estimates on a
feature matrix and return scalar values per estimator.

ID is computed on raw embeddings (no L2-normalization) to match the distance
geometry used by KNN/linear probes elsewhere in this package.
"""

import logging
from typing import Any

import numpy as np
import torch

logger = logging.getLogger(__name__)


class DegenerateManifoldError(ValueError):
    """Feature manifold is degenerate; the estimator returned a non-finite dimension."""


SUPPORTED_ESTIMATORS: tuple[str, ...] = (
    "lPCA",
    "TwoNN",
    "MLE",
    "CorrInt",
    "MiND_ML",
    "KNN",
    "DANCo",
    "FisherS",
)


def _load_estimator(name: str) -> type:
    """Lazy-import a torchid global estimator class by name."""
    try:
        from torchid import estimators as _est
    except ImportError as e:
        raise ImportError(
            "torchid is required for intrinsic-dimension metrics. "
            "Install with `pip install 'torchgeo-bench[id]'` "
            "(requires Python >=3.13)."
        ) from e
    if not hasattr(_est, name):
        raise ValueError(
            f"Unknown torchid estimator '{name}'. Supported: {', '.join(SUPPORTED_ESTIMATORS)}."
        )
    return getattr(_est, name)


def _resolve_device(device: str | torch.device | None) -> torch.device:
    """Resolve the requested device, falling back to CPU when CUDA unavailable."""
    if device is None:
        dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        dev = torch.device(device)
    if dev.type == "cuda" and not torch.cuda.is_available():
        logger.warning("CUDA requested for intrinsic-dim but unavailable; using CPU.")
        dev = torch.device("cpu")
    return dev


def _subsample(X: np.ndarray, max_samples: int | None, seed: int) -> np.ndarray:
    """Deterministically subsample rows of X if it exceeds max_samples."""
    if max_samples is None or X.shape[0] <= max_samples:
        return X
    rng = np.random.default_rng(seed)
    idx = rng.choice(X.shape[0], size=max_samples, replace=False)
    return X[idx]


def _two_nearest_distances(X: torch.Tensor) -> torch.Tensor:
    """Pairwise (d1, d2) for each row, matching torchid's knn precision.

    We deliberately replicate torchid's exact squared-distance formula
    (``x_sq + y_sq − 2·x·y.T`` then ``clamp_(min=0)``) instead of using
    ``torch.cdist``.  ``cdist`` is more numerically stable on CUDA, so
    its distances disagree with torchid's at the underflow boundary —
    that mismatch was hiding a TwoNN nan we just debugged (sweep 88205,
    Prithvi v1_100): the dedup said ``d1.min = 9.96e-3, zeros = 0`` but
    torchid's internal knn produced ``d1 == 0`` for the same rows
    because its squared-distance formula cancels to a tiny negative,
    gets clamped to 0, and underflows to 0 in fp32 after ``.sqrt()``.

    Replicating the formula keeps dedup and the estimator agreeing on
    which rows are degenerate.
    """
    x_sq = (X * X).sum(dim=1, keepdim=True)
    y_sq = x_sq.squeeze(1)
    d_sq = (x_sq + y_sq.unsqueeze(0) - 2.0 * (X @ X.T)).clamp_(min=0.0)
    d_sq.fill_diagonal_(float("inf"))
    top2_sq = d_sq.topk(k=2, largest=False).values
    return top2_sq.sqrt()


def _drop_zero_distance_rows(X_tensor: torch.Tensor) -> torch.Tensor:
    """Drop rows whose computed nearest-neighbour distance underflows to zero.

    TwoNN's slope is ``sum(x * y) / sum(x * x)`` over ``x = log(mu)`` where
    ``mu = d2 / d1``.  When two rows are close enough that their fp32 squared
    distance underflows, ``d1 == 0``; the estimator's inner ``clamp_min``
    leaves ``mu = 0``, and ``log(0) = -inf`` poisons the slope to ``nan`` —
    observed in the wild on Prithvi / Clay CLS-token embeddings.

    Bit-exact dedup doesn't catch this case because the rows differ in
    their last few bits; only the *distance* underflows.  Drop the rows
    where ``d1 == 0`` or ``d2 == 0`` so the remaining set has well-defined
    distance ratios.
    """
    d = _two_nearest_distances(X_tensor)
    keep = (d[:, 0] > 0) & (d[:, 1] > 0)
    n_drop = int((~keep).sum().item())
    if n_drop > 0:
        logger.info(
            f"[intrinsic-dim] dropped {n_drop} rows with zero-distance neighbours "
            f"({X_tensor.shape[0]} -> {int(keep.sum().item())}) before estimation."
        )
        return X_tensor[keep]
    return X_tensor


[docs] def compute_intrinsic_dim( X: np.ndarray, estimators: list[str], device: str | torch.device | None = None, max_samples: int | None = 10_000, seed: int = 0, ) -> dict[str, float]: """Compute intrinsic dimension of X for each requested estimator. Args: X: Feature matrix of shape ``(n_samples, n_features)``. estimators: Names of torchid global estimators (see ``SUPPORTED_ESTIMATORS``). device: ``"cuda"``, ``"cpu"``, a ``torch.device``, or ``None`` to auto-select (CUDA when available, otherwise CPU). max_samples: Cap row count via random subsampling for speed/memory. ``None`` disables subsampling. seed: RNG seed for subsampling determinism. Returns: Mapping ``{estimator_name: dimension}``. Estimator-internal exceptions propagate; we no longer swallow them as NaN, because doing so previously hid the TwoNN/fp32-zero-distance bug. """ if X.ndim != 2: raise ValueError(f"X must be 2D, got shape {X.shape}") if not estimators: return {} dev = _resolve_device(device) Xs = _subsample(X, max_samples, seed) X_tensor = torch.from_numpy(np.ascontiguousarray(Xs)).to(dev, dtype=torch.float32) X_tensor = _drop_zero_distance_rows(X_tensor) out: dict[str, float] = {} for name in estimators: # Estimator-internal exceptions propagate so we actually fix the # bug instead of silently emitting NaN to the CSV. The only # tolerated "soft" failure path is a numerical NaN/inf in # dimension_ after a clean fit — even there we raise with a full # diagnostic dump so the next person debugging has a real lead. cls = _load_estimator(name) est: Any = cls().fit(X_tensor) value = float(est.dimension_) if not np.isfinite(value): d = _two_nearest_distances(X_tensor) d1, d2 = d[:, 0], d[:, 1] raise DegenerateManifoldError( f"[intrinsic-dim] {name} returned non-finite dimension ({value}) on " f"X{tuple(X_tensor.shape)} after dedup. " f"d1[min={d1.min():.3e} median={d1.median():.3e} zeros={(d1 == 0).sum().item()}] " f"d2[min={d2.min():.3e} zeros={(d2 == 0).sum().item()}] " f"X[norm_min={X_tensor.norm(dim=1).min():.3e} " f"norm_max={X_tensor.norm(dim=1).max():.3e} std={X_tensor.std():.3e}]. " f"Investigate before writing this to the CSV." ) out[name] = value return out