"""High-level dataset loading helpers and registry for torchgeo-bench.
This module owns the public ``get_datasets`` API used by
``torchgeo_bench.main`` and the registry that maps dataset names to their
:class:`~.base.BenchDataset` subclass. All band resolution, resize
transforms and DataLoader construction live here so the per-dataset wrappers
stay focused on declaring metadata.
"""
import logging
import warnings
from collections.abc import Callable, Iterable
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from .base import BenchDataset
from .benv2 import BENV2
from .burn_scars import BurnScars
from .caffe import CaFFe
from .cloudsen12 import CloudSEN12
from .dynamic_earthnet import DynamicEarthNet
from .eurosat import EuroSAT, EuroSATSpatial
from .flair2 import FLAIR2
from .forestnet import Forestnet
from .fotw import FieldsOfTheWorld
from .kuro_siwo import KuroSiwo
from .m_bigearthnet import MBigEarthNet
from .m_brick_kiln import MBrickKiln
from .m_eurosat import MEurosat
from .m_forestnet import MForestnet
from .m_pv4ger import MPv4ger
from .m_so2sat import MSo2Sat
from .pastis import PASTIS
from .so2sat import So2Sat
from .spacenet2 import SpaceNet2
from .spacenet7 import SpaceNet7
from .treesatai import TreeSatAI
logger = logging.getLogger(__name__)
_REGISTRY: dict[str, type[BenchDataset]] = {
cls.name: cls
for cls in [
# V1 classification
MEurosat,
MForestnet,
MSo2Sat,
MPv4ger,
MBrickKiln,
MBigEarthNet,
# V2 classification
BENV2,
TreeSatAI,
So2Sat,
Forestnet,
# V2 segmentation
CaFFe,
BurnScars,
CloudSEN12,
DynamicEarthNet,
FLAIR2,
FieldsOfTheWorld,
KuroSiwo,
PASTIS,
SpaceNet2,
SpaceNet7,
# torchgeo template
EuroSAT,
EuroSATSpatial,
]
}
[docs]
def get_bench_dataset_class(name: str) -> type[BenchDataset]:
"""Look up a dataset by name and return its :class:`BenchDataset` class.
Args:
name: Dataset identifier (e.g. ``"m-eurosat"``, ``"burn_scars"``).
Returns:
The registered :class:`BenchDataset` subclass.
Raises:
KeyError: If *name* is not in the registry.
"""
if name not in _REGISTRY:
available = ", ".join(sorted(_REGISTRY))
raise KeyError(f"Unknown dataset '{name}'. Available: {available}")
return _REGISTRY[name]
[docs]
def list_datasets() -> list[str]:
"""Return sorted names of all registered benchmark datasets."""
return sorted(_REGISTRY)
def _make_resize_transform(
image_size: int | None,
interpolation: str,
) -> Callable[[dict], dict] | None:
"""Build a sample-level transform that resizes ``image`` (and ``mask``)."""
if image_size is None:
return None
valid_modes = ("bicubic", "bilinear", "nearest")
if interpolation not in valid_modes:
raise ValueError(f"interpolation must be one of {valid_modes}, got {interpolation!r}.")
interp_mode = interpolation
align_corners = False if interp_mode in ("bicubic", "bilinear") else None
def _resize(sample: dict) -> dict:
img: torch.Tensor = sample["image"]
h, w = img.shape[-2], img.shape[-1]
if h != image_size or w != image_size:
img = F.interpolate(
img.unsqueeze(0),
size=(image_size, image_size),
mode=interp_mode,
align_corners=align_corners,
).squeeze(0)
sample["image"] = img
if "mask" in sample:
mask: torch.Tensor = sample["mask"].float()
h_m, w_m = mask.shape[-2], mask.shape[-1]
if h_m != image_size or w_m != image_size:
mask = (
F.interpolate(
mask.unsqueeze(0).unsqueeze(0),
size=(image_size, image_size),
mode="nearest",
)
.squeeze(0)
.squeeze(0)
.long()
)
sample["mask"] = mask
return sample
return _resize
def _make_loader(ds: Dataset, *, batch_size: int, shuffle: bool, num_workers: int) -> DataLoader:
return DataLoader(
ds,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=True,
)
[docs]
def get_datasets(
dataset_name: str = "m-forestnet",
partition_name: str = "default",
batch_size: int = 32,
return_val: bool = False,
num_workers: int = 8,
image_size: int | None = None,
interpolation: str = "bilinear",
bands: str | Iterable[str] | None = "rgb",
) -> tuple:
"""Load benchmark dataset splits and dataloaders.
Datasets always emit raw float32 values; per-channel normalization is
the model's responsibility (see :class:`~torchgeo_bench.models.interface.BenchModel`).
Args:
dataset_name: Identifier registered in :data:`_REGISTRY`.
partition_name: Partition name (only honoured by datasets where
:attr:`~.base.BenchDataset.supports_partitions` is ``True``).
batch_size: Batch size for the returned dataloaders.
return_val: If ``True``, also return a validation dataloader.
num_workers: Number of dataloader worker processes.
image_size: If set, resize images (and masks, with nearest) to this
square size at sample time.
interpolation: Resize interpolation for images (``"bicubic"``,
``"bilinear"``, ``"nearest"``).
bands: ``"rgb"`` (use the dataset's ``rgb_bands``), ``"all"`` /
``None`` (load all bands), or an explicit iterable of band names.
Returns:
Either ``(train_dataset, train_loader, test_loader)`` or, when
``return_val=True``, ``(train_dataset, train_loader, val_loader,
test_loader)``.
Raises:
KeyError: If ``dataset_name`` is not registered.
"""
cls = get_bench_dataset_class(dataset_name)
bench = cls()
if partition_name != "default" and not bench.supports_partitions:
warnings.warn(
f"Dataset '{dataset_name}' does not support custom partitions. "
f"Ignoring partition '{partition_name}'.",
UserWarning,
stacklevel=2,
)
if bands == "rgb":
bands_tuple: tuple[str, ...] | None = tuple(bench.rgb_bands)
elif bands == "all" or bands is None:
bands_tuple = None
elif isinstance(bands, str):
raise ValueError(
f"Invalid bands parameter: {bands!r}. Use 'rgb', 'all', None, "
"or an iterable of band names."
)
else:
bands_tuple = tuple(bands)
transform = _make_resize_transform(image_size, interpolation)
train_partition = partition_name if bench.supports_partitions else "default"
common = {"bands": bands_tuple, "transform": transform}
train_ds = bench.get_dataset("train", partition=train_partition, **common)
val_ds = bench.get_dataset("val", partition="default", **common)
test_ds = bench.get_dataset("test", partition="default", **common)
train_loader = _make_loader(
train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers
)
val_loader = _make_loader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = _make_loader(
test_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers
)
if return_val:
return train_ds, train_loader, val_loader, test_loader
return train_ds, train_loader, test_loader
__all__ = [
"get_bench_dataset_class",
"get_datasets",
"list_datasets",
]