collate_samples¶
- hrtfpykit.datasets.collate_samples(batch)¶
Collate hrtfpykit dataset samples for PyTorch data loaders.
collate_samplesis the batching function used withtorch.utils.data.DataLoader. The data loader receives individual samples fromdataset[index]and passes a list of those samples tocollate_samples. hrtfpykit samples are dictionaries withinputsandtargetentries. Each entry is eitherNoneor a dictionary whose keys come from the selected specs.The values inside
inputsandtargetfollow the spec that produced them.HRTFSpecreturns an IR or TF array extracted from the selected HRTF version.ITDSpec,ILDSpec, andSHSpecreturn values calculated from the selected HRTF version. For these acoustic specs,transformis an HRTF transform applied before extraction or calculation.AnthropometrySpecandMetadataSpecreturn the selected subject values from their loaded resources, commonly a dictionary for CSV resources or a NumPy slice for matrix resources.MeshSpecreturns a mesh path string unless itstransformloads that path.ImageSpecandVideoSpecreturn a path string when one file matches the sample, a list of path strings when several files match, or the values produced by theirtransform. WithImageSpec(concatenate=True), transformed image arrays are concatenated along axis zero before collation.collate_samplesconverts homogeneous numeric values into PyTorch tensors so a training loop can usebatch["inputs"]andbatch["target"]directly. Floating point numeric values are converted totorch.float32, which matches the default dtype used by standard PyTorch model parameters. Integer indices and boolean values keep their natural tensor dtypes. Arrays and tensors with matching shapes are stacked along a leading batch axis. Numeric dictionaries with the same keys are converted tobatch x featurestensors. Lists with the same length are collated by position; for example, a subject with nine transformed RGB images of shape3 x 224 x 224becomesbatch x 9 x 3 x 224 x 224. Strings, paths, ragged values, mixedNonevalues, and non numeric objects are kept as Python lists.- Parameters:
batch (sequence) – Sequence of samples returned by a map style dataset. In normal PyTorch usage, this is the list built internally from calls such as
dataset[index]beforeDataLoaderyields a batch.- Returns:
Collated batch. For standard hrtfpykit dataset samples, the returned object is a dictionary containing collated
inputsandtargetentries. Homogeneous numeric values are returned as PyTorch tensors.- Return type:
object
- Raises:
TypeError – If
batchis not a sequence of dataset samples.ValueError – If
batchis empty.ImportError – If PyTorch is not installed.
Notes
Dataset indexing stays framework neutral. Image, video, and mesh specs return paths until their transforms load those paths into arrays or tensors. Tensor conversion happens here, at DataLoader collation time. Floating tensors are returned as
torch.float32so common training loops do not need to cast NumPyfloat64values manually.Examples
This example builds a PyTorch training batch from HRTF magnitudes and spherical harmonic targets.
HRTFSpecreturns one subject level magnitude map,SHSpecreturns the coefficient target, andcollate_samplesstacks both values as tensors. Floating tensors are already returned astorch.float32, so the training loop does not need.float()casts.>>> import torch >>> from math import prod >>> from torch import nn >>> from torch.utils.data import DataLoader >>> from hrtfpykit.datasets import HUTUBS, HRTFSpec, SHSpec, collate_samples >>> train_dataset = HUTUBS( ... root="datasets/hutubs", ... inputs=HRTFSpec( ... domain="frequency", ... signal="tf_magnitude_db", ... ears="left", ... index_by=("subject",), ... name="magnitude", ... ), ... target=SHSpec( ... sh_order=9, ... ears="left", ... index_by=("subject",), ... name="sh", ... ), ... split="train", ... ) >>> train_loader = DataLoader( ... train_dataset, ... batch_size=8, ... collate_fn=collate_samples, ... ) >>> batch = next(iter(train_loader)) >>> print(batch["inputs"]["magnitude"].shape) torch.Size([8, 440, 129]) >>> print(batch["inputs"]["magnitude"].dtype) torch.float32 >>> print(batch["target"]["sh"].shape) torch.Size([8, 100, 129]) >>> print(batch["target"]["sh"].dtype) torch.float32 >>> class MagnitudeToSHModel(nn.Module): ... def __init__(self, target_shape): ... super().__init__() ... self.target_shape = tuple(target_shape) ... self.encoder = nn.Sequential( ... nn.Conv2d(1, 32, kernel_size=3, padding=1), ... nn.ReLU(), ... nn.AdaptiveAvgPool2d((1, 1)), ... nn.Flatten(), ... ) ... self.head = nn.Linear(32, prod(self.target_shape)) ... def forward(self, magnitude): ... features = self.encoder(magnitude.unsqueeze(1)) ... return self.head(features).reshape(magnitude.shape[0], *self.target_shape) >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> target_shape = batch["target"]["sh"].shape[1:] >>> model = MagnitudeToSHModel(target_shape).to(device) >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) >>> loss_fn = nn.MSELoss() >>> for epoch in range(10): ... total_loss = 0.0 ... num_batches = 0 ... for batch in train_loader: ... magnitude = batch["inputs"]["magnitude"].to(device) ... target = batch["target"]["sh"].to(device) ... prediction = model(magnitude) ... loss = loss_fn(prediction, target) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step() ... total_loss += float(loss.detach().cpu()) ... num_batches += 1 ... print(f"epoch {epoch + 1:02d} loss={total_loss / num_batches:.6f}")