Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restructure workflow to facilitate workflow partitioning for efficient streaming execution #163

Merged
merged 12 commits into from
Sep 25, 2024
2 changes: 1 addition & 1 deletion docs/user-guide/common/beam-center-finder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@
"kwargs = dict( # noqa: C408\n",
" workflow=workflow,\n",
" detector=detector['data'],\n",
" norm=workflow.compute(NormWavelengthTerm[SampleRun]),\n",
" norm=workflow.compute(CleanDirectBeam),\n",
")"
]
},
Expand Down
11 changes: 7 additions & 4 deletions docs/user-guide/isis/sans2d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,10 @@
" WavelengthMonitor[BackgroundRun, Incident],\n",
" WavelengthMonitor[BackgroundRun, Transmission],\n",
")\n",
"parts = (CleanSummedQ[SampleRun, Numerator], CleanSummedQ[SampleRun, Denominator])\n",
"parts = (\n",
" WavelengthScaledQ[SampleRun, Numerator],\n",
" WavelengthScaledQ[SampleRun, Denominator],\n",
")\n",
"iofqs = (IofQ[SampleRun], IofQ[BackgroundRun], BackgroundSubtractedIofQ)\n",
"keys = (*monitors, MaskedData[SampleRun], *parts, *iofqs)\n",
"\n",
Expand All @@ -335,12 +338,12 @@
"\n",
"wavelength = workflow.compute(WavelengthBins)\n",
"display(\n",
" results[CleanSummedQ[SampleRun, Numerator]]\n",
" results[WavelengthScaledQ[SampleRun, Numerator]]\n",
" .hist(wavelength=wavelength)\n",
" .transpose()\n",
" .plot(norm='log')\n",
")\n",
"display(results[CleanSummedQ[SampleRun, Denominator]].plot(norm='log'))\n",
"display(results[WavelengthScaledQ[SampleRun, Denominator]].plot(norm='log'))\n",
"parts = {str(key): results[key].sum('wavelength') for key in parts}\n",
"display(sc.plot(parts, norm='log'))\n",
"\n",
Expand Down Expand Up @@ -429,7 +432,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
11 changes: 7 additions & 4 deletions docs/user-guide/isis/zoom.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@
" WavelengthMonitor[SampleRun, Incident],\n",
" WavelengthMonitor[SampleRun, Transmission],\n",
")\n",
"parts = (CleanSummedQ[SampleRun, Numerator], CleanSummedQ[SampleRun, Denominator])\n",
"parts = (\n",
" WavelengthScaledQ[SampleRun, Numerator],\n",
" WavelengthScaledQ[SampleRun, Denominator],\n",
")\n",
"iofqs = (IofQ[SampleRun],)\n",
"keys = (*monitors, MaskedData[SampleRun], *parts, *iofqs)\n",
"\n",
Expand All @@ -249,12 +252,12 @@
"\n",
"wavelength = workflow.compute(WavelengthBins)\n",
"display(\n",
" results[CleanSummedQ[SampleRun, Numerator]]\n",
" results[WavelengthScaledQ[SampleRun, Numerator]]\n",
" .hist(wavelength=wavelength)\n",
" .transpose()\n",
" .plot(norm='log')\n",
")\n",
"display(results[CleanSummedQ[SampleRun, Denominator]].plot(norm='log'))\n",
"display(results[WavelengthScaledQ[SampleRun, Denominator]].plot(norm='log'))\n",
"parts = {str(key): results[key] for key in parts}\n",
"parts = {\n",
" key: val.sum('wavelength') if val.bins is None else val.hist()\n",
Expand Down Expand Up @@ -309,7 +312,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
10 changes: 4 additions & 6 deletions src/ess/sans/beam_center_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
from .logging import get_logger
from .types import (
BeamCenter,
CleanDirectBeam,
DetectorBankSizes,
DimsToKeep,
IofQ,
MaskedData,
NeXusDetector,
NormWavelengthTerm,
QBins,
ReturnEvents,
SampleRun,
Expand Down Expand Up @@ -174,9 +174,7 @@ def _iofq_in_quadrants(
workflow[NeXusDetector[SampleRun]] = sc.DataGroup(data=detector[sel])
# MaskedData would be computed automatically, but we did it above already
workflow[MaskedData[SampleRun]] = calibrated[sel]
workflow[NormWavelengthTerm[SampleRun]] = (
norm if norm.dims == ('wavelength',) else norm[sel]
)
workflow[CleanDirectBeam] = norm if norm.dims == ('wavelength',) else norm[sel]
out[quad] = workflow.compute(IofQ[SampleRun])
return out

Expand Down Expand Up @@ -363,7 +361,7 @@ def beam_center_from_iofq(
keys = (
NeXusDetector[SampleRun],
MaskedData[SampleRun],
NormWavelengthTerm[SampleRun],
CleanDirectBeam,
ElasticCoordTransformGraph,
)
workflow = workflow.copy()
Expand All @@ -372,7 +370,7 @@ def beam_center_from_iofq(
results = workflow.compute(keys)
detector = results[NeXusDetector[SampleRun]]['data']
data = results[MaskedData[SampleRun]]
norm = results[NormWavelengthTerm[SampleRun]]
norm = results[CleanDirectBeam]
graph = results[ElasticCoordTransformGraph]

# Avoid reloading the detector
Expand Down
58 changes: 50 additions & 8 deletions src/ess/sans/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import NewType

import scipp as sc
from ess.reduce.uncertainty import broadcast_uncertainties
from scippneutron.conversion.beamline import (
beam_aligned_unit_vectors,
scattering_angles_with_gravity,
Expand All @@ -13,18 +14,24 @@
from .types import (
CleanQ,
CleanQxy,
CleanSummedQ,
CleanSummedQxy,
CleanWavelength,
CleanWavelengthMasked,
CorrectForGravity,
Denominator,
IofQPart,
MaskedData,
MonitorTerm,
MonitorType,
Numerator,
RunType,
ScatteringRunType,
TofMonitor,
UncertaintyBroadcastMode,
WavelengthMask,
WavelengthMonitor,
WavelengthScaledQ,
WavelengthScaledQxy,
)


Expand Down Expand Up @@ -163,12 +170,44 @@ def detector_to_wavelength(
)


def mask_wavelength(
da: CleanWavelength[ScatteringRunType, IofQPart], mask: WavelengthMask
) -> CleanWavelengthMasked[ScatteringRunType, IofQPart]:
def mask_wavelength_q(
da: CleanSummedQ[ScatteringRunType, Numerator], mask: WavelengthMask
) -> WavelengthScaledQ[ScatteringRunType, Numerator]:
if mask is not None:
da = mask_range(da, mask=mask)
return CleanWavelengthMasked[ScatteringRunType, IofQPart](da)
return CleanSummedQ[ScatteringRunType, Numerator](da)
SimonHeybrock marked this conversation as resolved.
Show resolved Hide resolved


def mask_wavelength_qxy(
da: CleanSummedQxy[ScatteringRunType, Numerator], mask: WavelengthMask
) -> WavelengthScaledQxy[ScatteringRunType, Numerator]:
if mask is not None:
da = mask_range(da, mask=mask)
return CleanSummedQxy[ScatteringRunType, Numerator](da)
SimonHeybrock marked this conversation as resolved.
Show resolved Hide resolved


def mask_and_scale_wavelength_q(
da: CleanSummedQ[ScatteringRunType, Denominator],
mask: WavelengthMask,
wavelength_term: MonitorTerm[ScatteringRunType],
uncertainties: UncertaintyBroadcastMode,
) -> WavelengthScaledQ[ScatteringRunType, Denominator]:
da = da * broadcast_uncertainties(wavelength_term, prototype=da, mode=uncertainties)
if mask is not None:
da = mask_range(da, mask=mask)
return WavelengthScaledQ[ScatteringRunType, Denominator](da)


def mask_and_scale_wavelength_qxy(
da: CleanSummedQxy[ScatteringRunType, Denominator],
mask: WavelengthMask,
wavelength_term: MonitorTerm[ScatteringRunType],
uncertainties: UncertaintyBroadcastMode,
) -> WavelengthScaledQxy[ScatteringRunType, Denominator]:
da = da * broadcast_uncertainties(wavelength_term, prototype=da, mode=uncertainties)
if mask is not None:
da = mask_range(da, mask=mask)
return WavelengthScaledQxy[ScatteringRunType, Denominator](da)


def _compute_Q(
Expand All @@ -186,7 +225,7 @@ def _compute_Q(


def compute_Q(
data: CleanWavelengthMasked[ScatteringRunType, IofQPart],
data: CleanWavelength[ScatteringRunType, IofQPart],
graph: ElasticCoordTransformGraph,
) -> CleanQ[ScatteringRunType, IofQPart]:
"""
Expand All @@ -198,7 +237,7 @@ def compute_Q(


def compute_Qxy(
data: CleanWavelengthMasked[ScatteringRunType, IofQPart],
data: CleanWavelength[ScatteringRunType, IofQPart],
graph: ElasticCoordTransformGraph,
) -> CleanQxy[ScatteringRunType, IofQPart]:
"""
Expand All @@ -214,7 +253,10 @@ def compute_Qxy(
sans_monitor,
monitor_to_wavelength,
detector_to_wavelength,
mask_wavelength,
mask_wavelength_q,
mask_wavelength_qxy,
mask_and_scale_wavelength_q,
mask_and_scale_wavelength_qxy,
compute_Q,
compute_Qxy,
)
22 changes: 11 additions & 11 deletions src/ess/sans/direct_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from .types import (
BackgroundRun,
BackgroundSubtractedIofQ,
CleanSummedQ,
Denominator,
DirectBeam,
Numerator,
ProcessedWavelengthBands,
SampleRun,
WavelengthBands,
WavelengthBins,
WavelengthScaledQ,
)


Expand Down Expand Up @@ -103,16 +103,16 @@ def direct_beam(*, workflow: Pipeline, I0: sc.Variable, niter: int = 5) -> list[

wavelength_bins = workflow.compute(WavelengthBins)
parts = (
CleanSummedQ[SampleRun, Numerator],
CleanSummedQ[SampleRun, Denominator],
CleanSummedQ[BackgroundRun, Numerator],
CleanSummedQ[BackgroundRun, Denominator],
WavelengthScaledQ[SampleRun, Numerator],
WavelengthScaledQ[SampleRun, Denominator],
WavelengthScaledQ[BackgroundRun, Numerator],
WavelengthScaledQ[BackgroundRun, Denominator],
)
parts = workflow.compute(parts)
# Convert events to histograms to make normalization (in every iteration) cheap
for key in [
CleanSummedQ[SampleRun, Numerator],
CleanSummedQ[BackgroundRun, Numerator],
WavelengthScaledQ[SampleRun, Numerator],
WavelengthScaledQ[BackgroundRun, Numerator],
]:
parts[key] = parts[key].hist(wavelength=wavelength_bins)

Expand All @@ -121,8 +121,8 @@ def direct_beam(*, workflow: Pipeline, I0: sc.Variable, niter: int = 5) -> list[
parts = {key: sc.values(result) for key, result in parts.items()}
for key, part in parts.items():
workflow[key] = part
sample0 = parts[CleanSummedQ[SampleRun, Denominator]]
background0 = parts[CleanSummedQ[BackgroundRun, Denominator]]
sample0 = parts[WavelengthScaledQ[SampleRun, Denominator]]
background0 = parts[WavelengthScaledQ[BackgroundRun, Denominator]]

results = []

Expand Down Expand Up @@ -158,8 +158,8 @@ def direct_beam(*, workflow: Pipeline, I0: sc.Variable, niter: int = 5) -> list[
db.coords['wavelength'] = sc.midpoints(
db.coords['wavelength'], dim='wavelength'
)
workflow[CleanSummedQ[SampleRun, Denominator]] = sample0 * db
workflow[CleanSummedQ[BackgroundRun, Denominator]] = background0 * db
workflow[WavelengthScaledQ[SampleRun, Denominator]] = sample0 * db
workflow[WavelengthScaledQ[BackgroundRun, Denominator]] = background0 * db

results.append(
{
Expand Down
7 changes: 6 additions & 1 deletion src/ess/sans/i_of_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,12 @@ def resample_direct_beam(
The direct beam function resampled to the requested resolution.
"""
if direct_beam is None:
return CleanDirectBeam(None)
return CleanDirectBeam(
sc.DataArray(
sc.ones(dims=wavelength_bins.dims, shape=[len(wavelength_bins) - 1]),
coords={'wavelength': wavelength_bins},
)
)
if sc.identical(direct_beam.coords['wavelength'], wavelength_bins):
return direct_beam
if direct_beam.variances is not None:
Expand Down
Loading