hrtf_loss

hrtfpykit.datasets.torch.hrtf_loss(prediction, target, metric='rmse', input_scale='db', reduction_method='mean', epsilon=1e-12)

Compute a scalar PyTorch loss for HRTF and HRIR model outputs.

hrtf_loss accepts PyTorch tensors or tensor-convertible values, compares prediction and target with RMSE, MAE, or LSD, and returns one scalar tensor for backpropagation with loss.backward(). The last tensor axis is the acoustic axis being scored: HRIR samples for time-domain targets, frequency bins for HRTF magnitude targets, and frequency bins for LSD. After that axis is reduced, the remaining axes, commonly batch, positions, and ears, are reduced with reduction_method.

The function supports three metrics:

  • metric="rmse" computes root mean squared error over the final tensor axis.

  • metric="mae" computes mean absolute error over the final tensor axis.

  • metric="lsd" computes log-spectral distance over the final tensor axis.

rmse and mae measure direct tensor error in the representation passed to the function. For HRIR targets, the final axis contains time samples. For HRTF magnitude or dB-magnitude targets, the final axis contains frequency bins and the loss measures magnitude error in that representation.

lsd measures spectral magnitude error in decibels. With input_scale="db", prediction and target are dB magnitudes. With input_scale="linear", values are linear magnitudes and are converted to dB with 20 * log10(clamp(value, min=epsilon)). Complex linear tensors are converted to magnitudes with abs before the dB conversion.

Parameters:
  • prediction (torch.Tensor or tensor-convertible) – Model output tensor. For full HRTF magnitude training this is commonly shaped (batch, positions, ears, frequency). For datasets indexed by subject, position, and ear, this may be shaped (batch, frequency). HRIR targets commonly use samples on the final axis.

  • target (torch.Tensor or tensor-convertible) – Target tensor with the same shape as prediction.

  • metric ({"rmse", "mae", "lsd"}, default=``”rmse”``) – Loss metric.

  • input_scale ({"db", "linear"}, default=``”db”``) – Scale used by metric="lsd" to interpret prediction and target before computing the spectral distance.

  • reduction_method ({"mean", "rms"}, default=``”mean”``) – Reduction applied after the final tensor axis has been reduced. "mean" averages metric values across remaining axes. "rms" computes a root mean square over the remaining metric values.

  • epsilon (float, default=1e-12) – Positive numerical floor used by LSD dB conversion and square-root stabilization.

Returns:

Scalar loss tensor.

Return type:

torch.Tensor

Raises:
  • ImportError – If PyTorch is unavailable in the current environment.

  • ValueError – If options are unsupported, input shapes differ, inputs are scalar, or epsilon is not finite and positive.

Examples

This example downloads the first ten measured HUTUBS HRTFs when needed, builds a dataset that pairs left-ear HRIR samples with left-ear HRTF magnitudes in dB, and trains a small PyTorch model with LSD as the loss. The model receives tensors shaped batch x positions x samples and predicts tensors shaped batch x positions x frequency.

>>> import torch
>>> from torch import nn
>>> from torch.utils.data import DataLoader
>>> from hrtfpykit.datasets import HUTUBS, HRTFSpec
>>> from hrtfpykit.datasets.torch import collate_samples, hrtf_loss
>>> selected_subject_ids = tuple(f"pp{i}" for i in range(1, 11))
>>> train_dataset = HUTUBS(
...     root="datasets/hutubs",
...     download=True,
...     download_resources="hrtf",
...     download_hrtf_variant="measured",
...     download_server="sofacoustics",
...     dataset_hrtf_variant="measured",
...     download_subject_ids=selected_subject_ids,
...     subject_ids=selected_subject_ids,
...     verify_checksum=True,
...     inputs=HRTFSpec(
...         domain="time",
...         signal="ir",
...         ears="left",
...         index_by=("subject",),
...         name="hrir",
...     ),
...     target=HRTFSpec(
...         domain="frequency",
...         signal="tf_magnitude_db",
...         ears="left",
...         index_by=("subject",),
...         name="magnitude_db",
...     ),
...     split="train",
... )
>>> train_loader = DataLoader(
...     train_dataset,
...     batch_size=8,
...     collate_fn=collate_samples,
... )
>>> batch = next(iter(train_loader))
>>> class HRIRToMagnitudeModel(nn.Module):
...     def __init__(self, num_samples, num_frequencies):
...         super().__init__()
...         self.network = nn.Sequential(
...             nn.Linear(num_samples, 256),
...             nn.ReLU(),
...             nn.Linear(256, num_frequencies),
...         )
...     def forward(self, hrir):
...         return self.network(hrir)
>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> num_samples = batch["inputs"]["hrir"].shape[-1]
>>> num_frequencies = batch["target"]["magnitude_db"].shape[-1]
>>> model = HRIRToMagnitudeModel(num_samples, num_frequencies).to(device)
>>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
>>> for epoch in range(10):
...     total_loss = 0.0
...     num_batches = 0
...     for batch in train_loader:
...         hrir = batch["inputs"]["hrir"].to(device)
...         target = batch["target"]["magnitude_db"].to(device)
...         prediction = model(hrir)
...         loss = hrtf_loss(
...             prediction,
...             target,
...             metric="lsd",
...             input_scale="db",
...         )
...         optimizer.zero_grad()
...         loss.backward()
...         optimizer.step()
...         total_loss += float(loss.detach().cpu())
...         num_batches += 1
...     print(f"epoch {epoch + 1:02d} lsd={total_loss / num_batches:.6f}")