forked from google-ai-edge/mediapipe
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request google-ai-edge#171 from googlesamples/mediapipe-sa…
…mples-py-livestream Added Object Detector Live Stream sample and updated existing sample
- Loading branch information
Showing
6 changed files
with
177 additions
and
138 deletions.
There are no files selected for viewing
135 changes: 0 additions & 135 deletions
135
examples/audio_classifier/python/audio_classification_live_stream/audio_record.py
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 change: 0 additions & 1 deletion
1
examples/audio_classifier/python/audio_classification_live_stream/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1 @@ | ||
sounddevice | ||
mediapipe |
135 changes: 135 additions & 0 deletions
135
examples/object_detection/python/object_detector_live_stream/detect.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
import argparse | ||
import sys | ||
import time | ||
|
||
import cv2 | ||
import mediapipe as mp | ||
|
||
from mediapipe.tasks import python | ||
from mediapipe.tasks.python import vision | ||
|
||
from utils import visualize | ||
|
||
|
||
def run(model: str, camera_id: int, width: int, height: int) -> None: | ||
"""Continuously run inference on images acquired from the camera. | ||
Args: | ||
model: Name of the TFLite object detection model. | ||
camera_id: The camera id to be passed to OpenCV. | ||
width: The width of the frame captured from the camera. | ||
height: The height of the frame captured from the camera. | ||
""" | ||
|
||
# Variables to calculate FPS | ||
counter, fps = 0, 0 | ||
start_time = time.time() | ||
|
||
# Start capturing video input from the camera | ||
cap = cv2.VideoCapture(camera_id) | ||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width) | ||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height) | ||
|
||
# Visualization parameters | ||
row_size = 20 # pixels | ||
left_margin = 24 # pixels | ||
text_color = (0, 0, 255) # red | ||
font_size = 1 | ||
font_thickness = 1 | ||
fps_avg_frame_count = 10 | ||
|
||
detection_result_list = [] | ||
|
||
def visualize_callback(result: vision.ObjectDetectorResult, | ||
output_image: mp.Image, timestamp_ms: int): | ||
result.timestamp_ms = timestamp_ms | ||
detection_result_list.append(result) | ||
|
||
|
||
# Initialize the object detection model | ||
base_options = python.BaseOptions(model_asset_path=model) | ||
options = vision.ObjectDetectorOptions(base_options=base_options, | ||
running_mode=vision.RunningMode.LIVE_STREAM, | ||
score_threshold=0.5, | ||
result_callback=visualize_callback) | ||
detector = vision.ObjectDetector.create_from_options(options) | ||
|
||
|
||
# Continuously capture images from the camera and run inference | ||
while cap.isOpened(): | ||
success, image = cap.read() | ||
if not success: | ||
sys.exit( | ||
'ERROR: Unable to read from webcam. Please verify your webcam settings.' | ||
) | ||
|
||
counter += 1 | ||
image = cv2.flip(image, 1) | ||
|
||
# Convert the image from BGR to RGB as required by the TFLite model. | ||
rgb_image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | ||
mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb_image) | ||
|
||
# Run object detection using the model. | ||
detector.detect_async(mp_image, counter) | ||
current_frame = mp_image.numpy_view() | ||
current_frame = cv2.cvtColor(current_frame, cv2.COLOR_RGB2BGR) | ||
|
||
# Calculate the FPS | ||
if counter % fps_avg_frame_count == 0: | ||
end_time = time.time() | ||
fps = fps_avg_frame_count / (end_time - start_time) | ||
start_time = time.time() | ||
|
||
# Show the FPS | ||
fps_text = 'FPS = {:.1f}'.format(fps) | ||
text_location = (left_margin, row_size) | ||
cv2.putText(current_frame, fps_text, text_location, cv2.FONT_HERSHEY_PLAIN, | ||
font_size, text_color, font_thickness) | ||
|
||
if detection_result_list: | ||
print(detection_result_list) | ||
vis_image = visualize(current_frame, detection_result_list[0]) | ||
cv2.imshow('object_detector', vis_image) | ||
detection_result_list.clear() | ||
else: | ||
cv2.imshow('object_detector', current_frame) | ||
|
||
# Stop the program if the ESC key is pressed. | ||
if cv2.waitKey(1) == 27: | ||
break | ||
|
||
detector.close() | ||
cap.release() | ||
cv2.destroyAllWindows() | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser( | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter) | ||
parser.add_argument( | ||
'--model', | ||
help='Path of the object detection model.', | ||
required=False, | ||
default='efficientdet.tflite') | ||
parser.add_argument( | ||
'--cameraId', help='Id of camera.', required=False, type=int, default=0) | ||
parser.add_argument( | ||
'--frameWidth', | ||
help='Width of frame to capture from camera.', | ||
required=False, | ||
type=int, | ||
default=1280) | ||
parser.add_argument( | ||
'--frameHeight', | ||
help='Height of frame to capture from camera.', | ||
required=False, | ||
type=int, | ||
default=720) | ||
args = parser.parse_args() | ||
|
||
run(args.model, int(args.cameraId), args.frameWidth, args.frameHeight) | ||
|
||
|
||
if __name__ == '__main__': | ||
main() |
1 change: 1 addition & 0 deletions
1
examples/object_detection/python/object_detector_live_stream/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
mediapipe |
40 changes: 40 additions & 0 deletions
40
examples/object_detection/python/object_detector_live_stream/utils.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import cv2 | ||
import numpy as np | ||
|
||
|
||
MARGIN = 10 # pixels | ||
ROW_SIZE = 10 # pixels | ||
FONT_SIZE = 1 | ||
FONT_THICKNESS = 1 | ||
TEXT_COLOR = (255, 0, 0) # red | ||
|
||
|
||
def visualize( | ||
image, | ||
detection_result | ||
) -> np.ndarray: | ||
"""Draws bounding boxes on the input image and return it. | ||
Args: | ||
image: The input RGB image. | ||
detection_result: The list of all "Detection" entities to be visualized. | ||
Returns: | ||
Image with bounding boxes. | ||
""" | ||
for detection in detection_result.detections: | ||
# Draw bounding_box | ||
bbox = detection.bounding_box | ||
start_point = bbox.origin_x, bbox.origin_y | ||
end_point = bbox.origin_x + bbox.width, bbox.origin_y + bbox.height | ||
cv2.rectangle(image, start_point, end_point, TEXT_COLOR, 3) | ||
|
||
# Draw label and score | ||
category = detection.categories[0] | ||
category_name = category.category_name | ||
probability = round(category.score, 2) | ||
result_text = category_name + ' (' + str(probability) + ')' | ||
text_location = (MARGIN + bbox.origin_x, | ||
MARGIN + ROW_SIZE + bbox.origin_y) | ||
cv2.putText(image, result_text, text_location, cv2.FONT_HERSHEY_PLAIN, | ||
FONT_SIZE, TEXT_COLOR, FONT_THICKNESS) | ||
|
||
return image |