Skip to content

Commit

Permalink
Merged commit includes the following changes:
Browse files Browse the repository at this point in the history
348524296  by Waymo Research:

    Fix tracking metrics build

--

PiperOrigin-RevId: 348524296
  • Loading branch information
Waymo Research authored and peisun1115 committed Dec 21, 2020
1 parent 7c503b8 commit 1fb8f2c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 5 deletions.
1 change: 1 addition & 0 deletions waymo_open_dataset/metrics/ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ cc_library(
"//waymo_open_dataset/protos:metrics_cc_proto",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/types:optional",
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
Expand Down
61 changes: 56 additions & 5 deletions waymo_open_dataset/metrics/ops/py_metrics_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,22 @@
tf.compat.v1.resource_loader.get_path_to_datafile(
'detection_metrics_ops.so'))

tracking_metrics_module = tf.load_op_library(
tf.compat.v1.resource_loader.get_path_to_datafile(
'tracking_metrics_ops.so'))


def detection_metrics(prediction_bbox, prediction_type, prediction_score,
prediction_frame_id, prediction_overlap_nlz,
ground_truth_bbox, ground_truth_type,
ground_truth_frame_id, ground_truth_difficulty,
config, ground_truth_speed=None):
def detection_metrics(prediction_bbox,
prediction_type,
prediction_score,
prediction_frame_id,
prediction_overlap_nlz,
ground_truth_bbox,
ground_truth_type,
ground_truth_frame_id,
ground_truth_difficulty,
config,
ground_truth_speed=None):
"""Wraps detection_metrics. See metrics_ops.cc for full documentation."""
if ground_truth_speed is None:
num_gt_boxes = tf.shape(ground_truth_bbox)[0]
Expand All @@ -47,3 +57,44 @@ def detection_metrics(prediction_bbox, prediction_type, prediction_score,
ground_truth_difficulty=ground_truth_difficulty,
ground_truth_speed=ground_truth_speed,
config=config)


def tracking_metrics(prediction_bbox,
prediction_type,
prediction_score,
prediction_frame_id,
prediction_sequence_id,
prediction_object_id,
ground_truth_bbox,
ground_truth_type,
ground_truth_frame_id,
ground_truth_sequence_id,
ground_truth_object_id,
ground_truth_difficulty,
config,
prediction_overlap_nlz=None,
ground_truth_speed=None):
"""Wraps tracking_metrics. See metrics_ops.cc for full documentation."""
if ground_truth_speed is None:
num_gt_boxes = tf.shape(ground_truth_bbox)[0]
ground_truth_speed = tf.zeros((num_gt_boxes, 2), dtype=tf.float32)

if prediction_overlap_nlz is None:
prediction_overlap_nlz = tf.zeros_like(prediction_frame_id, dtype=tf.bool)

return tracking_metrics_module.tracking_metrics(
prediction_bbox=prediction_bbox,
prediction_type=prediction_type,
prediction_score=prediction_score,
prediction_frame_id=prediction_frame_id,
prediction_sequence_id=prediction_sequence_id,
prediction_object_id=prediction_object_id,
prediction_overlap_nlz=prediction_overlap_nlz,
ground_truth_bbox=ground_truth_bbox,
ground_truth_type=ground_truth_type,
ground_truth_frame_id=ground_truth_frame_id,
ground_truth_sequence_id=ground_truth_sequence_id,
ground_truth_object_id=ground_truth_object_id,
ground_truth_difficulty=ground_truth_difficulty,
ground_truth_speed=ground_truth_speed,
config=config)

0 comments on commit 1fb8f2c

Please sign in to comment.