Skip to content

Commit

Permalink
Add tracking-only ui to sleap-track.
Browse files Browse the repository at this point in the history
If sleap-track is given labels file as data_path and tracking
policy but no models, then it will run the tracker on the
previously generated predictions. This is a better interface
than having to call sleap.nn.tracking to retrack.

Resolves issue #260.
  • Loading branch information
ntabris committed Jan 22, 2020
1 parent 8c90746 commit ec98962
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 41 deletions.
129 changes: 88 additions & 41 deletions sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Predictor:
* "confmap": Instance of `peak_finding.ConfmapPeakFinder`
* "paf": Instance of `paf_grouping.PAFGrouper`
* "tracking": Instance of `tracking.Tracker`
* "previous_predictions": `predicted.PredictedInstancePredictor`
Note: the pipeline will be determined by which policies are given.
"""
Expand Down Expand Up @@ -77,15 +78,27 @@ def predict(
if self.has_grayscale_models:
video_kwargs["grayscale"] = True

is_dummy_video = False
if "tracking" in self.policies and "previous_predictions" in self.policies:
if not self.policies["tracking"].uses_image:
# We're just running the tracker for previous predictions
# and this tracker doesn't use the images, so we'll load
# "dummy" images.
is_dummy_video = True

video_ds = utils.VideoLoader(
filename=video_filename, frame_inds=frames, **video_kwargs
filename=video_filename,
frame_inds=frames,
dummy=is_dummy_video,
**video_kwargs,
)

predicted_frames = []

for chunk_ind, frame_inds, imgs in video_ds:

predicted_instances_chunk = self.predict_chunk(
imgs, chunk_ind, video_ds.chunk_size
imgs, chunk_ind, video_ds.chunk_size, frame_inds=frame_inds
)

sample_inds = np.arange(len(imgs))
Expand All @@ -100,9 +113,12 @@ def predict(

return predicted_frames

def predict_chunk(self, img_chunk, chunk_ind, chunk_size):
def predict_chunk(self, img_chunk, chunk_ind, chunk_size, frame_inds=None):
"""Runs the inference components of pipeline for a chunk."""

if "previous_predictions" in self.policies:
return self.policies["previous_predictions"].get_chunk(frame_inds)

if "centroid" in self.policies:
# Detect centroids and pull out region proposals.
centroid_predictor = self.policies["centroid"]
Expand Down Expand Up @@ -306,7 +322,6 @@ def frame_list(frame_str: str):
action="append",
help="Path to saved model (confmaps, pafs, ...) JSON. "
"Multiple models can be specified, each preceded by --model.",
required=True,
)

parser.add_argument(
Expand Down Expand Up @@ -350,10 +365,14 @@ def frame_list(frame_str: str):
@classmethod
def cli_args_to_policies(cls, args):
policy_args = util.make_scoped_dictionary(vars(args), exclude_nones=True)
return cls.from_paths_and_policy_args(args.models, policy_args)
return cls.from_paths_and_policy_args(
model_paths=args.models, policy_args=policy_args, args=args,
)

@classmethod
def from_paths_and_policy_args(cls, model_paths: List[str], policy_args: dict):
def from_paths_and_policy_args(
cls, model_paths: List[str], policy_args: dict, args: dict
):
policy_args["region"]["merge_overlapping"] = True

inferred_box_length = 160 # default if not set by user or inferrable
Expand All @@ -369,35 +388,55 @@ def from_paths_and_policy_args(cls, model_paths: List[str], policy_args: dict):

# Load the information for these models
loaded_models = dict()
for model_path in model_paths:
training_job = job.TrainingJob.load_json(model_path)
inference_model = model.InferenceModel.from_training_job(training_job)
policy_key = model_type_policy_key_map[training_job.model.output_type]

loaded_models[policy_key] = dict(
job=training_job, inference_model=inference_model
)
if model_paths:
for model_path in model_paths:
training_job = job.TrainingJob.load_json(model_path)
inference_model = model.InferenceModel.from_training_job(training_job)
policy_key = model_type_policy_key_map[training_job.model.output_type]

loaded_models[policy_key] = dict(
job=training_job, inference_model=inference_model
)

# Add policy classes which depend on models
for policy_key, policy_model in loaded_models.items():
training_job = policy_model["job"]
inference_model = policy_model["inference_model"]
# Add policy classes which depend on models
for policy_key, policy_model in loaded_models.items():
training_job = policy_model["job"]
inference_model = policy_model["inference_model"]

if policy_key == "confmap" and "paf" not in loaded_models.keys():
# Use topdown class when we have confmaps and not pafs
policy_class = POLICY_CLASSES["topdown"]
else:
policy_class = POLICY_CLASSES[policy_key]
if policy_key == "confmap" and "paf" not in loaded_models.keys():
# Use topdown class when we have confmaps and not pafs
policy_class = POLICY_CLASSES["topdown"]
else:
policy_class = POLICY_CLASSES[policy_key]

policy_object = policy_class(
inference_model=inference_model, **policy_args[policy_key]
)
policy_object = policy_class(
inference_model=inference_model, **policy_args[policy_key]
)

policies[policy_key] = policy_object
policies[policy_key] = policy_object

if training_job.trainer.bounding_box_size is not None:
if training_job.trainer.bounding_box_size > 0:
inferred_box_length = training_job.trainer.bounding_box_size
if training_job.trainer.bounding_box_size is not None:
if training_job.trainer.bounding_box_size > 0:
inferred_box_length = training_job.trainer.bounding_box_size

# No models specified so see if we're using previous predictions
else:
try:
previous_labels = Labels.load_file(
args.data_path, video_callback=[os.path.dirname(args.data_path)],
)
from .predicted import PredictedInstancePredictor

policies["previous_predictions"] = PredictedInstancePredictor(
labels=previous_labels,
)
print(f"Using previous predictions from {args.data_path}")
args.data_path = previous_labels.videos[0].filename
print(f"Setting video to {args.data_path}")
except Exception:
# We weren't able to read file as Labels object
pass

if "topdown" in policies:
policy_args["region"]["merge_overlapping"] = False
Expand Down Expand Up @@ -492,22 +531,30 @@ def predict_subprocess(
def check_valid_policies(cls, policies: dict) -> bool:

has_topdown = "topdown" in policies

has_previous = "previous_predictions" in policies
has_tracker = "tracking" in policies
non_topdowns = [key for key in policies.keys() if key in ("confmap", "paf")]

if has_topdown and non_topdowns:
raise ValueError(
f"Cannot combine topdown model with non-topdown model"
f" {non_topdowns}."
)
if has_previous:
if not has_tracker:
raise ValueError(
f"No tracker specified for running on previous predictions"
)

if non_topdowns and "confmap" not in non_topdowns:
raise ValueError("Must have CONFIDENCE_MAP model.")
else:
if has_topdown and non_topdowns:
raise ValueError(
f"Cannot combine topdown model with non-topdown model"
f" {non_topdowns}."
)

if not has_topdown and not non_topdowns:
raise ValueError(
f"Must have either TOPDOWN or CONFIDENCE_MAP/PART_AFFINITY_FIELD models."
)
if non_topdowns and "confmap" not in non_topdowns:
raise ValueError("Must have CONFIDENCE_MAP model.")

if not has_topdown and not non_topdowns:
raise ValueError(
f"Must have either TOPDOWN or CONFIDENCE_MAP/PART_AFFINITY_FIELD models."
)

return True

Expand Down
25 changes: 25 additions & 0 deletions sleap/nn/predicted.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import attr


@attr.s(auto_attribs=True)
class PredictedInstancePredictor:
"""
Returns chunk of previously generated predictions in format of Predictor.
"""

labels: "Labels"
video_idx: int = 0

def get_chunk(self, frame_inds):
video = self.labels.videos[self.video_idx]

# Return dict keyed to sample index (i.e., offset in frame_inds), value
# is the list of instances for that frame.
return {
i: [
inst
for lf in self.labels.find(video=video, frame_idx=int(frame_idx))
for inst in lf.instances
]
for i, frame_idx in enumerate(frame_inds)
}

0 comments on commit ec98962

Please sign in to comment.