Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add multi-animal support to DLCLive #72

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 169 additions & 3 deletions dlclive/dlclive.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import tensorflow as tf
import typing
from pathlib import Path
from scipy.optimize import linear_sum_assignment
from typing import Optional, Tuple, List

try:
Expand All @@ -24,6 +25,13 @@
except Exception:
pass

from deeplabcut.pose_estimation_tensorflow.config import load_config
from deeplabcut.pose_estimation_tensorflow.core import (
predict, predict_multianimal,
)
from deeplabcut.pose_estimation_tensorflow.lib import (
trackingutils, inferenceutils,
)
from dlclive.graph import (
read_graph,
finalize_graph,
Expand Down Expand Up @@ -86,7 +94,7 @@ class DLCLive(object):
display_lik : float, optional
Likelihood threshold for display

display_raidus : int, optional
display_radius : int, optional
radius for keypoint display in pixels, default=3
"""

Expand Down Expand Up @@ -294,7 +302,7 @@ def init_inference(self, frame=None, **kwargs):
graph = finalize_graph(graph_def)
output_nodes = get_output_nodes(graph)
output_nodes = [on.replace("DLC/", "") for on in output_nodes]

tf_version_2 = tf.__version__[0] == '2'

if tf_version_2:
Expand All @@ -311,7 +319,7 @@ def init_inference(self, frame=None, **kwargs):
output_nodes,
input_shapes={"Placeholder": [1, processed_frame.shape[0], processed_frame.shape[1], 3]},
)

try:
tflite_model = converter.convert()
except Exception:
Expand Down Expand Up @@ -478,3 +486,161 @@ def close(self):
self.is_initialized = False
if self.display is not None:
self.display.destroy()


class MultiAnimalDLCLive(DLCLive):
def __init__(
self,
model_path,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type hint!
model_path: Union[Path, str]
n_animals: int
n_multibodyparts: int

n_animals,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these not specified in the config file?

n_multibodyparts,
track_method: str = "box",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could be a literal
track_method: Literal['box', 'ellipse']

min_hits: int = 1,
max_age: int = 1,
sim_threshold: float = 0.6,
resize: Optional[float] = None,
convert2rgb: bool = True,
processor: Optional['Processor'] = None,
display: typing.Union[bool, Display] = False,
pcutoff: float = 0.5,
display_radius: int = 3,
display_cmap: str = "bmy",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing arguments from parent signature:

  • model_type
  • precision
  • tf_config
  • cropping
  • dynamic

Any reason? We should try and make uniform API if possible.

Since these are all in the DLCLive.PARAMETERS tuple, this will cause .parameterization to error. This also removes control over the process_frame method.

Two approaches: add the missing arguments, or else remove all of them and accept a **kwargs that gets passed to the parent class.

):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

needs docstring! Since inherits from DLCLive but the signature differs, need to also document changes in calling convention. Also need to document any new attrs

if track_method not in ("box", "ellipse"):
raise ValueError("`track_method` should be either `box` or `ellipse`.")

self.model_path = model_path
super().__init__(
Path(model_path).parent,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parent class expects a str.

resize=resize,
convert2rgb=convert2rgb,
processor=processor,
display=display,
pcutoff=pcutoff,
display_radius=display_radius,
display_cmap=display_cmap,
)
self.n_animals = n_animals
self.n_multibodyparts = n_multibodyparts
self.track_method = track_method
self.min_hits = min_hits
self.max_age = max_age
self.sim_threshold = sim_threshold

def read_config(self):
cfg_path = Path(self.path).resolve() / "pose_cfg.yaml"
if not cfg_path.exists():
raise FileNotFoundError(
f"The pose configuration file for the exported model at {str(cfg_path)} was not found. Please check the path to the exported model directory"
)

self.cfg = load_config(cfg_path)
self.cfg["batch_size"] = 1
self.cfg["init_weights"] = self.model_path.split(".")[0]
self.identity_only = self.cfg["num_idchannel"] > 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not declared in init


def init_inference(
self,
frame=None,
allow_growth=False,
**kwargs,
):
self.sess, self.inputs, self.outputs = predict.setup_pose_prediction(
self.cfg, allow_growth=allow_growth,
)

if self.track_method == "box":
self.mot_tracker = trackingutils.SORTBox(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

attribute not declared in __init__. would be nice to just make this a property

self.max_age, self.min_hits, self.sim_threshold,
)
else:
self.mot_tracker = trackingutils.SORTEllipse(
self.max_age, self.min_hits, self.sim_threshold,
)


data = {
"metadata": {
"all_joints_names": self.cfg["all_joints_names"],
"PAFgraph": self.cfg["partaffinityfield_graph"],
"PAFinds": self.cfg.get("paf_best", np.arange(self.n_multibodyparts))
}
}
# Hack to avoid IndexError when determining _has_identity
temp = {"identity": []} if self.identity_only else {}
data["frame0"] = temp
self.ass = inferenceutils.Assembler(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

__init__ also does not declare that the object has an ass

data,
max_n_individuals=self.n_animals,
n_multibodyparts=self.n_multibodyparts,
greedy=True, # TODO Benchmark vs optimal matching
identity_only=self.identity_only,
max_overlap=1,
)

if frame is not None:
pose = self.get_pose(frame, **kwargs)
else:
pose = None

self.is_initialized = True

return pose

def get_pose(self, frame=None, **kwargs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now might also be a good time to add docstrings to these methods too -- their function is straightforward enough, but documenting what goes on in them, etc. either here or documenting the arguments/attrs in the __init__ docstring that change the function here.

if frame is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

iirc frame being optional is a hangover from an old version. I can't think of a reason to keep it that way?

raise DLCLiveError("No frame provided for live pose estimation")

frame = self.process_frame(frame)
data_dict = predict_multianimal.predict_batched_peaks_and_costs(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what happens from here to L627, some comments would be nice!

self.cfg,
np.expand_dims(frame, axis=0),
self.sess,
self.inputs,
self.outputs,
)
if not data_dict:
return
else:
data_dict = data_dict[0]

pose = np.full((self.n_animals, self.n_multibodyparts, 4), np.nan)
assemblies, unique = self.ass._assemble(data_dict, ind_frame=0)
if assemblies:
if self.n_animals == 1:
pose[0] = assemblies[0].data
else:
animals = np.stack([a.data for a in assemblies])
if not self.ass.identity_only:
if self.track_method == "box":
xy = trackingutils.calc_bboxes_from_keypoints(animals)
else:
xy = animals[..., :2]
trackers = self.mot_tracker.track(xy)[:, -2:].astype(np.int)
else:
# Optimal identity assignment based on soft voting
mat = np.zeros(
(len(assemblies), self.n_animals)
)
for nrow, assembly in enumerate(assemblies):
for k, v in assembly.soft_identity.items():
mat[nrow, k] = v
inds = linear_sum_assignment(mat, maximize=True)
trackers = np.c_[inds][:, ::-1]
# Discard trackers of false positives
trackers = trackers[trackers[:, 0] < self.n_animals]
for pose_ind, animal_ind in trackers:
pose[pose_ind] = animals[animal_ind]
self.pose = (pose, unique)

if self.display is not None:
self.display.display_frame(frame, self.pose)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did this get tested? not sure if the maDLC pose is any different than regular


if self.resize is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing corrections for cropping and dynamic_cropping which can't be set from __init__, but can be set as instance attributes. these should be split out into a separate method in the parent class to avoid having to duplicate code.

self.pose[0][..., :2] *= 1 / self.resize
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like this should happen for n poses, right? rather than a hard 2? but i'm not sure what the structure of pose is here.

self.pose[1][:, :2] *= 1 / self.resize

if self.processor:
self.pose = self.processor.process(self.pose, **kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the processor need to be different for maDLC pose?


return self.pose
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does the close method need to be updated at all? again not sure how the maDLC stuff works.