hrtf_loss¶
- hrtfpykit.datasets.torch.hrtf_loss(prediction, target, metric='rmse', input_scale='db', reduction_method='mean', epsilon=1e-12)¶
Compute a scalar PyTorch loss for HRTF and HRIR model outputs.
hrtf_lossaccepts PyTorch tensors or tensor-convertible values, comparespredictionandtargetwith RMSE, MAE, or LSD, and returns one scalar tensor for backpropagation withloss.backward(). The last tensor axis is the acoustic axis being scored: HRIR samples for time-domain targets, frequency bins for HRTF magnitude targets, and frequency bins for LSD. After that axis is reduced, the remaining axes, commonly batch, positions, and ears, are reduced withreduction_method.The function supports three metrics:
metric="rmse"computes root mean squared error over the final tensor axis.metric="mae"computes mean absolute error over the final tensor axis.metric="lsd"computes log-spectral distance over the final tensor axis.
rmseandmaemeasure direct tensor error in the representation passed to the function. For HRIR targets, the final axis contains time samples. For HRTF magnitude or dB-magnitude targets, the final axis contains frequency bins and the loss measures magnitude error in that representation.lsdmeasures spectral magnitude error in decibels. Withinput_scale="db",predictionandtargetare dB magnitudes. Withinput_scale="linear", values are linear magnitudes and are converted to dB with20 * log10(clamp(value, min=epsilon)). Complex linear tensors are converted to magnitudes withabsbefore the dB conversion.- Parameters:
prediction (torch.Tensor or tensor-convertible) – Model output tensor. For full HRTF magnitude training this is commonly shaped
(batch, positions, ears, frequency). For datasets indexed by subject, position, and ear, this may be shaped(batch, frequency). HRIR targets commonly use samples on the final axis.target (torch.Tensor or tensor-convertible) – Target tensor with the same shape as
prediction.metric ({
"rmse","mae","lsd"}, default=``”rmse”``) – Loss metric.input_scale ({
"db","linear"}, default=``”db”``) – Scale used bymetric="lsd"to interpretpredictionandtargetbefore computing the spectral distance.reduction_method ({
"mean","rms"}, default=``”mean”``) – Reduction applied after the final tensor axis has been reduced."mean"averages metric values across remaining axes."rms"computes a root mean square over the remaining metric values.epsilon (float, default=1e-12) – Positive numerical floor used by LSD dB conversion and square-root stabilization.
- Returns:
Scalar loss tensor.
- Return type:
torch.Tensor
- Raises:
ImportError – If PyTorch is unavailable in the current environment.
ValueError – If options are unsupported, input shapes differ, inputs are scalar, or
epsilonis not finite and positive.
Examples
This example downloads the first ten measured HUTUBS HRTFs when needed, builds a dataset that pairs left-ear HRIR samples with left-ear HRTF magnitudes in dB, and trains a small PyTorch model with LSD as the loss. The model receives tensors shaped
batch x positions x samplesand predicts tensors shapedbatch x positions x frequency.>>> import torch >>> from torch import nn >>> from torch.utils.data import DataLoader >>> from hrtfpykit.datasets import HUTUBS, HRTFSpec >>> from hrtfpykit.datasets.torch import collate_samples, hrtf_loss >>> selected_subject_ids = tuple(f"pp{i}" for i in range(1, 11)) >>> train_dataset = HUTUBS( ... root="datasets/hutubs", ... download=True, ... download_resources="hrtf", ... download_hrtf_variant="measured", ... download_server="sofacoustics", ... dataset_hrtf_variant="measured", ... download_subject_ids=selected_subject_ids, ... subject_ids=selected_subject_ids, ... verify_checksum=True, ... inputs=HRTFSpec( ... domain="time", ... signal="ir", ... ears="left", ... index_by=("subject",), ... name="hrir", ... ), ... target=HRTFSpec( ... domain="frequency", ... signal="tf_magnitude_db", ... ears="left", ... index_by=("subject",), ... name="magnitude_db", ... ), ... split="train", ... ) >>> train_loader = DataLoader( ... train_dataset, ... batch_size=8, ... collate_fn=collate_samples, ... ) >>> batch = next(iter(train_loader)) >>> class HRIRToMagnitudeModel(nn.Module): ... def __init__(self, num_samples, num_frequencies): ... super().__init__() ... self.network = nn.Sequential( ... nn.Linear(num_samples, 256), ... nn.ReLU(), ... nn.Linear(256, num_frequencies), ... ) ... def forward(self, hrir): ... return self.network(hrir) >>> device = "cuda" if torch.cuda.is_available() else "cpu" >>> num_samples = batch["inputs"]["hrir"].shape[-1] >>> num_frequencies = batch["target"]["magnitude_db"].shape[-1] >>> model = HRIRToMagnitudeModel(num_samples, num_frequencies).to(device) >>> optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) >>> for epoch in range(10): ... total_loss = 0.0 ... num_batches = 0 ... for batch in train_loader: ... hrir = batch["inputs"]["hrir"].to(device) ... target = batch["target"]["magnitude_db"].to(device) ... prediction = model(hrir) ... loss = hrtf_loss( ... prediction, ... target, ... metric="lsd", ... input_scale="db", ... ) ... optimizer.zero_grad() ... loss.backward() ... optimizer.step() ... total_loss += float(loss.detach().cpu()) ... num_batches += 1 ... print(f"epoch {epoch + 1:02d} lsd={total_loss / num_batches:.6f}")