diff --git a/sleap/gui/app.py b/sleap/gui/app.py index b1c7880bb..6f5733830 100644 --- a/sleap/gui/app.py +++ b/sleap/gui/app.py @@ -484,6 +484,20 @@ def add_submenu_choices(menu, title, options, key): lambda: self.commands.exportAnalysisFile(all_videos=True), ) + export_csv_menu = fileMenu.addMenu("Export Analysis CSV...") + add_menu_item( + export_csv_menu, + "export_csv_current", + "Current Video...", + self.commands.exportCSVFile, + ) + add_menu_item( + export_csv_menu, + "export_csv_all", + "All Videos...", + lambda: self.commands.exportCSVFile(all_videos=True), + ) + add_menu_item(fileMenu, "export_nwb", "Export NWB...", self.commands.exportNWB) fileMenu.addSeparator() diff --git a/sleap/gui/commands.py b/sleap/gui/commands.py index 33fc75a4a..127f2ebb9 100644 --- a/sleap/gui/commands.py +++ b/sleap/gui/commands.py @@ -42,7 +42,6 @@ class which inherits from `AppCommand` (or a more specialized class such as import cv2 import attr from qtpy import QtCore, QtWidgets, QtGui -from qtpy.QtWidgets import QMessageBox, QProgressDialog from sleap.util import get_package_file from sleap.skeleton import Node, Skeleton @@ -51,6 +50,7 @@ class which inherits from `AppCommand` (or a more specialized class such as from sleap.io.convert import default_analysis_filename from sleap.io.dataset import Labels from sleap.io.format.adaptor import Adaptor +from sleap.io.format.csv import CSVAdaptor from sleap.io.format.ndx_pose import NDXPoseAdaptor from sleap.gui.dialogs.delete import DeleteDialog from sleap.gui.dialogs.importvideos import ImportVideos @@ -331,7 +331,11 @@ def saveProjectAs(self): def exportAnalysisFile(self, all_videos: bool = False): """Shows gui for exporting analysis h5 file.""" - self.execute(ExportAnalysisFile, all_videos=all_videos) + self.execute(ExportAnalysisFile, all_videos=all_videos, csv=False) + + def exportCSVFile(self, all_videos: bool = False): + """Shows gui for exporting analysis csv file.""" + self.execute(ExportAnalysisFile, all_videos=all_videos, csv=True) def exportNWB(self): """Show gui for exporting nwb file.""" @@ -1130,13 +1134,20 @@ class ExportAnalysisFile(AppCommand): } export_filter = ";;".join(export_formats.keys()) + export_formats_csv = { + "CSV (*.csv)": "csv", + } + export_filter_csv = ";;".join(export_formats_csv.keys()) + @classmethod def do_action(cls, context: CommandContext, params: dict): from sleap.io.format.sleap_analysis import SleapAnalysisAdaptor from sleap.io.format.nix import NixAdaptor for output_path, video in params["analysis_videos"]: - if Path(output_path).suffix[1:] == "nix": + if params["csv"]: + adaptor = CSVAdaptor + elif Path(output_path).suffix[1:] == "nix": adaptor = NixAdaptor else: adaptor = SleapAnalysisAdaptor @@ -1149,18 +1160,24 @@ def do_action(cls, context: CommandContext, params: dict): @staticmethod def ask(context: CommandContext, params: dict) -> bool: - def ask_for_filename(default_name: str) -> str: + def ask_for_filename(default_name: str, csv: bool) -> str: """Allow user to specify the filename""" + filter = ( + ExportAnalysisFile.export_filter_csv + if csv + else ExportAnalysisFile.export_filter + ) filename, selected_filter = FileDialog.save( context.app, caption="Export Analysis File...", dir=default_name, - filter=ExportAnalysisFile.export_filter, + filter=filter, ) return filename # Ensure labels has labeled frames labels = context.labels + is_csv = params["csv"] if len(labels.labeled_frames) == 0: raise ValueError("No labeled frames in project. Nothing to export.") @@ -1178,7 +1195,7 @@ def ask_for_filename(default_name: str) -> str: # Specify (how to get) the output filename default_name = context.state["filename"] or "labels" fn = PurePath(default_name) - file_extension = "h5" + file_extension = "csv" if is_csv else "h5" if len(videos) == 1: # Allow user to specify the filename use_default = False @@ -1191,18 +1208,23 @@ def ask_for_filename(default_name: str) -> str: caption="Select Folder to Export Analysis Files...", dir=str(fn.parent), ) - if len(ExportAnalysisFile.export_formats) > 1: + export_format = ( + ExportAnalysisFile.export_formats_csv + if is_csv + else ExportAnalysisFile.export_formats + ) + if len(export_format) > 1: item, ok = QtWidgets.QInputDialog.getItem( context.app, "Select export format", "Available export formats", - list(ExportAnalysisFile.export_formats.keys()), + list(export_format.keys()), 0, False, ) if not ok: return False - file_extension = ExportAnalysisFile.export_formats[item] + file_extension = export_format[item] if len(dirname) == 0: return False @@ -1219,7 +1241,9 @@ def ask_for_filename(default_name: str) -> str: format_suffix=file_extension, ) - filename = default_name if use_default else ask_for_filename(default_name) + filename = ( + default_name if use_default else ask_for_filename(default_name, is_csv) + ) # Check that filename is valid and create list of video / output paths if len(filename) != 0: analysis_videos.append(video) diff --git a/sleap/info/write_tracking_h5.py b/sleap/info/write_tracking_h5.py index 8bd583230..2b714eeb5 100644 --- a/sleap/info/write_tracking_h5.py +++ b/sleap/info/write_tracking_h5.py @@ -1,4 +1,4 @@ -"""Generate an HDF5 file with track occupancy and point location data. +"""Generate an HDF5 or CSV file with track occupancy and point location data. Ignores tracks that are entirely empty. By default will also ignore empty frames from the beginning and end of video, although @@ -29,6 +29,7 @@ import json import h5py as h5 import numpy as np +import pandas as pd from typing import Any, Dict, List, Tuple, Union @@ -286,12 +287,77 @@ def write_occupancy_file( print(f"Saved as {output_path}") +def write_csv_file(output_path, data_dict): + + """Write CSV file with data from given dictionary. + + Args: + output_path: Path of HDF5 file. + data_dict: Dictionary with data to save. Keys are dataset names, + values are the data. + + Returns: + None + """ + + if data_dict["tracks"].shape[-1] == 0: + print(f"No tracks to export in {data_dict['video_path']}. Skipping the export") + return + + data_dict["node_names"] = [s.decode() for s in data_dict["node_names"]] + data_dict["track_names"] = [s.decode() for s in data_dict["track_names"]] + data_dict["track_occupancy"] = np.transpose(data_dict["track_occupancy"]).astype( + bool + ) + + # Find frames with at least one animal tracked. + valid_frame_idxs = np.argwhere(data_dict["track_occupancy"].any(axis=1)).flatten() + + tracks = [] + for frame_idx in valid_frame_idxs: + frame_tracks = data_dict["tracks"][frame_idx] + + for i in range(frame_tracks.shape[-1]): + pts = frame_tracks[..., i] + conf_scores = data_dict["point_scores"][frame_idx][..., i] + + if np.isnan(pts).all(): + # Skip if animal wasn't detected in the current frame. + continue + if data_dict["track_names"]: + track = data_dict["track_names"][i] + else: + track = None + + instance_score = data_dict["instance_scores"][frame_idx][i] + + detection = { + "track": track, + "frame_idx": frame_idx, + "instance.score": instance_score, + } + + # Coordinates for each body part. + for node_name, score, (x, y) in zip( + data_dict["node_names"], conf_scores, pts + ): + detection[f"{node_name}.x"] = x + detection[f"{node_name}.y"] = y + detection[f"{node_name}.score"] = score + + tracks.append(detection) + + tracks = pd.DataFrame(tracks) + tracks.to_csv(output_path, index=False) + + def main( labels: Labels, output_path: str, labels_path: str = None, all_frames: bool = True, video: Video = None, + csv: bool = False, ): """Writes HDF5 file with matrices of track occupancy and coordinates. @@ -306,6 +372,7 @@ def main( video: The :py:class:`Video` from which to get data. If no `video` is specified, then the first video in `source_object` videos list will be used. If there are no labeled frames in the `video`, then no output file will be written. + csv: Bool to save the analysis as a csv file if set to True Returns: None @@ -367,7 +434,10 @@ def main( provenance=json.dumps(labels.provenance), # dict cannot be written to hdf5. ) - write_occupancy_file(output_path, data_dict, transpose=True) + if csv: + write_csv_file(output_path, data_dict) + else: + write_occupancy_file(output_path, data_dict, transpose=True) if __name__ == "__main__": diff --git a/sleap/io/format/csv.py b/sleap/io/format/csv.py new file mode 100644 index 000000000..4640ee117 --- /dev/null +++ b/sleap/io/format/csv.py @@ -0,0 +1,70 @@ +"""Adaptor for writing SLEAP analysis as csv.""" + +from sleap.io import format + +from sleap import Labels, Video +from typing import Optional, Callable, List, Text, Union + + +class CSVAdaptor(format.adaptor.Adaptor): + FORMAT_ID = 1.0 + + # 1.0 initial implementation + + @property + def handles(self): + return format.adaptor.SleapObjectType.labels + + @property + def default_ext(self): + return "csv" + + @property + def all_exts(self): + return ["csv", "xlsx"] + + @property + def name(self): + return "CSV" + + def can_read_file(self, file: format.filehandle.FileHandle): + return False + + def can_write_filename(self, filename: str): + return self.does_match_ext(filename) + + def does_read(self) -> bool: + return False + + def does_write(self) -> bool: + return True + + @classmethod + def write( + cls, + filename: str, + source_object: Labels, + source_path: str = None, + video: Video = None, + ): + """Writes csv file for :py:class:`Labels` `source_object`. + + Args: + filename: The filename for the output file. + source_object: The :py:class:`Labels` from which to get data from. + source_path: Path for the labels object + video: The :py:class:`Video` from which toget data from. If no `video` is + specified, then the first video in `source_object` videos list will be + used. If there are no :py:class:`Labeled Frame`s in the `video`, then no + analysis file will be written. + """ + from sleap.info.write_tracking_h5 import main as write_analysis + + write_analysis( + labels=source_object, + output_path=filename, + labels_path=source_path, + all_frames=True, + video=video, + csv=True, + ) diff --git a/tests/data/csv_format/minimal_instance.000_centered_pair_low_quality.analysis.csv b/tests/data/csv_format/minimal_instance.000_centered_pair_low_quality.analysis.csv new file mode 100644 index 000000000..83d3259be --- /dev/null +++ b/tests/data/csv_format/minimal_instance.000_centered_pair_low_quality.analysis.csv @@ -0,0 +1,2 @@ +track,frame_idx,instance.score,A.x,A.y,A.score,B.x,B.y,B.score +,0,nan,205.9300539013689,187.88964024221963,,278.63521449272383,203.3658657346604, diff --git a/tests/fixtures/datasets.py b/tests/fixtures/datasets.py index b8d438fb6..801fcc092 100644 --- a/tests/fixtures/datasets.py +++ b/tests/fixtures/datasets.py @@ -26,6 +26,9 @@ TEST_HDF5_PREDICTIONS = "tests/data/hdf5_format_v1/centered_pair_predictions.h5" TEST_SLP_PREDICTIONS = "tests/data/hdf5_format_v1/centered_pair_predictions.slp" TEST_MIN_DANCE_LABELS = "tests/data/slp_hdf5/dance.mp4.labels.slp" +TEST_CSV_PREDICTIONS = ( + "tests/data/csv_format/minimal_instance.000_centered_pair_low_quality.analysis.csv" +) @pytest.fixture @@ -247,6 +250,11 @@ def centered_pair_predictions_hdf5_path(): return TEST_HDF5_PREDICTIONS +@pytest.fixture +def minimal_instance_predictions_csv_path(): + return TEST_CSV_PREDICTIONS + + @pytest.fixture def centered_pair_predictions_slp_path(): return TEST_SLP_PREDICTIONS diff --git a/tests/gui/test_commands.py b/tests/gui/test_commands.py index fa3ff3d9c..bb708354b 100644 --- a/tests/gui/test_commands.py +++ b/tests/gui/test_commands.py @@ -122,13 +122,19 @@ def ask(obj: RemoveVideo, context: CommandContext, params: dict) -> bool: assert context.state["video"] not in videos_to_remove -@pytest.mark.parametrize("out_suffix", ["h5", "nix"]) +@pytest.mark.parametrize("out_suffix", ["h5", "nix", "csv"]) def test_ExportAnalysisFile( centered_pair_predictions: Labels, + centered_pair_predictions_hdf5_path: str, small_robot_mp4_vid: Video, out_suffix: str, tmpdir, ): + if out_suffix == "csv": + csv = True + else: + csv = False + def ExportAnalysisFile_ask(context: CommandContext, params: dict): """Taken from ExportAnalysisFile.ask()""" @@ -151,7 +157,7 @@ def ask_for_filename(default_name: str) -> str: if len(videos) == 0: raise ValueError("No labeled frames in video(s). Nothing to export.") - default_name = context.state["filename"] or "labels" + default_name = "labels" fn = PurePath(tmpdir, default_name) if len(videos) == 1: # Allow user to specify the filename @@ -194,7 +200,7 @@ def assert_videos_written(num_videos: int, labels_path: str = None): assert Path(output_path).exists() output_paths.append(output_path) - if labels_path is not None: + if labels_path is not None and not params["csv"]: meta_reader = extract_meta_hdf5 if out_suffix == "h5" else read_nix_meta labels_key = "labels_path" if out_suffix == "h5" else "project" read_meta = meta_reader(output_path, dset_names_in=["labels_path"]) @@ -209,8 +215,20 @@ def assert_videos_written(num_videos: int, labels_path: str = None): context = CommandContext.from_labels(labels) context.state["filename"] = None + if csv: + + context.state["filename"] = centered_pair_predictions_hdf5_path + + params = {"all_videos": True, "csv": csv} + okay = ExportAnalysisFile_ask(context=context, params=params) + assert okay == True + ExportAnalysisFile.do_action(context=context, params=params) + assert_videos_written(num_videos=1, labels_path=context.state["filename"]) + + return + # Test with all_videos False (single video) - params = {"all_videos": False} + params = {"all_videos": False, "csv": csv} okay = ExportAnalysisFile_ask(context=context, params=params) assert okay == True ExportAnalysisFile.do_action(context=context, params=params) @@ -218,7 +236,7 @@ def assert_videos_written(num_videos: int, labels_path: str = None): # Add labels path and test with all_videos True (single video) context.state["filename"] = str(tmpdir.with_name("path.to.labels")) - params = {"all_videos": True} + params = {"all_videos": True, "csv": csv} okay = ExportAnalysisFile_ask(context=context, params=params) assert okay == True ExportAnalysisFile.do_action(context=context, params=params) @@ -227,7 +245,7 @@ def assert_videos_written(num_videos: int, labels_path: str = None): # Add a video (no labels) and test with all_videos True labels.add_video(small_robot_mp4_vid) - params = {"all_videos": True} + params = {"all_videos": True, "csv": csv} okay = ExportAnalysisFile_ask(context=context, params=params) assert okay == True ExportAnalysisFile.do_action(context=context, params=params) @@ -239,7 +257,7 @@ def assert_videos_written(num_videos: int, labels_path: str = None): labels.add_instance(frame=labeled_frame, instance=instance) labels.append(labeled_frame) - params = {"all_videos": False} + params = {"all_videos": False, "csv": csv} okay = ExportAnalysisFile_ask(context=context, params=params) assert okay == True ExportAnalysisFile.do_action(context=context, params=params) @@ -248,14 +266,14 @@ def assert_videos_written(num_videos: int, labels_path: str = None): # Add specific video and test with all_videos False context.state["videos"] = labels.videos[1] - params = {"all_videos": False} + params = {"all_videos": False, "csv": csv} okay = ExportAnalysisFile_ask(context=context, params=params) assert okay == True ExportAnalysisFile.do_action(context=context, params=params) assert_videos_written(num_videos=1, labels_path=context.state["filename"]) # Test with all videos True - params = {"all_videos": True} + params = {"all_videos": True, "csv": csv} okay = ExportAnalysisFile_ask(context=context, params=params) assert okay == True ExportAnalysisFile.do_action(context=context, params=params) @@ -273,7 +291,7 @@ def assert_videos_written(num_videos: int, labels_path: str = None): labels.videos[0].backend.filename = str(tmpdir / "session1" / "video.mp4") labels.videos[1].backend.filename = str(tmpdir / "session2" / "video.mp4") - params = {"all_videos": True} + params = {"all_videos": True, "csv": csv} okay = ExportAnalysisFile_ask(context=context, params=params) assert okay == True ExportAnalysisFile.do_action(context=context, params=params) @@ -284,7 +302,7 @@ def assert_videos_written(num_videos: int, labels_path: str = None): for video in all_videos: labels.remove_video(labels.videos[-1]) - params = {"all_videos": True} + params = {"all_videos": True, "csv": csv} with pytest.raises(ValueError): okay = ExportAnalysisFile_ask(context=context, params=params) diff --git a/tests/io/test_formats.py b/tests/io/test_formats.py index b28de176e..a89bf60d7 100644 --- a/tests/io/test_formats.py +++ b/tests/io/test_formats.py @@ -2,6 +2,7 @@ from pathlib import Path, PurePath import numpy as np +import pandas as pd from numpy.testing import assert_array_equal import pytest import nixio @@ -17,6 +18,7 @@ from sleap.gui.commands import ImportAlphaTracker from sleap.gui.app import MainWindow from sleap.gui.state import GuiState +from sleap.info.write_tracking_h5 import get_nodes_as_np_strings def test_text_adaptor(tmpdir): @@ -126,6 +128,24 @@ def test_hdf5_v1_filehandle(centered_pair_predictions_hdf5_path): ) +def test_csv(tmpdir, min_labels_slp, minimal_instance_predictions_csv_path): + from sleap.info.write_tracking_h5 import main as write_analysis + + filename_csv = str(tmpdir + "\\analysis.csv") + write_analysis(min_labels_slp, output_path=filename_csv, all_frames=True, csv=True) + + labels_csv = pd.read_csv(filename_csv) + + csv_predictions = pd.read_csv(minimal_instance_predictions_csv_path) + + assert labels_csv.equals(csv_predictions) + + labels = min_labels_slp + + # check number of cols + assert len(labels_csv.columns) - 3 == len(get_nodes_as_np_strings(labels)) * 3 + + def test_analysis_hdf5(tmpdir, centered_pair_predictions): from sleap.info.write_tracking_h5 import main as write_analysis