Skip to content

Commit

Permalink
Expose attributes of NWBFile and create Labels API for exporting to N…
Browse files Browse the repository at this point in the history
…WB (#855)
  • Loading branch information
roomrys authored Jul 23, 2022
1 parent d522f1d commit 7e833a2
Show file tree
Hide file tree
Showing 4 changed files with 201 additions and 97 deletions.
21 changes: 21 additions & 0 deletions sleap/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import cattr
import h5py as h5
import numpy as np
import datetime
from sklearn.model_selection import train_test_split

try:
Expand Down Expand Up @@ -2028,6 +2029,26 @@ def export(self, filename: str):

SleapAnalysisAdaptor.write(filename, self)

def export_nwb(
self,
filename: str,
overwrite: bool = False,
session_description: str = "Processed SLEAP pose data",
identifier: Optional[str] = None,
session_start_time: Optional[datetime.datetime] = None,
):
from sleap.io.format.ndx_pose import NDXPoseAdaptor

NDXPoseAdaptor.write(
NDXPoseAdaptor,
filename=filename,
labels=self,
overwrite=overwrite,
session_description=session_description,
identifier=identifier,
session_start_time=session_start_time,
)

@classmethod
def load_json(cls, filename: str, *args, **kwargs) -> "Labels":
from .format import read
Expand Down
208 changes: 132 additions & 76 deletions sleap/io/format/ndx_pose.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import datetime
import re
import numpy as np
import uuid

from pathlib import PurePath
from typing import List
from pathlib import Path, PurePath
from typing import List, Optional
from pynwb import NWBFile, NWBHDF5IO, ProcessingModule
from ndx_pose import PoseEstimationSeries, PoseEstimation

Expand Down Expand Up @@ -68,7 +69,9 @@ def read(self, file: FileHandle) -> Labels:
nwb_file = read_nwbfile.processing

# Get list of videos
video_keys: List[str] = list(nwb_file.keys())
video_keys: List[str] = [
key for key in nwb_file.keys() if "SLEAP_VIDEO" in key
]
video_tracks = dict()

# Get track keys
Expand Down Expand Up @@ -164,7 +167,15 @@ def read(self, file: FileHandle) -> Labels:
labels = Labels(lfs)
return labels

def write(self, filename: str, labels: Labels):
def write(
self,
filename: str,
labels: Labels,
overwrite: bool = False,
session_description: str = "Processed SLEAP pose data",
identifier: Optional[str] = None,
session_start_time: Optional[datetime.datetime] = None,
):
"""Write all `PredictedInstance` objects in a `Labels` object to an NWB file.
Use `Labels.numpy` to create a `pynwb.NWBFile` with a separate
Expand Down Expand Up @@ -198,97 +209,142 @@ def write(self, filename: str, labels: Labels):
Args:
filename: Output path for the NWB format file.
labels: The `Labels` object to covert to a NWB format file.
overwrite: Boolean that overwrites existing NWB file if True. If False, data
will be appended to existing NWB file.
session_description: Description for entire project. Stored under
NWBFile "session_description" key. If appending data to a preexisting
file, then the session_description will not be used.
identifier: Unique identifier for project. If no identifier is
specified, then will generate a GUID. If appending data to a
preexisting file, then the identifier will not be used.
session_start_time: THe datetime associated with the project. If no
session_start_time is given, then the current datetime will be used. If
appending data to a preexisting file, then the session_start_time will
not be used.
Returns:
A `pynwb.NWBFile` with a separate `pynwb.ProcessingModule` for each
`Video` in the `Labels` object.
"""

skeleton = labels.skeleton

# Check that this project contains predicted instances
if len(labels.predicted_instances) == 0:
raise TypeError(
"Only predicted instances are written to the NWB format. "
"This project has no predicted instances"
"This project has no predicted instances."
)

print(f"\nCreating NWB file...")
nwb_file = NWBFile(
session_description="session_description",
identifier="identifier",
session_start_time=datetime.datetime.now(datetime.timezone.utc),
)
# Set optional kwargs if not specified by user
if session_start_time is None:
session_start_time = datetime.datetime.now(datetime.timezone.utc)
identifier = str(uuid.uuid4()) if identifier is None else identifier

for video_idx, video in enumerate(labels.videos):
# Create new processing module for each video
video_fn = PurePath(video.backend.filename)
nwb_processing_module = nwb_file.create_processing_module(
name=f"{video_idx:03}_{video_fn.stem}",
description=f"Processed SLEAP pose data for {video_fn.name} with "
f"{skeleton.name} skeleton.",
)

# Get tracks for each video
video_lfs = labels.get(video)
untracked = all(
[inst.track is None for lf in video_lfs for inst in lf.instances]
)
tracks_numpy = labels.numpy(
video=video,
all_frames=True,
untracked=untracked,
return_confidence=True,
)
n_frames, n_tracks, n_nodes, _ = tracks_numpy.shape
timestamps = np.arange(n_frames)
for track_idx in list(range(n_tracks)):
pose_estimation_series: List[PoseEstimationSeries] = []

for node_idx, node in enumerate(skeleton.nodes):

# Create instance of PoseEstimationSeries for each node
data = tracks_numpy[:, track_idx, node_idx, :2]
confidence = tracks_numpy[:, track_idx, node_idx, 2]

pose_estimation_series.append(
PoseEstimationSeries(
name=f"{node.name}",
description=f"Sequential trajectory of {node.name}.",
data=data,
unit="pixels",
reference_frame="No reference.",
timestamps=timestamps,
confidence=confidence,
confidence_definition="Point-wise confidence scores.",
)
try:
io = None
if Path(filename).exists() and not overwrite:
# Append to file if it exists and we do not want to overwrite
print(f"\nOpening existing NWB file...")
io = NWBHDF5IO(filename, mode="a", load_namespaces=True)
nwb_file = io.read()
else:
# If file does not exist or we want to overwrite, create new file
if not overwrite:
print(f"\nCould not find the file specified: {filename}")
print(f"\nCreating NWB file...")
nwb_file = NWBFile(
session_description=session_description,
identifier=identifier,
session_start_time=session_start_time,
)
io = NWBHDF5IO(filename, mode="w")

skeleton = labels.skeleton

for video_idx, video in enumerate(labels.videos):
# Create new processing module for each video
video_fn = PurePath(video.backend.filename)
try:
name = f"SLEAP_VIDEO_{video_idx:03}_{video_fn.stem}"
nwb_processing_module = nwb_file.create_processing_module(
name=name,
description=f"{session_description} for {video_fn.name} with "
f"{skeleton.name} skeleton.",
)
except ValueError:
# Cannot overwrite or delete processing modules
print(
f"Processing module for {video_fn.name} already exists... "
f"Skipping: {name}"
)
continue

# Combine each node's PoseEstimationSeries to create a PoseEstimation
name_prefix = "untracked" if untracked else "track"
pose_estimation = PoseEstimation(
name=f"{name_prefix}{track_idx:03}",
pose_estimation_series=pose_estimation_series,
description=(
f"Estimated positions of {skeleton.name} in video {video_fn} "
f"using SLEAP."
),
original_videos=[f"{video_fn}"],
labeled_videos=[f"{video_fn}"],
dimensions=np.array([[video.backend.height, video.backend.width]]),
scorer=str(labels.provenance),
source_software="SLEAP",
source_software_version=f"{sleap.__version__}",
nodes=skeleton.node_names,
edges=skeleton.edge_inds,
# Get tracks for each video
video_lfs = labels.get(video)
untracked = all(
[inst.track is None for lf in video_lfs for inst in lf.instances]
)
tracks_numpy = labels.numpy(
video=video,
all_frames=True,
untracked=untracked,
return_confidence=True,
)
n_frames, n_tracks, n_nodes, _ = tracks_numpy.shape
timestamps = np.arange(n_frames)
for track_idx in list(range(n_tracks)):
pose_estimation_series: List[PoseEstimationSeries] = []

for node_idx, node in enumerate(skeleton.nodes):

# Create instance of PoseEstimationSeries for each node
data = tracks_numpy[:, track_idx, node_idx, :2]
confidence = tracks_numpy[:, track_idx, node_idx, 2]

pose_estimation_series.append(
PoseEstimationSeries(
name=f"{node.name}",
description=f"Sequential trajectory of {node.name}.",
data=data,
unit="pixels",
reference_frame="No reference.",
timestamps=timestamps,
confidence=confidence,
confidence_definition="Point-wise confidence scores.",
)
)

# Combine each node's PoseEstimationSeries to create a PoseEstimation
name_prefix = "untracked" if untracked else "track"
pose_estimation = PoseEstimation(
name=f"{name_prefix}{track_idx:03}",
pose_estimation_series=pose_estimation_series,
description=(
f"Estimated positions of {skeleton.name} in video {video_fn} "
f"using SLEAP."
),
original_videos=[f"{video_fn}"],
labeled_videos=[f"{video_fn}"],
dimensions=np.array(
[[video.backend.height, video.backend.width]]
),
scorer=str(labels.provenance),
source_software="SLEAP",
source_software_version=f"{sleap.__version__}",
nodes=skeleton.node_names,
edges=skeleton.edge_inds,
)

# Create a processing module for each
nwb_processing_module.add(pose_estimation)
# Create a processing module for each
nwb_processing_module.add(pose_estimation)

path = filename
with NWBHDF5IO(path, mode="w") as io:
io.write(nwb_file)

except Exception as e:
raise e

finally:
if io is not None:
io.close()

print(f"Finished writing NWB file to {filename}\n")
16 changes: 15 additions & 1 deletion tests/io/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import os
import pytest
import numpy as np
from pathlib import Path
from pathlib import Path, PurePath

import sleap
from sleap.skeleton import Skeleton
from sleap.instance import Instance, Point, LabeledFrame, PredictedInstance, Track
from sleap.io.video import Video, MediaVideo
from sleap.io.dataset import Labels, load_file
from sleap.io.legacy import load_labels_json_old
from sleap.io.format.ndx_pose import NDXPoseAdaptor
from sleap.io.format import filehandle
from sleap.gui.suggestions import VideoFrameSuggestions, SuggestionFrame
from tests.io.test_formats import assert_read_labels_match

TEST_H5_DATASET = "tests/data/hdf5_format_v1/training.scale=0.50,sigma=10.h5"

Expand Down Expand Up @@ -1493,3 +1496,14 @@ def test_remove_untracked_instances(min_tracks_2node_labels):
# Test function with remove_empty_frames=True
labels.remove_untracked_instances(remove_empty_frames=True)
assert all([len(lf.instances) > 0 for lf in labels.labeled_frames])


def test_export_nwb(centered_pair_predictions: Labels, tmpdir):
filename = str(PurePath(tmpdir, "ndx_pose_test.nwb"))

# Export to NWB file
centered_pair_predictions.export_nwb(filename)

# Read from NWB file
read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename))
assert_read_labels_match(centered_pair_predictions, read_labels)
53 changes: 33 additions & 20 deletions tests/io/test_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,26 +320,7 @@ def test_tracking_scores(tmpdir, centered_pair_predictions_slp_path):
assert hasattr(instance, "tracking_score")


def test_nwb(centered_pair_predictions: Labels, small_robot_mp4_vid: Video, tmpdir):
"""Test that `Labels` can be written to and recreated from an NWB file."""
labels = centered_pair_predictions
filename = str(PurePath(tmpdir, "ndx_pose_test.nwb"))

# Add another video with an untracked PredictedInstance
labels.videos.append(small_robot_mp4_vid)
pred_instance: PredictedInstance = PredictedInstance.from_instance(
labels[0].instances[0], score=5
)
pred_instance.track = None
lf = LabeledFrame(video=small_robot_mp4_vid, frame_idx=6, instances=[pred_instance])
labels.append(lf)

# Write to NWB file
NDXPoseAdaptor.write(NDXPoseAdaptor, filename, labels)

# Read from NWB file
read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename))

def assert_read_labels_match(labels, read_labels):
# Labeled Frames
assert len(read_labels.labeled_frames) == len(labels.labeled_frames)

Expand Down Expand Up @@ -380,6 +361,38 @@ def test_nwb(centered_pair_predictions: Labels, small_robot_mp4_vid: Video, tmpd
assert read_labels.skeleton.edge_inds == labels.skeleton.edge_inds
assert len(read_labels.tracks) == len(labels.tracks)


def test_nwb(
centered_pair_predictions: Labels,
small_robot_mp4_vid: Video,
tmpdir,
):
"""Test that `Labels` can be written to and recreated from an NWB file."""

labels = centered_pair_predictions
filename = str(PurePath(tmpdir, "ndx_pose_test.nwb"))

# Add another video with an untracked PredictedInstance
labels.videos.append(small_robot_mp4_vid)
pred_instance: PredictedInstance = PredictedInstance.from_instance(
labels[0].instances[0], score=5
)
pred_instance.track = None
lf = LabeledFrame(video=small_robot_mp4_vid, frame_idx=6, instances=[pred_instance])
labels.append(lf)

# Write to NWB file
NDXPoseAdaptor.write(NDXPoseAdaptor, filename, labels)

# Read from NWB file
read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename))
assert_read_labels_match(labels, read_labels)

# Append to NWB File (no changes expected)
NDXPoseAdaptor.write(NDXPoseAdaptor, filename, labels)
read_labels = NDXPoseAdaptor.read(NDXPoseAdaptor, filehandle.FileHandle(filename))
assert_read_labels_match(labels, read_labels)

# Project with no predicted instances
labels.instances = []
with pytest.raises(TypeError):
Expand Down

0 comments on commit 7e833a2

Please sign in to comment.