Add a Dataset#
This page explains how to wire a new geospatial dataset into torchgeo-bench so that any registered model can be evaluated on it automatically.
Prerequisites#
Clone the repository and install the development dependencies:
$ git clone https://github.com/torchgeo/torchgeo-bench.git
$ cd torchgeo-bench
$ conda activate torchgeo-bench
$ uv sync --extra dev
This is the same setup used in Evaluate your own model (Stage 1). Download the dataset files so you can test loading locally:
$ torchgeo-bench download <dataset_name>
Implement BenchDataset#
Create a new module under src/torchgeo_bench/datasets/ and subclass
BenchDataset:
from torchgeo_bench.datasets.base import BenchDataset, BandSpec
class MyDataset(BenchDataset):
name = "my_dataset"
task = "classification" # or "segmentation"
num_classes = 10
bands: list[BandSpec] = [...]
split_sizes = {"train": 5000, "val": 1000, "test": 2000}
def get_dataset(self, split: str, bands) -> torch.utils.data.Dataset:
... # return a Dataset yielding (image_tensor, label) pairs
Required class-level attributes:
name— unique string identifier used by the Hydra registry and CLItask—"classification"or"segmentation"num_classes— integer label countbands— list ofBandSpecobjects supplying per-channel sensor / wavelength / normalisation statssplit_sizes— dict withtrain,val, andtestkeys
The get_dataset method must accept split ("train", "val", or
"test") and bands (the subset of bands requested by the model), and
return a torch.utils.data.Dataset whose __getitem__ yields
(image_tensor, label) pairs.
Note
V1 vs V2 loader patterns. V1 datasets (m- prefix) read images
directly from HDF5 files via
_V1Dataset. V2 datasets use
torchgeo dataset classes as the underlying loader and inherit from
_V2Dataset. When adding a
genuinely new dataset, prefer the V2 torchgeo pattern so the loader can
participate in torchgeo’s transform pipeline.
Register and configure#
1. Export the class from src/torchgeo_bench/datasets/__init__.py:
from .my_dataset import MyDataset
2. Add a Hydra dataset config YAML under
src/torchgeo_bench/conf/dataset/ named after your dataset
(e.g. my_dataset.yaml):
# @package _global_
defaults:
- base_dataset
dataset:
name: my_dataset
num_classes: 10
task: classification
Adjust keys as needed; see existing configs in that directory for reference.
Run the smoke test#
With the dataset on disk, run a quick benchmark to verify the dataset loads and produces sensible results:
$ torchgeo-bench run model=timm/resnet50 dataset.names=[my_dataset] \
eval.skip_linear=true eval.bootstrap=10
Once results look sensible, follow the PR workflow described in Contribute a model (Stage 2) to open a pull request.