torchgeo_bench.models#

This module provides the abstract BenchModel interface and a collection of concrete backbones that can be benchmarked across the torchgeo_bench.datasets registry.

Interface#

class torchgeo_bench.models.BenchModel(bands, normalization=NormalizationStrategy.BANDSPEC_ZSCORE, **_)[source][source]#

Bases: Module, ABC

Abstract base interface for benchmarkable models.

Parameters:
  • bands (list[BandSpec]) – Ordered list of BandSpec describing the input channels. Length determines num_channels.

  • normalization (NormalizationStrategy | str) – Input-normalisation strategy name (one of bandspec_zscore / model_native / minmax / minmax_zscore / identity). Defaults to "bandspec_zscore".

Subclasses may declare:

  • expected_input_unit — what scale the pretrained backbone was fed at training (e.g. s2_dn, reflectance_0_1, uint8). Used by the model_native strategy.

  • pretrain_mean / pretrain_std — per-channel normalisation applied after unit conversion under model_native.

normalize_inputs(images)[source][source]#

Apply the configured normalisation strategy.

forward_patch_features(images)[source][source]#

Return a batch of vector embeddings (B, K) from raw inputs.

Sealed: applies normalize_inputs() then dispatches to _forward_patch_features(). Override normalize_inputs() to change the normalization policy and _forward_patch_features() to change the backbone forward.

forward(images)[source][source]#

Alias for forward_patch_features().

Backbones#

Random Convolutional Features#

class torchgeo_bench.models.RCFBench(bands, features=512, kernel_size=3, mode='gaussian', stats_mode='mean', seed=None, dataset=None, **_kwargs)[source][source]#

Bases: BenchModel

Wrapper for the existing RCF implementation.

Modes:

  • mode="gaussian": filters are drawn from a Gaussian; default BenchModel.normalize_inputs() (per-channel z-score) is applied to inference inputs.

  • mode="empirical": filters are sampled from dataset. To keep the filter bank and inference inputs in the same distribution, the passed dataset is wrapped so its samples are pre-normalized with the same per-channel z-score this RCFBench will use at inference.

Image statistics baseline#

class torchgeo_bench.models.ImageStatsBench(bands, normalization=NormalizationStrategy.BANDSPEC_ZSCORE, **_)[source][source]#

Bases: BenchModel

BenchModel that returns per-image statistics (mean, std, min, max).

Returns raw sensor statistics: normalize_inputs() is overridden to identity so the per-band magnitudes are preserved. Downstream KNN distances and the LogisticRegression sweep see large, unscaled per-channel values; widen eval.c_range if the default sweep saturates.

normalize_inputs(images)[source][source]#

Identity — this model intentionally exposes raw sensor statistics.

timm encoders#

class torchgeo_bench.models.TimmPatchBenchModel(bands, *, model_name, pretrained=True, normalize=False, global_pool='avg', auto_resize=False, target_size=None, use_cls_token=False, input_normalization='bands_zscore', **_kwargs)[source][source]#

Bases: BenchModel

BenchModel wrapper for any timm backbone.

Parameters:
  • bands (list[BandSpec]) – Ordered BandSpec list (channel count = len(bands)).

  • model_name (str) – Any timm model name (e.g. "resnet50", "convnext_small", "vit_base_patch16_224").

  • pretrained (bool) – Load pretrained weights when available.

  • normalize (bool) – If True, L2-normalize the output embedding.

  • global_pool (str | None) – Global pooling strategy for timm headless models.

  • use_cls_token (bool) – For ViT-family models, use the CLS token instead of averaging spatial tokens.

  • auto_resize (bool) – If True, bilinearly resize each batch to target_size.

  • target_size (int | None) – Square target size; auto-inferred from backbone.default_cfg["input_size"] when not given.

  • input_normalization (str) – One of "bands_zscore" (default; per-channel z-score using BandSpec statistics — correct for raw remote-sensing data of any channel count), "imagenet" (per-channel rescale to [0, 1] using BandSpec.{min, max}, then (x - mean) / std with ImageNet RGB stats; refuses to instantiate when len(bands) != 3), "timm_default" (same shape as "imagenet" but reads backbone.default_cfg’s mean/std; also requires len(bands) == 3), or "none" (identity).

normalize_inputs(images)[source][source]#

Apply the configured normalization policy.

torchgeo encoders#

class torchgeo_bench.models.TorchGeoResNetBench(bands, *, factory='resnet50', weights_class='ResNet50_Weights', weights_member='SENTINEL2_RGB_MOCO', auto_resize=False, target_size=224, input_unit_check='warn', **_kwargs)[source][source]#

Bases: _TorchGeoBackboneBench

Wrapper for torchgeo ResNet models (resnet18 / resnet50 / resnet152).

These return timm.models.resnet.ResNet instances. We replace .fc with Identity() to get headless (B, K) feature vectors.

Defaults match the SeCo / MoCo Sentinel-2 RGB pretrained weights, whose Normalize transform expects raw Sentinel-2 DN values divided into a single global scale.

class torchgeo_bench.models.TorchGeoSwinBench(bands, *, factory='swin_v2_b', weights_class='Swin_V2_B_Weights', weights_member='NAIP_RGB_MI_SATLAS', auto_resize=True, target_size=256, input_unit_check='warn', **_kwargs)[source][source]#

Bases: _TorchGeoBackboneBench

Wrapper for torchgeo Swin-V2 models (NAIP / Sentinel-2 SatLAS variants).

class torchgeo_bench.models.TorchGeoScaleMAEBench(bands, *, factory='scalemae_large_patch16', weights_class='ScaleMAELarge16_Weights', weights_member='FMOW_RGB', auto_resize=True, target_size=224, input_unit_check='warn', pool='mean', **_kwargs)[source][source]#

Bases: _TorchGeoBackboneBench

Wrapper for torchgeo ScaleMAE-Large.

forward_features() returns (B, N+1, D) tokens; pool selects between CLS, mean-pooled patch tokens, or their concatenation.

class torchgeo_bench.models.TorchGeoDOFABench(bands, *, factory='dofa_base_patch16_224', weights_class='DOFABase16_Weights', weights_member='DOFA_MAE', wavelengths=None, auto_resize=True, target_size=224, input_unit_check='warn', **_kwargs)[source][source]#

Bases: _TorchGeoBackboneBench

Wrapper for torchgeo DOFA models (dofa_base / dofa_large).

DOFA requires a list of wavelengths (one per input channel in µm). forward_features(x, wavelengths) returns (B, D).

class torchgeo_bench.models.TorchGeoEarthLocBench(bands, *, factory='earthloc', weights_class='EarthLoc_Weights', weights_member='SENTINEL2_RESNET50', auto_resize=True, target_size=320, input_unit_check='warn', **_kwargs)[source][source]#

Bases: _TorchGeoBackboneBench

Wrapper for torchgeo EarthLoc.

forward(x) returns a (B, 4096) global descriptor.

OlmoEarth#

class torchgeo_bench.models.OlmoEarthBenchModel(bands, *, model_size='base', version='v1', patch_size=4, input_res=None, time_steps=1, std_multiplier=2.0, normalize=False, sar_log_scale=False, landsat_scale_factor=None, sensor_remap=None, min_image_size=None, **_kwargs)[source][source]#

Bases: BenchModel

BenchModel wrapper for OlmoEarth geospatial foundation models.

OlmoEarth is a multi-modal ViT trained on Sentinel-2, Sentinel-1, Landsat, NAIP, and other Earth-observation streams by AI2. The wrapper picks the right modality (or modalities) from bands[0].sensor and constructs a properly-shaped batch for OlmoEarth’s encoder.

Supported modalities (auto-detected from BandSpec.sensor):

  • "s2" -> Modality.SENTINEL2_L2A (12 channels, 3 band-sets)

  • "landsat" -> Modality.LANDSAT (11 channels, 2 band-sets)

  • "sar" -> Modality.SENTINEL1 (2 channels, 1 band-set)

  • "aerial" / "naip" -> S2 path with RGB zero-fill

Mixed-sensor inputs (e.g. ["s2", "sar"]) are handled by building separate tensor branches and populating multiple MaskedOlmoEarthSample fields simultaneously.

Channels missing from the input are zero-filled at the corresponding OlmoEarth position; the mask stays all-visible so pool_spatially can still produce embeddings.

The wrapper overrides normalize_inputs to identity — OlmoEarth’s internal Normalizer consumes raw values directly. Input scale (DN / reflectance / uint8) is auto-detected per sensor group and rescaled to S2 DN before normalisation (SAR values are passed as-is).

input_res is auto-detected from the primary sensor’s GSD: 10 m for S2/SAR, 30 m for Landsat. Pass input_res explicitly to override.

Parameters:
  • bands (list[BandSpec]) – Ordered BandSpec list describing the input channels.

  • model_size (Literal['nano', 'tiny', 'base', 'large']) – One of "nano", "tiny", "base", "large". "large" is only available for version="v1".

  • version (Literal['v1', 'v1_1']) – Model version — "v1" (default) or "v1_1". v1.1 ships Nano/Tiny/Base with improved accuracy and ~25% more parameters; no Large variant yet.

  • patch_size (int) – Patch size for the encoder (default 8).

  • input_res (int | None) – Input resolution in meters. None (default) lets the wrapper auto-detect from the primary sensor GSD.

  • time_steps (int) – Temporal slots in the input. Default 3.

  • std_multiplier (float) – Std multiplier passed to Normalizer.

  • normalize (bool) – If True, L2-normalize output embeddings.

  • sar_log_scale (bool) – If True, convert SAR values to dB via 10·log10(max(v, 1e-6)) before feeding OlmoEarth’s S1 normalizer, which was trained on σ⁰ dB values.

  • landsat_scale_factor (float | None) – Optional multiplier applied to Landsat values after the standard uint8→DN conversion. Use to compensate for mis-matched scales between GeoBench’s uint8 composites and OlmoEarth’s pretraining DN range (~10 000).

  • sensor_remap (dict[str, str] | None) – Optional dict mapping sensor names to alternate routing keys, e.g. {"landsat": "landsat_as_s2"} to route Landsat bands through the S2 normalizer (+6.6 pp on m-forestnet).

  • min_image_size (int | None) – If set, upsample inputs smaller than this value to min_image_size × min_image_size via bilinear interpolation. Useful for datasets with small native images (e.g. m-so2sat at 32 px) where the patch grid would otherwise be too sparse.

normalize_inputs(images)[source][source]#

Identity — OlmoEarth’s internal Normalizer handles raw values.

SAM 3#

class torchgeo_bench.models.SAM3Encoder(bands, *, checkpoint_path=None, model_name_or_path='facebook/sam3', **_kwargs)[source][source]#

Bases: BenchModel

Frozen SAM3 vision encoder (ViT-H + FPN neck) as a benchmark backbone.

The full SAM3 model is loaded but only the vision encoder is retained. The text encoder, geometry encoder, DETR encoder/decoder, and mask decoder are discarded to save memory.

On the first forward pass the RoPE buffers are reset to match the actual input resolution. If the image dimensions are not multiples of patch_size=14, images are cropped to the nearest valid size and a warning is logged. Only 3-channel RGB input is supported.

Parameters:
  • bands (list[BandSpec]) – Ordered BandSpec list. Must have exactly 3 entries (RGB only).

  • checkpoint_path (str | None) – Path to a local HuggingFace-format checkpoint directory containing model.safetensors and config.json.

  • model_name_or_path (str) – HuggingFace Hub model ID. Used only if checkpoint_path is not set.

Segmentation heads#

These heads attach to a frozen BenchModel backbone to produce dense per-pixel predictions. See SegmentationProbe for the wiring layer, and Segmentation backbone layer reference for the verified eval.segmentation.layers values for each supported timm backbone family.

class torchgeo_bench.models.LinearHead(channels_list, num_classes)[source][source]#

Bases: Module

Per-layer BN + 1×1 conv heads with learned scale-weighted fusion.

For a single layer the output is returned directly. For multiple layers, each head’s logits are upsampled to the input resolution and combined via learned scalar weights.

Parameters:
  • channels_list (list[int]) – Channel count for each hooked feature layer.

  • num_classes (int) – Number of segmentation output classes.

forward(features, input_h, input_w)[source][source]#

Upsample and sum per-layer logits.

class torchgeo_bench.models.ConvBlockHead(channels_list, num_classes, hidden_dim=256)[source][source]#

Bases: Module

Per-layer 1×1 projection to hidden_dim, aligned concat, 1×1 classification head.

All feature maps are projected to the same channel count, upsampled to the finest spatial resolution in the batch, concatenated, and classified with a single 1×1 conv.

Parameters:
  • channels_list (list[int]) – Channel count for each hooked feature layer.

  • num_classes (int) – Number of segmentation output classes.

  • hidden_dim (int) – Projection dimension (default 256).

forward(features, input_h, input_w)[source][source]#

Project, upsample, concat, and classify features.

class torchgeo_bench.models.FPNHead(channels_list, num_classes, hidden_dim=256)[source][source]#

Bases: Module

Feature Pyramid Network decoder head.

Applies lateral 1×1 convs, a top-down merging pathway, 3×3 refinement convs, then upsamples all levels to the finest resolution, concatenates, and classifies with a 1×1 conv.

Layers must be provided in coarse-to-fine order (deepest / lowest- resolution first). Example for ResNet: ["layer4", "layer3", "layer2", "layer1"].

Parameters:
  • channels_list (list[int]) – Channel count for each hooked feature layer (coarse-to-fine).

  • num_classes (int) – Number of segmentation output classes.

  • hidden_dim (int) – Feature dimension used throughout the FPN (default 256).

forward(features, input_h, input_w)[source][source]#

Top-down FPN forward pass.

Parameters:
  • features (list[Tensor]) – Feature maps in coarse-to-fine order (index 0 = coarsest).

  • input_h (int) – Target output height (input image height).

  • input_w (int) – Target output width (input image width).

class torchgeo_bench.models.DPTHead(channels_list, num_classes, hidden_dim=256, kernel_size=3)[source][source]#

Bases: Module

DPT-style decoder head (adapted from probe3d at mbanani/probe3d, single-view).

Requires exactly 4 feature layers in coarse-to-fine order (same convention as FPN, e.g. ["layer4", "layer3", "layer2", "layer1"] for ResNet). The forward pass processes features from coarsest to finest through a cascade of FeatureFusionBlock modules.

Upsampling chain (mirroring probe3d):
  1. 1×1 project each map to hidden_dim

  2. 2× upsample all projected maps

  3. Top-down fusion cascade (coarse → fine)

  4. 4× upsample the fused result

  5. 3×3 → ReLU → 3×3 output conv to num_classes

  6. Final resize to input resolution

Parameters:
  • channels_list (list[int]) – Channel count for each hooked feature layer (coarse-to-fine). Must have exactly 4 entries.

  • num_classes (int) – Number of segmentation output classes.

  • hidden_dim (int) – Hidden channel dimension (default 256).

  • kernel_size (int) – Conv kernel size for residual units (default 3).

forward(features, input_h, input_w)[source][source]#

DPT forward pass.

Parameters:
  • features (list[Tensor]) – Feature maps in coarse-to-fine order (index 0 = coarsest).

  • input_h (int) – Target output height.

  • input_w (int) – Target output width.