Succinct summary: An older but still strong performing RNN based method that also incorporates an attention mechanism to improve forecasting. Can only forecast one step at a time
Metadata incorporation and results:
We have modified DA-RNN to utilize meta-data by including a merging layer in the encoder_layer. This module can be called with several different properties:
Code Block | ||
---|---|---|
| ||
class MetaMerger(nn.Module): def __init__(self, meta_params, meta_method, embed_shape, in_shape): super().__init__() self.method_layer = meta_method if meta_method == "down_sample": self.initial_layer = torch.nn.Linear(embed_shape, in_shape) elif meta_method == "up_sample": self.initial_layer = torch.nn.Linear(in_shape, embed_shape) self.model_merger = MergingModel(meta_params["method"], meta_params["params"]) |
Lets review the methods you can use:
“down_sample”: In this method we use a linear layer to map the embedding representation to an Tensor of size (batch_size, seq_len, n_feature_time_series-1). The Bilinear layer then operates along the seq_len dimension as it will then be resized to (batch_size, 1, n_feature_series-1).
“up_sample”: The reverse of down sample scales the features back to embedding_dim.
“repeat”: Repeats the meta-data across the temporal dimension. So for instance if the temporal data is dimension (batch_size, seq_len, n_feature_time_series) then we will repeat the meta-data to be (batch_size, seq_len, embedding_dim).
COVID-19 Forecasting Results (no pre-training):
County | Vanilla (no meta data) | down_sample | repeat | Notes |
---|---|---|---|---|
Palm Beach County FL | 662.233 MSE | 1860 MSE | 1741 MSE | Meta-data does not seem to help. |
Albany County NY | ||||
Middlesex County MA | 250.495 MSE | ? |
COVID-19 Forecasting Results (with pre-training first on other 240 other counties):
County | Vanilla (no meta data) | down_sample | repeat | |
---|---|---|---|---|
Palm Beach County | 99.941 MSE | |||
Middlesex County | 33.745 MSE | |||
Albany County | 0.1058 MSE | |||
Harris County | 35988.954 MSE |
River flow forecasting (no pre-training)
Gage ID | Vanilla (no meta data included) | down_sample | Self-Attention | repeat |
---|---|---|---|---|