Skip to content

Commit

Permalink
Clean up implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
astrofrog committed Jul 11, 2023
1 parent ed7647f commit 0f83a03
Show file tree
Hide file tree
Showing 7 changed files with 34 additions and 26 deletions.
16 changes: 13 additions & 3 deletions doc/customizing_guide/customization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,21 @@ Custom Stretches
----------------

You can add additional stretches for use in e.g. the image viewer. Stretches should be
provided as callables (e.g. functions or classes with a ``__call__`` method) which take
values in the range [0:1] and return values in the range [0:1]:
provided as a function or an initialized class with a ``__call__`` method which takes
values in the range [0:1] and return values in the range [0:1], and takes an optional
``out=`` keyword argument. If this is set, the array values should be modified in-place
and the output array should be returned::

from glue.config import stretches
stretches.add('cbrt', lambda x: x ** (1/3))

def cbrt(x, out=None):
if out is not None:
out[:] = out ** 1/3
return out
else:
return x ** 1/3

stretches.add('cbrt', cbrt)

.. _custom-actions:

Expand Down
18 changes: 13 additions & 5 deletions glue/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -634,18 +634,26 @@ class StretchRegistry(DictRegistry):
Stores custom stretches
"""

def add(self, label, stretch_cls):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._display = {}

def add(self, label, stretch_cls, display=None):
if label in self.members:
raise ValueError("Stretch class '{0}' already registered".format(label))
else:
self.members[label] = stretch_cls
self._display[label] = display or label

def __call__(self, label):
def adder(stretch_cls):
self.add(label, stretch_cls)
return stretch_cls
return adder

def display_func(self, label):
return self._display[label]


class QtClientRegistry(Registry):
"""
Expand Down Expand Up @@ -1024,10 +1032,10 @@ def __iter__(self):
from astropy.visualization import (LinearStretch, SqrtStretch, AsinhStretch,
LogStretch)
stretches = StretchRegistry()
stretches.add('linear', LinearStretch)
stretches.add('sqrt', SqrtStretch)
stretches.add('arcsinh', AsinhStretch)
stretches.add('log', LogStretch)
stretches.add('linear', LinearStretch(), display='Linear')
stretches.add('sqrt', SqrtStretch(), display='Square Root')
stretches.add('arcsinh', AsinhStretch(), display='Arcsinh')
stretches.add('log', LogStretch(), display='Logarithmic')

# Backward-compatibility
single_subset_action = layer_action
Expand Down
2 changes: 1 addition & 1 deletion glue/viewers/image/composite_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def __call__(self, bounds=None):

interval = ManualInterval(*layer['clim'])
contrast_bias = ContrastBiasStretch(layer['contrast'], layer['bias'])
stretch = stretches[layer['stretch']]()
stretch = stretches.members[layer['stretch']]

if callable(layer['array']):
array = layer['array'](bounds=bounds)
Expand Down
11 changes: 3 additions & 8 deletions glue/viewers/image/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections import defaultdict

from glue.core import BaseData
from glue.config import colormaps
from glue.config import colormaps, stretches
from glue.viewers.matplotlib.state import (MatplotlibDataViewerState,
MatplotlibLayerState,
DeferredDrawCallbackProperty as DDCProperty,
Expand Down Expand Up @@ -525,13 +525,8 @@ def __init__(self, layer=None, viewer_state=None, **kwargs):
ImageLayerState.percentile.set_choices(self, [100, 99.5, 99, 95, 90, 'Custom'])
ImageLayerState.percentile.set_display_func(self, percentile_display.get)

stretch_display = {'linear': 'Linear',
'sqrt': 'Square Root',
'arcsinh': 'Arcsinh',
'log': 'Logarithmic'}

ImageLayerState.stretch.set_choices(self, ['linear', 'sqrt', 'arcsinh', 'log'])
ImageLayerState.stretch.set_display_func(self, stretch_display.get)
ImageLayerState.stretch.set_choices(self, list(stretches.members))
ImageLayerState.stretch.set_display_func(self, stretches.display_func)

self.add_callback('global_sync', self._update_syncing)
self.add_callback('layer', self._update_attribute)
Expand Down
2 changes: 1 addition & 1 deletion glue/viewers/scatter/layer_artist.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def _update_visual_attributes(self, changed, force=False):
set_mpl_artist_cmap(self.density_artist, c, self.state)

if force or 'stretch' in changed:
self.density_artist.set_norm(ImageNormalize(stretch=stretches[self.state.stretch]()))
self.density_artist.set_norm(ImageNormalize(stretch=stretches.members[self.state.stretch]))

if force or 'dpi' in changed:
self.density_artist.set_dpi(self._viewer_state.dpi)
Expand Down
2 changes: 1 addition & 1 deletion glue/viewers/scatter/python_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def python_export_scatter_layer(layer, *args):
options['color'] = layer.state.color
options['vmin'] = code('density_limits.min')
options['vmax'] = code('density_limits.max')
options['norm'] = code("ImageNormalize(stretch=stretches['{0}']())".format(layer.state.stretch))
options['norm'] = code("ImageNormalize(stretch=stretches.members['{0}'])".format(layer.state.stretch))
else:
options['c'] = code("layer_data['{0}']".format(layer.state.cmap_att.label))
options['cmap'] = code("plt.cm.{0}".format(layer.state.cmap.name))
Expand Down
9 changes: 2 additions & 7 deletions glue/viewers/scatter/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from glue.core import BaseData, Subset

from glue.config import colormaps
from glue.config import colormaps, stretches
from glue.viewers.matplotlib.state import (MatplotlibDataViewerState,
MatplotlibLayerState,
DeferredDrawCallbackProperty as DDCProperty,
Expand Down Expand Up @@ -327,13 +327,8 @@ def __init__(self, viewer_state=None, layer=None, **kwargs):
ScatterLayerState.vector_origin.set_choices(self, ['tail', 'middle', 'tip'])
ScatterLayerState.vector_origin.set_display_func(self, vector_origin_display.get)

stretch_display = {'linear': 'Linear',
'sqrt': 'Square Root',
'arcsinh': 'Arcsinh',
'log': 'Logarithmic'}

ScatterLayerState.stretch.set_choices(self, ['linear', 'sqrt', 'arcsinh', 'log'])
ScatterLayerState.stretch.set_display_func(self, stretch_display.get)
ScatterLayerState.stretch.set_display_func(self, stretches.display_func)

if self.viewer_state is not None:
self.viewer_state.add_callback('x_att', self._on_xy_change, priority=10000)
Expand Down

0 comments on commit 0f83a03

Please sign in to comment.