Skip to content

Commit 4d7ff98

Browse files
tamaranormanTorax team
authored and
Torax team
committed
Move postprocessing out of sim state
PiperOrigin-RevId: 743894950
1 parent 72c270b commit 4d7ff98

14 files changed

+246
-203
lines changed

Diff for: torax/orchestration/initial_state.py

+70-43
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,49 @@
1717
from absl import logging
1818
import jax.numpy as jnp
1919
from torax import output
20+
from torax import post_processing
2021
from torax import state
22+
from torax.config import build_runtime_params
2123
from torax.config import config_args
2224
from torax.config import runtime_params_slice
2325
from torax.core_profiles import initialization
2426
from torax.geometry import geometry
27+
from torax.geometry import geometry_provider as geometry_provider_lib
2528
from torax.orchestration import step_function
2629
from torax.sources import source_profile_builders
2730
from torax.torax_pydantic import file_restart as file_restart_pydantic_model
2831
import xarray as xr
2932

3033

31-
def get_initial_state(
34+
def get_initial_state_and_post_processed_outputs(
35+
t: float,
36+
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
37+
dynamic_runtime_params_slice_provider: build_runtime_params.DynamicRuntimeParamsSliceProvider,
38+
geometry_provider: geometry_provider_lib.GeometryProvider,
39+
step_fn: step_function.SimulationStepFn,
40+
) -> tuple[state.ToraxSimState, state.PostProcessedOutputs]:
41+
"""Returns the initial state and post processed outputs."""
42+
dynamic_runtime_params_slice_for_init, geo_for_init = (
43+
build_runtime_params.get_consistent_dynamic_runtime_params_slice_and_geometry(
44+
t=t,
45+
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
46+
geometry_provider=geometry_provider,
47+
)
48+
)
49+
initial_state = _get_initial_state(
50+
static_runtime_params_slice=static_runtime_params_slice,
51+
dynamic_runtime_params_slice=dynamic_runtime_params_slice_for_init,
52+
geo=geo_for_init,
53+
step_fn=step_fn,
54+
)
55+
post_processed_outputs = post_processing.make_post_processed_outputs(
56+
initial_state,
57+
dynamic_runtime_params_slice_for_init,
58+
)
59+
return initial_state, post_processed_outputs
60+
61+
62+
def _get_initial_state(
3263
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
3364
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
3465
geo: geometry.Geometry,
@@ -59,7 +90,6 @@ def get_initial_state(
5990
# This will be overridden within run_simulation().
6091
core_sources=initial_core_sources,
6192
core_transport=state.CoreTransport.zeros(geo),
62-
post_processed_outputs=state.PostProcessedOutputs.zeros(geo),
6393
stepper_numeric_outputs=state.StepperNumericOutputs(
6494
stepper_error_state=0,
6595
outer_stepper_iterations=0,
@@ -69,34 +99,39 @@ def get_initial_state(
6999
)
70100

71101

72-
def initial_state_from_file_restart(
102+
def get_initial_state_and_post_processed_outputs_from_file(
103+
t_initial: float,
73104
file_restart: file_restart_pydantic_model.FileRestart,
74105
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
75-
dynamic_runtime_params_slice_for_init: runtime_params_slice.DynamicRuntimeParamsSlice,
76-
geo_for_init: geometry.Geometry,
106+
dynamic_runtime_params_slice_provider: build_runtime_params.DynamicRuntimeParamsSliceProvider,
107+
geometry_provider: geometry_provider_lib.GeometryProvider,
77108
step_fn: step_function.SimulationStepFn,
78-
) -> state.ToraxSimState:
79-
"""Returns the initial state for a file restart."""
109+
) -> tuple[state.ToraxSimState, state.PostProcessedOutputs]:
110+
"""Returns the initial state and post processed outputs from a file."""
80111
data_tree = output.load_state_file(file_restart.filename)
81112
# Find the closest time in the given dataset.
82113
data_tree = data_tree.sel(time=file_restart.time, method='nearest')
83114
t_restart = data_tree.time.item()
84115
core_profiles_dataset = data_tree.children[output.CORE_PROFILES].dataset
85116
# Remap coordinates in saved file to be consistent with expectations of
86117
# how config_args parses xarrays.
87-
core_profiles_dataset = core_profiles_dataset.rename(
88-
{output.RHO_CELL_NORM: config_args.RHO_NORM}
89-
)
90118
core_profiles_dataset = core_profiles_dataset.squeeze()
91-
if t_restart != dynamic_runtime_params_slice_for_init.numerics.t_initial:
119+
if t_restart != t_initial:
92120
logging.warning(
93121
'Requested restart time %f not exactly available in state file %s.'
94122
' Restarting from closest available time %f instead.',
95123
file_restart.time,
96124
file_restart.filename,
97125
t_restart,
98126
)
99-
# Override some of dynamic runtime params slice from t=t_initial.
127+
128+
dynamic_runtime_params_slice_for_init, geo_for_init = (
129+
build_runtime_params.get_consistent_dynamic_runtime_params_slice_and_geometry(
130+
t=t_initial,
131+
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
132+
geometry_provider=geometry_provider,
133+
)
134+
)
100135
dynamic_runtime_params_slice_for_init, geo_for_init = (
101136
_override_initial_runtime_params_from_file(
102137
dynamic_runtime_params_slice_for_init,
@@ -105,36 +140,42 @@ def initial_state_from_file_restart(
105140
core_profiles_dataset,
106141
)
107142
)
143+
initial_state = _get_initial_state(
144+
static_runtime_params_slice=static_runtime_params_slice,
145+
dynamic_runtime_params_slice=dynamic_runtime_params_slice_for_init,
146+
geo=geo_for_init,
147+
step_fn=step_fn,
148+
)
108149
post_processed_dataset = data_tree.children[
109150
output.POST_PROCESSED_OUTPUTS
110151
].dataset
111152
post_processed_dataset = post_processed_dataset.rename(
112153
{output.RHO_CELL_NORM: config_args.RHO_NORM}
113154
)
114155
post_processed_dataset = post_processed_dataset.squeeze()
115-
post_processed_outputs = (
116-
_override_initial_state_post_processed_outputs_from_file(
117-
geo_for_init,
118-
post_processed_dataset,
119-
)
156+
post_processed_outputs = post_processing.make_post_processed_outputs(
157+
initial_state,
158+
dynamic_runtime_params_slice_for_init,
120159
)
121-
122-
initial_state = get_initial_state(
123-
static_runtime_params_slice=static_runtime_params_slice,
124-
dynamic_runtime_params_slice=dynamic_runtime_params_slice_for_init,
125-
geo=geo_for_init,
126-
step_fn=step_fn,
160+
post_processed_outputs = dataclasses.replace(
161+
post_processed_outputs,
162+
E_cumulative_fusion=post_processed_dataset.data_vars[
163+
'E_cumulative_fusion'
164+
].to_numpy(),
165+
E_cumulative_external=post_processed_dataset.data_vars[
166+
'E_cumulative_external'
167+
].to_numpy(),
127168
)
128-
# In restarts we always know the initial vloop_lcfs so replace the
129-
# zeros initialization (for Ip BC case) from get_initial_state.
130169
core_profiles = dataclasses.replace(
131170
initial_state.core_profiles,
132171
vloop_lcfs=core_profiles_dataset.vloop_lcfs.values,
133172
)
134-
return dataclasses.replace(
135-
initial_state,
136-
post_processed_outputs=post_processed_outputs,
137-
core_profiles=core_profiles,
173+
return (
174+
dataclasses.replace(
175+
initial_state,
176+
core_profiles=core_profiles,
177+
),
178+
post_processed_outputs,
138179
)
139180

140181

@@ -190,17 +231,3 @@ def _override_initial_runtime_params_from_file(
190231
)
191232

192233
return dynamic_runtime_params_slice, geo
193-
194-
195-
def _override_initial_state_post_processed_outputs_from_file(
196-
geo: geometry.Geometry,
197-
ds: xr.Dataset,
198-
) -> state.PostProcessedOutputs:
199-
"""Override parts of initial state post processed outputs from file."""
200-
post_processed_outputs = state.PostProcessedOutputs.zeros(geo)
201-
post_processed_outputs = dataclasses.replace(
202-
post_processed_outputs,
203-
E_cumulative_fusion=ds.data_vars['E_cumulative_fusion'].to_numpy(),
204-
E_cumulative_external=ds.data_vars['E_cumulative_external'].to_numpy(),
205-
)
206-
return post_processed_outputs

Diff for: torax/orchestration/run_simulation.py

+20-20
Original file line numberDiff line numberDiff line change
@@ -86,37 +86,36 @@ def run_simulation(
8686
)
8787
)
8888

89-
dynamic_runtime_params_slice_for_init, geo_for_init = (
90-
build_runtime_params.get_consistent_dynamic_runtime_params_slice_and_geometry(
91-
t=torax_config.runtime_params.numerics.t_initial,
92-
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
93-
geometry_provider=geometry_provider,
94-
)
95-
)
96-
9789
if torax_config.restart and torax_config.restart.do_restart:
98-
initial_state = initial_state_lib.initial_state_from_file_restart(
99-
file_restart=torax_config.restart,
100-
static_runtime_params_slice=static_runtime_params_slice,
101-
dynamic_runtime_params_slice_for_init=dynamic_runtime_params_slice_for_init,
102-
geo_for_init=geo_for_init,
103-
step_fn=step_fn,
90+
initial_state, post_processed_outputs = (
91+
initial_state_lib.get_initial_state_and_post_processed_outputs_from_file(
92+
t_initial=torax_config.runtime_params.numerics.t_initial,
93+
file_restart=torax_config.restart,
94+
static_runtime_params_slice=static_runtime_params_slice,
95+
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
96+
geometry_provider=geometry_provider,
97+
step_fn=step_fn,
98+
)
10499
)
105100
restart_case = True
106101
else:
107-
initial_state = initial_state_lib.get_initial_state(
108-
static_runtime_params_slice=static_runtime_params_slice,
109-
dynamic_runtime_params_slice=dynamic_runtime_params_slice_for_init,
110-
geo=geo_for_init,
111-
step_fn=step_fn,
102+
initial_state, post_processed_outputs = (
103+
initial_state_lib.get_initial_state_and_post_processed_outputs(
104+
t=torax_config.runtime_params.numerics.t_initial,
105+
static_runtime_params_slice=static_runtime_params_slice,
106+
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
107+
geometry_provider=geometry_provider,
108+
step_fn=step_fn,
109+
)
112110
)
113111
restart_case = False
114112

115-
state_history, sim_error = sim._run_simulation( # pylint: disable=protected-access
113+
state_history, post_processed_outputs_history, sim_error = sim._run_simulation( # pylint: disable=protected-access
116114
static_runtime_params_slice=static_runtime_params_slice,
117115
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
118116
geometry_provider=geometry_provider,
119117
initial_state=initial_state,
118+
initial_post_processed_outputs=post_processed_outputs,
120119
restart_case=restart_case,
121120
step_fn=step_fn,
122121
log_timestep_info=log_timestep_info,
@@ -125,6 +124,7 @@ def run_simulation(
125124

126125
return output.StateHistory(
127126
state_history=state_history,
127+
post_processed_outputs_history=post_processed_outputs_history,
128128
sim_error=sim_error,
129129
source_models=source_models,
130130
torax_config=torax_config,

Diff for: torax/orchestration/step_function.py

+11-8
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def __call__(
118118
dynamic_runtime_params_slice_provider: build_runtime_params.DynamicRuntimeParamsSliceProvider,
119119
geometry_provider: geometry_provider_lib.GeometryProvider,
120120
input_state: state.ToraxSimState,
121-
) -> tuple[state.ToraxSimState, state.SimError]:
121+
previous_post_processed_outputs: state.PostProcessedOutputs,
122+
) -> tuple[state.ToraxSimState, state.PostProcessedOutputs, state.SimError]:
122123
"""Advances the simulation state one time step.
123124
124125
Args:
@@ -138,6 +139,8 @@ def __call__(
138139
(in order to support time-dependent geometries).
139140
input_state: State at the start of the time step, including the core
140141
profiles which are being evolved.
142+
previous_post_processed_outputs: Post-processed outputs from the previous
143+
time step.
141144
142145
Returns:
143146
ToraxSimState containing:
@@ -150,6 +153,9 @@ def __call__(
150153
1 if solver did not converge for this step (was above coarse tol)
151154
2 if solver converged within coarse tolerance. Allowed to pass with
152155
a warning. Occasional error=2 has low impact on final sim state.
156+
PostProcessedOutputs containing:
157+
- post-processed outputs at the end of the time step.
158+
- cumulative quantities.
153159
SimError indicating if an error has occurred during simulation.
154160
"""
155161
dynamic_runtime_params_slice_t, geo_t = (
@@ -217,13 +223,14 @@ def __call__(
217223
explicit_source_profiles,
218224
)
219225

220-
output_state = post_processing.make_outputs(
226+
post_processed_outputs = post_processing.make_post_processed_outputs(
221227
sim_state=output_state,
222228
dynamic_runtime_params_slice=dynamic_runtime_params_slice_t_plus_dt,
223-
previous_sim_state=input_state,
229+
previous_post_processed_outputs=previous_post_processed_outputs,
224230
)
225231

226-
return output_state, output_state.check_for_errors()
232+
return output_state, post_processed_outputs, state.check_for_errors(
233+
output_state, post_processed_outputs)
227234

228235
def init_time_step_calculator(
229236
self,
@@ -364,7 +371,6 @@ def step(
364371
core_profiles=core_profiles,
365372
core_transport=core_transport,
366373
core_sources=core_sources,
367-
post_processed_outputs=state.PostProcessedOutputs.zeros(geo_t_plus_dt),
368374
stepper_numeric_outputs=stepper_numeric_outputs,
369375
geometry=geo_t_plus_dt,
370376
)
@@ -492,9 +498,6 @@ def body_fun(
492498
core_profiles=core_profiles,
493499
core_transport=core_transport,
494500
core_sources=core_sources,
495-
post_processed_outputs=state.PostProcessedOutputs.zeros(
496-
geo_t_plus_dt
497-
),
498501
stepper_numeric_outputs=stepper_numeric_outputs,
499502
geometry=geo_t_plus_dt,
500503
),

0 commit comments

Comments
 (0)