Source code for torchgeo_bench.models.interface
"""Model interface for torchgeo-bench.
Defines :class:`BenchModel`, the abstract base class that every benchmarkable
model inherits from. The contract is split into two halves:
1. **Construction**: subclasses receive a
:class:`list[~torchgeo_bench.datasets.base.BandSpec]` describing the input
channels. Per-channel mean/std/min/max statistics are available on each
:class:`~torchgeo_bench.datasets.base.BandSpec` and are used by
:meth:`normalize_inputs` to z-score the raw input tensor.
2. **Forward path**:
* The public :meth:`forward_patch_features` is **sealed** — subclasses
do not override it. It always applies :meth:`normalize_inputs` before
dispatching to :meth:`_forward_patch_features`, so normalization can't
be silently forgotten.
* Subclasses implement :meth:`_forward_patch_features` (the abstract
hook) which receives the **already-normalized** ``(B, C, H, W)`` tensor
and returns ``(B, K)`` embeddings.
Models whose backbones do their own normalization (e.g. OlmoEarth) override
:meth:`normalize_inputs` to identity. Models that need a different
normalization strategy (ImageNet-style for pretrained RGB CNNs, weights-bound
``Normalize`` transforms for torchgeo wrappers, etc.) override
:meth:`normalize_inputs` with their own policy.
"""
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from torchgeo_bench.datasets.base import BandSpec
from ._input_units import InputUnit
from ._normalization import NormalizationStrategy, build_normalizer
[docs]
class BenchModel(nn.Module, ABC):
"""Abstract base interface for benchmarkable models.
Args:
bands: Ordered list of :class:`BandSpec` describing the input
channels. Length determines :attr:`num_channels`.
normalization: Input-normalisation strategy name (one of
``bandspec_zscore`` / ``model_native`` / ``minmax`` /
``minmax_zscore`` / ``identity``). Defaults to
``"bandspec_zscore"``.
Subclasses may declare:
* ``expected_input_unit`` — what scale the pretrained backbone was
fed at training (e.g. ``s2_dn``, ``reflectance_0_1``, ``uint8``).
Used by the ``model_native`` strategy.
* ``pretrain_mean`` / ``pretrain_std`` — per-channel normalisation
applied *after* unit conversion under ``model_native``.
"""
expected_input_unit: InputUnit | None = None
pretrain_mean: list[float] | None = None
pretrain_std: list[float] | None = None
def __init__(
self,
bands: list[BandSpec],
normalization: NormalizationStrategy | str = NormalizationStrategy.BANDSPEC_ZSCORE,
**_: object,
) -> None:
super().__init__()
if not bands:
raise ValueError("BenchModel requires a non-empty list of BandSpec.")
self.bands: list[BandSpec] = list(bands)
self.num_channels: int = len(self.bands)
self.normalization = NormalizationStrategy(normalization)
self._normalizer = build_normalizer(
self.normalization,
bands=self.bands,
expected_input_unit=self.expected_input_unit,
pretrain_mean=self.pretrain_mean,
pretrain_std=self.pretrain_std,
)
@abstractmethod
def _forward_patch_features(self, images: torch.Tensor) -> torch.Tensor:
"""Subclass hook — receives normalized ``(B, C, H, W)``, returns ``(B, K)``.
Implementations should call only the backbone; the public
:meth:`forward_patch_features` has already applied
:meth:`normalize_inputs`.
Args:
images: Normalized input tensor of shape ``(B, C, H, W)``.
Returns:
Embeddings tensor of shape ``(B, K)``.
"""
raise NotImplementedError
[docs]
def forward_patch_features(self, images: torch.Tensor) -> torch.Tensor:
"""Return a batch of vector embeddings ``(B, K)`` from raw inputs.
Sealed: applies :meth:`normalize_inputs` then dispatches to
:meth:`_forward_patch_features`. Override
:meth:`normalize_inputs` to change the normalization policy and
:meth:`_forward_patch_features` to change the backbone forward.
"""
return self._forward_patch_features(self.normalize_inputs(images))
[docs]
def forward(self, images: torch.Tensor) -> torch.Tensor:
"""Alias for :meth:`forward_patch_features`."""
return self.forward_patch_features(images)