Skip to content

Commit

Permalink
support for logging tracker data
Browse files Browse the repository at this point in the history
  • Loading branch information
Kovelja009 committed Apr 7, 2024
1 parent 1a7bfc1 commit 24cb5f9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
40 changes: 31 additions & 9 deletions rivian.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import cv2
import torch
import numpy as np
import time
from inference import run
from models.common import DetectMultiBackend
from sort import Sort
from video_utils import save_video_tracking_data
from ultralytics.utils.plotting import Annotator, colors

def run_object_detection():
Expand All @@ -18,11 +20,21 @@ def run_object_detection():
# camera setup
oak_d = oak.OAK_D(fps=60, width=1920, height=1080)

############################################################
# SORT tracker
############################################################
######################## SORT tracker ######################
# Initialize SORT tracker
mot_tracker = Sort(min_hits=5, max_age=20)

############################################################
##################### Logging data #########################
# Initialize variables
ids_list = [] # List to keep track of unique IDs
frame_count = 0 # Counter to keep track of frame number
wave = 1 # indicator whether video should record wave or not
file_name = 'tracking_data.csv' # File name to save tracking data
video_name = 'video1.mp4' # Video name to save tracking data
elapsed_time = -1 # Time in seconds
should_save = False
# Initialize dictionary to store tracking information
tracks = {
"frame": [],
Expand All @@ -31,18 +43,19 @@ def run_object_detection():
"y1": [],
"x2": [],
"y2": [],
"xc": [],
"yc": [],
"class": [],
"wave": [],
"video": []
}

# Initialize variables
ids_list = [] # List to keep track of unique IDs
frame_count = 0 # Counter to keep track of frame number
############################################################
start_time = time.time()

while True:
frame = oak_d.get_color_frame(show_fps=True)
# Object detection
img, annotated_img, bbox_coord_conf_cls = run(frame=frame, classes=[1], model=model)
img, annotated_img, bbox_coord_conf_cls = run(frame=frame, classes=[0,1,2], model=model)
annotator = Annotator(img, line_width=3, example=str(model.names))

# Update tracker
Expand Down Expand Up @@ -77,14 +90,23 @@ def run_object_detection():
tracks['y1'].append(y1)
tracks['x2'].append(x2)
tracks['y2'].append(y2)
tracks['xc'].append(int((x1 + x2) / 2))
tracks['yc'].append(int((y1 + y2) / 2))
tracks['class'].append(names[i])
tracks['wave'].append(wave)
tracks['video'].append(video_name)

cv2.imshow("Levi", img)
frame_count += 1
cv2.imshow("Levi", img)
current_time = time.time()

# Break the loop if 'q' key is pressed
if cv2.waitKey(1) & 0xFF == ord('q'):
if cv2.waitKey(1) & 0xFF == ord('q') or (current_time - start_time > elapsed_time and elapsed_time > 0):
cv2.destroyAllWindows()
if should_save:
# Save tracking data to CSV file
save_video_tracking_data(tracks, file_name)
print("Data appended to CSV file successfully!")
break

if __name__ == '__main__':
Expand Down
19 changes: 19 additions & 0 deletions video_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import csv

def save_video_tracking_data(tracks, file_name):
"""Save tracking data to a CSV file."""
fieldnames = tracks.keys()

# Write the dictionary to the CSV file in append mode
with open(file_name, 'a', newline='') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

# Check if the file is empty (i.e., if it needs a header row)
# If the file is empty, write the header row
if csvfile.tell() == 0:
writer.writeheader()

# Write each row of the dictionary to the CSV file
for i in range(len(tracks['frame'])):
row = {field: tracks[field][i] for field in fieldnames}
writer.writerow(row)

0 comments on commit 24cb5f9

Please sign in to comment.