💡 In this work, we propose a plug-and-play loss that replaces the widely used Winner-Takes-All loss for motion forecasting models.
annealed Winner-Takes-All (aWTA), a better loss for training motion forecasting models:
🔥Powered by Hydra, Pytorch-lightinig, and WandB, the framework is easy to configure, train and log.
- Create a new conda environment
conda create -n unitraj python=3.9
conda activate unitraj
- Install ScenarioNet: https://scenarionet.readthedocs.io/en/latest/install.html
pip --no-cache-dir install "metadrive-simulator>=0.4.1.1"
python -m metadrive.examples.profile_metadrive # test your installation
cd scenarionet
sudo apt-get update
sudo apt install libspatialindex-dev
pip --no-cache-dir install -e .
pip install --no-cache-dir av2 --upgrade
- Install Unitraj:
git clone https://github.com/valeoai/MF_aWTA
pip install -r requirements.txt
pip install --no-cache-dir torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
python setup.py develop #only for the first install
export PYTHONPATH=$PYTHONPATH:UniTraj
export PYTHONPATH=$PYTHONPATH:UniTraj/unitraj/models/mtr/ops
Known issues and solutions:
-
Make sure to have the compiled
knn_cuda.cpython-39-x86_64-linux-gnu.so
in/UniTraj/unitraj/models/mtr/ops/knn
Otherwise, it means that the commandpython setup.py develop
didn’t work well when install UniTraj -
if you have path issue when running 'train.py' or 'predict.py', you can try to insert the absolute path of unitraj and
/UniTraj/unitraj/models/mtr/ops/knn
at the beginning of 'train.py' and 'predict.py'
import os
sys.path.append("/Path/TO/UniTraj/unitraj/models/mtr/ops/")
sys.path.append("/Path/TO/UniTraj/")
You can verify the installation of UniTraj by running the training script:
cd unitraj
python train.py config-name=mtr_av2_awta
The model will be trained on several sample data.
There are three main components in UniTraj: dataset, model, and config. The structure of the code is as follows:
unitraj
├── configs
│ ├── config.yaml
│ ├── method
│ │ ├── autobot.yaml
│ │ ├── MTR.yaml
│ │ ├── wayformer.yaml
├── datasets
│ ├── base_dataset.py
│ ├── autobot_dataset.py
│ ├── wayformer_dataset.py
│ ├── MTR_dataset.py
├── models
│ ├── autobot
│ ├── mtr
│ ├── wayformer
│ ├── base_model
├── utils
There is a base config, dataset, and model class, and each model has its own config, dataset, and model class that inherit from the base class.
The code is modified from UniTraj. UniTraj also takes data from ScenarioNet as input. Process the data with ScenarioNet in advance.
- You need to download Argoverse 2 and Waymo Open Motion Datasest.
- Convert the data into ScenarioNet format:
- For Argoverse 2:
python -m scenarionet.convert_argoverse2 -d /path/to/your/database/train/ –raw_data_path /path/to/your/raw_data/train/
python -m scenarionet.convert_argoverse2 -d /path/to/your/database/val/ –raw_data_path /path/to/your/raw_data/val/
- For WOMD
python -m scenarionet.convert_waymo -d /path/to/your/database/training/ --raw_data_path /path/to/your/database/train/ --num_workers 64
python -m scenarionet.convert_waymo -d /path/to/your/database/validation/ --raw_data_path /path/to/your/database/val/ --num_workers 64
UniTraj uses Hydra to manage configuration files.
The universal configuration file is located in unitraj/config/config.yaml
.
Each model has its own configuration file in unitraj/config/method/
, for
example, unitraj/config/method/autobot.yaml
.
The configuration file is organized in a hierarchical structure, and the configuration of the model is inherited from the universal configuration file.
Please refer to config.yaml and method/mtr.yaml for more details.
The configurations for each method and dataset are provided in ./configs
. The top 5 best models based on minFDE will be saved under ./lightning_logs
and tensorboard logs are also saved in the same folder (loss, metrics and some visualizations during training.)
For example, for running MTR with Argoverse 2, you can run (you may need to specify the paths of Argoverse 2 scenario data in ./configs/mtr_av2_awta.yaml
):
cd unitraj
python train.py --config-name=mtr_av2_awta
By default, the model is trained with 8 GPUs, you can modify the number of GPUs in the corresponding config file like mtr_av2_awta
, and the batch size could be changed in configs/method/MTR_wo_anchor.yaml
.
- Download the checkpoints from the Release tagged model_weights and put them into
./model_zoo/
. - Run the evaluation, as an example, to evaluate MTR with av2, you can run:
cd unitraj
python predict.py --config-name=mtr_av2_awta_predict
aWTA is a standalone loss compatible with all motion forecasting models that formally use the WTA loss. You only need to change the WTA loss into aWTA. Here is an example:
From WTA loss:
def wta_loss(prediction, gt, gt_valid_mask):
'''
prediction: predicted forecasts, of shape [batch, hypotheses, timesteps, 2]
gt: ground-truth forecasting trajectory, of shape [batch, timesteps, 2]
gt_valid_mask: ground-truth forecasting mask indicating the valid future steps, of shape [batch, timesteps]
'''
# compute prediction, gt distance, such as ADE
distance = compute_ade(prediction, gt, gt_valid_mask)
# select the prediction with the minimum distance to the ground truth
nearest_hypothesis_idxs = distance.argmin(dim=-1) # [batch]
nearest_hypothesis_bs_idxs = torch.arange(nearest_hypothesis_idxs.shape[0]).type_as(nearest_hypothesis_idxs) # [batch]
# extract the L2 distance between the selected hypothesis and gt
loss_reg = distance[nearest_hypothesis_bs_idxs, nearest_hypothesis_idxs] # [batch]
return loss_reg.mean() # mean over the batch
To aWTA loss:
def awta_loss(prediction, gt, gt_valid_mask, cur_temperature):
'''
prediction: predicted forecasts, of shape [batch, hypotheses, timesteps, 2]
gt: ground-truth forecasting trajectory, of shape [batch, timesteps, 2]
gt_valid_mask: ground-truth forecasting mask indicating the valid future steps, of shape [batch, timesteps]
cur_temperature: the current temperature for aWTA
'''
# compute prediction, gt distance, such as ADE
distance = compute_ade(prediction, gt, gt_valid_mask)
# calculate the weights q(t): softmin of the distance, controlled by the current temperature
awta_weights = torch.softmax(-1.0*distance/cur_temperature, dim=-1).detach() # [batch, hypotheses]
# weight the distance by awta weights
loss_reg = distance * awta_weights # [batch, hypotheses]
return loss_reg.sum(-1).mean() # sum over weighted hypotheses and average over the batch
def temperature_scheduler(init_temperature, cur_epoch, exp_base):
'''
init_temperature: initial temperature
cur_epoch: current number of epochs
exp_base: exponential scheduler base
'''
return init_temperature*exp_base**cur_epoch
This work is released under the Apache 2.0 license.
@article{xu2025awta,
title={Annealed Winner-Takes-All for Motion Forecasting},
author = {Yihong Xu and
Victor Letzelter and
Mickaël Chen and
\'{E}loi Zablocki and
Matthieu Cord},
journal = {under review},
year = {2025}
}