Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update frozen graph import to support both ragged and unragged tensors #21

Merged
merged 6 commits into from
May 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions src/Bonsai.Sleap/PredictCentroids.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ private IObservable<CentroidCollection> Process(IObservable<IplImage[]> source)
TFSession.Runner runner = null;
var graph = TensorHelper.ImportModel(ModelFileName, out TFSession session);
var config = ConfigHelper.LoadTrainingConfig(TrainingConfig);
var ragged = graph["Identity_6"] != null;

if (config.ModelType != ModelType.Centroid)
{
Expand All @@ -84,7 +85,7 @@ private IObservable<CentroidCollection> Process(IObservable<IplImage[]> source)
var tensorSize = input[0].Size;
var batchSize = input.Length;
var scaleFactor = ScaleFactor;

if (scaleFactor.HasValue)
{
poseScale = scaleFactor.Value;
Expand All @@ -93,37 +94,48 @@ private IObservable<CentroidCollection> Process(IObservable<IplImage[]> source)
poseScale = 1.0 / poseScale;
}

if (tensor == null || tensor.Shape[0] != batchSize || tensor.Shape[1] != tensorSize.Height || tensor.Shape[2] != tensorSize.Width )
if (tensor == null || tensor.Shape[0] != batchSize || tensor.Shape[1] != tensorSize.Height || tensor.Shape[2] != tensorSize.Width)
{
tensor?.Dispose();
runner = session.GetRunner();
tensor = TensorHelper.CreatePlaceholder(graph, runner, tensorSize, batchSize, colorChannels);

runner.Fetch(graph["Identity"][0]);
runner.Fetch(graph["Identity_2"][0]);
if (ragged)
{
// ragged version of the frozen graph
runner.Fetch(graph["Identity"][0]);
runner.Fetch(graph["Identity_2"][0]);
}
else
{
// unragged version of the frozen graph
runner.Fetch(graph["Identity"][0]);
runner.Fetch(graph["Identity_1"][0]);
}
}

var frames = Array.ConvertAll(input, frame =>
var frames = Array.ConvertAll(input, frame =>
{
frame = TensorHelper.EnsureFrameSize(frame, tensorSize, ref resizeTemp);
frame = TensorHelper.EnsureColorFormat(frame, ColorConversion, ref colorTemp, colorChannels);
return frame;
});
TensorHelper.UpdateTensor(tensor, colorChannels, frames);
var output = runner.Run();


var shapeIdx = ragged ? 0 : 1;
var centroidCollection = new CentroidCollection(input[0]);
if (output[0].Shape[0] == 0) return centroidCollection;
if (output[0].Shape[shapeIdx] == 0) return centroidCollection;
else
{
// Fetch the results from output
var centroidConfidenceTensor = output[0];
float[] centroidConfArr = new float[centroidConfidenceTensor.Shape[0]];
centroidConfidenceTensor.GetValue(centroidConfArr);
float[] centroidConfArr = new float[centroidConfidenceTensor.Shape[shapeIdx]];
TensorHelper.GetTensorValue(centroidConfidenceTensor, centroidConfArr);

var centroidTensor = output[1];
float[,] centroidArr = new float[centroidTensor.Shape[0], centroidTensor.Shape[1]];
centroidTensor.GetValue(centroidArr);
float[,] centroidArr = new float[centroidTensor.Shape[shapeIdx], centroidTensor.Shape[shapeIdx + 1]];
TensorHelper.GetTensorValue(centroidTensor, centroidArr);

var confidenceThreshold = CentroidMinConfidence;
for (int i = 0; i < centroidConfArr.GetLength(0); i++)
Expand Down
43 changes: 29 additions & 14 deletions src/Bonsai.Sleap/PredictPoseIdentities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ private IObservable<PoseIdentityCollection> Process(IObservable<IplImage[]> sour
TFSession.Runner runner = null;
var graph = TensorHelper.ImportModel(ModelFileName, out TFSession session);
var config = ConfigHelper.LoadTrainingConfig(TrainingConfig);
var ragged = graph["Identity_6"] != null;

if (config.ModelType != ModelType.MultiClass)
{
Expand All @@ -114,17 +115,30 @@ private IObservable<PoseIdentityCollection> Process(IObservable<IplImage[]> sour
poseScale = 1.0 / poseScale;
}

if (tensor == null || tensor.Shape[0] != batchSize || tensor.Shape[1] != tensorSize.Height || tensor.Shape[2] != tensorSize.Width )
if (tensor == null || tensor.Shape[0] != batchSize || tensor.Shape[1] != tensorSize.Height || tensor.Shape[2] != tensorSize.Width)
{
tensor?.Dispose();
runner = session.GetRunner();
tensor = TensorHelper.CreatePlaceholder(graph, runner, tensorSize, batchSize, colorChannels);

runner.Fetch(graph["Identity"][0]);
runner.Fetch(graph["Identity_2"][0]);
runner.Fetch(graph["Identity_4"][0]);
runner.Fetch(graph["Identity_5"][0]);
runner.Fetch(graph["Identity_6"][0]);
if (ragged)
{
// ragged version of the frozen graph
runner.Fetch(graph["Identity"][0]);
runner.Fetch(graph["Identity_2"][0]);
runner.Fetch(graph["Identity_4"][0]);
runner.Fetch(graph["Identity_5"][0]);
runner.Fetch(graph["Identity_6"][0]);
}
else
{
// unragged version of the frozen graph
runner.Fetch(graph["Identity"][0]);
runner.Fetch(graph["Identity_1"][0]);
runner.Fetch(graph["Identity_2"][0]);
runner.Fetch(graph["Identity_3"][0]);
runner.Fetch(graph["Identity_4"][0]);
}
}

var frames = Array.ConvertAll(input, frame =>
Expand All @@ -136,30 +150,31 @@ private IObservable<PoseIdentityCollection> Process(IObservable<IplImage[]> sour
TensorHelper.UpdateTensor(tensor, colorChannels, frames);
var output = runner.Run();

var shapeIdx = ragged ? 0 : 1;
var identityCollection = new PoseIdentityCollection(input[0]);
if (output[0].Shape[0] == 0) return identityCollection;
if (output[0].Shape[shapeIdx] == 0) return identityCollection;
else
{
// Fetch the results from output
var centroidConfidenceTensor = output[0];
float[] centroidConfArr = new float[centroidConfidenceTensor.Shape[0]];
centroidConfidenceTensor.GetValue(centroidConfArr);
float[] centroidConfArr = new float[centroidConfidenceTensor.Shape[shapeIdx]];
TensorHelper.GetTensorValue(centroidConfidenceTensor, centroidConfArr);

var centroidTensor = output[1];
float[,] centroidArr = new float[centroidTensor.Shape[0], centroidTensor.Shape[1]];
centroidTensor.GetValue(centroidArr);
float[,] centroidArr = new float[centroidTensor.Shape[shapeIdx], centroidTensor.Shape[shapeIdx + 1]];
TensorHelper.GetTensorValue(centroidTensor, centroidArr);

var partConfTensor = output[2];
float[,] partConfArr = new float[partConfTensor.Shape[0], partConfTensor.Shape[1]];
partConfTensor.GetValue(partConfArr);
TensorHelper.GetTensorValue(partConfTensor, partConfArr);

var poseTensor = output[3];
float[,,] poseArr = new float[poseTensor.Shape[0], poseTensor.Shape[1], poseTensor.Shape[2]];
poseTensor.GetValue(poseArr);
TensorHelper.GetTensorValue(poseTensor, poseArr);

var idTensor = output[4];
float[,] idArr = new float[idTensor.Shape[0], idTensor.Shape[1]];
idTensor.GetValue(idArr);
TensorHelper.GetTensorValue(idTensor, idArr);

var partThreshold = PartMinConfidence;
var idThreshold = IdentityMinConfidence;
Expand Down
49 changes: 31 additions & 18 deletions src/Bonsai.Sleap/PredictPoses.cs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ private IObservable<PoseCollection> Process(IObservable<IplImage[]> source)
TFSession.Runner runner = null;
var graph = TensorHelper.ImportModel(ModelFileName, out TFSession session);
var config = ConfigHelper.LoadTrainingConfig(TrainingConfig);
var ragged = graph["Identity_6"] != null;

if (config.ModelType != ModelType.CenteredInstance)
{
Expand All @@ -93,7 +94,7 @@ private IObservable<PoseCollection> Process(IObservable<IplImage[]> source)
var tensorSize = input[0].Size;
var batchSize = input.Length;
var scaleFactor = ScaleFactor;

if (scaleFactor.HasValue)
{
poseScale = scaleFactor.Value;
Expand All @@ -102,20 +103,31 @@ private IObservable<PoseCollection> Process(IObservable<IplImage[]> source)
poseScale = 1.0 / poseScale;
}

if (tensor == null || tensor.Shape[0] != batchSize || tensor.Shape[1] != tensorSize.Height || tensor.Shape[2] != tensorSize.Width )
if (tensor == null || tensor.Shape[0] != batchSize || tensor.Shape[1] != tensorSize.Height || tensor.Shape[2] != tensorSize.Width)
{
tensor?.Dispose();
runner = session.GetRunner();
tensor = TensorHelper.CreatePlaceholder(graph, runner, tensorSize, batchSize, colorChannels);

runner.Fetch(graph["Identity"][0]);
runner.Fetch(graph["Identity_2"][0]);
runner.Fetch(graph["Identity_4"][0]);
runner.Fetch(graph["Identity_6"][0]);

if (ragged)
{
// ragged version of the frozen graph
runner.Fetch(graph["Identity"][0]);
runner.Fetch(graph["Identity_2"][0]);
runner.Fetch(graph["Identity_4"][0]);
runner.Fetch(graph["Identity_6"][0]);
}
else
{
// unragged version of the frozen graph
runner.Fetch(graph["Identity"][0]);
runner.Fetch(graph["Identity_1"][0]);
runner.Fetch(graph["Identity_2"][0]);
runner.Fetch(graph["Identity_3"][0]);
}
}

var frames = Array.ConvertAll(input, frame =>
var frames = Array.ConvertAll(input, frame =>
{
frame = TensorHelper.EnsureFrameSize(frame, tensorSize, ref resizeTemp);
frame = TensorHelper.EnsureColorFormat(frame, ColorConversion, ref colorTemp, colorChannels);
Expand All @@ -125,30 +137,31 @@ private IObservable<PoseCollection> Process(IObservable<IplImage[]> source)
TensorHelper.UpdateTensor(tensor, colorChannels, frames);
var output = runner.Run();

var shapeIdx = ragged ? 0 : 1;
var poseCollection = new PoseCollection(input[0]);
if (output[0].Shape[0] == 0) return poseCollection;
if (output[0].Shape[shapeIdx] == 0) return poseCollection;
else
{
var centroidConfidenceTensor = output[0];
float[] centroidConfArr = new float[centroidConfidenceTensor.Shape[0]];
centroidConfidenceTensor.GetValue(centroidConfArr);
float[] centroidConfArr = new float[centroidConfidenceTensor.Shape[shapeIdx]];
TensorHelper.GetTensorValue(centroidConfidenceTensor, centroidConfArr);

var centroidTensor = output[1];
float[,] centroidArr = new float[centroidTensor.Shape[0], centroidTensor.Shape[1]];
centroidTensor.GetValue(centroidArr);
float[,] centroidArr = new float[centroidTensor.Shape[shapeIdx], centroidTensor.Shape[shapeIdx + 1]];
TensorHelper.GetTensorValue(centroidTensor, centroidArr);

var partConfTensor = output[2];
float[,] partConfArr = new float[partConfTensor.Shape[0], partConfTensor.Shape[1]];
partConfTensor.GetValue(partConfArr);
float[,] partConfArr = new float[partConfTensor.Shape[shapeIdx], partConfTensor.Shape[shapeIdx + 1]];
TensorHelper.GetTensorValue(partConfTensor, partConfArr);

var poseTensor = output[3];
float[,,] poseArr = new float[poseTensor.Shape[0], poseTensor.Shape[1], poseTensor.Shape[2]];
poseTensor.GetValue(poseArr);
float[,,] poseArr = new float[poseTensor.Shape[shapeIdx], poseTensor.Shape[shapeIdx + 1], poseTensor.Shape[shapeIdx + 2]];
TensorHelper.GetTensorValue(poseTensor, poseArr);

var partThreshold = PartMinConfidence;
var centroidThreshold = CentroidMinConfidence;

//Loop the available identifications
// Loop the available identifications
for (int i = 0; i < centroidArr.GetLength(0); i++)
{
var pose = new Pose(input[0]);
Expand Down
14 changes: 14 additions & 0 deletions src/Bonsai.Sleap/TensorHelper.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using OpenCV.Net;
using System;
using System.IO;
using System.Runtime.InteropServices;
using TensorFlow;

namespace Bonsai.Sleap
Expand Down Expand Up @@ -132,5 +133,18 @@ public static IplImage[] GetTensorMaps(TFTensor tensor, int batchIndex = 0)
return result;
}
}

public static unsafe void GetTensorValue(TFTensor tensor, Array array)
{
var elementType = array.GetType().GetElementType();
tensor.CheckDataTypeAndSize(elementType, array.Length);
var gCHandle = GCHandle.Alloc(array, GCHandleType.Pinned);
try
{
var num = tensor.TensorByteSize.ToUInt64();
Buffer.MemoryCopy(tensor.Data.ToPointer(), gCHandle.AddrOfPinnedObject().ToPointer(), num, num);
}
finally { gCHandle.Free(); }
}
}
}