Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return peak stats in focus_from_transverse_band #179

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions tests/test_focus_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,21 +38,26 @@ def test_focus_estimator(tmp_path):

plot_path = tmp_path.joinpath("test.pdf")
data3D = np.random.random((11, 256, 256))
slice = focus.focus_from_transverse_band(
slice, stats = focus.focus_from_transverse_band(
data3D, NA_det, lambda_ill, ps, plot_path=str(plot_path)
)
assert slice >= 0
assert slice <= data3D.shape[0]
assert plot_path.exists()
assert isinstance(stats, dict)
assert stats["peak_index"] == slice
assert stats["peak_FWHM"] > 0

# Check single slice
slice = focus.focus_from_transverse_band(
slice, stats = focus.focus_from_transverse_band(
np.random.random((1, 10, 10)),
NA_det,
lambda_ill,
ps,
)
assert slice == 0
assert stats["peak_index"] is None
assert stats["peak_FWHM"] is None


def test_focus_estimator_snr(tmp_path):
Expand Down Expand Up @@ -80,7 +85,7 @@ def test_focus_estimator_snr(tmp_path):
ps,
plot_path=plot_path,
threshold_FWHM=5,
)
)[0]
assert plot_path.exists()
if slice is not None:
assert np.abs(slice - 10) <= 2
10 changes: 7 additions & 3 deletions waveorder/focus.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def focus_from_transverse_band(
return the index of the in-focus slice
else:
return None
peak_stats : dict
Dictionary with statistics of the detected peaks, currently 'peak_index' and 'peak_FWHM'.

Example
------
Expand All @@ -62,6 +64,7 @@ def focus_from_transverse_band(
>>> in_focus_data = data[slice,:,:]
"""
minmaxfunc = _mode_to_minmaxfunc(mode)
peak_stats = {'peak_index': None, 'peak_FWHM': None}

_check_focus_inputs(
zyx_array, NA_det, lambda_ill, pixel_size, midband_fractions
Expand All @@ -72,7 +75,7 @@ def focus_from_transverse_band(
warnings.warn(
"The dataset only contained a single slice. Returning trivial slice index = 0."
)
return 0
return 0, peak_stats

# Calculate coordinates
_, Y, X = zyx_array.shape
Expand All @@ -95,9 +98,10 @@ def focus_from_transverse_band(

peak_results = peak_widths(midband_sum, [peak_index])
peak_FWHM = peak_results[0][0]
peak_stats.update({'peak_index': peak_index, 'peak_FWHM': peak_FWHM})

if peak_FWHM >= threshold_FWHM:
in_focus_index = peak_index
in_focus_index = int(peak_index)
else:
in_focus_index = None

Expand All @@ -112,7 +116,7 @@ def focus_from_transverse_band(
threshold_FWHM,
)

return in_focus_index
return in_focus_index, peak_stats


def _mode_to_minmaxfunc(mode):
Expand Down
Loading