17
17
from absl import logging
18
18
import jax .numpy as jnp
19
19
from torax import output
20
+ from torax import post_processing
20
21
from torax import state
22
+ from torax .config import build_runtime_params
21
23
from torax .config import config_args
22
24
from torax .config import runtime_params_slice
23
25
from torax .core_profiles import initialization
24
26
from torax .geometry import geometry
27
+ from torax .geometry import geometry_provider as geometry_provider_lib
25
28
from torax .orchestration import step_function
26
29
from torax .sources import source_profile_builders
27
30
from torax .torax_pydantic import file_restart as file_restart_pydantic_model
28
31
import xarray as xr
29
32
30
33
31
- def get_initial_state (
34
+ def get_initial_state_and_post_processed_outputs (
35
+ t : float ,
36
+ static_runtime_params_slice : runtime_params_slice .StaticRuntimeParamsSlice ,
37
+ dynamic_runtime_params_slice_provider : build_runtime_params .DynamicRuntimeParamsSliceProvider ,
38
+ geometry_provider : geometry_provider_lib .GeometryProvider ,
39
+ step_fn : step_function .SimulationStepFn ,
40
+ ) -> tuple [state .ToraxSimState , state .PostProcessedOutputs ]:
41
+ """Returns the initial state and post processed outputs."""
42
+ dynamic_runtime_params_slice_for_init , geo_for_init = (
43
+ build_runtime_params .get_consistent_dynamic_runtime_params_slice_and_geometry (
44
+ t = t ,
45
+ dynamic_runtime_params_slice_provider = dynamic_runtime_params_slice_provider ,
46
+ geometry_provider = geometry_provider ,
47
+ )
48
+ )
49
+ initial_state = _get_initial_state (
50
+ static_runtime_params_slice = static_runtime_params_slice ,
51
+ dynamic_runtime_params_slice = dynamic_runtime_params_slice_for_init ,
52
+ geo = geo_for_init ,
53
+ step_fn = step_fn ,
54
+ )
55
+ post_processed_outputs = post_processing .make_post_processed_outputs (
56
+ initial_state ,
57
+ dynamic_runtime_params_slice_for_init ,
58
+ )
59
+ return initial_state , post_processed_outputs
60
+
61
+
62
+ def _get_initial_state (
32
63
static_runtime_params_slice : runtime_params_slice .StaticRuntimeParamsSlice ,
33
64
dynamic_runtime_params_slice : runtime_params_slice .DynamicRuntimeParamsSlice ,
34
65
geo : geometry .Geometry ,
@@ -59,7 +90,6 @@ def get_initial_state(
59
90
# This will be overridden within run_simulation().
60
91
core_sources = initial_core_sources ,
61
92
core_transport = state .CoreTransport .zeros (geo ),
62
- post_processed_outputs = state .PostProcessedOutputs .zeros (geo ),
63
93
stepper_numeric_outputs = state .StepperNumericOutputs (
64
94
stepper_error_state = 0 ,
65
95
outer_stepper_iterations = 0 ,
@@ -69,34 +99,39 @@ def get_initial_state(
69
99
)
70
100
71
101
72
- def initial_state_from_file_restart (
102
+ def get_initial_state_and_post_processed_outputs_from_file (
103
+ t_initial : float ,
73
104
file_restart : file_restart_pydantic_model .FileRestart ,
74
105
static_runtime_params_slice : runtime_params_slice .StaticRuntimeParamsSlice ,
75
- dynamic_runtime_params_slice_for_init : runtime_params_slice . DynamicRuntimeParamsSlice ,
76
- geo_for_init : geometry . Geometry ,
106
+ dynamic_runtime_params_slice_provider : build_runtime_params . DynamicRuntimeParamsSliceProvider ,
107
+ geometry_provider : geometry_provider_lib . GeometryProvider ,
77
108
step_fn : step_function .SimulationStepFn ,
78
- ) -> state .ToraxSimState :
79
- """Returns the initial state for a file restart ."""
109
+ ) -> tuple [ state .ToraxSimState , state . PostProcessedOutputs ] :
110
+ """Returns the initial state and post processed outputs from a file."""
80
111
data_tree = output .load_state_file (file_restart .filename )
81
112
# Find the closest time in the given dataset.
82
113
data_tree = data_tree .sel (time = file_restart .time , method = 'nearest' )
83
114
t_restart = data_tree .time .item ()
84
115
core_profiles_dataset = data_tree .children [output .CORE_PROFILES ].dataset
85
116
# Remap coordinates in saved file to be consistent with expectations of
86
117
# how config_args parses xarrays.
87
- core_profiles_dataset = core_profiles_dataset .rename (
88
- {output .RHO_CELL_NORM : config_args .RHO_NORM }
89
- )
90
118
core_profiles_dataset = core_profiles_dataset .squeeze ()
91
- if t_restart != dynamic_runtime_params_slice_for_init . numerics . t_initial :
119
+ if t_restart != t_initial :
92
120
logging .warning (
93
121
'Requested restart time %f not exactly available in state file %s.'
94
122
' Restarting from closest available time %f instead.' ,
95
123
file_restart .time ,
96
124
file_restart .filename ,
97
125
t_restart ,
98
126
)
99
- # Override some of dynamic runtime params slice from t=t_initial.
127
+
128
+ dynamic_runtime_params_slice_for_init , geo_for_init = (
129
+ build_runtime_params .get_consistent_dynamic_runtime_params_slice_and_geometry (
130
+ t = t_initial ,
131
+ dynamic_runtime_params_slice_provider = dynamic_runtime_params_slice_provider ,
132
+ geometry_provider = geometry_provider ,
133
+ )
134
+ )
100
135
dynamic_runtime_params_slice_for_init , geo_for_init = (
101
136
_override_initial_runtime_params_from_file (
102
137
dynamic_runtime_params_slice_for_init ,
@@ -105,36 +140,42 @@ def initial_state_from_file_restart(
105
140
core_profiles_dataset ,
106
141
)
107
142
)
143
+ initial_state = _get_initial_state (
144
+ static_runtime_params_slice = static_runtime_params_slice ,
145
+ dynamic_runtime_params_slice = dynamic_runtime_params_slice_for_init ,
146
+ geo = geo_for_init ,
147
+ step_fn = step_fn ,
148
+ )
108
149
post_processed_dataset = data_tree .children [
109
150
output .POST_PROCESSED_OUTPUTS
110
151
].dataset
111
152
post_processed_dataset = post_processed_dataset .rename (
112
153
{output .RHO_CELL_NORM : config_args .RHO_NORM }
113
154
)
114
155
post_processed_dataset = post_processed_dataset .squeeze ()
115
- post_processed_outputs = (
116
- _override_initial_state_post_processed_outputs_from_file (
117
- geo_for_init ,
118
- post_processed_dataset ,
119
- )
156
+ post_processed_outputs = post_processing .make_post_processed_outputs (
157
+ initial_state ,
158
+ dynamic_runtime_params_slice_for_init ,
120
159
)
121
-
122
- initial_state = get_initial_state (
123
- static_runtime_params_slice = static_runtime_params_slice ,
124
- dynamic_runtime_params_slice = dynamic_runtime_params_slice_for_init ,
125
- geo = geo_for_init ,
126
- step_fn = step_fn ,
160
+ post_processed_outputs = dataclasses .replace (
161
+ post_processed_outputs ,
162
+ E_cumulative_fusion = post_processed_dataset .data_vars [
163
+ 'E_cumulative_fusion'
164
+ ].to_numpy (),
165
+ E_cumulative_external = post_processed_dataset .data_vars [
166
+ 'E_cumulative_external'
167
+ ].to_numpy (),
127
168
)
128
- # In restarts we always know the initial vloop_lcfs so replace the
129
- # zeros initialization (for Ip BC case) from get_initial_state.
130
169
core_profiles = dataclasses .replace (
131
170
initial_state .core_profiles ,
132
171
vloop_lcfs = core_profiles_dataset .vloop_lcfs .values ,
133
172
)
134
- return dataclasses .replace (
135
- initial_state ,
136
- post_processed_outputs = post_processed_outputs ,
137
- core_profiles = core_profiles ,
173
+ return (
174
+ dataclasses .replace (
175
+ initial_state ,
176
+ core_profiles = core_profiles ,
177
+ ),
178
+ post_processed_outputs ,
138
179
)
139
180
140
181
@@ -190,17 +231,3 @@ def _override_initial_runtime_params_from_file(
190
231
)
191
232
192
233
return dynamic_runtime_params_slice , geo
193
-
194
-
195
- def _override_initial_state_post_processed_outputs_from_file (
196
- geo : geometry .Geometry ,
197
- ds : xr .Dataset ,
198
- ) -> state .PostProcessedOutputs :
199
- """Override parts of initial state post processed outputs from file."""
200
- post_processed_outputs = state .PostProcessedOutputs .zeros (geo )
201
- post_processed_outputs = dataclasses .replace (
202
- post_processed_outputs ,
203
- E_cumulative_fusion = ds .data_vars ['E_cumulative_fusion' ].to_numpy (),
204
- E_cumulative_external = ds .data_vars ['E_cumulative_external' ].to_numpy (),
205
- )
206
- return post_processed_outputs
0 commit comments