Loss Functions (stengression.Losses)#
- stengression.Losses.EnergyMSELoss(y_true, y_pred_samples, lambda_mse=0.5)#
Combines Energy Score and Mean Squared Error (MSE) loss.
This hybrid loss function regularizes the probabilistic Energy Score with a deterministic MSE term calculated from the sample mean. This is often used to stabilize training and improve the accuracy of the point-estimate.
- Parameters:
y_true (torch.Tensor) – Ground truth tensor of shape \((B, t_{pred}, N, D)\).
y_pred_samples (torch.Tensor) – Predicted samples of shape \((M, B, t_{pred}, N, D)\).
lambda_mse (float, optional) – Weighting factor for the MSE component. Defaults to 0.5.
- Returns:
- Combined loss value calculated as:
\(L = |ES(y, \hat{y}) + \lambda_{mse} \cdot MSE(y, \mathbb{E}[\hat{y}])|\).
- Return type:
torch.Tensor
Note
The MSE component is computed by first averaging the \(M\) samples to obtain a point forecast.
- stengression.Losses.energy_score_loss(y_true, y_pred_samples)#
Calculates the Energy Score loss for multivariate probabilistic forecasting.
The Energy Score is a proper scoring rule used to evaluate the quality of probabilistic forecasts.
- Parameters:
y_true (torch.Tensor) – Ground truth tensor of shape \((B, t_{pred}, N, D)\).
y_pred_samples (torch.Tensor) – Predicted samples or trajectories of shape \((M, B, t_{pred}, N, D)\), where \(M\) is the number of stochastic samples.
- Returns:
A scalar representing the mean Energy Score loss across the batch.
- Return type:
torch.Tensor
Note
The loss is calculated as:
\[ES(F, y) = E_F[\|X - y\|] - \frac{1}{2}E_F[\|X - X'\|]\]where \(X\) and \(X'\) are independent samples from the forecast distribution \(F\).