A Dual-Stage Attention-Based Recurrent Neural Network

 

 

Succinct summary: An older but still strong performing RNN based method that also incorporates an attention mechanism to improve forecasting. Can only forecast one step at a time

 

Metadata incorporation and results:

We have modified DA-RNN to utilize meta-data by including a merging layer in the encoder_layer. This module can be called with several different properties:

class MetaMerger(nn.Module): def __init__(self, meta_params, meta_method, embed_shape, in_shape): super().__init__() self.method_layer = meta_method if meta_method == "down_sample": self.initial_layer = torch.nn.Linear(embed_shape, in_shape) elif meta_method == "up_sample": self.initial_layer = torch.nn.Linear(in_shape, embed_shape) self.model_merger = MergingModel(meta_params["method"], meta_params["params"])

Lets review the methods you can use:

“down_sample”: In this method we use a linear layer to map the embedding representation to an Tensor of size (batch_size, seq_len, n_feature_time_series-1). The Bilinear layer then operates along the seq_len dimension as it will then be resized to (batch_size, 1, n_feature_series-1).

“up_sample”: The reverse of down sample scales the features back to embedding_dim.

“repeat”: Repeats the meta-data across the temporal dimension. So for instance if the temporal data is dimension (batch_size, seq_len, n_feature_time_series) then we will repeat the meta-data to be (batch_size, seq_len, embedding_dim).

COVID-19 Forecasting Results (no pre-training):

County

Vanilla (no meta data)

down_sample

repeat

Notes

County

Vanilla (no meta data)

down_sample

repeat

Notes

Palm Beach County FL

662.233 MSE

1860 MSE

1741 MSE

Meta-data does not seem to help.

Albany County NY

 

 

 

 

Middlesex County MA

240.994 MSE

250.495 MSE

?

 

COVID-19 Forecasting Results (with pre-training first on other 240 other counties):

County

Vanilla (no meta data)

down_sample

repeat

 

County

Vanilla (no meta data)

down_sample

repeat

 

Palm Beach County

 

 

99.941 MSE

 

Middlesex County

 

 

33.745 MSE

 

Albany County

 

 

0.1058 MSE

 

Harris County

 

 

35988.954 MSE

 

River flow forecasting (no pre-training)

Gage ID

Vanilla (no meta data included)

down_sample

Self-Attention

repeat

Gage ID

Vanilla (no meta data included)

down_sample

Self-Attention

repeat