Source code for torchgeo_bench.models.segmentation_heads

"""Segmentation decoder heads for use with SegmentationProbe."""

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class LinearHead(nn.Module): """Per-layer BN + 1×1 conv heads with learned scale-weighted fusion. For a single layer the output is returned directly. For multiple layers, each head's logits are upsampled to the input resolution and combined via learned scalar weights. Args: channels_list: Channel count for each hooked feature layer. num_classes: Number of segmentation output classes. """ def __init__(self, channels_list: list[int], num_classes: int) -> None: super().__init__() self.num_classes = num_classes self.heads = nn.ModuleList( [ nn.Sequential(nn.BatchNorm2d(c), nn.Conv2d(c, num_classes, kernel_size=1)) for c in channels_list ] ) if len(channels_list) > 1: self.scale_weights = nn.Parameter(torch.ones(len(channels_list)))
[docs] def forward(self, features: list[torch.Tensor], input_h: int, input_w: int) -> torch.Tensor: """Upsample and sum per-layer logits.""" total_logits: torch.Tensor | int = 0 for idx, (feat, head) in enumerate(zip(features, self.heads)): logits = head(feat) if logits.shape[-2:] != (input_h, input_w): logits = F.interpolate( logits, size=(input_h, input_w), mode="bilinear", align_corners=False ) if len(self.heads) == 1: return logits total_logits = total_logits + logits * self.scale_weights[idx] return total_logits # type: ignore[return-value]
class PatchLinearHead(nn.Module): """ViT-specific patch decoder using a per-token linear projection. The first hooked feature map is interpreted as a token grid. Each token is independently projected from ``D`` channels to ``C * P^2`` logits via a 1x1 conv, then rearranged to pixel space with ``pixel_shuffle(P)``. Args: channels_list: Channel count for each hooked feature layer. Only the first entry is used. num_classes: Number of segmentation output classes. """ def __init__(self, channels_list: list[int], num_classes: int) -> None: super().__init__() self.num_classes = num_classes self.in_channels = channels_list[0] self.norm = ChannelLayerNorm(self.in_channels) self.conv: nn.Conv2d | None = None self.patch_size: int | None = None def _ensure_conv( self, patch_size: int, device: torch.device, dtype: torch.dtype, ) -> nn.Conv2d: if patch_size <= 0: raise ValueError(f"Patch size must be positive, got {patch_size}.") if self.conv is None: conv = nn.Conv2d(self.in_channels, self.num_classes * patch_size**2, kernel_size=1) self.conv = conv.to(device=device, dtype=dtype) self.patch_size = patch_size return self.conv if self.patch_size != patch_size: raise ValueError( f"PatchLinearHead was initialized for patch_size={self.patch_size}, " f"but got patch_size={patch_size}." ) if self.conv.weight.device != device or self.conv.weight.dtype != dtype: self.conv = self.conv.to(device=device, dtype=dtype) return self.conv def forward(self, features: list[torch.Tensor], input_h: int, input_w: int) -> torch.Tensor: """Project the token grid into pixel logits.""" feat = features[0] grid_h, grid_w = feat.shape[-2:] patch_h = round(input_h / grid_h) patch_w = round(input_w / grid_w) if patch_h != patch_w: raise ValueError( "PatchLinearHead requires square patch geometry. " f"Got input=({input_h}, {input_w}) and grid=({grid_h}, {grid_w})." ) conv = self._ensure_conv(patch_h, feat.device, feat.dtype) logits = conv(self.norm(feat)) logits = F.pixel_shuffle(logits, patch_h) if logits.shape[-2:] != (input_h, input_w): logits = F.interpolate( logits, size=(input_h, input_w), mode="bilinear", align_corners=False ) return logits
[docs] class ConvBlockHead(nn.Module): """Per-layer 1×1 projection to hidden_dim, aligned concat, 1×1 classification head. All feature maps are projected to the same channel count, upsampled to the finest spatial resolution in the batch, concatenated, and classified with a single 1×1 conv. Args: channels_list: Channel count for each hooked feature layer. num_classes: Number of segmentation output classes. hidden_dim: Projection dimension (default 256). """ def __init__(self, channels_list: list[int], num_classes: int, hidden_dim: int = 256) -> None: super().__init__() self.projectors = nn.ModuleList( [ nn.Sequential( nn.Conv2d(c, hidden_dim, kernel_size=1, bias=False), nn.BatchNorm2d(hidden_dim), nn.SiLU(inplace=True), ) for c in channels_list ] ) self.head = nn.Conv2d(hidden_dim * len(channels_list), num_classes, kernel_size=1)
[docs] def forward(self, features: list[torch.Tensor], input_h: int, input_w: int) -> torch.Tensor: """Project, upsample, concat, and classify features.""" proj_feats = [proj(f) for f, proj in zip(features, self.projectors)] target_h, target_w = 0, 0 for f in proj_feats: if f.shape[-2] > target_h: target_h, target_w = f.shape[-2:] if target_h <= 1: target_h, target_w = 16, 16 aligned = [] for f in proj_feats: if f.shape[-2:] != (target_h, target_w): f = F.interpolate( f, size=(target_h, target_w), mode="bilinear", align_corners=False ) aligned.append(f) logits = self.head(torch.cat(aligned, dim=1)) if logits.shape[-2:] != (input_h, input_w): logits = F.interpolate( logits, size=(input_h, input_w), mode="bilinear", align_corners=False ) return logits
[docs] class FPNHead(nn.Module): """Feature Pyramid Network decoder head. Applies lateral 1×1 convs, a top-down merging pathway, 3×3 refinement convs, then upsamples all levels to the finest resolution, concatenates, and classifies with a 1×1 conv. Layers must be provided in **coarse-to-fine order** (deepest / lowest- resolution first). Example for ResNet: ``["layer4", "layer3", "layer2", "layer1"]``. Args: channels_list: Channel count for each hooked feature layer (coarse-to-fine). num_classes: Number of segmentation output classes. hidden_dim: Feature dimension used throughout the FPN (default 256). """ def __init__(self, channels_list: list[int], num_classes: int, hidden_dim: int = 256) -> None: super().__init__() # Normalise raw CNN features before projection. BN is appropriate here: # CNN channels have per-filter semantics and batch stats are stable. self.input_norms = nn.ModuleList([nn.BatchNorm2d(c) for c in channels_list]) self.laterals = nn.ModuleList( [nn.Conv2d(c, hidden_dim, kernel_size=1, bias=False) for c in channels_list] ) self.fpn_convs = nn.ModuleList( [ nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(hidden_dim), nn.ReLU(inplace=True), ) for _ in channels_list ] ) self.fpn_head = nn.Conv2d(hidden_dim * len(channels_list), num_classes, kernel_size=1)
[docs] def forward(self, features: list[torch.Tensor], input_h: int, input_w: int) -> torch.Tensor: """Top-down FPN forward pass. Args: features: Feature maps in coarse-to-fine order (index 0 = coarsest). input_h: Target output height (input image height). input_w: Target output width (input image width). """ laterals = [lat(norm(f)) for f, norm, lat in zip(features, self.input_norms, self.laterals)] # Top-down merging: from coarsest (0) to finest (-1) for i in range(len(laterals) - 1): target_size = laterals[i + 1].shape[-2:] laterals[i + 1] = laterals[i + 1] + F.interpolate( laterals[i], size=target_size, mode="bilinear", align_corners=False ) fpn_outs = [conv(p) for p, conv in zip(laterals, self.fpn_convs)] finest_size = fpn_outs[-1].shape[-2:] aligned = [] for f in fpn_outs: if f.shape[-2:] != finest_size: f = F.interpolate(f, size=finest_size, mode="bilinear", align_corners=False) aligned.append(f) logits = self.fpn_head(torch.cat(aligned, dim=1)) if logits.shape[-2:] != (input_h, input_w): logits = F.interpolate( logits, size=(input_h, input_w), mode="bilinear", align_corners=False ) return logits
# --------------------------------------------------------------------------- # DPT helper modules (adapted from probe3d — mbanani/probe3d) # --------------------------------------------------------------------------- class ChannelLayerNorm(nn.Module): """LayerNorm over the channel dimension of a (B, C, H, W) feature map. Normalises each spatial position independently across channels — equivalent to the LayerNorm inside a ViT block. This is the natural choice before projecting ViT intermediate features, where residual-stream outliers can cause large inter-layer scale differences that BatchNorm handles poorly (sample-wise norm is immune to per-batch outlier corruption). """ def __init__(self, num_channels: int) -> None: super().__init__() self.norm = nn.LayerNorm(num_channels) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply layer norm over channels.""" # x: (B, C, H, W) → permute to (B, H, W, C) → LN → back return self.norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2).contiguous() class ResidualConvUnit(nn.Module): """Two conv+ReLU layers with a residual skip connection.""" def __init__(self, features: int, kernel_size: int = 3) -> None: super().__init__() padding = kernel_size // 2 self.conv = nn.Sequential( nn.Conv2d(features, features, kernel_size, padding=padding), nn.ReLU(inplace=True), nn.Conv2d(features, features, kernel_size, padding=padding), nn.ReLU(inplace=True), ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Apply residual conv block.""" return self.conv(x) + x class FeatureFusionBlock(nn.Module): """Fuses a feature map with an optional skip connection via residual conv units.""" def __init__(self, features: int, kernel_size: int = 3, with_skip: bool = True) -> None: super().__init__() self.with_skip = with_skip if with_skip: self.resConfUnit1 = ResidualConvUnit(features, kernel_size) self.resConfUnit2 = ResidualConvUnit(features, kernel_size) def forward(self, x: torch.Tensor, skip_x: torch.Tensor | None = None) -> torch.Tensor: """Fuse skip connection and refine features.""" if skip_x is not None: if skip_x.shape[-2:] != x.shape[-2:]: skip_x = F.interpolate( skip_x, size=x.shape[-2:], mode="bilinear", align_corners=False ) x = self.resConfUnit1(x) + skip_x return self.resConfUnit2(x)
[docs] class DPTHead(nn.Module): """DPT-style decoder head (adapted from probe3d at mbanani/probe3d, single-view). Requires exactly **4** feature layers in **coarse-to-fine order** (same convention as FPN, e.g. ``["layer4", "layer3", "layer2", "layer1"]`` for ResNet). The forward pass processes features from coarsest to finest through a cascade of :class:`FeatureFusionBlock` modules. Upsampling chain (mirroring probe3d): 1. 1×1 project each map to ``hidden_dim`` 2. 2× upsample all projected maps 3. Top-down fusion cascade (coarse → fine) 4. 4× upsample the fused result 5. 3×3 → ReLU → 3×3 output conv to ``num_classes`` 6. Final resize to input resolution Args: channels_list: Channel count for each hooked feature layer (coarse-to-fine). Must have exactly 4 entries. num_classes: Number of segmentation output classes. hidden_dim: Hidden channel dimension (default 256). kernel_size: Conv kernel size for residual units (default 3). """ def __init__( self, channels_list: list[int], num_classes: int, hidden_dim: int = 256, kernel_size: int = 3, ) -> None: super().__init__() if len(channels_list) != 4: raise ValueError( f"DPTHead requires exactly 4 feature layers, got {len(channels_list)}. " "Specify exactly 4 layer names in coarse-to-fine order in the model config." ) # Normalise ViT residual-stream features before projection. LayerNorm # over channels (per spatial position) matches the ViT's own internal # normalisation and is sample-wise — robust to the per-layer outlier # activations common in specialist ViTs (e.g. DOFA). self.input_norms = nn.ModuleList([ChannelLayerNorm(c) for c in channels_list]) # 1×1 projection — index 0 = coarsest self.convs = nn.ModuleList( [nn.Conv2d(c, hidden_dim, kernel_size=1, padding=0) for c in channels_list] ) # Fusion blocks: index 0 = coarsest (no skip), 1-3 receive skip from previous level self.ref = nn.ModuleList( [ FeatureFusionBlock(hidden_dim, kernel_size, with_skip=False), # coarsest FeatureFusionBlock(hidden_dim, kernel_size), FeatureFusionBlock(hidden_dim, kernel_size), FeatureFusionBlock(hidden_dim, kernel_size), # finest ] ) self.out_conv = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(hidden_dim, num_classes, kernel_size=3, padding=1), )
[docs] def forward(self, features: list[torch.Tensor], input_h: int, input_w: int) -> torch.Tensor: """DPT forward pass. Args: features: Feature maps in coarse-to-fine order (index 0 = coarsest). input_h: Target output height. input_w: Target output width. """ # Normalise → project → 2× upsample projected = [ F.interpolate(conv(norm(f)), scale_factor=2, mode="bilinear", align_corners=True) for norm, conv, f in zip(self.input_norms, self.convs, features) ] # Top-down cascade: coarsest (0) → finest (3) out = self.ref[0](projected[0], None) out = self.ref[1](projected[1], out) out = self.ref[2](projected[2], out) out = self.ref[3](projected[3], out) # 4× upsample → output conv → resize to input resolution out = F.interpolate(out, scale_factor=4, mode="bilinear", align_corners=True) out = self.out_conv(out) if out.shape[-2:] != (input_h, input_w): out = F.interpolate(out, size=(input_h, input_w), mode="bilinear", align_corners=False) return out