Skip to content

Commit

Permalink
Merge pull request #163 from scipp/streaming-workflow
Browse files Browse the repository at this point in the history
Restructure workflow to facilitate workflow partitioning for efficient streaming execution
  • Loading branch information
SimonHeybrock authored Sep 25, 2024
2 parents f813d90 + 8a55abd commit 8826f6b
Show file tree
Hide file tree
Showing 11 changed files with 261 additions and 185 deletions.
2 changes: 1 addition & 1 deletion docs/user-guide/common/beam-center-finder.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@
"kwargs = dict( # noqa: C408\n",
" workflow=workflow,\n",
" detector=detector['data'],\n",
" norm=workflow.compute(NormWavelengthTerm[SampleRun]),\n",
" norm=workflow.compute(CleanDirectBeam),\n",
")"
]
},
Expand Down
11 changes: 7 additions & 4 deletions docs/user-guide/isis/sans2d.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,10 @@
" WavelengthMonitor[BackgroundRun, Incident],\n",
" WavelengthMonitor[BackgroundRun, Transmission],\n",
")\n",
"parts = (CleanSummedQ[SampleRun, Numerator], CleanSummedQ[SampleRun, Denominator])\n",
"parts = (\n",
" WavelengthScaledQ[SampleRun, Numerator],\n",
" WavelengthScaledQ[SampleRun, Denominator],\n",
")\n",
"iofqs = (IofQ[SampleRun], IofQ[BackgroundRun], BackgroundSubtractedIofQ)\n",
"keys = (*monitors, MaskedData[SampleRun], *parts, *iofqs)\n",
"\n",
Expand All @@ -335,12 +338,12 @@
"\n",
"wavelength = workflow.compute(WavelengthBins)\n",
"display(\n",
" results[CleanSummedQ[SampleRun, Numerator]]\n",
" results[WavelengthScaledQ[SampleRun, Numerator]]\n",
" .hist(wavelength=wavelength)\n",
" .transpose()\n",
" .plot(norm='log')\n",
")\n",
"display(results[CleanSummedQ[SampleRun, Denominator]].plot(norm='log'))\n",
"display(results[WavelengthScaledQ[SampleRun, Denominator]].plot(norm='log'))\n",
"parts = {str(key): results[key].sum('wavelength') for key in parts}\n",
"display(sc.plot(parts, norm='log'))\n",
"\n",
Expand Down Expand Up @@ -429,7 +432,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
11 changes: 7 additions & 4 deletions docs/user-guide/isis/zoom.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@
" WavelengthMonitor[SampleRun, Incident],\n",
" WavelengthMonitor[SampleRun, Transmission],\n",
")\n",
"parts = (CleanSummedQ[SampleRun, Numerator], CleanSummedQ[SampleRun, Denominator])\n",
"parts = (\n",
" WavelengthScaledQ[SampleRun, Numerator],\n",
" WavelengthScaledQ[SampleRun, Denominator],\n",
")\n",
"iofqs = (IofQ[SampleRun],)\n",
"keys = (*monitors, MaskedData[SampleRun], *parts, *iofqs)\n",
"\n",
Expand All @@ -249,12 +252,12 @@
"\n",
"wavelength = workflow.compute(WavelengthBins)\n",
"display(\n",
" results[CleanSummedQ[SampleRun, Numerator]]\n",
" results[WavelengthScaledQ[SampleRun, Numerator]]\n",
" .hist(wavelength=wavelength)\n",
" .transpose()\n",
" .plot(norm='log')\n",
")\n",
"display(results[CleanSummedQ[SampleRun, Denominator]].plot(norm='log'))\n",
"display(results[WavelengthScaledQ[SampleRun, Denominator]].plot(norm='log'))\n",
"parts = {str(key): results[key] for key in parts}\n",
"parts = {\n",
" key: val.sum('wavelength') if val.bins is None else val.hist()\n",
Expand Down Expand Up @@ -309,7 +312,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
10 changes: 4 additions & 6 deletions src/ess/sans/beam_center_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
from .logging import get_logger
from .types import (
BeamCenter,
CleanDirectBeam,
DetectorBankSizes,
DimsToKeep,
IofQ,
MaskedData,
NeXusDetector,
NormWavelengthTerm,
QBins,
ReturnEvents,
SampleRun,
Expand Down Expand Up @@ -175,9 +175,7 @@ def _iofq_in_quadrants(
workflow[NeXusDetector[SampleRun]] = sc.DataGroup(data=detector[sel])
# MaskedData would be computed automatically, but we did it above already
workflow[MaskedData[SampleRun]] = calibrated[sel]
workflow[NormWavelengthTerm[SampleRun]] = (
norm if norm.dims == ('wavelength',) else norm[sel]
)
workflow[CleanDirectBeam] = norm if norm.dims == ('wavelength',) else norm[sel]
out[quad] = workflow.compute(IofQ[SampleRun])
return out

Expand Down Expand Up @@ -364,7 +362,7 @@ def beam_center_from_iofq(
keys = (
NeXusDetector[SampleRun],
MaskedData[SampleRun],
NormWavelengthTerm[SampleRun],
CleanDirectBeam,
ElasticCoordTransformGraph,
)
workflow = workflow.copy()
Expand All @@ -373,7 +371,7 @@ def beam_center_from_iofq(
results = workflow.compute(keys)
detector = results[NeXusDetector[SampleRun]]['data']
data = results[MaskedData[SampleRun]]
norm = results[NormWavelengthTerm[SampleRun]]
norm = results[CleanDirectBeam]
graph = results[ElasticCoordTransformGraph]

# Avoid reloading the detector
Expand Down
59 changes: 51 additions & 8 deletions src/ess/sans/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,30 @@
)
from scippneutron.conversion.graph import beamline, tof

from ess.reduce.uncertainty import broadcast_uncertainties

from .common import mask_range
from .types import (
CleanQ,
CleanQxy,
CleanSummedQ,
CleanSummedQxy,
CleanWavelength,
CleanWavelengthMasked,
CorrectForGravity,
Denominator,
IofQPart,
MaskedData,
MonitorTerm,
MonitorType,
Numerator,
RunType,
ScatteringRunType,
TofMonitor,
UncertaintyBroadcastMode,
WavelengthMask,
WavelengthMonitor,
WavelengthScaledQ,
WavelengthScaledQxy,
)


Expand Down Expand Up @@ -163,12 +171,44 @@ def detector_to_wavelength(
)


def mask_wavelength(
da: CleanWavelength[ScatteringRunType, IofQPart], mask: WavelengthMask
) -> CleanWavelengthMasked[ScatteringRunType, IofQPart]:
def mask_wavelength_q(
da: CleanSummedQ[ScatteringRunType, Numerator], mask: WavelengthMask
) -> WavelengthScaledQ[ScatteringRunType, Numerator]:
if mask is not None:
da = mask_range(da, mask=mask)
return WavelengthScaledQ[ScatteringRunType, Numerator](da)


def mask_wavelength_qxy(
da: CleanSummedQxy[ScatteringRunType, Numerator], mask: WavelengthMask
) -> WavelengthScaledQxy[ScatteringRunType, Numerator]:
if mask is not None:
da = mask_range(da, mask=mask)
return WavelengthScaledQxy[ScatteringRunType, Numerator](da)


def mask_and_scale_wavelength_q(
da: CleanSummedQ[ScatteringRunType, Denominator],
mask: WavelengthMask,
wavelength_term: MonitorTerm[ScatteringRunType],
uncertainties: UncertaintyBroadcastMode,
) -> WavelengthScaledQ[ScatteringRunType, Denominator]:
da = da * broadcast_uncertainties(wavelength_term, prototype=da, mode=uncertainties)
if mask is not None:
da = mask_range(da, mask=mask)
return WavelengthScaledQ[ScatteringRunType, Denominator](da)


def mask_and_scale_wavelength_qxy(
da: CleanSummedQxy[ScatteringRunType, Denominator],
mask: WavelengthMask,
wavelength_term: MonitorTerm[ScatteringRunType],
uncertainties: UncertaintyBroadcastMode,
) -> WavelengthScaledQxy[ScatteringRunType, Denominator]:
da = da * broadcast_uncertainties(wavelength_term, prototype=da, mode=uncertainties)
if mask is not None:
da = mask_range(da, mask=mask)
return CleanWavelengthMasked[ScatteringRunType, IofQPart](da)
return WavelengthScaledQxy[ScatteringRunType, Denominator](da)


def _compute_Q(
Expand All @@ -186,7 +226,7 @@ def _compute_Q(


def compute_Q(
data: CleanWavelengthMasked[ScatteringRunType, IofQPart],
data: CleanWavelength[ScatteringRunType, IofQPart],
graph: ElasticCoordTransformGraph,
) -> CleanQ[ScatteringRunType, IofQPart]:
"""
Expand All @@ -198,7 +238,7 @@ def compute_Q(


def compute_Qxy(
data: CleanWavelengthMasked[ScatteringRunType, IofQPart],
data: CleanWavelength[ScatteringRunType, IofQPart],
graph: ElasticCoordTransformGraph,
) -> CleanQxy[ScatteringRunType, IofQPart]:
"""
Expand All @@ -214,7 +254,10 @@ def compute_Qxy(
sans_monitor,
monitor_to_wavelength,
detector_to_wavelength,
mask_wavelength,
mask_wavelength_q,
mask_wavelength_qxy,
mask_and_scale_wavelength_q,
mask_and_scale_wavelength_qxy,
compute_Q,
compute_Qxy,
)
22 changes: 11 additions & 11 deletions src/ess/sans/direct_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
from .types import (
BackgroundRun,
BackgroundSubtractedIofQ,
CleanSummedQ,
Denominator,
DirectBeam,
Numerator,
ProcessedWavelengthBands,
SampleRun,
WavelengthBands,
WavelengthBins,
WavelengthScaledQ,
)


Expand Down Expand Up @@ -103,16 +103,16 @@ def direct_beam(*, workflow: Pipeline, I0: sc.Variable, niter: int = 5) -> list[

wavelength_bins = workflow.compute(WavelengthBins)
parts = (
CleanSummedQ[SampleRun, Numerator],
CleanSummedQ[SampleRun, Denominator],
CleanSummedQ[BackgroundRun, Numerator],
CleanSummedQ[BackgroundRun, Denominator],
WavelengthScaledQ[SampleRun, Numerator],
WavelengthScaledQ[SampleRun, Denominator],
WavelengthScaledQ[BackgroundRun, Numerator],
WavelengthScaledQ[BackgroundRun, Denominator],
)
parts = workflow.compute(parts)
# Convert events to histograms to make normalization (in every iteration) cheap
for key in [
CleanSummedQ[SampleRun, Numerator],
CleanSummedQ[BackgroundRun, Numerator],
WavelengthScaledQ[SampleRun, Numerator],
WavelengthScaledQ[BackgroundRun, Numerator],
]:
parts[key] = parts[key].hist(wavelength=wavelength_bins)

Expand All @@ -121,8 +121,8 @@ def direct_beam(*, workflow: Pipeline, I0: sc.Variable, niter: int = 5) -> list[
parts = {key: sc.values(result) for key, result in parts.items()}
for key, part in parts.items():
workflow[key] = part
sample0 = parts[CleanSummedQ[SampleRun, Denominator]]
background0 = parts[CleanSummedQ[BackgroundRun, Denominator]]
sample0 = parts[WavelengthScaledQ[SampleRun, Denominator]]
background0 = parts[WavelengthScaledQ[BackgroundRun, Denominator]]

results = []

Expand Down Expand Up @@ -158,8 +158,8 @@ def direct_beam(*, workflow: Pipeline, I0: sc.Variable, niter: int = 5) -> list[
db.coords['wavelength'] = sc.midpoints(
db.coords['wavelength'], dim='wavelength'
)
workflow[CleanSummedQ[SampleRun, Denominator]] = sample0 * db
workflow[CleanSummedQ[BackgroundRun, Denominator]] = background0 * db
workflow[WavelengthScaledQ[SampleRun, Denominator]] = sample0 * db
workflow[WavelengthScaledQ[BackgroundRun, Denominator]] = background0 * db

results.append(
{
Expand Down
7 changes: 6 additions & 1 deletion src/ess/sans/i_of_q.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,12 @@ def resample_direct_beam(
The direct beam function resampled to the requested resolution.
"""
if direct_beam is None:
return CleanDirectBeam(None)
return CleanDirectBeam(
sc.DataArray(
sc.ones(dims=wavelength_bins.dims, shape=[len(wavelength_bins) - 1]),
coords={'wavelength': wavelength_bins},
)
)
if sc.identical(direct_beam.coords['wavelength'], wavelength_bins):
return direct_beam
if direct_beam.variances is not None:
Expand Down
Loading

0 comments on commit 8826f6b

Please sign in to comment.