Versions Compared

Key

  • This line was added.
  • This line was removed.
  • Formatting was changed.

Succinct summary: An older but still strong performing RNN based method that also incorporates an attention mechanism to improve forecasting. Can only forecast one step at a time

Metadata incorporation and results:

We have modified DA-RNN to utilize meta-data by including a merging layer in the encoder_layer. This module can be called with several different properties:

Code Block
languagepy
class MetaMerger(nn.Module):
    def __init__(self, meta_params, meta_method, embed_shape, in_shape):
        super().__init__()
        self.method_layer = meta_method
        if meta_method == "down_sample":
            self.initial_layer = torch.nn.Linear(embed_shape, in_shape)
        elif meta_method == "up_sample":
            self.initial_layer = torch.nn.Linear(in_shape, embed_shape)
        self.model_merger = MergingModel(meta_params["method"], meta_params["params"])

Lets review the methods you can use:

“down_sample”: In this method we use a linear layer to map the embedding representation to an Tensor of size (batch_size, seq_len, n_feature_time_series-1). The Bilinear layer then operates along the seq_len dimension as it will then be resized to (batch_size, 1, n_feature_series-1).

“up_sample”: The reverse of down sample scales the features back to embedding_dim.

“repeat”: Repeats the meta-data across the temporal dimension. So for instance if the temporal data is dimension (batch_size, seq_len, n_feature_time_series) then we will repeat the meta-data to be (batch_size, seq_len, embedding_dim).

COVID-19 Forecasting Results (no pre-training):

County

Vanilla (no meta data)

down_sample

repeat

Notes

Palm Beach County FL

662.233 MSE

1860 MSE

1741 MSE

Meta-data does not seem to help.

Albany County NY

Middlesex County MA

240.994 MSE

250.495 MSE

?

COVID-19 Forecasting Results (with pre-training first on other 240 other counties):

County

Vanilla (no meta data)

down_sample

repeat

Palm Beach County

99.941 MSE

Middlesex County

33.745 MSE

Albany County

0.1058 MSE

Harris County

35988.954 MSE

River flow forecasting (no pre-training)

Gage ID

Vanilla (no meta data included)

down_sample

Self-Attention

repeat