-
Notifications
You must be signed in to change notification settings - Fork 78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add dataclass for AV2 MF challenge submissions #41
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks really nice. Left a few minor comments.
NUM_PRED_TIMESTEPS: Final[int] = 60 | ||
|
||
|
||
@dataclass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this be frozen?
""" | ||
for scenario_predictions in self.predictions.values(): | ||
for track_predictions in scenario_predictions.values(): | ||
if track_predictions.shape[-2:] != (NUM_PRED_TIMESTEPS, 2): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it'd be nice to assign track_predictions.shape[-2:]
to a descriptive variable name.
|
||
TrackPredictions = NDArrayNumber # Track predictions are expected to be of shape (*, NUM_PRED_TIMESTEPS, 2) | ||
ScenarioPredictions = Dict[str, TrackPredictions] # Mapping from track ID to track predictions | ||
PredictionRow = Tuple[str, str, int, NDArrayNumber, NDArrayNumber] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] A docstring can be useful here as well
SUBMISSION_COL_NAMES: Final[List[str]] = [ | ||
"scenario_id", | ||
"track_id", | ||
"prediction_rank", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean we don't expect probabilities?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to add probabilities as an input.
"predicted_trajectory_x", | ||
"predicted_trajectory_y", | ||
] | ||
NUM_PRED_TIMESTEPS: Final[int] = 60 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Can we make this more visible? Seems like an important constant. Probably a config shared across the API as well.
ce28356
to
eb6d202
Compare
PR Summary
This PR adds a dataclass to help researchers build submission files for the AV2 motion forecasting challenge.
Challenge submissions are initialized using a mapping from scenario and track IDs to a numpy array of predicted trajectories. The data dict is then serialized to a parquet file, which will be submittable to an evaluation server.
Testing
Added unit tests for data validation and (de)serialization.
In order to ensure this PR works as intended, it is:
Compliance with Standards
As the author, I certify that this PR conforms to the following standards: