Skip to content

Commit

Permalink
feat: implement PoseLandmarker
Browse files Browse the repository at this point in the history
  • Loading branch information
homuler committed Sep 2, 2023
1 parent e89f0b4 commit 9ebc72f
Show file tree
Hide file tree
Showing 7 changed files with 426 additions and 0 deletions.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,208 @@
// Copyright (c) 2023 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

using System.Collections.Generic;

namespace Mediapipe.Tasks.Vision.PoseLandmarker
{
public sealed class PoseLandmarker : Core.BaseVisionTaskApi
{
private const string _IMAGE_IN_STREAM_NAME = "image_in";
private const string _IMAGE_OUT_STREAM_NAME = "image_out";
private const string _IMAGE_TAG = "IMAGE";
private const string _NORM_RECT_STREAM_NAME = "norm_rect_in";
private const string _NORM_RECT_TAG = "NORM_RECT";
private const string _SEGMENTATION_MASK_STREAM_NAME = "segmentation_mask";
private const string _SEGMENTATION_MASK_TAG = "SEGMENTATION_MASK";
private const string _NORM_LANDMARKS_STREAM_NAME = "norm_landmarks";
private const string _NORM_LANDMARKS_TAG = "NORM_LANDMARKS";
private const string _POSE_WORLD_LANDMARKS_STREAM_NAME = "world_landmarks";
private const string _POSE_WORLD_LANDMARKS_TAG = "WORLD_LANDMARKS";
private const string _TASK_GRAPH_NAME = "mediapipe.tasks.vision.pose_landmarker.PoseLandmarkerGraph";

private const int _MICRO_SECONDS_PER_MILLISECOND = 1000;

#pragma warning disable IDE0052 // Remove unread private members
/// <remarks>
/// keep reference to prevent GC from collecting the callback instance.
/// </remarks>
private readonly Tasks.Core.TaskRunner.PacketsCallback _packetCallback;
#pragma warning restore IDE0052

private PoseLandmarker(
CalculatorGraphConfig graphConfig,
Core.RunningMode runningMode,
Tasks.Core.TaskRunner.PacketsCallback packetCallback) : base(graphConfig, runningMode, packetCallback)
{
_packetCallback = packetCallback;
}

/// <summary>
/// Creates an <see cref="PoseLandmarker" /> object from a TensorFlow Lite model and the default <see cref="PoseLandmarkerOptions" />.
///
/// Note that the created <see cref="PoseLandmarker" /> instance is in image mode,
/// for detecting pose landmarks on single image inputs.
/// </summary>
/// <param name="modelPath">Path to the model.</param>
/// <returns>
/// <see cref="PoseLandmarker" /> object that's created from the model and the default <see cref="PoseLandmarkerOptions" />.
/// </returns>
public static PoseLandmarker CreateFromModelPath(string modelPath)
{
var baseOptions = new Tasks.Core.BaseOptions(modelAssetPath: modelPath);
var options = new PoseLandmarkerOptions(baseOptions, runningMode: Core.RunningMode.IMAGE);
return CreateFromOptions(options);
}

/// <summary>
/// Creates the <see cref="PoseLandmarker" /> object from <paramref name="PoseLandmarkerOptions" />.
/// </summary>
/// <param name="options">Options for the pose landmarker task.</param>
/// <returns>
/// <see cref="PoseLandmarker" /> object that's created from <paramref name="options" />.
/// </returns>
public static PoseLandmarker CreateFromOptions(PoseLandmarkerOptions options)
{
var outputStreams = new List<string> {
string.Join(":", _NORM_LANDMARKS_TAG, _NORM_LANDMARKS_STREAM_NAME),
string.Join(":", _POSE_WORLD_LANDMARKS_TAG, _POSE_WORLD_LANDMARKS_STREAM_NAME),
string.Join(":", _IMAGE_TAG, _IMAGE_OUT_STREAM_NAME),
};
if (options.outputSegmentationMasks)
{
outputStreams.Add(string.Join(":", _SEGMENTATION_MASK_TAG, _SEGMENTATION_MASK_STREAM_NAME));
}
var taskInfo = new Tasks.Core.TaskInfo<PoseLandmarkerOptions>(
taskGraph: _TASK_GRAPH_NAME,
inputStreams: new List<string> {
string.Join(":", _IMAGE_TAG, _IMAGE_IN_STREAM_NAME),
string.Join(":", _NORM_RECT_TAG, _NORM_RECT_STREAM_NAME),
},
outputStreams: outputStreams,
taskOptions: options);

return new PoseLandmarker(
taskInfo.GenerateGraphConfig(options.runningMode == Core.RunningMode.LIVE_STREAM),
options.runningMode,
BuildPacketsCallback(options.resultCallback));
}

/// <summary>
/// Performs pose landmarks detection on the provided MediaPipe Image.
///
/// Only use this method when the <see cref="PoseLandmarker" /> is created with the image running mode.
/// The image can be of any size with format RGB or RGBA.
/// </summary>
/// <param name="image">MediaPipe Image.</param>
/// <param name="imageProcessingOptions">Options for image processing.</param>
/// <returns>
/// The pose landmarks detection results.
/// </returns>
public PoseLandmarkerResult Detect(Image image, Core.ImageProcessingOptions? imageProcessingOptions = null)
{
var normalizedRect = ConvertToNormalizedRect(imageProcessingOptions, image, roiAllowed: false);

var packetMap = new PacketMap();
packetMap.Emplace(_IMAGE_IN_STREAM_NAME, new ImagePacket(image));
packetMap.Emplace(_NORM_RECT_STREAM_NAME, new NormalizedRectPacket(normalizedRect));
var outputPackets = ProcessImageData(packetMap);

return BuildPoseLandmarkerResult(outputPackets);
}

/// <summary>
/// Performs pose landmarks detection on the provided video frames.
///
/// Only use this method when the PoseLandmarker is created with the video
/// running mode. It's required to provide the video frame's timestamp (in
/// milliseconds) along with the video frame. The input timestamps should be
/// monotonically increasing for adjacent calls of this method.
/// </summary>
/// <returns>
/// The pose landmarks detection results.
/// </returns>
public PoseLandmarkerResult DetectForVideo(Image image, int timestampMs, Core.ImageProcessingOptions? imageProcessingOptions = null)
{
var normalizedRect = ConvertToNormalizedRect(imageProcessingOptions, image, roiAllowed: false);

PacketMap outputPackets = null;
using (var timestamp = new Timestamp(timestampMs * _MICRO_SECONDS_PER_MILLISECOND))
{
var packetMap = new PacketMap();
packetMap.Emplace(_IMAGE_IN_STREAM_NAME, new ImagePacket(image, timestamp));
packetMap.Emplace(_NORM_RECT_STREAM_NAME, new NormalizedRectPacket(normalizedRect).At(timestamp));
outputPackets = ProcessVideoData(packetMap);
}

return BuildPoseLandmarkerResult(outputPackets);
}

/// <summary>
/// Sends live image data to perform pose landmarks detection.
///
/// Only use this method when the PoseLandmarker is created with the live stream
/// running mode. The input timestamps should be monotonically increasing for
/// adjacent calls of this method. This method will return immediately after the
/// input image is accepted. The results will be available via the
/// <see cref="PoseLandmarkerOptions.ResultCallback" /> provided in the <see cref="PoseLandmarkerOptions" />.
/// The <see cref="DetectAsync" /> method is designed to process live stream data such as camera
/// input. To lower the overall latency, pose landmarker may drop the input
/// images if needed. In other words, it's not guaranteed to have output per
/// input image.
public void DetectAsync(Image image, int timestampMs, Core.ImageProcessingOptions? imageProcessingOptions = null)
{
var normalizedRect = ConvertToNormalizedRect(imageProcessingOptions, image, roiAllowed: false);

using (var timestamp = new Timestamp(timestampMs * _MICRO_SECONDS_PER_MILLISECOND))
{
var packetMap = new PacketMap();
packetMap.Emplace(_IMAGE_IN_STREAM_NAME, new ImagePacket(image, timestamp));
packetMap.Emplace(_NORM_RECT_STREAM_NAME, new NormalizedRectPacket(normalizedRect).At(timestamp));

SendLiveStreamData(packetMap);
}
}

private static Tasks.Core.TaskRunner.PacketsCallback BuildPacketsCallback(PoseLandmarkerOptions.ResultCallback resultCallback)
{
if (resultCallback == null)
{
return null;
}

return (PacketMap outputPackets) =>
{
var outImagePacket = outputPackets.At<ImagePacket, Image>(_IMAGE_OUT_STREAM_NAME);
if (outImagePacket == null || outImagePacket.IsEmpty())
{
return;
}

var image = outImagePacket.Get();
var handLandmarkerResult = BuildPoseLandmarkerResult(outputPackets);
var timestamp = outImagePacket.Timestamp().Microseconds() / _MICRO_SECONDS_PER_MILLISECOND;

resultCallback(handLandmarkerResult, image, (int)timestamp);
};
}

private static PoseLandmarkerResult BuildPoseLandmarkerResult(PacketMap outputPackets)
{
var poseLandmarksProtoPacket =
outputPackets.At<NormalizedLandmarkListVectorPacket, List<NormalizedLandmarkList>>(_NORM_LANDMARKS_STREAM_NAME);
if (poseLandmarksProtoPacket.IsEmpty())
{
return PoseLandmarkerResult.Empty();
}

var poseLandmarksProto = poseLandmarksProtoPacket.Get();
var poseWorldLandmarksProto = outputPackets.At<LandmarkListVectorPacket, List<LandmarkList>>(_POSE_WORLD_LANDMARKS_STREAM_NAME).Get();
var segmentationMasks = outputPackets.At<ImageVectorPacket, List<Image>>(_SEGMENTATION_MASK_STREAM_NAME)?.Get();

return PoseLandmarkerResult.CreateFrom(poseLandmarksProto, poseWorldLandmarksProto, segmentationMasks);
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Copyright (c) 2023 homuler
//
// Use of this source code is governed by an MIT-style
// license that can be found in the LICENSE file or at
// https://opensource.org/licenses/MIT.

namespace Mediapipe.Tasks.Vision.PoseLandmarker
{
/// <summary>
/// Options for the pose landmarker task.
/// </summary>
public sealed class PoseLandmarkerOptions : Tasks.Core.ITaskOptions
{
/// <param name="poseLandmarksResult">
/// The pose landmarker detection results.
/// </param>
/// <param name="image">
/// The input image that the pose landmarker runs on.
/// </param>
/// <param name="timestampMs">
/// The input timestamp in milliseconds.
/// </param>
public delegate void ResultCallback(PoseLandmarkerResult poseLandmarksResult, Image image, int timestampMs);

/// <summary>
/// Base options for the pose landmarker task.
/// </summary>
public Tasks.Core.BaseOptions baseOptions { get; }
/// <summary>
/// The running mode of the task. Default to the image mode.
/// PoseLandmarker has three running modes:
/// <list type="number">
/// <item>
/// <description>The image mode for detecting pose landmarks on single image inputs.</description>
/// </item>
/// <item>
/// <description>The video mode for detecting pose landmarks on the decoded frames of a video.</description>
/// </item>
/// <item>
/// <description>
/// The live stream mode or detecting pose landmarks on the live stream of input data, such as from camera.
/// In this mode, the <see cref="resultCallback" /> below must be specified to receive the detection results asynchronously.
/// </description>
/// </item>
/// </list>
/// </summary>
public Core.RunningMode runningMode { get; }
/// <summary>
/// The maximum number of poses can be detected by the pose landmarker.
/// </summary>
public int numPoses { get; }
/// <summary>
/// The minimum confidence score for the pose detection to be considered successful.
/// </summary>
public float minPoseDetectionConfidence { get; }
/// <summary>
/// The minimum confidence score of pose presence score in the pose landmark detection.
/// </summary>
public float minPosePresenceConfidence { get; }
/// <summary>
/// The minimum confidence score for the pose tracking to be considered successful.
/// </summary>
public float minTrackingConfidence { get; }
/// <summary>
/// whether to output segmentation masks.
/// </summary>
public bool outputSegmentationMasks { get; }
/// <summary>
/// The user-defined result callback for processing live stream data.
/// The result callback should only be specified when the running mode is set to the live stream mode.
/// </summary>
public ResultCallback resultCallback { get; }

public PoseLandmarkerOptions(
Tasks.Core.BaseOptions baseOptions,
Core.RunningMode runningMode = Core.RunningMode.IMAGE,
int numPoses = 1,
float minPoseDetectionConfidence = 0.5f,
float minPosePresenceConfidence = 0.5f,
float minTrackingConfidence = 0.5f,
bool outputSegmentationMasks = false,
ResultCallback resultCallback = null)
{
this.baseOptions = baseOptions;
this.runningMode = runningMode;
this.numPoses = numPoses;
this.minPoseDetectionConfidence = minPoseDetectionConfidence;
this.minPosePresenceConfidence = minPosePresenceConfidence;
this.minTrackingConfidence = minTrackingConfidence;
this.outputSegmentationMasks = outputSegmentationMasks;
this.resultCallback = resultCallback;
}

internal Proto.PoseLandmarkerGraphOptions ToProto()
{
var baseOptionsProto = baseOptions.ToProto();
baseOptionsProto.UseStreamMode = runningMode != Core.RunningMode.IMAGE;

return new Proto.PoseLandmarkerGraphOptions
{
BaseOptions = baseOptionsProto,
PoseDetectorGraphOptions = new PoseDetector.Proto.PoseDetectorGraphOptions
{
NumPoses = numPoses,
MinDetectionConfidence = minPoseDetectionConfidence,
},
PoseLandmarksDetectorGraphOptions = new Proto.PoseLandmarksDetectorGraphOptions
{
MinDetectionConfidence = minPosePresenceConfidence,
},
MinTrackingConfidence = minTrackingConfidence,
};
}

CalculatorOptions Tasks.Core.ITaskOptions.ToCalculatorOptions()
{
var options = new CalculatorOptions();
options.SetExtension(Proto.PoseLandmarkerGraphOptions.Extensions.Ext, ToProto());
return options;
}
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 9ebc72f

Please sign in to comment.