torchgeo_bench.models#
This module provides the abstract BenchModel interface and a
collection of concrete backbones that can be benchmarked across the
torchgeo_bench.datasets registry.
Interface#
- class torchgeo_bench.models.BenchModel(bands, normalization=NormalizationStrategy.BANDSPEC_ZSCORE, **_)[source][source]#
-
Abstract base interface for benchmarkable models.
- Parameters:
bands (list[BandSpec]) – Ordered list of
BandSpecdescribing the input channels. Length determinesnum_channels.normalization (NormalizationStrategy | str) – Input-normalisation strategy name (one of
bandspec_zscore/model_native/minmax/minmax_zscore/identity). Defaults to"bandspec_zscore".
Subclasses may declare:
expected_input_unit— what scale the pretrained backbone was fed at training (e.g.s2_dn,reflectance_0_1,uint8). Used by themodel_nativestrategy.pretrain_mean/pretrain_std— per-channel normalisation applied after unit conversion undermodel_native.
- forward_patch_features(images)[source][source]#
Return a batch of vector embeddings
(B, K)from raw inputs.Sealed: applies
normalize_inputs()then dispatches to_forward_patch_features(). Overridenormalize_inputs()to change the normalization policy and_forward_patch_features()to change the backbone forward.
- forward(images)[source][source]#
Alias for
forward_patch_features().
Backbones#
Random Convolutional Features#
- class torchgeo_bench.models.RCFBench(bands, features=512, kernel_size=3, mode='gaussian', stats_mode='mean', seed=None, dataset=None, **_kwargs)[source][source]#
Bases:
BenchModelWrapper for the existing
RCFimplementation.Modes:
mode="gaussian": filters are drawn from a Gaussian; defaultBenchModel.normalize_inputs()(per-channel z-score) is applied to inference inputs.mode="empirical": filters are sampled fromdataset. To keep the filter bank and inference inputs in the same distribution, the passed dataset is wrapped so its samples are pre-normalized with the same per-channel z-score thisRCFBenchwill use at inference.
Image statistics baseline#
- class torchgeo_bench.models.ImageStatsBench(bands, normalization=NormalizationStrategy.BANDSPEC_ZSCORE, **_)[source][source]#
Bases:
BenchModelBenchModel that returns per-image statistics (mean, std, min, max).
Returns raw sensor statistics:
normalize_inputs()is overridden to identity so the per-band magnitudes are preserved. Downstream KNN distances and the LogisticRegression sweep see large, unscaled per-channel values; wideneval.c_rangeif the default sweep saturates.
timm encoders#
- class torchgeo_bench.models.TimmPatchBenchModel(bands, *, model_name, pretrained=True, normalize=False, global_pool='avg', auto_resize=False, target_size=None, use_cls_token=False, input_normalization='bands_zscore', **_kwargs)[source][source]#
Bases:
BenchModelBenchModel wrapper for any timm backbone.
- Parameters:
bands (list[BandSpec]) – Ordered
BandSpeclist (channel count =len(bands)).model_name (str) – Any timm model name (e.g.
"resnet50","convnext_small","vit_base_patch16_224").pretrained (bool) – Load pretrained weights when available.
normalize (bool) – If
True, L2-normalize the output embedding.global_pool (str | None) – Global pooling strategy for timm headless models.
use_cls_token (bool) – For ViT-family models, use the CLS token instead of averaging spatial tokens.
auto_resize (bool) – If
True, bilinearly resize each batch totarget_size.target_size (int | None) – Square target size; auto-inferred from
backbone.default_cfg["input_size"]when not given.input_normalization (str) – One of
"bands_zscore"(default; per-channel z-score usingBandSpecstatistics — correct for raw remote-sensing data of any channel count),"imagenet"(per-channel rescale to[0, 1]usingBandSpec.{min, max}, then(x - mean) / stdwith ImageNet RGB stats; refuses to instantiate whenlen(bands) != 3),"timm_default"(same shape as"imagenet"but readsbackbone.default_cfg’smean/std; also requireslen(bands) == 3), or"none"(identity).
torchgeo encoders#
- class torchgeo_bench.models.TorchGeoResNetBench(bands, *, factory='resnet50', weights_class='ResNet50_Weights', weights_member='SENTINEL2_RGB_MOCO', auto_resize=False, target_size=224, input_unit_check='warn', **_kwargs)[source][source]#
Bases:
_TorchGeoBackboneBenchWrapper for torchgeo ResNet models (resnet18 / resnet50 / resnet152).
These return
timm.models.resnet.ResNetinstances. We replace.fcwithIdentity()to get headless(B, K)feature vectors.Defaults match the SeCo / MoCo Sentinel-2 RGB pretrained weights, whose
Normalizetransform expects raw Sentinel-2 DN values divided into a single global scale.
- class torchgeo_bench.models.TorchGeoSwinBench(bands, *, factory='swin_v2_b', weights_class='Swin_V2_B_Weights', weights_member='NAIP_RGB_MI_SATLAS', auto_resize=True, target_size=256, input_unit_check='warn', **_kwargs)[source][source]#
Bases:
_TorchGeoBackboneBenchWrapper for torchgeo Swin-V2 models (NAIP / Sentinel-2 SatLAS variants).
- class torchgeo_bench.models.TorchGeoScaleMAEBench(bands, *, factory='scalemae_large_patch16', weights_class='ScaleMAELarge16_Weights', weights_member='FMOW_RGB', auto_resize=True, target_size=224, input_unit_check='warn', pool='mean', **_kwargs)[source][source]#
Bases:
_TorchGeoBackboneBenchWrapper for torchgeo ScaleMAE-Large.
forward_features()returns(B, N+1, D)tokens;poolselects between CLS, mean-pooled patch tokens, or their concatenation.
- class torchgeo_bench.models.TorchGeoDOFABench(bands, *, factory='dofa_base_patch16_224', weights_class='DOFABase16_Weights', weights_member='DOFA_MAE', wavelengths=None, auto_resize=True, target_size=224, input_unit_check='warn', **_kwargs)[source][source]#
Bases:
_TorchGeoBackboneBenchWrapper for torchgeo DOFA models (dofa_base / dofa_large).
DOFA requires a list of wavelengths (one per input channel in µm).
forward_features(x, wavelengths)returns(B, D).
- class torchgeo_bench.models.TorchGeoEarthLocBench(bands, *, factory='earthloc', weights_class='EarthLoc_Weights', weights_member='SENTINEL2_RESNET50', auto_resize=True, target_size=320, input_unit_check='warn', **_kwargs)[source][source]#
Bases:
_TorchGeoBackboneBenchWrapper for torchgeo EarthLoc.
forward(x)returns a(B, 4096)global descriptor.
OlmoEarth#
- class torchgeo_bench.models.OlmoEarthBenchModel(bands, *, model_size='base', version='v1', patch_size=4, input_res=None, time_steps=1, std_multiplier=2.0, normalize=False, sar_log_scale=False, landsat_scale_factor=None, sensor_remap=None, min_image_size=None, **_kwargs)[source][source]#
Bases:
BenchModelBenchModel wrapper for OlmoEarth geospatial foundation models.
OlmoEarth is a multi-modal ViT trained on Sentinel-2, Sentinel-1, Landsat, NAIP, and other Earth-observation streams by AI2. The wrapper picks the right modality (or modalities) from
bands[0].sensorand constructs a properly-shaped batch for OlmoEarth’s encoder.Supported modalities (auto-detected from
BandSpec.sensor):"s2"->Modality.SENTINEL2_L2A(12 channels, 3 band-sets)"landsat"->Modality.LANDSAT(11 channels, 2 band-sets)"sar"->Modality.SENTINEL1(2 channels, 1 band-set)"aerial"/"naip"-> S2 path with RGB zero-fill
Mixed-sensor inputs (e.g.
["s2", "sar"]) are handled by building separate tensor branches and populating multipleMaskedOlmoEarthSamplefields simultaneously.Channels missing from the input are zero-filled at the corresponding OlmoEarth position; the mask stays all-visible so
pool_spatiallycan still produce embeddings.The wrapper overrides
normalize_inputsto identity — OlmoEarth’s internalNormalizerconsumes raw values directly. Input scale (DN / reflectance / uint8) is auto-detected per sensor group and rescaled to S2 DN before normalisation (SAR values are passed as-is).input_resis auto-detected from the primary sensor’s GSD: 10 m for S2/SAR, 30 m for Landsat. Passinput_resexplicitly to override.- Parameters:
bands (list[BandSpec]) – Ordered
BandSpeclist describing the input channels.model_size (Literal['nano', 'tiny', 'base', 'large']) – One of
"nano","tiny","base","large"."large"is only available forversion="v1".version (Literal['v1', 'v1_1']) – Model version —
"v1"(default) or"v1_1". v1.1 ships Nano/Tiny/Base with improved accuracy and ~25% more parameters; no Large variant yet.patch_size (int) – Patch size for the encoder (default 8).
input_res (int | None) – Input resolution in meters.
None(default) lets the wrapper auto-detect from the primary sensor GSD.time_steps (int) – Temporal slots in the input. Default 3.
std_multiplier (float) – Std multiplier passed to
Normalizer.normalize (bool) – If True, L2-normalize output embeddings.
sar_log_scale (bool) – If True, convert SAR values to dB via
10·log10(max(v, 1e-6))before feeding OlmoEarth’s S1 normalizer, which was trained on σ⁰ dB values.landsat_scale_factor (float | None) – Optional multiplier applied to Landsat values after the standard uint8→DN conversion. Use to compensate for mis-matched scales between GeoBench’s uint8 composites and OlmoEarth’s pretraining DN range (~10 000).
sensor_remap (dict[str, str] | None) – Optional dict mapping sensor names to alternate routing keys, e.g.
{"landsat": "landsat_as_s2"}to route Landsat bands through the S2 normalizer (+6.6 pp on m-forestnet).min_image_size (int | None) – If set, upsample inputs smaller than this value to
min_image_size × min_image_sizevia bilinear interpolation. Useful for datasets with small native images (e.g. m-so2sat at 32 px) where the patch grid would otherwise be too sparse.
SAM 3#
- class torchgeo_bench.models.SAM3Encoder(bands, *, checkpoint_path=None, model_name_or_path='facebook/sam3', **_kwargs)[source][source]#
Bases:
BenchModelFrozen SAM3 vision encoder (ViT-H + FPN neck) as a benchmark backbone.
The full SAM3 model is loaded but only the vision encoder is retained. The text encoder, geometry encoder, DETR encoder/decoder, and mask decoder are discarded to save memory.
On the first forward pass the RoPE buffers are reset to match the actual input resolution. If the image dimensions are not multiples of
patch_size=14, images are cropped to the nearest valid size and a warning is logged. Only 3-channel RGB input is supported.- Parameters:
bands (list[BandSpec]) – Ordered
BandSpeclist. Must have exactly 3 entries (RGB only).checkpoint_path (str | None) – Path to a local HuggingFace-format checkpoint directory containing
model.safetensorsandconfig.json.model_name_or_path (str) – HuggingFace Hub model ID. Used only if
checkpoint_pathis not set.
Segmentation heads#
These heads attach to a frozen BenchModel backbone to produce
dense per-pixel predictions. See SegmentationProbe
for the wiring layer, and Segmentation backbone layer reference for the
verified eval.segmentation.layers values for each supported timm
backbone family.
- class torchgeo_bench.models.LinearHead(channels_list, num_classes)[source][source]#
Bases:
ModulePer-layer BN + 1×1 conv heads with learned scale-weighted fusion.
For a single layer the output is returned directly. For multiple layers, each head’s logits are upsampled to the input resolution and combined via learned scalar weights.
- Parameters:
- class torchgeo_bench.models.ConvBlockHead(channels_list, num_classes, hidden_dim=256)[source][source]#
Bases:
ModulePer-layer 1×1 projection to hidden_dim, aligned concat, 1×1 classification head.
All feature maps are projected to the same channel count, upsampled to the finest spatial resolution in the batch, concatenated, and classified with a single 1×1 conv.
- Parameters:
- class torchgeo_bench.models.FPNHead(channels_list, num_classes, hidden_dim=256)[source][source]#
Bases:
ModuleFeature Pyramid Network decoder head.
Applies lateral 1×1 convs, a top-down merging pathway, 3×3 refinement convs, then upsamples all levels to the finest resolution, concatenates, and classifies with a 1×1 conv.
Layers must be provided in coarse-to-fine order (deepest / lowest- resolution first). Example for ResNet:
["layer4", "layer3", "layer2", "layer1"].- Parameters:
- class torchgeo_bench.models.DPTHead(channels_list, num_classes, hidden_dim=256, kernel_size=3)[source][source]#
Bases:
ModuleDPT-style decoder head (adapted from probe3d at mbanani/probe3d, single-view).
Requires exactly 4 feature layers in coarse-to-fine order (same convention as FPN, e.g.
["layer4", "layer3", "layer2", "layer1"]for ResNet). The forward pass processes features from coarsest to finest through a cascade ofFeatureFusionBlockmodules.- Upsampling chain (mirroring probe3d):
1×1 project each map to
hidden_dim2× upsample all projected maps
Top-down fusion cascade (coarse → fine)
4× upsample the fused result
3×3 → ReLU → 3×3 output conv to
num_classesFinal resize to input resolution
- Parameters:
channels_list (list[int]) – Channel count for each hooked feature layer (coarse-to-fine). Must have exactly 4 entries.
num_classes (int) – Number of segmentation output classes.
hidden_dim (int) – Hidden channel dimension (default 256).
kernel_size (int) – Conv kernel size for residual units (default 3).