Skip to content

Commit

Permalink
Add units to variables (#170)
Browse files Browse the repository at this point in the history
* Distance Threshold and ROI Pruning

* Distance Threshold and ROI Pruning

* Hierarchical AP

* Update eval.py

* Fix partial lint + typing.

* Reformat.

* fix some linting and pytest issues

* PR Changes

* Fixed Linting

* Black Autoformat

* Update src/av2/evaluation/detection/constants.py

Co-authored-by: Benjamin Wilson <[email protected]>

* Update src/av2/evaluation/detection/eval.py

Co-authored-by: Benjamin Wilson <[email protected]>

* PR Refinements

* Minor Fix

* tidy code, fix forecasting evaluation

* fix imports, return tuned metric values from tracking evaluation

* fix linting

* fix typing

* fix ruff

* Test revert reformatting.

* Undo formatting.

* Undo additional formatting.

* Revert a few more formatting changes.

* Simplify expressions.

* Refactor eval.

* Fix imports.

* Fix lint.

* Fix lint.

* Consolidate constants.

* Clean up.

* Add unit test stubs.

* Update typing.

* Fix typing + fix lint in detection eval.

* Change detection args names.

* Fix typing.

* Fix lint.

* Fix mypy.

* Reduce number of conversions.

* Make lca columns a constant.

* fix group_frames bug

* Add Units

* Minor Fix

* Fix duplicate imports.

* Remove additional duplicate lines.

* Add files via upload

* Remove PYC

* Remove __pycache__

* Fix mypy

* Run black.

---------

Co-authored-by: Neehar Peri <[email protected]>
Co-authored-by: Neehar Peri <[email protected]>
Co-authored-by: Neehar Peri <[email protected]>
Co-authored-by: Benjamin Wilson <[email protected]>
Co-authored-by: Redrew <[email protected]>
Co-authored-by: Neehar Peri <[email protected]>
Co-authored-by: Neehar Peri <[email protected]>
Co-authored-by: Neehar Peri <[email protected]>
  • Loading branch information
9 people authored Apr 25, 2023
1 parent 45d7026 commit 98eb457
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 273 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ share/python-wheels/
.installed.cfg
*.egg
MANIFEST
*.pyc

# PyInstaller
# Usually these files are written by a python script from a template
Expand Down
16 changes: 9 additions & 7 deletions src/av2/evaluation/detection/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@

import numpy as np
import pandas as pd
from tqdm import tqdm
from av2.evaluation.detection.constants import (
HIERARCHY,
LCA,
Expand Down Expand Up @@ -442,14 +443,14 @@ def evaluate_hierarchy(
gts_categories_list.append(sweep_gts_categories)

num_dts = len(sweep_dts)
num_gts = len(sweep_dts)
num_gts = len(sweep_gts)
dts_uuids_list.extend(num_dts * [uuid])
gts_uuids_list.extend(num_gts * [uuid])

dts_npy = np.concatenate(dts).astype(np.float64)
gts_npy = np.concatenate(gts).astype(np.float64)
dts_categories_npy = np.concatenate(dts_categories).astype(np.object_)
gts_categories_npy = np.concatenate(gts_categories).astype(np.object_)
dts_npy = np.concatenate(dts_list).astype(np.float64)
gts_npy = np.concatenate(gts_list).astype(np.float64)
dts_categories_npy = np.concatenate(dts_categories_list).astype(np.object_)
gts_categories_npy = np.concatenate(gts_categories_list).astype(np.object_)
dts_uuids_npy = np.array(dts_uuids_list)
gts_uuids_npy = np.array(gts_uuids_list)

Expand Down Expand Up @@ -487,8 +488,9 @@ def evaluate_hierarchy(
)

logger.info("Starting evaluation ...")
with mp.get_context("spawn").Pool(processes=n_jobs) as p:
accumulate_outputs: Any = p.starmap(accumulate_hierarchy, accumulate_hierarchy_args_list)
accumulate_outputs = []
for accumulate_args in tqdm(accumulate_hierarchy_args_list):
accumulate_outputs.append(accumulate_hierarchy(*accumulate_args))

super_categories = list(HIERARCHY.keys())
metrics = np.zeros((len(cfg.categories), len(HIERARCHY.keys())))
Expand Down
14 changes: 9 additions & 5 deletions src/av2/evaluation/detection/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def accumulate_hierarchy(
fp: Dict[int, Any] = {}
gt_name: Dict[int, List[Any]] = {}
pred_name: Dict[int, List[Any]] = {}
taken: Dict[int, Set[Tuple[Any, Any]]] = {}
taken: Dict[int, Set[Tuple[Any, Any, Any]]] = {}
for i in range(len(cfg.affinity_thresholds_m)):
tp[i] = []
fp[i] = []
Expand All @@ -292,7 +292,11 @@ def accumulate_hierarchy(
min_dist = len(cfg.affinity_thresholds_m) * [np.inf]
match_gt_idx = len(cfg.affinity_thresholds_m) * [None]

keep_sweep = gts_uuids == np.array([gts.shape[0] * [pred_uuid]]).squeeze()
if len(gts_uuids) > 0:
keep_sweep = np.all(gts_uuids == np.array([gts.shape[0] * [pred_uuid]]).squeeze(), axis=1)
else:
keep_sweep = []

gt_ind_sweep = np.arange(gts.shape[0])[keep_sweep]
gts_sweep = gts[keep_sweep]
gts_cats_sweep = gts_cats[keep_sweep]
Expand All @@ -303,7 +307,7 @@ def accumulate_hierarchy(

# Find closest match among ground truth boxes
for i in range(len(cfg.affinity_thresholds_m)):
if gt_cat == cat and not (pred_uuid, gt_idx) in taken[i]:
if gt_cat == cat and not (pred_uuid[0], pred_uuid[1], gt_idx) in taken[i]:
this_distance = dist_mat[pred_idx][gt_idx]
if this_distance < min_dist[i]:
min_dist[i] = this_distance
Expand All @@ -316,7 +320,7 @@ def accumulate_hierarchy(
# Find closest match among ground truth boxes

for i in range(len(cfg.affinity_thresholds_m)):
if not is_match[i] and not (pred_uuid, gt_idx) in taken[i]:
if not is_match[i] and not (pred_uuid[0], pred_uuid[1], gt_idx) in taken[i]:
this_distance = dist_mat[pred_idx][gt_idx]
if this_distance < min_dist[i]:
min_dist[i] = this_distance
Expand All @@ -330,7 +334,7 @@ def accumulate_hierarchy(

for i in range(len(cfg.affinity_thresholds_m)):
if is_match[i]:
taken[i].add((pred_uuid, gt_idx))
taken[i].add((pred_uuid[0], pred_uuid[1], gt_idx))
tp[i].append(1)
fp[i].append(0)

Expand Down
94 changes: 0 additions & 94 deletions src/av2/evaluation/forecasting/SUBMISSION_FORMAT.md

This file was deleted.

76 changes: 38 additions & 38 deletions src/av2/evaluation/forecasting/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def evaluate(
ground_truth: ForecastSequences = convert_forecast_labels(raw_ground_truth)
ground_truth = filter_max_dist(ground_truth, max_range_m)

utils.annotate_frame_metadata(predictions, ground_truth, ["ego_translation"])
utils.annotate_frame_metadata(predictions, ground_truth, ["ego_translation_m"])
predictions = filter_max_dist(predictions, max_range_m)

if dataset_dir is not None:
Expand All @@ -77,20 +77,20 @@ def evaluate(
pred = []

for agent in gt:
if agent["future_translation"].shape[0] < 1:
if agent["future_translation_m"].shape[0] < 1:
continue

agent["seq_id"] = seq_id
agent["timestamp"] = timestamp_ns
agent["velocity"] = utils.agent_velocity(agent)
agent["timestamp_ns"] = timestamp_ns
agent["velocity_m_per_s"] = utils.agent_velocity_m_per_s(agent)
agent["trajectory_type"] = utils.trajectory_type(agent, CATEGORY_TO_VELOCITY_M_PER_S)

gt_agents.append(agent)

for agent in pred:
agent["seq_id"] = seq_id
agent["timestamp"] = timestamp_ns
agent["velocity"] = utils.agent_velocity(agent)
agent["timestamp_ns"] = timestamp_ns
agent["velocity_m_per_s"] = utils.agent_velocity_m_per_s(agent)
agent["trajectory_type"] = utils.trajectory_type(agent, CATEGORY_TO_VELOCITY_M_PER_S)

pred_agents.append(agent)
Expand Down Expand Up @@ -138,8 +138,8 @@ def accumulate(
"""Perform matching between predicted and ground truth trajectories.
Args:
pred_agents: List of predicted trajectories for a given log_id and timestamp.
gt_agents: List of ground truth trajectories for a given log_id and timestamp.
pred_agents: List of predicted trajectories for a given log_id and timestamp_ns.
gt_agents: List of ground truth trajectories for a given log_id and timestamp_ns.
top_k: Number of future trajectories to consider when evaluating Forecastin AP, ADE and FDE (K=5 by default).
class_name: Match class name (e.g. car, pedestrian, bicycle) to determine if a trajectory is included
in evaluation.
Expand All @@ -165,7 +165,7 @@ def match(gt: str, pred: str, profile: str) -> bool:
sortind = [i for (v, i) in sorted((v, i) for (i, v) in enumerate(conf))][::-1]
gt_agents_by_frame = defaultdict(list)
for agent in gt:
gt_agents_by_frame[(agent["seq_id"], agent["timestamp"])].append(agent)
gt_agents_by_frame[(agent["seq_id"], agent["timestamp_ns"])].append(agent)

npos = len(gt)
# ---------------------------------------------
Expand All @@ -179,12 +179,12 @@ def match(gt: str, pred: str, profile: str) -> bool:
min_dist = np.inf
match_gt_idx = None

gt_agents_in_frame = gt_agents_by_frame[(pred_agent["seq_id"], pred_agent["timestamp"])]
gt_agents_in_frame = gt_agents_by_frame[(pred_agent["seq_id"], pred_agent["timestamp_ns"])]
for gt_idx, gt_agent in enumerate(gt_agents_in_frame):
if not (pred_agent["seq_id"], pred_agent["timestamp"], gt_idx) in taken:
if not (pred_agent["seq_id"], pred_agent["timestamp_ns"], gt_idx) in taken:
# Find closest match among ground truth boxes
this_distance = utils.center_distance(
gt_agent["current_translation"], pred_agent["current_translation"]
gt_agent["current_translation_m"], pred_agent["current_translation_m"]
)
if this_distance < min_dist:
min_dist = this_distance
Expand All @@ -194,18 +194,18 @@ def match(gt: str, pred: str, profile: str) -> bool:
is_match = min_dist < threshold

if is_match and match_gt_idx is not None:
taken.add((pred_agent["seq_id"], pred_agent["timestamp"], match_gt_idx))
taken.add((pred_agent["seq_id"], pred_agent["timestamp_ns"], match_gt_idx))
gt_match_agent = gt_agents_in_frame[match_gt_idx]

gt_len = gt_match_agent["future_translation"].shape[0]
gt_len = gt_match_agent["future_translation_m"].shape[0]
forecast_match_th = [threshold + constants.FORECAST_SCALAR[i] * velocity for i in range(gt_len + 1)]

if top_k == 1:
ind = cast(int, np.argmax(pred_agent["score"]))
forecast_dist = [
utils.center_distance(
gt_match_agent["future_translation"][i],
pred_agent["prediction"][ind][i],
gt_match_agent["future_translation_m"][i],
pred_agent["prediction_m"][ind][i],
)
for i in range(gt_len)
]
Expand All @@ -221,8 +221,8 @@ def match(gt: str, pred: str, profile: str) -> bool:
for ind in range(top_k):
curr_forecast_dist = [
utils.center_distance(
gt_match_agent["future_translation"][i],
pred_agent["prediction"][ind][i],
gt_match_agent["future_translation_m"][i],
pred_agent["prediction_m"][ind][i],
)
for i in range(gt_len)
]
Expand Down Expand Up @@ -303,27 +303,27 @@ def convert_forecast_labels(labels: Any) -> Any:
frame_dict = {}
for frame_idx, frame in enumerate(frames):
forecast_instances = []
for instance in utils.array_dict_iterator(frame, len(frame["translation"])):
for instance in utils.array_dict_iterator(frame, len(frame["translation_m"])):
future_translations: Any = []
for future_frame in frames[frame_idx + 1 : frame_idx + 1 + constants.NUM_TIMESTEPS]:
if instance["track_id"] not in future_frame["track_id"]:
break
future_translations.append(
future_frame["translation"][future_frame["track_id"] == instance["track_id"]][0]
future_frame["translation_m"][future_frame["track_id"] == instance["track_id"]][0]
)

if len(future_translations) == 0:
continue

forecast_instances.append(
{
"current_translation": instance["translation"][:2],
"ego_translation": instance["ego_translation"][:2],
"future_translation": np.array(future_translations)[:, :2],
"current_translation_m": instance["translation_m"][:2],
"ego_translation_m": instance["ego_translation_m"][:2],
"future_translation_m": np.array(future_translations)[:, :2],
"name": instance["name"],
"size": instance["size"],
"yaw": instance["yaw"],
"velocity": instance["velocity"][:2],
"velocity_m_per_s": instance["velocity_m_per_s"][:2],
"label": instance["label"],
}
)
Expand All @@ -346,14 +346,14 @@ def filter_max_dist(forecasts: ForecastSequences, max_range_m: int) -> ForecastS
Dictionary of tracks.
"""
for seq_id in forecasts.keys():
for timestamp in forecasts[seq_id].keys():
for timestamp_ns in forecasts[seq_id].keys():
keep_forecasts = [
agent
for agent in forecasts[seq_id][timestamp]
if "ego_translation" in agent
and np.linalg.norm(agent["current_translation"] - agent["ego_translation"]) < max_range_m
for agent in forecasts[seq_id][timestamp_ns]
if "ego_translation_m" in agent
and np.linalg.norm(agent["current_translation_m"] - agent["ego_translation_m"]) < max_range_m
]
forecasts[seq_id][timestamp] = keep_forecasts
forecasts[seq_id][timestamp_ns] = keep_forecasts

return forecasts

Expand Down Expand Up @@ -386,23 +386,23 @@ def filter_drivable_area(forecasts: ForecastSequences, dataset_dir: str) -> Fore
for log_id in log_ids:
avm = log_id_to_avm[log_id]

for timestamp in forecasts[log_id]:
city_SE3_ego = log_id_to_timestamped_poses[log_id][int(timestamp)]
for timestamp_ns in forecasts[log_id]:
city_SE3_ego = log_id_to_timestamped_poses[log_id][int(timestamp_ns)]

translation, size, quat = [], [], []
translation_m, size, quat = [], [], []

if len(forecasts[log_id][timestamp]) == 0:
if len(forecasts[log_id][timestamp_ns]) == 0:
continue

for box in forecasts[log_id][timestamp]:
translation.append(box["current_translation"] - box["ego_translation"])
for box in forecasts[log_id][timestamp_ns]:
translation_m.append(box["current_translation_m"] - box["ego_translation_m"])
size.append(box["size"])
quat.append(yaw_to_quaternion3d(box["yaw"]))

score = np.ones((len(translation), 1))
score = np.ones((len(translation_m), 1))
boxes = np.concatenate(
[
np.array(translation),
np.array(translation_m),
np.array(size),
np.array(quat),
np.array(score),
Expand All @@ -411,7 +411,7 @@ def filter_drivable_area(forecasts: ForecastSequences, dataset_dir: str) -> Fore
)

is_evaluated = compute_objects_in_roi_mask(boxes, city_SE3_ego, avm)
forecasts[log_id][timestamp] = list(np.array(forecasts[log_id][timestamp])[is_evaluated])
forecasts[log_id][timestamp_ns] = list(np.array(forecasts[log_id][timestamp_ns])[is_evaluated])

return forecasts

Expand Down
Loading

0 comments on commit 98eb457

Please sign in to comment.