torchgeo_bench.datasets#

Every benchmark dataset is a subclass of BenchDataset that declares its metadata (bands, number of classes, task type, split sizes) and knows how to produce a PyTorch Dataset for each split. Datasets are registered automatically on import so that get_bench_dataset_class() can resolve them by their CLI name (e.g. "m-eurosat" or "benv2").

Base classes#

class torchgeo_bench.datasets.BenchDataset[source][source]#

Bases: ABC

Abstract base class for benchmark datasets.

Subclasses must define the class-level metadata attributes listed below and implement get_dataset() and data_root().

name#

Dataset identifier used on the command line (e.g. "m-eurosat").

Type:

str

task#

"classification" or "segmentation".

Type:

Literal[‘classification’, ‘segmentation’]

num_classes#

Number of output classes.

Type:

int

bands#

Ordered list of all available spectral bands with statistics.

Type:

list[torchgeo_bench.datasets.base.BandSpec]

rgb_bands#

Short names of the bands to use for RGB-only mode.

Type:

list[str]

split_sizes#

Number of samples per split for the default partition, keyed by "train", "val", "test".

Type:

dict[str, int]

multilabel#

Whether labels are multi-hot (e.g. BigEarthNet).

Type:

bool

supports_partitions#

Whether the dataset honours a non-default partition argument (V1 GeoBench datasets do; V2 does not).

Type:

bool

property num_channels: int#

Total number of spectral bands.

property rgb_indices: list[int]#

Indices into bands for the RGB subset.

abstractmethod classmethod data_root()[source][source]#

Return the directory the upstream loader expects.

For V1/V2 wrappers this is the parent directory containing per-dataset subdirectories (e.g. data/classification_v1.0); for torchgeo wrappers it is the dataset’s own root (e.g. data/eurosat).

select_band_specs(bands)[source][source]#

Return the BandSpec entries matching bands.

Preserves the order given by bands. Raises ValueError if any requested band is not declared on the dataset.

abstractmethod get_dataset(split, *, partition='default', bands=None, transform=None)[source][source]#

Return a PyTorch Dataset for a split.

Datasets always emit raw float32 values; normalization is the BenchModel’s job.

Parameters:
  • split (str) – "train", "val", or "test".

  • partition (str) – Partition name (V1 only, e.g. "0.01x_train"). Ignored by datasets where supports_partitions is False.

  • bands (tuple[str, ...] | None) – Tuple of canonical band names to load. None loads all.

  • transform (Callable | None) – Optional sample transform callable.

get_dataloader(split, *, batch_size=32, num_workers=8, shuffle=None, pin_memory=True, **dataset_kwargs)[source][source]#

Convenience wrapper: build a DataLoader.

class torchgeo_bench.datasets.BandSpec(sensor, name, source_name, mean, std, min, max, wavelength_um=None)[source][source]#

Bases: object

Metadata for a single spectral band in a dataset.

Parameters:
  • sensor (str) – Sensor family identifier (e.g. "s2", "landsat", "aerial", "sar", "planet", "worldview").

  • name (str) – Canonical short band name used in the public API (e.g. "red", "b02", "nir", "vv").

  • source_name (str) – Band key as it appears in the data files. For V1 HDF5 files this is the long form ("04 - Red"); for V2 datasets this is typically the uppercase band code ("B04").

  • mean (float) – Train-split mean pixel value (raw units, no normalization).

  • std (float) – Train-split standard deviation.

  • min (float) – Train-split minimum pixel value.

  • max (float) – Train-split maximum pixel value.

  • wavelength_um (float | None) – Approximate centre wavelength in micrometres. None for non-optical bands (SAR, DEM, elevation).

Loading API#

torchgeo_bench.datasets.get_datasets(dataset_name='m-forestnet', partition_name='default', batch_size=32, return_val=False, num_workers=8, image_size=None, interpolation='bilinear', bands='rgb')[source][source]#

Load benchmark dataset splits and dataloaders.

Datasets always emit raw float32 values; per-channel normalization is the model’s responsibility (see BenchModel).

Parameters:
  • dataset_name (str) – Identifier registered in _REGISTRY.

  • partition_name (str) – Partition name (only honoured by datasets where supports_partitions is True).

  • batch_size (int) – Batch size for the returned dataloaders.

  • return_val (bool) – If True, also return a validation dataloader.

  • num_workers (int) – Number of dataloader worker processes.

  • image_size (int | None) – If set, resize images (and masks, with nearest) to this square size at sample time.

  • interpolation (str) – Resize interpolation for images ("bicubic", "bilinear", "nearest").

  • bands (str | Iterable[str] | None) – "rgb" (use the dataset’s rgb_bands), "all" / None (load all bands), or an explicit iterable of band names.

Returns:

Either (train_dataset, train_loader, test_loader) or, when return_val=True, (train_dataset, train_loader, val_loader, test_loader).

Raises:

KeyError – If dataset_name is not registered.

Return type:

tuple

torchgeo_bench.datasets.get_bench_dataset_class(name)[source][source]#

Look up a dataset by name and return its BenchDataset class.

Parameters:

name (str) – Dataset identifier (e.g. "m-eurosat", "burn_scars").

Returns:

The registered BenchDataset subclass.

Raises:

KeyError – If name is not in the registry.

Return type:

type[BenchDataset]

torchgeo_bench.datasets.list_datasets()[source][source]#

Return sorted names of all registered benchmark datasets.

GeoBench V1 (classification)#

V1 datasets use the m- prefix on the command line. They wrap the original GeoBench HDF5 distributions and expose the standard train/val/test splits plus alternative partitions where available.

class torchgeo_bench.datasets.MBigEarthNet[source][source]#

Bases: _V1Dataset

Sentinel-2 multi-label land-cover classification (43 classes).

Based on the BigEarthNet dataset with 12 Sentinel-2 spectral bands. Uses multi-hot label encoding.

class torchgeo_bench.datasets.MBrickKiln[source][source]#

Bases: _V1Dataset

Sentinel-2 brick kiln detection (2 classes).

Based on the Brick-Kiln dataset with 13 Sentinel-2 spectral bands.

class torchgeo_bench.datasets.MEurosat[source][source]#

Bases: _V1Dataset

Sentinel-2 land-use classification (10 classes).

Based on the EuroSAT dataset with 13 Sentinel-2 spectral bands.

class torchgeo_bench.datasets.MForestnet[source][source]#

Bases: _V1Dataset

Landsat forest-change classification (12 classes).

Based on the ForestNet dataset with 6 Landsat spectral bands.

class torchgeo_bench.datasets.MPv4ger[source][source]#

Bases: _V1Dataset

Aerial solar panel detection (2 classes).

Based on the PV4GER dataset with 3 aerial RGB bands.

class torchgeo_bench.datasets.MSo2Sat[source][source]#

Bases: _V1Dataset

Sentinel-2 + SAR local climate zone classification (17 classes).

Based on the So2Sat dataset with 10 Sentinel-2 and 8 SAR bands.

GeoBench V2 — classification#

class torchgeo_bench.datasets.BENV2[source][source]#

Bases: _V2Dataset

Sentinel-2 + SAR multi-class classification (19 classes).

BigEarthNet V2 with 12 Sentinel-2 optical bands and 2 SAR bands.

class torchgeo_bench.datasets.Forestnet[source][source]#

Bases: _V2Dataset

Sentinel-2 forest-change classification (12 classes).

GeoBench V2 version with 6 Sentinel-2 spectral bands.

class torchgeo_bench.datasets.So2Sat[source][source]#

Bases: _V2Dataset

Sentinel-2 + SAR local climate zone classification (17 classes).

GeoBench V2 version with 10 Sentinel-2 and 2 SAR bands.

class torchgeo_bench.datasets.TreeSatAI[source][source]#

Bases: _V2Dataset

Aerial + Sentinel-2 + SAR tree species classification (15 classes).

Multi-sensor dataset with aerial RGB+NIR, 12 Sentinel-2 bands, and 3 SAR bands. Class indices follow the upstream GeoBenchTreeSatAI.classes ordering: Abies, Acer, Alnus, Betula, Cleared, Fagus, Fraxinus, Larix, Picea, Pinus, Populus, Prunus, Pseudotsuga, Quercus, Tilia.

GeoBench V2 — segmentation#

class torchgeo_bench.datasets.BurnScars[source][source]#

Bases: _V2Dataset

Sentinel-2 burn scar segmentation (3 classes).

Classes: background, burn, cloud.

class torchgeo_bench.datasets.CaFFe[source][source]#

Bases: _V2Dataset

Aerial grayscale calving-front segmentation (4 classes).

The upstream GeoBench V2 dataset returns (image, mask) pairs, so this wrapper exposes it as a segmentation task even though the dataset name historically suggested classification.

class torchgeo_bench.datasets.CloudSEN12[source][source]#

Bases: _V2Dataset

Sentinel-2 cloud segmentation (4 classes).

class torchgeo_bench.datasets.DynamicEarthNet[source][source]#

Bases: _V2Dataset

Planet + Sentinel-2 land-cover change segmentation (7 classes).

class torchgeo_bench.datasets.FLAIR2[source][source]#

Bases: _V2Dataset

Aerial land-cover segmentation (13 classes).

French aerial imagery with RGB, NIR, and elevation bands. The upstream GeoBenchFLAIR2 accepts a flat band_order list and returns a single stacked image tensor, so this wrapper does not use the multi-modality dict shape.

class torchgeo_bench.datasets.FieldsOfTheWorld[source][source]#

Bases: _V2Dataset

Sentinel-2 field boundary segmentation (4 classes).

Classes: background, field, boundary, other. Upstream returns image_a / image_b change-detection pairs; canonicalize_sample() keeps the later acquisition (image_b).

canonicalize_sample(sample)[source][source]#

Pick the later acquisition (image_b) and surface it as image.

class torchgeo_bench.datasets.KuroSiwo[source][source]#

Bases: _V2Dataset

SAR flood mapping segmentation (4 classes).

Upstream emits multi-temporal SAR (image_pre_1 / image_pre_2 / image_post) plus a static DEM (image_dem). Its built-in return_stacked_image=True path stacks per-timestep tensors along a new temporal axis, which (a) leaves the result 4-D (C, T, H, W) and (b) hits an assertion when SAR and DEM channel counts differ.

To produce a clean 3-D (C, H, W) image we bypass that path altogether: we ask upstream for the post-event SAR only (time_step=["post"]), then concatenate optional DEM along the channel dimension ourselves in canonicalize_sample().

get_dataset(split, *, partition='default', bands=None, transform=None)[source][source]#

Return a GeoBenchv2 configured for single-timestep SAR + DEM.

Forces return_stacked_image=False (so the upstream emits per-modality keys we can stack ourselves) and time_step=["post"] (so only the post-event SAR acquisition is loaded). canonicalize_sample() then folds the per-modality tensors into a single 3-D image key.

canonicalize_sample(sample)[source][source]#

Fold per-modality keys into a single 3-D (C, H, W) image tensor.

Upstream emits image_post for SAR (we only request the post-event timestep) and/or image_dem depending on the requested band order. Both are 3-D (C, H, W) so we can simply concatenate them along the channel dimension. Per-modality keys are removed from the sample once merged.

class torchgeo_bench.datasets.PASTIS[source][source]#

Bases: _V2Dataset

Sentinel-2 + SAR crop type segmentation (20 classes).

Includes ascending and descending SAR orbit passes (s1_asc, s1_desc).

class torchgeo_bench.datasets.SpaceNet2[source][source]#

Bases: _V2Dataset

WorldView building footprint segmentation (3 classes).

8 multispectral + 1 panchromatic band from WorldView satellite.

class torchgeo_bench.datasets.SpaceNet7[source][source]#

Bases: _V2Dataset

Planet building change segmentation (3 classes).

RGB imagery from Planet satellites.

torchgeo wrappers#

class torchgeo_bench.datasets.EuroSAT[source][source]#

Bases: BenchDataset

Sentinel-2 land-use classification (10 classes), via torchgeo.

13 Sentinel-2 spectral bands. Identical task and class set as MEurosat (GeoBench V1) but loads data through torchgeo.datasets.EuroSAT, so file layout and download behaviour are managed by torchgeo.

classmethod data_root()[source][source]#

Return Path("data/eurosat") (torchgeo manages its own layout below).

get_dataset(split, *, partition='default', bands=None, transform=None, normalize='mean_stdev')[source][source]#

Return a torchgeo.datasets.EuroSAT for the given split.

class torchgeo_bench.datasets.EuroSATSpatial[source][source]#

Bases: EuroSAT

EuroSAT with longitude-based 60/20/20 train/val/test splits.

Uses torchgeo.datasets.EuroSATSpatial, which partitions tiles by longitude so train/val/test regions are spatially disjoint. Same 27000 images, classes, bands, and stats as EuroSAT; only the split assignment differs. Stronger generalization signal than the default random split.

classmethod data_root()[source][source]#

Return Path("data/eurosat") — shares the archive with EuroSAT.

get_dataset(split, *, partition='default', bands=None, transform=None, normalize='mean_stdev')[source][source]#

Return a torchgeo.datasets.EuroSATSpatial for the given split.