Skip to content

Commit 11319f8

Browse files
tamaranormanTorax team
authored and
Torax team
committed
Support prescribed sources for multiple affected profiles
Fixes #806 PiperOrigin-RevId: 744684193
1 parent 8f95c3d commit 11319f8

25 files changed

+295
-153
lines changed

docs/configuration.rst

+4-4
Original file line numberDiff line numberDiff line change
@@ -918,9 +918,9 @@ The configurable runtime parameters of each source are as follows:
918918
This is documented in the individual source sections.
919919

920920
* ``'PRESCRIBED'``
921-
Source values are arbitrarily prescribed by the user. The value is set by ``prescribed_values``, and can contain the same
922-
data structures as :ref:`Time-varying arrays`. Currently, this is only supported for sources that have a 1D output
923-
along the cell grid or face grid.
921+
Source values are arbitrarily prescribed by the user. The value is set by
922+
``prescribed_values``, and should be a tuple of values. Each value can
923+
contain the same data structures as :ref:`Time-varying arrays`.
924924

925925
For example, to set 'fusion_power' to zero, e.g. for testing or sensitivity purposes, set:
926926

@@ -938,7 +938,7 @@ preamble to the CONFIG dict within config module, set:
938938
'sources': {
939939
'generic_current_source': {
940940
'mode': 'PRESCRIBED',
941-
'prescribed_values': (times, rhon, current_profiles),
941+
'prescribed_values': ((times, rhon, current_profiles),),
942942
},
943943
944944
where the example ``times`` is a 1D numpy array of times, ``rhon`` is a 1D numpy array of normalized toroidal flux

torax/sources/base.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ class SourceModelBase(torax_pydantic.BaseModelFrozen, abc.ABC):
4040
implicit or explicit. For example, file-based sources are always explicit.
4141
If an incorrect combination of source type and is_explicit is passed in,
4242
an error will be thrown when running the simulation.
43-
prescribed_values: Prescribed values for the source. Used only when the
44-
source is fully prescribed (i.e. source.mode == Mode.PRESCRIBED). The
45-
default here is a vector of all zeros along for all rho and time, and the
46-
output vector is along the cell grid.
43+
prescribed_values: Tuple of prescribed values for the source, one for each
44+
affected core profile. Used only when thesource is fully prescribed (i.e.
45+
source.mode == Mode.PRESCRIBED). The default here is a vector of all zeros
46+
along for all rho and time, and the output vector is along the cell grid.
4747
"""
4848
mode: runtime_params.Mode = runtime_params.Mode.ZERO
4949
is_explicit: bool = False
50-
prescribed_values: torax_pydantic.TimeVaryingArray = (
51-
torax_pydantic.ValidatedDefault({0: {0: 0, 1: 0}})
50+
prescribed_values: tuple[torax_pydantic.TimeVaryingArray, ...] = (
51+
torax_pydantic.ValidatedDefault(({0: {0: 0, 1: 0}},))
5252
)
5353

5454
def build_static_params(self) -> runtime_params.StaticRuntimeParams:

torax/sources/bootstrap_current_source.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def build_dynamic_params(
151151
t: chex.Numeric,
152152
) -> DynamicRuntimeParams:
153153
return DynamicRuntimeParams(
154-
prescribed_values=self.prescribed_values.get_value(t),
154+
prescribed_values=tuple(
155+
[v.get_value(t) for v in self.prescribed_values]
156+
),
155157
bootstrap_mult=self.bootstrap_mult,
156158
)
157159

torax/sources/bremsstrahlung_heat_sink.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,9 @@ def build_dynamic_params(
159159
t: chex.Numeric,
160160
) -> 'DynamicRuntimeParams':
161161
return DynamicRuntimeParams(
162-
prescribed_values=self.prescribed_values.get_value(t),
162+
prescribed_values=tuple(
163+
[v.get_value(t) for v in self.prescribed_values]
164+
),
163165
use_relativistic_correction=self.use_relativistic_correction,
164166
)
165167

torax/sources/cyclotron_radiation_heat_sink.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,9 @@ def build_dynamic_params(
415415
t: chex.Numeric,
416416
) -> 'DynamicRuntimeParams':
417417
return DynamicRuntimeParams(
418-
prescribed_values=self.prescribed_values.get_value(t),
418+
prescribed_values=tuple(
419+
[v.get_value(t) for v in self.prescribed_values]
420+
),
419421
wall_reflection_coeff=self.wall_reflection_coeff,
420422
)
421423

torax/sources/electron_cyclotron_source.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ def build_dynamic_params(
176176
t: chex.Numeric,
177177
) -> DynamicRuntimeParams:
178178
return DynamicRuntimeParams(
179-
prescribed_values=self.prescribed_values.get_value(t),
179+
prescribed_values=tuple(
180+
[v.get_value(t) for v in self.prescribed_values]
181+
),
180182
cd_efficiency=self.cd_efficiency.get_value(t),
181183
manual_ec_power_density=self.manual_ec_power_density.get_value(t),
182184
gaussian_ec_power_density_width=self.gaussian_ec_power_density_width.get_value(

torax/sources/fusion_heat_source.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ def build_dynamic_params(
198198
t: chex.Numeric,
199199
) -> runtime_params_lib.DynamicRuntimeParams:
200200
return runtime_params_lib.DynamicRuntimeParams(
201-
prescribed_values=self.prescribed_values.get_value(t),
201+
prescribed_values=tuple(
202+
[v.get_value(t) for v in self.prescribed_values]
203+
),
202204
)
203205

204206
def build_source(self) -> FusionHeatSource:

torax/sources/gas_puff_source.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def build_dynamic_params(
110110
t: chex.Numeric,
111111
) -> DynamicGasPuffRuntimeParams:
112112
return DynamicGasPuffRuntimeParams(
113-
prescribed_values=self.prescribed_values.get_value(t),
113+
prescribed_values=tuple(
114+
[v.get_value(t) for v in self.prescribed_values]
115+
),
114116
puff_decay_length=self.puff_decay_length.get_value(t),
115117
S_puff_tot=self.S_puff_tot.get_value(t),
116118
)

torax/sources/generic_current_source.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def build_dynamic_params(
151151
t: chex.Numeric,
152152
) -> DynamicRuntimeParams:
153153
return DynamicRuntimeParams(
154-
prescribed_values=self.prescribed_values.get_value(t),
154+
prescribed_values=tuple(
155+
[v.get_value(t) for v in self.prescribed_values]
156+
),
155157
Iext=self.Iext.get_value(t),
156158
fext=self.fext.get_value(t),
157159
wext=self.wext.get_value(t),

torax/sources/generic_ion_el_heat_source.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@ def build_dynamic_params(
158158
t: chex.Numeric,
159159
) -> DynamicRuntimeParams:
160160
return DynamicRuntimeParams(
161-
prescribed_values=self.prescribed_values.get_value(t),
161+
prescribed_values=tuple(
162+
[v.get_value(t) for v in self.prescribed_values]
163+
),
162164
w=self.w.get_value(t),
163165
rsource=self.rsource.get_value(t),
164166
Ptot=self.Ptot.get_value(t),

torax/sources/generic_particle_source.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,9 @@ def build_dynamic_params(
121121
t: chex.Numeric,
122122
) -> DynamicParticleRuntimeParams:
123123
return DynamicParticleRuntimeParams(
124-
prescribed_values=self.prescribed_values.get_value(t),
124+
prescribed_values=tuple(
125+
[v.get_value(t) for v in self.prescribed_values]
126+
),
125127
particle_width=self.particle_width.get_value(t),
126128
deposition_location=self.deposition_location.get_value(t),
127129
S_tot=self.S_tot.get_value(t),

torax/sources/impurity_radiation_heat_sink/impurity_radiation_constant_fraction.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,9 @@ def build_dynamic_params(
112112
t: chex.Numeric,
113113
) -> 'DynamicRuntimeParams':
114114
return DynamicRuntimeParams(
115-
prescribed_values=self.prescribed_values.get_value(t),
115+
prescribed_values=tuple(
116+
[v.get_value(t) for v in self.prescribed_values]
117+
),
116118
fraction_of_total_power_density=self.fraction_of_total_power_density.get_value(
117119
t
118120
),

torax/sources/impurity_radiation_heat_sink/impurity_radiation_mavrin_fit.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,9 @@ def build_dynamic_params(
242242
t: chex.Numeric,
243243
) -> 'DynamicRuntimeParams':
244244
return DynamicRuntimeParams(
245-
prescribed_values=self.prescribed_values.get_value(t),
245+
prescribed_values=tuple(
246+
[v.get_value(t) for v in self.prescribed_values]
247+
),
246248
radiation_multiplier=self.radiation_multiplier,
247249
)
248250

torax/sources/ion_cyclotron_source.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,9 @@ def build_dynamic_params(
489489
t: chex.Numeric,
490490
) -> DynamicRuntimeParams:
491491
return DynamicRuntimeParams(
492-
prescribed_values=self.prescribed_values.get_value(t),
492+
prescribed_values=tuple(
493+
[v.get_value(t) for v in self.prescribed_values]
494+
),
493495
wall_inner=self.wall_inner,
494496
wall_outer=self.wall_outer,
495497
frequency=self.frequency.get_value(t),

torax/sources/ohmic_heat_source.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def build_dynamic_params(
103103
t: chex.Numeric,
104104
) -> runtime_params_lib.DynamicRuntimeParams:
105105
return runtime_params_lib.DynamicRuntimeParams(
106-
prescribed_values=self.prescribed_values.get_value(t),
106+
prescribed_values=tuple(
107+
[v.get_value(t) for v in self.prescribed_values]
108+
),
107109
)
108110

109111
def build_source(self) -> OhmicHeatSource:

torax/sources/pellet_source.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,9 @@ def build_dynamic_params(
116116
t: chex.Numeric,
117117
) -> DynamicPelletRuntimeParams:
118118
return DynamicPelletRuntimeParams(
119-
prescribed_values=self.prescribed_values.get_value(t),
119+
prescribed_values=tuple(
120+
[v.get_value(t) for v in self.prescribed_values]
121+
),
120122
pellet_width=self.pellet_width.get_value(t),
121123
pellet_deposition_location=self.pellet_deposition_location.get_value(t),
122124
S_pellet_tot=self.S_pellet_tot.get_value(t),

torax/sources/qei_source.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,9 @@ def build_dynamic_params(
166166
t: chex.Numeric,
167167
) -> DynamicRuntimeParams:
168168
return DynamicRuntimeParams(
169-
prescribed_values=self.prescribed_values.get_value(t),
169+
prescribed_values=tuple(
170+
[v.get_value(t) for v in self.prescribed_values]
171+
),
170172
Qei_mult=self.Qei_mult,
171173
)
172174

torax/sources/runtime_params.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class DynamicRuntimeParams:
4848
stateless, so these params are their inputs to determine their output
4949
profiles.
5050
"""
51-
prescribed_values: array_typing.ArrayFloat
51+
prescribed_values: tuple[array_typing.ArrayFloat, ...]
5252

5353

5454
@chex.dataclass(frozen=True)

torax/sources/source.py

+10-3
Original file line numberDiff line numberDiff line change
@@ -163,9 +163,16 @@ def get_value(
163163
calculated_source_profiles,
164164
)
165165
case runtime_params_lib.Mode.PRESCRIBED.value:
166-
# TODO(b/395854896) add support for sources that affect multiple core
167-
# profiles.
168-
return (dynamic_source_runtime_params.prescribed_values,)
166+
if len(self.affected_core_profiles) != len(
167+
dynamic_source_runtime_params.prescribed_values
168+
):
169+
raise ValueError(
170+
'When using PRESCRIBED mode, the number of prescribed values must'
171+
' match the number of affected core profiles. Was: '
172+
f'{len(dynamic_source_runtime_params.prescribed_values)} '
173+
f' Expected: {len(self.affected_core_profiles)}.'
174+
)
175+
return dynamic_source_runtime_params.prescribed_values
169176
case runtime_params_lib.Mode.ZERO.value:
170177
zeros = jnp.zeros(geo.rho_norm.shape)
171178
return (zeros,) * len(self.affected_core_profiles)

torax/sources/tests/pydantic_model_test.py

+82
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from absl.testing import absltest
1717
from absl.testing import parameterized
18+
import numpy as np
1819
from torax.sources import base
1920
from torax.sources import bootstrap_current_source
2021
from torax.sources import fusion_heat_source
@@ -26,6 +27,7 @@
2627
from torax.sources import source_models as source_models_lib
2728
from torax.sources.impurity_radiation_heat_sink import impurity_radiation_constant_fraction
2829
from torax.sources.impurity_radiation_heat_sink import impurity_radiation_mavrin_fit
30+
from torax.torax_pydantic import torax_pydantic
2931

3032

3133
class PydanticModelTest(parameterized.TestCase):
@@ -150,6 +152,86 @@ def test_empty_source_config_only_has_defaults_turned_off(self):
150152
)
151153
self.assertLen(sources.source_model_config, 3)
152154

155+
def test_adding_a_source_with_prescribed_values(self):
156+
"""Tests that a source can be added with overriding defaults."""
157+
sources = pydantic_model.Sources.from_dict({
158+
'generic_current_source': {
159+
'mode': 'PRESCRIBED',
160+
'prescribed_values': ((
161+
np.array([0.0, 1.0, 2.0, 3.0]),
162+
np.array([0., 0.5, 1.0]),
163+
np.full([4, 3], 42)
164+
),),
165+
},
166+
'electron_cyclotron_source': {
167+
'mode': 'PRESCRIBED',
168+
'prescribed_values': (
169+
3.,
170+
4.,
171+
),
172+
}
173+
})
174+
mesh = torax_pydantic.Grid1D(nx=4, dx=0.25)
175+
torax_pydantic.set_grid(sources, mesh)
176+
source = sources.source_model_config['generic_current_source']
177+
self.assertLen(source.prescribed_values, 1)
178+
self.assertIsInstance(
179+
source.prescribed_values[0], torax_pydantic.TimeVaryingArray)
180+
source = sources.source_model_config['electron_cyclotron_source']
181+
self.assertLen(source.prescribed_values, 2)
182+
self.assertIsInstance(
183+
source.prescribed_values[0], torax_pydantic.TimeVaryingArray)
184+
self.assertIsInstance(
185+
source.prescribed_values[1], torax_pydantic.TimeVaryingArray)
186+
value = source.prescribed_values[0].get_value(0.0)
187+
np.testing.assert_equal(value, 3.)
188+
value = source.prescribed_values[1].get_value(0.0)
189+
np.testing.assert_equal(value, 4.)
190+
191+
def test_bremsstrahlung_and_mavrin_validator_with_bremsstrahlung_zero(self):
192+
valid_config = {
193+
'bremsstrahlung_heat_sink': {'mode': 'ZERO'},
194+
'impurity_radiation_heat_sink': {
195+
'mode': 'PRESCRIBED',
196+
'model_function_name': 'impurity_radiation_mavrin_fit',
197+
},
198+
}
199+
pydantic_model.Sources.from_dict(valid_config)
200+
201+
def test_bremsstrahlung_and_mavrin_validator_with_mavrin_zero(self):
202+
valid_config = {
203+
'bremsstrahlung_heat_sink': {'mode': 'PRESCRIBED'},
204+
'impurity_radiation_heat_sink': {
205+
'mode': 'ZERO',
206+
'model_function_name': 'impurity_radiation_mavrin_fit',
207+
},
208+
}
209+
pydantic_model.Sources.from_dict(valid_config)
210+
211+
def test_bremsstrahlung_and_mavrin_validator_with_constant_fraction(self):
212+
valid_config = {
213+
'bremsstrahlung_heat_sink': {'mode': 'PRESCRIBED'},
214+
'impurity_radiation_heat_sink': {
215+
'mode': 'PRESCRIBED',
216+
'model_function_name': 'radially_constant_fraction_of_Pin',
217+
},
218+
}
219+
pydantic_model.Sources.from_dict(valid_config)
220+
221+
def test_bremsstrahlung_and_mavrin_validator_with_invalid_config(self):
222+
invalid_config = {
223+
'bremsstrahlung_heat_sink': {'mode': 'PRESCRIBED'},
224+
'impurity_radiation_heat_sink': {
225+
'mode': 'PRESCRIBED',
226+
'model_function_name': 'impurity_radiation_mavrin_fit',
227+
},
228+
}
229+
with self.assertRaisesRegex(
230+
ValueError,
231+
'Both bremsstrahlung_heat_sink and impurity_radiation_heat_sink',
232+
):
233+
pydantic_model.Sources.from_dict(invalid_config)
234+
153235

154236
if __name__ == '__main__':
155237
absltest.main()

torax/sources/tests/source_profile_builders_test.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def affected_core_profiles(
130130
runtime_params_slice.DynamicRuntimeParamsSlice,
131131
sources={
132132
'foo': source_runtime_params.DynamicRuntimeParams(
133-
prescribed_values=jnp.ones(self.geo.rho.shape)
133+
prescribed_values=(jnp.ones(self.geo.rho.shape),)
134134
)
135135
},
136136
)
@@ -191,7 +191,8 @@ def affected_core_profiles(
191191
runtime_params_slice.DynamicRuntimeParamsSlice,
192192
sources={
193193
'foo': source_runtime_params.DynamicRuntimeParams(
194-
prescribed_values=jnp.ones(self.geo.rho.shape)
194+
prescribed_values=(jnp.ones(self.geo.rho.shape),
195+
jnp.ones(self.geo.rho.shape))
195196
)
196197
},
197198
)
@@ -279,7 +280,7 @@ def affected_core_profiles(
279280
runtime_params_slice.DynamicRuntimeParamsSlice,
280281
sources={
281282
'foo': source_runtime_params.DynamicRuntimeParams(
282-
prescribed_values=jnp.ones(self.geo.rho.shape)
283+
prescribed_values=(jnp.ones(self.geo.rho.shape),)
283284
)
284285
},
285286
)

0 commit comments

Comments
 (0)