Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Option to Export CSV #1438

Merged
merged 4 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions sleap/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,20 @@ def add_submenu_choices(menu, title, options, key):
lambda: self.commands.exportAnalysisFile(all_videos=True),
)

export_csv_menu = fileMenu.addMenu("Export Analysis CSV...")
add_menu_item(
export_csv_menu,
"export_csv_current",
"Current Video...",
self.commands.exportCSVFile,
)
add_menu_item(
export_csv_menu,
"export_csv_all",
"All Videos...",
lambda: self.commands.exportCSVFile(all_videos=True),
)

add_menu_item(fileMenu, "export_nwb", "Export NWB...", self.commands.exportNWB)

fileMenu.addSeparator()
Expand Down
44 changes: 34 additions & 10 deletions sleap/gui/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class which inherits from `AppCommand` (or a more specialized class such as
import cv2
import attr
from qtpy import QtCore, QtWidgets, QtGui
from qtpy.QtWidgets import QMessageBox, QProgressDialog

from sleap.util import get_package_file
from sleap.skeleton import Node, Skeleton
Expand All @@ -51,6 +50,7 @@ class which inherits from `AppCommand` (or a more specialized class such as
from sleap.io.convert import default_analysis_filename
from sleap.io.dataset import Labels
from sleap.io.format.adaptor import Adaptor
from sleap.io.format.csv import CSVAdaptor
from sleap.io.format.ndx_pose import NDXPoseAdaptor
from sleap.gui.dialogs.delete import DeleteDialog
from sleap.gui.dialogs.importvideos import ImportVideos
Expand Down Expand Up @@ -331,7 +331,11 @@ def saveProjectAs(self):

def exportAnalysisFile(self, all_videos: bool = False):
"""Shows gui for exporting analysis h5 file."""
self.execute(ExportAnalysisFile, all_videos=all_videos)
self.execute(ExportAnalysisFile, all_videos=all_videos, csv=False)

def exportCSVFile(self, all_videos: bool = False):
"""Shows gui for exporting analysis csv file."""
self.execute(ExportAnalysisFile, all_videos=all_videos, csv=True)

def exportNWB(self):
"""Show gui for exporting nwb file."""
Expand Down Expand Up @@ -1130,13 +1134,20 @@ class ExportAnalysisFile(AppCommand):
}
export_filter = ";;".join(export_formats.keys())

export_formats_csv = {
"CSV (*.csv)": "csv",
}
export_filter_csv = ";;".join(export_formats_csv.keys())

@classmethod
def do_action(cls, context: CommandContext, params: dict):
from sleap.io.format.sleap_analysis import SleapAnalysisAdaptor
from sleap.io.format.nix import NixAdaptor

for output_path, video in params["analysis_videos"]:
if Path(output_path).suffix[1:] == "nix":
if params["csv"]:
adaptor = CSVAdaptor
elif Path(output_path).suffix[1:] == "nix":
adaptor = NixAdaptor
else:
adaptor = SleapAnalysisAdaptor
Expand All @@ -1149,18 +1160,24 @@ def do_action(cls, context: CommandContext, params: dict):

@staticmethod
def ask(context: CommandContext, params: dict) -> bool:
def ask_for_filename(default_name: str) -> str:
def ask_for_filename(default_name: str, csv: bool) -> str:
"""Allow user to specify the filename"""
filter = (
ExportAnalysisFile.export_filter_csv
if csv
else ExportAnalysisFile.export_filter
)
filename, selected_filter = FileDialog.save(
context.app,
caption="Export Analysis File...",
dir=default_name,
filter=ExportAnalysisFile.export_filter,
filter=filter,
)
return filename

# Ensure labels has labeled frames
labels = context.labels
is_csv = params["csv"]
if len(labels.labeled_frames) == 0:
raise ValueError("No labeled frames in project. Nothing to export.")

Expand All @@ -1178,7 +1195,7 @@ def ask_for_filename(default_name: str) -> str:
# Specify (how to get) the output filename
default_name = context.state["filename"] or "labels"
fn = PurePath(default_name)
file_extension = "h5"
file_extension = "csv" if is_csv else "h5"
if len(videos) == 1:
# Allow user to specify the filename
use_default = False
Expand All @@ -1191,18 +1208,23 @@ def ask_for_filename(default_name: str) -> str:
caption="Select Folder to Export Analysis Files...",
dir=str(fn.parent),
)
if len(ExportAnalysisFile.export_formats) > 1:
export_format = (
ExportAnalysisFile.export_formats_csv
if is_csv
else ExportAnalysisFile.export_formats
)
if len(export_format) > 1:
item, ok = QtWidgets.QInputDialog.getItem(
context.app,
"Select export format",
"Available export formats",
list(ExportAnalysisFile.export_formats.keys()),
list(export_format.keys()),
0,
False,
)
if not ok:
return False
file_extension = ExportAnalysisFile.export_formats[item]
file_extension = export_format[item]
if len(dirname) == 0:
return False

Expand All @@ -1219,7 +1241,9 @@ def ask_for_filename(default_name: str) -> str:
format_suffix=file_extension,
)

filename = default_name if use_default else ask_for_filename(default_name)
filename = (
default_name if use_default else ask_for_filename(default_name, is_csv)
)
# Check that filename is valid and create list of video / output paths
if len(filename) != 0:
analysis_videos.append(video)
Expand Down
74 changes: 72 additions & 2 deletions sleap/info/write_tracking_h5.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Generate an HDF5 file with track occupancy and point location data.
"""Generate an HDF5 or CSV file with track occupancy and point location data.

Ignores tracks that are entirely empty. By default will also ignore
empty frames from the beginning and end of video, although
Expand Down Expand Up @@ -29,6 +29,7 @@
import json
import h5py as h5
import numpy as np
import pandas as pd

from typing import Any, Dict, List, Tuple, Union

Expand Down Expand Up @@ -286,12 +287,77 @@ def write_occupancy_file(
print(f"Saved as {output_path}")


def write_csv_file(output_path, data_dict):

"""Write CSV file with data from given dictionary.

Args:
output_path: Path of HDF5 file.
data_dict: Dictionary with data to save. Keys are dataset names,
values are the data.

Returns:
None
"""

if data_dict["tracks"].shape[-1] == 0:
print(f"No tracks to export in {data_dict['video_path']}. Skipping the export")
return

data_dict["node_names"] = [s.decode() for s in data_dict["node_names"]]
data_dict["track_names"] = [s.decode() for s in data_dict["track_names"]]
data_dict["track_occupancy"] = np.transpose(data_dict["track_occupancy"]).astype(
bool
)

# Find frames with at least one animal tracked.
valid_frame_idxs = np.argwhere(data_dict["track_occupancy"].any(axis=1)).flatten()

tracks = []
for frame_idx in valid_frame_idxs:
frame_tracks = data_dict["tracks"][frame_idx]

for i in range(frame_tracks.shape[-1]):
pts = frame_tracks[..., i]
conf_scores = data_dict["point_scores"][frame_idx][..., i]

if np.isnan(pts).all():
# Skip if animal wasn't detected in the current frame.
continue
if data_dict["track_names"]:
track = data_dict["track_names"][i]
else:
track = None

instance_score = data_dict["instance_scores"][frame_idx][i]

detection = {
"track": track,
"frame_idx": frame_idx,
"instance.score": instance_score,
}

# Coordinates for each body part.
for node_name, score, (x, y) in zip(
data_dict["node_names"], conf_scores, pts
):
detection[f"{node_name}.x"] = x
detection[f"{node_name}.y"] = y
detection[f"{node_name}.score"] = score

tracks.append(detection)

tracks = pd.DataFrame(tracks)
tracks.to_csv(output_path, index=False)


def main(
labels: Labels,
output_path: str,
labels_path: str = None,
all_frames: bool = True,
video: Video = None,
csv: bool = False,
):
"""Writes HDF5 file with matrices of track occupancy and coordinates.

Expand All @@ -306,6 +372,7 @@ def main(
video: The :py:class:`Video` from which to get data. If no `video` is specified,
then the first video in `source_object` videos list will be used. If there
are no labeled frames in the `video`, then no output file will be written.
csv: Bool to save the analysis as a csv file if set to True

Returns:
None
Expand Down Expand Up @@ -367,7 +434,10 @@ def main(
provenance=json.dumps(labels.provenance), # dict cannot be written to hdf5.
)

write_occupancy_file(output_path, data_dict, transpose=True)
if csv:
write_csv_file(output_path, data_dict)
else:
write_occupancy_file(output_path, data_dict, transpose=True)


if __name__ == "__main__":
Expand Down
70 changes: 70 additions & 0 deletions sleap/io/format/csv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Adaptor for writing SLEAP analysis as csv."""

from sleap.io import format

from sleap import Labels, Video
from typing import Optional, Callable, List, Text, Union


class CSVAdaptor(format.adaptor.Adaptor):
FORMAT_ID = 1.0

# 1.0 initial implementation

@property
def handles(self):
return format.adaptor.SleapObjectType.labels

@property
def default_ext(self):
return "csv"

@property
def all_exts(self):
return ["csv", "xlsx"]

@property
def name(self):
return "CSV"

def can_read_file(self, file: format.filehandle.FileHandle):
return False

def can_write_filename(self, filename: str):
return self.does_match_ext(filename)

def does_read(self) -> bool:
return False

def does_write(self) -> bool:
return True

@classmethod
def write(
cls,
filename: str,
source_object: Labels,
source_path: str = None,
video: Video = None,
):
"""Writes csv file for :py:class:`Labels` `source_object`.

Args:
filename: The filename for the output file.
source_object: The :py:class:`Labels` from which to get data from.
source_path: Path for the labels object
video: The :py:class:`Video` from which toget data from. If no `video` is
specified, then the first video in `source_object` videos list will be
used. If there are no :py:class:`Labeled Frame`s in the `video`, then no
analysis file will be written.
"""
from sleap.info.write_tracking_h5 import main as write_analysis

write_analysis(
labels=source_object,
output_path=filename,
labels_path=source_path,
all_frames=True,
video=video,
csv=True,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
track,frame_idx,instance.score,A.x,A.y,A.score,B.x,B.y,B.score
,0,nan,205.9300539013689,187.88964024221963,,278.63521449272383,203.3658657346604,
8 changes: 8 additions & 0 deletions tests/fixtures/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
TEST_HDF5_PREDICTIONS = "tests/data/hdf5_format_v1/centered_pair_predictions.h5"
TEST_SLP_PREDICTIONS = "tests/data/hdf5_format_v1/centered_pair_predictions.slp"
TEST_MIN_DANCE_LABELS = "tests/data/slp_hdf5/dance.mp4.labels.slp"
TEST_CSV_PREDICTIONS = (
"tests/data/csv_format/minimal_instance.000_centered_pair_low_quality.analysis.csv"
)


@pytest.fixture
Expand Down Expand Up @@ -247,6 +250,11 @@ def centered_pair_predictions_hdf5_path():
return TEST_HDF5_PREDICTIONS


@pytest.fixture
def minimal_instance_predictions_csv_path():
return TEST_CSV_PREDICTIONS


@pytest.fixture
def centered_pair_predictions_slp_path():
return TEST_SLP_PREDICTIONS
Expand Down
Loading