This is an implementation of the methodology in "RNN with Particle Flow for Probabilistic Spatio-temporal Forecasting". The arXiv version is available at https://arxiv.org/pdf/2106.06064.pdf
You can
- reproduce the results for AGCGRU+flow, DCGRU+flow, and GRU+flow algorithms in our paper and supplementary material with the checkpoints of our trained models.
- re-train the models with the same hyperparameter setting and evaluate the forecasts.
- download the preprocessed datasets and checkpoints from https://drive.google.com/file/d/1tnaTTBVxUOdSpc5hxYUpCqUriYVVr5BP/view?usp=sharing
- unzip the downloaded zip file inside the folder
rnn_flow
- run script
python inference_for_dir.py --search_dir ckpt/pretrained --output_dir ./result_pretrained --gpu_id 0
Then the results will be saved in the ./result_pretrained
.
- download the preprocessed datasets and checkpoints and decompress the downloaded zip file inside the folder
rnn_flow
- train model with script (e.g.) :
python training.py --dataset PEMS07 --rnn_type agcgru --cost mae --gpu_id 0 --trial_id --ckpt_dir ./ckpt --data_dir ./data
- evaluate with script:
python inference.py --dataset PEMS07 --rnn_type agcgru --cost mae --gpu_id 0 --trial_id 0 --ckpt_dir ./ckpt --output_dir ./result
Note that:
- the setting of
--dataset
,--rnn_type
,--cost
,--trial_id
and--ckpt_dir
should be consistent in training and evaluation. --trial_id
controls the random seed; by changing it, we can get multiple-trials of results with different random seeds.
--rnn_type
:
agcgru
dcgru
gru
--cost
:
mae
(for MAE, MAPE and RMSE computation)nll
(for CRPS, P10QL and P90QL computation)
The adjacency matrices of the graphs of the datasets are stored in ./data/sensor_graph
(--graph_dir
).
The model configs are stored in ./data/model_config
(--config_dir
)
If the folders are changed, please make the arguments of train.py
and inference.py
consistent.
The code has been tested with python 3.7.0
with the following packages installed (along with their dependencies):
tensorflow-gpu==1.15.0
numpy==1.19.2
networkx==2.5
pandas==1.1.3
scikit-learn==0.23.2
PyYAML==5.3.1
scipy==1.5.4
In addition, for running with gpu, the code has been tested with CUDA 10.0
and cuDNN 7.4
.
Please cite our paper if you use this code in your own work:
@inproceedings{pal2021,
author={S. Pal and L. Ma and Y. Zhang and M. Coates},
booktitle={Proc. Int. Conf. Machine Learning},
title={{RNN} with Particle Flow for Probabilistic Spatio-temporal Forecasting},
month={Jul.},
year={2021},
address={Virtual Conference}}