Skip to content

Commit

Permalink
Merge pull request google-ai-edge#171 from googlesamples/mediapipe-sa…
Browse files Browse the repository at this point in the history
…mples-py-livestream

Added Object Detector Live Stream sample and updated existing sample
  • Loading branch information
PaulTR authored Jun 19, 2023
2 parents 845d432 + 3553b8c commit c48f257
Show file tree
Hide file tree
Showing 6 changed files with 177 additions and 138 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@

import argparse
import time
import audio_record

from mediapipe.tasks import python
from mediapipe.tasks.python.audio.core import audio_record
from mediapipe.tasks.python.components import containers
from mediapipe.tasks.python import audio
from utils import Plotter
Expand Down Expand Up @@ -93,7 +93,6 @@ def save_result(result: audio.AudioClassifierResult, timestamp_ms: int):
audio_data.load_from_array(data)
classifier.classify_async(audio_data, round(last_inference_time * 1000))

# print(classification_result_list)
# # Plot the classification results.
if classification_result_list:
print(classification_result_list)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
sounddevice
mediapipe
135 changes: 135 additions & 0 deletions examples/object_detection/python/object_detector_live_stream/detect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import argparse
import sys
import time

import cv2
import mediapipe as mp

from mediapipe.tasks import python
from mediapipe.tasks.python import vision

from utils import visualize


def run(model: str, camera_id: int, width: int, height: int) -> None:
"""Continuously run inference on images acquired from the camera.
Args:
model: Name of the TFLite object detection model.
camera_id: The camera id to be passed to OpenCV.
width: The width of the frame captured from the camera.
height: The height of the frame captured from the camera.
"""

# Variables to calculate FPS
counter, fps = 0, 0
start_time = time.time()

# Start capturing video input from the camera
cap = cv2.VideoCapture(camera_id)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)

# Visualization parameters
row_size = 20 # pixels
left_margin = 24 # pixels
text_color = (0, 0, 255) # red
font_size = 1
font_thickness = 1
fps_avg_frame_count = 10

detection_result_list = []

def visualize_callback(result: vision.ObjectDetectorResult,
output_image: mp.Image, timestamp_ms: int):
result.timestamp_ms = timestamp_ms
detection_result_list.append(result)


# Initialize the object detection model
base_options = python.BaseOptions(model_asset_path=model)
options = vision.ObjectDetectorOptions(base_options=base_options,
running_mode=vision.RunningMode.LIVE_STREAM,
score_threshold=0.5,
result_callback=visualize_callback)
detector = vision.ObjectDetector.create_from_options(options)


# Continuously capture images from the camera and run inference
while cap.isOpened():
success, image = cap.read()
if not success:
sys.exit(
'ERROR: Unable to read from webcam. Please verify your webcam settings.'
)

counter += 1
image = cv2.flip(image, 1)

# Convert the image from BGR to RGB as required by the TFLite model.
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_image)

# Run object detection using the model.
detector.detect_async(mp_image, counter)
current_frame = mp_image.numpy_view()
current_frame = cv2.cvtColor(current_frame, cv2.COLOR_RGB2BGR)

# Calculate the FPS
if counter % fps_avg_frame_count == 0:
end_time = time.time()
fps = fps_avg_frame_count / (end_time - start_time)
start_time = time.time()

# Show the FPS
fps_text = 'FPS = {:.1f}'.format(fps)
text_location = (left_margin, row_size)
cv2.putText(current_frame, fps_text, text_location, cv2.FONT_HERSHEY_PLAIN,
font_size, text_color, font_thickness)

if detection_result_list:
print(detection_result_list)
vis_image = visualize(current_frame, detection_result_list[0])
cv2.imshow('object_detector', vis_image)
detection_result_list.clear()
else:
cv2.imshow('object_detector', current_frame)

# Stop the program if the ESC key is pressed.
if cv2.waitKey(1) == 27:
break

detector.close()
cap.release()
cv2.destroyAllWindows()


def main():
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
'--model',
help='Path of the object detection model.',
required=False,
default='efficientdet.tflite')
parser.add_argument(
'--cameraId', help='Id of camera.', required=False, type=int, default=0)
parser.add_argument(
'--frameWidth',
help='Width of frame to capture from camera.',
required=False,
type=int,
default=1280)
parser.add_argument(
'--frameHeight',
help='Height of frame to capture from camera.',
required=False,
type=int,
default=720)
args = parser.parse_args()

run(args.model, int(args.cameraId), args.frameWidth, args.frameHeight)


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mediapipe
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import cv2
import numpy as np


MARGIN = 10 # pixels
ROW_SIZE = 10 # pixels
FONT_SIZE = 1
FONT_THICKNESS = 1
TEXT_COLOR = (255, 0, 0) # red


def visualize(
image,
detection_result
) -> np.ndarray:
"""Draws bounding boxes on the input image and return it.
Args:
image: The input RGB image.
detection_result: The list of all "Detection" entities to be visualized.
Returns:
Image with bounding boxes.
"""
for detection in detection_result.detections:
# Draw bounding_box
bbox = detection.bounding_box
start_point = bbox.origin_x, bbox.origin_y
end_point = bbox.origin_x + bbox.width, bbox.origin_y + bbox.height
cv2.rectangle(image, start_point, end_point, TEXT_COLOR, 3)

# Draw label and score
category = detection.categories[0]
category_name = category.category_name
probability = round(category.score, 2)
result_text = category_name + ' (' + str(probability) + ')'
text_location = (MARGIN + bbox.origin_x,
MARGIN + ROW_SIZE + bbox.origin_y)
cv2.putText(image, result_text, text_location, cv2.FONT_HERSHEY_PLAIN,
FONT_SIZE, TEXT_COLOR, FONT_THICKNESS)

return image

0 comments on commit c48f257

Please sign in to comment.