Skip to content

Commit

Permalink
updates to gater for smoother performance
Browse files Browse the repository at this point in the history
  • Loading branch information
ajitjohnson committed Nov 21, 2024
1 parent 22fb6df commit f7940a3
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 163 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tool.poetry]

name = "SCIMAP"
version = "2.2.4"
version = "2.2.5"
description = "Spatial Single-Cell Analysis Toolkit"

license = "MIT"
Expand Down
2 changes: 2 additions & 0 deletions scimap/plotting/gate_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
Subsequently, the identified gating parameters can be applied to the dataset using `sm.pp.rescale`,
enabling precise control over data segmentation and analysis based on marker expression levels.
`gate_finder()` is deprecated and will be removed in a future version. Please use `sm.pl.napariGater()` instead.
## Function
"""

Expand Down
274 changes: 112 additions & 162 deletions scimap/plotting/napariGater.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
Subsequently, the identified gating parameters can be applied to the dataset using `sm.pp.rescale`,
enabling precise control over data segmentation and analysis based on marker expression levels.
Repacement for `sm.pl.gate_finder()`
## Function
"""

Expand Down Expand Up @@ -117,6 +119,26 @@ def initialize_gates(adata, imageid):
adata.uns['gates'].loc[marker, :] = gate_value
pbar.update(1)

# Initialize provenance tracking
if 'napariGaterProvenance' not in adata.uns:
adata.uns['napariGaterProvenance'] = {
'manually_adjusted': {}, # Track adjusted markers per image
'timestamp': {}, # Track when adjustments were made
'original_values': {}, # Track original GMM values
}

# Initialize for current image if needed
current_image = adata.obs[imageid].iloc[0]
if current_image not in adata.uns['napariGaterProvenance']['manually_adjusted']:
adata.uns['napariGaterProvenance']['manually_adjusted'][current_image] = {}
adata.uns['napariGaterProvenance']['timestamp'][current_image] = {}
adata.uns['napariGaterProvenance']['original_values'][current_image] = {}

# Store initial GMM values
for marker in adata.var.index:
adata.uns['napariGaterProvenance']['original_values'][current_image][marker] = \
float(adata.uns['gates'].loc[marker, current_image])

return adata


Expand Down Expand Up @@ -216,172 +238,47 @@ def initialize_contrast_settings(adata, img_data, channel_names, imageid='imagei
return adata


def check_pyramid_levels(tiff_file):
#"""Check if the TIFF file has pyramid levels"""
try:
series = tiff_file.series[0]
return hasattr(series, 'levels')
except Exception:
return False


def add_channel_to_viewer(viewer, img, channel_idx, channel_name, contrast_limits, colormap):
#"""Add a channel to viewer with proper pyramid handling"""
try:
# Store current view state if any layer exists
if len(viewer.layers) > 0:
current_zoom = viewer.camera.zoom
current_center = viewer.camera.center
else:
current_zoom = None
current_center = None

# Get the data shape from the series
tiff_file = img._store._source
series = tiff_file.series[0]

# Check if we have pyramid levels
has_pyramids = check_pyramid_levels(tiff_file)

if has_pyramids:
# Load pyramid levels
pyramid_data = []
for level in series.levels:
level_data = level.pages[channel_idx].asarray()
pyramid_data.append(level_data)

# Add to viewer as multiscale
viewer.add_image(
pyramid_data,
name=channel_name,
visible=False,
colormap=colormap,
blending='additive',
contrast_limits=contrast_limits,
multiscale=True,
rendering='mip',
interpolation2d='nearest'
)
else:
# Fallback to single resolution
channel_data = series.pages[channel_idx].asarray()
viewer.add_image(
channel_data,
name=channel_name,
visible=False,
colormap=colormap,
blending='additive',
contrast_limits=contrast_limits,
multiscale=False,
rendering='mip',
interpolation2d='nearest'
)

# After adding the new layer, restore view state if it existed
if current_zoom is not None:
viewer.camera.zoom = current_zoom
viewer.camera.center = current_center

return True

except Exception as e:
print(f"Warning: Channel {channel_name} could not be loaded")
return False


def load_image_efficiently(image_path):
#"""Efficiently load image with proper lazy loading"""
"""Efficiently load image using zarr conversion"""
if isinstance(image_path, str):
if image_path.endswith(('.tiff', '.tif')):
tiff_file = tiff.TiffFile(image_path, is_ome=False)
series = tiff_file.series[0]
image = tiff.TiffFile(image_path, is_ome=False)
z = zarr.open(image.aszarr(), mode='r')

if hasattr(series, 'levels'):
# For pyramidal images, create a list of dask arrays
data = []
for level in series.levels:
shape = (len(level.pages),) + level.pages[0].shape
chunks = (1,) + level.pages[0].shape # Chunk by channel

@delayed
def get_page(i, level=level):
return level.pages[i].asarray()

# Create lazy dask array for this level
level_data = da.stack([
da.from_delayed(
get_page(i),
shape=level.pages[0].shape,
dtype=level.pages[0].dtype
)
for i in range(len(level.pages))
])
data.append(level_data)

return data, tiff_file
# Check for pyramids
n_levels = len(image.series[0].levels)

if n_levels > 1:
data = [da.from_zarr(z[i]) for i in range(n_levels)]
multiscale = True
else:
# For non-pyramidal images, create single dask array
shape = (len(series.pages),) + series.pages[0].shape
chunks = (1,) + series.pages[0].shape
data = da.from_zarr(z)
multiscale = False

@delayed
def get_page(i):
return series.pages[i].asarray()

data = da.stack([
da.from_delayed(
get_page(i),
shape=series.pages[0].shape,
dtype=series.pages[0].dtype
)
for i in range(len(series.pages))
])

return data, tiff_file
return data, image, multiscale

return None, None
return None, None, False

def add_channels_to_viewer(viewer, img_data, channel_names, contrast_settings, colormaps):
#"""Add all channels to viewer efficiently"""
if isinstance(img_data, list): # Pyramidal
# Add all channels at once with multiscale
for channel_idx, channel_name in enumerate(channel_names):
contrast_limits = (
contrast_settings[channel_name]['low'],
contrast_settings[channel_name]['high']
)

# Extract this channel's data across all pyramid levels
channel_data = [level[channel_idx] for level in img_data]

viewer.add_image(
channel_data,
name=channel_name,
visible=False,
colormap=colormaps[channel_idx % len(colormaps)],
blending='additive',
contrast_limits=contrast_limits,
multiscale=True,
rendering='mip',
interpolation2d='nearest'
)
else: # Non-pyramidal
# Add all channels at once
viewer.add_image(
img_data,
channel_axis=0,
name=channel_names,
visible=False,
colormap=colormaps,
blending='additive',
contrast_limits=[
(contrast_settings[name]['low'], contrast_settings[name]['high'])
for name in channel_names
],
multiscale=False,
rendering='mip',
interpolation2d='nearest'
)
"""Add channels maintaining pyramid structure if available"""
n_channels = len(channel_names)
extended_colormaps = [colormaps[i % len(colormaps)] for i in range(n_channels)]

viewer.add_image(
img_data,
channel_axis=0,
name=channel_names,
visible=False,
colormap=extended_colormaps,
blending='additive',
contrast_limits=[
(contrast_settings[name]['low'], contrast_settings[name]['high'])
for name in channel_names
],
multiscale=isinstance(img_data, list), # True if pyramidal
rendering='mip',
interpolation2d='nearest'
)


def napariGater(
Expand Down Expand Up @@ -498,7 +395,7 @@ def napariGater(

# Load image efficiently
print("Loading image data...")
img_data, tiff_file = load_image_efficiently(image_path)
img_data, tiff_file, multiscale = load_image_efficiently(image_path)
if img_data is None:
raise ValueError("Failed to load image data")

Expand Down Expand Up @@ -532,12 +429,17 @@ def napariGater(
# Create the viewer and add all channels efficiently
viewer = napari.Viewer()

default_colormaps = [
'magenta', 'cyan', 'yellow', 'red', 'green', 'blue',
'magenta', 'cyan', 'yellow', 'red', 'green', 'blue'
] # Basic colors that will be cycled

add_channels_to_viewer(
viewer,
img_data,
channel_names,
adata.uns['image_contrast_settings'][current_image],
colormaps=['magenta', 'cyan', 'yellow', 'red', 'green', 'blue']
colormaps=default_colormaps
)

# Verify loaded channels
Expand Down Expand Up @@ -576,20 +478,37 @@ def napariGater(

@magicgui(
auto_call=True,
marker={'choices': list(adata.var.index), 'value': initial_marker},
layout='vertical',
marker={
'choices': list(adata.var.index),
'value': initial_marker,
'label': 'Select Marker:'
},
gate={
'widget_type': 'FloatSpinBox',
'min': min_val,
'max': max_val,
'value': initial_gate,
'step': 0.01,
'label': 'Gate Threshold:'
},
marker_status={
'widget_type': 'Label',
'value': '⚪ Not adjusted' # Initial value
},
confirm_gate={
'widget_type': 'PushButton',
'text': 'Confirm Gate'
},
finish={
'widget_type': 'PushButton',
'text': 'Finish Gating'
},
confirm_gate={'widget_type': 'PushButton', 'text': 'Confirm Gate'},
finish={'widget_type': 'PushButton', 'text': 'Finish Gating'},
)
def gate_controls(
marker: str,
gate: float = initial_gate,
marker_status: str = '⚪ Not adjusted',
confirm_gate=False,
finish=False,
):
Expand Down Expand Up @@ -656,12 +575,43 @@ def _on_marker_change(marker: str):
viewer.camera.zoom = current_state['zoom']
viewer.camera.center = current_state['center']

# Update status with more visible formatting and shorter timestamp
current_image = adata.obs[imageid].iloc[0] if subset is None else subset
is_adjusted = marker in adata.uns['napariGaterProvenance']['manually_adjusted'].get(current_image, {})
if is_adjusted:
status_text = "✓ ADJUSTED"
# Get and format timestamp
timestamp = adata.uns['napariGaterProvenance']['timestamp'][current_image][marker]
# Convert stored timestamp to shorter format
from datetime import datetime
try:
dt = datetime.strptime(timestamp, "%Y-%m-%d %H:%M:%S")
short_timestamp = dt.strftime("%y-%m-%d %H:%M")
status_text += f" ({short_timestamp})"
except:
status_text += f" ({timestamp})"
else:
status_text = "⚪ NOT ADJUSTED"

gate_controls.marker_status.value = status_text

@gate_controls.confirm_gate.clicked.connect
def _on_confirm():
marker = gate_controls.marker.value
gate = gate_controls.gate.value
current_image = adata.obs[imageid].iloc[0] if subset is None else subset

# Update gate value
adata.uns['gates'].loc[marker, current_image] = float(gate)

# Update provenance with shorter timestamp
from datetime import datetime
timestamp = datetime.now().strftime("%y-%m-%d %H:%M") # Shorter format
adata.uns['napariGaterProvenance']['manually_adjusted'][current_image][marker] = float(gate)
adata.uns['napariGaterProvenance']['timestamp'][current_image][marker] = timestamp

# Update status with confirmation message
gate_controls.marker_status.value = f"✓ ADJUSTED ({timestamp})"

# Add handler for finish button
@gate_controls.finish.clicked.connect
Expand Down

0 comments on commit f7940a3

Please sign in to comment.