Graph Convolutional Engression Network (GCEN)#

class stengression.Models.GCEN(in_feat_dim, gcn_out_feat, lstm_hidden_dim, lstm_num_layers, lstm_dropout, p_lag, t_pred, graph_info, noise_encode='add', noise_dist='gaussian', noise_dim=2, noise_std=1, graph_conv_params: dict | None = None, gcn_seed=21, temporal_seed=9)#

Graph Convolutional Engression Network (GCEN) for probabilistic spatiotemporal forecasting.

This model integrates a GraphConv spatial module with an LSTM-based temporal module to capture both spatial dependencies and temporal dynamics. It follows the engression principle, using noise injection to generate stochastic forecast samples.

Parameters:
  • in_feat_dim (int) – Input feature dimension per node.

  • gcn_out_feat (int) – Output feature dimension from the GCN spatial module.

  • lstm_hidden_dim (int) – Hidden dimension size of the LSTM temporal module.

  • lstm_num_layers (int) – Number of layers in the LSTM.

  • lstm_dropout (float) – Dropout probability in the LSTM.

  • p_lag (int) – Number of past timesteps used as input (lookback window).

  • t_pred (int) – Number of future timesteps predicted (forecast horizon).

  • graph_info (GraphInfo) – An instance of GraphInfo containing graph structure.

  • noise_encode (str, optional) – Method for noise injection: 'add' or 'concat'. Defaults to 'add'.

  • noise_dist (str, optional) – Distribution for noise: 'gaussian' or 'uniform'. Defaults to 'gaussian'.

  • noise_dim (int, optional) – Dimension of noise features if concatenated. Defaults to 2.

  • noise_std (float, optional) – Scaling factor for the noise. Defaults to 1.

  • graph_conv_params (dict, optional) – Dictionary of params for GraphConv. Defaults to None.

  • gcn_seed (int, optional) – Seed for GCN weight initialization. Defaults to 21.

  • temporal_seed (int, optional) – Seed for LSTM initialization. Defaults to 9.

gcn#

The spatial processing module.

Type:

GraphConv

lstm#

The temporal processing module.

Type:

nn.LSTM

output_layer#

Fully connected layer mapping hidden state to forecast horizon.

Type:

nn.Linear

Example

>>> # Initialize the model with graph metadata
>>> model = GCEN(
...     in_feat_dim=1,
...     gcn_out_feat=16,
...     lstm_hidden_dim=32,
...     lstm_num_layers=2,
...     lstm_dropout=0.1,
...     p_lag=12,
...     t_pred=4,
...     graph_info=graph_data
... )
>>> # Train using a probabilistic loss function
>>> model.fit(data_loader, optimizer, loss_fn=energy_score_loss,
...           num_epochs=100, m_samples=2, device=device, visualize=True)
>>> # Generate a single stochastic forward pass
>>> output = model(history_tensor)
>>> # Generate an ensemble of 100 out-of-sample forecast samples
>>> forecast_samples = model.predict(history_tensor, m_samples=100, device=device)
>>> # In-sample analysis and residual diagnostics
>>> in_sample_preds = model.predict_in_sample(train_data, m_samples=100, method="q_step")
>>> residuals = model.get_residuals(in_sample_preds, train_data, point_method="median")
>>> model.plot_residuals(residuals, plots_per_row=4)
>>> # Out-of-sample performance evaluation
>>> metrics_df = model.evaluate_forecasts(history_tensor, y_test, y_train, point_method="median")
>>> print(metrics_df)
evaluate_forecasts(history: Tensor, y_true: Tensor, y_train: Tensor, m_samples=100, n_repeats=50, point_method: str | float = 'median', unstandardize=None, device=None)#

Repeatedly generate probabilistic forecasts from a single trained model and return summary metrics.

This method performs Monte Carlo style evaluation by generating multiple ensembles to account for the stochastic nature of the engression model.

Parameters:
  • history (torch.Tensor) – Last p_lag observations of shape \((T_{in}, N, D_{in})\).

  • y_true (torch.Tensor) – Ground truth for the forecast horizon of shape \((T_{out}, N, D_{in})\).

  • y_train (torch.Tensor) – In-sample training data of shape \((T_{train}, N, D_{in})\). Required for scaling MASE and RMSSE.

  • m_samples (int, optional) – Number of samples per forecast ensemble. Defaults to 100.

  • n_repeats (int, optional) – Number of times to repeat the ensemble generation to calculate metric stability. Defaults to 50.

  • point_method (str or float, optional) – Method to extract a point forecast from the ensemble. Options: "median", "mean", or a float quantile (e.g., 0.75). Defaults to "median".

  • unstandardize (list, optional) – A list [mean, std] to reverse data normalization. Defaults to None.

  • device (torch.device or str, optional) – Device for evaluation. Defaults to None.

Returns:

A DataFrame containing the mean and standard deviation across repeats for the following metrics:

  • Point: SMAPE, MAE, RMSE, MASE, RMSSE.

  • Probabilistic: Pinball (80%, 95%), Rho-risk (0.5, 0.9), CRPS.

  • Calibration: Empirical Coverage, Winkler Score.

Return type:

pd.DataFrame

evaluate_in_sample_fit(data: Tensor, m_samples: int = 100, n_repeats: int = 10, method: str = '1_step', batch_size: int = 32, point_method: str | float = 'median', unstandardize: list | None = None, device: str | None = None)#

Repeatedly generates in-sample probabilistic forecasts and returns summary metrics.

This method assesses the model’s ability to reconstruct the historical training sequence by performing multiple stochastic passes over the data. It accounts for the “Engression” noise injection by repeating the process n_repeats times and reporting the stability of the metrics.

Parameters:
  • data (torch.Tensor) – The full historical dataset of shape \((T, N, D_{in})\), where \(T\) is total time steps.

  • m_samples (int, optional) – Number of stochastic samples per forecast ensemble window. Defaults to 100.

  • n_repeats (int, optional) – Number of full evaluation cycles to perform to calculate mean/std of metrics. Defaults to 10.

  • method (str, optional) – The sliding window strategy: "1_step" or "q_step". Defaults to "1_step".

  • batch_size (int, optional) – Number of windows processed in a single forward pass. Defaults to 32.

  • point_method (str or float, optional) – Metric for extracting a single forecast from the ensemble to compute point errors (SMAPE, MAE, RMSE, MASE, RMSSE). Accepts "median", "mean", or a float quantile (e.g., 0.75). Defaults to "median".

  • unstandardize (list, optional) – A list [mean, std] used to rescale predictions and ground truth to original units. Defaults to None.

  • device (str or torch.device, optional) – Computation device. Defaults to None.

Returns:

A summary table containing the mean and standard deviation for the following metrics across all repeats:

  • Point Metrics: SMAPE, MAE, RMSE, MASE, RMSSE.

  • Probabilistic Metrics: Pinball (80%, 95%), Rho-risk (0.5, 0.9), CRPS.

  • Calibration Metrics: Empirical Coverage, Winkler Score.

Return type:

pd.DataFrame

Note

Similar to predict_in_sample(), metrics are calculated only for the time steps after the initial p_lag warm-up period.

fit(data_loader, optimizer, loss_fn, num_epochs=100, m_samples=2, device='cpu', monitor=True, visualize=True, verbose=False)#

Trains the GCEN model using a probabilistic loss function.

Parameters:
  • data_loader (torch.utils.data.DataLoader) – PyTorch DataLoader yielding (x_batch, y_batch).

  • optimizer (torch.optim.Optimizer) – PyTorch optimizer instance.

  • loss_fn (callable) – Loss function accepting (target, pred_samples).

  • num_epochs (int, optional) – Number of training epochs. Defaults to 100.

  • m_samples (int, optional) – Number of stochastic samples per batch to estimate the probabilistic loss. Defaults to 2.

  • device (str, optional) – Device to run training on ('cpu' or 'cuda').

  • monitor (bool, optional) – If True, shows a progress bar.

  • visualize (bool, optional) – If True, plots the loss curve after training.

  • verbose (bool, optional) – If True, prints periodic loss updates.

forward(x: Tensor)#

Forward pass for generating a single stochastic forecast.

Parameters:

x (torch.Tensor) – Input tensor of shape \((B, T_{in}, N, D_{in})\), where \(B\) is batch size, \(T_{in}\) is p_lag, \(N\) is num_nodes, and \(D_{in}\) is in_feat_dim.

Returns:

Forecasted tensor of shape \((B, T_{out}, N, D_{in})\), where \(T_{out}\) is t_pred.

Return type:

torch.Tensor

get_residuals(in_sample_preds: Tensor, original_data: Tensor, point_method: str | float = 'median') Tensor#

Computes the residual matrix from in-sample predictions.

This method reduces the stochastic ensemble into a single point forecast using the specified point_method and calculates the error: \(Residual = Actual - Predicted\).

Parameters:
  • in_sample_preds (torch.Tensor) – Predictions from predict_in_sample() with shape \((M, T, N, D)\), where \(M\) is the number of samples.

  • original_data (torch.Tensor) – The ground truth dataset of shape \((T, N, D)\).

  • point_method (str or float, optional) – Method to extract a point forecast from the ensemble. Options: "median", "mean", or a float representing a quantile (e.g., 0.75). Defaults to "median".

Returns:

Residual tensor of shape \((T, N, D)\).

Return type:

torch.Tensor

Note

Following the structure of the in-sample predictions, the first p_lag time steps will contain NaN values. These represent the “warm-up” period where no forecasts were generated.

plot_in_sample_fit(in_sample_preds: Tensor, original_data: Tensor, plots_per_row: int = 4, confidence_level: float = 0.95, node_names: list | None = None, title: str = 'GCEN: In-Sample Forecast Fit vs. Actual', savefig: bool = False, filename: str | None = None)#

Plots in-sample forecasted time series against actual ground truth values.

This method generates a grid of subplots (one per node) showing the median forecast, the ground truth, and a shaded prediction interval based on the stochastic ensemble.

Parameters:
  • in_sample_preds (torch.Tensor) – Prediction ensemble from predict_in_sample() of shape \((M, T, N, D)\).

  • original_data (torch.Tensor) – Ground truth dataset of shape \((T, N, D)\).

  • plots_per_row (int, optional) – Number of subplots to display per row in the figure grid. Defaults to 4.

  • confidence_level (float, optional) – The width of the shaded prediction interval (e.g., 0.95 for a 95% interval). Defaults to 0.95.

  • node_names (list of str, optional) – Custom labels for each node. If None, nodes are labeled by index. Defaults to None.

  • title (str, optional) – The main title for the entire figure. Defaults to ‘GCEN: In-Sample Forecast Fit vs. Actual’.

  • savefig (bool, optional) – If True, the resulting figure is exported to a file. Defaults to False.

  • filename (str, optional) – The file path/name for saving the figure (e.g., ‘fit_plot.png’). Required if savefig is True.

Note

The shaded area represents the uncertainty captured by the Engression noise injection. The bounds are calculated as the \((1 - confidence\_level)/2\) and \((1 + confidence\_level)/2\) quantiles of the ensemble.

Returns:

This method displays the plot using plt.show() or saves it to disk.

Return type:

None

plot_residuals(residuals: Tensor, plots_per_row: int = 4, node_names: list | None = None, title: str = 'GCEN: In-Sample Residuals per Node', savefig: bool = False, filename: str | None = None)#

Plots the time series of residuals (errors) for each node in a grid.

This visualization helps in identifying systematic biases or patterns in the model’s errors across different spatial nodes. A horizontal line at zero is included for reference.

Parameters:
  • residuals (torch.Tensor) – Residual matrix obtained from get_residuals(), with shape \((T, N, D)\).

  • plots_per_row (int, optional) – Number of subplots to display in each row of the grid. Defaults to 4.

  • node_names (list of str, optional) – Labels for each node subplot. If None, node indices are used. Defaults to None.

  • title (str, optional) – The main title for the figure. Defaults to ‘GCEN: In-Sample Residuals per Node’.

  • savefig (bool, optional) – If True, the plot will be saved to the specified filename. Defaults to False.

  • filename (str, optional) – Path where the figure should be saved. Required if savefig is True. Defaults to None.

Note

Any NaN values present in the residuals (typically the first p_lag steps) are automatically handled by the plotting backend and will appear as gaps in the time series.

Returns:

Displays the plot using plt.show() or saves the file.

Return type:

None

predict(history: Tensor, m_samples: int = 100, unstandardize: list | None = None, device='cpu') Tensor#

Generates an ensemble of forecasts for a single historical observation.

Parameters:
  • history (torch.Tensor) – Past observations of shape \((T_{in}, N, D_{in})\).

  • m_samples (int, optional) – Number of ensemble members to generate. Defaults to 100.

  • unstandardize (list, optional) – A list [mean, std] to reverse data normalization. Defaults to None.

  • device (str, optional) – Computation device. Defaults to 'cpu'.

Returns:

Ensemble of forecasts of shape \((M, T_{out}, N, D_{in})\), where \(M\) is m_samples.

Return type:

torch.Tensor

predict_in_sample(data: Tensor, m_samples: int = 100, method: str = '1_step', batch_size: int = 64, unstandardize: list | None = None, device: str = 'cpu') Tensor#

Generates in-sample predictions for the entire training dataset.

This method applies a sliding window across the provided historical data to produce stochastic forecasts for every possible time step.

Parameters:
  • data (torch.Tensor) – The original dataset of shape \((T, N, D)\), where \(T\) is the total time steps.

  • m_samples (int, optional) – Number of stochastic forecast samples to generate per window. Defaults to 100.

  • method (str, optional) –

    Strategy for in-sample forecasting. Options include:

    • "1_step": Slides the input window by 1 step, recording only the 1-step ahead forecast for each position.

    • "q_step": Slides by t_pred (q) steps, recording full non-overlapping forecast horizons.

    Defaults to "1_step".

  • batch_size (int, optional) – Number of windows to process simultaneously to optimize memory. Defaults to 64.

  • unstandardize (list, optional) – A list [mean, std] to reverse data normalization. Defaults to None.

  • device (str, optional) – Device to perform computations on. Defaults to "cpu".

Returns:

In-sample forecast ensemble of shape \((M, T, N, D)\), where \(M\) is m_samples.

Return type:

torch.Tensor

Note

The first p_lag time steps in the output tensor will contain NaN values because there is insufficient historical context to generate a forecast for the beginning of the sequence.

class stengression.Models.GraphConv(in_feat: int, out_feat: int, graph_info, gcn_seed: int = 21, aggregation_type: str = 'mean', combination_type: str = 'concat', activation: str | None = None)#

Graph convolution layer for learning node representations in a graph.

Parameters:
  • in_feat (int) – Input feature dimension per node.

  • out_feat (int) – Output feature dimension per node after convolution.

  • graph_info (GraphInfo) – Graph structure metadata containing edges and node counts.

  • gcn_seed (int, optional) – Random seed for weight initialization. Defaults to 21.

  • aggregation_type (str, optional) – How to aggregate neighbor messages. Options: ‘mean’, ‘sum’, ‘max’. Defaults to ‘mean’.

  • combination_type (str, optional) – How to combine node features and aggregated messages. Options: ‘concat’, ‘add’. Defaults to ‘concat’.

  • activation (str, optional) – Name of the activation function from torch.nn.functional (e.g., ‘relu’). If None, no activation is applied. Defaults to None.

Note

This layer computes spatial graph convolutions by aggregating neighbor features, combining with transformed node features, and applying a non-linear activation. The layer performs a message-passing operation:

  1. Aggregate: Collects features from neighbors defined in graph_info.

  2. Combine: Merges aggregated features with the node’s own features.

  3. Activate: Applies the specified non-linear function.

Shapes:
  • Input: \((N, B, T, D_{in})\) where \(N\) is num_nodes, \(B\) is batch_size, \(T\) is sequence_length, and \(D_{in}\) is in_feat.

  • Output: \((N, B, T, D_{out})\) where \(D_{out}\) is out_feat.

compute_nodes_representation(features: Tensor)#

Compute node representations via a linear projection on the last dimension.

This method performs a matrix multiplication between the input features and the layer’s internal weight matrix.

Parameters:

features (torch.Tensor) – Input feature tensor of shape \((N, B, T, D_{in})\), where \(N\) is the number of nodes, \(B\) is batch size, \(T\) is sequence length, and \(D_{in}\) is the input feature dimension.

Returns:

The projected representations of shape

\((N, B, T, D_{out})\), where \(D_{out}\) is the output feature dimension.

Return type:

torch.Tensor

forward(features: Tensor)#

Perform the graph convolution forward pass.

This method implements the three-step message passing process: 1. Projection: Computes node-wise representations. 2. Aggregation: Gathers messages from neighboring nodes. 3. Update: Combines local and neighborhood information.

Parameters:

features (torch.Tensor) – Input spatiotemporal features of shape \((N, B, T, D_{in})\), where \(N\) is number of nodes, \(B\) is batch size, \(T\) is sequence length, and \(D_{in}\) is input feature dimension.

Returns:

The updated node representations of shape \((N, B, T, D_{out})\), where \(D_{out}\) is the output feature dimension.

Return type:

torch.Tensor

class stengression.Models.GraphInfo(edges: Tuple[list, list], num_nodes: int)#

A class to hold graph structure information.