Source code for torchgeo_bench.utils
"""Feature extraction utilities for model benchmarking."""
import numpy as np
import torch
from rich.progress import track
from torch.utils.data import DataLoader
[docs]
def extract_features(
model: torch.nn.Module,
dataloader: DataLoader,
device: str | torch.device,
transforms: object | None = None,
verbose: bool = True,
) -> tuple[np.ndarray, np.ndarray]:
"""Extract feature embeddings and labels from a dataloader.
Args:
model: Model to use for feature extraction.
dataloader: DataLoader yielding dicts with ``"image"`` and ``"label"`` keys.
device: Device to run inference on.
transforms: Optional transform applied to images before the model.
verbose: Whether to display a progress bar.
Returns:
Tuple of (features, labels) as NumPy arrays.
"""
x_all = []
y_all = []
iterator = (
track(dataloader, total=len(dataloader), description="Extracting")
if verbose
else dataloader
)
for batch in iterator:
images = batch["image"].to(device)
if "label" not in batch:
raise KeyError(
"Batch is missing 'label' key. extract_features() is a classification "
"utility; for segmentation use "
"SegmentationProbe.extract_segmentation_features() instead."
)
labels = batch["label"].numpy()
if transforms is not None:
images = transforms(images)
with torch.no_grad(), torch.inference_mode():
features = model(images)
if isinstance(features, torch.Tensor):
features = features.cpu().numpy()
else:
if "norm" in features:
features = features["norm"].cpu().numpy()
elif "global_pool" in features:
features = features["global_pool"].cpu().numpy()
elif "head.global_pool" in features:
features = features["head.global_pool"].cpu().numpy()
if features.ndim == 3 and features.shape[1] == 1:
features = features[:, 0, :]
else:
raise ValueError(f"Unexpected features format: {features.keys()}")
if features.ndim == 1:
features = features[np.newaxis, :]
if features.ndim == 3:
features = np.mean(features, axis=1, keepdims=False)
x_all.append(features)
y_all.append(labels)
x_all = np.concatenate(x_all, axis=0)
y_all = np.concatenate(y_all, axis=0)
return x_all, y_all