diff --git a/docs/source/index.rst b/docs/source/index.rst index 55fccd0..dd713f1 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,10 +23,12 @@ Loading data is simple with the ``snub.io`` module. For example the following co snub.io.add_splikeplot(project_directory, 'my_ephys_data', spike_times, spike_labels) +We also support automatic conversion of NWB files to SNUB projects for a limited set of NWB neurodata types. + + SNUB Documentation ------------------ - .. toctree:: :maxdepth: 2 @@ -34,6 +36,8 @@ SNUB Documentation tutorials + nwb + snub.io diff --git a/docs/source/nwb.rst b/docs/source/nwb.rst new file mode 100644 index 0000000..1021f6b --- /dev/null +++ b/docs/source/nwb.rst @@ -0,0 +1,131 @@ +Neurodata Without Borders +========================= + +We provide a rudimentary tool for automatically generating a SNUB project from NWB files, which contain raw and processed data from neuroscience recordings. The data are stored hierarchically, and each component of the hierarchy has a specific neurodata type that reflects the measurement modality (e.g ``Units`` for spike trains, ``ImageSeries`` for video, etc.). Our conversion tool generates a SNUB subplot for each supported neurodata type. Users can optionally restrict this process to a subset of the NWB hierarchy (e.g. include pose tracking while excluding electrophysiology, or include just a subset of electrophysiology measurements). + + +Neurodata types +--------------- + +The following neurodata types are supported: + +- ``IntervalSeries`` + Contains start and end times for (possibly labeled) intervals. A SNUB trace plot is generated containing one trace per interval type. + +- ``RoiResponseSeries`` + Contains fluorescence traces for regions of interest (ROIs). A SNUB heatmap is generated containing one row per ROI. Metadata associated with each ROI is not linked in the SNUB plot. + +- ``TimeSeries`` + Contains time series in one or more dimensions. A SNUB heatmap is generated for 15 or more dimensions, and a SNUB trace plot is generaed for fewer than 15 dimensions. + +- ``PoseEstimation`` + Contains pose tracking data (available via the ``ndx-pose`` extension). A SNUB trace plot is generated for each tracked body part and spatial dimension. For 3D data, a 3D pose plot is also generated. + +- ``ImageSeries`` + Contains video data. We assume that the video is stored as a separate file and that the ``ImageSeries`` object contains frame timestamps and a relative path to that file. A SNUB video plot is then generated. + +- ``LabelSeries`` + Contains discrete label time series in the form of a binary matrix with one column per label abd one row per time bin (available via the ``ndx-labels`` extension). A SNUB heatmap is generated directly from this matrix. + +- ``TimeIntervals`` + Contains annotated intervals. Each interval has a start time, a stop time, and an arbitrary number of additional metadata fields. A SNUB trace plot is generated with one trace showing the start and stop times of each interval. All other metadata is ignored since it cannot be canonically represented using the currently available SNUB plot types. + +- ``Position`` + Contains position data in the form of one or more ``SpatialSeries`` objects. A SNUB trace plot is generated with traces for each spatial dimensions of each consistuent spatial series. + +- ``SpatialSeries`` + Contains spatial data in the form of a time series with one or more dimensions. A standalone SNUB trace plot is generated for the spatial series if it is not contained within a ``Position`` object. + +- ``Units`` + Contains spike trains for one or more units. A corresponding SNUB spike plot is generated. + +- ``Events`` + Contains a sequence of unlabeled event times (available via the ``ndx-events`` extension). A SNUB trace plot is generated with a single trace that spikes at each event time. + + +Examples +-------- + +For each example, run the first code block in a terminal and the second in a python console or notebook. + + +A change in behavioral state switches the pattern of motor output that underlies rhythmic head and orofacial movements +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +*Liao, Song-Mao; Kleinfeld, David; Rinehart, Duane; University of California San Diego (2023) Dataset for: A change in behavioral state switches the pattern of motor output that underlies rhythmic head and orofacial movements (Version 0.230515.0530) [Data set]. DANDI archive. https://doi.org/10.48324/dandi.000540/0.230515.0530* + +Includes the following SNUB-compatible neurodata types: ``TimeSeries``, ``ImageSeries`` + + .. code-block:: bash + + # Download NWB file + dandi download https://api.dandiarchive.org/api/dandisets/000540/versions/0.230515.0530/assets/94307bee-459c-424e-b3a0-1e86b23f04b2/download/ + + # Download associated video and create directory for it + dandi download https://api.dandiarchive.org/api/dandisets/000540/versions/0.230515.0530/assets/942b0806-2c8b-4289-a072-9e965884fcb6/download/ + mkdir sub-SLR087_ses-20180706_obj-14ua2bs_behavior+image + mv 9557b48e-46f0-45f2-a700-a2e15318c5bc_external_file_0.avi sub-SLR087_ses-20180706_obj-14ua2bs_behavior+image/ + + + .. code-block:: python + + import os, snub + + # Define paths + nwb_file = "sub-SLR087_ses-20180706_obj-14ua2bs_behavior+image.nwb" + name = os.path.splitext(os.path.basename(nwb_file))[0] + project_directory = os.path.join(os.path.dirname(nwb_file), f"SNUB-{name}") + + # Make SNUB plot that includes video and torso tracking + snub.io.create_project_from_nwb(project_directory, nwb_file, branches=['torso_dlc', 'ImageSeries']) + + +A Unified Framework for Dopamine Signals across Timescales +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +*Kim, HyungGoo; Malik, Athar; Mikhael, John; Bech, Pol; Tsutsui-Kimura, Iku; Sun, Fangmiao; Zhang, Yajun; Li, Yulong; Watabe-Uchida, Mitsuko; Gershman, Samuel; Uchida, Naoshige (2023) A Unified Framework for Dopamine Signals across Timescales (Version draft) [Data set]. DANDI archive. https://dandiarchive.org/dandiset/000251/draft* + +Includes the following SNUB-compatible neurodata types: ``TimeSeries``, ``TimeIntervals``, ``SpatialSeries``, ``Events`` + + .. code-block:: bash + + # Download NWB file + dandi download https://api.dandiarchive.org/api/dandisets/000251/versions/draft/assets/b28fcb84-2e23-472c-913c-383151bc58ef/download/ + + + .. code-block:: python + + import os, snub + + # Define paths + nwb_file = "sub-108_ses-Ca-VS-VR-2.nwb" + name = os.path.splitext(os.path.basename(nwb_file))[0] + project_directory = os.path.join(os.path.dirname(nwb_file), f"SNUB-{name}") + + # Make SNUB plot + snub.io.create_project_from_nwb(project_directory, nwb_file) + + +Neural population dynamics during reaching +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +*Churchland, Mark; Cunningham, John P.; Kaufman, Matthew T.; Foster, Justin D.; Nuyujukian, Paul; Ryu, Stephen I.; Shenoy, Krishna V. (2022) Neural population dynamics during reaching (Version draft) [Data set]. DANDI archive. https://dandiarchive.org/dandiset/000070/draft* + +Includes the following SNUB-compatible neurodata types: ``Units``, ``TimeIntervals``, ``Position`` + + .. code-block:: bash + + # Download NWB file + dandi download https://api.dandiarchive.org/api/dandisets/000070/versions/draft/assets/7b95fe3a-c859-4406-b80d-e50bad775d01/download/ + + .. code-block:: python + + import os, snub + + # Define paths + nwb_file = "sub-Jenkins_ses-20090912_behavior+ecephys.nwb" + name = os.path.splitext(os.path.basename(nwb_file))[0] + project_directory = os.path.join(os.path.dirname(nwb_file), f"SNUB-{name}") + + # Make SNUB plot + snub.io.create_project_from_nwb(project_directory, nwb_file) diff --git a/paper.bib b/paper.bib index ba5f9e3..8913cc3 100644 --- a/paper.bib +++ b/paper.bib @@ -1,93 +1,124 @@ @software{vispy, - author = {Luke Campagnola and - Eric Larson and - Almar Klein and - David Hoese and - Siddharth and - Cyrille Rossant and - Adam Griffiths and - Nicolas P. Rougier and - asnt and - Kai Mühlbauer and - Alexander Taylor and - MSS and - Talley Lambert and - sylm21 and - Alex J. Champandard and - Max Hunter and - Thomas Robitaille and - Mustafa Furkan Kaptan and - Elliott Sales de Andrade and - Karl Czajkowski and - Lorenzo Gaifas and - Alessandro Bacchini and - Guillaume Favelier and - Etienne Combrisson and - ThenTech and - fschill and - Mark Harfouche and - Michael Aye and - Casper van Elteren and - Cedric GESTES}, - title = {vispy/vispy: Version 0.11.0}, - month = jul, - year = 2022, - publisher = {Zenodo}, - version = {v0.11.0}, - doi = {10.5281/zenodo.6795163}, - url = {https://doi.org/10.5281/zenodo.6795163} + author = {Luke Campagnola and + Eric Larson and + Almar Klein and + David Hoese and + Siddharth and + Cyrille Rossant and + Adam Griffiths and + Nicolas P. Rougier and + asnt and + Kai Mühlbauer and + Alexander Taylor and + MSS and + Talley Lambert and + sylm21 and + Alex J. Champandard and + Max Hunter and + Thomas Robitaille and + Mustafa Furkan Kaptan and + Elliott Sales de Andrade and + Karl Czajkowski and + Lorenzo Gaifas and + Alessandro Bacchini and + Guillaume Favelier and + Etienne Combrisson and + ThenTech and + fschill and + Mark Harfouche and + Michael Aye and + Casper van Elteren and + Cedric GESTES}, + title = {vispy/vispy: Version 0.11.0}, + month = jul, + year = 2022, + publisher = {Zenodo}, + version = {v0.11.0}, + doi = {10.5281/zenodo.6795163}, + url = {https://doi.org/10.5281/zenodo.6795163} } @article{bento, - abstract = {The study of naturalistic social behavior requires quantification of animals' interactions. This is generally done through manual annotation---a highly time-consuming and tedious process. Recent advances in computer vision enable tracking the pose (posture) of freely behaving animals. However, automatically and accurately classifying complex social behaviors remains technically challenging. We introduce the Mouse Action Recognition System (MARS), an automated pipeline for pose estimation and behavior quantification in pairs of freely interacting mice. We compare MARS's annotations to human annotations and find that MARS's pose estimation and behavior classification achieve human-level performance. We also release the pose and annotation datasets used to train MARS to serve as community benchmarks and resources. Finally, we introduce the Behavior Ensemble and Neural Trajectory Observatory (BENTO), a graphical user interface for analysis of multimodal neuroscience datasets. Together, MARS and BENTO provide an end-to-end pipeline for behavior data extraction and analysis in a package that is user-friendly and easily modifiable.}, - article_type = {journal}, - author = {Segalin, Cristina and Williams, Jalani and Karigo, Tomomi and Hui, May and Zelikowsky, Moriel and Sun, Jennifer J and Perona, Pietro and Anderson, David J and Kennedy, Ann}, - citation = {eLife 2021;10:e63720}, + abstract = {The study of naturalistic social behavior requires quantification of animals' interactions. This is generally done through manual annotation---a highly time-consuming and tedious process. Recent advances in computer vision enable tracking the pose (posture) of freely behaving animals. However, automatically and accurately classifying complex social behaviors remains technically challenging. We introduce the Mouse Action Recognition System (MARS), an automated pipeline for pose estimation and behavior quantification in pairs of freely interacting mice. We compare MARS's annotations to human annotations and find that MARS's pose estimation and behavior classification achieve human-level performance. We also release the pose and annotation datasets used to train MARS to serve as community benchmarks and resources. Finally, we introduce the Behavior Ensemble and Neural Trajectory Observatory (BENTO), a graphical user interface for analysis of multimodal neuroscience datasets. Together, MARS and BENTO provide an end-to-end pipeline for behavior data extraction and analysis in a package that is user-friendly and easily modifiable.}, + article_type = {journal}, + author = {Segalin, Cristina and Williams, Jalani and Karigo, Tomomi and Hui, May and Zelikowsky, Moriel and Sun, Jennifer J and Perona, Pietro and Anderson, David J and Kennedy, Ann}, + citation = {eLife 2021;10:e63720}, date-modified = {2022-10-05 13:18:49 -0400}, - doi = {10.7554/eLife.63720}, - editor = {Berman, Gordon J and Wassum, Kate M and Gal, Asaf}, - issn = {2050-084X}, - journal = {eLife}, - keywords = {social behavior, pose estimation, machine learning, computer vision, microendoscopic imaging, software}, - month = {nov}, - pages = {e63720}, - pub_date = {2021-11-30}, - publisher = {eLife Sciences Publications, Ltd}, - title = {The Mouse Action Recognition System (MARS) software pipeline for automated analysis of social behaviors in mice}, - url = {https://doi.org/10.7554/eLife.63720}, - volume = 10, - year = 2021, - bdsk-url-1 = {https://doi.org/10.7554/eLife.63720}} + doi = {10.7554/eLife.63720}, + editor = {Berman, Gordon J and Wassum, Kate M and Gal, Asaf}, + issn = {2050-084X}, + journal = {eLife}, + keywords = {social behavior, pose estimation, machine learning, computer vision, microendoscopic imaging, software}, + month = {nov}, + pages = {e63720}, + pub_date = {2021-11-30}, + publisher = {eLife Sciences Publications, Ltd}, + title = {The Mouse Action Recognition System (MARS) software pipeline for automated analysis of social behaviors in mice}, + url = {https://doi.org/10.7554/eLife.63720}, + volume = 10, + year = 2021, + bdsk-url-1 = {https://doi.org/10.7554/eLife.63720} +} @misc{rastermap, - author = {C. Stringer and M. Pachitariu}, - title = {rastermap}, - year = {2020}, + author = {C. Stringer and M. Pachitariu}, + title = {rastermap}, + year = {2020}, publisher = {GitHub}, - journal = {GitHub repository}, - url = {https://github.com/MouseLand/rastermap} + journal = {GitHub repository}, + url = {https://github.com/MouseLand/rastermap} } @misc{vidio, - author = {J. Bohnslav}, - title = {VidIO: simple, performant video reading and writing in python}, - year = {2020}, + author = {J. Bohnslav}, + title = {VidIO: simple, performant video reading and writing in python}, + year = {2020}, publisher = {GitHub}, - journal = {GitHub repository}, - url = {https://github.com/jbohnslav/vidio} + journal = {GitHub repository}, + url = {https://github.com/jbohnslav/vidio} } @misc{umap, - doi = {10.48550/ARXIV.1802.03426}, - url = {https://arxiv.org/abs/1802.03426}, - author = {McInnes, Leland and Healy, John and Melville, James}, - keywords = {Machine Learning (stat.ML), Computational Geometry (cs.CG), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences}, - title = {UMAP: Uniform Manifold Approximation and Projection for Dimension Reduction}, + doi = {10.48550/ARXIV.1802.03426}, + url = {https://arxiv.org/abs/1802.03426}, + author = {McInnes, Leland and Healy, John and Melville, James}, + keywords = {Machine Learning (stat.ML), Computational Geometry (cs.CG), Machine Learning (cs.LG), FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {UMAP: Uniform Manifold Approximation and Projection for Dimension Reduction}, publisher = {arXiv}, - year = {2018}, + year = {2018}, copyright = {arXiv.org perpetual, non-exclusive license} } +@misc{petrucco_2020_3925903, + author = {Petrucco, Luigi}, + title = {Mouse head schema}, + month = jul, + year = 2020, + publisher = {Zenodo}, + doi = {10.5281/zenodo.3925903}, + url = {https://doi.org/10.5281/zenodo.3925903} +} + + +@article{NWB, + article_type = {journal}, + title = {The Neurodata Without Borders ecosystem for neurophysiological data science}, + author = {Rübel, Oliver and Tritt, Andrew and Ly, Ryan and Dichter, Benjamin K and Ghosh, Satrajit and Niu, Lawrence and Baker, Pamela and Soltesz, Ivan and Ng, Lydia and Svoboda, Karel and Frank, Loren and Bouchard, Kristofer E}, + editor = {Colgin, Laura L and Jadhav, Shantanu P}, + volume = 11, + year = 2022, + month = {oct}, + pub_date = {2022-10-04}, + pages = {e78362}, + citation = {eLife 2022;11:e78362}, + doi = {10.7554/eLife.78362}, + url = {https://doi.org/10.7554/eLife.78362}, + abstract = {The neurophysiology of cells and tissues are monitored electrophysiologically and optically in diverse experiments and species, ranging from flies to humans. Understanding the brain requires integration of data across this diversity, and thus these data must be findable, accessible, interoperable, and reusable (FAIR). This requires a standard language for data and metadata that can coevolve with neuroscience. We describe design and implementation principles for a language for neurophysiology data. Our open-source software (Neurodata Without Borders, NWB) defines and modularizes the interdependent, yet separable, components of a data language. We demonstrate NWB’s impact through unified description of neurophysiology data across diverse modalities and species. NWB exists in an ecosystem, which includes data management, analysis, visualization, and archive tools. Thus, the NWB data language enables reproduction, interchange, and reuse of diverse neurophysiology data. More broadly, the design principles of NWB are generally applicable to enhance discovery across biology through data FAIRness.}, + keywords = {Neurophysiology, data ecosystem, data language, data standard, FAIR data, archive}, + journal = {eLife}, + issn = {2050-084X}, + publisher = {eLife Sciences Publications, Ltd} +} diff --git a/paper.md b/paper.md index d35ca59..7a00bac 100644 --- a/paper.md +++ b/paper.md @@ -34,7 +34,7 @@ Direct inspection of behavior and neurophysiology recordings is hard because the We provide dedicated widgets and loading functions for exploring raw video, 3D animal pose, behavior annotations, electrophysiology recordings, and calcium imaging data - either as a raster or as a super-position of labeled regions of interest (ROIs). More broadly, SNUB can dislay any data that takes the form of a heatmap, scatter plot, video, or collection of named temporally-varying signals. -In addition to the front-end GUI, we include a library of functions for ingesting raw data and saving it to a format that is readable by the SNUB viewer. The following code, for example, creates a project with paired electrophysiology and video data. +In addition to the front-end GUI, we include a library of functions that ingest raw data and save it to a format that is readable by the SNUB viewer. The following code, for example, creates a project with paired electrophysiology and video data. ``` snub.io.create_project(project_directory, duration=1800) @@ -42,12 +42,14 @@ snub.io.add_video(project_directory, 'path/to/my_video.avi', name='IR_camera') snub.io.add_splikeplot(project_directory, 'my_ephys_data', spike_times, spike_labels) ``` +We also provide a rudimentary tool for automatically generating SNUB datasets from Neurodata Without Borders (NWB) files, which contain raw and processed data from neuroscience recordings [@NWB]. The data in NWB files are stored hierarchically, and each component of the hierarchy has a specific neurodata type that reflects the measurement modality (e.g, "Units" for spike trains, "ImageSeries" for video). Our conversion tool generates a SNUB display element for each supported neurodata type. Users can optionally restrict this process to a subset of the NWB hierarchy (e.g., include pose tracking while excluding electrophysiology, or include just a subset of electrophysiology measurements). + SNUB is a flexible general-purpose tool that complements more specialized packages such as rastermap [@rastermap] and Bento [@bento]. The rastermap interface, for example, is hard-coded for the display of neural activity rasters, ROIs and 2D embeddings of neural activity. Bento is hard-coded for the display of neural activity rasters, behavioral videos and behavioral annotations. SNUB can reproduce either of these configurations and is especially useful when one wishes to include additional types of data or more directly customize the way that data is rendered. The graphics in SNUB are powered by vispy [@vispy]. SNUB includes wrappers for several dimensionality reduction methods, including rastermap [@rastermap] for ordering raster plots and UMAP [@umap] for 2D scatter plots. Fast video loading is enabled by vidio [@vidio]. The app icon was adapted from a drawing contributed to scidraw by Luigi Petrucco [@petrucco_2020_3925903]. # Acknowledgements -We are grateful to Mohammed Osman for initial contributions to the 3D keypoint visualization tool. CW is a Fellow of The Jane Coffin Childs Memorial Fund for Medical Research. SRD is supported by NIH grants U19NS113201, RF1AG073625, R01NS114020, the Brain Research Foundation, and the Simons Collaboration on the Global Brain. +We are grateful to Mohammed Osman for contributions to the 3D keypoint visualization and NWB conversion tools. CW is a Fellow of The Jane Coffin Childs Memorial Fund for Medical Research. SRD is supported by NIH grants U19NS113201, RF1AG073625, R01NS114020, the Brain Research Foundation, and the Simons Collaboration on the Global Brain. # References \ No newline at end of file diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..71f641c --- /dev/null +++ b/setup.cfg @@ -0,0 +1,56 @@ +[metadata] +name = snub +author = Caleb Weinreb +author_email = calebsw@gmail.com +url = https://github.com/dattalab/SNUB +classifiers = + Programming Language :: Python :: 3 + Operating System :: OS Independent + +[options] +packages = find: +include_package_data = True +python_requires = >=3.8 +install_requires = + PyQt5 + numpy + scikit-learn + tqdm + cmapy + interlap + numba + vispy + imageio + imageio-ffmpeg + umap-learn + rastermap==0.1.3 + ipykernel + pyqtgraph + networkx + opencv-python-headless + vidio>=0.0.3 + pytest + pytest-qt + pynwb + ndx-pose + ndx-photometry + ndx-labels + ndx-depth-moseq + pandas + dandi + + +[options.entry_points] +console_scripts = + snub = snub.gui.main:run + +[options.package_data] +* = *.md + +[versioneer] +VCS = git +style = pep440 +versionfile_source = snub/_version.py +versionfile_build = snub/_version.py +tag_prefix = +parentdir_prefix = \ No newline at end of file diff --git a/setup.py b/setup.py index 7b4600a..c5a6caf 100644 --- a/setup.py +++ b/setup.py @@ -1,40 +1,11 @@ import setuptools +import versioneer -with open("README.md", "r") as f: +with open("README.md", "r", encoding="utf-8") as f: long_description = f.read() setuptools.setup( name="snub", - version="0.0.3", - author="Caleb Weinreb", - author_email="calebsw@gmail.com", - description="Systems neuro browser", - include_package_data=True, - packages=setuptools.find_packages(), - classifiers=[ - "Programming Language :: Python :: 3", - "Operating System :: OS Independent", - ], - entry_points={"console_scripts": ["snub = snub.gui.main:run"]}, - python_requires=">=3.8", - install_requires=[ - "PyQt5", - "numpy", - "scikit-learn", - "tqdm", - "cmapy", - "interlap", - "numba", - "vispy", - "imageio", - "imageio-ffmpeg", - "umap-learn", - "rastermap", - "ipykernel", - "pyqtgraph", - "networkx", - "opencv-python-headless", - "vidio>=0.0.3", - ], - url="https://github.com/calebweinreb/SNUB", + version=versioneer.get_version(), + cmdclass=versioneer.get_cmdclass(), ) diff --git a/snub/__init__.py b/snub/__init__.py index 3e2a2db..abc1336 100644 --- a/snub/__init__.py +++ b/snub/__init__.py @@ -1,2 +1,6 @@ from . import io from . import gui + +from . import _version + +__version__ = _version.get_versions()["version"] diff --git a/snub/__main__.py b/snub/__main__.py index 6978470..a42be42 100644 --- a/snub/__main__.py +++ b/snub/__main__.py @@ -1,4 +1,4 @@ from snub.gui.main import run -if __name__ == '__main__': - run() \ No newline at end of file +if __name__ == "__main__": + run() diff --git a/snub/_version.py b/snub/_version.py new file mode 100644 index 0000000..03d226b --- /dev/null +++ b/snub/_version.py @@ -0,0 +1,709 @@ +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. +# Generated by versioneer-0.28 +# https://github.com/python-versioneer/python-versioneer +"""Git implementation of _version.py.""" + +import errno +import os +import re +import subprocess +import sys +from typing import Callable, Dict +import functools + + +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = "$Format:%d$" + git_full = "$Format:%H$" + git_date = "$Format:%ci$" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} + return keywords + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "pep440" + cfg.tag_prefix = "" + cfg.parentdir_prefix = "" + cfg.versionfile_source = "keypoint_moseq/_version.py" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} + + +def register_vcs_handler(vcs, method): # decorator + """Create decorator to mark a method as the handler of a VCS.""" + + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + + return decorate + + +def run_command( + commands, args, cwd=None, verbose=False, hide_stderr=False, env=None +): + """Call the given command(s).""" + assert isinstance(commands, list) + process = None + + popen_kwargs = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: + try: + dispcmd = str([command] + args) + # remember shell=False, so use git.cmd on windows, not just git + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, + ) + break + except OSError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %s" % (commands,)) + return None, None + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: + if verbose: + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) + return None, process.returncode + return stdout, process.returncode + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for _ in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") + date = keywords.get("date") + if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = {r.strip() for r in refnames.strip("()").split(",")} + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = {r for r in refs if re.search(r"\d", r)} + if verbose: + print("discarding '%s', no digits" % ",".join(refs - tags)) + if verbose: + print("likely tags: %s" % ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix) :] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r"\d", r): + continue + if verbose: + print("picking %s" % r) + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner( + GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose + ) + if rc != 0: + if verbose: + print("Directory %s not under git control" % root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + f"{tag_prefix}[[:digit:]]*", + ], + cwd=root, + ) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + branch_name, rc = runner( + GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root + ) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[: git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) + if not mo: + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = ( + "unable to parse git-describe output: '%s'" % describe_out + ) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix) :] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_branch(pieces): + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). + + Exceptions: + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver): + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the post-release + version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces): + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: + if pieces["distance"]: + # update the post release segment + tag_version, post_version = pep440_split_post( + pieces["closest-tag"] + ) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % ( + post_version + 1, + pieces["distance"], + ) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] + else: + # exception #1 + rendered = "0.post0.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_post_branch(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords( + get_keywords(), cfg.tag_prefix, verbose + ) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) + # versionfile_source is the relative path from the top of the source + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for _ in cfg.versionfile_source.split("/"): + root = os.path.dirname(root) + except NameError: + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None, + } + + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } diff --git a/snub/gui/__init__.py b/snub/gui/__init__.py index 8a2fdfc..c626a80 100644 --- a/snub/gui/__init__.py +++ b/snub/gui/__init__.py @@ -1,4 +1,4 @@ from . import panels from . import stacks from . import tracks -from . import main \ No newline at end of file +from . import main diff --git a/snub/gui/main.py b/snub/gui/main.py index bec4103..ba85c28 100644 --- a/snub/gui/main.py +++ b/snub/gui/main.py @@ -2,12 +2,10 @@ from PyQt5.QtWidgets import * from PyQt5.QtGui import * import sys, os, json -import numpy as np from functools import partial from snub.gui.utils import IntervalIndex, CheckBox from snub.gui.stacks import PanelStack, TrackStack from snub.gui.tracks import TracePlot -import time def set_style(app): diff --git a/snub/gui/panels/pose3D.py b/snub/gui/panels/pose3D.py index 06bb555..2b991eb 100644 --- a/snub/gui/panels/pose3D.py +++ b/snub/gui/panels/pose3D.py @@ -2,7 +2,6 @@ from PyQt5.QtWidgets import * from PyQt5.QtGui import * import numpy as np -import time import os from vispy.scene import SceneCanvas diff --git a/snub/gui/panels/roi.py b/snub/gui/panels/roi.py index 83276da..ed30a2a 100644 --- a/snub/gui/panels/roi.py +++ b/snub/gui/panels/roi.py @@ -2,81 +2,116 @@ from PyQt5.QtWidgets import * from PyQt5.QtGui import * import numpy as np -import time import os -import cmapy import cv2 from scipy.sparse import load_npz -from functools import partial from vidio import VideoReader from vispy.scene import SceneCanvas from vispy.scene.visuals import Image, Line from snub.gui.panels import Panel -from snub.gui.utils import HeaderMixin, IntervalIndex, AdjustColormapDialog +from snub.gui.utils import HeaderMixin, AdjustColormapDialog from snub.io.project import _random_color def _roi_contours(rois, dims, threshold_max_ratio=0.2, blur_kernel=2): - rois = np.array(rois.todense()).reshape(rois.shape[0],*dims) + rois = np.array(rois.todense()).reshape(rois.shape[0], *dims) contour_coordinates = [] for roi in rois: - roi_blur = cv2.GaussianBlur(roi,(11,11),blur_kernel) - roi_mask = roi_blur > roi_blur.max()*threshold_max_ratio - xy = cv2.findContours(roi_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[0][0].squeeze() - contour_coordinates.append(np.vstack((xy,xy[:1]))) + roi_blur = cv2.GaussianBlur(roi, (11, 11), blur_kernel) + roi_mask = roi_blur > roi_blur.max() * threshold_max_ratio + xy = cv2.findContours( + roi_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE + )[0][0].squeeze() + contour_coordinates.append(np.vstack((xy, xy[:1]))) return contour_coordinates class ROIPanel(Panel, HeaderMixin): eps = 1e-10 - def __init__(self, config, rois_path=None, labels_path=None, timestamps_path=None, - dimensions=None, video_paths=None, contour_colors={}, linewidth=3, - initial_selected_rois=[], vmin=0, vmax=1, colormap='viridis', **kwargs): - + def __init__( + self, + config, + rois_path=None, + labels_path=None, + timestamps_path=None, + dimensions=None, + video_paths=None, + contour_colors={}, + linewidth=3, + initial_selected_rois=[], + vmin=0, + vmax=1, + colormap="viridis", + **kwargs + ): super().__init__(config, **kwargs) self.linewidth = linewidth self.colormap = colormap - self.vmin,self.vmax = vmin,vmax + self.vmin, self.vmax = vmin, vmax self.dims = dimensions self.current_frame_index = None self.is_visible = True - self.rois = load_npz(os.path.join(config['project_directory'],rois_path)) - self.timestamps = np.load(os.path.join(config['project_directory'],timestamps_path)) + self.rois = load_npz(os.path.join(config["project_directory"], rois_path)) + self.timestamps = np.load( + os.path.join(config["project_directory"], timestamps_path) + ) - if labels_path is None: self.labels = [str(i) for i in range(self.rois.shape[0])] - else: self.labels = open(os.path.join(config['project_directory'],labels_path),'r').read().split('\n') + if labels_path is None: + self.labels = [str(i) for i in range(self.rois.shape[0])] + else: + self.labels = ( + open(os.path.join(config["project_directory"], labels_path), "r") + .read() + .split("\n") + ) self.adjust_colormap_dialog = AdjustColormapDialog(self, self.vmin, self.vmax) self.adjust_colormap_dialog.new_range.connect(self.update_colormap_range) - self.canvas = SceneCanvas(self, keys='interactive', show=True) + self.canvas = SceneCanvas(self, keys="interactive", show=True) self.canvas.events.mouse_release.connect(self.mouse_release) - self.viewbox = self.canvas.central_widget.add_grid().add_view(row=0, col=0, camera='panzoom') - self.viewbox.camera.set_range(x=(0,self.dims[1]), y=(0,self.dims[0]), margin=0) - self.viewbox.camera.aspect=1 + self.viewbox = self.canvas.central_widget.add_grid().add_view( + row=0, col=0, camera="panzoom" + ) + self.viewbox.camera.set_range( + x=(0, self.dims[1]), y=(0, self.dims[0]), margin=0 + ) + self.viewbox.camera.aspect = 1 self.contours = {} - for label,coordinates in zip(self.labels, _roi_contours(self.rois, self.dims)): - color = contour_colors[label] if label in contour_colors else _random_color() - self.contours[label] = Line(coordinates, color=np.array(color)/255, - width=self.linewidth, connect='strip', parent=None) - - self.vids = {name : VideoReader( - os.path.join(config['project_directory'],video_path) - ) for name,video_path in video_paths.items()} + for label, coordinates in zip(self.labels, _roi_contours(self.rois, self.dims)): + color = ( + contour_colors[label] if label in contour_colors else _random_color() + ) + self.contours[label] = Line( + coordinates, + color=np.array(color) / 255, + width=self.linewidth, + connect="strip", + parent=None, + ) + + self.vids = { + name: VideoReader(os.path.join(config["project_directory"], video_path)) + for name, video_path in video_paths.items() + } self.dropDown = QComboBox() self.dropDown.addItems(list(video_paths.keys())[::-1]) self.dropDown.activated.connect(self.update_image) - self.image = Image(np.zeros(self.dims, dtype=np.float32), - cmap=colormap, parent=self.viewbox.scene, clim=(0,1)) - self.update_current_time(config['init_current_time']) + self.image = Image( + np.zeros(self.dims, dtype=np.float32), + cmap=colormap, + parent=self.viewbox.scene, + clim=(0, 1), + ) + self.update_current_time(config["init_current_time"]) self.initUI(**kwargs) def initUI(self, **kwargs): @@ -84,64 +119,76 @@ def initUI(self, **kwargs): self.layout.addWidget(self.dropDown) self.layout.addWidget(self.canvas.native) self.image.order = 1 - for c in self.contours.values(): c.order=0 - self.dropDown.setStyleSheet(""" + for c in self.contours.values(): + c.order = 0 + self.dropDown.setStyleSheet( + """ QComboBox::item { color: white; background-color : #3E3E3E;} - QComboBox::item:selected { background-color: #999999;} """) + QComboBox::item:selected { background-color: #999999;} """ + ) def update_visible_contours(self, visible_contours): - for l,c in self.contours.items(): + for l, c in self.contours.items(): if l in visible_contours: c.parent = self.viewbox.scene - else: c.parent = None + else: + c.parent = None def update_current_time(self, t): - self.current_frame_index = min(self.timestamps.searchsorted(t), len(self.timestamps)-1) - if self.is_visible: self.update_image() + self.current_frame_index = min( + self.timestamps.searchsorted(t), len(self.timestamps) - 1 + ) + if self.is_visible: + self.update_image() def toggle_visiblity(self, *args): super().toggle_visiblity(*args) - if self.is_visible: self.update_image() + if self.is_visible: + self.update_image() def update_image(self): name = self.dropDown.currentText() - if self.current_frame_index is None: x = np.zeros(self.dims) - else: x = self.vids[name][self.current_frame_index][:,:,0]/255 - image = (np.clip(x, self.vmin, self.vmax)-self.vmin)/(self.vmax-self.vmin) + if self.current_frame_index is None: + x = np.zeros(self.dims) + else: + x = self.vids[name][self.current_frame_index][:, :, 0] / 255 + image = (np.clip(x, self.vmin, self.vmax) - self.vmin) / (self.vmax - self.vmin) self.image.set_data(image.astype(np.float32)) self.canvas.update() def update_colormap_range(self, vmin, vmax): - self.vmin,self.vmax = vmin,vmax + self.vmin, self.vmax = vmin, vmax self.update_image() def show_adjust_colormap_dialog(self): self.adjust_colormap_dialog.show() def mouse_release(self, event): - if event.button == 2: self.context_menu(event) + if event.button == 2: + self.context_menu(event) def context_menu(self, event): contextMenu = QMenu(self) - def add_menu_item(name, slot, item_type='label'): + + def add_menu_item(name, slot, item_type="label"): action = QWidgetAction(self) - if item_type=='checkbox': + if item_type == "checkbox": widget = QCheckBox(name) widget.stateChanged.connect(slot) - elif item_type=='label': + elif item_type == "label": widget = QLabel(name) action.triggered.connect(slot) action.setDefaultWidget(widget) - contextMenu.addAction(action) + contextMenu.addAction(action) return widget # click to show adjust colormap range dialog - label = add_menu_item('Adjust colormap range',self.show_adjust_colormap_dialog) + label = add_menu_item("Adjust colormap range", self.show_adjust_colormap_dialog) - contextMenu.setStyleSheet(""" + contextMenu.setStyleSheet( + """ QMenu::item, QLabel, QCheckBox { background-color : #3E3E3E; padding: 5px 6px 5px 6px;} QMenu::item:selected, QLabel:hover, QCheckBox:hover { background-color: #999999;} - QMenu::separator { background-color: rgb(20,20,20);} """) + QMenu::separator { background-color: rgb(20,20,20);} """ + ) action = contextMenu.exec_(event.native.globalPos()) - - diff --git a/snub/gui/panels/scatter.py b/snub/gui/panels/scatter.py index a67e39a..e9baac8 100644 --- a/snub/gui/panels/scatter.py +++ b/snub/gui/panels/scatter.py @@ -1,9 +1,7 @@ from PyQt5.QtCore import * from PyQt5.QtWidgets import * from PyQt5.QtGui import * -import pyqtgraph as pg import numpy as np -import time import os import cmapy from functools import partial @@ -16,173 +14,229 @@ from snub.gui.utils import HeaderMixin, AdjustColormapDialog, IntervalIndex - - - class ScatterPanel(Panel, HeaderMixin): eps = 1e-10 - def __init__(self, config, selected_intervals, data_path=None, name='', - pointsize=10, linewidth=1, facecolor=(180,180,180), xlim=None, ylim=None, - selected_edgecolor=(255,255,0), edgecolor=(0,0,0), current_node_size=20, - current_node_color=(255,0,0), colormap='viridis', - selection_intersection_threshold=0.5, variable_labels=[], **kwargs): - + def __init__( + self, + config, + selected_intervals, + data_path=None, + name="", + pointsize=10, + linewidth=1, + facecolor=(180, 180, 180), + xlim=None, + ylim=None, + selected_edgecolor=(255, 255, 0), + edgecolor=(0, 0, 0), + current_node_size=20, + current_node_color=(255, 0, 0), + colormap="viridis", + selection_intersection_threshold=0.5, + variable_labels=[], + **kwargs + ): super().__init__(config, **kwargs) assert data_path is not None self.selected_intervals = selected_intervals - self.bounds = config['bounds'] - self.min_step = config['min_step'] + self.bounds = config["bounds"] + self.min_step = config["min_step"] self.pointsize = pointsize self.linewidth = linewidth - self.facecolor = np.array(facecolor)/256 - self.edgecolor = np.array(edgecolor)/256 + self.facecolor = np.array(facecolor) / 256 + self.edgecolor = np.array(edgecolor) / 256 self.colormap = colormap - self.selected_edgecolor = np.array(selected_edgecolor)/256 + self.selected_edgecolor = np.array(selected_edgecolor) / 256 self.current_node_size = current_node_size - self.current_node_color = np.array(current_node_color)/256 + self.current_node_color = np.array(current_node_color) / 256 self.selection_intersection_threshold = selection_intersection_threshold - self.variable_labels = ['Interval start','Interval end']+variable_labels - self.vmin,self.vmax = 0,1 - self.current_variable_label = '(No color)' - self.sort_nodes_by_variable = True + self.variable_labels = ["Interval start", "Interval end"] + variable_labels + self.vmin, self.vmax = 0, 1 + self.current_variable_label = "(No color)" + self.sort_nodes_by_variable = True self.show_marker_trail = False - self.data = np.load(os.path.join(config['project_directory'],data_path)) - self.data[:,2:4] = self.data[:,2:4] + np.array([-self.eps, self.eps]) - self.is_selected = np.zeros(self.data.shape[0])>0 + self.data = np.load(os.path.join(config["project_directory"], data_path)) + self.data[:, 2:4] = self.data[:, 2:4] + np.array([-self.eps, self.eps]) + self.is_selected = np.zeros(self.data.shape[0]) > 0 self.plot_order = np.arange(self.data.shape[0]) - self.interval_index = IntervalIndex(min_step=self.min_step, intervals=self.data[:,2:4]) + self.interval_index = IntervalIndex( + min_step=self.min_step, intervals=self.data[:, 2:4] + ) self.adjust_colormap_dialog = AdjustColormapDialog(self, self.vmin, self.vmax) self.variable_menu = QListWidget(self) self.variable_menu.itemClicked.connect(self.variable_menu_item_clicked) self.show_variable_menu() - self.canvas = SceneCanvas(self, keys='interactive', show=True) + self.canvas = SceneCanvas(self, keys="interactive", show=True) self.canvas.events.mouse_move.connect(self.mouse_move) self.canvas.events.mouse_release.connect(self.mouse_release) - self.viewbox = self.canvas.central_widget.add_grid().add_view(row=0, col=0, camera='panzoom') - self.viewbox.camera.aspect=1 + self.viewbox = self.canvas.central_widget.add_grid().add_view( + row=0, col=0, camera="panzoom" + ) + self.viewbox.camera.aspect = 1 self.scatter = Markers(antialias=0) self.scatter_selected = Markers(antialias=0) self.current_node_marker = Markers(antialias=0) - self.rect = Rectangle(border_color=(1,1,1), color=(1,1,1,.2), center=(0,0), width=1, height=1) + self.rect = Rectangle( + border_color=(1, 1, 1), + color=(1, 1, 1, 0.2), + center=(0, 0), + width=1, + height=1, + ) self.viewbox.add(self.scatter) - self.initUI(name=name, xlim=xlim, ylim=ylim, ) + self.initUI( + name=name, + xlim=xlim, + ylim=ylim, + ) def initUI(self, xlim=None, ylim=None, **kwargs): super().initUI(**kwargs) splitter = QSplitter(Qt.Horizontal) splitter.addWidget(self.variable_menu) splitter.addWidget(self.canvas.native) - splitter.setStretchFactor(0,3) - splitter.setStretchFactor(1,3) + splitter.setStretchFactor(0, 3) + splitter.setStretchFactor(1, 3) self.layout.addWidget(splitter) self.update_scatter() - if xlim is None: xlim = [self.data[:,0].min(),self.data[:,0].max()] - if ylim is None: ylim = [self.data[:,1].min(),self.data[:,1].max()] + if xlim is None: + xlim = [self.data[:, 0].min(), self.data[:, 0].max()] + if ylim is None: + ylim = [self.data[:, 1].min(), self.data[:, 1].max()] self.viewbox.camera.set_range(x=xlim, y=ylim, margin=0.1) self.rect.order = 0 - self.current_node_marker.order=1 - self.scatter_selected.order=2 - self.scatter.order=3 - self.variable_menu.setStyleSheet(""" + self.current_node_marker.order = 1 + self.scatter_selected.order = 2 + self.scatter.order = 3 + self.variable_menu.setStyleSheet( + """ QListWidget::item { background-color : #3E3E3E; color:white; padding: 5px 6px 5px 6px;} QListWidget::item:hover, QLabel:hover { background-color: #999999; color:white; } - QListWidget { background-color : #3E3E3E; }""") - + QListWidget { background-color : #3E3E3E; }""" + ) def update_scatter(self): if self.current_variable_label in self.variable_labels: - x = self.data[:,2+self.variable_labels.index(self.current_variable_label)] - if self.sort_nodes_by_variable: self.plot_order = np.argsort(x)[::-1] - else: self.plot_order = np.arange(len(x)) - x = np.clip((x - self.vmin) / (self.vmax - self.vmin), 0, 1)[self.plot_order] - face_color = cmapy.cmap(self.colormap).squeeze()[:,::-1][(255*x).astype(int)]/255 - else: face_color = np.repeat(self.facecolor[None],self.data.shape[0],axis=0) + x = self.data[ + :, 2 + self.variable_labels.index(self.current_variable_label) + ] + if self.sort_nodes_by_variable: + self.plot_order = np.argsort(x)[::-1] + else: + self.plot_order = np.arange(len(x)) + x = np.clip((x - self.vmin) / (self.vmax - self.vmin), 0, 1)[ + self.plot_order + ] + face_color = ( + cmapy.cmap(self.colormap).squeeze()[:, ::-1][(255 * x).astype(int)] + / 255 + ) + else: + face_color = np.repeat(self.facecolor[None], self.data.shape[0], axis=0) self.scatter.set_data( - pos=self.data[self.plot_order,:2], + pos=self.data[self.plot_order, :2], face_color=face_color, - edge_color=self.edgecolor, - edge_width=self.linewidth, - size=self.pointsize) + edge_color=self.edgecolor, + edge_width=self.linewidth, + size=self.pointsize, + ) if self.is_selected.any(): is_selected = self.is_selected[self.plot_order] self.scatter_selected.set_data( - pos=self.data[self.plot_order,:2][is_selected], + pos=self.data[self.plot_order, :2][is_selected], face_color=face_color[is_selected], - edge_color=self.selected_edgecolor, - edge_width=(self.linewidth*2), - size=self.pointsize) + edge_color=self.selected_edgecolor, + edge_width=(self.linewidth * 2), + size=self.pointsize, + ) self.scatter_selected.parent = self.viewbox.scene - else: self.scatter_selected.parent = None - + else: + self.scatter_selected.parent = None def update_current_time(self, t): - if self.show_marker_trail: - times = np.linspace(t,t-5,11) - sizes = np.exp(np.linspace(0,-1.5,11)) + if self.show_marker_trail: + times = np.linspace(t, t - 5, 11) + sizes = np.exp(np.linspace(0, -1.5, 11)) else: times = np.array([t]) sizes = np.array([1]) - nodes,time_indexes = self.interval_index.intervals_containing(times) - if len(nodes)>0: + nodes, time_indexes = self.interval_index.intervals_containing(times) + if len(nodes) > 0: self.current_node_marker.set_data( - pos=self.data[nodes,:2], + pos=self.data[nodes, :2], face_color=self.current_node_color, - size=sizes[time_indexes]*self.current_node_size) + size=sizes[time_indexes] * self.current_node_size, + ) self.current_node_marker.parent = self.viewbox.scene - else: self.current_node_marker.parent = None - + else: + self.current_node_marker.parent = None def context_menu(self, event): contextMenu = QMenu(self) - def add_menu_item(name, slot, item_type='label'): + + def add_menu_item(name, slot, item_type="label"): action = QWidgetAction(self) - if item_type=='checkbox': + if item_type == "checkbox": widget = QCheckBox(name) widget.stateChanged.connect(slot) - elif item_type=='label': + elif item_type == "label": widget = QLabel(name) action.triggered.connect(slot) action.setDefaultWidget(widget) - contextMenu.addAction(action) + contextMenu.addAction(action) return widget - + # show/hide variable menu - if self.variable_menu.isVisible(): add_menu_item('Hide variables menu', self.hide_variable_menu) - else: add_menu_item('Show variables menu', self.show_variable_menu) + if self.variable_menu.isVisible(): + add_menu_item("Hide variables menu", self.hide_variable_menu) + else: + add_menu_item("Show variables menu", self.show_variable_menu) # get enriched variables (only available is nodes are selected) - label = add_menu_item('Sort variables by enrichment', self.get_enriched_variables) - if self.is_selected.sum()==0: label.setStyleSheet("QLabel { color: rgb(120,120,120); }") - label = add_menu_item('Restore original variable order', self.show_variable_menu) + label = add_menu_item( + "Sort variables by enrichment", self.get_enriched_variables + ) + if self.is_selected.sum() == 0: + label.setStyleSheet("QLabel { color: rgb(120,120,120); }") + label = add_menu_item( + "Restore original variable order", self.show_variable_menu + ) contextMenu.addSeparator() # toggle whether to plot high-variable-val nodes on top - checkbox = add_menu_item('Plot high values on top', self.toggle_sort_by_color_value, item_type='checkbox') - if self.sort_nodes_by_variable: checkbox.setChecked(True) - else: checkbox.setChecked(False) + checkbox = add_menu_item( + "Plot high values on top", + self.toggle_sort_by_color_value, + item_type="checkbox", + ) + if self.sort_nodes_by_variable: + checkbox.setChecked(True) + else: + checkbox.setChecked(False) contextMenu.addSeparator() # click to show adjust colormap range dialog - label = add_menu_item('Adjust colormap range',self.show_adjust_colormap_dialog) + label = add_menu_item("Adjust colormap range", self.show_adjust_colormap_dialog) contextMenu.addSeparator() - if self.show_marker_trail: - add_menu_item('Hide marker trail',partial(self.toggle_marker_trail,False)) - else: add_menu_item('Show marker trail',partial(self.toggle_marker_trail,True)) - + if self.show_marker_trail: + add_menu_item("Hide marker trail", partial(self.toggle_marker_trail, False)) + else: + add_menu_item("Show marker trail", partial(self.toggle_marker_trail, True)) - contextMenu.setStyleSheet(""" + contextMenu.setStyleSheet( + """ QMenu::item, QLabel, QCheckBox { background-color : #3E3E3E; padding: 5px 6px 5px 6px;} QMenu::item:selected, QLabel:hover, QCheckBox:hover { background-color: #999999;} - QMenu::separator { background-color: rgb(20,20,20);} """) + QMenu::separator { background-color: rgb(20,20,20);} """ + ) action = contextMenu.exec_(event.native.globalPos()) - def toggle_marker_trail(self, visibility): self.show_marker_trail = visibility self.update_scatter() @@ -195,24 +249,30 @@ def hide_variable_menu(self): def show_variable_menu(self, *args, variable_order=None): self.variable_menu.clear() - if variable_order is None: variable_order = self.variable_labels - for name in variable_order: self.variable_menu.addItem(name) + if variable_order is None: + variable_order = self.variable_labels + for name in variable_order: + self.variable_menu.addItem(name) self.variable_menu.show() def get_enriched_variables(self): - if self.is_selected.sum() > 0 and len(self.variable_labels)>0: - variables_zscore = (self.data[:,2:] - self.data[:,2:].mean(0))/(np.std(self.data[:,2:],axis=0)+self.eps) + if self.is_selected.sum() > 0 and len(self.variable_labels) > 0: + variables_zscore = (self.data[:, 2:] - self.data[:, 2:].mean(0)) / ( + np.std(self.data[:, 2:], axis=0) + self.eps + ) enrichment = variables_zscore[self.is_selected].mean(0) variable_order = [self.variable_labels[i] for i in np.argsort(-enrichment)] self.show_variable_menu(variable_order=variable_order) def update_colormap_range(self, vmin, vmax): - self.vmin,self.vmax = vmin,vmax + self.vmin, self.vmax = vmin, vmax self.update_scatter() def toggle_sort_by_color_value(self, check_state): - if check_state == 0: self.sort_nodes_by_variable = False - else: self.sort_nodes_by_variable = True + if check_state == 0: + self.sort_nodes_by_variable = False + else: + self.sort_nodes_by_variable = True self.update_scatter() def show_adjust_colormap_dialog(self): @@ -221,14 +281,16 @@ def show_adjust_colormap_dialog(self): def colorby(self, label): self.current_variable_label = label if self.current_variable_label in self.variable_labels: - x = self.data[:,2+self.variable_labels.index(self.current_variable_label)] - self.vmin,self.vmax = x.min()-self.eps,x.max()+self.eps - self.adjust_colormap_dialog.update_range(self.vmin,self.vmax) + x = self.data[ + :, 2 + self.variable_labels.index(self.current_variable_label) + ] + self.vmin, self.vmax = x.min() - self.eps, x.max() + self.eps + self.adjust_colormap_dialog.update_range(self.vmin, self.vmax) self.update_scatter() def mouse_release(self, event): self.rect.parent = None - if event.button == 2: + if event.button == 2: self.context_menu(event) def mouse_move(self, event): @@ -237,23 +299,28 @@ def mouse_move(self, event): if keys.SHIFT in mods or keys.CONTROL in mods: current_pos = self.viewbox.scene.transform.imap(event.pos)[:2] start_pos = self.viewbox.scene.transform.imap(event.press_event.pos)[:2] - if all((current_pos-start_pos)!=0): - self.rect.center = (current_pos+start_pos)/2 - self.rect.width = np.abs(current_pos[0]-start_pos[0]) - self.rect.height = np.abs(current_pos[1]-start_pos[1]) + if all((current_pos - start_pos) != 0): + self.rect.center = (current_pos + start_pos) / 2 + self.rect.width = np.abs(current_pos[0] - start_pos[0]) + self.rect.height = np.abs(current_pos[1] - start_pos[1]) self.rect.parent = self.viewbox.scene - selection_value = int(mods[0]==keys.SHIFT) - enclosed_points = np.all([ - self.data[:,:2]>=np.minimum(current_pos, start_pos), - self.data[:,:2]<=np.maximum(current_pos, start_pos)],axis=(0,2)) + selection_value = int(mods[0] == keys.SHIFT) + enclosed_points = np.all( + [ + self.data[:, :2] >= np.minimum(current_pos, start_pos), + self.data[:, :2] <= np.maximum(current_pos, start_pos), + ], + axis=(0, 2), + ) self.selection_change.emit( - list(self.data[enclosed_points,2:4]), - [selection_value]*len(enclosed_points)) + list(self.data[enclosed_points, 2:4]), + [selection_value] * len(enclosed_points), + ) def update_selected_intervals(self): - intersections = self.selected_intervals.intersection_proportions(self.data[:,2:4]) + intersections = self.selected_intervals.intersection_proportions( + self.data[:, 2:4] + ) self.is_selected = intersections > self.selection_intersection_threshold self.update_scatter() - - diff --git a/snub/gui/panels/video.py b/snub/gui/panels/video.py index cef7a75..b4c7eaf 100644 --- a/snub/gui/panels/video.py +++ b/snub/gui/panels/video.py @@ -2,7 +2,6 @@ from PyQt5.QtWidgets import * from PyQt5.QtGui import * import numpy as np -import time import os from vidio import VideoReader diff --git a/snub/gui/stacks/__init__.py b/snub/gui/stacks/__init__.py index 50f3b74..f4e2fe5 100644 --- a/snub/gui/stacks/__init__.py +++ b/snub/gui/stacks/__init__.py @@ -1,3 +1,3 @@ from .base import Stack from .panel import PanelStack -from .track import TrackStack \ No newline at end of file +from .track import TrackStack diff --git a/snub/gui/stacks/base.py b/snub/gui/stacks/base.py index 7b82817..dfa90a4 100644 --- a/snub/gui/stacks/base.py +++ b/snub/gui/stacks/base.py @@ -11,12 +11,13 @@ def __init__(self, config, selected_intervals): self.selected_intervals = selected_intervals def change_layout_mode(self, layout_mode): - for widget in self.widgets: widget.change_layout_mode(layout_mode) - + for widget in self.widgets: + widget.change_layout_mode(layout_mode) + def initUI(self): sizePolicy = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) sizePolicy.setHorizontalStretch(self.size_ratio) self.setSizePolicy(sizePolicy) widget_order = np.argsort([w.order for w in self.widgets]) - self.widgets = [self.widgets[i] for i in widget_order] \ No newline at end of file + self.widgets = [self.widgets[i] for i in widget_order] diff --git a/snub/gui/stacks/panel.py b/snub/gui/stacks/panel.py index 31114cb..202787d 100644 --- a/snub/gui/stacks/panel.py +++ b/snub/gui/stacks/panel.py @@ -1,45 +1,45 @@ from PyQt5.QtCore import * from PyQt5.QtWidgets import * from PyQt5.QtGui import * -import numpy as np from snub.gui.stacks import Stack from snub.gui.panels import VideoPanel, ScatterPanel, ROIPanel, Pose3DPanel + class PanelStack(Stack): def __init__(self, config, selected_intervals): super().__init__(config, selected_intervals) - self.size_ratio = config['panels_size_ratio'] + self.size_ratio = config["panels_size_ratio"] - for props in config['scatter']: # initialize scatter plots + for props in config["scatter"]: # initialize scatter plots panel = ScatterPanel(config, self.selected_intervals, **props) self.widgets.append(panel) - for props in config['video']: # initialize video + for props in config["video"]: # initialize video panel = VideoPanel(config, **props) self.widgets.append(panel) - for props in config['pose3D']: # initialize 3D pose viewer + for props in config["pose3D"]: # initialize 3D pose viewer panel = Pose3DPanel(config, **props) self.widgets.append(panel) - for props in config['roiplot']: # initialize ROI plot + for props in config["roiplot"]: # initialize ROI plot panel = ROIPanel(config, **props) self.widgets.append(panel) self.initUI() - def initUI(self): super().initUI() hbox = QHBoxLayout(self) self.splitter = QSplitter(Qt.Vertical) - for panel in self.widgets: self.splitter.addWidget(panel) + for panel in self.widgets: + self.splitter.addWidget(panel) self.splitter.setSizes([w.size_ratio for w in self.widgets]) hbox.addWidget(self.splitter) - self.splitter.setSizes([100000*p.size_ratio for p in self.widgets]) + self.splitter.setSizes([100000 * p.size_ratio for p in self.widgets]) hbox.setContentsMargins(0, 0, 0, 0) def get_by_name(self, name): @@ -47,14 +47,16 @@ def get_by_name(self, name): if panel.name == name: return panel - def update_current_time(self,t): + def update_current_time(self, t): for panel in self.widgets: panel.update_current_time(t) def update_selected_intervals(self): - for panel in self.widgets: + for panel in self.widgets: panel.update_selected_intervals() def change_layout_mode(self, layout_mode): - self.splitter.setOrientation({'columns':Qt.Vertical, 'rows':Qt.Horizontal}[layout_mode]) - super().change_layout_mode(layout_mode) \ No newline at end of file + self.splitter.setOrientation( + {"columns": Qt.Vertical, "rows": Qt.Horizontal}[layout_mode] + ) + super().change_layout_mode(layout_mode) diff --git a/snub/gui/stacks/track.py b/snub/gui/stacks/track.py index d1816a4..bc3fc79 100644 --- a/snub/gui/stacks/track.py +++ b/snub/gui/stacks/track.py @@ -2,7 +2,6 @@ from PyQt5.QtWidgets import * from PyQt5.QtGui import * import numpy as np -import time from snub.gui.stacks import Stack from snub.gui.tracks import * diff --git a/snub/gui/tracks/__init__.py b/snub/gui/tracks/__init__.py index 2eae32e..010e535 100644 --- a/snub/gui/tracks/__init__.py +++ b/snub/gui/tracks/__init__.py @@ -1,4 +1,12 @@ -from .base import Track, TrackGroup, Timeline, SelectionOverlay, LineOverlay, position_to_time, time_to_position +from .base import ( + Track, + TrackGroup, + Timeline, + SelectionOverlay, + LineOverlay, + position_to_time, + time_to_position, +) from .trace import TracePlot, HeadedTracePlot from .heatmap import Heatmap, HeatmapTraceGroup, HeadedHeatmap -from .spike import SpikePlot, HeadedSpikePlot, SpikePlotTraceGroup \ No newline at end of file +from .spike import SpikePlot, HeadedSpikePlot, SpikePlotTraceGroup diff --git a/snub/gui/tracks/base.py b/snub/gui/tracks/base.py index 1298d0e..eb5dcd9 100644 --- a/snub/gui/tracks/base.py +++ b/snub/gui/tracks/base.py @@ -2,9 +2,6 @@ from PyQt5.QtWidgets import * from PyQt5.QtGui import * import numpy as np -import os -import time -from functools import partial from snub.gui.utils import HeaderMixin @@ -226,7 +223,7 @@ def get_visible_tick_positions(self): ) tick_times = np.arange(first_tick, self.current_range[1], tick_interval) tick_positions = self._time_to_position(tick_times) - return tick_times, tick_positions + return tick_times, tick_positions.astype(int) def paintEvent(self, event): qp = QPainter() diff --git a/snub/gui/tracks/heatmap.py b/snub/gui/tracks/heatmap.py index 8c6a553..5b6292f 100644 --- a/snub/gui/tracks/heatmap.py +++ b/snub/gui/tracks/heatmap.py @@ -5,7 +5,6 @@ import os import numpy as np import cmapy -import time from numba import njit, prange from snub.gui.tracks import Track, TracePlot, TrackGroup @@ -220,7 +219,7 @@ def paintEvent(self, event): qp.setPen(QColor(*self.label_colors[i], self.base_alpha)) qp.drawText( self.label_margin, - height * self.height() - self.max_label_height // 2, + int(height * self.height()) - self.max_label_height // 2, self.max_label_width, self.max_label_height, Qt.AlignVCenter, diff --git a/snub/gui/tracks/spike.py b/snub/gui/tracks/spike.py index 277f24b..5702424 100644 --- a/snub/gui/tracks/spike.py +++ b/snub/gui/tracks/spike.py @@ -1,46 +1,69 @@ from PyQt5.QtCore import * from PyQt5.QtWidgets import * from PyQt5.QtGui import * -from functools import partial import os import numpy as np import cmapy -import time from vispy.scene import SceneCanvas from vispy.scene.visuals import Markers, Line -from snub.gui.tracks import Track, TracePlot, TrackGroup, Heatmap +from snub.gui.tracks import TracePlot, TrackGroup, Heatmap -''' +""" class SpikePlot(Heatmap): def __init__(self, config, selected_intervals, spikes_path=None, heatmap_path=None, **kwargs): print(spikes_path, heatmap_path) super().__init__(config, selected_intervals, data_path=heatmap_path, **kwargs) self.spike_data = np.load(os.path.join(config['project_directory'],spikes_path)) -''' +""" class SpikePlot(Heatmap): - def __init__(self, config, selected_intervals, spikes_path=None, markersize=5, - heatmap_path=None, heatmap_range=60, colormap='viridis', **kwargs): - + def __init__( + self, + config, + selected_intervals, + spikes_path=None, + markersize=5, + heatmap_path=None, + heatmap_range=60, + colormap="viridis", + **kwargs + ): super().__init__(config, selected_intervals, data_path=heatmap_path, **kwargs) self.heatmap_range = heatmap_range - spike_data = np.load(os.path.join(config['project_directory'],spikes_path)) - self.spike_times,self.spike_labels = spike_data[:,0], spike_data[:,1].astype(int) + spike_data = np.load(os.path.join(config["project_directory"], spikes_path)) + self.spike_times, self.spike_labels = spike_data[:, 0], spike_data[:, 1].astype( + int + ) self.max_label = self.spike_labels.max() - self.markersize=markersize + self.markersize = markersize self.colormap = colormap - self.cmap = cmapy.cmap(self.colormap).squeeze()[:,::-1].astype(np.float32)/255 - self.canvas = SceneCanvas(self, keys='interactive', bgcolor=self.cmap[0], show=True) - self.viewbox = self.canvas.central_widget.add_grid().add_view(row=0, col=0, camera='panzoom') - line_verts = np.vstack([ - np.ones(self.max_label)*self.spike_times.min()-10, - np.arange(self.max_label), - np.ones(self.max_label)*self.spike_times.max()+10, - np.arange(self.max_label)]).T - self.lines = Line(pos=line_verts.reshape(-1,2), color=np.clip(self.cmap[0]+.1,0,1), method='gl', width=0.5, connect='segments') + self.cmap = ( + cmapy.cmap(self.colormap).squeeze()[:, ::-1].astype(np.float32) / 255 + ) + self.canvas = SceneCanvas( + self, keys="interactive", bgcolor=self.cmap[0], show=True + ) + self.viewbox = self.canvas.central_widget.add_grid().add_view( + row=0, col=0, camera="panzoom" + ) + line_verts = np.vstack( + [ + np.ones(self.max_label) * self.spike_times.min() - 10, + np.arange(self.max_label), + np.ones(self.max_label) * self.spike_times.max() + 10, + np.arange(self.max_label), + ] + ).T + self.lines = Line( + pos=line_verts.reshape(-1, 2), + color=np.clip(self.cmap[0] + 0.1, 0, 1), + method="gl", + width=0.5, + connect="segments", + ) self.viewbox.add(self.lines) self.scatter = Markers() @@ -48,9 +71,9 @@ def __init__(self, config, selected_intervals, spikes_path=None, markersize=5, self.scatter.order = -1 self.set_scatter_data() self.viewbox.add(self.scatter) - + layout = QVBoxLayout(self) - layout.setContentsMargins(0,0,0,0) + layout.setContentsMargins(0, 0, 0, 0) layout.addWidget(self.canvas.native, 1) self.heatmap_image.raise_() self.heatmap_labels.raise_() @@ -58,41 +81,65 @@ def __init__(self, config, selected_intervals, spikes_path=None, markersize=5, def update_current_range(self, current_range): super().update_current_range(current_range) - if (self.current_range[1]-self.current_range[0]) >= self.heatmap_range: + if (self.current_range[1] - self.current_range[0]) >= self.heatmap_range: self.heatmap_image.show() else: self.heatmap_image.hide() - self.viewbox.camera.set_range(x=self.current_range, y=self.get_ylim(), margin=1e-10) - bgcolor = self.cmap[0]*(self.current_range[1]-self.current_range[0])/self.heatmap_range + self.viewbox.camera.set_range( + x=self.current_range, y=self.get_ylim(), margin=1e-10 + ) + bgcolor = ( + self.cmap[0] + * (self.current_range[1] - self.current_range[0]) + / self.heatmap_range + ) self.canvas.bgcolor = bgcolor - self.lines.set_data(color=np.clip(bgcolor+.1,0,1)) + self.lines.set_data(color=np.clip(bgcolor + 0.1, 0, 1)) def spike_coordinates(self): - ycoords = self.max_label-np.argsort(self.row_order)[self.spike_labels]+.5 - return np.vstack((self.spike_times,ycoords)).T + ycoords = self.max_label - np.argsort(self.row_order)[self.spike_labels] + 0.5 + return np.vstack((self.spike_times, ycoords)).T def spike_colors(self): image_data = self.get_image_data() rows = np.argsort(self.row_order)[self.spike_labels] - cols = np.around((self.spike_times-self.intervals[0,0])/self.min_step).astype(int) - colors = image_data[rows,np.clip(cols,0,image_data.shape[1]-1)].astype(np.float32)/255 + cols = np.around( + (self.spike_times - self.intervals[0, 0]) / self.min_step + ).astype(int) + colors = ( + image_data[rows, np.clip(cols, 0, image_data.shape[1] - 1)].astype( + np.float32 + ) + / 255 + ) return colors def zoom_in_vertical(self): super().zoom_in_vertical() - self.viewbox.camera.set_range(x=self.current_range, y=self.get_ylim(), margin=1e-10) + self.viewbox.camera.set_range( + x=self.current_range, y=self.get_ylim(), margin=1e-10 + ) - def zoom_vertical(self,origin,scale_factor): - super().zoom_vertical(origin,scale_factor) - self.viewbox.camera.set_range(x=self.current_range, y=self.get_ylim(), margin=1e-10) + def zoom_vertical(self, origin, scale_factor): + super().zoom_vertical(origin, scale_factor) + self.viewbox.camera.set_range( + x=self.current_range, y=self.get_ylim(), margin=1e-10 + ) def get_ylim(self): - return self.max_label - np.array(self.vertical_range)[::-1] + 1 + return self.max_label - np.array(self.vertical_range)[::-1] + 1 def set_scatter_data(self): xy = self.spike_coordinates() c = self.spike_colors() - self.scatter.set_data(xy, edge_width=0, face_color=c, edge_color=None, symbol='vbar', size=self.markersize) + self.scatter.set_data( + xy, + edge_width=0, + face_color=c, + edge_color=None, + symbol="vbar", + size=self.markersize, + ) def update_row_order(self, order): super().update_row_order(order) @@ -103,27 +150,43 @@ def update_colormap_range(self, *args): self.set_scatter_data() - - - - class HeadedSpikePlot(TrackGroup): def __init__(self, config, selected_intervals, **kwargs): spikeplot = SpikePlot(config, selected_intervals, **kwargs) - super().__init__(config, tracks={'spikeplot':spikeplot}, track_order=['spikeplot'], **kwargs) + super().__init__( + config, tracks={"spikeplot": spikeplot}, track_order=["spikeplot"], **kwargs + ) class SpikePlotTraceGroup(TrackGroup): - def __init__(self, config, selected_intervals, trace_height_ratio=1, - heatmap_height_ratio=2, height_ratio=1, **kwargs): + def __init__( + self, + config, + selected_intervals, + trace_height_ratio=1, + heatmap_height_ratio=2, + height_ratio=1, + **kwargs + ): self.height_ratio = trace_height_ratio + heatmap_height_ratio - spikeplot = SpikePlot(config, selected_intervals, height_ratio=heatmap_height_ratio, **kwargs) + spikeplot = SpikePlot( + config, selected_intervals, height_ratio=heatmap_height_ratio, **kwargs + ) x = spikeplot.intervals.mean(1) - trace_data = {l:np.vstack((x,d)).T for l,d in zip(spikeplot.labels, spikeplot.data)} - trace = TracePlot(config, height_ratio=trace_height_ratio, data=trace_data, **kwargs) - - super().__init__(config, tracks={'trace':trace, 'spikeplot':spikeplot}, - track_order=['trace','spikeplot'], height_ratio=height_ratio, **kwargs) + trace_data = { + l: np.vstack((x, d)).T for l, d in zip(spikeplot.labels, spikeplot.data) + } + trace = TracePlot( + config, height_ratio=trace_height_ratio, data=trace_data, **kwargs + ) + + super().__init__( + config, + tracks={"trace": trace, "spikeplot": spikeplot}, + track_order=["trace", "spikeplot"], + height_ratio=height_ratio, + **kwargs + ) spikeplot.display_trace_signal.connect(trace.show_trace) diff --git a/snub/gui/tracks/trace.py b/snub/gui/tracks/trace.py index 637a38f..ee24b59 100644 --- a/snub/gui/tracks/trace.py +++ b/snub/gui/tracks/trace.py @@ -2,7 +2,6 @@ from PyQt5.QtWidgets import * from PyQt5.QtGui import * import pyqtgraph as pg -import colorsys import numpy as np import pickle import os diff --git a/snub/gui/utils/__init__.py b/snub/gui/utils/__init__.py index eb4ec6d..862cea3 100644 --- a/snub/gui/utils/__init__.py +++ b/snub/gui/utils/__init__.py @@ -1,2 +1,2 @@ from .interval import IntervalIndex -from .widgets import AdjustColormapDialog, HeaderMixin, CheckBox \ No newline at end of file +from .widgets import AdjustColormapDialog, HeaderMixin, CheckBox diff --git a/snub/gui/utils/interval.py b/snub/gui/utils/interval.py index 1180f8a..4b512cb 100644 --- a/snub/gui/utils/interval.py +++ b/snub/gui/utils/interval.py @@ -2,8 +2,7 @@ from numba import njit, prange - -#@njit +@njit def sum_by_index(x, ixs, n): out = np.zeros(n) for i in prange(len(ixs)): @@ -11,69 +10,82 @@ def sum_by_index(x, ixs, n): return out -class IntervalIndexBase(): - def __init__(self, intervals=np.empty((0,2)), **kwargs): +class IntervalIndexBase: + def __init__(self, intervals=np.empty((0, 2)), **kwargs): self.intervals = intervals def clear(self): - self.intervals = np.empty((0,2)) + self.intervals = np.empty((0, 2)) def partition_intervals(self, start, end): - ends_before = self.intervals[:,1] < start - ends_after = self.intervals[:,1] >= start - starts_before = self.intervals[:,0] <= end - starts_after = self.intervals[:,0] > end + ends_before = self.intervals[:, 1] < start + ends_after = self.intervals[:, 1] >= start + starts_before = self.intervals[:, 0] <= end + starts_after = self.intervals[:, 0] > end intersect = self.intervals[np.bitwise_and(ends_after, starts_before)] pre = self.intervals[ends_before] post = self.intervals[starts_after] - return pre,intersect,post - + return pre, intersect, post + def add_interval(self, start, end): - pre,intersect,post = self.partition_intervals(start,end) + pre, intersect, post = self.partition_intervals(start, end) if intersect.shape[0] > 0: - merged_start = np.minimum(intersect[0,0],start) - merged_end = np.maximum(intersect[-1,1],end) - else: + merged_start = np.minimum(intersect[0, 0], start) + merged_end = np.maximum(intersect[-1, 1], end) + else: merged_start, merged_end = start, end - merged_interval = np.array([merged_start, merged_end]).reshape(1,2) + merged_interval = np.array([merged_start, merged_end]).reshape(1, 2) self.intervals = np.vstack((pre, merged_interval, post)) def remove_interval(self, start, end): - pre,intersect,post = self.partition_intervals(start,end) - pre_intersect = np.empty((0,2)) - post_intersect = np.empty((0,2)) + pre, intersect, post = self.partition_intervals(start, end) + pre_intersect = np.empty((0, 2)) + post_intersect = np.empty((0, 2)) if intersect.shape[0] > 0: - if intersect[0,0] < start: pre_intersect = np.array([intersect[0,0],start]) - if intersect[-1,1] > end: post_intersect = np.array([end,intersect[-1,1]]) - self.intervals = np.vstack((pre,pre_intersect,post_intersect,post)) + if intersect[0, 0] < start: + pre_intersect = np.array([intersect[0, 0], start]) + if intersect[-1, 1] > end: + post_intersect = np.array([end, intersect[-1, 1]]) + self.intervals = np.vstack((pre, pre_intersect, post_intersect, post)) - def intersection_proportions(self, query_intervals): + def intersection_proportions(self, query_intervals): query_ixs, ref_ixs = self.all_overlaps_both(self.intervals, query_intervals) - if len(query_ixs)>0: - intersection_starts = np.maximum(query_intervals[query_ixs,0], self.intervals[ref_ixs,0]) - intersection_ends = np.minimum(query_intervals[query_ixs,1], self.intervals[ref_ixs,1]) + if len(query_ixs) > 0: + intersection_starts = np.maximum( + query_intervals[query_ixs, 0], self.intervals[ref_ixs, 0] + ) + intersection_ends = np.minimum( + query_intervals[query_ixs, 1], self.intervals[ref_ixs, 1] + ) intersection_lengths = intersection_ends - intersection_starts - query_intersection_lengths = sum_by_index(intersection_lengths, query_ixs, query_intervals.shape[0]) - query_lengths = query_intervals[:,1] - query_intervals[:,0] + 1e-10 + query_intersection_lengths = sum_by_index( + intersection_lengths, query_ixs, query_intervals.shape[0] + ) + query_lengths = query_intervals[:, 1] - query_intervals[:, 0] + 1e-10 return query_intersection_lengths / query_lengths - else: return np.zeros(query_intervals.shape[0]) + else: + return np.zeros(query_intervals.shape[0]) def all_containments_both(self, ref_intervals, query_locations): raise NotImplementedError() def intervals_containing(self, query_locations): - query_ixs,ref_ixs = self.all_containments_both(self.intervals, query_locations) - valid_containments = np.all([ - self.intervals[ref_ixs,0] <= query_locations[query_ixs], - self.intervals[ref_ixs,1] >= query_locations[query_ixs]],axis=0) + query_ixs, ref_ixs = self.all_containments_both(self.intervals, query_locations) + valid_containments = np.all( + [ + self.intervals[ref_ixs, 0] <= query_locations[query_ixs], + self.intervals[ref_ixs, 1] >= query_locations[query_ixs], + ], + axis=0, + ) return ref_ixs[valid_containments], query_ixs[valid_containments] try: - from ncls import NCLS + # try executing so exception is triggered on import not at runtime - ncls.all_containments_both(np.arange(1),np.arange(1),np.arange(1)); + ncls.all_containments_both(np.arange(1), np.arange(1), np.arange(1)) class IntervalIndex(IntervalIndexBase): def __init__(self, min_step=0.033, **kwargs): @@ -81,15 +93,19 @@ def __init__(self, min_step=0.033, **kwargs): self.min_step = min_step def preprocess_for_ncls(self, intervals): - intervals_discretized = (intervals/self.min_step).astype(int) - return (intervals_discretized[:,0].copy(order='C'), - intervals_discretized[:,1].copy(order='C'), - np.arange(intervals_discretized.shape[0])) + intervals_discretized = (intervals / self.min_step).astype(int) + return ( + intervals_discretized[:, 0].copy(order="C"), + intervals_discretized[:, 1].copy(order="C"), + np.arange(intervals_discretized.shape[0]), + ) def all_containments_both(self, ref_intervals, query_locations): query_locations = (query_locations / self.min_step).astype(int) ncls = NCLS(*self.preprocess_for_ncls(ref_intervals)) - return ncls.all_containments_both(query_locations, query_locations, np.arange(len(query_locations))) + return ncls.all_containments_both( + query_locations, query_locations, np.arange(len(query_locations)) + ) def all_overlaps_both(self, ref_intervals, query_intervals): query_intervals = self.preprocess_for_ncls(query_intervals) @@ -98,26 +114,23 @@ def all_overlaps_both(self, ref_intervals, query_intervals): return ncls.all_overlaps_both(*query_intervals) except: - from interlap import InterLap + class IntervalIndex(IntervalIndexBase): def __init__(self, **kwargs): super().__init__(**kwargs) def all_overlaps_both(self, ref_intervals, query_intervals): - inter = InterLap(ranges=[(s,e,i) for i,(s,e) in enumerate(ref_intervals)]) - query_ixs,ref_ixs = [],[] - for i,(s,e) in enumerate(query_intervals): - overlap_ixs = [interval[2] for interval in inter.find((s,e))] + inter = InterLap( + ranges=[(s, e, i) for i, (s, e) in enumerate(ref_intervals)] + ) + query_ixs, ref_ixs = [], [] + for i, (s, e) in enumerate(query_intervals): + overlap_ixs = [interval[2] for interval in inter.find((s, e))] ref_ixs.append(overlap_ixs) - query_ixs.append([i]*len(overlap_ixs)) - return np.hstack(query_ixs).astype(int),np.hstack(ref_ixs).astype(int) + query_ixs.append([i] * len(overlap_ixs)) + return np.hstack(query_ixs).astype(int), np.hstack(ref_ixs).astype(int) def all_containments_both(self, ref_intervals, query_locations): - query_intervals = np.repeat(query_locations[:,None],2,axis=1) + query_intervals = np.repeat(query_locations[:, None], 2, axis=1) return self.all_overlaps_both(ref_intervals, query_intervals) - - - - - diff --git a/snub/gui/utils/widgets.py b/snub/gui/utils/widgets.py index 7e4cf9e..2b8af34 100644 --- a/snub/gui/utils/widgets.py +++ b/snub/gui/utils/widgets.py @@ -1,4 +1,4 @@ -import numpy as np, os +import os from pyqtgraph import VerticalLabel from PyQt5.QtCore import * from PyQt5.QtWidgets import * diff --git a/snub/io/__init__.py b/snub/io/__init__.py index fab35c9..94b77df 100644 --- a/snub/io/__init__.py +++ b/snub/io/__init__.py @@ -1,4 +1,5 @@ from .project import * from .manifold import * from .video import * -from .plot import * \ No newline at end of file +from .plot import * +from .nwb import * diff --git a/snub/io/manifold.py b/snub/io/manifold.py index ca7f427..fb90df7 100644 --- a/snub/io/manifold.py +++ b/snub/io/manifold.py @@ -1,37 +1,29 @@ import numpy as np -import warnings -# Binning / smoothing - -def firing_rates( - spike_times, - spike_labels, - window_size=0.2, - window_step=0.02 -): +def firing_rates(spike_times, spike_labels, window_size=0.2, window_step=0.05): """Convert spike tikes to firing rates using a sliding window - + Parameters ---------- spike_times : ndarray Spike times (in seconds) for all units. The source of each spike is input separately using ``spike_labels`` - + spike_labels: ndarray The source/label for each spike in ``spike_times``. The maximum value of this array determines the number of rows in the heatmap. - + window_size: float, default=0.2 Length (in seconds) of the sliding window used to calculate firing rates - - window_step: float, default=0.02 + + window_step: float, default=0.05 Step-size (in seconds) between each window used to calculate firing rates Returns ------- firing_rates: ndarray - Array of firing rates, where rows units and columns are sliding + Array of firing rates, where rows units and columns are sliding window locations. ``firing_rates`` has shape ``(N,M)`` where:: N = max(spike_labels)+1 @@ -42,28 +34,24 @@ def firing_rates( The time (in seconds) corresponding to the left-boundary of the first window in ``firing_rates``. """ - # round spikes to window_step and factor our start time - spike_times = np.around(spike_times/window_step).astype(int) + # round spikes to window_step and factor out start time + spike_times = np.around(spike_times / window_step).astype(int) start_time = spike_times.min() spike_times = spike_times - start_time - + # create heatmap of spike counts for each window_step-sized bin spike_labels = spike_labels.astype(int) - heatmap = np.zeros((spike_labels.max()+1, spike_times.max()+1)) - np.add.at(heatmap, (spike_labels, spike_times), 1/window_step) - + heatmap = np.zeros((spike_labels.max() + 1, spike_times.max() + 1)) + np.add.at(heatmap, (spike_labels, spike_times), 1 / window_step) + # use convolution to get sliding window counts - kernel = np.ones(int(window_size//window_step))/(window_size//window_step) - for i in range(heatmap.shape[0]): heatmap[i,:] = np.convolve(heatmap[i,:],kernel, mode='same') - return heatmap, start_time-window_step/2 + kernel = np.ones(int(window_size // window_step)) / (window_size // window_step) + for i in range(heatmap.shape[0]): + heatmap[i, :] = np.convolve(heatmap[i, :], kernel, mode="same") + return heatmap, (start_time - 1 / 2) * window_step -def bin_data( - data, - binsize, - axis=-1, - return_intervals=False -): +def bin_data(data, binsize, axis=-1, return_intervals=False): """Bin data using non-overlaping windows along `axis` Returns @@ -73,46 +61,36 @@ def bin_data( bin_intervals: ndarray (returned if ``rerturn_intervals=True``) (N,2) array with the start and end index of each bin """ - data = np.moveaxis(data,axis,-1) - pad_amount = (-data.shape[-1])%binsize - num_bins = int((data.shape[-1]+pad_amount)/binsize) + data = np.moveaxis(data, axis, -1) + pad_amount = (-data.shape[-1]) % binsize + num_bins = int((data.shape[-1] + pad_amount) / binsize) - data_padded = np.pad(data,[(0,0)]*(len(data.shape)-1)+[(0,pad_amount)]) + data_padded = np.pad(data, [(0, 0)] * (len(data.shape) - 1) + [(0, pad_amount)]) data_binned = data_padded.reshape(*data.shape[:-1], num_bins, binsize).mean(-1) - if pad_amount > 0: data_binned[...,-1] = data_binned[...,-1] * binsize/(binsize-pad_amount) - data_binned = np.moveaxis(data_binned,-1,axis) + if pad_amount > 0: + data_binned[..., -1] = data_binned[..., -1] * binsize / (binsize - pad_amount) + data_binned = np.moveaxis(data_binned, -1, axis) if return_intervals: - bin_starts = np.arange(0,num_bins)*binsize - bin_ends = np.arange(1,num_bins+1)*binsize + bin_starts = np.arange(0, num_bins) * binsize + bin_ends = np.arange(1, num_bins + 1) * binsize bin_ends[-1] = data.shape[-1] - bin_intervals = np.vstack((bin_starts,bin_ends)).T + bin_intervals = np.vstack((bin_starts, bin_ends)).T return data_binned, bin_intervals - else: return data_binned - + else: + return data_binned -# Normalization - def zscore(data, axis=0, eps=1e-10): """ Z-score standardize the data along ``axis`` """ mean = np.mean(data, axis=axis, keepdims=True) std = np.std(data, axis=axis, keepdims=True) + eps - return (data-mean)/std - - + return (data - mean) / std - -# Dimensionality reduction - -def sort( - data, - method='rastermap', - options={} -): +def sort(data, method="rastermap", options={}): """Compute neuron ordering that groups neurons with similar activity Parameters @@ -136,25 +114,25 @@ def sort( Ordering index that can be used for sorting (see `numpy.argsort`) """ - valid_sort_methods = ['rastermap'] + valid_sort_methods = ["rastermap"] if not method in valid_sort_methods: - raise AssertionError(method+' is not a valid sort method. Must be one of '+repr(valid_sort_methods)) - if method=='rastermap': - print('Computing row order with rastermap') + raise AssertionError( + method + + " is not a valid sort method. Must be one of " + + repr(valid_sort_methods) + ) + if method == "rastermap": + print("Computing row order with rastermap") from rastermap import mapping - model = mapping.Rastermap(n_components=1).fit(data) - return np.argsort(model.embedding[:,0]) + + model = mapping.Rastermap(n_components=1, **options).fit(data) + return np.argsort(model.embedding[:, 0]) def umap_embedding( - data, - standardize=True, - n_pcs=20, - n_components=2, - n_neighbors=100, - **kwargs + data, standardize=True, n_pcs=20, n_components=2, n_neighbors=100, **kwargs ): - """Generate a 2D embedding of neural activity using UMAP. The function + """Generate a 2D embedding of neural activity using UMAP. The function generates the embedding in three steps: 1. (Optionally) standardize (Z-score) the activity of each neuron @@ -172,7 +150,7 @@ def umap_embedding( Whether to standardize (Z-score) the data prior to PCA n_pcs: int, default=20 - Number of principal components to use during PCA. If ``n_pcs=None``, the binned + Number of principal components to use during PCA. If ``n_pcs=None``, the binned data will be passed directly to UMAP n_components: int, default=2 @@ -192,16 +170,11 @@ def umap_embedding( from sklearn.decomposition import PCA from umap import UMAP - if standardize: data = zscore(data, axis=1) + if standardize: + data = zscore(data, axis=1) PCs = PCA(n_components=n_pcs).fit_transform(data.T) - umap_obj = UMAP(n_neighbors=n_neighbors, n_components=n_components, n_epochs=500, **kwargs) + umap_obj = UMAP( + n_neighbors=n_neighbors, n_components=n_components, n_epochs=500, **kwargs + ) coordinates = umap_obj.fit_transform(PCs) return coordinates - - - - - - - - \ No newline at end of file diff --git a/snub/io/nwb.py b/snub/io/nwb.py new file mode 100644 index 0000000..8d16c7a --- /dev/null +++ b/snub/io/nwb.py @@ -0,0 +1,488 @@ +import pynwb +import os +from ndx_pose import PoseEstimation +from ndx_labels import LabelSeries +import numpy as np +import snub.io.project +from vidio import VideoReader + + +EPS = 1e-6 + + +def create_project_from_nwb( + project_directory, + nwb_path, + branches=["root"], + use_full_path=True, + project_options={}, + subplot_options={}, +): + """ + Given an NWB file and a specification of the branches of the file to be visualized, + this method creates a SNUB project with the file's data. + + Parameters + ---------- + project_directory : str + Project path. A directory will be created at this location. + + nwb_path : str + Path to the NWB file. + + branches : list of str, optional + A specification of which subtrees of the NWB file (on the basis of their name) + to include in the SNUB project. If None, all data with a supported type will be + included. + + use_full_path : bool, optional + Whether to use the full path in the NWB file when naming the SNUB object + corresponding to the NWB data. If False, only the name of the leaf node will be + used. If these names are not unique, an error will be raised. + + project_options, dict + Additonal key word arguments for snub.io.create_project + + subplot_options, dict + Additonal key word arguments to be passed to the specific subplot-adding functions + as a dict mapping dataset names to dicts of options. The names should be full paths + if use_full_path is True, or just the names of the leaf nodes otherwise. + """ + nwb_type_mapping = { + "IntervalSeries": add_interval_series, + "RoiResponseSeries": add_roi_response_series, + "TimeSeries": add_generic_timeseries, + "PoseEstimation": add_pose_estimation, + "ImageSeries": add_image_series, + "LabelSeries": add_label_series, + "TimeIntervals": add_time_intervals, + "Position": add_position, + "SpatialSeries": add_spatial_series, + "Units": add_ephys_units, + "Events": add_events, + } + + with pynwb.NWBHDF5IO(nwb_path, mode="r", load_namespaces=True) as io: + nwbfile = io.read() + + # Get all datasets to be included + children = list_included_datasets(nwbfile, branches, nwb_type_mapping) + print("The following datasets will be included in the SNUB plot:") + for child in children: + print(f" {_generate_name(child, use_full_path)} ({child.neurodata_type})") + print("") + + # Check that there is at least one dataset to be included + assert len(children) > 0, ( + "No datasets found in NWB file that match the specified branch names and have" + " one of the following supported types: " + str(nwb_type_mapping.keys()) + ) + + # Check that all datasets to be included have unique names + if not use_full_path: + names = [c.name for c in children] + assert len(names) == len(set(names)), ( + "Not all datasets to be included have unique names. You must set " + "use_full_path=True to resolve this." + ) + + # Create project + start_time, end_time = _get_start_end_times(children) + snub.io.project.create_project( + project_directory, + start_time=start_time, + end_time=end_time, + **project_options, + ) + + # Add data + for child in children: + name = _generate_name(child, use_full_path) + + if name in subplot_options: + opts = subplot_options[name] + else: + opts = {} + + try: # try/except catches malformed data + nwb_type_mapping[child.neurodata_type]( + project_directory, child, name, start_time, end_time, opts + ) + except Exception as e: + print(f"Skipping data {name} because of error: {e}") + + +def _get_start_end_times(objects): + """ + Given a list of objects from an NWB file, returns the earliest and latest timestamps + from the object or its children. + """ + starts, ends = [], [] + for obj in objects: + if obj.neurodata_type == "PoseEstimation": + s, e = _get_start_end_times(obj.pose_estimation_series.values()) + elif obj.neurodata_type == "Position": + s, e = _get_start_end_times(obj.spatial_series.values()) + elif obj.neurodata_type == "Units": + s = obj.spike_times[()].min() + e = obj.spike_times[()].max() + elif obj.neurodata_type == "TimeIntervals": + s = obj.start_time[()].min() + e = obj.stop_time[()].max() + else: + print(obj.neurodata_type, obj.name) + timestamps = get_timestamps(obj) + s = timestamps.min() + e = timestamps.max() + starts.append(s) + ends.append(e) + return min(starts), max(ends) + + +def list_included_datasets(nwbfile, branches, nwb_type_mapping): + """ + Enumerates all datasets in an NWB file that should be included in the SNUB project. + + Datasets are included if (a) their type is supported and (b) they belong to one of + the branches specified by the user. If an included dataset is the child of another + dataset, only the parent will be included (e.g. when a Position dataset is included, + its child SpatialSeries datasets will not be included). Datasets are also excluded + if they have a "rate" parameter that is set to 0. + """ + included_datasets = [] + for child in nwbfile.all_children(): + ancestors = child.get_ancestors() + ancestor_names = [a.name for a in ancestors] + [child.name] + + has_supported_type = child.neurodata_type in nwb_type_mapping + from_included_branch = len(set(branches).intersection(ancestor_names)) > 0 + has_no_included_parents = ( + len([c for c in included_datasets if c in ancestors]) == 0 + ) + has_valid_rate = (not "rate" in child.fields) or (child.rate > 0) + + if np.all( + [ + has_supported_type, + from_included_branch, + has_no_included_parents, + has_valid_rate, + ] + ): + included_datasets.append(child) + + return included_datasets + + +def get_timestamps(obj): + """ + Get timestamps for a TimeSeries object or ImageSeries object. + + Parameters + ---------- + obj: pynwb.TimeSeries or pywnb.ImageSeries + NWB TimeSeries or ImageSeries object. + """ + if obj.timestamps is None: + # Get start time + if "start" in obj.fields: + start = obj.start + elif "starting_time" in obj.fields: + start = obj.starting_time + else: + raise AssertionError( + f"TimeSeries {obj.name} has no start time. Cannot determine timestamps." + ) + + # Get rate + if "rate" not in obj.fields: + raise AssertionError( + f"TimeSeries {obj.name} has no rate. Cannot determine timestamps." + ) + rate = obj.rate + + # Get duration + if obj.neurodata_type == "ImageSeries": + T = len(VideoReader(obj.external_file[0])) + else: + T = len(obj.data) + + # Compute timestamps + stamps = start + np.arange(T) / rate + else: + stamps = obj.timestamps[:] + stamps = stamps.astype(float) + return stamps + + +def _timestamps_to_intervals(timestamps): + """ + Given an array of timestamps, returns an array of intervals that are centered on the + timestamps and have width determined by the inter-timestamp intervals. + """ + starts = np.hstack( + [ + [timestamps[0] - (timestamps[1] - timestamps[0]) / 2], + (timestamps[1:] + timestamps[:-1]) / 2, + ] + ) + ends = np.hstack( + [ + (timestamps[1:] + timestamps[:-1]) / 2, + [timestamps[-1] + (timestamps[-1] - timestamps[-2]) / 2], + ] + ) + return np.vstack([starts, ends]).T + + +def _generate_name(child, use_full_path): + if use_full_path: + return ".".join( + [a.name for a in list(child.get_ancestors())[:-1][::-1] + [child]] + ) + else: + return child.name + + +def add_interval_series(project_directory, obj, name, start_time, end_time, options): + """ + Adds an interval series to a SNUB project in the form of a traceplot. + """ + print(f'Adding interval series "{name}" as a traceplot.') + timestamps = obj.timestamps[()] # contains start and end times + data = obj.data[()] # contains interval types numbers: + for start, - for end + + traces = {} + interval_types = set(data[data > 0]) + for i in interval_types: + trace = [[start_time, 0]] + starts = timestamps[data == i] + ends = timestamps[data == -i] + for start, end in zip(starts, ends): + trace.append([start - EPS, 0]) + trace.append([start, 1]) + trace.append([end - EPS, 1]) + trace.append([end, 0]) + trace.append([end_time, 0]) + traces[str(i)] = np.array(trace) + + snub.io.project.add_traceplot(project_directory, name, traces, **options) + + +def add_roi_response_series( + project_directory, obj, name, start_time, end_time, options +): + """ + Adds an ROI response series to a SNUB project in the form of a heatmap. + """ + print(f'Adding ROI response series "{name}" as a heatmap.') + data = obj.data[()].T + start_time = obj.starting_time + binsize = 1 / obj.rate + + snub.io.project.add_heatmap( + project_directory, + name, + data, + start_time=start_time, + binsize=binsize, + **options, + ) + + +def add_spatial_series(project_directory, obj, name, start_time, end_time, options): + """ + Adds a spatial series to a SNUB project in the form of a traceplot. + """ + print(f"Adding spatial series {name} as a traceplot.") + data = obj.data[()] + timestamps = get_timestamps(obj) + + if len(data.shape) == 1: + data = data[:, None] + + traces = {} + for i in range(data.shape[1]): + trace = np.vstack([timestamps, data[:, i]]).T + traces[f"dim {i}"] = trace + + snub.io.project.add_traceplot(project_directory, name, traces, **options) + + +def add_position(project_directory, obj, name, start_time, end_time, options): + """ + Adds position data to a SNUB project in the form of a traceplot. + """ + print(f'Adding positions "{name}" as a traceplot.') + + traces = {} + for key, child in obj.spatial_series.items(): + timestamps = get_timestamps(child) + positions = child.data[()] + if len(positions.shape) == 1: + trace = np.vstack([timestamps, positions]).T + traces[key] = trace + else: + for i in range(positions.shape[1]): + trace = np.vstack([timestamps, positions[:, i]]).T + traces[f"{key} (dim {i})"] = trace + + snub.io.project.add_traceplot(project_directory, name, traces, **options) + + +def add_generic_timeseries( + project_directory, obj, name, start_time, end_time, options, heatmap_threshold=10 +): + """ + Adds a generic timeseries to a SNUB project in the form of a traceplot or heatmap, + depending on the number of dimensions. + """ + data = obj.data[()] + timestamps = get_timestamps(obj) + + if len(data.shape) == 1: + data = data[:, None] + + if data.shape[1] > heatmap_threshold: + print(f'Adding generic timeseries "{name}" as a heatmap.') + snub.io.project.add_heatmap( + project_directory, + name, + data.T, + time_intervals=_timestamps_to_intervals(timestamps), + **options, + ) + else: + print(f'Adding generic timeseries "{name}" as a traceplot.') + traces = {} + for i in range(data.shape[1]): + trace = np.vstack([timestamps, data[:, i]]).T + traces[f"dim-{i}"] = trace + + snub.io.project.add_traceplot(project_directory, name, traces, **options) + + +def add_pose_estimation(project_directory, obj, name, start_time, end_time, options): + """ + Adds pose estimation to a SNUB project in the form of a traceplot (and a 3D pose plot + if the pose estimation is 3D). + """ + print(f'Adding pose estimation "{name}" as a traceplot.') + + joint_labels = obj.nodes[:] + keypoints, traces = [], {} + for joint in joint_labels: + child = obj.pose_estimation_series[joint] + timestamps = get_timestamps(child) + keypoints.append(child.data[()]) + for i in range(child.data.shape[1]): + label = f"{joint}-{['x','y','z'][i]}" + trace = np.vstack([timestamps, child.data[:, i]]).T + traces[label] = trace + + snub.io.project.add_traceplot(project_directory, name, traces, **options) + + if keypoints[-1].shape[1] == 3: + print(f'Adding pose estimation "{name}" as a 3D pose plot.') + snub.io.project.add_pose3D( + project_directory, + name, + np.stack(keypoints, axis=1), + links=obj.edges[:], + time_intervals=_timestamps_to_intervals(timestamps), + ) + + +def add_image_series(project_directory, obj, name, start_time, end_time, options): + """ + Adds an image series to a SNUB project in the form of a video. + """ + print(f'Adding image series "{name}" as a video.') + timestamps = get_timestamps(obj) + for path in obj.external_file: + if not os.path.exists(path): + print(f"Warning: external file {path} does not exist. Skipping.") + else: + snub.io.project.add_video( + project_directory, path, name, timestamps=timestamps, **options + ) + + +def add_label_series(project_directory, obj, name, start_time, end_time, options): + """ + Adds a label series to a SNUB project in the form of a heatmap. + """ + print(f'Adding label series "{name}" as a heatmap.') + data = obj.data[()].T + timestamps = get_timestamps(obj) + labels = obj.vocabulary[:] + + snub.io.project.add_heatmap( + project_directory, + name, + data, + time_intervals=_timestamps_to_intervals(timestamps), + labels=labels, + **options, + ) + + +def add_time_intervals(project_directory, obj, name, start_time, end_time, options): + """ + Adds a time intervals to a SNUB project in the form of a traceplot. + """ + print(f'Adding time intervals "{name}" as a traceplot.') + + ignored_fields = [ + n for n in obj.colnames if n not in ["start_time", "stop_time", "timeseries"] + ] + if ignored_fields: + print(f'Warning: ignoring fields {ignored_fields} in time intervals "{name}"') + + starts = obj.start_time[()] + ends = obj.stop_time[()] + trace = [[start_time, 0]] + for start, end in zip(starts, ends): + trace.append([start - EPS, 0]) + trace.append([start, 1]) + trace.append([end - EPS, 1]) + trace.append([end, 0]) + trace.append([end_time, 0]) + traces = {"intervals": np.array(trace)} + snub.io.project.add_traceplot(project_directory, name, traces, **options) + + +def add_ephys_units(project_directory, obj, name, start_time, end_time, options): + """ + Adds ephys units to a SNUB project in the form of a spikeplot. + """ + print(f'Adding ephys units "{name}" as a spikeplot.') + + spike_times_per_unit = obj.to_dataframe()["spike_times"] + spike_times = np.hstack(spike_times_per_unit) + spike_labels = np.hstack( + [np.ones(len(spikes)) * i for i, spikes in enumerate(spike_times_per_unit)] + ) + snub.io.project.add_spikeplot( + project_directory, + name, + spike_times, + spike_labels, + **options, + ) + + +def add_events(project_directory, obj, name, start_time, end_time, options): + """ + Adds events to a SNUB project in the form of a traceplot. + """ + print(f'Adding events "{name}" as a traceplot.') + + trace = [(start_time, 0)] + for t in obj.timestamps[()]: + trace.append((t - EPS, 0)) + trace.append((t, 1)) + trace.append((t + EPS, 0)) + trace.append((end_time, 0)) + traces = {"events": np.array(trace)} + snub.io.project.add_traceplot(project_directory, name, traces, **options) diff --git a/snub/io/plot.py b/snub/io/plot.py index 42b4c13..d77c4a8 100644 --- a/snub/io/plot.py +++ b/snub/io/plot.py @@ -1,21 +1,28 @@ import numpy as np - def scatter_plot_bounds(xy, margin=0.05, n_neighbors=100, distance_cutoff=2): """ Get xlim and ylim for a scatter plot such that outliers are excluded. Bounds are based on the largest component of a knn graph with distance cutoff. """ import pynndescent, networkx as nx - edges,dists = pynndescent.NNDescent(xy,n_neighbors=n_neighbors).neighbor_graph + + edges, dists = pynndescent.NNDescent(xy, n_neighbors=n_neighbors).neighbor_graph G = nx.Graph() G.add_nodes_from(np.arange(xy.shape[0])) - for i,j in zip(*np.nonzero(dists= (3, 11): + import tomllib +else: + try: + import tomli as tomllib + except ImportError: + have_tomllib = False + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_root(): + """Get the project root directory. + + We require that all commands are run from the project root, i.e. the + directory that contains setup.py, setup.cfg, and versioneer.py . + """ + root = os.path.realpath(os.path.abspath(os.getcwd())) + setup_py = os.path.join(root, "setup.py") + versioneer_py = os.path.join(root, "versioneer.py") + if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + # allow 'python path/to/setup.py COMMAND' + root = os.path.dirname(os.path.realpath(os.path.abspath(sys.argv[0]))) + setup_py = os.path.join(root, "setup.py") + versioneer_py = os.path.join(root, "versioneer.py") + if not (os.path.exists(setup_py) or os.path.exists(versioneer_py)): + err = ( + "Versioneer was unable to run the project root directory. " + "Versioneer requires setup.py to be executed from " + "its immediate directory (like 'python setup.py COMMAND'), " + "or in a way that lets it use sys.argv[0] to find the root " + "(like 'python path/to/setup.py COMMAND')." + ) + raise VersioneerBadRootError(err) + try: + # Certain runtime workflows (setup.py install/develop in a setuptools + # tree) execute all dependencies in a single python process, so + # "versioneer" may be imported multiple times, and python's shared + # module-import table will cache the first one. So we can't use + # os.path.dirname(__file__), as that will find whichever + # versioneer.py was first imported, even in later projects. + my_path = os.path.realpath(os.path.abspath(__file__)) + me_dir = os.path.normcase(os.path.splitext(my_path)[0]) + vsr_dir = os.path.normcase(os.path.splitext(versioneer_py)[0]) + if me_dir != vsr_dir and "VERSIONEER_PEP518" not in globals(): + print( + "Warning: build in %s is using versioneer.py from %s" + % (os.path.dirname(my_path), versioneer_py) + ) + except NameError: + pass + return root + + +def get_config_from_root(root): + """Read the project setup.cfg file to determine Versioneer config.""" + # This might raise OSError (if setup.cfg is missing), or + # configparser.NoSectionError (if it lacks a [versioneer] section), or + # configparser.NoOptionError (if it lacks "VCS="). See the docstring at + # the top of versioneer.py for instructions on writing your setup.cfg . + root = Path(root) + pyproject_toml = root / "pyproject.toml" + setup_cfg = root / "setup.cfg" + section = None + if pyproject_toml.exists() and have_tomllib: + try: + with open(pyproject_toml, "rb") as fobj: + pp = tomllib.load(fobj) + section = pp["tool"]["versioneer"] + except (tomllib.TOMLDecodeError, KeyError): + pass + if not section: + parser = configparser.ConfigParser() + with open(setup_cfg) as cfg_file: + parser.read_file(cfg_file) + parser.get("versioneer", "VCS") # raise error if missing + + section = parser["versioneer"] + + cfg = VersioneerConfig() + cfg.VCS = section["VCS"] + cfg.style = section.get("style", "") + cfg.versionfile_source = section.get("versionfile_source") + cfg.versionfile_build = section.get("versionfile_build") + cfg.tag_prefix = section.get("tag_prefix") + if cfg.tag_prefix in ("''", '""', None): + cfg.tag_prefix = "" + cfg.parentdir_prefix = section.get("parentdir_prefix") + cfg.verbose = section.get("verbose") + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +# these dictionaries contain VCS-specific tools +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} + + +def register_vcs_handler(vcs, method): # decorator + """Create decorator to mark a method as the handler of a VCS.""" + + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + HANDLERS.setdefault(vcs, {})[method] = f + return f + + return decorate + + +def run_command( + commands, args, cwd=None, verbose=False, hide_stderr=False, env=None +): + """Call the given command(s).""" + assert isinstance(commands, list) + process = None + + popen_kwargs = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: + try: + dispcmd = str([command] + args) + # remember shell=False, so use git.cmd on windows, not just git + process = subprocess.Popen( + [command] + args, + cwd=cwd, + env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr else None), + **popen_kwargs, + ) + break + except OSError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %s" % dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %s" % (commands,)) + return None, None + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: + if verbose: + print("unable to run %s (error)" % dispcmd) + print("stdout was %s" % stdout) + return None, process.returncode + return stdout, process.returncode + + +LONG_VERSION_PY[ + "git" +] = r''' +# This file helps to compute a version number in source trees obtained from +# git-archive tarball (such as those provided by githubs download-from-tag +# feature). Distribution tarballs (built by setup.py sdist) and build +# directories (produced by setup.py build) will contain a much shorter file +# that just contains the computed version number. + +# This file is released into the public domain. +# Generated by versioneer-0.28 +# https://github.com/python-versioneer/python-versioneer + +"""Git implementation of _version.py.""" + +import errno +import os +import re +import subprocess +import sys +from typing import Callable, Dict +import functools + + +def get_keywords(): + """Get the keywords needed to look up the version information.""" + # these strings will be replaced by git during git-archive. + # setup.py/versioneer.py will grep for the variable names, so they must + # each be defined on a line of their own. _version.py will just call + # get_keywords(). + git_refnames = "%(DOLLAR)sFormat:%%d%(DOLLAR)s" + git_full = "%(DOLLAR)sFormat:%%H%(DOLLAR)s" + git_date = "%(DOLLAR)sFormat:%%ci%(DOLLAR)s" + keywords = {"refnames": git_refnames, "full": git_full, "date": git_date} + return keywords + + +class VersioneerConfig: + """Container for Versioneer configuration parameters.""" + + +def get_config(): + """Create, populate and return the VersioneerConfig() object.""" + # these strings are filled in when 'setup.py versioneer' creates + # _version.py + cfg = VersioneerConfig() + cfg.VCS = "git" + cfg.style = "%(STYLE)s" + cfg.tag_prefix = "%(TAG_PREFIX)s" + cfg.parentdir_prefix = "%(PARENTDIR_PREFIX)s" + cfg.versionfile_source = "%(VERSIONFILE_SOURCE)s" + cfg.verbose = False + return cfg + + +class NotThisMethod(Exception): + """Exception raised if a method is not valid for the current scenario.""" + + +LONG_VERSION_PY: Dict[str, str] = {} +HANDLERS: Dict[str, Dict[str, Callable]] = {} + + +def register_vcs_handler(vcs, method): # decorator + """Create decorator to mark a method as the handler of a VCS.""" + def decorate(f): + """Store f in HANDLERS[vcs][method].""" + if vcs not in HANDLERS: + HANDLERS[vcs] = {} + HANDLERS[vcs][method] = f + return f + return decorate + + +def run_command(commands, args, cwd=None, verbose=False, hide_stderr=False, + env=None): + """Call the given command(s).""" + assert isinstance(commands, list) + process = None + + popen_kwargs = {} + if sys.platform == "win32": + # This hides the console window if pythonw.exe is used + startupinfo = subprocess.STARTUPINFO() + startupinfo.dwFlags |= subprocess.STARTF_USESHOWWINDOW + popen_kwargs["startupinfo"] = startupinfo + + for command in commands: + try: + dispcmd = str([command] + args) + # remember shell=False, so use git.cmd on windows, not just git + process = subprocess.Popen([command] + args, cwd=cwd, env=env, + stdout=subprocess.PIPE, + stderr=(subprocess.PIPE if hide_stderr + else None), **popen_kwargs) + break + except OSError: + e = sys.exc_info()[1] + if e.errno == errno.ENOENT: + continue + if verbose: + print("unable to run %%s" %% dispcmd) + print(e) + return None, None + else: + if verbose: + print("unable to find command, tried %%s" %% (commands,)) + return None, None + stdout = process.communicate()[0].strip().decode() + if process.returncode != 0: + if verbose: + print("unable to run %%s (error)" %% dispcmd) + print("stdout was %%s" %% stdout) + return None, process.returncode + return stdout, process.returncode + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for _ in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return {"version": dirname[len(parentdir_prefix):], + "full-revisionid": None, + "dirty": False, "error": None, "date": None} + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print("Tried directories %%s but none started with prefix %%s" %% + (str(rootdirs), parentdir_prefix)) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") + date = keywords.get("date") + if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + + # git-2.2.0 added "%%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = {r.strip() for r in refnames.strip("()").split(",")} + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = {r[len(TAG):] for r in refs if r.startswith(TAG)} + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %%d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = {r for r in refs if re.search(r'\d', r)} + if verbose: + print("discarding '%%s', no digits" %% ",".join(refs - tags)) + if verbose: + print("likely tags: %%s" %% ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix):] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r'\d', r): + continue + if verbose: + print("picking %%s" %% r) + return {"version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": None, + "date": date} + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return {"version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, "error": "no suitable tags", "date": None} + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner(GITS, ["rev-parse", "--git-dir"], cwd=root, + hide_stderr=not verbose) + if rc != 0: + if verbose: + print("Directory %%s not under git control" %% root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = runner(GITS, [ + "describe", "--tags", "--dirty", "--always", "--long", + "--match", f"{tag_prefix}[[:digit:]]*" + ], cwd=root) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + branch_name, rc = runner(GITS, ["rev-parse", "--abbrev-ref", "HEAD"], + cwd=root) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[:git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r'^(.+)-(\d+)-g([0-9a-f]+)$', git_describe) + if not mo: + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = ("unable to parse git-describe output: '%%s'" + %% describe_out) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%%s' doesn't start with prefix '%%s'" + print(fmt %% (full_tag, tag_prefix)) + pieces["error"] = ("tag '%%s' doesn't start with prefix '%%s'" + %% (full_tag, tag_prefix)) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix):] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = runner(GITS, ["show", "-s", "--format=%%ci", "HEAD"], cwd=root)[0].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_branch(pieces): + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). + + Exceptions: + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%%d.g%%s" %% (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%%d.g%%s" %% (pieces["distance"], + pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver): + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the + post-release version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces): + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: + if pieces["distance"]: + # update the post release segment + tag_version, post_version = pep440_split_post(pieces["closest-tag"]) + rendered = tag_version + if post_version is not None: + rendered += ".post%%d.dev%%d" %% (post_version + 1, pieces["distance"]) + else: + rendered += ".post0.dev%%d" %% (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] + else: + # exception #1 + rendered = "0.post0.dev%%d" %% pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + return rendered + + +def render_pep440_post_branch(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%%s" %% pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%%d" %% pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%%d-g%%s" %% (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return {"version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None} + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%%s'" %% style) + + return {"version": rendered, "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], "error": None, + "date": pieces.get("date")} + + +def get_versions(): + """Get version information or return default if unable to do so.""" + # I am in _version.py, which lives at ROOT/VERSIONFILE_SOURCE. If we have + # __file__, we can work backwards from there to the root. Some + # py2exe/bbfreeze/non-CPython implementations don't do __file__, in which + # case we can only use expanded keywords. + + cfg = get_config() + verbose = cfg.verbose + + try: + return git_versions_from_keywords(get_keywords(), cfg.tag_prefix, + verbose) + except NotThisMethod: + pass + + try: + root = os.path.realpath(__file__) + # versionfile_source is the relative path from the top of the source + # tree (where the .git directory might live) to this file. Invert + # this to find the root from __file__. + for _ in cfg.versionfile_source.split('/'): + root = os.path.dirname(root) + except NameError: + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to find root of source tree", + "date": None} + + try: + pieces = git_pieces_from_vcs(cfg.tag_prefix, root, verbose) + return render(pieces, cfg.style) + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + return versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + except NotThisMethod: + pass + + return {"version": "0+unknown", "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", "date": None} +''' + + +@register_vcs_handler("git", "get_keywords") +def git_get_keywords(versionfile_abs): + """Extract version information from the given file.""" + # the code embedded in _version.py can just fetch the value of these + # keywords. When used from setup.py, we don't want to import _version.py, + # so we do it with a regexp instead. This function is not used from + # _version.py. + keywords = {} + try: + with open(versionfile_abs, "r") as fobj: + for line in fobj: + if line.strip().startswith("git_refnames ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["refnames"] = mo.group(1) + if line.strip().startswith("git_full ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["full"] = mo.group(1) + if line.strip().startswith("git_date ="): + mo = re.search(r'=\s*"(.*)"', line) + if mo: + keywords["date"] = mo.group(1) + except OSError: + pass + return keywords + + +@register_vcs_handler("git", "keywords") +def git_versions_from_keywords(keywords, tag_prefix, verbose): + """Get version information from git keywords.""" + if "refnames" not in keywords: + raise NotThisMethod("Short version file found") + date = keywords.get("date") + if date is not None: + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + + # git-2.2.0 added "%cI", which expands to an ISO-8601 -compliant + # datestamp. However we prefer "%ci" (which expands to an "ISO-8601 + # -like" string, which we must then edit to make compliant), because + # it's been around since git-1.5.3, and it's too difficult to + # discover which version we're using, or to work around using an + # older one. + date = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + refnames = keywords["refnames"].strip() + if refnames.startswith("$Format"): + if verbose: + print("keywords are unexpanded, not using") + raise NotThisMethod("unexpanded keywords, not a git-archive tarball") + refs = {r.strip() for r in refnames.strip("()").split(",")} + # starting in git-1.8.3, tags are listed as "tag: foo-1.0" instead of + # just "foo-1.0". If we see a "tag: " prefix, prefer those. + TAG = "tag: " + tags = {r[len(TAG) :] for r in refs if r.startswith(TAG)} + if not tags: + # Either we're using git < 1.8.3, or there really are no tags. We use + # a heuristic: assume all version tags have a digit. The old git %d + # expansion behaves like git log --decorate=short and strips out the + # refs/heads/ and refs/tags/ prefixes that would let us distinguish + # between branches and tags. By ignoring refnames without digits, we + # filter out many common branch names like "release" and + # "stabilization", as well as "HEAD" and "master". + tags = {r for r in refs if re.search(r"\d", r)} + if verbose: + print("discarding '%s', no digits" % ",".join(refs - tags)) + if verbose: + print("likely tags: %s" % ",".join(sorted(tags))) + for ref in sorted(tags): + # sorting will prefer e.g. "2.0" over "2.0rc1" + if ref.startswith(tag_prefix): + r = ref[len(tag_prefix) :] + # Filter out refs that exactly match prefix or that don't start + # with a number once the prefix is stripped (mostly a concern + # when prefix is '') + if not re.match(r"\d", r): + continue + if verbose: + print("picking %s" % r) + return { + "version": r, + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": None, + "date": date, + } + # no suitable tags, so version is "0+unknown", but full hex is still there + if verbose: + print("no suitable tags, using unknown + full revision id") + return { + "version": "0+unknown", + "full-revisionid": keywords["full"].strip(), + "dirty": False, + "error": "no suitable tags", + "date": None, + } + + +@register_vcs_handler("git", "pieces_from_vcs") +def git_pieces_from_vcs(tag_prefix, root, verbose, runner=run_command): + """Get version from 'git describe' in the root of the source tree. + + This only gets called if the git-archive 'subst' keywords were *not* + expanded, and _version.py hasn't already been rewritten with a short + version string, meaning we're inside a checked out source tree. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + + # GIT_DIR can interfere with correct operation of Versioneer. + # It may be intended to be passed to the Versioneer-versioned project, + # but that should not change where we get our version from. + env = os.environ.copy() + env.pop("GIT_DIR", None) + runner = functools.partial(runner, env=env) + + _, rc = runner( + GITS, ["rev-parse", "--git-dir"], cwd=root, hide_stderr=not verbose + ) + if rc != 0: + if verbose: + print("Directory %s not under git control" % root) + raise NotThisMethod("'git rev-parse --git-dir' returned error") + + # if there is a tag matching tag_prefix, this yields TAG-NUM-gHEX[-dirty] + # if there isn't one, this yields HEX[-dirty] (no NUM) + describe_out, rc = runner( + GITS, + [ + "describe", + "--tags", + "--dirty", + "--always", + "--long", + "--match", + f"{tag_prefix}[[:digit:]]*", + ], + cwd=root, + ) + # --long was added in git-1.5.5 + if describe_out is None: + raise NotThisMethod("'git describe' failed") + describe_out = describe_out.strip() + full_out, rc = runner(GITS, ["rev-parse", "HEAD"], cwd=root) + if full_out is None: + raise NotThisMethod("'git rev-parse' failed") + full_out = full_out.strip() + + pieces = {} + pieces["long"] = full_out + pieces["short"] = full_out[:7] # maybe improved later + pieces["error"] = None + + branch_name, rc = runner( + GITS, ["rev-parse", "--abbrev-ref", "HEAD"], cwd=root + ) + # --abbrev-ref was added in git-1.6.3 + if rc != 0 or branch_name is None: + raise NotThisMethod("'git rev-parse --abbrev-ref' returned error") + branch_name = branch_name.strip() + + if branch_name == "HEAD": + # If we aren't exactly on a branch, pick a branch which represents + # the current commit. If all else fails, we are on a branchless + # commit. + branches, rc = runner(GITS, ["branch", "--contains"], cwd=root) + # --contains was added in git-1.5.4 + if rc != 0 or branches is None: + raise NotThisMethod("'git branch --contains' returned error") + branches = branches.split("\n") + + # Remove the first line if we're running detached + if "(" in branches[0]: + branches.pop(0) + + # Strip off the leading "* " from the list of branches. + branches = [branch[2:] for branch in branches] + if "master" in branches: + branch_name = "master" + elif not branches: + branch_name = None + else: + # Pick the first branch that is returned. Good or bad. + branch_name = branches[0] + + pieces["branch"] = branch_name + + # parse describe_out. It will be like TAG-NUM-gHEX[-dirty] or HEX[-dirty] + # TAG might have hyphens. + git_describe = describe_out + + # look for -dirty suffix + dirty = git_describe.endswith("-dirty") + pieces["dirty"] = dirty + if dirty: + git_describe = git_describe[: git_describe.rindex("-dirty")] + + # now we have TAG-NUM-gHEX or HEX + + if "-" in git_describe: + # TAG-NUM-gHEX + mo = re.search(r"^(.+)-(\d+)-g([0-9a-f]+)$", git_describe) + if not mo: + # unparsable. Maybe git-describe is misbehaving? + pieces["error"] = ( + "unable to parse git-describe output: '%s'" % describe_out + ) + return pieces + + # tag + full_tag = mo.group(1) + if not full_tag.startswith(tag_prefix): + if verbose: + fmt = "tag '%s' doesn't start with prefix '%s'" + print(fmt % (full_tag, tag_prefix)) + pieces["error"] = "tag '%s' doesn't start with prefix '%s'" % ( + full_tag, + tag_prefix, + ) + return pieces + pieces["closest-tag"] = full_tag[len(tag_prefix) :] + + # distance: number of commits since tag + pieces["distance"] = int(mo.group(2)) + + # commit: short hex revision ID + pieces["short"] = mo.group(3) + + else: + # HEX: no tags + pieces["closest-tag"] = None + out, rc = runner(GITS, ["rev-list", "HEAD", "--left-right"], cwd=root) + pieces["distance"] = len(out.split()) # total number of commits + + # commit date: see ISO-8601 comment in git_versions_from_keywords() + date = runner(GITS, ["show", "-s", "--format=%ci", "HEAD"], cwd=root)[ + 0 + ].strip() + # Use only the last line. Previous lines may contain GPG signature + # information. + date = date.splitlines()[-1] + pieces["date"] = date.strip().replace(" ", "T", 1).replace(" ", "", 1) + + return pieces + + +def do_vcs_install(versionfile_source, ipy): + """Git-specific installation logic for Versioneer. + + For Git, this means creating/changing .gitattributes to mark _version.py + for export- subst keyword substitution. + """ + GITS = ["git"] + if sys.platform == "win32": + GITS = ["git.cmd", "git.exe"] + files = [versionfile_source] + if ipy: + files.append(ipy) + if "VERSIONEER_PEP518" not in globals(): + try: + my_path = __file__ + if my_path.endswith((".pyc", ".pyo")): + my_path = os.path.splitext(my_path)[0] + ".py" + versioneer_file = os.path.relpath(my_path) + except NameError: + versioneer_file = "versioneer.py" + files.append(versioneer_file) + present = False + try: + with open(".gitattributes", "r") as fobj: + for line in fobj: + if line.strip().startswith(versionfile_source): + if "export-subst" in line.strip().split()[1:]: + present = True + break + except OSError: + pass + if not present: + with open(".gitattributes", "a+") as fobj: + fobj.write(f"{versionfile_source} export-subst\n") + files.append(".gitattributes") + run_command(GITS, ["add", "--"] + files) + + +def versions_from_parentdir(parentdir_prefix, root, verbose): + """Try to determine the version from the parent directory name. + + Source tarballs conventionally unpack into a directory that includes both + the project name and a version string. We will also support searching up + two directory levels for an appropriately named parent directory + """ + rootdirs = [] + + for _ in range(3): + dirname = os.path.basename(root) + if dirname.startswith(parentdir_prefix): + return { + "version": dirname[len(parentdir_prefix) :], + "full-revisionid": None, + "dirty": False, + "error": None, + "date": None, + } + rootdirs.append(root) + root = os.path.dirname(root) # up a level + + if verbose: + print( + "Tried directories %s but none started with prefix %s" + % (str(rootdirs), parentdir_prefix) + ) + raise NotThisMethod("rootdir doesn't start with parentdir_prefix") + + +SHORT_VERSION_PY = """ +# This file was generated by 'versioneer.py' (0.28) from +# revision-control system data, or from the parent directory name of an +# unpacked source archive. Distribution tarballs contain a pre-generated copy +# of this file. + +import json + +version_json = ''' +%s +''' # END VERSION_JSON + + +def get_versions(): + return json.loads(version_json) +""" + + +def versions_from_file(filename): + """Try to determine the version from _version.py if present.""" + try: + with open(filename) as f: + contents = f.read() + except OSError: + raise NotThisMethod("unable to read _version.py") + mo = re.search( + r"version_json = '''\n(.*)''' # END VERSION_JSON", + contents, + re.M | re.S, + ) + if not mo: + mo = re.search( + r"version_json = '''\r\n(.*)''' # END VERSION_JSON", + contents, + re.M | re.S, + ) + if not mo: + raise NotThisMethod("no version_json in _version.py") + return json.loads(mo.group(1)) + + +def write_to_version_file(filename, versions): + """Write the given version number to the given _version.py file.""" + os.unlink(filename) + contents = json.dumps( + versions, sort_keys=True, indent=1, separators=(",", ": ") + ) + with open(filename, "w") as f: + f.write(SHORT_VERSION_PY % contents) + + print("set %s to '%s'" % (filename, versions["version"])) + + +def plus_or_dot(pieces): + """Return a + if we don't already have one, else return a .""" + if "+" in pieces.get("closest-tag", ""): + return "." + return "+" + + +def render_pep440(pieces): + """Build up version string, with post-release "local version identifier". + + Our goal: TAG[+DISTANCE.gHEX[.dirty]] . Note that if you + get a tagged build and then dirty it, you'll get TAG+0.gHEX.dirty + + Exceptions: + 1: no tags. git_describe was just HEX. 0+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_branch(pieces): + """TAG[[.dev0]+DISTANCE.gHEX[.dirty]] . + + The ".dev0" means not master branch. Note that .dev0 sorts backwards + (a feature branch will appear "older" than the master branch). + + Exceptions: + 1: no tags. 0[.dev0]+untagged.DISTANCE.gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0" + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+untagged.%d.g%s" % (pieces["distance"], pieces["short"]) + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def pep440_split_post(ver): + """Split pep440 version string at the post-release segment. + + Returns the release segments before the post-release and the post-release + version number (or -1 if no post-release segment is present). + """ + vc = str.split(ver, ".post") + return vc[0], int(vc[1] or 0) if len(vc) == 2 else None + + +def render_pep440_pre(pieces): + """TAG[.postN.devDISTANCE] -- No -dirty. + + Exceptions: + 1: no tags. 0.post0.devDISTANCE + """ + if pieces["closest-tag"]: + if pieces["distance"]: + # update the post release segment + tag_version, post_version = pep440_split_post( + pieces["closest-tag"] + ) + rendered = tag_version + if post_version is not None: + rendered += ".post%d.dev%d" % ( + post_version + 1, + pieces["distance"], + ) + else: + rendered += ".post0.dev%d" % (pieces["distance"]) + else: + # no commits, use the tag as the version + rendered = pieces["closest-tag"] + else: + # exception #1 + rendered = "0.post0.dev%d" % pieces["distance"] + return rendered + + +def render_pep440_post(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX] . + + The ".dev0" means dirty. Note that .dev0 sorts backwards + (a dirty tree will appear "older" than the corresponding clean one), + but you shouldn't be releasing software with -dirty anyways. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + return rendered + + +def render_pep440_post_branch(pieces): + """TAG[.postDISTANCE[.dev0]+gHEX[.dirty]] . + + The ".dev0" means not master branch. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0]+gHEX[.dirty] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += plus_or_dot(pieces) + rendered += "g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["branch"] != "master": + rendered += ".dev0" + rendered += "+g%s" % pieces["short"] + if pieces["dirty"]: + rendered += ".dirty" + return rendered + + +def render_pep440_old(pieces): + """TAG[.postDISTANCE[.dev0]] . + + The ".dev0" means dirty. + + Exceptions: + 1: no tags. 0.postDISTANCE[.dev0] + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"] or pieces["dirty"]: + rendered += ".post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + else: + # exception #1 + rendered = "0.post%d" % pieces["distance"] + if pieces["dirty"]: + rendered += ".dev0" + return rendered + + +def render_git_describe(pieces): + """TAG[-DISTANCE-gHEX][-dirty]. + + Like 'git describe --tags --dirty --always'. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + if pieces["distance"]: + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render_git_describe_long(pieces): + """TAG-DISTANCE-gHEX[-dirty]. + + Like 'git describe --tags --dirty --always -long'. + The distance/hash is unconditional. + + Exceptions: + 1: no tags. HEX[-dirty] (note: no 'g' prefix) + """ + if pieces["closest-tag"]: + rendered = pieces["closest-tag"] + rendered += "-%d-g%s" % (pieces["distance"], pieces["short"]) + else: + # exception #1 + rendered = pieces["short"] + if pieces["dirty"]: + rendered += "-dirty" + return rendered + + +def render(pieces, style): + """Render the given version pieces into the requested style.""" + if pieces["error"]: + return { + "version": "unknown", + "full-revisionid": pieces.get("long"), + "dirty": None, + "error": pieces["error"], + "date": None, + } + + if not style or style == "default": + style = "pep440" # the default + + if style == "pep440": + rendered = render_pep440(pieces) + elif style == "pep440-branch": + rendered = render_pep440_branch(pieces) + elif style == "pep440-pre": + rendered = render_pep440_pre(pieces) + elif style == "pep440-post": + rendered = render_pep440_post(pieces) + elif style == "pep440-post-branch": + rendered = render_pep440_post_branch(pieces) + elif style == "pep440-old": + rendered = render_pep440_old(pieces) + elif style == "git-describe": + rendered = render_git_describe(pieces) + elif style == "git-describe-long": + rendered = render_git_describe_long(pieces) + else: + raise ValueError("unknown style '%s'" % style) + + return { + "version": rendered, + "full-revisionid": pieces["long"], + "dirty": pieces["dirty"], + "error": None, + "date": pieces.get("date"), + } + + +class VersioneerBadRootError(Exception): + """The project root directory is unknown or missing key files.""" + + +def get_versions(verbose=False): + """Get the project version from whatever source is available. + + Returns dict with two keys: 'version' and 'full'. + """ + if "versioneer" in sys.modules: + # see the discussion in cmdclass.py:get_cmdclass() + del sys.modules["versioneer"] + + root = get_root() + cfg = get_config_from_root(root) + + assert cfg.VCS is not None, "please set [versioneer]VCS= in setup.cfg" + handlers = HANDLERS.get(cfg.VCS) + assert handlers, "unrecognized VCS '%s'" % cfg.VCS + verbose = verbose or cfg.verbose + assert ( + cfg.versionfile_source is not None + ), "please set versioneer.versionfile_source" + assert cfg.tag_prefix is not None, "please set versioneer.tag_prefix" + + versionfile_abs = os.path.join(root, cfg.versionfile_source) + + # extract version from first of: _version.py, VCS command (e.g. 'git + # describe'), parentdir. This is meant to work for developers using a + # source checkout, for users of a tarball created by 'setup.py sdist', + # and for users of a tarball/zipball created by 'git archive' or github's + # download-from-tag feature or the equivalent in other VCSes. + + get_keywords_f = handlers.get("get_keywords") + from_keywords_f = handlers.get("keywords") + if get_keywords_f and from_keywords_f: + try: + keywords = get_keywords_f(versionfile_abs) + ver = from_keywords_f(keywords, cfg.tag_prefix, verbose) + if verbose: + print("got version from expanded keyword %s" % ver) + return ver + except NotThisMethod: + pass + + try: + ver = versions_from_file(versionfile_abs) + if verbose: + print("got version from file %s %s" % (versionfile_abs, ver)) + return ver + except NotThisMethod: + pass + + from_vcs_f = handlers.get("pieces_from_vcs") + if from_vcs_f: + try: + pieces = from_vcs_f(cfg.tag_prefix, root, verbose) + ver = render(pieces, cfg.style) + if verbose: + print("got version from VCS %s" % ver) + return ver + except NotThisMethod: + pass + + try: + if cfg.parentdir_prefix: + ver = versions_from_parentdir(cfg.parentdir_prefix, root, verbose) + if verbose: + print("got version from parentdir %s" % ver) + return ver + except NotThisMethod: + pass + + if verbose: + print("unable to compute version") + + return { + "version": "0+unknown", + "full-revisionid": None, + "dirty": None, + "error": "unable to compute version", + "date": None, + } + + +def get_version(): + """Get the short version string for this project.""" + return get_versions()["version"] + + +def get_cmdclass(cmdclass=None): + """Get the custom setuptools subclasses used by Versioneer. + + If the package uses a different cmdclass (e.g. one from numpy), it should + be provide as an argument. + """ + if "versioneer" in sys.modules: + del sys.modules["versioneer"] + # this fixes the "python setup.py develop" case (also 'install' and + # 'easy_install .'), in which subdependencies of the main project are + # built (using setup.py bdist_egg) in the same python process. Assume + # a main project A and a dependency B, which use different versions + # of Versioneer. A's setup.py imports A's Versioneer, leaving it in + # sys.modules by the time B's setup.py is executed, causing B to run + # with the wrong versioneer. Setuptools wraps the sub-dep builds in a + # sandbox that restores sys.modules to it's pre-build state, so the + # parent is protected against the child's "import versioneer". By + # removing ourselves from sys.modules here, before the child build + # happens, we protect the child from the parent's versioneer too. + # Also see https://github.com/python-versioneer/python-versioneer/issues/52 + + cmds = {} if cmdclass is None else cmdclass.copy() + + # we add "version" to setuptools + from setuptools import Command + + class cmd_version(Command): + description = "report generated version string" + user_options = [] + boolean_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + vers = get_versions(verbose=True) + print("Version: %s" % vers["version"]) + print(" full-revisionid: %s" % vers.get("full-revisionid")) + print(" dirty: %s" % vers.get("dirty")) + print(" date: %s" % vers.get("date")) + if vers["error"]: + print(" error: %s" % vers["error"]) + + cmds["version"] = cmd_version + + # we override "build_py" in setuptools + # + # most invocation pathways end up running build_py: + # distutils/build -> build_py + # distutils/install -> distutils/build ->.. + # setuptools/bdist_wheel -> distutils/install ->.. + # setuptools/bdist_egg -> distutils/install_lib -> build_py + # setuptools/install -> bdist_egg ->.. + # setuptools/develop -> ? + # pip install: + # copies source tree to a tempdir before running egg_info/etc + # if .git isn't copied too, 'git describe' will fail + # then does setup.py bdist_wheel, or sometimes setup.py install + # setup.py egg_info -> ? + + # pip install -e . and setuptool/editable_wheel will invoke build_py + # but the build_py command is not expected to copy any files. + + # we override different "build_py" commands for both environments + if "build_py" in cmds: + _build_py = cmds["build_py"] + else: + from setuptools.command.build_py import build_py as _build_py + + class cmd_build_py(_build_py): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_py.run(self) + if getattr(self, "editable_mode", False): + # During editable installs `.py` and data files are + # not copied to build_lib + return + # now locate _version.py in the new build/ directory and replace + # it with an updated value + if cfg.versionfile_build: + target_versionfile = os.path.join( + self.build_lib, cfg.versionfile_build + ) + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + cmds["build_py"] = cmd_build_py + + if "build_ext" in cmds: + _build_ext = cmds["build_ext"] + else: + from setuptools.command.build_ext import build_ext as _build_ext + + class cmd_build_ext(_build_ext): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + _build_ext.run(self) + if self.inplace: + # build_ext --inplace will only build extensions in + # build/lib<..> dir with no _version.py to write to. + # As in place builds will already have a _version.py + # in the module dir, we do not need to write one. + return + # now locate _version.py in the new build/ directory and replace + # it with an updated value + if not cfg.versionfile_build: + return + target_versionfile = os.path.join( + self.build_lib, cfg.versionfile_build + ) + if not os.path.exists(target_versionfile): + print( + f"Warning: {target_versionfile} does not exist, skipping " + "version update. This can happen if you are running build_ext " + "without first running build_py." + ) + return + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + cmds["build_ext"] = cmd_build_ext + + if "cx_Freeze" in sys.modules: # cx_freeze enabled? + from cx_Freeze.dist import build_exe as _build_exe + + # nczeczulin reports that py2exe won't like the pep440-style string + # as FILEVERSION, but it can be used for PRODUCTVERSION, e.g. + # setup(console=[{ + # "version": versioneer.get_version().split("+", 1)[0], # FILEVERSION + # "product_version": versioneer.get_version(), + # ... + + class cmd_build_exe(_build_exe): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + target_versionfile = cfg.versionfile_source + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + _build_exe.run(self) + os.unlink(target_versionfile) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + + cmds["build_exe"] = cmd_build_exe + del cmds["build_py"] + + if "py2exe" in sys.modules: # py2exe enabled? + try: + from py2exe.setuptools_buildexe import py2exe as _py2exe + except ImportError: + from py2exe.distutils_buildexe import py2exe as _py2exe + + class cmd_py2exe(_py2exe): + def run(self): + root = get_root() + cfg = get_config_from_root(root) + versions = get_versions() + target_versionfile = cfg.versionfile_source + print("UPDATING %s" % target_versionfile) + write_to_version_file(target_versionfile, versions) + + _py2exe.run(self) + os.unlink(target_versionfile) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + + cmds["py2exe"] = cmd_py2exe + + # sdist farms its file list building out to egg_info + if "egg_info" in cmds: + _egg_info = cmds["egg_info"] + else: + from setuptools.command.egg_info import egg_info as _egg_info + + class cmd_egg_info(_egg_info): + def find_sources(self): + # egg_info.find_sources builds the manifest list and writes it + # in one shot + super().find_sources() + + # Modify the filelist and normalize it + root = get_root() + cfg = get_config_from_root(root) + self.filelist.append("versioneer.py") + if cfg.versionfile_source: + # There are rare cases where versionfile_source might not be + # included by default, so we must be explicit + self.filelist.append(cfg.versionfile_source) + self.filelist.sort() + self.filelist.remove_duplicates() + + # The write method is hidden in the manifest_maker instance that + # generated the filelist and was thrown away + # We will instead replicate their final normalization (to unicode, + # and POSIX-style paths) + from setuptools import unicode_utils + + normalized = [ + unicode_utils.filesys_decode(f).replace(os.sep, "/") + for f in self.filelist.files + ] + + manifest_filename = os.path.join(self.egg_info, "SOURCES.txt") + with open(manifest_filename, "w") as fobj: + fobj.write("\n".join(normalized)) + + cmds["egg_info"] = cmd_egg_info + + # we override different "sdist" commands for both environments + if "sdist" in cmds: + _sdist = cmds["sdist"] + else: + from setuptools.command.sdist import sdist as _sdist + + class cmd_sdist(_sdist): + def run(self): + versions = get_versions() + self._versioneer_generated_versions = versions + # unless we update this, the command will keep using the old + # version + self.distribution.metadata.version = versions["version"] + return _sdist.run(self) + + def make_release_tree(self, base_dir, files): + root = get_root() + cfg = get_config_from_root(root) + _sdist.make_release_tree(self, base_dir, files) + # now locate _version.py in the new base_dir directory + # (remembering that it may be a hardlink) and replace it with an + # updated value + target_versionfile = os.path.join(base_dir, cfg.versionfile_source) + print("UPDATING %s" % target_versionfile) + write_to_version_file( + target_versionfile, self._versioneer_generated_versions + ) + + cmds["sdist"] = cmd_sdist + + return cmds + + +CONFIG_ERROR = """ +setup.cfg is missing the necessary Versioneer configuration. You need +a section like: + + [versioneer] + VCS = git + style = pep440 + versionfile_source = src/myproject/_version.py + versionfile_build = myproject/_version.py + tag_prefix = + parentdir_prefix = myproject- + +You will also need to edit your setup.py to use the results: + + import versioneer + setup(version=versioneer.get_version(), + cmdclass=versioneer.get_cmdclass(), ...) + +Please read the docstring in ./versioneer.py for configuration instructions, +edit setup.cfg, and re-run the installer or 'python versioneer.py setup'. +""" + +SAMPLE_CONFIG = """ +# See the docstring in versioneer.py for instructions. Note that you must +# re-run 'versioneer.py setup' after changing this section, and commit the +# resulting files. + +[versioneer] +#VCS = git +#style = pep440 +#versionfile_source = +#versionfile_build = +#tag_prefix = +#parentdir_prefix = + +""" + +OLD_SNIPPET = """ +from ._version import get_versions +__version__ = get_versions()['version'] +del get_versions +""" + +INIT_PY_SNIPPET = """ +from . import {0} +__version__ = {0}.get_versions()['version'] +""" + + +def do_setup(): + """Do main VCS-independent setup function for installing Versioneer.""" + root = get_root() + try: + cfg = get_config_from_root(root) + except ( + OSError, + configparser.NoSectionError, + configparser.NoOptionError, + ) as e: + if isinstance(e, (OSError, configparser.NoSectionError)): + print( + "Adding sample versioneer config to setup.cfg", file=sys.stderr + ) + with open(os.path.join(root, "setup.cfg"), "a") as f: + f.write(SAMPLE_CONFIG) + print(CONFIG_ERROR, file=sys.stderr) + return 1 + + print(" creating %s" % cfg.versionfile_source) + with open(cfg.versionfile_source, "w") as f: + LONG = LONG_VERSION_PY[cfg.VCS] + f.write( + LONG + % { + "DOLLAR": "$", + "STYLE": cfg.style, + "TAG_PREFIX": cfg.tag_prefix, + "PARENTDIR_PREFIX": cfg.parentdir_prefix, + "VERSIONFILE_SOURCE": cfg.versionfile_source, + } + ) + + ipy = os.path.join(os.path.dirname(cfg.versionfile_source), "__init__.py") + if os.path.exists(ipy): + try: + with open(ipy, "r") as f: + old = f.read() + except OSError: + old = "" + module = os.path.splitext(os.path.basename(cfg.versionfile_source))[0] + snippet = INIT_PY_SNIPPET.format(module) + if OLD_SNIPPET in old: + print(" replacing boilerplate in %s" % ipy) + with open(ipy, "w") as f: + f.write(old.replace(OLD_SNIPPET, snippet)) + elif snippet not in old: + print(" appending to %s" % ipy) + with open(ipy, "a") as f: + f.write(snippet) + else: + print(" %s unmodified" % ipy) + else: + print(" %s doesn't exist, ok" % ipy) + ipy = None + + # Make VCS-specific changes. For git, this means creating/changing + # .gitattributes to mark _version.py for export-subst keyword + # substitution. + do_vcs_install(cfg.versionfile_source, ipy) + return 0 + + +def scan_setup_py(): + """Validate the contents of setup.py against Versioneer's expectations.""" + found = set() + setters = False + errors = 0 + with open("setup.py", "r") as f: + for line in f.readlines(): + if "import versioneer" in line: + found.add("import") + if "versioneer.get_cmdclass()" in line: + found.add("cmdclass") + if "versioneer.get_version()" in line: + found.add("get_version") + if "versioneer.VCS" in line: + setters = True + if "versioneer.versionfile_source" in line: + setters = True + if len(found) != 3: + print("") + print("Your setup.py appears to be missing some important items") + print("(but I might be wrong). Please make sure it has something") + print("roughly like the following:") + print("") + print(" import versioneer") + print(" setup( version=versioneer.get_version(),") + print(" cmdclass=versioneer.get_cmdclass(), ...)") + print("") + errors += 1 + if setters: + print("You should remove lines like 'versioneer.VCS = ' and") + print("'versioneer.versionfile_source = ' . This configuration") + print("now lives in setup.cfg, and should be removed from setup.py") + print("") + errors += 1 + return errors + + +def setup_command(): + """Set up Versioneer and exit with appropriate error code.""" + errors = do_setup() + errors += scan_setup_py() + sys.exit(1 if errors else 0) + + +if __name__ == "__main__": + cmd = sys.argv[1] + if cmd == "setup": + setup_command()