diff --git a/visanalysis/analysis/imaging_data.py b/visanalysis/analysis/imaging_data.py index 5f1a1b0..85271eb 100644 --- a/visanalysis/analysis/imaging_data.py +++ b/visanalysis/analysis/imaging_data.py @@ -8,20 +8,20 @@ """ import functools import os +import warnings + import h5py -import numpy as np +import matplotlib.colors as mcolors import matplotlib.gridspec as gridspec import matplotlib.pyplot as plt -import matplotlib.colors as mcolors +import numpy as np import pandas as pd import scipy.signal as signal -import warnings - -from visanalysis.util import plot_tools, h5io from visanalysis.util import general_utils as gu +from visanalysis.util import h5io, plot_tools -class ImagingDataObject(): +class ImagingDataObject: """ ImagingDataObject is the central analysis class in visanalysis. -Each instance of this class is an object associated with a single visprotocol @@ -29,8 +29,17 @@ class ImagingDataObject(): -It interacts strictly with the .hdf5 datafile generated by a visprotocol experiment (i.e. not with raw imaging data or metadata) """ - __slots__ = ["file_path", "series_number", "colors", "quiet", - "timing_channel_ind", "threshold", "frame_slop", "command_frame_rate"] + + __slots__ = [ + "file_path", + "series_number", + "colors", + "quiet", + "timing_channel_ind", + "threshold", + "frame_slop", + "command_frame_rate", + ] def __init__(self, file_path, series_number, quiet=False, cfg_dict=None): self.file_path = file_path @@ -43,8 +52,10 @@ def __init__(self, file_path, series_number, quiet=False, cfg_dict=None): # SET DEFAULTS: # For stimulus timing... self.timing_channel_ind = 0 # For multiple photodiode channels, which one to use to define stim timing? - self.threshold = 0.6, # photodiode trace threshold for up/down finding. Normalized to 0-1 - self.frame_slop = 20, # datapoints +/- ideal frame duration + self.threshold = ( + 0.6, + ) # photodiode trace threshold for up/down finding. Normalized to 0-1 + self.frame_slop = (20,) # datapoints +/- ideal frame duration self.command_frame_rate = 120 # Hz, expected frame rate for monitor # Use user-supplied cfg dict to overwrite attributes if cfg_dict is not None: @@ -52,15 +63,25 @@ def __init__(self, file_path, series_number, quiet=False, cfg_dict=None): setattr(self, k, v) # check to see if hdf5 file exists - assert self.file_path.split('.')[-1] == 'hdf5', 'file_path must point to an .hdf5 file, \n current file_path is {}'.format(self.file_path) - assert os.path.exists(self.file_path), 'No hdf5 file found at \n {}, check your filepath'.format(self.file_path) + assert ( + self.file_path.split(".")[-1] == "hdf5" + ), "file_path must point to an .hdf5 file, \n current file_path is {}".format( + self.file_path + ) + assert os.path.exists( + self.file_path + ), "No hdf5 file found at \n {}, check your filepath".format(self.file_path) # check to see if series exists in this file: - with h5py.File(self.file_path, 'r') as experiment_file: + with h5py.File(self.file_path, "r") as experiment_file: find_partial = functools.partial(h5io.find_series, sn=self.series_number) epoch_run_group = experiment_file.visititems(find_partial) if epoch_run_group is None: - raise Exception('No series {} found in {}'.format(self.series_number, self.file_path)) + raise Exception( + "No series {} found in {}".format( + self.series_number, self.file_path + ) + ) def getRunParameters(self, param_key=None): """ @@ -69,13 +90,18 @@ def getRunParameters(self, param_key=None): :param_key: name of run parameter to return. If None, return all run parametrs as dict. Default=None """ - with h5py.File(self.file_path, 'r') as experiment_file: + with h5py.File(self.file_path, "r") as experiment_file: find_partial = functools.partial(h5io.find_series, sn=self.series_number) epoch_run_group = experiment_file.visititems(find_partial) if param_key: - assert param_key in epoch_run_group.attrs, 'Run parameter "{}" not found in run parameters. \n'.format(param_key) \ - + 'Available run parameters are {}'.format([x for x in epoch_run_group.attrs]) + assert ( + param_key in epoch_run_group.attrs + ), 'Run parameter "{}" not found in run parameters. \n'.format( + param_key + ) + "Available run parameters are {}".format( + [x for x in epoch_run_group.attrs] + ) run_parameter = epoch_run_group.attrs[param_key] else: # Get all run parameters in a dict run_parameter = {} @@ -91,14 +117,19 @@ def getEpochParameters(self, param_key=None): :param_key: name of epoch parameter to return. If None, return all epoch parametrs as list of dicts. Default=None """ - with h5py.File(self.file_path, 'r') as experiment_file: + with h5py.File(self.file_path, "r") as experiment_file: find_partial = functools.partial(h5io.find_series, sn=self.series_number) epoch_run_group = experiment_file.visititems(find_partial) epoch_parameter = [] - for epoch in epoch_run_group['epochs'].values(): + for epoch in epoch_run_group["epochs"].values(): if param_key: - assert param_key in epoch.attrs, 'Epoch parameter "{}" not found in epoch_parameters. \n'.format(param_key) \ - + 'Available epoch_parameters are {}'.format([x for x in epoch.attrs]) + assert ( + param_key in epoch.attrs + ), 'Epoch parameter "{}" not found in epoch_parameters. \n'.format( + param_key + ) + "Available epoch_parameters are {}".format( + [x for x in epoch.attrs] + ) epoch_parameter.append(epoch.attrs[param_key]) else: # Get all epoch parameters as a dict new_params = {} @@ -116,8 +147,12 @@ def getEpochParameterDicts(self, target_keys=None): """ stim_dicts = [] for ep in self.getEpochParameters(): - if target_keys is None: # Default: any params with "current" or "component" in them, as in "current_intensity" - new_keys = [key for key in ep.keys() if 'current' in key or 'component' in key] + if ( + target_keys is None + ): # Default: any params with "current" or "component" in them, as in "current_intensity" + new_keys = [ + key for key in ep.keys() if "current" in key or "component" in key + ] else: new_keys = [key for key in target_keys if key in ep.keys()] @@ -133,10 +168,15 @@ def getExperimentMetadata(self, metadata_key=None): :metadata_key: name of metadata item to return. If None, return all metadata as dict. Default=None """ - with h5py.File(self.file_path, 'r') as experiment_file: + with h5py.File(self.file_path, "r") as experiment_file: if metadata_key: - assert metadata_key in experiment_file.attrs, 'metadata_key "{}" not found in metadata. \n'.format(metadata_key) \ - + 'Available metadata keys are {}'.format([x for x in experiment_file.attrs]) + assert ( + metadata_key in experiment_file.attrs + ), 'metadata_key "{}" not found in metadata. \n'.format( + metadata_key + ) + "Available metadata keys are {}".format( + [x for x in experiment_file.attrs] + ) exp_metadata = experiment_file.attrs[metadata_key] else: # Get all fly metadata in a dict exp_metadata = {} @@ -152,14 +192,19 @@ def getFlyMetadata(self, metadata_key=None): :metadata_key: name of metadata item to return. If None, return all metadata as dict. Default=None """ - with h5py.File(self.file_path, 'r') as experiment_file: + with h5py.File(self.file_path, "r") as experiment_file: find_partial = functools.partial(h5io.find_series, sn=self.series_number) epoch_run_group = experiment_file.visititems(find_partial) fly_group = epoch_run_group.parent.parent if metadata_key: - assert metadata_key in fly_group.attrs, 'metadata_key "{}" not found in metadata. \n'.format(metadata_key) \ - + 'Available metadata keys are {}'.format([x for x in fly_group.attrs]) + assert ( + metadata_key in fly_group.attrs + ), 'metadata_key "{}" not found in metadata. \n'.format( + metadata_key + ) + "Available metadata keys are {}".format( + [x for x in fly_group.attrs] + ) fly_metadata = fly_group.attrs[metadata_key] else: # Get all fly metadata in a dict fly_metadata = {} @@ -175,14 +220,19 @@ def getAcquisitionMetadata(self, metadata_key=None): :metadata_key: name of metadata item to return. If None, return all metadata as dict. Default=None """ - with h5py.File(self.file_path, 'r') as experiment_file: + with h5py.File(self.file_path, "r") as experiment_file: find_partial = functools.partial(h5io.find_series, sn=self.series_number) epoch_run_group = experiment_file.visititems(find_partial) - acquisition_group = epoch_run_group['acquisition'] + acquisition_group = epoch_run_group["acquisition"] if metadata_key: - assert metadata_key in acquisition_group.attrs, 'metadata_key "{}" not found in metadata. \n'.format(metadata_key) \ - + 'Available metadata keys are {}'.format([x for x in acquisition_group.attrs]) + assert ( + metadata_key in acquisition_group.attrs + ), 'metadata_key "{}" not found in metadata. \n'.format( + metadata_key + ) + "Available metadata keys are {}".format( + [x for x in acquisition_group.attrs] + ) acquisition_metadata = acquisition_group.attrs[metadata_key] else: # Get all imaging metadata in a dict acquisition_metadata = {} @@ -200,17 +250,17 @@ def getVoltageData(self): voltage_time_vector: array voltage_sample_rate: (Hz) """ - with h5py.File(self.file_path, 'r') as experiment_file: + with h5py.File(self.file_path, "r") as experiment_file: find_partial = functools.partial(h5io.find_series, sn=self.series_number) epoch_run_group = experiment_file.visititems(find_partial) - stimulus_timing_group = epoch_run_group['stimulus_timing'] + stimulus_timing_group = epoch_run_group["stimulus_timing"] - voltage_trace = stimulus_timing_group.get('frame_monitor')[:] + voltage_trace = stimulus_timing_group.get("frame_monitor")[:] if len(voltage_trace.shape) < 2: # dummy dim for single channel photodiode voltage_trace = voltage_trace[np.newaxis, :] - voltage_time_vector = stimulus_timing_group.get('time_vector')[:] - voltage_sample_rate = stimulus_timing_group.attrs['sample_rate'] + voltage_time_vector = stimulus_timing_group.get("time_vector")[:] + voltage_sample_rate = stimulus_timing_group.attrs["sample_rate"] return voltage_trace, voltage_time_vector, voltage_sample_rate @@ -222,37 +272,40 @@ def getResponseTiming(self): response_timing: dict """ - with h5py.File(self.file_path, 'r') as experiment_file: + with h5py.File(self.file_path, "r") as experiment_file: find_partial = functools.partial(h5io.find_series, sn=self.series_number) epoch_run_group = experiment_file.visititems(find_partial) - acquisition_group = epoch_run_group['acquisition'] + acquisition_group = epoch_run_group["acquisition"] response_timing = {} - response_timing['time_vector'] = acquisition_group.get('time_points')[:] # sec - response_timing['sample_period'] = acquisition_group.attrs['sample_period'] # sec + response_timing["time_vector"] = acquisition_group.get("time_points")[ + : + ] # sec + response_timing["sample_period"] = acquisition_group.attrs[ + "sample_period" + ] # sec return response_timing def getVolumeFrameOffsets(self): """ For volumetric scans, return temporal offset for each z slice relative to frame_times - + returns: frame_offsets (sec), temporal sample offset for each z slice component of a volume """ - with h5py.File(self.file_path, 'r+') as experiment_file: + with h5py.File(self.file_path, "r+") as experiment_file: find_partial = functools.partial(h5io.find_series, sn=self.series_number) series_group = experiment_file.visititems(find_partial) - acquisition_group = series_group['acquisition'] - frame_times = np.asarray(acquisition_group.get('frame_times')) + acquisition_group = series_group["acquisition"] + frame_times = np.asarray(acquisition_group.get("frame_times")) # initialize frame_offsets and populate frame_offsets = np.zeros(frame_times.shape[1]) - for frame_ind in range(0,frame_times.shape[1]): + for frame_ind in range(0, frame_times.shape[1]): # offsets are consitent for every single volume, so we only need to do this calculation once for each z-slice - frame_offsets[frame_ind] = frame_times[0,frame_ind]-frame_times[0,0] + frame_offsets[frame_ind] = frame_times[0, frame_ind] - frame_times[0, 0] return frame_offsets - def getStimulusTiming(self, - plot_trace_flag=False): + def getStimulusTiming(self, plot_trace_flag=False): """ Returns stimulus timing information based on photodiode voltage trace from frame tracker signal. @@ -270,7 +323,11 @@ def getStimulusTiming(self, if len(frame_monitor_channels.shape) == 1: frame_monitor_channels = frame_monitor_channels[np.newaxis, :] - minimum_epoch_separation = 0.9 * (run_parameters['pre_time'] + run_parameters['tail_time']) * sample_rate + minimum_epoch_separation = ( + 0.9 + * (run_parameters["pre_time"] + run_parameters["tail_time"]) + * sample_rate + ) num_channels = frame_monitor_channels.shape[0] channel_timing = [] @@ -278,7 +335,9 @@ def getStimulusTiming(self, frame_monitor = frame_monitor_channels[ch, :] # Low-pass filter frame_monitor trace - b, a = signal.butter(4, 10*self.command_frame_rate, btype='low', fs=sample_rate) + b, a = signal.butter( + 4, 10 * self.command_frame_rate, btype="low", fs=sample_rate + ) frame_monitor = signal.filtfilt(b, a, frame_monitor) # shift & normalize so frame monitor trace lives on [0 1] @@ -288,15 +347,34 @@ def getStimulusTiming(self, # find frame flip times V_orig = frame_monitor[0:-2] V_shift = frame_monitor[1:-1] - ups = np.where(np.logical_and(V_orig < self.threshold, V_shift >= self.threshold))[0] + 1 - downs = np.where(np.logical_and(V_orig >= self.threshold, V_shift < self.threshold))[0] + 1 + ups = ( + np.where( + np.logical_and(V_orig < self.threshold, V_shift >= self.threshold) + )[0] + + 1 + ) + downs = ( + np.where( + np.logical_and(V_orig >= self.threshold, V_shift < self.threshold) + )[0] + + 1 + ) frame_times = np.sort(np.append(ups, downs)) # Use frame flip times to find stimulus start times - stimulus_start_frames = np.append(0, np.where(np.diff(frame_times) > minimum_epoch_separation)[0] + 1) - stimulus_end_frames = np.append(np.where(np.diff(frame_times) > minimum_epoch_separation)[0], len(frame_times)-1) - stimulus_start_times = frame_times[stimulus_start_frames] / sample_rate # datapoints -> sec - stimulus_end_times = frame_times[stimulus_end_frames] / sample_rate # datapoints -> sec + stimulus_start_frames = np.append( + 0, np.where(np.diff(frame_times) > minimum_epoch_separation)[0] + 1 + ) + stimulus_end_frames = np.append( + np.where(np.diff(frame_times) > minimum_epoch_separation)[0], + len(frame_times) - 1, + ) + stimulus_start_times = ( + frame_times[stimulus_start_frames] / sample_rate + ) # datapoints -> sec + stimulus_end_times = ( + frame_times[stimulus_end_frames] / sample_rate + ) # datapoints -> sec stim_durations = stimulus_end_times - stimulus_start_times # sec @@ -304,13 +382,26 @@ def getStimulusTiming(self, frame_durations = [] dropped_frame_times = [] for s_ind, ss in enumerate(stimulus_start_frames): - frame_len = np.diff(frame_times[stimulus_start_frames[s_ind]:stimulus_end_frames[s_ind]+1]) - dropped_frame_inds = np.where(np.abs(frame_len - ideal_frame_len)>self.frame_slop)[0] + 1 # +1 b/c diff + frame_len = np.diff( + frame_times[ + stimulus_start_frames[s_ind] : stimulus_end_frames[s_ind] + 1 + ] + ) + dropped_frame_inds = ( + np.where(np.abs(frame_len - ideal_frame_len) > self.frame_slop)[0] + + 1 + ) # +1 b/c diff if len(dropped_frame_inds) > 0: - dropped_frame_times.append(frame_times[ss]+dropped_frame_inds * ideal_frame_len) # time when dropped frames should have flipped + dropped_frame_times.append( + frame_times[ss] + dropped_frame_inds * ideal_frame_len + ) # time when dropped frames should have flipped # print('Warning! Ch. {} Dropped {} frames in epoch {}'.format(ch, len(dropped_frame_inds), s_ind)) - good_frame_inds = np.where(np.abs(frame_len - ideal_frame_len) <= self.frame_slop)[0] - frame_durations.append(frame_len[good_frame_inds]) # only include non-dropped frames in frame rate calc + good_frame_inds = np.where( + np.abs(frame_len - ideal_frame_len) <= self.frame_slop + )[0] + frame_durations.append( + frame_len[good_frame_inds] + ) # only include non-dropped frames in frame rate calc if len(dropped_frame_times) > 0: dropped_frame_times = np.hstack(dropped_frame_times) # datapoints @@ -327,24 +418,50 @@ def getStimulusTiming(self, ax = frame_monitor_figure.add_subplot(gs1[1, :]) ax.plot(time_vector, frame_monitor) # ax.plot(time_vector[frame_times], self.threshold * np.ones(frame_times.shape), 'ko') - ax.plot(stimulus_start_times, self.threshold * np.ones(stimulus_start_times.shape), 'go') - ax.plot(stimulus_end_times, self.threshold * np.ones(stimulus_end_times.shape), 'ro') - ax.plot(dropped_frame_times / sample_rate, 1 * np.ones(dropped_frame_times.shape), 'rx') - ax.set_title('Ch. {}: Frame rate = {:.2f} Hz'.format(ch, frame_rate), fontsize=12) + ax.plot( + stimulus_start_times, + self.threshold * np.ones(stimulus_start_times.shape), + "go", + ) + ax.plot( + stimulus_end_times, + self.threshold * np.ones(stimulus_end_times.shape), + "ro", + ) + ax.plot( + dropped_frame_times / sample_rate, + 1 * np.ones(dropped_frame_times.shape), + "rx", + ) + ax.set_title( + "Ch. {}: Frame rate = {:.2f} Hz".format(ch, frame_rate), fontsize=12 + ) ax = frame_monitor_figure.add_subplot(gs1[0, 0]) ax.hist(frame_durations) - ax.axvline(ideal_frame_len, color='k') - ax.set_xlabel('Frame duration (datapoints)') + ax.axvline(ideal_frame_len, color="k") + ax.set_xlabel("Frame duration (datapoints)") ax = frame_monitor_figure.add_subplot(gs1[0, 1]) - ax.plot(stim_durations, 'b.') - ax.axhline(y=run_parameters['stim_time'], xmin=0, xmax=run_parameters['num_epochs'], color='k', linestyle='-', marker='None', alpha=0.50) - ymin = np.min([0.9 * run_parameters['stim_time'], np.min(stim_durations)]) - ymax = np.max([1.1 * run_parameters['stim_time'], np.max(stim_durations)]) + ax.plot(stim_durations, "b.") + ax.axhline( + y=run_parameters["stim_time"], + xmin=0, + xmax=run_parameters["num_epochs"], + color="k", + linestyle="-", + marker="None", + alpha=0.50, + ) + ymin = np.min( + [0.9 * run_parameters["stim_time"], np.min(stim_durations)] + ) + ymax = np.max( + [1.1 * run_parameters["stim_time"], np.max(stim_durations)] + ) ax.set_ylim([ymin, ymax]) - ax.set_xlabel('Epoch') - ax.set_ylabel('Stim duration (sec)') + ax.set_xlabel("Epoch") + ax.set_ylabel("Stim duration (sec)") frame_monitor_figure.tight_layout() plt.show() @@ -353,23 +470,52 @@ def getStimulusTiming(self, pass else: # Print timing summary - print('===================TIMING: Channel {}======================'.format(ch)) - print('{} Stims presented (of {} parameterized)'.format(len(stim_durations), len(epoch_parameters))) + print( + "===================TIMING: Channel {}======================".format( + ch + ) + ) + print( + "{} Stims presented (of {} parameterized)".format( + len(stim_durations), len(epoch_parameters) + ) + ) inter_stim_starts = np.diff(stimulus_start_times) - print('Stim start to start: [min={:.3f}, median={:.3f}, max={:.3f}] / parameterized = {:.3f} sec'.format(inter_stim_starts.min(), - np.median(inter_stim_starts), - inter_stim_starts.max(), - run_parameters['stim_time'] + run_parameters['pre_time'] + run_parameters['tail_time'])) - print('Stim duration: [min={:.3f}, median={:.3f}, max={:.3f}] / parameterized = {:.3f} sec'.format(stim_durations.min(), np.median(stim_durations), stim_durations.max(), run_parameters['stim_time'])) + print( + "Stim start to start: [min={:.3f}, median={:.3f}, max={:.3f}] / parameterized = {:.3f} sec".format( + inter_stim_starts.min(), + np.median(inter_stim_starts), + inter_stim_starts.max(), + run_parameters["stim_time"] + + run_parameters["pre_time"] + + run_parameters["tail_time"], + ) + ) + print( + "Stim duration: [min={:.3f}, median={:.3f}, max={:.3f}] / parameterized = {:.3f} sec".format( + stim_durations.min(), + np.median(stim_durations), + stim_durations.max(), + run_parameters["stim_time"], + ) + ) total_frames = len(frame_times) dropped_frames = len(dropped_frame_times) - print('Dropped {} / {} frames ({:.2f}%)'.format(dropped_frames, total_frames, 100*dropped_frames/total_frames)) - print('==========================================================') - - new_dict = {'stimulus_end_times': stimulus_end_times, - 'stimulus_start_times': stimulus_start_times, - 'dropped_frame_times': dropped_frame_times, - 'frame_rate': frame_rate} + print( + "Dropped {} / {} frames ({:.2f}%)".format( + dropped_frames, + total_frames, + 100 * dropped_frames / total_frames, + ) + ) + print("==========================================================") + + new_dict = { + "stimulus_end_times": stimulus_end_times, + "stimulus_start_times": stimulus_start_times, + "dropped_frame_times": dropped_frame_times, + "frame_rate": frame_rate, + } channel_timing.append(new_dict) return channel_timing[self.timing_channel_ind] @@ -377,103 +523,139 @@ def getStimulusTiming(self, def getBehaviorTiming(self): """ Get behavior timing (e.g. behavior camera exposures) - + Returns: response_timing: dict - + """ - with h5py.File(self.file_path, 'r') as experiment_file: + with h5py.File(self.file_path, "r") as experiment_file: find_partial = functools.partial(h5io.find_series, sn=self.series_number) epoch_run_group = experiment_file.visititems(find_partial) - behavior_group = epoch_run_group['behavior'] + behavior_group = epoch_run_group["behavior"] behavior_timing = {} - cam_group_names = [x for x in behavior_group.keys() if isinstance(behavior_group[x], h5py.Group) and x!='log_lines'] + cam_group_names = [ + x + for x in behavior_group.keys() + if isinstance(behavior_group[x], h5py.Group) and x != "log_lines" + ] for cam_name in cam_group_names: cam_group = behavior_group[cam_name] cam_timing = {} - cam_timing['frame_time'] = cam_group.get('exposure_onset')[:] # sec - cam_timing['exposure_time'] = cam_group.attrs['exposure_time'] # sec - cam_timing['frame_rate'] = cam_group.attrs['frame_rate'] # Hz + cam_timing["frame_time"] = cam_group.get("exposure_onset")[:] # sec + cam_timing["exposure_time"] = cam_group.attrs["exposure_time"] # sec + cam_timing["frame_rate"] = cam_group.attrs["frame_rate"] # Hz behavior_timing[cam_name] = cam_timing - + # Fictrac data exists, but no fictrac camera group with precise timing - fictrac_cam_group_candidates = [x for x in cam_group_names if 'fictrac' in x.lower()] - if len(fictrac_cam_group_candidates) == 0 and 'fictrac_data' in behavior_group.keys(): - fictrac_header = behavior_group['fictrac_data'].attrs['fictrac_data_header'] - frame_time = behavior_group['fictrac_data'][:, np.where(fictrac_header=='timestamp')[0][0]] / 1000 - frame_rate = 1/np.mean(np.diff(frame_time)) - - behavior_timing['fictrac'] = {} - behavior_timing['fictrac']['frame_time'] = frame_time - behavior_timing['fictrac']['frame_rate'] = frame_rate + fictrac_cam_group_candidates = [ + x for x in cam_group_names if "fictrac" in x.lower() + ] + if ( + len(fictrac_cam_group_candidates) == 0 + and "fictrac_data" in behavior_group.keys() + ): + fictrac_header = behavior_group["fictrac_data"].attrs[ + "fictrac_data_header" + ] + frame_time = ( + behavior_group["fictrac_data"][ + :, np.where(fictrac_header == "timestamp")[0][0] + ] + / 1000 + ) + frame_rate = 1 / np.mean(np.diff(frame_time)) + + behavior_timing["fictrac"] = {} + behavior_timing["fictrac"]["frame_time"] = frame_time + behavior_timing["fictrac"]["frame_rate"] = frame_rate return behavior_timing - + def getBehaviorData(self, stimulus_timing): """ Get Fictrac data """ - + epoch_params = self.getEpochParameters() n_epochs = len(epoch_params) behavior_timing = self.getBehaviorTiming() - - fictrac_timing_candidates = [x for x in behavior_timing.keys() if 'fictrac' in x.lower()] + + fictrac_timing_candidates = [ + x for x in behavior_timing.keys() if "fictrac" in x.lower() + ] if len(fictrac_timing_candidates) == 0: - print('No Fictrac data found.') + print("No Fictrac data found.") return None elif len(fictrac_timing_candidates) > 1: - print('Multiple Fictrac data found. Using first one.') - fictrac_timestamps = behavior_timing[fictrac_timing_candidates[0]]['frame_time'] - - with h5py.File(self.file_path, 'r') as experiment_file: + print("Multiple Fictrac data found. Using first one.") + fictrac_timestamps = behavior_timing[fictrac_timing_candidates[0]]["frame_time"] + + with h5py.File(self.file_path, "r") as experiment_file: find_partial = functools.partial(h5io.find_series, sn=self.series_number) epoch_run_group = experiment_file.visititems(find_partial) - behavior_group = epoch_run_group['behavior'] - - if 'fictrac_data' in behavior_group.keys(): - header = behavior_group['fictrac_data'].attrs['fictrac_data_header'] - fictrac_data = pd.DataFrame(behavior_group['fictrac_data'][:], columns=header) - fictrac_data = fictrac_data.astype({'frame_count': int, 'sequence_number': int}) - fictrac_data = fictrac_data.set_index('frame_count') + behavior_group = epoch_run_group["behavior"] + + if "fictrac_data" in behavior_group.keys(): + header = behavior_group["fictrac_data"].attrs["fictrac_data_header"] + fictrac_data = pd.DataFrame( + behavior_group["fictrac_data"][:], columns=header + ) + fictrac_data = fictrac_data.astype( + {"frame_count": int, "sequence_number": int} + ) + fictrac_data = fictrac_data.set_index("frame_count") else: print("No Fictrac data found in the HDF data file.") return - - if 'log_lines' in behavior_group.keys(): - log_group = behavior_group['log_lines'] - set_pos_0_groups = [(log_group[line_name]['set_pos_0'], log_group[line_name].attrs['ts']) for line_name in log_group.keys() if 'set_pos_0' in log_group[line_name].keys()] - + + if "log_lines" in behavior_group.keys(): + log_group = behavior_group["log_lines"] + set_pos_0_groups = [ + ( + log_group[line_name]["set_pos_0"], + log_group[line_name].attrs["ts"], + ) + for line_name in log_group.keys() + if "set_pos_0" in log_group[line_name].keys() + ] + assert len(set_pos_0_groups) == n_epochs - - set_pos_0_frame_nums = [set_pos_0_group.attrs['frame_num'] for set_pos_0_group, ts in set_pos_0_groups] # 1:1 correspondence to epochs - + + set_pos_0_frame_nums = [ + set_pos_0_group.attrs["frame_num"] + for set_pos_0_group, ts in set_pos_0_groups + ] # 1:1 correspondence to epochs + # Trim Fictrac timestamps and data to the min length of the two n_fictrac_valid_frames = min(len(fictrac_timestamps), len(fictrac_data)) fictrac_timestamps = fictrac_timestamps[:n_fictrac_valid_frames] fictrac_data = fictrac_data[:n_fictrac_valid_frames] - + ##### Pull out Fictrac data for each epoch ##### - - stim_start_times = stimulus_timing['stimulus_start_times'] # from photodiodes - stim_end_times = stimulus_timing['stimulus_end_times'] # from photodiodes + + stim_start_times = stimulus_timing["stimulus_start_times"] # from photodiodes + stim_end_times = stimulus_timing["stimulus_end_times"] # from photodiodes assert len(stim_start_times) == len(stim_end_times) - + if len(stim_start_times) < n_epochs: - print('CAUTION! stimulus_timing has length less than number of epochs. Trimming # of epochs.') + print( + "CAUTION! stimulus_timing has length less than number of epochs. Trimming # of epochs." + ) n_epochs = len(stim_start_times) run_params = self.getRunParameters() - iti = run_params['pre_time'] + run_params['tail_time'] + iti = run_params["pre_time"] + run_params["tail_time"] fictrac_data_for_epoch = [] ts_pointer = 0 prev_ts_pointer = 0 # Epoch -1: - current_stim_end_time = stim_start_times[0] - iti # "stim end time" for epoch -1 + current_stim_end_time = ( + stim_start_times[0] - iti + ) # "stim end time" for epoch -1 # Find Fictrac index and timestamp for when stimulus starts while fictrac_timestamps[ts_pointer] <= current_stim_end_time: ts_pointer += 1 @@ -487,9 +669,9 @@ def getBehaviorData(self, stimulus_timing): # next_stim_start_fictrac_index = ts_pointer # next_stim_start_fictrac_timestamp = fictrac_timestamps[ts_pointer] - print('Pulling out Fictrac data for each epoch...') + print("Pulling out Fictrac data for each epoch...") for e in range(n_epochs): - print(f'Epoch {e+1}/{n_epochs}') + print(f"Epoch {e+1}/{n_epochs}") prev_stim_end_fictrac_index = current_stim_end_fictrac_index # prev_stim_end_fictrac_timestamp = current_stim_end_fictrac_timestamp # current_stim_start_fictrac_index = next_stim_start_fictrac_index @@ -497,46 +679,71 @@ def getBehaviorData(self, stimulus_timing): current_stim_start_time = stim_start_times[e] current_stim_end_time = stim_end_times[e] - + # print(current_stim_end_time - current_stim_start_time) # print(fictrac_timestamps[ts_pointer]) # print(current_stim_end_time) - + # Find Fictrac index and timestamp for when stimulus ends - while ts_pointer < len(fictrac_timestamps) and fictrac_timestamps[ts_pointer] <= current_stim_end_time: + while ( + ts_pointer < len(fictrac_timestamps) + and fictrac_timestamps[ts_pointer] <= current_stim_end_time + ): # if fictrac_timestamps[ts_pointer] > 1200: # print(fictrac_timestamps[ts_pointer]) ts_pointer += 1 current_stim_end_fictrac_index = ts_pointer - 1 # current_stim_end_fictrac_timestamp = fictrac_timestamps[ts_pointer] - if e < n_epochs-1: - next_stim_start_time = stim_start_times[e+1] + if e < n_epochs - 1: + next_stim_start_time = stim_start_times[e + 1] else: next_stim_start_time = current_stim_end_time + iti # Find Fictrac index and timestamp for when stimulus starts - while ts_pointer < len(fictrac_timestamps) and fictrac_timestamps[ts_pointer] <= next_stim_start_time: + while ( + ts_pointer < len(fictrac_timestamps) + and fictrac_timestamps[ts_pointer] <= next_stim_start_time + ): ts_pointer += 1 next_stim_start_fictrac_index = ts_pointer # next_stim_start_fictrac_timestamp = fictrac_timestamps[ts_pointer] epoch_set_pos_0_frame_num = set_pos_0_frame_nums[e] - epoch_fictrac_data = fictrac_data[prev_stim_end_fictrac_index : next_stim_start_fictrac_index].copy() + epoch_fictrac_data = fictrac_data[ + prev_stim_end_fictrac_index:next_stim_start_fictrac_index + ].copy() # print(f'{len(epoch_fictrac_data)} Fictrac frames') - epoch_fictrac_data['integrated_xpos'] -= epoch_fictrac_data['integrated_xpos'][epoch_set_pos_0_frame_num] - epoch_fictrac_data['integrated_ypos'] -= epoch_fictrac_data['integrated_ypos'][epoch_set_pos_0_frame_num] - epoch_fictrac_data['integrated_heading'] = np.unwrap(epoch_fictrac_data['integrated_heading']) - epoch_fictrac_data['integrated_heading'] -= epoch_fictrac_data['integrated_heading'][epoch_set_pos_0_frame_num] - epoch_fictrac_data['integrated_heading'] = np.mod((epoch_fictrac_data['integrated_heading'] + np.pi), np.pi*2) - np.pi - - epoch_fictrac_timestamps = np.copy(fictrac_timestamps[prev_stim_end_fictrac_index : next_stim_start_fictrac_index]) + epoch_fictrac_data["integrated_xpos"] -= epoch_fictrac_data[ + "integrated_xpos" + ][epoch_set_pos_0_frame_num] + epoch_fictrac_data["integrated_ypos"] -= epoch_fictrac_data[ + "integrated_ypos" + ][epoch_set_pos_0_frame_num] + epoch_fictrac_data["integrated_heading"] = np.unwrap( + epoch_fictrac_data["integrated_heading"] + ) + epoch_fictrac_data["integrated_heading"] -= epoch_fictrac_data[ + "integrated_heading" + ][epoch_set_pos_0_frame_num] + epoch_fictrac_data["integrated_heading"] = ( + np.mod((epoch_fictrac_data["integrated_heading"] + np.pi), np.pi * 2) + - np.pi + ) + + epoch_fictrac_timestamps = np.copy( + fictrac_timestamps[ + prev_stim_end_fictrac_index:next_stim_start_fictrac_index + ] + ) epoch_fictrac_timestamps -= current_stim_start_time - fictrac_data_for_epoch.append({'timestamps':epoch_fictrac_timestamps, 'data':epoch_fictrac_data}) + fictrac_data_for_epoch.append( + {"timestamps": epoch_fictrac_timestamps, "data": epoch_fictrac_data} + ) return fictrac_data_for_epoch - def getRoiSetNames(self, roi_prefix='rois'): + def getRoiSetNames(self, roi_prefix="rois"): """ -roi_prefix: 'rois' or 'aligned' @@ -544,7 +751,7 @@ def getRoiSetNames(self, roi_prefix='rois'): 'aligned' used for mask-generated, no path objects for drawing """ roi_set_names = [] - with h5py.File(self.file_path, 'r') as experiment_file: + with h5py.File(self.file_path, "r") as experiment_file: find_partial = functools.partial(h5io.find_series, sn=self.series_number) roi_parent_group = experiment_file.visititems(find_partial)[roi_prefix] for roi_set_name in roi_parent_group.keys(): @@ -552,7 +759,13 @@ def getRoiSetNames(self, roi_prefix='rois'): return roi_set_names - def getRoiResponses(self, roi_set_name, background_subtraction=False, roi_prefix='rois', return_erm=True): + def getRoiResponses( + self, + roi_set_name, + background_subtraction=False, + roi_prefix="rois", + return_erm=True, + ): """ Get responses for indicated roi Params: @@ -572,28 +785,36 @@ def getRoiResponses(self, roi_set_name, background_subtraction=False, roi_prefix time_vector: 1d array, time values for epoch_response traces (sec) """ roi_data = {} - with h5py.File(self.file_path, 'r') as experiment_file: + with h5py.File(self.file_path, "r") as experiment_file: find_partial = functools.partial(h5io.find_series, sn=self.series_number) roi_parent_group = experiment_file.visititems(find_partial)[roi_prefix] - assert roi_set_name in roi_parent_group, 'roi_set_name "{}" not found in roi group'.format(roi_set_name) + assert ( + roi_set_name in roi_parent_group + ), 'roi_set_name "{}" not found in roi group'.format(roi_set_name) roi_set_group = roi_parent_group[roi_set_name] - roi_data['roi_response'] = list(roi_set_group.get("roi_response")[:]) - roi_data['roi_mask'] = roi_set_group.get("roi_mask")[:] - roi_data['roi_image'] = roi_set_group.get("roi_image")[:] + roi_data["roi_response"] = list(roi_set_group.get("roi_response")[:]) + roi_data["roi_mask"] = roi_set_group.get("roi_mask")[:] + roi_data["roi_image"] = roi_set_group.get("roi_image")[:] if background_subtraction: - with h5py.File(self.file_path, 'r') as experiment_file: - find_partial = functools.partial(h5io.find_series, sn=self.series_number) + with h5py.File(self.file_path, "r") as experiment_file: + find_partial = functools.partial( + h5io.find_series, sn=self.series_number + ) roi_parent_group = experiment_file.visititems(find_partial)[roi_prefix] - bg_roi_group = roi_parent_group['bg'] + bg_roi_group = roi_parent_group["bg"] bg_roi_response = list(bg_roi_group.get("roi_response")[:]) - roi_data['roi_response'] = [x - np.squeeze(bg_roi_response) for x in roi_data['roi_response']] + roi_data["roi_response"] = [ + x - np.squeeze(bg_roi_response) for x in roi_data["roi_response"] + ] if return_erm: - time_vector, response_matrix = self.getEpochResponseMatrix(np.vstack(roi_data.get('roi_response'))) - roi_data['epoch_response'] = response_matrix - roi_data['time_vector'] = time_vector + time_vector, response_matrix = self.getEpochResponseMatrix( + np.vstack(roi_data.get("roi_response")) + ) + roi_data["epoch_response"] = response_matrix + roi_data["time_vector"] = time_vector return roi_data @@ -616,34 +837,58 @@ def getEpochResponseMatrix(self, region_response, dff=True, df=False): response_timing = self.getResponseTiming() stimulus_timing = self.getStimulusTiming() - epoch_start_times = stimulus_timing['stimulus_start_times'] - run_parameters['pre_time'] - epoch_end_times = stimulus_timing['stimulus_end_times'] + run_parameters['tail_time'] - epoch_time = (run_parameters['pre_time'] - + run_parameters['stim_time'] - + run_parameters['tail_time']) # sec + epoch_start_times = ( + stimulus_timing["stimulus_start_times"] - run_parameters["pre_time"] + ) + epoch_end_times = ( + stimulus_timing["stimulus_end_times"] + run_parameters["tail_time"] + ) + epoch_time = ( + run_parameters["pre_time"] + + run_parameters["stim_time"] + + run_parameters["tail_time"] + ) # sec # find how many acquisition frames correspond to pre, stim, tail time - epoch_frames = int(epoch_time / response_timing['sample_period']) # in acquisition frames - pre_frames = int(run_parameters['pre_time'] / response_timing['sample_period']) # in acquisition frames - time_vector = np.arange(0, epoch_frames) * response_timing['sample_period'] # sec + epoch_frames = int( + epoch_time / response_timing["sample_period"] + ) # in acquisition frames + pre_frames = int( + run_parameters["pre_time"] / response_timing["sample_period"] + ) # in acquisition frames + time_vector = ( + np.arange(0, epoch_frames) * response_timing["sample_period"] + ) # sec no_trials = len(epoch_start_times) - response_matrix = np.empty(shape=(no_regions, no_trials, epoch_frames), dtype=float) + response_matrix = np.empty( + shape=(no_regions, no_trials, epoch_frames), dtype=float + ) response_matrix[:] = np.nan - cut_inds = np.empty(0, dtype=int) # trial/epoch indices to cut from response_matrix + cut_inds = np.empty( + 0, dtype=int + ) # trial/epoch indices to cut from response_matrix for idx, val in enumerate(epoch_start_times): - stack_inds = np.where(np.logical_and(response_timing['time_vector'] < epoch_end_times[idx], - response_timing['time_vector'] >= epoch_start_times[idx]))[0] - if len(stack_inds) == 0: # no imaging acquisitions happened during this epoch presentation + stack_inds = np.where( + np.logical_and( + response_timing["time_vector"] < epoch_end_times[idx], + response_timing["time_vector"] >= epoch_start_times[idx], + ) + )[0] + if ( + len(stack_inds) == 0 + ): # no imaging acquisitions happened during this epoch presentation cut_inds = np.append(cut_inds, idx) continue if np.any(stack_inds > region_response.shape[1]): cut_inds = np.append(cut_inds, idx) continue if idx == no_trials: - if len(stack_inds) < epoch_frames: # missed images for the end of the stimulus + if ( + len(stack_inds) < epoch_frames + ): # missed images for the end of the stimulus cut_inds = np.append(cut_inds, idx) - print('Missed acquisition frames at the end of the stimulus!') + print("Missed acquisition frames at the end of the stimulus!") continue # pull out Roi values for these scans. shape of newRespChunk is (nROIs,nScans) new_resp_chunk = region_response[:, stack_inds] @@ -651,7 +896,9 @@ def getEpochResponseMatrix(self, region_response, dff=True, df=False): if dff: # calculate baseline using pre frames - baseline = np.mean(new_resp_chunk[:, 0:pre_frames], axis=1, keepdims=True) + baseline = np.mean( + new_resp_chunk[:, 0:pre_frames], axis=1, keepdims=True + ) # to dF/F with warnings.catch_warnings(): # Warning to catch divide by zero or nan. Will return nan or inf warnings.simplefilter("ignore", category=RuntimeWarning) @@ -660,15 +907,21 @@ def getEpochResponseMatrix(self, region_response, dff=True, df=False): if df: # calculate baseline using pre frames, don't divide by f - baseline = np.mean(new_resp_chunk[:, 0:pre_frames], axis=1, keepdims=True) + baseline = np.mean( + new_resp_chunk[:, 0:pre_frames], axis=1, keepdims=True + ) # to dF/F with warnings.catch_warnings(): # Warning to catch divide by zero or nan. Will return nan or inf warnings.simplefilter("ignore", category=RuntimeWarning) - new_resp_chunk = (new_resp_chunk - baseline) + new_resp_chunk = new_resp_chunk - baseline if epoch_frames > new_resp_chunk.shape[-1]: - print('Warnging: Size mismatch idx = {}'.format(idx)) # the end of a response clipped off - response_matrix[:, idx, :new_resp_chunk.shape[-1]] = new_resp_chunk[:, 0:] + print( + "Warning: Size mismatch idx = {}".format(idx) + ) # the end of a response clipped off + response_matrix[:, idx, : new_resp_chunk.shape[-1]] = new_resp_chunk[ + :, 0: + ] else: response_matrix[:, idx, :] = new_resp_chunk[:, 0:epoch_frames] # except: @@ -676,14 +929,20 @@ def getEpochResponseMatrix(self, region_response, dff=True, df=False): # print(response_matrix.shape) # print(new_resp_chunk.shape) # print(epoch_frames) - # cut_inds = np.append(cut_inds, idx) + # cut_inds = np.append(cut_inds, idx) if len(cut_inds) > 0: - print('Warning: cut {} epochs from epoch response matrix'.format(len(cut_inds))) - response_matrix = np.delete(response_matrix, cut_inds, axis=1) # shape = (region, trial, time) + print( + "Warning: cut {} epochs from epoch response matrix".format( + len(cut_inds) + ) + ) + response_matrix = np.delete( + response_matrix, cut_inds, axis=1 + ) # shape = (region, trial, time) return time_vector, response_matrix -# # # # # # # # # # # # # CONVENIENCE METHODS # # # # # # # # # # # # # # # # # # # # # # # # # # + # # # # # # # # # # # # # CONVENIENCE METHODS # # # # # # # # # # # # # # # # # # # # # # # # # # def getEpochGroupingsByParameters(self, parameter_key=None): """ @@ -703,19 +962,26 @@ def getEpochGroupingsByParameters(self, parameter_key=None): if parameter_key is not None: if type(parameter_key) is str: # single param key parameter_key = [parameter_key] # list-ify - parameter_values = [list(gu.convert_iterables_to_tuples(pd.values())) for pd in self.getEpochParameterDicts(target_keys=parameter_key)] + parameter_values = [ + list(gu.convert_iterables_to_tuples(pd.values())) + for pd in self.getEpochParameterDicts(target_keys=parameter_key) + ] # Get unique parameter combinations - unique_parameter_values = [list(s) for s in set(tuple(pv) for pv in parameter_values)] + unique_parameter_values = [ + list(s) for s in set(tuple(pv) for pv in parameter_values) + ] # Sort unique_parameter_values into some sensible ordering unique_parameter_values.sort() # Get epoch indices for each unique parameter combination - epoch_inds = [np.where([pv == up for pv in parameter_values])[0] for up in unique_parameter_values] - - return unique_parameter_values, epoch_inds + epoch_inds = [ + np.where([pv == up for pv in parameter_values])[0] + for up in unique_parameter_values + ] + return unique_parameter_values, epoch_inds def getTrialAverages(self, epoch_response_matrix, parameter_key=None): """ @@ -738,35 +1004,58 @@ def getTrialAverages(self, epoch_response_matrix, parameter_key=None): """ # Get unique parameter combinations and epoch indices for each unique parameter combination - unique_parameter_values, epoch_inds = self.getEpochGroupingsByParameters(parameter_key) + unique_parameter_values, epoch_inds = self.getEpochGroupingsByParameters( + parameter_key + ) n_stimuli = len(unique_parameter_values) n_regions, n_trials, t_dim = epoch_response_matrix.shape - mean_response = np.ndarray(shape=(n_regions, n_stimuli, t_dim)) # n_regions x stim condition x time - sem_response = np.ndarray(shape=(n_regions, n_stimuli, t_dim)) # n_regions x stim condition x time + mean_response = np.ndarray( + shape=(n_regions, n_stimuli, t_dim) + ) # n_regions x stim condition x time + sem_response = np.ndarray( + shape=(n_regions, n_stimuli, t_dim) + ) # n_regions x stim condition x time mean_response[:] = np.nan sem_response[:] = np.nan trial_response_by_stimulus = [] - for p_ind, up in enumerate(unique_parameter_values): # For distinct stimulus parameterizations - pull_inds = epoch_inds[p_ind] # trial indices matching this parameterization + for p_ind, up in enumerate( + unique_parameter_values + ): # For distinct stimulus parameterizations + pull_inds = epoch_inds[ + p_ind + ] # trial indices matching this parameterization # Check to see if any epochs are not in the epoch_response_matrix. Possibly because of a cut-short recording if np.any(np.array(pull_inds) >= epoch_response_matrix.shape[1]): tmp = np.where(pull_inds >= epoch_response_matrix.shape[1])[0] - print('Epoch(s) {} not included in epoch_response_matrix'.format(pull_inds[tmp])) + print( + "Epoch(s) {} not included in epoch_response_matrix".format( + pull_inds[tmp] + ) + ) pull_inds = pull_inds[pull_inds < epoch_response_matrix.shape[1]] with warnings.catch_warnings(): # Warning if nanmean run on an axis with ALL nans warnings.simplefilter("ignore", category=RuntimeWarning) - mean_response[:, p_ind, :] = np.nanmean(epoch_response_matrix[:, pull_inds, :], axis=1) - sem_response[:, p_ind, :] = np.nanstd(epoch_response_matrix[:, pull_inds, :], axis=1) / np.sqrt(len(pull_inds)) + mean_response[:, p_ind, :] = np.nanmean( + epoch_response_matrix[:, pull_inds, :], axis=1 + ) + sem_response[:, p_ind, :] = np.nanstd( + epoch_response_matrix[:, pull_inds, :], axis=1 + ) / np.sqrt(len(pull_inds)) trial_response_by_stimulus.append(epoch_response_matrix[:, pull_inds, :]) - return unique_parameter_values, mean_response, sem_response, trial_response_by_stimulus + return ( + unique_parameter_values, + mean_response, + sem_response, + trial_response_by_stimulus, + ) - def getResponseAmplitude(self, epoch_response_matrix, metric='max'): + def getResponseAmplitude(self, epoch_response_matrix, metric="max"): """ Get response amplitude from the start of the stimulus to the end of the trial Params: @@ -779,14 +1068,22 @@ def getResponseAmplitude(self, epoch_response_matrix, metric='max'): run_parameters = self.getRunParameters() response_timing = self.getResponseTiming() - pre_frames = int(run_parameters['pre_time'] / response_timing.get('sample_period')) - - if metric == 'max': - response_amplitude = np.nanmax(epoch_response_matrix[..., pre_frames:], axis=-1) - elif metric == 'mean': - response_amplitude = np.nanmean(epoch_response_matrix[..., pre_frames:], axis=-1) - elif metric == 'min': - response_amplitude = np.nanmin(epoch_response_matrix[..., pre_frames:], axis=-1) + pre_frames = int( + run_parameters["pre_time"] / response_timing.get("sample_period") + ) + + if metric == "max": + response_amplitude = np.nanmax( + epoch_response_matrix[..., pre_frames:], axis=-1 + ) + elif metric == "mean": + response_amplitude = np.nanmean( + epoch_response_matrix[..., pre_frames:], axis=-1 + ) + elif metric == "min": + response_amplitude = np.nanmin( + epoch_response_matrix[..., pre_frames:], axis=-1 + ) return response_amplitude @@ -801,18 +1098,22 @@ def generateRoiMap(self, roi_name, scale_bar_length=0, z=None): """ roi_data = self.getRoiResponses(roi_name) if z is None: - im = roi_data.get('roi_image') - msk = roi_data.get('roi_mask') + im = roi_data.get("roi_image") + msk = roi_data.get("roi_mask") else: - im = roi_data.get('roi_image')[..., z] - msk = roi_data.get('roi_mask')[..., z] + im = roi_data.get("roi_image")[..., z] + msk = roi_data.get("roi_mask")[..., z] new_image = plot_tools.overlayImage(im, msk, 0.5, self.colors) - fh, ax = plt.subplots(1, 1, figsize=(4,4)) + fh, ax = plt.subplots(1, 1, figsize=(4, 4)) ax.imshow(new_image) - ax.set_aspect('equal') + ax.set_aspect("equal") ax.set_axis_off() if scale_bar_length > 0: - microns_per_pixel = float(self.getAcquisitionMetadata()['micronsPerPixel_XAxis']) - plot_tools.addImageScaleBar(ax, new_image, scale_bar_length, microns_per_pixel, 'lr') + microns_per_pixel = float( + self.getAcquisitionMetadata()["micronsPerPixel_XAxis"] + ) + plot_tools.addImageScaleBar( + ax, new_image, scale_bar_length, microns_per_pixel, "lr" + )