-
Notifications
You must be signed in to change notification settings - Fork 78
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add dataclass for AV2 MF challenge submissions (#41)
- Loading branch information
Showing
3 changed files
with
217 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# <Copyright 2022, Argo AI, LLC. Released under the MIT license.> | ||
|
||
"""Constants used throughout the AV2 motion forecasting API.""" | ||
|
||
from typing import Final | ||
|
||
AV2_SCENARIO_STEP_HZ: Final[int] = 10 | ||
AV2_SCENARIO_OBS_TIMESTEPS: Final[int] = 50 | ||
AV2_SCENARIO_PRED_TIMESTEPS: Final[int] = 60 | ||
AV2_SCENARIO_TOTAL_TIMESTEPS: Final[int] = 110 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# <Copyright 2022, Argo AI, LLC. Released under the MIT license.> | ||
|
||
"""Classes and utilities used to build submissions for the AV2 motion forecasting challenge.""" | ||
|
||
from __future__ import annotations | ||
|
||
from collections import defaultdict | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Dict, Final, List, Tuple | ||
|
||
import numpy as np | ||
import pandas as pd | ||
|
||
from av2.datasets.motion_forecasting.constants import AV2_SCENARIO_PRED_TIMESTEPS | ||
from av2.utils.typing import NDArrayNumber | ||
|
||
# Define type aliases used for submission | ||
PredictedTrajectories = NDArrayNumber # (K, AV2_SCENARIO_PRED_TIMESTEPS, 2) | ||
PredictionProbabilities = NDArrayNumber # (K,) | ||
TrackPredictions = Tuple[PredictedTrajectories, PredictionProbabilities] | ||
ScenarioPredictions = Dict[str, TrackPredictions] # Mapping from track ID to track predictions | ||
PredictionRow = Tuple[str, str, float, PredictedTrajectories, PredictionProbabilities] | ||
|
||
SUBMISSION_COL_NAMES: Final[List[str]] = [ | ||
"scenario_id", | ||
"track_id", | ||
"probability", | ||
"predicted_trajectory_x", | ||
"predicted_trajectory_y", | ||
] | ||
EXPECTED_PREDICTION_SHAPE: Final[Tuple[int, ...]] = (AV2_SCENARIO_PRED_TIMESTEPS, 2) | ||
|
||
|
||
@dataclass(frozen=True) | ||
class ChallengeSubmission: | ||
"""Representation used to build submission for the AV2 motion forecasting challenge. | ||
Args: | ||
predictions: Container for all predictions to score - mapping from scenario ID to scenario-level predictions. | ||
""" | ||
|
||
predictions: Dict[str, ScenarioPredictions] | ||
|
||
def __post_init__(self) -> None: | ||
"""Validate that each of the submitted predictions has the appropriate shape and normalized probabilities. | ||
Raises: | ||
ValueError: If predictions for at least one track are not of shape (*, AV2_SCENARIO_PRED_TIMESTEPS, 2). | ||
ValueError: If for any track, prediction probabilities doesn't match the number of predicted trajectories. | ||
ValueError: If prediction probabilities for at least one track do not sum to 1. | ||
""" | ||
for scenario_id, scenario_predictions in self.predictions.items(): | ||
for track_id, (predicted_trajectories, prediction_probabilities) in scenario_predictions.items(): | ||
# Validate that predicted trajectories are of the correct shape | ||
if predicted_trajectories[0].shape[-2:] != EXPECTED_PREDICTION_SHAPE: | ||
raise ValueError( | ||
f"Prediction for track {track_id} in {scenario_id} found with invalid shape " | ||
f"{predicted_trajectories.shape}, expected (*, {AV2_SCENARIO_PRED_TIMESTEPS}, 2)." | ||
) | ||
|
||
# Validate that the number of predicted trajectories and prediction probabilities matches | ||
if len(predicted_trajectories) != len(prediction_probabilities): | ||
raise ValueError( | ||
f"Prediction for track {track_id} in {scenario_id} has " | ||
f"{len(predicted_trajectories)} predicted trajectories, but " | ||
f"{len(prediction_probabilities)} probabilities." | ||
) | ||
|
||
# Validate that prediction probabilities for each track are normalized | ||
prediction_probability_sum = np.sum(prediction_probabilities) | ||
probability_is_normalized = np.isclose(1, prediction_probability_sum) | ||
if not probability_is_normalized: | ||
raise ValueError( | ||
f"Track probabilities must sum to 1, but probabilities for track {track_id} in {scenario_id} " | ||
f"sum up to {prediction_probability_sum}." | ||
) | ||
|
||
@classmethod | ||
def from_parquet(cls, submission_file_path: Path) -> ChallengeSubmission: | ||
"""Load challenge submission from serialized parquet representation on disk. | ||
Args: | ||
submission_file_path: Path to the serialized submission file (in parquet format). | ||
Returns: | ||
Challenge submission object initialized using the loaded data. | ||
""" | ||
# Load submission data and sort rows by descending probability | ||
submission_df = pd.read_parquet(submission_file_path) | ||
submission_df.sort_values(by="probability", inplace=True, ascending=False) | ||
|
||
# From serialized data, build scenario-track mapping for predictions | ||
submission_dict: Dict[str, ScenarioPredictions] = defaultdict(lambda: defaultdict(dict)) # type: ignore | ||
for (scenario_id, track_id), track_df in submission_df.groupby(["scenario_id", "track_id"]): | ||
predicted_trajectories_x = np.stack(track_df.loc[:, "predicted_trajectory_x"].values.tolist()) | ||
predicted_trajectories_y = np.stack(track_df.loc[:, "predicted_trajectory_y"].values.tolist()) | ||
predicted_trajectories = np.stack((predicted_trajectories_x, predicted_trajectories_y), axis=-1) | ||
prediction_probabilities = np.array(track_df.loc[:, "probability"].values.tolist()) | ||
|
||
submission_dict[scenario_id][track_id] = (predicted_trajectories, prediction_probabilities) | ||
|
||
return cls(predictions=submission_dict) | ||
|
||
def to_parquet(self, submission_file_path: Path) -> None: | ||
"""Serialize and save challenge submission on disk using parquet representation. | ||
Args: | ||
submission_file_path: Path to the desired location for serialized submission file. | ||
""" | ||
prediction_rows: List[PredictionRow] = [] | ||
|
||
# Build list of rows for the submission dataframe | ||
for scenario_id, scenario_predictions in self.predictions.items(): | ||
for track_id, (predicted_trajectories, prediction_probabilities) in scenario_predictions.items(): | ||
for prediction_idx in range(len(predicted_trajectories)): | ||
prediction_rows.append( | ||
( | ||
scenario_id, | ||
track_id, | ||
prediction_probabilities[prediction_idx], | ||
predicted_trajectories[prediction_idx, :, 0], | ||
predicted_trajectories[prediction_idx, :, 1], | ||
) | ||
) | ||
|
||
# Build submission dataframe and serialize as parquet file | ||
submission_df = pd.DataFrame(prediction_rows, columns=SUBMISSION_COL_NAMES) | ||
submission_df.to_parquet(submission_file_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# <Copyright 2022, Argo AI, LLC. Released under the MIT license.> | ||
|
||
"""Unit tests for AV2 motion forecasting challenge-related utilities.""" | ||
|
||
from contextlib import AbstractContextManager | ||
from contextlib import nullcontext as does_not_raise | ||
from pathlib import Path | ||
from typing import Dict | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
from av2.datasets.motion_forecasting.constants import AV2_SCENARIO_PRED_TIMESTEPS | ||
from av2.datasets.motion_forecasting.eval.submission import ChallengeSubmission, ScenarioPredictions, TrackPredictions | ||
|
||
# Build valid submission with predictions for a single track in a single scenario | ||
valid_track_predictions: TrackPredictions = (np.zeros((2, AV2_SCENARIO_PRED_TIMESTEPS, 2)), np.array([0.6, 0.4])) | ||
valid_scenario_predictions: ScenarioPredictions = {"valid_track_id": valid_track_predictions} | ||
valid_submission_predictions = {"valid_scenario_id": valid_scenario_predictions} | ||
|
||
# Build invalid track submission with incorrect prediction length | ||
too_short_track_predictions: TrackPredictions = (np.zeros((1, AV2_SCENARIO_PRED_TIMESTEPS - 1, 2)), np.array([1.0])) | ||
too_short_scenario_predictions = {"invalid_track_id": too_short_track_predictions} | ||
too_short_submission_predictions = {"invaild_scenario_id": too_short_scenario_predictions} | ||
|
||
# Build invalid track submission with mismatched predicted trajectories and probabilities | ||
mismatched_track_predictions: TrackPredictions = (np.zeros((1, AV2_SCENARIO_PRED_TIMESTEPS, 2)), np.array([0.5, 0.5])) | ||
mismatched_scenario_predictions = {"invalid_track_id": mismatched_track_predictions} | ||
mismatched_submission_predictions = {"invaild_scenario_id": mismatched_scenario_predictions} | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test_submission_dict, expectation", | ||
[ | ||
(valid_submission_predictions, does_not_raise()), | ||
(too_short_submission_predictions, pytest.raises(ValueError)), | ||
(mismatched_submission_predictions, pytest.raises(ValueError)), | ||
], | ||
ids=["valid", "wrong_shape_trajectory", "mismatched_trajectory_probability_shape"], | ||
) | ||
def test_challenge_submission_data_validation( | ||
test_submission_dict: Dict[str, ScenarioPredictions], expectation: AbstractContextManager # type: ignore | ||
) -> None: | ||
"""Test that validation of submitted trajectories works as expected during challenge submission initialization. | ||
Args: | ||
test_submission_dict: Scenario-level predictions used to initialize challenge submission. | ||
expectation: Context manager to capture the appropriate exception for each test case. | ||
""" | ||
with expectation: | ||
ChallengeSubmission(predictions=test_submission_dict) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"test_submission_dict", | ||
[(valid_submission_predictions)], | ||
ids=["valid_submission"], | ||
) | ||
def test_challenge_submission_serialization(tmpdir: Path, test_submission_dict: Dict[str, ScenarioPredictions]) -> None: | ||
"""Test that challenge submissions can be serialized/deserialized without changes to internal state. | ||
Args: | ||
tmpdir: tmpdir: Temp directory used in the test (provided via built-in fixture). | ||
test_submission_dict: Scenario-level predictions used to initialize challenge submission. | ||
""" | ||
# Serialize submission to parquet file and load data back from disk | ||
submission_file_path = tmpdir / "submission.parquet" | ||
submission = ChallengeSubmission(predictions=test_submission_dict) | ||
submission.to_parquet(submission_file_path) | ||
deserialized_submission = ChallengeSubmission.from_parquet(submission_file_path) | ||
|
||
# Check that deserialized data matches original data exactly | ||
for scenario_id, scenario_predictions in submission.predictions.items(): | ||
for track_id, (expected_trajectories, expected_probabilities) in scenario_predictions.items(): | ||
deserialized_predictions = deserialized_submission.predictions[scenario_id][track_id] | ||
(deserialized_trajectories, deserialized_probabilities) = deserialized_predictions | ||
assert np.array_equal(deserialized_trajectories, expected_trajectories) | ||
assert np.array_equal(deserialized_probabilities, expected_probabilities) |