torchgeo_bench.datasets#
Every benchmark dataset is a subclass of BenchDataset that declares
its metadata (bands, number of classes, task type, split sizes) and knows how
to produce a PyTorch Dataset for each split.
Datasets are registered automatically on import so that
get_bench_dataset_class() can resolve them by their CLI name
(e.g. "m-eurosat" or "benv2").
Base classes#
- class torchgeo_bench.datasets.BenchDataset[source][source]#
Bases:
ABCAbstract base class for benchmark datasets.
Subclasses must define the class-level metadata attributes listed below and implement
get_dataset()anddata_root().- task#
"classification"or"segmentation".- Type:
Literal[‘classification’, ‘segmentation’]
- bands#
Ordered list of all available spectral bands with statistics.
- split_sizes#
Number of samples per split for the default partition, keyed by
"train","val","test".
- supports_partitions#
Whether the dataset honours a non-default
partitionargument (V1 GeoBench datasets do; V2 does not).- Type:
- abstractmethod classmethod data_root()[source][source]#
Return the directory the upstream loader expects.
For V1/V2 wrappers this is the parent directory containing per-dataset subdirectories (e.g.
data/classification_v1.0); for torchgeo wrappers it is the dataset’s own root (e.g.data/eurosat).
- select_band_specs(bands)[source][source]#
Return the
BandSpecentries matching bands.Preserves the order given by bands. Raises
ValueErrorif any requested band is not declared on the dataset.
- abstractmethod get_dataset(split, *, partition='default', bands=None, transform=None)[source][source]#
Return a PyTorch
Datasetfor a split.Datasets always emit raw float32 values; normalization is the
BenchModel’s job.- Parameters:
split (str) –
"train","val", or"test".partition (str) – Partition name (V1 only, e.g.
"0.01x_train"). Ignored by datasets wheresupports_partitionsisFalse.bands (tuple[str, ...] | None) – Tuple of canonical band names to load.
Noneloads all.transform (Callable | None) – Optional sample transform callable.
- get_dataloader(split, *, batch_size=32, num_workers=8, shuffle=None, pin_memory=True, **dataset_kwargs)[source][source]#
Convenience wrapper: build a
DataLoader.
- class torchgeo_bench.datasets.BandSpec(sensor, name, source_name, mean, std, min, max, wavelength_um=None)[source][source]#
Bases:
objectMetadata for a single spectral band in a dataset.
- Parameters:
sensor (str) – Sensor family identifier (e.g.
"s2","landsat","aerial","sar","planet","worldview").name (str) – Canonical short band name used in the public API (e.g.
"red","b02","nir","vv").source_name (str) – Band key as it appears in the data files. For V1 HDF5 files this is the long form (
"04 - Red"); for V2 datasets this is typically the uppercase band code ("B04").mean (float) – Train-split mean pixel value (raw units, no normalization).
std (float) – Train-split standard deviation.
min (float) – Train-split minimum pixel value.
max (float) – Train-split maximum pixel value.
wavelength_um (float | None) – Approximate centre wavelength in micrometres.
Nonefor non-optical bands (SAR, DEM, elevation).
Loading API#
- torchgeo_bench.datasets.get_datasets(dataset_name='m-forestnet', partition_name='default', batch_size=32, return_val=False, num_workers=8, image_size=None, interpolation='bilinear', bands='rgb')[source][source]#
Load benchmark dataset splits and dataloaders.
Datasets always emit raw float32 values; per-channel normalization is the model’s responsibility (see
BenchModel).- Parameters:
dataset_name (str) – Identifier registered in
_REGISTRY.partition_name (str) – Partition name (only honoured by datasets where
supports_partitionsisTrue).batch_size (int) – Batch size for the returned dataloaders.
return_val (bool) – If
True, also return a validation dataloader.num_workers (int) – Number of dataloader worker processes.
image_size (int | None) – If set, resize images (and masks, with nearest) to this square size at sample time.
interpolation (str) – Resize interpolation for images (
"bicubic","bilinear","nearest").bands (str | Iterable[str] | None) –
"rgb"(use the dataset’srgb_bands),"all"/None(load all bands), or an explicit iterable of band names.
- Returns:
Either
(train_dataset, train_loader, test_loader)or, whenreturn_val=True,(train_dataset, train_loader, val_loader, test_loader).- Raises:
KeyError – If
dataset_nameis not registered.- Return type:
- torchgeo_bench.datasets.get_bench_dataset_class(name)[source][source]#
Look up a dataset by name and return its
BenchDatasetclass.- Parameters:
name (str) – Dataset identifier (e.g.
"m-eurosat","burn_scars").- Returns:
The registered
BenchDatasetsubclass.- Raises:
KeyError – If name is not in the registry.
- Return type:
GeoBench V1 (classification)#
V1 datasets use the m- prefix on the command line. They wrap the original
GeoBench HDF5 distributions and expose the standard train/val/test
splits plus alternative partitions where available.
- class torchgeo_bench.datasets.MBigEarthNet[source][source]#
Bases:
_V1DatasetSentinel-2 multi-label land-cover classification (43 classes).
Based on the BigEarthNet dataset with 12 Sentinel-2 spectral bands. Uses multi-hot label encoding.
- class torchgeo_bench.datasets.MBrickKiln[source][source]#
Bases:
_V1DatasetSentinel-2 brick kiln detection (2 classes).
Based on the Brick-Kiln dataset with 13 Sentinel-2 spectral bands.
- class torchgeo_bench.datasets.MEurosat[source][source]#
Bases:
_V1DatasetSentinel-2 land-use classification (10 classes).
Based on the EuroSAT dataset with 13 Sentinel-2 spectral bands.
- class torchgeo_bench.datasets.MForestnet[source][source]#
Bases:
_V1DatasetLandsat forest-change classification (12 classes).
Based on the ForestNet dataset with 6 Landsat spectral bands.
GeoBench V2 — classification#
- class torchgeo_bench.datasets.BENV2[source][source]#
Bases:
_V2DatasetSentinel-2 + SAR multi-class classification (19 classes).
BigEarthNet V2 with 12 Sentinel-2 optical bands and 2 SAR bands.
- class torchgeo_bench.datasets.Forestnet[source][source]#
Bases:
_V2DatasetSentinel-2 forest-change classification (12 classes).
GeoBench V2 version with 6 Sentinel-2 spectral bands.
- class torchgeo_bench.datasets.So2Sat[source][source]#
Bases:
_V2DatasetSentinel-2 + SAR local climate zone classification (17 classes).
GeoBench V2 version with 10 Sentinel-2 and 2 SAR bands.
- class torchgeo_bench.datasets.TreeSatAI[source][source]#
Bases:
_V2DatasetAerial + Sentinel-2 + SAR tree species classification (15 classes).
Multi-sensor dataset with aerial RGB+NIR, 12 Sentinel-2 bands, and 3 SAR bands. Class indices follow the upstream
GeoBenchTreeSatAI.classesordering: Abies, Acer, Alnus, Betula, Cleared, Fagus, Fraxinus, Larix, Picea, Pinus, Populus, Prunus, Pseudotsuga, Quercus, Tilia.
GeoBench V2 — segmentation#
- class torchgeo_bench.datasets.BurnScars[source][source]#
Bases:
_V2DatasetSentinel-2 burn scar segmentation (3 classes).
Classes: background, burn, cloud.
- class torchgeo_bench.datasets.CaFFe[source][source]#
Bases:
_V2DatasetAerial grayscale calving-front segmentation (4 classes).
The upstream GeoBench V2 dataset returns
(image, mask)pairs, so this wrapper exposes it as a segmentation task even though the dataset name historically suggested classification.
- class torchgeo_bench.datasets.CloudSEN12[source][source]#
Bases:
_V2DatasetSentinel-2 cloud segmentation (4 classes).
- class torchgeo_bench.datasets.DynamicEarthNet[source][source]#
Bases:
_V2DatasetPlanet + Sentinel-2 land-cover change segmentation (7 classes).
- class torchgeo_bench.datasets.FLAIR2[source][source]#
Bases:
_V2DatasetAerial land-cover segmentation (13 classes).
French aerial imagery with RGB, NIR, and elevation bands. The upstream
GeoBenchFLAIR2accepts a flatband_orderlist and returns a single stackedimagetensor, so this wrapper does not use the multi-modality dict shape.
- class torchgeo_bench.datasets.FieldsOfTheWorld[source][source]#
Bases:
_V2DatasetSentinel-2 field boundary segmentation (4 classes).
Classes: background, field, boundary, other. Upstream returns
image_a/image_bchange-detection pairs;canonicalize_sample()keeps the later acquisition (image_b).
- class torchgeo_bench.datasets.KuroSiwo[source][source]#
Bases:
_V2DatasetSAR flood mapping segmentation (4 classes).
Upstream emits multi-temporal SAR (
image_pre_1/image_pre_2/image_post) plus a static DEM (image_dem). Its built-inreturn_stacked_image=Truepath stacks per-timestep tensors along a new temporal axis, which (a) leaves the result 4-D(C, T, H, W)and (b) hits an assertion when SAR and DEM channel counts differ.To produce a clean 3-D
(C, H, W)image we bypass that path altogether: we ask upstream for the post-event SAR only (time_step=["post"]), then concatenate optional DEM along the channel dimension ourselves incanonicalize_sample().- get_dataset(split, *, partition='default', bands=None, transform=None)[source][source]#
Return a
GeoBenchv2configured for single-timestep SAR + DEM.Forces
return_stacked_image=False(so the upstream emits per-modality keys we can stack ourselves) andtime_step=["post"](so only the post-event SAR acquisition is loaded).canonicalize_sample()then folds the per-modality tensors into a single 3-Dimagekey.
- canonicalize_sample(sample)[source][source]#
Fold per-modality keys into a single 3-D
(C, H, W)image tensor.Upstream emits
image_postfor SAR (we only request the post-event timestep) and/orimage_demdepending on the requested band order. Both are 3-D(C, H, W)so we can simply concatenate them along the channel dimension. Per-modality keys are removed from the sample once merged.
- class torchgeo_bench.datasets.PASTIS[source][source]#
Bases:
_V2DatasetSentinel-2 + SAR crop type segmentation (20 classes).
Includes ascending and descending SAR orbit passes (
s1_asc,s1_desc).
torchgeo wrappers#
- class torchgeo_bench.datasets.EuroSAT[source][source]#
Bases:
BenchDatasetSentinel-2 land-use classification (10 classes), via torchgeo.
13 Sentinel-2 spectral bands. Identical task and class set as
MEurosat(GeoBench V1) but loads data throughtorchgeo.datasets.EuroSAT, so file layout and download behaviour are managed by torchgeo.
- class torchgeo_bench.datasets.EuroSATSpatial[source][source]#
Bases:
EuroSATEuroSAT with longitude-based 60/20/20 train/val/test splits.
Uses
torchgeo.datasets.EuroSATSpatial, which partitions tiles by longitude so train/val/test regions are spatially disjoint. Same 27000 images, classes, bands, and stats asEuroSAT; only the split assignment differs. Stronger generalization signal than the default random split.