From 1fb8f2c682026519700c63f4c8ddb86e38ea8bb6 Mon Sep 17 00:00:00 2001 From: Waymo Research Date: Mon, 21 Dec 2020 14:02:44 -0800 Subject: [PATCH] Merged commit includes the following changes: 348524296 by Waymo Research: Fix tracking metrics build -- PiperOrigin-RevId: 348524296 --- waymo_open_dataset/metrics/ops/BUILD | 1 + .../metrics/ops/py_metrics_ops.py | 61 +++++++++++++++++-- 2 files changed, 57 insertions(+), 5 deletions(-) diff --git a/waymo_open_dataset/metrics/ops/BUILD b/waymo_open_dataset/metrics/ops/BUILD index 3867fdb..6e76d89 100644 --- a/waymo_open_dataset/metrics/ops/BUILD +++ b/waymo_open_dataset/metrics/ops/BUILD @@ -17,6 +17,7 @@ cc_library( "//waymo_open_dataset/protos:metrics_cc_proto", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/types:optional", "@local_config_tf//:libtensorflow_framework", "@local_config_tf//:tf_header_lib", diff --git a/waymo_open_dataset/metrics/ops/py_metrics_ops.py b/waymo_open_dataset/metrics/ops/py_metrics_ops.py index 61982c5..d19c0b6 100644 --- a/waymo_open_dataset/metrics/ops/py_metrics_ops.py +++ b/waymo_open_dataset/metrics/ops/py_metrics_ops.py @@ -24,12 +24,22 @@ tf.compat.v1.resource_loader.get_path_to_datafile( 'detection_metrics_ops.so')) +tracking_metrics_module = tf.load_op_library( + tf.compat.v1.resource_loader.get_path_to_datafile( + 'tracking_metrics_ops.so')) + -def detection_metrics(prediction_bbox, prediction_type, prediction_score, - prediction_frame_id, prediction_overlap_nlz, - ground_truth_bbox, ground_truth_type, - ground_truth_frame_id, ground_truth_difficulty, - config, ground_truth_speed=None): +def detection_metrics(prediction_bbox, + prediction_type, + prediction_score, + prediction_frame_id, + prediction_overlap_nlz, + ground_truth_bbox, + ground_truth_type, + ground_truth_frame_id, + ground_truth_difficulty, + config, + ground_truth_speed=None): """Wraps detection_metrics. See metrics_ops.cc for full documentation.""" if ground_truth_speed is None: num_gt_boxes = tf.shape(ground_truth_bbox)[0] @@ -47,3 +57,44 @@ def detection_metrics(prediction_bbox, prediction_type, prediction_score, ground_truth_difficulty=ground_truth_difficulty, ground_truth_speed=ground_truth_speed, config=config) + + +def tracking_metrics(prediction_bbox, + prediction_type, + prediction_score, + prediction_frame_id, + prediction_sequence_id, + prediction_object_id, + ground_truth_bbox, + ground_truth_type, + ground_truth_frame_id, + ground_truth_sequence_id, + ground_truth_object_id, + ground_truth_difficulty, + config, + prediction_overlap_nlz=None, + ground_truth_speed=None): + """Wraps tracking_metrics. See metrics_ops.cc for full documentation.""" + if ground_truth_speed is None: + num_gt_boxes = tf.shape(ground_truth_bbox)[0] + ground_truth_speed = tf.zeros((num_gt_boxes, 2), dtype=tf.float32) + + if prediction_overlap_nlz is None: + prediction_overlap_nlz = tf.zeros_like(prediction_frame_id, dtype=tf.bool) + + return tracking_metrics_module.tracking_metrics( + prediction_bbox=prediction_bbox, + prediction_type=prediction_type, + prediction_score=prediction_score, + prediction_frame_id=prediction_frame_id, + prediction_sequence_id=prediction_sequence_id, + prediction_object_id=prediction_object_id, + prediction_overlap_nlz=prediction_overlap_nlz, + ground_truth_bbox=ground_truth_bbox, + ground_truth_type=ground_truth_type, + ground_truth_frame_id=ground_truth_frame_id, + ground_truth_sequence_id=ground_truth_sequence_id, + ground_truth_object_id=ground_truth_object_id, + ground_truth_difficulty=ground_truth_difficulty, + ground_truth_speed=ground_truth_speed, + config=config)