Source code for torchgeo_bench.models.olmoearth

"""OlmoEarth model wrapper for torchgeo-bench.

Wraps the OlmoEarth geospatial foundation model (AI2) for use with the
BenchModel interface.  Multi-modal: the wrapper auto-selects OlmoEarth's
``Modality.SENTINEL2_L2A`` / ``SENTINEL1`` / ``LANDSAT`` / ``NAIP`` based
on the input ``BandSpec.sensor`` field and builds the right channel layout,
band-set mask, and ``MaskedOlmoEarthSample`` field for each modality.

Mixed-sensor inputs (e.g. m-so2sat with Sentinel-2 + SAR) are handled by
splitting bands into per-sensor groups and populating the corresponding
``MaskedOlmoEarthSample`` fields simultaneously.

Reference implementations (canonical first):
    https://github.com/allenai/olmoearth_pretrain/blob/main/docs/Inference-Quickstart.md
    https://github.com/isaaccorley/geopool/blob/main/scripts/embed_olmoearth.py
"""

import logging
from collections import defaultdict
from typing import Literal

import numpy as np
import torch
import torch.nn.functional as F

from torchgeo_bench.datasets.base import BandSpec

from ._input_units import InputUnit, _detect_band_group_unit, to_s2_dn
from .interface import BenchModel

logger = logging.getLogger(__name__)


# Canonical S2 band order kept as a module constant for backward
# compatibility with tests and external callers.
OLMOEARTH_S2_BANDS = (
    "B02",
    "B03",
    "B04",
    "B08",
    "B05",
    "B06",
    "B07",
    "B8A",
    "B11",
    "B12",
    "B01",
    "B09",
)


# Per-modality layout — kept as a single source of truth so the build
# steps below (channel layout, mask shape, sample field) all agree.
#
# ``name_to_idx`` maps any ``BandSpec.name`` we expect to see in our
# datasets (lower-cased) to the corresponding position in OlmoEarth's
# expected band_order for the modality.  Both semantic names (``blue``,
# ``red_edge_1``) and source-style names (``b02``, ``b04``) are
# accepted so the wrapper works with either GeoBench V1 or V2 datasets.
_MODALITY_INFO: dict[str, dict] = {
    "s2": {
        "modality_name": "SENTINEL2_L2A",
        "sample_field": "sentinel2_l2a",
        "channels": 12,
        "num_band_sets": 3,
        # OlmoEarth band_order: B02, B03, B04, B08, B05, B06, B07, B8A,
        # B11, B12, B01, B09.
        "name_to_idx": {
            "blue": 0,
            "b02": 0,
            "green": 1,
            "b03": 1,
            "red": 2,
            "b04": 2,
            "nir": 3,
            "b08": 3,
            "red_edge_1": 4,
            "b05": 4,
            "red_edge_2": 5,
            "b06": 5,
            "red_edge_3": 6,
            "b07": 6,
            "red_edge_4": 7,
            "b8a": 7,
            "swir_1": 8,
            "b11": 8,
            "swir_2": 9,
            "b12": 9,
            "coastal_aerosol": 10,
            "b01": 10,
            "water_vapour": 11,
            "b09": 11,
            # B10 (cirrus) — OlmoEarth has no cirrus slot; skip gracefully.
            "swir_cirrus": None,
            "b10": None,
        },
    },
    "landsat": {
        "modality_name": "LANDSAT",
        "sample_field": "landsat",
        "channels": 11,
        "num_band_sets": 2,
        # OlmoEarth band_order: B8 (pan), B1, B2, B3, B4, B5, B6, B7,
        # B9, B10, B11.  GeoBench m-forestnet (Landsat-8) typically
        # ships only B2/B3/B4/B5/B6/B7 under semantic names.
        "name_to_idx": {
            "panchromatic": 0,
            "pan": 0,
            "b8": 0,
            "coastal_aerosol": 1,
            "coastal": 1,
            "b1": 1,
            "blue": 2,
            "b2": 2,
            "green": 3,
            "b3": 3,
            "red": 4,
            "b4": 4,
            "nir": 5,
            "b5": 5,
            "swir_1": 6,
            "b6": 6,
            "swir_2": 7,
            "b7": 7,
            "cirrus": 8,
            "b9": 8,
            "tirs_1": 9,
            "thermal_1": 9,
            "b10": 9,
            "tirs_2": 10,
            "thermal_2": 10,
            "b11": 10,
        },
    },
    # Sentinel-1 SAR: two channels — vv (0) and vh (1).  OlmoEarth
    # BandSet(["vv", "vh"], 16) with is_multitemporal=True.
    # m-so2sat ships 8 SAR-derived bands (real/imag + Lee-filtered
    # components); all are routed to the nearest vv/vh slot.  When
    # multiple bands land on the same slot (e.g. vv_real and vv_lee both
    # map to 0) the later-indexed source band wins — for m-so2sat that
    # means the Lee-filtered imaginary component overwrites, which is
    # reasonable since all variants carry the same polarisation signal.
    "sar": {
        "modality_name": "SENTINEL1",
        "sample_field": "sentinel1",
        "channels": 2,
        "num_band_sets": 1,
        "name_to_idx": {
            "vv": 0,
            "vv_real": 0,
            "vv_imag": 0,
            "vv_lee": 0,
            "vv_lee_real": 0,
            "vv_lee_imag": 0,
            "vh": 1,
            "vh_real": 1,
            "vh_imag": 1,
            "vh_lee": 1,
            "vh_lee_real": 1,
            "vh_lee_imag": 1,
        },
    },
    # Landsat routed through the S2 normalizer.  Wavelengths align well:
    # B-G-R-NIR-SWIR1-SWIR2 ↔ B02-B03-B04-B08-B11-B12.  Use via
    # sensor_remap={"landsat": "landsat_as_s2"}.
    "landsat_as_s2": {
        "modality_name": "SENTINEL2_L2A",
        "sample_field": "sentinel2_l2a",
        "channels": 12,
        "num_band_sets": 3,
        "name_to_idx": {
            "blue": 0,
            "b2": 0,
            "green": 1,
            "b3": 1,
            "red": 2,
            "b4": 2,
            "nir": 3,
            "b5": 5,
            "swir_1": 8,
            "b6": 8,
            "swir_2": 9,
            "b7": 9,
        },
    },
    # NAIP / aerial: no dedicated OlmoEarth modality, route RGB through
    # the S2 path with non-RGB positions zero-filled.
    "aerial": {
        "modality_name": "SENTINEL2_L2A",
        "sample_field": "sentinel2_l2a",
        "channels": 12,
        "num_band_sets": 3,
        "name_to_idx": {
            "red": 2,
            "r": 2,
            "green": 1,
            "g": 1,
            "blue": 0,
            "b": 0,
            "nir": 3,
            "ir": 3,
        },
    },
}
# Aliases.
_MODALITY_INFO["naip"] = _MODALITY_INFO["aerial"]


# Canonical GSD (meters) per sensor for OlmoEarth's positional encodings.
# Landsat pixels are 30 m; S2/S1 are 10 m.
_SENSOR_INPUT_RES: dict[str, int] = {
    "s2": 10,
    "sar": 10,  # S1 coregistered to S2 10 m grid in OlmoEarth pretraining
    "landsat": 30,
    "aerial": 1,
    "naip": 1,
}

# Sensors whose raw values should NOT be rescaled to S2 DN — pass as-is to
# OlmoEarth's modality-specific normalizer.  SAR values can be large
# (Lee-filtered max ~10 000) and the S1 Normalizer expects the original scale.
# detect_input_unit returns S2_DN for them, making to_s2_dn a no-op anyway,
# but being explicit avoids surprises if the heuristic ever changes.
_PASSTHROUGH_SENSORS: frozenset[str] = frozenset({"sar"})


def _build_sensor_groups(bands: list[BandSpec]) -> list[dict]:
    """Group bands by sensor and resolve per-group modality metadata.

    Returns a list of dicts (one per unique sensor) with keys:
        sensor, modality_name, sample_field, channels, num_band_sets,
        src_indices (indices into the original ``bands`` list),
        dst_indices (target channel positions inside OlmoEarth's layout).
    Preserves the order sensors first appear in ``bands``.
    """
    order: list[str] = []
    grouped: dict[str, list[tuple[int, BandSpec]]] = defaultdict(list)
    for i, b in enumerate(bands):
        key = b.sensor.lower()
        if key not in grouped:
            order.append(key)
        grouped[key].append((i, b))

    result = []
    for sensor in order:
        if sensor not in _MODALITY_INFO:
            supported = sorted(set(_MODALITY_INFO))
            raise ValueError(
                f"OlmoEarth wrapper has no layout for sensor '{sensor}'.  Supported: {supported}."
            )
        info = _MODALITY_INFO[sensor]
        name_to_idx = info["name_to_idx"]
        src_indices: list[int] = []
        dst_indices: list[int] = []
        unknown: list[str] = []
        for src_idx, b in grouped[sensor]:
            key_name = b.name.lower()
            if key_name not in name_to_idx:
                unknown.append(b.name)
            elif name_to_idx[key_name] is not None:
                src_indices.append(src_idx)
                dst_indices.append(name_to_idx[key_name])
            # else: known-but-skippable band (e.g. swir_cirrus / B10 for S2) — zero-filled
        if unknown:
            raise ValueError(
                f"OlmoEarth wrapper can't map BandSpec names {unknown} for "
                f"sensor '{sensor}'.  Add them to "
                f"_MODALITY_INFO['{sensor}']['name_to_idx'] "
                f"with the correct OlmoEarth band index."
            )
        group_bands = [bands[i] for i in src_indices]
        input_unit: InputUnit | None = (
            None if sensor in _PASSTHROUGH_SENSORS else _detect_band_group_unit(group_bands)
        )
        result.append(
            {
                "sensor": sensor,
                "modality_name": info["modality_name"],
                "sample_field": info["sample_field"],
                "channels": info["channels"],
                "num_band_sets": info["num_band_sets"],
                "src_indices": src_indices,
                "dst_indices": dst_indices,
                "input_unit": input_unit,
            }
        )
    return result


def _resolve_modality(bands: list[BandSpec]) -> dict:
    """Pick the OlmoEarth modality from the input ``BandSpec.sensor`` field.

    Single-sensor convenience wrapper around ``_build_sensor_groups``.
    Raises if the sensor isn't one we have a layout for, or if the bands
    span multiple sensors (use ``_build_sensor_groups`` directly instead).
    """
    sensor = bands[0].sensor.lower()
    if not all(b.sensor.lower() == sensor for b in bands):
        sensors = sorted({b.sensor.lower() for b in bands})
        raise ValueError(
            f"OlmoEarth wrapper expects a single sensor per call; got mixed sensors {sensors}."
        )
    if sensor not in _MODALITY_INFO:
        raise ValueError(
            f"OlmoEarth wrapper has no layout for sensor '{sensor}'.  "
            f"Supported: {sorted(set(_MODALITY_INFO))}."
        )
    return _MODALITY_INFO[sensor]


[docs] class OlmoEarthBenchModel(BenchModel): """BenchModel wrapper for OlmoEarth geospatial foundation models. OlmoEarth is a multi-modal ViT trained on Sentinel-2, Sentinel-1, Landsat, NAIP, and other Earth-observation streams by AI2. The wrapper picks the right modality (or modalities) from ``bands[0].sensor`` and constructs a properly-shaped batch for OlmoEarth's encoder. Supported modalities (auto-detected from ``BandSpec.sensor``): * ``"s2"`` -> ``Modality.SENTINEL2_L2A`` (12 channels, 3 band-sets) * ``"landsat"`` -> ``Modality.LANDSAT`` (11 channels, 2 band-sets) * ``"sar"`` -> ``Modality.SENTINEL1`` (2 channels, 1 band-set) * ``"aerial"`` / ``"naip"`` -> S2 path with RGB zero-fill Mixed-sensor inputs (e.g. ``["s2", "sar"]``) are handled by building separate tensor branches and populating multiple ``MaskedOlmoEarthSample`` fields simultaneously. Channels missing from the input are zero-filled at the corresponding OlmoEarth position; the mask stays all-visible so ``pool_spatially`` can still produce embeddings. The wrapper overrides ``normalize_inputs`` to identity — OlmoEarth's internal ``Normalizer`` consumes raw values directly. Input scale (DN / reflectance / uint8) is auto-detected per sensor group and rescaled to S2 DN before normalisation (SAR values are passed as-is). ``input_res`` is auto-detected from the primary sensor's GSD: 10 m for S2/SAR, 30 m for Landsat. Pass ``input_res`` explicitly to override. Args: bands: Ordered ``BandSpec`` list describing the input channels. model_size: One of ``"nano"``, ``"tiny"``, ``"base"``, ``"large"``. ``"large"`` is only available for ``version="v1"``. version: Model version — ``"v1"`` (default) or ``"v1_1"``. v1.1 ships Nano/Tiny/Base with improved accuracy and ~25% more parameters; no Large variant yet. patch_size: Patch size for the encoder (default 8). input_res: Input resolution in meters. ``None`` (default) lets the wrapper auto-detect from the primary sensor GSD. time_steps: Temporal slots in the input. Default 3. std_multiplier: Std multiplier passed to ``Normalizer``. normalize: If True, L2-normalize output embeddings. sar_log_scale: If True, convert SAR values to dB via ``10·log10(max(v, 1e-6))`` before feeding OlmoEarth's S1 normalizer, which was trained on σ⁰ dB values. landsat_scale_factor: Optional multiplier applied to Landsat values *after* the standard uint8→DN conversion. Use to compensate for mis-matched scales between GeoBench's uint8 composites and OlmoEarth's pretraining DN range (~10 000). sensor_remap: Optional dict mapping sensor names to alternate routing keys, e.g. ``{"landsat": "landsat_as_s2"}`` to route Landsat bands through the S2 normalizer (+6.6 pp on m-forestnet). min_image_size: If set, upsample inputs smaller than this value to ``min_image_size × min_image_size`` via bilinear interpolation. Useful for datasets with small native images (e.g. m-so2sat at 32 px) where the patch grid would otherwise be too sparse. """ def __init__( self, bands: list[BandSpec], *, model_size: Literal["nano", "tiny", "base", "large"] = "base", version: Literal["v1", "v1_1"] = "v1", patch_size: int = 4, input_res: int | None = None, time_steps: int = 1, std_multiplier: float = 2.0, normalize: bool = False, sar_log_scale: bool = False, landsat_scale_factor: float | None = None, sensor_remap: dict[str, str] | None = None, min_image_size: int | None = None, **_kwargs, ) -> None: super().__init__(bands=bands, **_kwargs) # Lazy imports so the package is only needed when this model is used. from olmoearth_pretrain_minimal import ModelID, Normalizer, load_model_from_id from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.constants import Modality # Quiet down the package's INFO-level ModalitySpec dumps. for logger_name in ( "olmoearth_pretrain_minimal", "olmoearth_pretrain_minimal.olmoearth_pretrain_v1", ): logging.getLogger(logger_name).setLevel(logging.WARNING) # Optionally remap sensor names before routing. bands_for_routing = self.bands if sensor_remap: from dataclasses import replace as dc_replace bands_for_routing = [ dc_replace(b, sensor=sensor_remap.get(b.sensor, b.sensor)) for b in self.bands ] # Build per-sensor groups (handles both single and mixed sensors). sensor_groups = _build_sensor_groups(bands_for_routing) for g in sensor_groups: g["modality"] = getattr(Modality, g["modality_name"]) self._sensor_groups = sensor_groups # Auto-detect input_res from primary sensor unless explicitly set. if input_res is None: sensors_present = {g["sensor"] for g in sensor_groups} # For mixed s2+sar, S2 10 m is the OlmoEarth pretraining grid. primary = "s2" if "s2" in sensors_present else sensor_groups[0]["sensor"] input_res = _SENSOR_INPUT_RES.get(primary, 10) self.model_size = model_size self.patch_size = patch_size self.input_res = input_res self.time_steps = time_steps self.do_normalize = normalize self.sar_log_scale = sar_log_scale self.landsat_scale_factor = landsat_scale_factor self.min_image_size = min_image_size model_id = getattr(ModelID, f"OLMOEARTH_{version.upper()}_{model_size.upper()}") self.encoder_model = load_model_from_id(model_id, load_weights=True) self.normalizer = Normalizer(std_multiplier=std_multiplier)
[docs] def normalize_inputs(self, images: torch.Tensor) -> torch.Tensor: """Identity — OlmoEarth's internal Normalizer handles raw values.""" return images
def _pad_group( self, g_images: torch.Tensor, dst_indices: list[int], target_channels: int, ) -> torch.Tensor: """Place each input channel at its OlmoEarth modality position. Returns ``(B, target_channels, H, W)`` with zeros for missing bands. """ B, _, H, W = g_images.shape out = torch.zeros(B, target_channels, H, W, device=g_images.device, dtype=g_images.dtype) for local_idx, dst_idx in enumerate(dst_indices): out[:, dst_idx] = g_images[:, local_idx] return out def _build_mask( self, B: int, H: int, W: int, num_band_sets: int, device: torch.device, ) -> torch.Tensor: """Per-band-set ``(B, H, W, T, num_band_sets)`` mask — all visible. ``MaskValue.ONLINE_ENCODER`` is 0 so zeros == visible. Shape must be 5-D because the encoder's ``apply_embedding_to_modality`` does ``mask[..., idx]`` for each band-set index. """ return torch.zeros( B, H, W, self.time_steps, num_band_sets, dtype=torch.float32, device=device ) @torch.no_grad() def _forward_patch_features( self, images: torch.Tensor, ) -> torch.Tensor: """Extract image-level embeddings from raw inputs.""" from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.nn.flexi_vit import PoolingType from olmoearth_pretrain_minimal.olmoearth_pretrain_v1.utils.datatypes import ( MaskedOlmoEarthSample, ) device = images.device B, _, H, W = images.shape if self.min_image_size is not None and (self.min_image_size > H or self.min_image_size > W): new_h = max(H, self.min_image_size) new_w = max(W, self.min_image_size) images = F.interpolate( images, size=(new_h, new_w), mode="bilinear", align_corners=False ) B, _, H, W = images.shape # v1.1 uses linear patch embed which requires H,W divisible by patch_size. # Pad to the next multiple if needed (zero-padding is mask-safe). pad_h = (self.patch_size - H % self.patch_size) % self.patch_size pad_w = (self.patch_size - W % self.patch_size) % self.patch_size if pad_h > 0 or pad_w > 0: images = F.pad(images, (0, pad_w, 0, pad_h)) B, _, H, W = images.shape timestamps = torch.zeros(B, self.time_steps, 3, dtype=torch.long, device=device) timestamps[:, :, 0] = 15 timestamps[:, :, 1] = 6 timestamps[:, :, 2] = 2020 sample_kwargs: dict = {} for group in self._sensor_groups: # Extract this sensor's channels from the full input tensor. g_images = images[:, group["src_indices"]] # (B, Csensor, H, W) # Rescale to S2 DN unless the sensor is a passthrough type. input_unit = group["input_unit"] if input_unit is not None: g_images = to_s2_dn(g_images, input_unit) if group["sensor"] == "landsat" and self.landsat_scale_factor is not None: g_images = g_images * self.landsat_scale_factor if group["sensor"] == "sar" and self.sar_log_scale: g_images = 10.0 * torch.log10(g_images.clamp(min=1e-6)) # Map input channels -> OlmoEarth modality layout. g_images = self._pad_group(g_images, group["dst_indices"], group["channels"]) # Normalize with modality-specific statistics. g_nhwc = g_images.permute(0, 2, 3, 1).cpu().numpy() g_nhwtc = g_nhwc[:, :, :, None, :] if self.time_steps > 1: g_nhwtc = np.repeat(g_nhwtc, self.time_steps, axis=3) g_nhwtc = self.normalizer.normalize(group["modality"], g_nhwtc) field = group["sample_field"] mask = self._build_mask(B, H, W, group["num_band_sets"], device) sample_kwargs[field] = torch.from_numpy(g_nhwtc).float().to(device) sample_kwargs[f"{field}_mask"] = mask sample = MaskedOlmoEarthSample(timestamps=timestamps, **sample_kwargs) outputs = self.encoder_model.encoder( sample, patch_size=self.patch_size, input_res=self.input_res, fast_pass=True, ) pooled = outputs["tokens_and_masks"].pool_spatially(PoolingType.MEAN) embeddings = pooled.mean(dim=(1, 2)) if self.do_normalize: embeddings = F.normalize(embeddings, p=2, dim=-1) return embeddings