From aee2d2adb4a5eba6938550e89898c60b1dac2162 Mon Sep 17 00:00:00 2001 From: Caleb Weinreb Date: Sat, 6 Jan 2024 14:20:28 -0500 Subject: [PATCH] Applied black formater to all files and removed unused imports --- snub/__main__.py | 4 +- snub/gui/__init__.py | 2 +- snub/gui/main.py | 2 - snub/gui/panels/pose3D.py | 1 - snub/gui/panels/roi.py | 157 +++++++++++++------- snub/gui/panels/scatter.py | 287 ++++++++++++++++++++++-------------- snub/gui/panels/video.py | 1 - snub/gui/stacks/__init__.py | 2 +- snub/gui/stacks/base.py | 7 +- snub/gui/stacks/panel.py | 28 ++-- snub/gui/stacks/track.py | 1 - snub/gui/tracks/__init__.py | 12 +- snub/gui/tracks/base.py | 3 - snub/gui/tracks/heatmap.py | 1 - snub/gui/tracks/spike.py | 161 ++++++++++++++------ snub/gui/tracks/trace.py | 1 - snub/gui/utils/__init__.py | 2 +- snub/gui/utils/interval.py | 119 ++++++++------- snub/gui/utils/widgets.py | 2 +- snub/io/__init__.py | 2 +- snub/io/manifold.py | 125 ++++++---------- snub/io/plot.py | 23 ++- snub/io/project.py | 2 - 23 files changed, 557 insertions(+), 388 deletions(-) diff --git a/snub/__main__.py b/snub/__main__.py index 6978470..a42be42 100644 --- a/snub/__main__.py +++ b/snub/__main__.py @@ -1,4 +1,4 @@ from snub.gui.main import run -if __name__ == '__main__': - run() \ No newline at end of file +if __name__ == "__main__": + run() diff --git a/snub/gui/__init__.py b/snub/gui/__init__.py index 8a2fdfc..c626a80 100644 --- a/snub/gui/__init__.py +++ b/snub/gui/__init__.py @@ -1,4 +1,4 @@ from . import panels from . import stacks from . import tracks -from . import main \ No newline at end of file +from . import main diff --git a/snub/gui/main.py b/snub/gui/main.py index bec4103..ba85c28 100644 --- a/snub/gui/main.py +++ b/snub/gui/main.py @@ -2,12 +2,10 @@ from PyQt5.QtWidgets import * from PyQt5.QtGui import * import sys, os, json -import numpy as np from functools import partial from snub.gui.utils import IntervalIndex, CheckBox from snub.gui.stacks import PanelStack, TrackStack from snub.gui.tracks import TracePlot -import time def set_style(app): diff --git a/snub/gui/panels/pose3D.py b/snub/gui/panels/pose3D.py index 06bb555..2b991eb 100644 --- a/snub/gui/panels/pose3D.py +++ b/snub/gui/panels/pose3D.py @@ -2,7 +2,6 @@ from PyQt5.QtWidgets import * from PyQt5.QtGui import * import numpy as np -import time import os from vispy.scene import SceneCanvas diff --git a/snub/gui/panels/roi.py b/snub/gui/panels/roi.py index 83276da..ed30a2a 100644 --- a/snub/gui/panels/roi.py +++ b/snub/gui/panels/roi.py @@ -2,81 +2,116 @@ from PyQt5.QtWidgets import * from PyQt5.QtGui import * import numpy as np -import time import os -import cmapy import cv2 from scipy.sparse import load_npz -from functools import partial from vidio import VideoReader from vispy.scene import SceneCanvas from vispy.scene.visuals import Image, Line from snub.gui.panels import Panel -from snub.gui.utils import HeaderMixin, IntervalIndex, AdjustColormapDialog +from snub.gui.utils import HeaderMixin, AdjustColormapDialog from snub.io.project import _random_color def _roi_contours(rois, dims, threshold_max_ratio=0.2, blur_kernel=2): - rois = np.array(rois.todense()).reshape(rois.shape[0],*dims) + rois = np.array(rois.todense()).reshape(rois.shape[0], *dims) contour_coordinates = [] for roi in rois: - roi_blur = cv2.GaussianBlur(roi,(11,11),blur_kernel) - roi_mask = roi_blur > roi_blur.max()*threshold_max_ratio - xy = cv2.findContours(roi_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)[0][0].squeeze() - contour_coordinates.append(np.vstack((xy,xy[:1]))) + roi_blur = cv2.GaussianBlur(roi, (11, 11), blur_kernel) + roi_mask = roi_blur > roi_blur.max() * threshold_max_ratio + xy = cv2.findContours( + roi_mask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE + )[0][0].squeeze() + contour_coordinates.append(np.vstack((xy, xy[:1]))) return contour_coordinates class ROIPanel(Panel, HeaderMixin): eps = 1e-10 - def __init__(self, config, rois_path=None, labels_path=None, timestamps_path=None, - dimensions=None, video_paths=None, contour_colors={}, linewidth=3, - initial_selected_rois=[], vmin=0, vmax=1, colormap='viridis', **kwargs): - + def __init__( + self, + config, + rois_path=None, + labels_path=None, + timestamps_path=None, + dimensions=None, + video_paths=None, + contour_colors={}, + linewidth=3, + initial_selected_rois=[], + vmin=0, + vmax=1, + colormap="viridis", + **kwargs + ): super().__init__(config, **kwargs) self.linewidth = linewidth self.colormap = colormap - self.vmin,self.vmax = vmin,vmax + self.vmin, self.vmax = vmin, vmax self.dims = dimensions self.current_frame_index = None self.is_visible = True - self.rois = load_npz(os.path.join(config['project_directory'],rois_path)) - self.timestamps = np.load(os.path.join(config['project_directory'],timestamps_path)) + self.rois = load_npz(os.path.join(config["project_directory"], rois_path)) + self.timestamps = np.load( + os.path.join(config["project_directory"], timestamps_path) + ) - if labels_path is None: self.labels = [str(i) for i in range(self.rois.shape[0])] - else: self.labels = open(os.path.join(config['project_directory'],labels_path),'r').read().split('\n') + if labels_path is None: + self.labels = [str(i) for i in range(self.rois.shape[0])] + else: + self.labels = ( + open(os.path.join(config["project_directory"], labels_path), "r") + .read() + .split("\n") + ) self.adjust_colormap_dialog = AdjustColormapDialog(self, self.vmin, self.vmax) self.adjust_colormap_dialog.new_range.connect(self.update_colormap_range) - self.canvas = SceneCanvas(self, keys='interactive', show=True) + self.canvas = SceneCanvas(self, keys="interactive", show=True) self.canvas.events.mouse_release.connect(self.mouse_release) - self.viewbox = self.canvas.central_widget.add_grid().add_view(row=0, col=0, camera='panzoom') - self.viewbox.camera.set_range(x=(0,self.dims[1]), y=(0,self.dims[0]), margin=0) - self.viewbox.camera.aspect=1 + self.viewbox = self.canvas.central_widget.add_grid().add_view( + row=0, col=0, camera="panzoom" + ) + self.viewbox.camera.set_range( + x=(0, self.dims[1]), y=(0, self.dims[0]), margin=0 + ) + self.viewbox.camera.aspect = 1 self.contours = {} - for label,coordinates in zip(self.labels, _roi_contours(self.rois, self.dims)): - color = contour_colors[label] if label in contour_colors else _random_color() - self.contours[label] = Line(coordinates, color=np.array(color)/255, - width=self.linewidth, connect='strip', parent=None) - - self.vids = {name : VideoReader( - os.path.join(config['project_directory'],video_path) - ) for name,video_path in video_paths.items()} + for label, coordinates in zip(self.labels, _roi_contours(self.rois, self.dims)): + color = ( + contour_colors[label] if label in contour_colors else _random_color() + ) + self.contours[label] = Line( + coordinates, + color=np.array(color) / 255, + width=self.linewidth, + connect="strip", + parent=None, + ) + + self.vids = { + name: VideoReader(os.path.join(config["project_directory"], video_path)) + for name, video_path in video_paths.items() + } self.dropDown = QComboBox() self.dropDown.addItems(list(video_paths.keys())[::-1]) self.dropDown.activated.connect(self.update_image) - self.image = Image(np.zeros(self.dims, dtype=np.float32), - cmap=colormap, parent=self.viewbox.scene, clim=(0,1)) - self.update_current_time(config['init_current_time']) + self.image = Image( + np.zeros(self.dims, dtype=np.float32), + cmap=colormap, + parent=self.viewbox.scene, + clim=(0, 1), + ) + self.update_current_time(config["init_current_time"]) self.initUI(**kwargs) def initUI(self, **kwargs): @@ -84,64 +119,76 @@ def initUI(self, **kwargs): self.layout.addWidget(self.dropDown) self.layout.addWidget(self.canvas.native) self.image.order = 1 - for c in self.contours.values(): c.order=0 - self.dropDown.setStyleSheet(""" + for c in self.contours.values(): + c.order = 0 + self.dropDown.setStyleSheet( + """ QComboBox::item { color: white; background-color : #3E3E3E;} - QComboBox::item:selected { background-color: #999999;} """) + QComboBox::item:selected { background-color: #999999;} """ + ) def update_visible_contours(self, visible_contours): - for l,c in self.contours.items(): + for l, c in self.contours.items(): if l in visible_contours: c.parent = self.viewbox.scene - else: c.parent = None + else: + c.parent = None def update_current_time(self, t): - self.current_frame_index = min(self.timestamps.searchsorted(t), len(self.timestamps)-1) - if self.is_visible: self.update_image() + self.current_frame_index = min( + self.timestamps.searchsorted(t), len(self.timestamps) - 1 + ) + if self.is_visible: + self.update_image() def toggle_visiblity(self, *args): super().toggle_visiblity(*args) - if self.is_visible: self.update_image() + if self.is_visible: + self.update_image() def update_image(self): name = self.dropDown.currentText() - if self.current_frame_index is None: x = np.zeros(self.dims) - else: x = self.vids[name][self.current_frame_index][:,:,0]/255 - image = (np.clip(x, self.vmin, self.vmax)-self.vmin)/(self.vmax-self.vmin) + if self.current_frame_index is None: + x = np.zeros(self.dims) + else: + x = self.vids[name][self.current_frame_index][:, :, 0] / 255 + image = (np.clip(x, self.vmin, self.vmax) - self.vmin) / (self.vmax - self.vmin) self.image.set_data(image.astype(np.float32)) self.canvas.update() def update_colormap_range(self, vmin, vmax): - self.vmin,self.vmax = vmin,vmax + self.vmin, self.vmax = vmin, vmax self.update_image() def show_adjust_colormap_dialog(self): self.adjust_colormap_dialog.show() def mouse_release(self, event): - if event.button == 2: self.context_menu(event) + if event.button == 2: + self.context_menu(event) def context_menu(self, event): contextMenu = QMenu(self) - def add_menu_item(name, slot, item_type='label'): + + def add_menu_item(name, slot, item_type="label"): action = QWidgetAction(self) - if item_type=='checkbox': + if item_type == "checkbox": widget = QCheckBox(name) widget.stateChanged.connect(slot) - elif item_type=='label': + elif item_type == "label": widget = QLabel(name) action.triggered.connect(slot) action.setDefaultWidget(widget) - contextMenu.addAction(action) + contextMenu.addAction(action) return widget # click to show adjust colormap range dialog - label = add_menu_item('Adjust colormap range',self.show_adjust_colormap_dialog) + label = add_menu_item("Adjust colormap range", self.show_adjust_colormap_dialog) - contextMenu.setStyleSheet(""" + contextMenu.setStyleSheet( + """ QMenu::item, QLabel, QCheckBox { background-color : #3E3E3E; padding: 5px 6px 5px 6px;} QMenu::item:selected, QLabel:hover, QCheckBox:hover { background-color: #999999;} - QMenu::separator { background-color: rgb(20,20,20);} """) + QMenu::separator { background-color: rgb(20,20,20);} """ + ) action = contextMenu.exec_(event.native.globalPos()) - - diff --git a/snub/gui/panels/scatter.py b/snub/gui/panels/scatter.py index a67e39a..e9baac8 100644 --- a/snub/gui/panels/scatter.py +++ b/snub/gui/panels/scatter.py @@ -1,9 +1,7 @@ from PyQt5.QtCore import * from PyQt5.QtWidgets import * from PyQt5.QtGui import * -import pyqtgraph as pg import numpy as np -import time import os import cmapy from functools import partial @@ -16,173 +14,229 @@ from snub.gui.utils import HeaderMixin, AdjustColormapDialog, IntervalIndex - - - class ScatterPanel(Panel, HeaderMixin): eps = 1e-10 - def __init__(self, config, selected_intervals, data_path=None, name='', - pointsize=10, linewidth=1, facecolor=(180,180,180), xlim=None, ylim=None, - selected_edgecolor=(255,255,0), edgecolor=(0,0,0), current_node_size=20, - current_node_color=(255,0,0), colormap='viridis', - selection_intersection_threshold=0.5, variable_labels=[], **kwargs): - + def __init__( + self, + config, + selected_intervals, + data_path=None, + name="", + pointsize=10, + linewidth=1, + facecolor=(180, 180, 180), + xlim=None, + ylim=None, + selected_edgecolor=(255, 255, 0), + edgecolor=(0, 0, 0), + current_node_size=20, + current_node_color=(255, 0, 0), + colormap="viridis", + selection_intersection_threshold=0.5, + variable_labels=[], + **kwargs + ): super().__init__(config, **kwargs) assert data_path is not None self.selected_intervals = selected_intervals - self.bounds = config['bounds'] - self.min_step = config['min_step'] + self.bounds = config["bounds"] + self.min_step = config["min_step"] self.pointsize = pointsize self.linewidth = linewidth - self.facecolor = np.array(facecolor)/256 - self.edgecolor = np.array(edgecolor)/256 + self.facecolor = np.array(facecolor) / 256 + self.edgecolor = np.array(edgecolor) / 256 self.colormap = colormap - self.selected_edgecolor = np.array(selected_edgecolor)/256 + self.selected_edgecolor = np.array(selected_edgecolor) / 256 self.current_node_size = current_node_size - self.current_node_color = np.array(current_node_color)/256 + self.current_node_color = np.array(current_node_color) / 256 self.selection_intersection_threshold = selection_intersection_threshold - self.variable_labels = ['Interval start','Interval end']+variable_labels - self.vmin,self.vmax = 0,1 - self.current_variable_label = '(No color)' - self.sort_nodes_by_variable = True + self.variable_labels = ["Interval start", "Interval end"] + variable_labels + self.vmin, self.vmax = 0, 1 + self.current_variable_label = "(No color)" + self.sort_nodes_by_variable = True self.show_marker_trail = False - self.data = np.load(os.path.join(config['project_directory'],data_path)) - self.data[:,2:4] = self.data[:,2:4] + np.array([-self.eps, self.eps]) - self.is_selected = np.zeros(self.data.shape[0])>0 + self.data = np.load(os.path.join(config["project_directory"], data_path)) + self.data[:, 2:4] = self.data[:, 2:4] + np.array([-self.eps, self.eps]) + self.is_selected = np.zeros(self.data.shape[0]) > 0 self.plot_order = np.arange(self.data.shape[0]) - self.interval_index = IntervalIndex(min_step=self.min_step, intervals=self.data[:,2:4]) + self.interval_index = IntervalIndex( + min_step=self.min_step, intervals=self.data[:, 2:4] + ) self.adjust_colormap_dialog = AdjustColormapDialog(self, self.vmin, self.vmax) self.variable_menu = QListWidget(self) self.variable_menu.itemClicked.connect(self.variable_menu_item_clicked) self.show_variable_menu() - self.canvas = SceneCanvas(self, keys='interactive', show=True) + self.canvas = SceneCanvas(self, keys="interactive", show=True) self.canvas.events.mouse_move.connect(self.mouse_move) self.canvas.events.mouse_release.connect(self.mouse_release) - self.viewbox = self.canvas.central_widget.add_grid().add_view(row=0, col=0, camera='panzoom') - self.viewbox.camera.aspect=1 + self.viewbox = self.canvas.central_widget.add_grid().add_view( + row=0, col=0, camera="panzoom" + ) + self.viewbox.camera.aspect = 1 self.scatter = Markers(antialias=0) self.scatter_selected = Markers(antialias=0) self.current_node_marker = Markers(antialias=0) - self.rect = Rectangle(border_color=(1,1,1), color=(1,1,1,.2), center=(0,0), width=1, height=1) + self.rect = Rectangle( + border_color=(1, 1, 1), + color=(1, 1, 1, 0.2), + center=(0, 0), + width=1, + height=1, + ) self.viewbox.add(self.scatter) - self.initUI(name=name, xlim=xlim, ylim=ylim, ) + self.initUI( + name=name, + xlim=xlim, + ylim=ylim, + ) def initUI(self, xlim=None, ylim=None, **kwargs): super().initUI(**kwargs) splitter = QSplitter(Qt.Horizontal) splitter.addWidget(self.variable_menu) splitter.addWidget(self.canvas.native) - splitter.setStretchFactor(0,3) - splitter.setStretchFactor(1,3) + splitter.setStretchFactor(0, 3) + splitter.setStretchFactor(1, 3) self.layout.addWidget(splitter) self.update_scatter() - if xlim is None: xlim = [self.data[:,0].min(),self.data[:,0].max()] - if ylim is None: ylim = [self.data[:,1].min(),self.data[:,1].max()] + if xlim is None: + xlim = [self.data[:, 0].min(), self.data[:, 0].max()] + if ylim is None: + ylim = [self.data[:, 1].min(), self.data[:, 1].max()] self.viewbox.camera.set_range(x=xlim, y=ylim, margin=0.1) self.rect.order = 0 - self.current_node_marker.order=1 - self.scatter_selected.order=2 - self.scatter.order=3 - self.variable_menu.setStyleSheet(""" + self.current_node_marker.order = 1 + self.scatter_selected.order = 2 + self.scatter.order = 3 + self.variable_menu.setStyleSheet( + """ QListWidget::item { background-color : #3E3E3E; color:white; padding: 5px 6px 5px 6px;} QListWidget::item:hover, QLabel:hover { background-color: #999999; color:white; } - QListWidget { background-color : #3E3E3E; }""") - + QListWidget { background-color : #3E3E3E; }""" + ) def update_scatter(self): if self.current_variable_label in self.variable_labels: - x = self.data[:,2+self.variable_labels.index(self.current_variable_label)] - if self.sort_nodes_by_variable: self.plot_order = np.argsort(x)[::-1] - else: self.plot_order = np.arange(len(x)) - x = np.clip((x - self.vmin) / (self.vmax - self.vmin), 0, 1)[self.plot_order] - face_color = cmapy.cmap(self.colormap).squeeze()[:,::-1][(255*x).astype(int)]/255 - else: face_color = np.repeat(self.facecolor[None],self.data.shape[0],axis=0) + x = self.data[ + :, 2 + self.variable_labels.index(self.current_variable_label) + ] + if self.sort_nodes_by_variable: + self.plot_order = np.argsort(x)[::-1] + else: + self.plot_order = np.arange(len(x)) + x = np.clip((x - self.vmin) / (self.vmax - self.vmin), 0, 1)[ + self.plot_order + ] + face_color = ( + cmapy.cmap(self.colormap).squeeze()[:, ::-1][(255 * x).astype(int)] + / 255 + ) + else: + face_color = np.repeat(self.facecolor[None], self.data.shape[0], axis=0) self.scatter.set_data( - pos=self.data[self.plot_order,:2], + pos=self.data[self.plot_order, :2], face_color=face_color, - edge_color=self.edgecolor, - edge_width=self.linewidth, - size=self.pointsize) + edge_color=self.edgecolor, + edge_width=self.linewidth, + size=self.pointsize, + ) if self.is_selected.any(): is_selected = self.is_selected[self.plot_order] self.scatter_selected.set_data( - pos=self.data[self.plot_order,:2][is_selected], + pos=self.data[self.plot_order, :2][is_selected], face_color=face_color[is_selected], - edge_color=self.selected_edgecolor, - edge_width=(self.linewidth*2), - size=self.pointsize) + edge_color=self.selected_edgecolor, + edge_width=(self.linewidth * 2), + size=self.pointsize, + ) self.scatter_selected.parent = self.viewbox.scene - else: self.scatter_selected.parent = None - + else: + self.scatter_selected.parent = None def update_current_time(self, t): - if self.show_marker_trail: - times = np.linspace(t,t-5,11) - sizes = np.exp(np.linspace(0,-1.5,11)) + if self.show_marker_trail: + times = np.linspace(t, t - 5, 11) + sizes = np.exp(np.linspace(0, -1.5, 11)) else: times = np.array([t]) sizes = np.array([1]) - nodes,time_indexes = self.interval_index.intervals_containing(times) - if len(nodes)>0: + nodes, time_indexes = self.interval_index.intervals_containing(times) + if len(nodes) > 0: self.current_node_marker.set_data( - pos=self.data[nodes,:2], + pos=self.data[nodes, :2], face_color=self.current_node_color, - size=sizes[time_indexes]*self.current_node_size) + size=sizes[time_indexes] * self.current_node_size, + ) self.current_node_marker.parent = self.viewbox.scene - else: self.current_node_marker.parent = None - + else: + self.current_node_marker.parent = None def context_menu(self, event): contextMenu = QMenu(self) - def add_menu_item(name, slot, item_type='label'): + + def add_menu_item(name, slot, item_type="label"): action = QWidgetAction(self) - if item_type=='checkbox': + if item_type == "checkbox": widget = QCheckBox(name) widget.stateChanged.connect(slot) - elif item_type=='label': + elif item_type == "label": widget = QLabel(name) action.triggered.connect(slot) action.setDefaultWidget(widget) - contextMenu.addAction(action) + contextMenu.addAction(action) return widget - + # show/hide variable menu - if self.variable_menu.isVisible(): add_menu_item('Hide variables menu', self.hide_variable_menu) - else: add_menu_item('Show variables menu', self.show_variable_menu) + if self.variable_menu.isVisible(): + add_menu_item("Hide variables menu", self.hide_variable_menu) + else: + add_menu_item("Show variables menu", self.show_variable_menu) # get enriched variables (only available is nodes are selected) - label = add_menu_item('Sort variables by enrichment', self.get_enriched_variables) - if self.is_selected.sum()==0: label.setStyleSheet("QLabel { color: rgb(120,120,120); }") - label = add_menu_item('Restore original variable order', self.show_variable_menu) + label = add_menu_item( + "Sort variables by enrichment", self.get_enriched_variables + ) + if self.is_selected.sum() == 0: + label.setStyleSheet("QLabel { color: rgb(120,120,120); }") + label = add_menu_item( + "Restore original variable order", self.show_variable_menu + ) contextMenu.addSeparator() # toggle whether to plot high-variable-val nodes on top - checkbox = add_menu_item('Plot high values on top', self.toggle_sort_by_color_value, item_type='checkbox') - if self.sort_nodes_by_variable: checkbox.setChecked(True) - else: checkbox.setChecked(False) + checkbox = add_menu_item( + "Plot high values on top", + self.toggle_sort_by_color_value, + item_type="checkbox", + ) + if self.sort_nodes_by_variable: + checkbox.setChecked(True) + else: + checkbox.setChecked(False) contextMenu.addSeparator() # click to show adjust colormap range dialog - label = add_menu_item('Adjust colormap range',self.show_adjust_colormap_dialog) + label = add_menu_item("Adjust colormap range", self.show_adjust_colormap_dialog) contextMenu.addSeparator() - if self.show_marker_trail: - add_menu_item('Hide marker trail',partial(self.toggle_marker_trail,False)) - else: add_menu_item('Show marker trail',partial(self.toggle_marker_trail,True)) - + if self.show_marker_trail: + add_menu_item("Hide marker trail", partial(self.toggle_marker_trail, False)) + else: + add_menu_item("Show marker trail", partial(self.toggle_marker_trail, True)) - contextMenu.setStyleSheet(""" + contextMenu.setStyleSheet( + """ QMenu::item, QLabel, QCheckBox { background-color : #3E3E3E; padding: 5px 6px 5px 6px;} QMenu::item:selected, QLabel:hover, QCheckBox:hover { background-color: #999999;} - QMenu::separator { background-color: rgb(20,20,20);} """) + QMenu::separator { background-color: rgb(20,20,20);} """ + ) action = contextMenu.exec_(event.native.globalPos()) - def toggle_marker_trail(self, visibility): self.show_marker_trail = visibility self.update_scatter() @@ -195,24 +249,30 @@ def hide_variable_menu(self): def show_variable_menu(self, *args, variable_order=None): self.variable_menu.clear() - if variable_order is None: variable_order = self.variable_labels - for name in variable_order: self.variable_menu.addItem(name) + if variable_order is None: + variable_order = self.variable_labels + for name in variable_order: + self.variable_menu.addItem(name) self.variable_menu.show() def get_enriched_variables(self): - if self.is_selected.sum() > 0 and len(self.variable_labels)>0: - variables_zscore = (self.data[:,2:] - self.data[:,2:].mean(0))/(np.std(self.data[:,2:],axis=0)+self.eps) + if self.is_selected.sum() > 0 and len(self.variable_labels) > 0: + variables_zscore = (self.data[:, 2:] - self.data[:, 2:].mean(0)) / ( + np.std(self.data[:, 2:], axis=0) + self.eps + ) enrichment = variables_zscore[self.is_selected].mean(0) variable_order = [self.variable_labels[i] for i in np.argsort(-enrichment)] self.show_variable_menu(variable_order=variable_order) def update_colormap_range(self, vmin, vmax): - self.vmin,self.vmax = vmin,vmax + self.vmin, self.vmax = vmin, vmax self.update_scatter() def toggle_sort_by_color_value(self, check_state): - if check_state == 0: self.sort_nodes_by_variable = False - else: self.sort_nodes_by_variable = True + if check_state == 0: + self.sort_nodes_by_variable = False + else: + self.sort_nodes_by_variable = True self.update_scatter() def show_adjust_colormap_dialog(self): @@ -221,14 +281,16 @@ def show_adjust_colormap_dialog(self): def colorby(self, label): self.current_variable_label = label if self.current_variable_label in self.variable_labels: - x = self.data[:,2+self.variable_labels.index(self.current_variable_label)] - self.vmin,self.vmax = x.min()-self.eps,x.max()+self.eps - self.adjust_colormap_dialog.update_range(self.vmin,self.vmax) + x = self.data[ + :, 2 + self.variable_labels.index(self.current_variable_label) + ] + self.vmin, self.vmax = x.min() - self.eps, x.max() + self.eps + self.adjust_colormap_dialog.update_range(self.vmin, self.vmax) self.update_scatter() def mouse_release(self, event): self.rect.parent = None - if event.button == 2: + if event.button == 2: self.context_menu(event) def mouse_move(self, event): @@ -237,23 +299,28 @@ def mouse_move(self, event): if keys.SHIFT in mods or keys.CONTROL in mods: current_pos = self.viewbox.scene.transform.imap(event.pos)[:2] start_pos = self.viewbox.scene.transform.imap(event.press_event.pos)[:2] - if all((current_pos-start_pos)!=0): - self.rect.center = (current_pos+start_pos)/2 - self.rect.width = np.abs(current_pos[0]-start_pos[0]) - self.rect.height = np.abs(current_pos[1]-start_pos[1]) + if all((current_pos - start_pos) != 0): + self.rect.center = (current_pos + start_pos) / 2 + self.rect.width = np.abs(current_pos[0] - start_pos[0]) + self.rect.height = np.abs(current_pos[1] - start_pos[1]) self.rect.parent = self.viewbox.scene - selection_value = int(mods[0]==keys.SHIFT) - enclosed_points = np.all([ - self.data[:,:2]>=np.minimum(current_pos, start_pos), - self.data[:,:2]<=np.maximum(current_pos, start_pos)],axis=(0,2)) + selection_value = int(mods[0] == keys.SHIFT) + enclosed_points = np.all( + [ + self.data[:, :2] >= np.minimum(current_pos, start_pos), + self.data[:, :2] <= np.maximum(current_pos, start_pos), + ], + axis=(0, 2), + ) self.selection_change.emit( - list(self.data[enclosed_points,2:4]), - [selection_value]*len(enclosed_points)) + list(self.data[enclosed_points, 2:4]), + [selection_value] * len(enclosed_points), + ) def update_selected_intervals(self): - intersections = self.selected_intervals.intersection_proportions(self.data[:,2:4]) + intersections = self.selected_intervals.intersection_proportions( + self.data[:, 2:4] + ) self.is_selected = intersections > self.selection_intersection_threshold self.update_scatter() - - diff --git a/snub/gui/panels/video.py b/snub/gui/panels/video.py index cef7a75..b4c7eaf 100644 --- a/snub/gui/panels/video.py +++ b/snub/gui/panels/video.py @@ -2,7 +2,6 @@ from PyQt5.QtWidgets import * from PyQt5.QtGui import * import numpy as np -import time import os from vidio import VideoReader diff --git a/snub/gui/stacks/__init__.py b/snub/gui/stacks/__init__.py index 50f3b74..f4e2fe5 100644 --- a/snub/gui/stacks/__init__.py +++ b/snub/gui/stacks/__init__.py @@ -1,3 +1,3 @@ from .base import Stack from .panel import PanelStack -from .track import TrackStack \ No newline at end of file +from .track import TrackStack diff --git a/snub/gui/stacks/base.py b/snub/gui/stacks/base.py index 7b82817..dfa90a4 100644 --- a/snub/gui/stacks/base.py +++ b/snub/gui/stacks/base.py @@ -11,12 +11,13 @@ def __init__(self, config, selected_intervals): self.selected_intervals = selected_intervals def change_layout_mode(self, layout_mode): - for widget in self.widgets: widget.change_layout_mode(layout_mode) - + for widget in self.widgets: + widget.change_layout_mode(layout_mode) + def initUI(self): sizePolicy = QSizePolicy(QSizePolicy.Expanding, QSizePolicy.Expanding) sizePolicy.setHorizontalStretch(self.size_ratio) self.setSizePolicy(sizePolicy) widget_order = np.argsort([w.order for w in self.widgets]) - self.widgets = [self.widgets[i] for i in widget_order] \ No newline at end of file + self.widgets = [self.widgets[i] for i in widget_order] diff --git a/snub/gui/stacks/panel.py b/snub/gui/stacks/panel.py index 31114cb..202787d 100644 --- a/snub/gui/stacks/panel.py +++ b/snub/gui/stacks/panel.py @@ -1,45 +1,45 @@ from PyQt5.QtCore import * from PyQt5.QtWidgets import * from PyQt5.QtGui import * -import numpy as np from snub.gui.stacks import Stack from snub.gui.panels import VideoPanel, ScatterPanel, ROIPanel, Pose3DPanel + class PanelStack(Stack): def __init__(self, config, selected_intervals): super().__init__(config, selected_intervals) - self.size_ratio = config['panels_size_ratio'] + self.size_ratio = config["panels_size_ratio"] - for props in config['scatter']: # initialize scatter plots + for props in config["scatter"]: # initialize scatter plots panel = ScatterPanel(config, self.selected_intervals, **props) self.widgets.append(panel) - for props in config['video']: # initialize video + for props in config["video"]: # initialize video panel = VideoPanel(config, **props) self.widgets.append(panel) - for props in config['pose3D']: # initialize 3D pose viewer + for props in config["pose3D"]: # initialize 3D pose viewer panel = Pose3DPanel(config, **props) self.widgets.append(panel) - for props in config['roiplot']: # initialize ROI plot + for props in config["roiplot"]: # initialize ROI plot panel = ROIPanel(config, **props) self.widgets.append(panel) self.initUI() - def initUI(self): super().initUI() hbox = QHBoxLayout(self) self.splitter = QSplitter(Qt.Vertical) - for panel in self.widgets: self.splitter.addWidget(panel) + for panel in self.widgets: + self.splitter.addWidget(panel) self.splitter.setSizes([w.size_ratio for w in self.widgets]) hbox.addWidget(self.splitter) - self.splitter.setSizes([100000*p.size_ratio for p in self.widgets]) + self.splitter.setSizes([100000 * p.size_ratio for p in self.widgets]) hbox.setContentsMargins(0, 0, 0, 0) def get_by_name(self, name): @@ -47,14 +47,16 @@ def get_by_name(self, name): if panel.name == name: return panel - def update_current_time(self,t): + def update_current_time(self, t): for panel in self.widgets: panel.update_current_time(t) def update_selected_intervals(self): - for panel in self.widgets: + for panel in self.widgets: panel.update_selected_intervals() def change_layout_mode(self, layout_mode): - self.splitter.setOrientation({'columns':Qt.Vertical, 'rows':Qt.Horizontal}[layout_mode]) - super().change_layout_mode(layout_mode) \ No newline at end of file + self.splitter.setOrientation( + {"columns": Qt.Vertical, "rows": Qt.Horizontal}[layout_mode] + ) + super().change_layout_mode(layout_mode) diff --git a/snub/gui/stacks/track.py b/snub/gui/stacks/track.py index d1816a4..bc3fc79 100644 --- a/snub/gui/stacks/track.py +++ b/snub/gui/stacks/track.py @@ -2,7 +2,6 @@ from PyQt5.QtWidgets import * from PyQt5.QtGui import * import numpy as np -import time from snub.gui.stacks import Stack from snub.gui.tracks import * diff --git a/snub/gui/tracks/__init__.py b/snub/gui/tracks/__init__.py index 2eae32e..010e535 100644 --- a/snub/gui/tracks/__init__.py +++ b/snub/gui/tracks/__init__.py @@ -1,4 +1,12 @@ -from .base import Track, TrackGroup, Timeline, SelectionOverlay, LineOverlay, position_to_time, time_to_position +from .base import ( + Track, + TrackGroup, + Timeline, + SelectionOverlay, + LineOverlay, + position_to_time, + time_to_position, +) from .trace import TracePlot, HeadedTracePlot from .heatmap import Heatmap, HeatmapTraceGroup, HeadedHeatmap -from .spike import SpikePlot, HeadedSpikePlot, SpikePlotTraceGroup \ No newline at end of file +from .spike import SpikePlot, HeadedSpikePlot, SpikePlotTraceGroup diff --git a/snub/gui/tracks/base.py b/snub/gui/tracks/base.py index 1298d0e..1803365 100644 --- a/snub/gui/tracks/base.py +++ b/snub/gui/tracks/base.py @@ -2,9 +2,6 @@ from PyQt5.QtWidgets import * from PyQt5.QtGui import * import numpy as np -import os -import time -from functools import partial from snub.gui.utils import HeaderMixin diff --git a/snub/gui/tracks/heatmap.py b/snub/gui/tracks/heatmap.py index 8c6a553..bdb6f8f 100644 --- a/snub/gui/tracks/heatmap.py +++ b/snub/gui/tracks/heatmap.py @@ -5,7 +5,6 @@ import os import numpy as np import cmapy -import time from numba import njit, prange from snub.gui.tracks import Track, TracePlot, TrackGroup diff --git a/snub/gui/tracks/spike.py b/snub/gui/tracks/spike.py index 277f24b..5702424 100644 --- a/snub/gui/tracks/spike.py +++ b/snub/gui/tracks/spike.py @@ -1,46 +1,69 @@ from PyQt5.QtCore import * from PyQt5.QtWidgets import * from PyQt5.QtGui import * -from functools import partial import os import numpy as np import cmapy -import time from vispy.scene import SceneCanvas from vispy.scene.visuals import Markers, Line -from snub.gui.tracks import Track, TracePlot, TrackGroup, Heatmap +from snub.gui.tracks import TracePlot, TrackGroup, Heatmap -''' +""" class SpikePlot(Heatmap): def __init__(self, config, selected_intervals, spikes_path=None, heatmap_path=None, **kwargs): print(spikes_path, heatmap_path) super().__init__(config, selected_intervals, data_path=heatmap_path, **kwargs) self.spike_data = np.load(os.path.join(config['project_directory'],spikes_path)) -''' +""" class SpikePlot(Heatmap): - def __init__(self, config, selected_intervals, spikes_path=None, markersize=5, - heatmap_path=None, heatmap_range=60, colormap='viridis', **kwargs): - + def __init__( + self, + config, + selected_intervals, + spikes_path=None, + markersize=5, + heatmap_path=None, + heatmap_range=60, + colormap="viridis", + **kwargs + ): super().__init__(config, selected_intervals, data_path=heatmap_path, **kwargs) self.heatmap_range = heatmap_range - spike_data = np.load(os.path.join(config['project_directory'],spikes_path)) - self.spike_times,self.spike_labels = spike_data[:,0], spike_data[:,1].astype(int) + spike_data = np.load(os.path.join(config["project_directory"], spikes_path)) + self.spike_times, self.spike_labels = spike_data[:, 0], spike_data[:, 1].astype( + int + ) self.max_label = self.spike_labels.max() - self.markersize=markersize + self.markersize = markersize self.colormap = colormap - self.cmap = cmapy.cmap(self.colormap).squeeze()[:,::-1].astype(np.float32)/255 - self.canvas = SceneCanvas(self, keys='interactive', bgcolor=self.cmap[0], show=True) - self.viewbox = self.canvas.central_widget.add_grid().add_view(row=0, col=0, camera='panzoom') - line_verts = np.vstack([ - np.ones(self.max_label)*self.spike_times.min()-10, - np.arange(self.max_label), - np.ones(self.max_label)*self.spike_times.max()+10, - np.arange(self.max_label)]).T - self.lines = Line(pos=line_verts.reshape(-1,2), color=np.clip(self.cmap[0]+.1,0,1), method='gl', width=0.5, connect='segments') + self.cmap = ( + cmapy.cmap(self.colormap).squeeze()[:, ::-1].astype(np.float32) / 255 + ) + self.canvas = SceneCanvas( + self, keys="interactive", bgcolor=self.cmap[0], show=True + ) + self.viewbox = self.canvas.central_widget.add_grid().add_view( + row=0, col=0, camera="panzoom" + ) + line_verts = np.vstack( + [ + np.ones(self.max_label) * self.spike_times.min() - 10, + np.arange(self.max_label), + np.ones(self.max_label) * self.spike_times.max() + 10, + np.arange(self.max_label), + ] + ).T + self.lines = Line( + pos=line_verts.reshape(-1, 2), + color=np.clip(self.cmap[0] + 0.1, 0, 1), + method="gl", + width=0.5, + connect="segments", + ) self.viewbox.add(self.lines) self.scatter = Markers() @@ -48,9 +71,9 @@ def __init__(self, config, selected_intervals, spikes_path=None, markersize=5, self.scatter.order = -1 self.set_scatter_data() self.viewbox.add(self.scatter) - + layout = QVBoxLayout(self) - layout.setContentsMargins(0,0,0,0) + layout.setContentsMargins(0, 0, 0, 0) layout.addWidget(self.canvas.native, 1) self.heatmap_image.raise_() self.heatmap_labels.raise_() @@ -58,41 +81,65 @@ def __init__(self, config, selected_intervals, spikes_path=None, markersize=5, def update_current_range(self, current_range): super().update_current_range(current_range) - if (self.current_range[1]-self.current_range[0]) >= self.heatmap_range: + if (self.current_range[1] - self.current_range[0]) >= self.heatmap_range: self.heatmap_image.show() else: self.heatmap_image.hide() - self.viewbox.camera.set_range(x=self.current_range, y=self.get_ylim(), margin=1e-10) - bgcolor = self.cmap[0]*(self.current_range[1]-self.current_range[0])/self.heatmap_range + self.viewbox.camera.set_range( + x=self.current_range, y=self.get_ylim(), margin=1e-10 + ) + bgcolor = ( + self.cmap[0] + * (self.current_range[1] - self.current_range[0]) + / self.heatmap_range + ) self.canvas.bgcolor = bgcolor - self.lines.set_data(color=np.clip(bgcolor+.1,0,1)) + self.lines.set_data(color=np.clip(bgcolor + 0.1, 0, 1)) def spike_coordinates(self): - ycoords = self.max_label-np.argsort(self.row_order)[self.spike_labels]+.5 - return np.vstack((self.spike_times,ycoords)).T + ycoords = self.max_label - np.argsort(self.row_order)[self.spike_labels] + 0.5 + return np.vstack((self.spike_times, ycoords)).T def spike_colors(self): image_data = self.get_image_data() rows = np.argsort(self.row_order)[self.spike_labels] - cols = np.around((self.spike_times-self.intervals[0,0])/self.min_step).astype(int) - colors = image_data[rows,np.clip(cols,0,image_data.shape[1]-1)].astype(np.float32)/255 + cols = np.around( + (self.spike_times - self.intervals[0, 0]) / self.min_step + ).astype(int) + colors = ( + image_data[rows, np.clip(cols, 0, image_data.shape[1] - 1)].astype( + np.float32 + ) + / 255 + ) return colors def zoom_in_vertical(self): super().zoom_in_vertical() - self.viewbox.camera.set_range(x=self.current_range, y=self.get_ylim(), margin=1e-10) + self.viewbox.camera.set_range( + x=self.current_range, y=self.get_ylim(), margin=1e-10 + ) - def zoom_vertical(self,origin,scale_factor): - super().zoom_vertical(origin,scale_factor) - self.viewbox.camera.set_range(x=self.current_range, y=self.get_ylim(), margin=1e-10) + def zoom_vertical(self, origin, scale_factor): + super().zoom_vertical(origin, scale_factor) + self.viewbox.camera.set_range( + x=self.current_range, y=self.get_ylim(), margin=1e-10 + ) def get_ylim(self): - return self.max_label - np.array(self.vertical_range)[::-1] + 1 + return self.max_label - np.array(self.vertical_range)[::-1] + 1 def set_scatter_data(self): xy = self.spike_coordinates() c = self.spike_colors() - self.scatter.set_data(xy, edge_width=0, face_color=c, edge_color=None, symbol='vbar', size=self.markersize) + self.scatter.set_data( + xy, + edge_width=0, + face_color=c, + edge_color=None, + symbol="vbar", + size=self.markersize, + ) def update_row_order(self, order): super().update_row_order(order) @@ -103,27 +150,43 @@ def update_colormap_range(self, *args): self.set_scatter_data() - - - - class HeadedSpikePlot(TrackGroup): def __init__(self, config, selected_intervals, **kwargs): spikeplot = SpikePlot(config, selected_intervals, **kwargs) - super().__init__(config, tracks={'spikeplot':spikeplot}, track_order=['spikeplot'], **kwargs) + super().__init__( + config, tracks={"spikeplot": spikeplot}, track_order=["spikeplot"], **kwargs + ) class SpikePlotTraceGroup(TrackGroup): - def __init__(self, config, selected_intervals, trace_height_ratio=1, - heatmap_height_ratio=2, height_ratio=1, **kwargs): + def __init__( + self, + config, + selected_intervals, + trace_height_ratio=1, + heatmap_height_ratio=2, + height_ratio=1, + **kwargs + ): self.height_ratio = trace_height_ratio + heatmap_height_ratio - spikeplot = SpikePlot(config, selected_intervals, height_ratio=heatmap_height_ratio, **kwargs) + spikeplot = SpikePlot( + config, selected_intervals, height_ratio=heatmap_height_ratio, **kwargs + ) x = spikeplot.intervals.mean(1) - trace_data = {l:np.vstack((x,d)).T for l,d in zip(spikeplot.labels, spikeplot.data)} - trace = TracePlot(config, height_ratio=trace_height_ratio, data=trace_data, **kwargs) - - super().__init__(config, tracks={'trace':trace, 'spikeplot':spikeplot}, - track_order=['trace','spikeplot'], height_ratio=height_ratio, **kwargs) + trace_data = { + l: np.vstack((x, d)).T for l, d in zip(spikeplot.labels, spikeplot.data) + } + trace = TracePlot( + config, height_ratio=trace_height_ratio, data=trace_data, **kwargs + ) + + super().__init__( + config, + tracks={"trace": trace, "spikeplot": spikeplot}, + track_order=["trace", "spikeplot"], + height_ratio=height_ratio, + **kwargs + ) spikeplot.display_trace_signal.connect(trace.show_trace) diff --git a/snub/gui/tracks/trace.py b/snub/gui/tracks/trace.py index 637a38f..ee24b59 100644 --- a/snub/gui/tracks/trace.py +++ b/snub/gui/tracks/trace.py @@ -2,7 +2,6 @@ from PyQt5.QtWidgets import * from PyQt5.QtGui import * import pyqtgraph as pg -import colorsys import numpy as np import pickle import os diff --git a/snub/gui/utils/__init__.py b/snub/gui/utils/__init__.py index eb4ec6d..862cea3 100644 --- a/snub/gui/utils/__init__.py +++ b/snub/gui/utils/__init__.py @@ -1,2 +1,2 @@ from .interval import IntervalIndex -from .widgets import AdjustColormapDialog, HeaderMixin, CheckBox \ No newline at end of file +from .widgets import AdjustColormapDialog, HeaderMixin, CheckBox diff --git a/snub/gui/utils/interval.py b/snub/gui/utils/interval.py index 1180f8a..4b512cb 100644 --- a/snub/gui/utils/interval.py +++ b/snub/gui/utils/interval.py @@ -2,8 +2,7 @@ from numba import njit, prange - -#@njit +@njit def sum_by_index(x, ixs, n): out = np.zeros(n) for i in prange(len(ixs)): @@ -11,69 +10,82 @@ def sum_by_index(x, ixs, n): return out -class IntervalIndexBase(): - def __init__(self, intervals=np.empty((0,2)), **kwargs): +class IntervalIndexBase: + def __init__(self, intervals=np.empty((0, 2)), **kwargs): self.intervals = intervals def clear(self): - self.intervals = np.empty((0,2)) + self.intervals = np.empty((0, 2)) def partition_intervals(self, start, end): - ends_before = self.intervals[:,1] < start - ends_after = self.intervals[:,1] >= start - starts_before = self.intervals[:,0] <= end - starts_after = self.intervals[:,0] > end + ends_before = self.intervals[:, 1] < start + ends_after = self.intervals[:, 1] >= start + starts_before = self.intervals[:, 0] <= end + starts_after = self.intervals[:, 0] > end intersect = self.intervals[np.bitwise_and(ends_after, starts_before)] pre = self.intervals[ends_before] post = self.intervals[starts_after] - return pre,intersect,post - + return pre, intersect, post + def add_interval(self, start, end): - pre,intersect,post = self.partition_intervals(start,end) + pre, intersect, post = self.partition_intervals(start, end) if intersect.shape[0] > 0: - merged_start = np.minimum(intersect[0,0],start) - merged_end = np.maximum(intersect[-1,1],end) - else: + merged_start = np.minimum(intersect[0, 0], start) + merged_end = np.maximum(intersect[-1, 1], end) + else: merged_start, merged_end = start, end - merged_interval = np.array([merged_start, merged_end]).reshape(1,2) + merged_interval = np.array([merged_start, merged_end]).reshape(1, 2) self.intervals = np.vstack((pre, merged_interval, post)) def remove_interval(self, start, end): - pre,intersect,post = self.partition_intervals(start,end) - pre_intersect = np.empty((0,2)) - post_intersect = np.empty((0,2)) + pre, intersect, post = self.partition_intervals(start, end) + pre_intersect = np.empty((0, 2)) + post_intersect = np.empty((0, 2)) if intersect.shape[0] > 0: - if intersect[0,0] < start: pre_intersect = np.array([intersect[0,0],start]) - if intersect[-1,1] > end: post_intersect = np.array([end,intersect[-1,1]]) - self.intervals = np.vstack((pre,pre_intersect,post_intersect,post)) + if intersect[0, 0] < start: + pre_intersect = np.array([intersect[0, 0], start]) + if intersect[-1, 1] > end: + post_intersect = np.array([end, intersect[-1, 1]]) + self.intervals = np.vstack((pre, pre_intersect, post_intersect, post)) - def intersection_proportions(self, query_intervals): + def intersection_proportions(self, query_intervals): query_ixs, ref_ixs = self.all_overlaps_both(self.intervals, query_intervals) - if len(query_ixs)>0: - intersection_starts = np.maximum(query_intervals[query_ixs,0], self.intervals[ref_ixs,0]) - intersection_ends = np.minimum(query_intervals[query_ixs,1], self.intervals[ref_ixs,1]) + if len(query_ixs) > 0: + intersection_starts = np.maximum( + query_intervals[query_ixs, 0], self.intervals[ref_ixs, 0] + ) + intersection_ends = np.minimum( + query_intervals[query_ixs, 1], self.intervals[ref_ixs, 1] + ) intersection_lengths = intersection_ends - intersection_starts - query_intersection_lengths = sum_by_index(intersection_lengths, query_ixs, query_intervals.shape[0]) - query_lengths = query_intervals[:,1] - query_intervals[:,0] + 1e-10 + query_intersection_lengths = sum_by_index( + intersection_lengths, query_ixs, query_intervals.shape[0] + ) + query_lengths = query_intervals[:, 1] - query_intervals[:, 0] + 1e-10 return query_intersection_lengths / query_lengths - else: return np.zeros(query_intervals.shape[0]) + else: + return np.zeros(query_intervals.shape[0]) def all_containments_both(self, ref_intervals, query_locations): raise NotImplementedError() def intervals_containing(self, query_locations): - query_ixs,ref_ixs = self.all_containments_both(self.intervals, query_locations) - valid_containments = np.all([ - self.intervals[ref_ixs,0] <= query_locations[query_ixs], - self.intervals[ref_ixs,1] >= query_locations[query_ixs]],axis=0) + query_ixs, ref_ixs = self.all_containments_both(self.intervals, query_locations) + valid_containments = np.all( + [ + self.intervals[ref_ixs, 0] <= query_locations[query_ixs], + self.intervals[ref_ixs, 1] >= query_locations[query_ixs], + ], + axis=0, + ) return ref_ixs[valid_containments], query_ixs[valid_containments] try: - from ncls import NCLS + # try executing so exception is triggered on import not at runtime - ncls.all_containments_both(np.arange(1),np.arange(1),np.arange(1)); + ncls.all_containments_both(np.arange(1), np.arange(1), np.arange(1)) class IntervalIndex(IntervalIndexBase): def __init__(self, min_step=0.033, **kwargs): @@ -81,15 +93,19 @@ def __init__(self, min_step=0.033, **kwargs): self.min_step = min_step def preprocess_for_ncls(self, intervals): - intervals_discretized = (intervals/self.min_step).astype(int) - return (intervals_discretized[:,0].copy(order='C'), - intervals_discretized[:,1].copy(order='C'), - np.arange(intervals_discretized.shape[0])) + intervals_discretized = (intervals / self.min_step).astype(int) + return ( + intervals_discretized[:, 0].copy(order="C"), + intervals_discretized[:, 1].copy(order="C"), + np.arange(intervals_discretized.shape[0]), + ) def all_containments_both(self, ref_intervals, query_locations): query_locations = (query_locations / self.min_step).astype(int) ncls = NCLS(*self.preprocess_for_ncls(ref_intervals)) - return ncls.all_containments_both(query_locations, query_locations, np.arange(len(query_locations))) + return ncls.all_containments_both( + query_locations, query_locations, np.arange(len(query_locations)) + ) def all_overlaps_both(self, ref_intervals, query_intervals): query_intervals = self.preprocess_for_ncls(query_intervals) @@ -98,26 +114,23 @@ def all_overlaps_both(self, ref_intervals, query_intervals): return ncls.all_overlaps_both(*query_intervals) except: - from interlap import InterLap + class IntervalIndex(IntervalIndexBase): def __init__(self, **kwargs): super().__init__(**kwargs) def all_overlaps_both(self, ref_intervals, query_intervals): - inter = InterLap(ranges=[(s,e,i) for i,(s,e) in enumerate(ref_intervals)]) - query_ixs,ref_ixs = [],[] - for i,(s,e) in enumerate(query_intervals): - overlap_ixs = [interval[2] for interval in inter.find((s,e))] + inter = InterLap( + ranges=[(s, e, i) for i, (s, e) in enumerate(ref_intervals)] + ) + query_ixs, ref_ixs = [], [] + for i, (s, e) in enumerate(query_intervals): + overlap_ixs = [interval[2] for interval in inter.find((s, e))] ref_ixs.append(overlap_ixs) - query_ixs.append([i]*len(overlap_ixs)) - return np.hstack(query_ixs).astype(int),np.hstack(ref_ixs).astype(int) + query_ixs.append([i] * len(overlap_ixs)) + return np.hstack(query_ixs).astype(int), np.hstack(ref_ixs).astype(int) def all_containments_both(self, ref_intervals, query_locations): - query_intervals = np.repeat(query_locations[:,None],2,axis=1) + query_intervals = np.repeat(query_locations[:, None], 2, axis=1) return self.all_overlaps_both(ref_intervals, query_intervals) - - - - - diff --git a/snub/gui/utils/widgets.py b/snub/gui/utils/widgets.py index 7e4cf9e..2b8af34 100644 --- a/snub/gui/utils/widgets.py +++ b/snub/gui/utils/widgets.py @@ -1,4 +1,4 @@ -import numpy as np, os +import os from pyqtgraph import VerticalLabel from PyQt5.QtCore import * from PyQt5.QtWidgets import * diff --git a/snub/io/__init__.py b/snub/io/__init__.py index fab35c9..1096828 100644 --- a/snub/io/__init__.py +++ b/snub/io/__init__.py @@ -1,4 +1,4 @@ from .project import * from .manifold import * from .video import * -from .plot import * \ No newline at end of file +from .plot import * diff --git a/snub/io/manifold.py b/snub/io/manifold.py index ca7f427..1eae667 100644 --- a/snub/io/manifold.py +++ b/snub/io/manifold.py @@ -1,37 +1,29 @@ import numpy as np -import warnings -# Binning / smoothing - -def firing_rates( - spike_times, - spike_labels, - window_size=0.2, - window_step=0.02 -): +def firing_rates(spike_times, spike_labels, window_size=0.2, window_step=0.02): """Convert spike tikes to firing rates using a sliding window - + Parameters ---------- spike_times : ndarray Spike times (in seconds) for all units. The source of each spike is input separately using ``spike_labels`` - + spike_labels: ndarray The source/label for each spike in ``spike_times``. The maximum value of this array determines the number of rows in the heatmap. - + window_size: float, default=0.2 Length (in seconds) of the sliding window used to calculate firing rates - + window_step: float, default=0.02 Step-size (in seconds) between each window used to calculate firing rates Returns ------- firing_rates: ndarray - Array of firing rates, where rows units and columns are sliding + Array of firing rates, where rows units and columns are sliding window locations. ``firing_rates`` has shape ``(N,M)`` where:: N = max(spike_labels)+1 @@ -43,27 +35,23 @@ def firing_rates( of the first window in ``firing_rates``. """ # round spikes to window_step and factor our start time - spike_times = np.around(spike_times/window_step).astype(int) + spike_times = np.around(spike_times / window_step).astype(int) start_time = spike_times.min() spike_times = spike_times - start_time - + # create heatmap of spike counts for each window_step-sized bin spike_labels = spike_labels.astype(int) - heatmap = np.zeros((spike_labels.max()+1, spike_times.max()+1)) - np.add.at(heatmap, (spike_labels, spike_times), 1/window_step) - + heatmap = np.zeros((spike_labels.max() + 1, spike_times.max() + 1)) + np.add.at(heatmap, (spike_labels, spike_times), 1 / window_step) + # use convolution to get sliding window counts - kernel = np.ones(int(window_size//window_step))/(window_size//window_step) - for i in range(heatmap.shape[0]): heatmap[i,:] = np.convolve(heatmap[i,:],kernel, mode='same') - return heatmap, start_time-window_step/2 + kernel = np.ones(int(window_size // window_step)) / (window_size // window_step) + for i in range(heatmap.shape[0]): + heatmap[i, :] = np.convolve(heatmap[i, :], kernel, mode="same") + return heatmap, start_time - window_step / 2 -def bin_data( - data, - binsize, - axis=-1, - return_intervals=False -): +def bin_data(data, binsize, axis=-1, return_intervals=False): """Bin data using non-overlaping windows along `axis` Returns @@ -73,46 +61,36 @@ def bin_data( bin_intervals: ndarray (returned if ``rerturn_intervals=True``) (N,2) array with the start and end index of each bin """ - data = np.moveaxis(data,axis,-1) - pad_amount = (-data.shape[-1])%binsize - num_bins = int((data.shape[-1]+pad_amount)/binsize) + data = np.moveaxis(data, axis, -1) + pad_amount = (-data.shape[-1]) % binsize + num_bins = int((data.shape[-1] + pad_amount) / binsize) - data_padded = np.pad(data,[(0,0)]*(len(data.shape)-1)+[(0,pad_amount)]) + data_padded = np.pad(data, [(0, 0)] * (len(data.shape) - 1) + [(0, pad_amount)]) data_binned = data_padded.reshape(*data.shape[:-1], num_bins, binsize).mean(-1) - if pad_amount > 0: data_binned[...,-1] = data_binned[...,-1] * binsize/(binsize-pad_amount) - data_binned = np.moveaxis(data_binned,-1,axis) + if pad_amount > 0: + data_binned[..., -1] = data_binned[..., -1] * binsize / (binsize - pad_amount) + data_binned = np.moveaxis(data_binned, -1, axis) if return_intervals: - bin_starts = np.arange(0,num_bins)*binsize - bin_ends = np.arange(1,num_bins+1)*binsize + bin_starts = np.arange(0, num_bins) * binsize + bin_ends = np.arange(1, num_bins + 1) * binsize bin_ends[-1] = data.shape[-1] - bin_intervals = np.vstack((bin_starts,bin_ends)).T + bin_intervals = np.vstack((bin_starts, bin_ends)).T return data_binned, bin_intervals - else: return data_binned - + else: + return data_binned -# Normalization - def zscore(data, axis=0, eps=1e-10): """ Z-score standardize the data along ``axis`` """ mean = np.mean(data, axis=axis, keepdims=True) std = np.std(data, axis=axis, keepdims=True) + eps - return (data-mean)/std - - + return (data - mean) / std - -# Dimensionality reduction - -def sort( - data, - method='rastermap', - options={} -): +def sort(data, method="rastermap", options={}): """Compute neuron ordering that groups neurons with similar activity Parameters @@ -136,25 +114,25 @@ def sort( Ordering index that can be used for sorting (see `numpy.argsort`) """ - valid_sort_methods = ['rastermap'] + valid_sort_methods = ["rastermap"] if not method in valid_sort_methods: - raise AssertionError(method+' is not a valid sort method. Must be one of '+repr(valid_sort_methods)) - if method=='rastermap': - print('Computing row order with rastermap') + raise AssertionError( + method + + " is not a valid sort method. Must be one of " + + repr(valid_sort_methods) + ) + if method == "rastermap": + print("Computing row order with rastermap") from rastermap import mapping - model = mapping.Rastermap(n_components=1).fit(data) - return np.argsort(model.embedding[:,0]) + + model = mapping.Rastermap(n_components=1, **options).fit(data) + return np.argsort(model.embedding[:, 0]) def umap_embedding( - data, - standardize=True, - n_pcs=20, - n_components=2, - n_neighbors=100, - **kwargs + data, standardize=True, n_pcs=20, n_components=2, n_neighbors=100, **kwargs ): - """Generate a 2D embedding of neural activity using UMAP. The function + """Generate a 2D embedding of neural activity using UMAP. The function generates the embedding in three steps: 1. (Optionally) standardize (Z-score) the activity of each neuron @@ -172,7 +150,7 @@ def umap_embedding( Whether to standardize (Z-score) the data prior to PCA n_pcs: int, default=20 - Number of principal components to use during PCA. If ``n_pcs=None``, the binned + Number of principal components to use during PCA. If ``n_pcs=None``, the binned data will be passed directly to UMAP n_components: int, default=2 @@ -192,16 +170,11 @@ def umap_embedding( from sklearn.decomposition import PCA from umap import UMAP - if standardize: data = zscore(data, axis=1) + if standardize: + data = zscore(data, axis=1) PCs = PCA(n_components=n_pcs).fit_transform(data.T) - umap_obj = UMAP(n_neighbors=n_neighbors, n_components=n_components, n_epochs=500, **kwargs) + umap_obj = UMAP( + n_neighbors=n_neighbors, n_components=n_components, n_epochs=500, **kwargs + ) coordinates = umap_obj.fit_transform(PCs) return coordinates - - - - - - - - \ No newline at end of file diff --git a/snub/io/plot.py b/snub/io/plot.py index 42b4c13..d77c4a8 100644 --- a/snub/io/plot.py +++ b/snub/io/plot.py @@ -1,21 +1,28 @@ import numpy as np - def scatter_plot_bounds(xy, margin=0.05, n_neighbors=100, distance_cutoff=2): """ Get xlim and ylim for a scatter plot such that outliers are excluded. Bounds are based on the largest component of a knn graph with distance cutoff. """ import pynndescent, networkx as nx - edges,dists = pynndescent.NNDescent(xy,n_neighbors=n_neighbors).neighbor_graph + + edges, dists = pynndescent.NNDescent(xy, n_neighbors=n_neighbors).neighbor_graph G = nx.Graph() G.add_nodes_from(np.arange(xy.shape[0])) - for i,j in zip(*np.nonzero(dists