Source code for torchgeo_bench.segmentation_task

"""Segmentation Training Task Logic."""

import logging
import math

import torch
import torch.nn as nn
from rich.progress import track
from torch.utils.data import DataLoader
from torchmetrics.classification import (
    MulticlassF1Score,
    MulticlassJaccardIndex,
    MulticlassPrecision,
    MulticlassRecall,
)

from .segmentation_probe import (
    CachedFeaturesDataset,
    GPUTensorCache,
    SegmentationProbe,
)

logger = logging.getLogger(__name__)

SegMetrics = dict[str, float]


[docs] class SegmentationSolver: """A lightweight trainer for the SegmentationProbe.""" def __init__( self, model: SegmentationProbe, num_classes: int, lr: float = 1e-3, weight_decay: float = 0.0, device: str = "cuda", criterion: nn.Module | None = None, lr_scheduler: str = "cosine", ignore_index: int = 255, ) -> None: """Initialize the SegmentationSolver. Args: model: The SegmentationProbe model to train. num_classes: Number of segmentation classes. lr: Learning rate for the optimizer. weight_decay: Weight decay for the optimizer. device: Device to run training on ('cuda' or 'cpu'). criterion: Loss module. Defaults to CrossEntropyLoss with ignore_index. lr_scheduler: LR schedule: "cosine" (CosineAnnealingLR) or "none" (constant LR). ignore_index: Label value to ignore in loss and metrics (default: 255). """ self.model = model.to(device) self.num_classes = num_classes self.device = device self.lr_scheduler_type = lr_scheduler self.ignore_index = ignore_index self.optimizer = torch.optim.AdamW( filter(lambda p: p.requires_grad, self.model.parameters()), lr=lr, weight_decay=weight_decay, ) self.criterion = ( criterion if criterion is not None else nn.CrossEntropyLoss(ignore_index=self.ignore_index) ) self.metric = MulticlassJaccardIndex( num_classes=self.num_classes, ignore_index=self.ignore_index, average="macro", ) self.metric_fw_iou = MulticlassJaccardIndex( num_classes=self.num_classes, ignore_index=self.ignore_index, average="weighted", ) self.metric_precision = MulticlassPrecision( num_classes=self.num_classes, ignore_index=self.ignore_index, average="macro", ) self.metric_recall = MulticlassRecall( num_classes=self.num_classes, ignore_index=self.ignore_index, average="macro", ) self.metric_f1 = MulticlassF1Score( num_classes=self.num_classes, ignore_index=self.ignore_index, average="macro", ) self._all_metrics = [ self.metric, self.metric_fw_iou, self.metric_precision, self.metric_recall, self.metric_f1, ] self.use_amp = device.startswith("cuda") and torch.cuda.is_available() self.scaler = torch.amp.GradScaler("cuda", enabled=self.use_amp) self.device_type = torch.device(device).type def _make_scheduler(self, epochs: int) -> torch.optim.lr_scheduler.LRScheduler | None: """Return a CosineAnnealingLR scheduler, or None for constant LR.""" if self.lr_scheduler_type == "cosine": return torch.optim.lr_scheduler.CosineAnnealingLR( self.optimizer, T_max=epochs, eta_min=1e-6 ) if self.lr_scheduler_type == "none": return None raise ValueError( f"Unknown lr_scheduler {self.lr_scheduler_type!r}. Expected 'cosine' or 'none'." )
[docs] def fit( self, train_loader: DataLoader, val_loader: DataLoader | None = None, epochs: int = 10, verbose: bool = True, ) -> float | None: """Train the segmentation probe. Args: train_loader: Training data loader. val_loader: Optional validation data loader for per-epoch mIoU logging. epochs: Number of training epochs. verbose: Whether to show progress bars and epoch logs. Returns: Val mIoU from the final epoch if val_loader is given, else None. """ scheduler = self._make_scheduler(epochs) last_val_miou: float | None = None for epoch in range(epochs): self.model.train() if self.model.freeze_backbone: self.model.backbone.eval() total_loss = 0.0 desc = f"Epoch {epoch + 1}/{epochs}" batches = track(train_loader, description=desc) if verbose else train_loader for _num_batches, batch in enumerate(batches, start=1): if isinstance(batch, dict): images = batch["image"].to(self.device) masks = batch["mask"].to(self.device).long() else: images, masks = batch[0].to(self.device), batch[1].to(self.device).long() if masks.ndim == 4: masks = masks.squeeze(1) self.optimizer.zero_grad() with torch.autocast(device_type=self.device_type, enabled=self.use_amp): logits = self.model(images) loss = self.criterion(logits, masks) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() total_loss += loss.item() if scheduler is not None: scheduler.step() if val_loader: val_metrics = self.evaluate(val_loader) last_val_miou = val_metrics["mIoU"] if verbose: logger.info(f"Epoch {epoch + 1} Val mIoU: {last_val_miou:.4f}") return last_val_miou
[docs] @torch.no_grad() def evaluate( self, dataloader: DataLoader, collect_preds: bool = False, ) -> "SegMetrics | tuple[SegMetrics, torch.Tensor]": """Evaluate the model on a dataloader and return segmentation metrics. Args: dataloader: Evaluation data loader. collect_preds: If True, also return predicted class maps (N, H, W) int64. Returns: Dict of metric name → value, or (metrics_dict, preds_tensor) when collect_preds=True. """ self.model.eval() for m in self._all_metrics: m.reset() m.to(self.device) pred_list: list[torch.Tensor] = [] for batch in dataloader: if isinstance(batch, dict): images = batch["image"].to(self.device) masks = batch["mask"].to(self.device) else: images, masks = batch[0].to(self.device), batch[1].to(self.device) # Ensure masks are (B, H, W) if masks.ndim == 4: masks = masks.squeeze(1) masks = masks.long() with torch.autocast(device_type=self.device_type, enabled=self.use_amp): logits = self.model(images) for m in self._all_metrics: m.update(logits, masks) if collect_preds: pred_list.append(logits.argmax(dim=1).cpu()) metrics = self._compute_metrics() if collect_preds: return metrics, torch.cat(pred_list, dim=0) return metrics
[docs] def fit_cached( self, train_cache: CachedFeaturesDataset, val_cache: CachedFeaturesDataset | None = None, batch_size: int = 64, epochs: int = 10, verbose: bool = True, gpu_train: "GPUTensorCache | None" = None, gpu_val: "GPUTensorCache | None" = None, ) -> float | None: """Train the segmentation head on pre-cached backbone features. The backbone is **not** called during training — cached features are fed directly to ``self.model.head``, which is the only component that runs a forward/backward pass. The entire feature cache is pre-moved to the GPU as contiguous tensors (:class:`GPUTensorCache`), eliminating per-batch CPU→GPU DMA transfers and ``torch.stack`` calls. If ``gpu_train`` is provided, that pre-built cache is used directly, allowing callers (e.g. an HPO loop) to transfer the cache once and reuse it across many calls. Args: train_cache: Pre-extracted training features from :meth:`SegmentationProbe.extract_segmentation_features`. val_cache: Optional validation cache for per-epoch mIoU logging. batch_size: Batch size for iterating over cached data. epochs: Number of training epochs. verbose: Whether to show progress bars and epoch logs. gpu_train: Optional pre-built GPU cache for training. If provided, the GPU transfer is skipped. gpu_val: Optional pre-built GPU cache for validation. Used only when ``gpu_train`` is also provided. Returns: Val mIoU from the final epoch if val_cache is given, else None. """ if gpu_train is None: gpu_train = GPUTensorCache.from_cached(train_cache, self.device) if val_cache is not None: gpu_val = GPUTensorCache.from_cached(val_cache, self.device) # Fast path: GPU tensor cache — no DataLoader, no host→device transfer per batch scheduler = self._make_scheduler(epochs) input_hw: tuple[int, int] = (gpu_train.masks.shape[-2], gpu_train.masks.shape[-1]) last_val_miou: float | None = None num_batches = math.ceil(len(gpu_train) / batch_size) for epoch in range(epochs): self.model.train() if self.model.freeze_backbone: self.model.backbone.eval() total_loss = 0.0 desc = f"Epoch {epoch + 1}/{epochs}" batches = gpu_train.shuffled_batches(batch_size) batches = track(batches, total=num_batches, description=desc) if verbose else batches for features, masks in batches: self.optimizer.zero_grad() with torch.autocast(device_type=self.device_type, enabled=self.use_amp): logits = self.model.head(features, *input_hw) loss = self.criterion(logits, masks) self.scaler.scale(loss).backward() self.scaler.step(self.optimizer) self.scaler.update() total_loss += loss.item() if scheduler is not None: scheduler.step() if gpu_val is not None: val_metrics = self._evaluate_gpu_cache(gpu_val, batch_size) last_val_miou = val_metrics["mIoU"] if verbose: logger.info(f"Epoch {epoch + 1} Val mIoU: {last_val_miou:.4f}") return last_val_miou
[docs] def evaluate_cached( self, cache: CachedFeaturesDataset, batch_size: int = 64, collect_preds: bool = False, ) -> "SegMetrics | tuple[SegMetrics, torch.Tensor]": """Evaluate on a CachedFeaturesDataset. The cache is moved to GPU as a :class:`GPUTensorCache` for zero per-batch host→device transfers. Args: cache: Pre-extracted features (output of :meth:`SegmentationProbe.extract_segmentation_features`). batch_size: Batch size for iterating over the cache. collect_preds: If True, also return predicted class maps (N, H, W) int64. Returns: Dict of metric name → value, or (metrics_dict, preds_tensor) when collect_preds=True. """ gpu_cache = GPUTensorCache.from_cached(cache, self.device) return self._evaluate_gpu_cache(gpu_cache, batch_size, collect_preds=collect_preds)
def _compute_metrics(self) -> "SegMetrics": """Compute and return all metrics as a dict.""" return { "mIoU": self.metric.compute().item(), "fw_IoU": self.metric_fw_iou.compute().item(), "precision": self.metric_precision.compute().item(), "recall": self.metric_recall.compute().item(), "f1": self.metric_f1.compute().item(), } @torch.no_grad() def _evaluate_gpu_cache( self, gpu_cache: GPUTensorCache, batch_size: int, collect_preds: bool = False, ) -> "SegMetrics | tuple[SegMetrics, torch.Tensor]": """Evaluate on a :class:`GPUTensorCache` and return segmentation metrics.""" self.model.eval() for m in self._all_metrics: m.reset() m.to(self.device) pred_list: list[torch.Tensor] = [] input_hw = (gpu_cache.masks.shape[-2], gpu_cache.masks.shape[-1]) for features, masks in gpu_cache.ordered_batches(batch_size): with torch.autocast(device_type=self.device_type, enabled=self.use_amp): logits = self.model.head(features, *input_hw) for m in self._all_metrics: m.update(logits, masks) if collect_preds: pred_list.append(logits.argmax(dim=1).cpu()) metrics = self._compute_metrics() if collect_preds: return metrics, torch.cat(pred_list, dim=0) return metrics