Skip to content

Commit

Permalink
Fix tracking metrics build
Browse files Browse the repository at this point in the history
  • Loading branch information
Pei Sun committed Dec 22, 2020
1 parent 1fb8f2c commit 71faf43
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 36 deletions.
30 changes: 3 additions & 27 deletions waymo_open_dataset/metrics/ops/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,40 +28,15 @@ py_library(
name = "py_metrics_ops",
srcs = ["py_metrics_ops.py"],
data = [
":detection_metrics_ops.so",
":tracking_metrics_ops.so",
":metrics_ops.so",
],
)

cc_binary(
name = "detection_metrics_ops.so",
name = "metrics_ops.so",
srcs = [
"detection_metrics_ops.cc",
"metrics_ops.cc",
],
copts = [
"-pthread",
],
linkshared = 1,
deps = [
":utils",
"//waymo_open_dataset:label_cc_proto",
"//waymo_open_dataset/metrics:config_util",
"//waymo_open_dataset/metrics:detection_metrics",
"//waymo_open_dataset/protos:breakdown_cc_proto",
"//waymo_open_dataset/protos:metrics_cc_proto",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@local_config_tf//:libtensorflow_framework",
"@local_config_tf//:tf_header_lib",
],
)

cc_binary(
name = "tracking_metrics_ops.so",
srcs = [
"metrics_ops.cc",
"tracking_metrics_ops.cc",
],
copts = [
Expand All @@ -72,6 +47,7 @@ cc_binary(
":utils",
"//waymo_open_dataset:label_cc_proto",
"//waymo_open_dataset/metrics:config_util",
"//waymo_open_dataset/metrics:detection_metrics",
"//waymo_open_dataset/metrics:tracking_metrics",
"//waymo_open_dataset/protos:breakdown_cc_proto",
"//waymo_open_dataset/protos:metrics_cc_proto",
Expand Down
13 changes: 4 additions & 9 deletions waymo_open_dataset/metrics/ops/py_metrics_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,8 @@

import tensorflow as tf

detection_metrics_module = tf.load_op_library(
tf.compat.v1.resource_loader.get_path_to_datafile(
'detection_metrics_ops.so'))

tracking_metrics_module = tf.load_op_library(
tf.compat.v1.resource_loader.get_path_to_datafile(
'tracking_metrics_ops.so'))
metrics_module = tf.load_op_library(
tf.compat.v1.resource_loader.get_path_to_datafile('metrics_ops.so'))


def detection_metrics(prediction_bbox,
Expand All @@ -45,7 +40,7 @@ def detection_metrics(prediction_bbox,
num_gt_boxes = tf.shape(ground_truth_bbox)[0]
ground_truth_speed = tf.zeros((num_gt_boxes, 2), dtype=tf.float32)

return detection_metrics_module.detection_metrics(
return metrics_module.detection_metrics(
prediction_bbox=prediction_bbox,
prediction_type=prediction_type,
prediction_score=prediction_score,
Expand Down Expand Up @@ -82,7 +77,7 @@ def tracking_metrics(prediction_bbox,
if prediction_overlap_nlz is None:
prediction_overlap_nlz = tf.zeros_like(prediction_frame_id, dtype=tf.bool)

return tracking_metrics_module.tracking_metrics(
return metrics_module.tracking_metrics(
prediction_bbox=prediction_bbox,
prediction_type=prediction_type,
prediction_score=prediction_score,
Expand Down

0 comments on commit 71faf43

Please sign in to comment.