-
Notifications
You must be signed in to change notification settings - Fork 133
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Qdtrack #183
Qdtrack #183
Changes from all commits
57ff355
fd40877
c264290
e6edf81
1680e01
d4da9f5
5fc8be1
e7fe4e5
021e20d
c54a88d
5a2a550
14581b8
61d85da
f4ad20f
d28e735
25d4b7d
14d7d14
f226257
0ea61c7
a0b771c
ee9b532
a6cac05
503dc42
fa035c1
6d3f771
b235b95
7b9624f
4b2ef4d
cba1d1a
eb503be
9998488
c8d89e5
3dcfbfb
c971388
cb69924
fcde4aa
e0d12f3
6418774
66b361c
16b3e46
a6ba4f5
aaf9002
b4c756a
2c9bcbd
2b76f41
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,6 +84,15 @@ | |
['kitti_tracking', 'coco', 'mot', 'nuscenes'], | ||
'CenterTrack available models') | ||
|
||
# QDTrack tracking flags. | ||
flags.DEFINE_string( | ||
'qd_track_model_path', 'dependencies/models/tracking/qd_track/' + | ||
'qdtrack-frcnn_r50_fpn_12e_bdd100k-13328aed.pth', 'Path to the model') | ||
flags.DEFINE_string( | ||
'qd_track_config_path', | ||
'dependencies/qdtrack/configs/qdtrack-frcnn_r50_fpn_12e_bdd100k.py', | ||
'Path to the model') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess this is the Path to the model configuration. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yea, it's the qd track model path but it's currently the path to the bdd model, rather than the waymo model |
||
|
||
# Lane detection flags. | ||
flags.DEFINE_float('lane_detection_gpu_memory_fraction', 0.3, | ||
'GPU memory fraction allocated to Lanenet') | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import time | ||
|
||
import erdos | ||
|
||
from pylot.perception.detection.obstacle import Obstacle | ||
from pylot.perception.detection.utils import BoundingBox2D, \ | ||
OBSTACLE_LABELS | ||
from pylot.perception.messages import ObstaclesMessage | ||
|
||
|
||
class QdTrackOperator(erdos.Operator): | ||
def __init__(self, camera_stream, obstacle_tracking_stream, flags, | ||
camera_setup): | ||
from qdtrack.apis import init_model | ||
|
||
camera_stream.add_callback(self.on_frame_msg, | ||
[obstacle_tracking_stream]) | ||
self._flags = flags | ||
self._logger = erdos.utils.setup_logging(self.config.name, | ||
self.config.log_file_name) | ||
self._csv_logger = erdos.utils.setup_csv_logging( | ||
self.config.name + '-csv', self.config.csv_log_file_name) | ||
self._camera_setup = camera_setup | ||
self.model = init_model(self._flags.qd_track_config_path, | ||
checkpoint=self._flags.qd_track_model_path, | ||
device='cuda:0', | ||
cfg_options=None) | ||
self.classes = ('pedestrian', 'rider', 'car', 'bus', 'truck', | ||
'bicycle', 'motorcycle', 'train') | ||
self.frame_id = 0 | ||
|
||
@staticmethod | ||
def connect(camera_stream): | ||
obstacle_tracking_stream = erdos.WriteStream() | ||
return [obstacle_tracking_stream] | ||
|
||
def destroy(self): | ||
self._logger.warn('destroying {}'.format(self.config.name)) | ||
|
||
@erdos.profile_method() | ||
def on_frame_msg(self, msg, obstacle_tracking_stream): | ||
"""Invoked when a FrameMessage is received on the camera stream.""" | ||
from qdtrack.apis import inference_model | ||
|
||
self._logger.debug('@{}: {} received frame'.format( | ||
msg.timestamp, self.config.name)) | ||
assert msg.frame.encoding == 'BGR', 'Expects BGR frames' | ||
start_time = time.time() | ||
image_np = msg.frame.as_bgr_numpy_array() | ||
results = inference_model(self.model, image_np, self.frame_id) | ||
self.frame_id += 1 | ||
|
||
bbox_result, track_result = results.values() | ||
obstacles = [] | ||
for k, v in track_result.items(): | ||
track_id = k | ||
bbox = v['bbox'][None, :] | ||
score = bbox[4] | ||
label_id = v['label'] | ||
label = self.classes[label_id] | ||
if label in ['pedestrian', 'rider']: | ||
label = 'person' | ||
if label in OBSTACLE_LABELS: | ||
bounding_box_2D = BoundingBox2D(bbox[0], bbox[2], bbox[1], | ||
bbox[3]) | ||
obstacles.append( | ||
Obstacle(bounding_box_2D, | ||
score, | ||
label, | ||
track_id, | ||
bounding_box_2D=bounding_box_2D)) | ||
runtime = (time.time() - start_time) * 1000 | ||
obstacle_tracking_stream.send( | ||
ObstaclesMessage(msg.timestamp, obstacles, runtime)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How reliable is mmdetection's installation? Did it work out of the box for you?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unfortunately it is not reliable and installation depends on torch and cuda version