Loss Functions (stengression.Losses)#

stengression.Losses.EnergyMSELoss(y_true, y_pred_samples, lambda_mse=0.5)#

Combines Energy Score and Mean Squared Error (MSE) loss.

This hybrid loss function regularizes the probabilistic Energy Score with a deterministic MSE term calculated from the sample mean. This is often used to stabilize training and improve the accuracy of the point-estimate.

Parameters:
  • y_true (torch.Tensor) – Ground truth tensor of shape \((B, t_{pred}, N, D)\).

  • y_pred_samples (torch.Tensor) – Predicted samples of shape \((M, B, t_{pred}, N, D)\).

  • lambda_mse (float, optional) – Weighting factor for the MSE component. Defaults to 0.5.

Returns:

Combined loss value calculated as:

\(L = |ES(y, \hat{y}) + \lambda_{mse} \cdot MSE(y, \mathbb{E}[\hat{y}])|\).

Return type:

torch.Tensor

Note

The MSE component is computed by first averaging the \(M\) samples to obtain a point forecast.

stengression.Losses.energy_score_loss(y_true, y_pred_samples)#

Calculates the Energy Score loss for multivariate probabilistic forecasting.

The Energy Score is a proper scoring rule used to evaluate the quality of probabilistic forecasts.

Parameters:
  • y_true (torch.Tensor) – Ground truth tensor of shape \((B, t_{pred}, N, D)\).

  • y_pred_samples (torch.Tensor) – Predicted samples or trajectories of shape \((M, B, t_{pred}, N, D)\), where \(M\) is the number of stochastic samples.

Returns:

A scalar representing the mean Energy Score loss across the batch.

Return type:

torch.Tensor

Note

The loss is calculated as:

\[ES(F, y) = E_F[\|X - y\|] - \frac{1}{2}E_F[\|X - X'\|]\]

where \(X\) and \(X'\) are independent samples from the forecast distribution \(F\).