Skip to content

Feature-aligned N-BEATS with Sinkhorn divergence (ICLR 2024)

Notifications You must be signed in to change notification settings

leejoonhun/fan-beats

Repository files navigation

Feature-aligned N-BEATS

Official PyTorch Implementation of Feature-aligned N-BEATS with Sinkhorn divergence.

Data

Data should have form of data/$SUPERDOMAIN/$DOMAIN.csv, with three columns:

  • time denotes the time index.
  • series denotes the series index.
  • value denotes the value of the time series at the given time index.

Source

Data used in the paper is obtained from the following sources:

Usage

python main.py --source-domains $SOURCE_DOMAIN1 $SOURCE_DOMAIN2 ... \
               --target-domain $TARGET_DOMAIN \
               --forecast-horizon $FORECAST_HORIZON \
               --lookback-multiple $LOOKBACK_MULTIPLE \
               --model $MODEL \
               --loss $LOSS \
               --regularizer $REGULARIZER \
               --temperature $TEMPERATURE \
               --scaler $SCALER \
               --metric $METRIC \
               --learning-rate $LEARNING_RATE \
               --num-lr-cycles $NUM_LR_CYCLES \
               --batch-size $BATCH_SIZE \
               --num-iters $NUM_ITERS \
               --seed $SEED \
               --dtype $DTYPE \
               --data-size $DATA_SIZE

The detailed descriptions about the arguments are as follows:

Argument Description Default
source_domains Source domains ${\mathcal{D}^k}_k$
target_domain Target domain $\mathcal{D}^T$
forecast_horizon Forecast horizon $\alpha$ 10
lookback_multiple Lookback multiple $\beta/\alpha$ 5
model Model architecture $\mathfrak{F}$ "NHiTS"
loss Forecasting loss function $\mathcal{L}$ "SMAPE"
regularizer Regularizer measure $\mathcal{L}_\mathrm{align}$
NOTE: "None" for vanilla model
"Sinkhorn"
temperature Regularizing temperature $\lambda$ 1.0
scaler Normalizing function $\sigma$ "softmax"
metric Evaluation metric for validation and test "SMAPE"
learning_rate Learning rate $\eta$ 2e-5
num_lr_cycles Number of learning rate cycles
NOTE: torch.optim.lr_scheduler.CyclicLR(mode="triangular2") is used (ref)
50
batch_size Batch size $B$ 2**12
num_iters Number of iterations 1000
seed Random seed 0
dtype Data type used for torch and numpy "float32"
data_size Fixed data size for each domain
NOTE: "None" to use all data
75000

Citation

@inproceedings{lee2024fanbeats,
  title={Feature-aligned N-BEATS with Sinkhorn divergence},
  author={Lee, Joonhun and Jeon, Myeongho and Kang, Myungjoo and Park, Kyunghyun},
  booktitle={The Twelfth International Conference on Learning Representations},
  year={2024}
}

Acknowledgement

We would like to acknowledge the significant contributions of the official N-BEATS implementation to our work. Our models are implemented based on their codebase.