Skip to content

Commit 931a1df

Browse files
tamaranormanTorax team
authored and
Torax team
committed
Change build_static_params to take runtime_params separatly
PiperOrigin-RevId: 743963715
1 parent 4d7ff98 commit 931a1df

35 files changed

+485
-426
lines changed

torax/config/build_runtime_params.py

+21-11
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@
2323
DynamicRuntimeParamsSlice and a corresponding geometry with consistent Ip.
2424
"""
2525
import chex
26+
from torax.config import numerics as numerics_lib
27+
from torax.config import plasma_composition as plasma_composition_lib
28+
from torax.config import profile_conditions as profile_conditions_lib
2629
from torax.config import runtime_params as general_runtime_params_lib
2730
from torax.config import runtime_params_slice
2831
from torax.geometry import geometry
@@ -38,16 +41,23 @@
3841

3942
def build_static_runtime_params_slice(
4043
*,
41-
runtime_params: general_runtime_params_lib.GeneralRuntimeParams,
44+
profile_conditions: profile_conditions_lib.ProfileConditions,
45+
numerics: numerics_lib.Numerics,
46+
plasma_composition: plasma_composition_lib.PlasmaComposition,
4247
sources: sources_pydantic_model.Sources,
4348
torax_mesh: torax_pydantic.Grid1D,
4449
stepper: stepper_pydantic_model.Stepper | None = None,
4550
) -> runtime_params_slice.StaticRuntimeParamsSlice:
4651
"""Builds a StaticRuntimeParamsSlice.
4752
4853
Args:
49-
runtime_params: General runtime params from which static params are taken,
50-
which are the choices on equations being solved, and adaptive dt.
54+
profile_conditions: Profile conditions from which the profile conditions
55+
static variables are taken, which are the boundary conditions for the
56+
plasma.
57+
numerics: Numerics from which the numerics static variables are taken, which
58+
are the equations being solved, adaptive dt, and the fixed dt.
59+
plasma_composition: Plasma composition from which the plasma composition
60+
static variables are taken, which are the main and impurity ion names.
5161
sources: data from which the source related static variables are taken,
5262
which are the explicit/implicit toggle and calculation mode for each
5363
source.
@@ -70,14 +80,14 @@ def build_static_runtime_params_slice(
7080
},
7181
torax_mesh=torax_mesh,
7282
stepper=stepper.build_static_params(),
73-
ion_heat_eq=runtime_params.numerics.ion_heat_eq,
74-
el_heat_eq=runtime_params.numerics.el_heat_eq,
75-
current_eq=runtime_params.numerics.current_eq,
76-
dens_eq=runtime_params.numerics.dens_eq,
77-
main_ion_names=runtime_params.plasma_composition.get_main_ion_names(),
78-
impurity_names=runtime_params.plasma_composition.get_impurity_names(),
79-
adaptive_dt=runtime_params.numerics.adaptive_dt,
80-
use_vloop_lcfs_boundary_condition=runtime_params.profile_conditions.use_vloop_lcfs_boundary_condition,
83+
ion_heat_eq=numerics.ion_heat_eq,
84+
el_heat_eq=numerics.el_heat_eq,
85+
current_eq=numerics.current_eq,
86+
dens_eq=numerics.dens_eq,
87+
main_ion_names=plasma_composition.get_main_ion_names(),
88+
impurity_names=plasma_composition.get_impurity_names(),
89+
adaptive_dt=numerics.adaptive_dt,
90+
use_vloop_lcfs_boundary_condition=profile_conditions.use_vloop_lcfs_boundary_condition,
8191
)
8292

8393

torax/config/tests/runtime_params_slice_test.py

+43-28
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,34 @@
1313
# limitations under the License.
1414

1515

16+
import copy
17+
1618
from absl.testing import absltest
1719
from absl.testing import parameterized
1820
import jax
1921
from torax.config import build_runtime_params
20-
from torax.config import runtime_params as general_runtime_params
2122
from torax.config import runtime_params_slice as runtime_params_slice_lib
22-
from torax.geometry import pydantic_model as geometry_pydantic_model
2323
from torax.tests.test_lib import default_sources
24-
from torax.torax_pydantic import torax_pydantic
24+
from torax.torax_pydantic import model_config
2525

2626

2727
class RuntimeParamsSliceTest(parameterized.TestCase):
2828

2929
def setUp(self):
3030
super().setUp()
31-
self._geo = geometry_pydantic_model.CircularConfig().build_geometry()
31+
self._python_config = {
32+
'runtime_params': {'numerics': {}},
33+
'geometry': {
34+
'geometry_type': 'circular',
35+
},
36+
'pedestal': {},
37+
'sources': default_sources.get_default_source_config(),
38+
'stepper': {},
39+
'time_step_calculator': {},
40+
'transport': {},
41+
}
42+
self._torax_config = model_config.ToraxConfig.from_dict(self._python_config)
43+
self._torax_mesh = self._torax_config.geometry.build_provider.torax_mesh
3244

3345
def test_dynamic_slice_can_be_input_to_jitted_function(self):
3446
"""Tests that the slice can be input to a jitted function."""
@@ -39,53 +51,56 @@ def foo(
3951
_ = runtime_params_slice # do nothing.
4052

4153
foo_jitted = jax.jit(foo)
42-
runtime_params = general_runtime_params.GeneralRuntimeParams()
43-
torax_pydantic.set_grid(runtime_params, self._geo.torax_mesh)
54+
runtime_params = self._torax_config.runtime_params
4455
dynamic_slice = build_runtime_params.DynamicRuntimeParamsSliceProvider(
4556
runtime_params,
46-
torax_mesh=self._geo.torax_mesh,
57+
torax_mesh=self._torax_mesh,
4758
)(
48-
t=runtime_params.numerics.t_initial,
59+
t=self._torax_config.numerics.t_initial,
4960
)
5061
# Make sure you can call the function with dynamic_slice as an arg.
5162
foo_jitted(dynamic_slice)
5263

5364
def test_static_runtime_params_slice_hash_same_for_same_params(self):
5465
"""Tests that the hash is the same for the same static params."""
55-
runtime_params = general_runtime_params.GeneralRuntimeParams()
56-
torax_pydantic.set_grid(runtime_params, self._geo.torax_mesh)
57-
sources = default_sources.get_default_sources()
5866
static_slice1 = build_runtime_params.build_static_runtime_params_slice(
59-
runtime_params=runtime_params,
60-
sources=sources,
61-
torax_mesh=self._geo.torax_mesh,
67+
profile_conditions=self._torax_config.profile_conditions,
68+
numerics=self._torax_config.numerics,
69+
plasma_composition=self._torax_config.plasma_composition,
70+
sources=self._torax_config.sources,
71+
torax_mesh=self._torax_mesh,
6272
)
6373
static_slice2 = build_runtime_params.build_static_runtime_params_slice(
64-
runtime_params=runtime_params,
65-
sources=sources,
66-
torax_mesh=self._geo.torax_mesh,
74+
profile_conditions=self._torax_config.profile_conditions,
75+
numerics=self._torax_config.numerics,
76+
plasma_composition=self._torax_config.plasma_composition,
77+
sources=self._torax_config.sources,
78+
torax_mesh=self._torax_mesh,
6779
)
6880
self.assertEqual(hash(static_slice1), hash(static_slice2))
6981

7082
def test_static_runtime_params_slice_hash_different_for_different_params(
7183
self,
7284
):
7385
"""Test that the hash changes when the static params change."""
74-
runtime_params = general_runtime_params.GeneralRuntimeParams()
75-
sources = default_sources.get_default_sources()
7686
static_slice1 = build_runtime_params.build_static_runtime_params_slice(
77-
runtime_params=runtime_params,
78-
sources=sources,
79-
torax_mesh=self._geo.torax_mesh,
87+
profile_conditions=self._torax_config.profile_conditions,
88+
numerics=self._torax_config.numerics,
89+
plasma_composition=self._torax_config.plasma_composition,
90+
sources=self._torax_config.sources,
91+
torax_mesh=self._torax_mesh,
8092
)
81-
runtime_params_mod = runtime_params.model_copy()
82-
runtime_params_mod._update_fields(
83-
{'numerics.ion_heat_eq': not runtime_params.numerics.ion_heat_eq}
93+
new_config = copy.deepcopy(self._python_config)
94+
new_config['runtime_params']['numerics']['ion_heat_eq'] = (
95+
not self._torax_config.runtime_params.numerics.ion_heat_eq
8496
)
97+
new_torax_config = model_config.ToraxConfig.from_dict(new_config)
8598
static_slice2 = build_runtime_params.build_static_runtime_params_slice(
86-
runtime_params=runtime_params_mod,
87-
sources=sources,
88-
torax_mesh=self._geo.torax_mesh,
99+
profile_conditions=new_torax_config.profile_conditions,
100+
numerics=new_torax_config.numerics,
101+
plasma_composition=new_torax_config.plasma_composition,
102+
sources=new_torax_config.sources,
103+
torax_mesh=self._torax_mesh,
89104
)
90105
self.assertNotEqual(hash(static_slice1), hash(static_slice2))
91106

torax/core_profiles/tests/getters_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,9 @@ def test_get_ion_density_and_charge_states(self):
266266
torax_mesh=self.geo.torax_mesh,
267267
)
268268
static_slice = build_runtime_params.build_static_runtime_params_slice(
269-
runtime_params=runtime_params,
269+
profile_conditions=runtime_params.profile_conditions,
270+
numerics=runtime_params.numerics,
271+
plasma_composition=runtime_params.plasma_composition,
270272
sources=sources,
271273
torax_mesh=self.geo.torax_mesh,
272274
)

torax/core_profiles/tests/initialization_test.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,9 @@ def test_initial_psi(
106106
t=1.0,
107107
)
108108
static_slice = build_runtime_params.build_static_runtime_params_slice(
109-
runtime_params=runtime_params,
109+
profile_conditions=runtime_params.profile_conditions,
110+
numerics=runtime_params.numerics,
111+
plasma_composition=runtime_params.plasma_composition,
110112
sources=sources,
111113
torax_mesh=self.geo.torax_mesh,
112114
)

torax/core_profiles/tests/updaters_test.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,9 @@ def test_compute_boundary_conditions_ne(
134134
)
135135
sources = source_pydantic_model.Sources()
136136
static_slice = build_runtime_params.build_static_runtime_params_slice(
137-
runtime_params=runtime_params,
137+
profile_conditions=runtime_params.profile_conditions,
138+
numerics=runtime_params.numerics,
139+
plasma_composition=runtime_params.plasma_composition,
138140
sources=sources,
139141
torax_mesh=self.geo.torax_mesh,
140142
)
@@ -199,7 +201,9 @@ def test_compute_boundary_conditions_Te(
199201
)
200202
sources = source_pydantic_model.Sources.from_dict({})
201203
static_slice = build_runtime_params.build_static_runtime_params_slice(
202-
runtime_params=runtime_params,
204+
profile_conditions=runtime_params.profile_conditions,
205+
numerics=runtime_params.numerics,
206+
plasma_composition=runtime_params.plasma_composition,
203207
sources=sources,
204208
torax_mesh=self.geo.torax_mesh,
205209
)
@@ -250,7 +254,9 @@ def test_compute_boundary_conditions_Ti(
250254
)
251255
sources = source_pydantic_model.Sources.from_dict({})
252256
static_slice = build_runtime_params.build_static_runtime_params_slice(
253-
runtime_params=runtime_params,
257+
profile_conditions=runtime_params.profile_conditions,
258+
numerics=runtime_params.numerics,
259+
plasma_composition=runtime_params.plasma_composition,
254260
sources=sources,
255261
torax_mesh=self.geo.torax_mesh,
256262
)

torax/fvm/tests/calc_coeffs_test.py

+37-53
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,13 @@
1515
from absl.testing import absltest
1616
from absl.testing import parameterized
1717
from torax.config import build_runtime_params
18-
from torax.config import numerics as numerics_lib
19-
from torax.config import profile_conditions as profile_conditions_lib
20-
from torax.config import runtime_params as general_runtime_params
2118
from torax.core_profiles import initialization
2219
from torax.fvm import calc_coeffs
23-
from torax.geometry import pydantic_model as geometry_pydantic_model
24-
from torax.pedestal_model import pydantic_model as pedestal_pydantic_model
25-
from torax.sources import pydantic_model as sources_pydantic_model
2620
from torax.sources import runtime_params as source_runtime_params
2721
from torax.sources import source_models as source_models_lib
2822
from torax.sources import source_profile_builders
29-
from torax.stepper import pydantic_model as stepper_pydantic_model
3023
from torax.tests.test_lib import default_sources
31-
from torax.transport_model import pydantic_model as transport_pydantic_model
24+
from torax.torax_pydantic import model_config
3225

3326

3427
class CoreProfileSettersTest(parameterized.TestCase):
@@ -42,66 +35,55 @@ class CoreProfileSettersTest(parameterized.TestCase):
4235
def test_calc_coeffs_smoke_test(
4336
self, num_cells, theta_imp, set_pedestal
4437
):
45-
runtime_params = general_runtime_params.GeneralRuntimeParams(
46-
profile_conditions=profile_conditions_lib.ProfileConditions(
47-
set_pedestal=set_pedestal,
48-
),
49-
numerics=numerics_lib.Numerics(
50-
el_heat_eq=False,
51-
),
52-
)
53-
stepper_params = stepper_pydantic_model.Stepper.from_dict(
54-
dict(
55-
predictor_corrector=False,
56-
theta_imp=theta_imp,
57-
)
58-
)
59-
geo = geometry_pydantic_model.CircularConfig(
60-
n_rho=num_cells
61-
).build_geometry()
62-
63-
transport = transport_pydantic_model.Transport.from_dict(
64-
{'transport_model': 'constant', 'chimin': 0, 'chii_const': 1}
65-
)
66-
pedestal = pedestal_pydantic_model.Pedestal()
67-
pedestal_model = pedestal.build_pedestal_model()
68-
transport_model = transport.build_transport_model()
69-
sources = default_sources.get_default_sources()
70-
sources_dict = sources.to_dict()
71-
sources_dict['qei_source']['Qei_mult'] = 0.0
72-
sources_dict['generic_ion_el_heat_source']['Ptot'] = (
73-
0.0
74-
)
75-
sources_dict['fusion_heat_source']['mode'] = (
38+
sources_config = default_sources.get_default_source_config()
39+
sources_config['qei_source']['Qei_mult'] = 0.0
40+
sources_config['generic_ion_el_heat_source']['Ptot'] = 0.0
41+
sources_config['fusion_heat_source']['mode'] = (
7642
source_runtime_params.Mode.ZERO
7743
)
78-
sources_dict['ohmic_heat_source']['mode'] = (
44+
sources_config['ohmic_heat_source']['mode'] = (
7945
source_runtime_params.Mode.ZERO
8046
)
81-
sources = sources_pydantic_model.Sources.from_dict(sources_dict)
47+
torax_config = model_config.ToraxConfig.from_dict(
48+
dict(
49+
runtime_params=dict(
50+
profile_conditions=dict(set_pedestal=set_pedestal),
51+
numerics=dict(el_heat_eq=False),
52+
),
53+
geometry=dict(geometry_type='circular', n_rho=num_cells),
54+
pedestal=dict(),
55+
sources=sources_config,
56+
stepper=dict(predictor_corrector=False, theta_imp=theta_imp),
57+
transport=dict(transport_model='constant', chimin=0, chii_const=1),
58+
time_step_calculator=dict(),
59+
)
60+
)
8261
source_models = source_models_lib.SourceModels(
83-
sources=sources.source_model_config
62+
sources=torax_config.sources.source_model_config
8463
)
8564
dynamic_runtime_params_slice = (
8665
build_runtime_params.DynamicRuntimeParamsSliceProvider(
87-
runtime_params,
88-
transport=transport,
89-
sources=sources,
90-
stepper=stepper_params,
91-
pedestal=pedestal,
92-
torax_mesh=geo.torax_mesh,
66+
torax_config.runtime_params,
67+
transport=torax_config.transport,
68+
sources=torax_config.sources,
69+
stepper=torax_config.stepper,
70+
pedestal=torax_config.pedestal,
71+
torax_mesh=torax_config.geometry.build_provider.torax_mesh,
9372
)(
94-
t=runtime_params.numerics.t_initial,
73+
t=torax_config.numerics.t_initial,
9574
)
9675
)
9776
static_runtime_params_slice = (
9877
build_runtime_params.build_static_runtime_params_slice(
99-
runtime_params=runtime_params,
100-
sources=sources,
101-
torax_mesh=geo.torax_mesh,
102-
stepper=stepper_params,
78+
profile_conditions=torax_config.profile_conditions,
79+
numerics=torax_config.numerics,
80+
plasma_composition=torax_config.plasma_composition,
81+
sources=torax_config.sources,
82+
torax_mesh=torax_config.geometry.build_provider.torax_mesh,
83+
stepper=torax_config.stepper,
10384
)
10485
)
86+
geo = torax_config.geometry.build_provider(torax_config.numerics.t_initial)
10587
core_profiles = initialization.initial_core_profiles(
10688
static_runtime_params_slice,
10789
dynamic_runtime_params_slice,
@@ -117,6 +99,8 @@ def test_calc_coeffs_smoke_test(
11799
core_profiles=core_profiles,
118100
explicit=True,
119101
)
102+
pedestal_model = torax_config.pedestal.build_pedestal_model()
103+
transport_model = torax_config.transport.build_transport_model()
120104
calc_coeffs.calc_coeffs(
121105
static_runtime_params_slice=static_runtime_params_slice,
122106
dynamic_runtime_params_slice=dynamic_runtime_params_slice,

0 commit comments

Comments
 (0)