Skip to content

Commit

Permalink
Merge pull request #35 from ericpre/fix_using_api_nogui
Browse files Browse the repository at this point in the history
Update for pint unit registry in hyperspy
  • Loading branch information
jlaehne authored Apr 15, 2024
2 parents bd7ade2 + e4c3b61 commit aeaa107
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 16 deletions.
42 changes: 27 additions & 15 deletions holospy/signals/hologram_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,15 @@
# You should have received a copy of the GNU General Public License
# along with HyperSpy. If not, see <https://www.gnu.org/licenses/#GPL>.

import importlib
import logging
from collections import OrderedDict
import scipy.constants as constants
import numpy as np
from dask.array import Array as daArray
from pint import UndefinedUnitError

from hyperspy.api_nogui import _ureg
from hyperspy._signals.signal2d import Signal2D
from hyperspy.signal import BaseSignal
from hyperspy._signals.signal1d import Signal1D
import hyperspy.api as hs
from hyperspy._signals.lazy import LazySignal
from holospy.reconstruct import (
reconstruct,
Expand All @@ -43,6 +41,20 @@
LAZYSIGNAL_DOC,
)

if importlib.util.find_spec("hyperspy.api_nogui") is None:
# Considering the usage of the UnitRegistry in holospy,
# sharing the same UnitRegistry in holospy is not necessary
# because there is no operations between quantities defined in
# hyperspy and holospy but this is good practise and
# can be used as a reference
import pint

_ureg = pint.get_application_registry()

else:
# Before hyperspy migrate to use pint default global UnitRegistry
from hyperspy.api_nogui import _ureg

_logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -70,13 +82,13 @@ def _parse_sb_position(s, reference, sb_position, sb, high_cf, num_workers=None)

else:
if (
isinstance(sb_position, BaseSignal)
isinstance(sb_position, hs.signals.BaseSignal)
and not sb_position._signal_dimension == 1
):
raise ValueError("sb_position dimension has to be 1.")

if not isinstance(sb_position, Signal1D):
sb_position = Signal1D(sb_position)
if not isinstance(sb_position, hs.signals.Signal1D):
sb_position = hs.signals.Signal1D(sb_position)
if isinstance(sb_position.data, daArray):
sb_position = sb_position.as_lazy()

Expand Down Expand Up @@ -107,12 +119,12 @@ def _parse_sb_size(s, reference, sb_position, sb_size, num_workers=None):
sb_position, num_workers=num_workers
)
else:
if not isinstance(sb_size, BaseSignal):
if not isinstance(sb_size, hs.signals.BaseSignal):
if isinstance(sb_size, (np.ndarray, daArray)) and sb_size.size > 1:
# transpose if np.array of multiple instances
sb_size = BaseSignal(sb_size).T
sb_size = hs.signals.BaseSignal(sb_size).T
else:
sb_size = BaseSignal(sb_size)
sb_size = hs.signals.BaseSignal(sb_size)
if isinstance(sb_size.data, daArray):
sb_size = sb_size.as_lazy()
if sb_size.axes_manager.navigation_size != s.axes_manager.navigation_size:
Expand Down Expand Up @@ -150,7 +162,7 @@ def _estimate_fringe_contrast_statistical(signal):
return signal.std(axes) / signal.mean(axes)


class HologramImage(Signal2D):
class HologramImage(hs.signals.Signal2D):
"""Signal class for holograms acquired via off-axis electron holography."""

_signal_type = "hologram"
Expand Down Expand Up @@ -412,7 +424,7 @@ def reconstruct_phase(

# Parsing reference:
if not isinstance(reference, HologramImage):
if isinstance(reference, Signal2D):
if isinstance(reference, hs.signals.Signal2D):
if (
not reference.axes_manager.navigation_shape
== self.axes_manager.navigation_shape
Expand Down Expand Up @@ -480,14 +492,14 @@ def reconstruct_phase(
if sb_smoothness is None:
sb_smoothness = sb_size * 0.05
else:
if not isinstance(sb_smoothness, BaseSignal):
if not isinstance(sb_smoothness, hs.signals.BaseSignal):
if (
isinstance(sb_smoothness, (np.ndarray, daArray))
and sb_smoothness.size > 1
):
sb_smoothness = BaseSignal(sb_smoothness).T
sb_smoothness = hs.signals.BaseSignal(sb_smoothness).T
else:
sb_smoothness = BaseSignal(sb_smoothness)
sb_smoothness = hs.signals.BaseSignal(sb_smoothness)
if isinstance(sb_smoothness.data, daArray):
sb_smoothness = sb_smoothness.as_lazy()

Expand Down
12 changes: 11 additions & 1 deletion holospy/tests/signals/test_hologram_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,10 @@
from scipy.interpolate import RectBivariateSpline

import hyperspy.api as hs
from holospy.signals.hologram_image import HologramImage
from holospy.signals.hologram_image import HologramImage, _ureg
from hyperspy.decorators import lazifyTestClass


# Set parameters outside the tests
img_size = 256
IMG_SIZE3X = 128
Expand Down Expand Up @@ -330,3 +331,12 @@ def test_reconstruct_phase_multi(lazy):
# e. Beam energy is not assigned, while 'mrad' units selected
with pytest.raises(AttributeError):
holo_image3.reconstruct_phase(sb_size=40, sb_unit="mrad")


def test_pint_unit_registry():
# Check that holospy and hyperspy share the same UnitRegistry
s = hs.signals.Signal1D(np.arange(10))
# this will not work if the UnitRegistry are not the same
s.axes_manager[0].scale_as_quantity = "2.5 µm"
s.axes_manager[0].scale_as_quantity += 2e-6 * _ureg.meter
assert s.axes_manager[0].scale == 4.5
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ classifiers = [
dependencies = [
"hyperspy>=2.0rc0",
"numpy>=1.20.0",
"pint>=0.10",
"scipy>=1.5.0",
]
dynamic = ["version"]
Expand Down
1 change: 1 addition & 0 deletions upcoming_changes/35.maintenance.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Use ``pint.get_application_registry`` instead of HyperSpy private API to get the handle of the ``pint.UnitRegistry``.

0 comments on commit aeaa107

Please sign in to comment.