Skip to content

Commit df83bfe

Browse files
committed
refactor: share parts of the workflow between amor and estia
1 parent 4db26a8 commit df83bfe

File tree

9 files changed

+40
-49
lines changed

9 files changed

+40
-49
lines changed

src/ess/amor/conversions.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,15 @@
66
from ..reflectometry.conversions import reflectometry_q
77
from ..reflectometry.types import (
88
BeamDivergenceLimits,
9+
CoordTransformationGraph,
910
WavelengthBins,
1011
YIndexLimits,
1112
ZIndexLimits,
1213
)
1314
from .geometry import Detector
14-
from .types import CoordTransformationGraph
1515

1616

17-
def theta(wavelength, divergence_angle, L2, sample_rotation, detector_rotation):
17+
def theta(wavelength, pixel_divergence_angle, L2, sample_rotation, detector_rotation):
1818
'''
1919
Angle of reflection.
2020
@@ -61,14 +61,15 @@ def theta(wavelength, divergence_angle, L2, sample_rotation, detector_rotation):
6161
'''
6262
c = sc.constants.g * sc.constants.m_n**2 / sc.constants.h**2
6363
out = (c * L2 * wavelength**2).to(unit='dimensionless') + sc.sin(
64-
divergence_angle.to(unit='rad', copy=False) + detector_rotation.to(unit='rad')
64+
pixel_divergence_angle.to(unit='rad', copy=False)
65+
+ detector_rotation.to(unit='rad')
6566
)
6667
out = sc.asin(out, out=out)
6768
out -= sample_rotation.to(unit='rad')
6869
return out
6970

7071

71-
def angle_of_divergence(theta, sample_rotation, angle_to_center_of_beam):
72+
def divergence_angle(theta, sample_rotation, angle_to_center_of_beam):
7273
"""
7374
Difference between the incident angle and the center of the incident beam.
7475
Useful for filtering parts of the beam that have too high divergence.
@@ -84,7 +85,7 @@ def angle_of_divergence(theta, sample_rotation, angle_to_center_of_beam):
8485

8586

8687
def wavelength(
87-
event_time_offset, divergence_angle, L1, L2, chopper_phase, chopper_frequency
88+
event_time_offset, pixel_divergence_angle, L1, L2, chopper_phase, chopper_frequency
8889
):
8990
"Converts event_time_offset to wavelength using the chopper settings."
9091
out = event_time_offset.to(unit="ns", dtype="float64", copy=True)
@@ -108,37 +109,22 @@ def wavelength(
108109
)
109110
# Correction for path length through guides being different
110111
# depending on incident angle.
111-
out -= (divergence_angle.to(unit="rad") / (np.pi * sc.units.rad)) * tau
112+
out -= (pixel_divergence_angle.to(unit="rad") / (np.pi * sc.units.rad)) * tau
112113
out *= (sc.constants.h / sc.constants.m_n) / (L1 + L2)
113114
return out.to(unit='angstrom', copy=False)
114115

115116

116117
def coordinate_transformation_graph() -> CoordTransformationGraph:
117118
return {
118-
"divergence_angle": "pixel_divergence_angle",
119119
"wavelength": wavelength,
120120
"theta": theta,
121-
"angle_of_divergence": angle_of_divergence,
121+
"divergence_angle": divergence_angle,
122122
"Q": reflectometry_q,
123123
"L1": lambda chopper_distance: sc.abs(chopper_distance),
124124
"L2": lambda distance_in_detector: distance_in_detector + Detector.distance,
125125
}
126126

127127

128-
def add_coords(
129-
da: sc.DataArray,
130-
graph: dict,
131-
) -> sc.DataArray:
132-
"Adds scattering coordinates to the raw detector data."
133-
return da.transform_coords(
134-
("wavelength", "theta", "angle_of_divergence", "Q", "L1", "L2"),
135-
graph,
136-
rename_dims=False,
137-
keep_intermediate=False,
138-
keep_aliases=False,
139-
)
140-
141-
142128
def _not_between(v, a, b):
143129
return (v < a) | (v > b)
144130

@@ -161,9 +147,9 @@ def add_masks(
161147
)
162148
da = da.bins.assign_masks(
163149
divergence_too_large=_not_between(
164-
da.bins.coords["angle_of_divergence"],
165-
bdlim[0].to(unit=da.bins.coords["angle_of_divergence"].bins.unit),
166-
bdlim[1].to(unit=da.bins.coords["angle_of_divergence"].bins.unit),
150+
da.bins.coords["divergence_angle"],
151+
bdlim[0].to(unit=da.bins.coords["divergence_angle"].bins.unit),
152+
bdlim[1].to(unit=da.bins.coords["divergence_angle"].bins.unit),
167153
),
168154
wavelength=_not_between(
169155
da.bins.coords['wavelength'],

src/ess/amor/normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
supermirror_reflectivity,
99
)
1010
from ..reflectometry.types import (
11+
CoordTransformationGraph,
1112
DetectorSpatialResolution,
1213
ReducedReference,
1314
ReducibleData,
@@ -22,7 +23,6 @@
2223
sample_size_resolution,
2324
wavelength_resolution,
2425
)
25-
from .types import CoordTransformationGraph
2626

2727

2828
def mask_events_where_supermirror_does_not_cover(

src/ess/amor/types.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
AngularResolution = NewType("AngularResolution", sc.Variable)
1010
SampleSizeResolution = NewType("SampleSizeResolution", sc.Variable)
1111

12-
CoordTransformationGraph = NewType("CoordTransformationGraph", dict)
13-
1412

1513
class ChopperFrequency(sciline.Scope[RunType, sc.Variable], sc.Variable):
1614
"""Frequency of the choppers used in the run."""

src/ess/amor/workflow.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from ..reflectometry.conversions import (
2+
add_coords,
23
add_proton_current_coord,
34
add_proton_current_mask,
45
)
56
from ..reflectometry.corrections import correct_by_footprint, correct_by_proton_current
67
from ..reflectometry.types import (
78
BeamDivergenceLimits,
9+
CoordTransformationGraph,
810
ProtonCurrent,
911
RawDetectorData,
1012
ReducibleData,
@@ -13,8 +15,7 @@
1315
YIndexLimits,
1416
ZIndexLimits,
1517
)
16-
from .conversions import add_coords, add_masks
17-
from .types import CoordTransformationGraph
18+
from .conversions import add_masks
1819

1920

2021
def add_coords_masks_and_apply_corrections(

src/ess/estia/conversions.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
from ..reflectometry.conversions import reflectometry_q
66
from ..reflectometry.types import (
77
BeamDivergenceLimits,
8+
CoordTransformationGraph,
89
WavelengthBins,
910
YIndexLimits,
1011
ZIndexLimits,
1112
)
12-
from .types import CoordTransformationGraph
1313

1414

1515
def theta(
@@ -65,20 +65,6 @@ def coordinate_transformation_graph() -> CoordTransformationGraph:
6565
}
6666

6767

68-
def add_coords(
69-
da: sc.DataArray,
70-
graph: dict,
71-
) -> sc.DataArray:
72-
"Adds scattering coordinates to the raw detector data."
73-
return da.transform_coords(
74-
("wavelength", "theta", "divergence_angle", "Q", "L1", "L2"),
75-
graph,
76-
rename_dims=False,
77-
keep_intermediate=False,
78-
keep_aliases=False,
79-
)
80-
81-
8268
def _not_between(v, a, b):
8369
return (v < a) | (v > b)
8470

src/ess/estia/normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
supermirror_reflectivity,
99
)
1010
from ..reflectometry.types import (
11+
CoordTransformationGraph,
1112
DetectorSpatialResolution,
1213
ReducedReference,
1314
ReducibleData,
@@ -21,7 +22,6 @@
2122
q_resolution,
2223
sample_size_resolution,
2324
)
24-
from .types import CoordTransformationGraph
2525

2626

2727
def mask_events_where_supermirror_does_not_cover(

src/ess/estia/workflow.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from ..reflectometry.conversions import (
2+
add_coords,
23
add_proton_current_coord,
34
add_proton_current_mask,
45
)
56
from ..reflectometry.corrections import correct_by_proton_current
67
from ..reflectometry.types import (
78
BeamDivergenceLimits,
9+
CoordTransformationGraph,
810
ProtonCurrent,
911
RawDetectorData,
1012
ReducibleData,
@@ -13,9 +15,8 @@
1315
YIndexLimits,
1416
ZIndexLimits,
1517
)
16-
from .conversions import add_coords, add_masks
18+
from .conversions import add_masks
1719
from .corrections import correct_by_footprint
18-
from .types import CoordTransformationGraph
1920

2021

2122
def add_coords_masks_and_apply_corrections(

src/ess/reflectometry/conversions.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from scipp.constants import pi
55
from scippneutron._utils import elem_dtype
66

7-
from .types import ProtonCurrent, RunType
7+
from .types import (
8+
ProtonCurrent,
9+
RunType,
10+
)
811

912

1013
def reflectometry_q(wavelength: sc.Variable, theta: sc.Variable) -> sc.Variable:
@@ -61,4 +64,18 @@ def add_proton_current_mask(da: sc.DataArray) -> sc.DataArray:
6164
return da
6265

6366

67+
def add_coords(
68+
da: sc.DataArray,
69+
graph: dict,
70+
) -> sc.DataArray:
71+
"Adds scattering coordinates to the raw detector data."
72+
return da.transform_coords(
73+
("wavelength", "theta", "divergence_angle", "Q", "L1", "L2"),
74+
graph,
75+
rename_dims=False,
76+
keep_intermediate=False,
77+
keep_aliases=False,
78+
)
79+
80+
6481
providers = ()

src/ess/reflectometry/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
SampleRun = NewType("SampleRun", str)
88
RunType = TypeVar("RunType", ReferenceRun, SampleRun)
99

10+
CoordTransformationGraph = NewType("CoordTransformationGraph", dict)
11+
1012

1113
class NeXusDetectorName(sciline.Scope[RunType, str], str):
1214
"""Name of the detector in the nexus file containing the events of the RunType"""

0 commit comments

Comments
 (0)