From f7940a337f93a95b694947deea51ff2fe23d54f1 Mon Sep 17 00:00:00 2001 From: Ajit Johnson Nirmal Date: Thu, 21 Nov 2024 15:08:53 -0500 Subject: [PATCH] updates to gater for smoother performance --- pyproject.toml | 2 +- scimap/plotting/gate_finder.py | 2 + scimap/plotting/napariGater.py | 274 ++++++++++++++------------------- 3 files changed, 115 insertions(+), 163 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e684fb3d..61cafa1c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [tool.poetry] name = "SCIMAP" -version = "2.2.4" +version = "2.2.5" description = "Spatial Single-Cell Analysis Toolkit" license = "MIT" diff --git a/scimap/plotting/gate_finder.py b/scimap/plotting/gate_finder.py index 26f4cc02..185974b7 100644 --- a/scimap/plotting/gate_finder.py +++ b/scimap/plotting/gate_finder.py @@ -10,6 +10,8 @@ Subsequently, the identified gating parameters can be applied to the dataset using `sm.pp.rescale`, enabling precise control over data segmentation and analysis based on marker expression levels. + `gate_finder()` is deprecated and will be removed in a future version. Please use `sm.pl.napariGater()` instead. + ## Function """ diff --git a/scimap/plotting/napariGater.py b/scimap/plotting/napariGater.py index 0aee39a3..5d567660 100644 --- a/scimap/plotting/napariGater.py +++ b/scimap/plotting/napariGater.py @@ -10,6 +10,8 @@ Subsequently, the identified gating parameters can be applied to the dataset using `sm.pp.rescale`, enabling precise control over data segmentation and analysis based on marker expression levels. + Repacement for `sm.pl.gate_finder()` + ## Function """ @@ -117,6 +119,26 @@ def initialize_gates(adata, imageid): adata.uns['gates'].loc[marker, :] = gate_value pbar.update(1) + # Initialize provenance tracking + if 'napariGaterProvenance' not in adata.uns: + adata.uns['napariGaterProvenance'] = { + 'manually_adjusted': {}, # Track adjusted markers per image + 'timestamp': {}, # Track when adjustments were made + 'original_values': {}, # Track original GMM values + } + + # Initialize for current image if needed + current_image = adata.obs[imageid].iloc[0] + if current_image not in adata.uns['napariGaterProvenance']['manually_adjusted']: + adata.uns['napariGaterProvenance']['manually_adjusted'][current_image] = {} + adata.uns['napariGaterProvenance']['timestamp'][current_image] = {} + adata.uns['napariGaterProvenance']['original_values'][current_image] = {} + + # Store initial GMM values + for marker in adata.var.index: + adata.uns['napariGaterProvenance']['original_values'][current_image][marker] = \ + float(adata.uns['gates'].loc[marker, current_image]) + return adata @@ -216,172 +238,47 @@ def initialize_contrast_settings(adata, img_data, channel_names, imageid='imagei return adata -def check_pyramid_levels(tiff_file): - #"""Check if the TIFF file has pyramid levels""" - try: - series = tiff_file.series[0] - return hasattr(series, 'levels') - except Exception: - return False - - -def add_channel_to_viewer(viewer, img, channel_idx, channel_name, contrast_limits, colormap): - #"""Add a channel to viewer with proper pyramid handling""" - try: - # Store current view state if any layer exists - if len(viewer.layers) > 0: - current_zoom = viewer.camera.zoom - current_center = viewer.camera.center - else: - current_zoom = None - current_center = None - - # Get the data shape from the series - tiff_file = img._store._source - series = tiff_file.series[0] - - # Check if we have pyramid levels - has_pyramids = check_pyramid_levels(tiff_file) - - if has_pyramids: - # Load pyramid levels - pyramid_data = [] - for level in series.levels: - level_data = level.pages[channel_idx].asarray() - pyramid_data.append(level_data) - - # Add to viewer as multiscale - viewer.add_image( - pyramid_data, - name=channel_name, - visible=False, - colormap=colormap, - blending='additive', - contrast_limits=contrast_limits, - multiscale=True, - rendering='mip', - interpolation2d='nearest' - ) - else: - # Fallback to single resolution - channel_data = series.pages[channel_idx].asarray() - viewer.add_image( - channel_data, - name=channel_name, - visible=False, - colormap=colormap, - blending='additive', - contrast_limits=contrast_limits, - multiscale=False, - rendering='mip', - interpolation2d='nearest' - ) - - # After adding the new layer, restore view state if it existed - if current_zoom is not None: - viewer.camera.zoom = current_zoom - viewer.camera.center = current_center - - return True - - except Exception as e: - print(f"Warning: Channel {channel_name} could not be loaded") - return False - - def load_image_efficiently(image_path): - #"""Efficiently load image with proper lazy loading""" + """Efficiently load image using zarr conversion""" if isinstance(image_path, str): if image_path.endswith(('.tiff', '.tif')): - tiff_file = tiff.TiffFile(image_path, is_ome=False) - series = tiff_file.series[0] + image = tiff.TiffFile(image_path, is_ome=False) + z = zarr.open(image.aszarr(), mode='r') - if hasattr(series, 'levels'): - # For pyramidal images, create a list of dask arrays - data = [] - for level in series.levels: - shape = (len(level.pages),) + level.pages[0].shape - chunks = (1,) + level.pages[0].shape # Chunk by channel - - @delayed - def get_page(i, level=level): - return level.pages[i].asarray() - - # Create lazy dask array for this level - level_data = da.stack([ - da.from_delayed( - get_page(i), - shape=level.pages[0].shape, - dtype=level.pages[0].dtype - ) - for i in range(len(level.pages)) - ]) - data.append(level_data) - - return data, tiff_file + # Check for pyramids + n_levels = len(image.series[0].levels) + + if n_levels > 1: + data = [da.from_zarr(z[i]) for i in range(n_levels)] + multiscale = True else: - # For non-pyramidal images, create single dask array - shape = (len(series.pages),) + series.pages[0].shape - chunks = (1,) + series.pages[0].shape + data = da.from_zarr(z) + multiscale = False - @delayed - def get_page(i): - return series.pages[i].asarray() - - data = da.stack([ - da.from_delayed( - get_page(i), - shape=series.pages[0].shape, - dtype=series.pages[0].dtype - ) - for i in range(len(series.pages)) - ]) - - return data, tiff_file + return data, image, multiscale - return None, None + return None, None, False def add_channels_to_viewer(viewer, img_data, channel_names, contrast_settings, colormaps): - #"""Add all channels to viewer efficiently""" - if isinstance(img_data, list): # Pyramidal - # Add all channels at once with multiscale - for channel_idx, channel_name in enumerate(channel_names): - contrast_limits = ( - contrast_settings[channel_name]['low'], - contrast_settings[channel_name]['high'] - ) - - # Extract this channel's data across all pyramid levels - channel_data = [level[channel_idx] for level in img_data] - - viewer.add_image( - channel_data, - name=channel_name, - visible=False, - colormap=colormaps[channel_idx % len(colormaps)], - blending='additive', - contrast_limits=contrast_limits, - multiscale=True, - rendering='mip', - interpolation2d='nearest' - ) - else: # Non-pyramidal - # Add all channels at once - viewer.add_image( - img_data, - channel_axis=0, - name=channel_names, - visible=False, - colormap=colormaps, - blending='additive', - contrast_limits=[ - (contrast_settings[name]['low'], contrast_settings[name]['high']) - for name in channel_names - ], - multiscale=False, - rendering='mip', - interpolation2d='nearest' - ) + """Add channels maintaining pyramid structure if available""" + n_channels = len(channel_names) + extended_colormaps = [colormaps[i % len(colormaps)] for i in range(n_channels)] + + viewer.add_image( + img_data, + channel_axis=0, + name=channel_names, + visible=False, + colormap=extended_colormaps, + blending='additive', + contrast_limits=[ + (contrast_settings[name]['low'], contrast_settings[name]['high']) + for name in channel_names + ], + multiscale=isinstance(img_data, list), # True if pyramidal + rendering='mip', + interpolation2d='nearest' + ) def napariGater( @@ -498,7 +395,7 @@ def napariGater( # Load image efficiently print("Loading image data...") - img_data, tiff_file = load_image_efficiently(image_path) + img_data, tiff_file, multiscale = load_image_efficiently(image_path) if img_data is None: raise ValueError("Failed to load image data") @@ -532,12 +429,17 @@ def napariGater( # Create the viewer and add all channels efficiently viewer = napari.Viewer() + default_colormaps = [ + 'magenta', 'cyan', 'yellow', 'red', 'green', 'blue', + 'magenta', 'cyan', 'yellow', 'red', 'green', 'blue' + ] # Basic colors that will be cycled + add_channels_to_viewer( viewer, img_data, channel_names, adata.uns['image_contrast_settings'][current_image], - colormaps=['magenta', 'cyan', 'yellow', 'red', 'green', 'blue'] + colormaps=default_colormaps ) # Verify loaded channels @@ -576,20 +478,37 @@ def napariGater( @magicgui( auto_call=True, - marker={'choices': list(adata.var.index), 'value': initial_marker}, + layout='vertical', + marker={ + 'choices': list(adata.var.index), + 'value': initial_marker, + 'label': 'Select Marker:' + }, gate={ 'widget_type': 'FloatSpinBox', 'min': min_val, 'max': max_val, 'value': initial_gate, 'step': 0.01, + 'label': 'Gate Threshold:' + }, + marker_status={ + 'widget_type': 'Label', + 'value': '⚪ Not adjusted' # Initial value + }, + confirm_gate={ + 'widget_type': 'PushButton', + 'text': 'Confirm Gate' + }, + finish={ + 'widget_type': 'PushButton', + 'text': 'Finish Gating' }, - confirm_gate={'widget_type': 'PushButton', 'text': 'Confirm Gate'}, - finish={'widget_type': 'PushButton', 'text': 'Finish Gating'}, ) def gate_controls( marker: str, gate: float = initial_gate, + marker_status: str = '⚪ Not adjusted', confirm_gate=False, finish=False, ): @@ -656,12 +575,43 @@ def _on_marker_change(marker: str): viewer.camera.zoom = current_state['zoom'] viewer.camera.center = current_state['center'] + # Update status with more visible formatting and shorter timestamp + current_image = adata.obs[imageid].iloc[0] if subset is None else subset + is_adjusted = marker in adata.uns['napariGaterProvenance']['manually_adjusted'].get(current_image, {}) + if is_adjusted: + status_text = "✓ ADJUSTED" + # Get and format timestamp + timestamp = adata.uns['napariGaterProvenance']['timestamp'][current_image][marker] + # Convert stored timestamp to shorter format + from datetime import datetime + try: + dt = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S") + short_timestamp = dt.strftime("%y-%m-%d %H:%M") + status_text += f" ({short_timestamp})" + except: + status_text += f" ({timestamp})" + else: + status_text = "⚪ NOT ADJUSTED" + + gate_controls.marker_status.value = status_text + @gate_controls.confirm_gate.clicked.connect def _on_confirm(): marker = gate_controls.marker.value gate = gate_controls.gate.value current_image = adata.obs[imageid].iloc[0] if subset is None else subset + + # Update gate value adata.uns['gates'].loc[marker, current_image] = float(gate) + + # Update provenance with shorter timestamp + from datetime import datetime + timestamp = datetime.now().strftime("%y-%m-%d %H:%M") # Shorter format + adata.uns['napariGaterProvenance']['manually_adjusted'][current_image][marker] = float(gate) + adata.uns['napariGaterProvenance']['timestamp'][current_image][marker] = timestamp + + # Update status with confirmation message + gate_controls.marker_status.value = f"✓ ADJUSTED ({timestamp})" # Add handler for finish button @gate_controls.finish.clicked.connect