Spatio-Temporal Engression Network (STEN)#

class stengression.Models.STEN(in_feat_dim, num_nodes, embedding_dim, max_spatial_lag, lstm_hidden_dim, lstm_num_layers, lstm_dropout, p_lag, t_pred, noise_encode='add', noise_dist='gaussian', noise_dim=2, noise_std=1, temporal_seed=9)#

Spatio-Temporal Engression Network.

Combines the STAR-Layer and Engression-LSTM for probabilistic spatiotemporal forecasting.

Parameters:
  • in_feat_dim (int) – Number of input features for each node (D).

  • embedding_dim (int) – The dimension of the spatial embedding (D’).

  • max_spatial_lag (int) – The maximum spatial lag (L) for the STAR-Layer.

  • p_lag (int) – Number of past timesteps used as input (input sequence length).

  • t_pred (int) – Number of future timesteps predicted (forecast horizon).

  • lstm_hidden_dim (int) – Hidden dimension size of the LSTM temporal module.

  • lstm_num_layers (int) – Number of layers in the LSTM.

  • lstm_dropout (float) – Dropout probability in the LSTM.

  • noise_dist (str) – Distribution type for noise injection (‘gaussian’ or ‘uniform’).

  • noise_encode (str) – Method for noise injection: ‘add’ (additive) or ‘concat’ (concatenation).

  • noise_dim (int) – Dimension of noise features if concatenated.

  • noise_std (int) – Standard deviation/scaling of the noise.

  • num_nodes (int) – Number of nodes in the graph.

  • temporal_seed (int) – Seed for reproducibility in LSTM initialization.

Example

>>> model = STEN(
...     in_feat_dim=1, embedding_dim=8, lstm_hidden_dim=32,
...     lstm_num_layers=2, lstm_dropout=0.1, p_lag=12, t_pred=4
... )
>>> model.fit(
...     data_loader, optimizer, loss_fn=energy_score_loss,
...     num_epochs=100, m_samples=2, device=device,
...     visualize=True, W_list=W_list
... )
>>> output = model(history_tensor, W_list=W_list)
>>> forecast_samples = model.predict(
...     history_tensor, m_samples=100, device=device, W_list=W_list
... )
>>> # In-sample analysis
>>> in_sample_preds = model.predict_in_sample(
...     train_data, m_samples=100, method="q_step", W_list=W_list
... )
>>> residuals = model.get_residuals(
...     in_sample_preds, train_data, point_method="median"
... )
>>> model.plot_residuals(residuals, plots_per_row=4)
>>> # Out-of-sample evaluation
>>> metrics_df = model.evaluate_forecasts(
...     history_tensor, y_test, y_train,
...     point_method="median", W_list=W_list
... )
>>> print(metrics_df)
evaluate_forecasts(history: Tensor, y_true: Tensor, y_train: Tensor, W_list, m_samples=100, n_repeats=50, point_method: str | float = 'median', unstandardize=None, device=None)#

Repeatedly generate probabilistic forecasts from a single trained model and return summary metrics.

This method performs Monte Carlo style evaluation by generating multiple ensembles to account for the stochastic nature of the engression model.

Parameters:
  • history (torch.Tensor) – Last p_lag observations of shape \((T_{in}, N, D_{in})\).

  • y_true (torch.Tensor) – Ground truth for the forecast horizon of shape \((T_{out}, N, D_{in})\).

  • y_train (torch.Tensor) – In-sample training data of shape \((T_{train}, N, D_{in})\). Required for scaling MASE and RMSSE.

  • W_list (list) – List of spatial weights matrices.

  • m_samples (int, optional) – Number of samples per forecast ensemble. Defaults to 100.

  • n_repeats (int, optional) – Number of times to repeat the ensemble generation to calculate metric stability. Defaults to 50.

  • point_method (str or float, optional) – Method to extract a point forecast from the ensemble. Options: "median", "mean", or a float quantile (e.g., 0.75). Defaults to "median".

  • unstandardize (list, optional) – A list [mean, std] to reverse data normalization. Defaults to None.

  • device (torch.device or str, optional) – Device for evaluation. Defaults to None.

Returns:

A DataFrame containing the mean and standard deviation across repeats for the following metrics:

  • Point: SMAPE, MAE, RMSE, MASE, RMSSE.

  • Probabilistic: Pinball (80%, 95%), Rho-risk (0.5, 0.9), CRPS.

  • Calibration: Empirical Coverage, Winkler Score.

Return type:

pd.DataFrame

evaluate_in_sample_fit(data: Tensor, W_list, m_samples: int = 100, n_repeats: int = 10, method: str = '1_step', batch_size: int = 32, point_method: str | float = 'median', unstandardize: list | None = None, device: str | None = None)#

Repeatedly generates in-sample probabilistic forecasts and returns summary metrics.

This method assesses the model’s ability to reconstruct the historical training sequence by performing multiple stochastic passes over the data. It accounts for the “Engression” noise injection by repeating the process n_repeats times and reporting the stability of the metrics.

Parameters:
  • data (torch.Tensor) – The full historical dataset of shape \((T, N, D_{in})\), where \(T\) is total time steps.

  • W_list (list) – List of spatial weights matrices.

  • m_samples (int, optional) – Number of stochastic samples per forecast ensemble window. Defaults to 100.

  • n_repeats (int, optional) – Number of full evaluation cycles to perform to calculate mean/std of metrics. Defaults to 10.

  • method (str, optional) – The sliding window strategy: "1_step" or "q_step". Defaults to "1_step".

  • batch_size (int, optional) – Number of windows processed in a single forward pass. Defaults to 32.

  • point_method (str or float, optional) – Metric for extracting a single forecast from the ensemble to compute point errors (SMAPE, MAE, RMSE, MASE, RMSSE). Accepts "median", "mean", or a float quantile (e.g., 0.75). Defaults to "median".

  • unstandardize (list, optional) – A list [mean, std] used to rescale predictions and ground truth to original units. Defaults to None.

  • device (str or torch.device, optional) – Computation device. Defaults to None.

Returns:

A summary table containing the mean and standard deviation for the following metrics across all repeats:

  • Point Metrics: SMAPE, MAE, RMSE, MASE, RMSSE.

  • Probabilistic Metrics: Pinball (80%, 95%), Rho-risk (0.5, 0.9), CRPS.

  • Calibration Metrics: Empirical Coverage, Winkler Score.

Return type:

pd.DataFrame

Note

Similar to predict_in_sample(), metrics are calculated only for the time steps after the initial p_lag warm-up period.

fit(data_loader, optimizer, loss_fn, W_list, num_epochs=100, m_samples=2, device='cpu', monitor=True, visualize=True, verbose=False)#

Trains the STEN model using a dataloader, optimizer, and probabilistic loss function.

Parameters:
  • data_loader (torch.utils.data.DataLoader) – PyTorch DataLoader yielding (x_batch, y_batch).

  • optimizer (torch.optim.Optimizer) – PyTorch optimizer instance.

  • loss_fn (callable) – Loss function accepting (target, pred_samples).

  • W_list (list) – List of spatial weights matrices.

  • num_epochs (int, optional) – Number of training epochs. Defaults to 100.

  • m_samples (int, optional) – Number of stochastic samples per batch to estimate the probabilistic loss. Defaults to 2.

  • device (str, optional) – Device to run training on ('cpu' or 'cuda').

  • monitor (bool, optional) – If True, shows a progress bar.

  • visualize (bool, optional) – If True, plots the loss curve after training.

  • verbose (bool, optional) – If True, prints periodic loss updates.

forward(x, W_list)#

Forward pass of STEN for generating a single stochastic forecast.

Parameters:
  • x (torch.Tensor) – Input tensor of shape \((B, T_{in}, N, D_{in})\), where \(B\) is batch size, \(T_{in}\) is p_lag, \(N\) is num_nodes, and \(D_{in}\) is in_feat_dim.

  • W_list (list) – List of spatial weights matrices.

Returns:

Forecasted tensor of shape \((B, T_{out}, N, D_{in})\), where \(T_{out}\) is t_pred.

Return type:

torch.Tensor

get_residuals(in_sample_preds: Tensor, original_data: Tensor, point_method: str | float = 'median') Tensor#

Computes the residual matrix from in-sample predictions.

This method reduces the stochastic ensemble into a single point forecast using the specified point_method and calculates the error: \(Residual = Actual - Predicted\).

Parameters:
  • in_sample_preds (torch.Tensor) – Predictions from predict_in_sample() with shape \((M, T, N, D)\), where \(M\) is the number of samples.

  • original_data (torch.Tensor) – The ground truth dataset of shape \((T, N, D)\).

  • point_method (str or float, optional) – Method to extract a point forecast from the ensemble. Options: "median", "mean", or a float representing a quantile (e.g., 0.75). Defaults to "median".

Returns:

Residual tensor of shape \((T, N, D)\).

Return type:

torch.Tensor

Note

Following the structure of the in-sample predictions, the first p_lag time steps will contain NaN values. These represent the “warm-up” period where no forecasts were generated.

plot_in_sample_fit(in_sample_preds: Tensor, original_data: Tensor, plots_per_row: int = 4, confidence_level: float = 0.95, node_names: list | None = None, title: str = 'STEN: In-Sample Forecast Fit vs. Actual', savefig: bool = False, filename: str | None = None)#

Plots in-sample forecasted time series against actual ground truth values.

This method generates a grid of subplots (one per node) showing the median forecast, the ground truth, and a shaded prediction interval based on the stochastic ensemble.

Parameters:
  • in_sample_preds (torch.Tensor) – Prediction ensemble from predict_in_sample() of shape \((M, T, N, D)\).

  • original_data (torch.Tensor) – Ground truth dataset of shape \((T, N, D)\).

  • plots_per_row (int, optional) – Number of subplots to display per row in the figure grid. Defaults to 4.

  • confidence_level (float, optional) – The width of the shaded prediction interval (e.g., 0.95 for a 95% interval). Defaults to 0.95.

  • node_names (list of str, optional) – Custom labels for each node. If None, nodes are labeled by index. Defaults to None.

  • title (str, optional) – The main title for the entire figure.

  • savefig (bool, optional) – If True, the resulting figure is exported to a file. Defaults to False.

  • filename (str, optional) – The file path/name for saving the figure (e.g., ‘fit_plot.png’). Required if savefig is True.

Note

The shaded area represents the uncertainty captured by the Engression noise injection. The bounds are calculated as the \((1 - confidence\_level)/2\) and \((1 + confidence\_level)/2\) quantiles of the ensemble.

Returns:

This method displays the plot using plt.show() or saves it to disk.

Return type:

None

plot_residuals(residuals: Tensor, plots_per_row: int = 4, node_names: list | None = None, title: str = 'STEN: In-Sample Residuals per Node', savefig: bool = False, filename: str | None = None)#

Plots the time series of residuals (errors) for each node in a grid.

This visualization helps in identifying systematic biases or patterns in the model’s errors across different spatial nodes. A horizontal line at zero is included for reference.

Parameters:
  • residuals (torch.Tensor) – Residual matrix obtained from get_residuals(), with shape \((T, N, D)\).

  • plots_per_row (int, optional) – Number of subplots to display in each row of the grid. Defaults to 4.

  • node_names (list of str, optional) – Labels for each node subplot. If None, node indices are used. Defaults to None.

  • title (str, optional) – The main title for the figure. Defaults to ‘STEN: In-Sample Residuals per Node’.

  • savefig (bool, optional) – If True, the plot will be saved to the specified filename. Defaults to False.

  • filename (str, optional) – Path where the figure should be saved. Required if savefig is True. Defaults to None.

Note

Any NaN values present in the residuals (typically the first p_lag steps) are automatically handled by the plotting backend and will appear as gaps in the time series.

Returns:

Displays the plot using plt.show() or saves the file.

Return type:

None

predict(history: Tensor, W_list, m_samples: int = 100, unstandardize=None, device=None) Tensor#

Generates an ensemble of forecasts for a single historical observation.

Parameters:
  • history (torch.Tensor) – Past observations of shape \((T_{in}, N, D_{in})\).

  • W_list (list) – List of spatial weights matrices.

  • m_samples (int, optional) – Number of ensemble members to generate. Defaults to 100.

  • unstandardize (list, optional) – A list [mean, std] to reverse data normalization. Defaults to None.

  • device (str, optional) – Computation device. Defaults to 'cpu'.

Returns:

Ensemble of forecasts of shape \((M, T_{out}, N, D_{in})\), where \(M\) is m_samples.

Return type:

torch.Tensor

predict_in_sample(data: Tensor, W_list, m_samples: int = 100, method: str = '1_step', batch_size: int = 64, unstandardize: list | None = None, device: str = 'cpu') Tensor#

Generates in-sample predictions for the entire training dataset.

This method applies a sliding window across the provided historical data to produce stochastic forecasts for every possible time step.

Parameters:
  • data (torch.Tensor) – The original dataset of shape \((T, N, D)\), where \(T\) is the total time steps.

  • W_list (list) – List of spatial weights matrices.

  • m_samples (int, optional) – Number of stochastic forecast samples to generate per window. Defaults to 100.

  • method (str, optional) –

    Strategy for in-sample forecasting. Options include:

    • "1_step": Slides the input window by 1 step, recording only the 1-step ahead forecast for each position.

    • "q_step": Slides by t_pred (q) steps, recording full non-overlapping forecast horizons.

    Defaults to "1_step".

  • batch_size (int, optional) – Number of windows to process simultaneously to optimize memory. Defaults to 64.

  • unstandardize (list, optional) – A list [mean, std] to reverse data normalization. Defaults to None.

  • device (str, optional) – Device to perform computations on. Defaults to "cpu".

Returns:

In-sample forecast ensemble of shape \((M, T, N, D)\), where \(M\) is m_samples.

Return type:

torch.Tensor

Note

The first p_lag time steps in the output tensor will contain NaN values because there is insufficient historical context to generate a forecast for the beginning of the sequence.

class stengression.Models.STARLayer(in_features, out_features, max_spatial_lag, spatial_seed=21)#

A differentiable neural layer inspired by the STARMA model’s spatial component.

This layer computes a spatial embedding by aggregating features from neighbors at multiple spatial lags (distances). It transforms node features through a combination of learnable weights and fixed spatial weight matrices.

Parameters:
  • in_features (int) – Number of input features for each node (D).

  • out_features (int) – Number of output features for each node (D’).

  • max_spatial_lag (int) – The maximum spatial lag (L) to consider.

Note

The forward pass expects a list of spatial weight matrices (W_list), where each matrix represents a specific spatial lag.

forward(x, W_list)#

Forward pass for the STAR-Layer.

Parameters:
  • x (torch.Tensor) – Input tensor of shape (batch, seq_len, num_nodes, in_features).

  • W_list (list of torch.Tensor) – A list of spatial weight matrices.

Returns:

The spatial embedding tensor of shape (batch, seq_len, num_nodes, out_features).

Return type:

torch.Tensor

Shapes:
  • Input: \((B, T, N, D_{in})\) where \(N\) is num_nodes, \(B\) is batch_size, \(T\) is sequence_length, and \(D_{in}\) is in_feat.

  • Output: \((B, T, N, D_{out})\) where \(D_{out}\) is \(D'\).