|
26 | 26 |
|
27 | 27 | import dataclasses
|
28 | 28 | import time
|
29 |
| -from typing import Optional |
30 | 29 |
|
31 | 30 | from absl import logging
|
32 | 31 | import jax
|
|
36 | 35 | from torax import post_processing
|
37 | 36 | from torax import state
|
38 | 37 | from torax.config import build_runtime_params
|
39 |
| -from torax.config import config_args |
40 |
| -from torax.config import runtime_params as general_runtime_params |
41 | 38 | from torax.config import runtime_params_slice
|
42 | 39 | from torax.core_profiles import initialization
|
43 | 40 | from torax.geometry import geometry
|
44 | 41 | from torax.geometry import geometry_provider as geometry_provider_lib
|
45 | 42 | from torax.orchestration import step_function
|
46 |
| -from torax.pedestal_model import pedestal_model as pedestal_model_lib |
47 |
| -from torax.pedestal_model import pydantic_model as pedestal_pydantic_model |
48 |
| -from torax.sources import pydantic_model as source_pydantic_model |
49 |
| -from torax.sources import source_models as source_models_lib |
50 | 43 | from torax.sources import source_profile_builders
|
51 |
| -from torax.stepper import pydantic_model as stepper_pydantic_model |
52 |
| -from torax.stepper import stepper as stepper_lib |
53 |
| -from torax.time_step_calculator import chi_time_step_calculator |
54 |
| -from torax.time_step_calculator import time_step_calculator as ts |
55 |
| -from torax.torax_pydantic import file_restart as file_restart_pydantic_model |
56 |
| -from torax.transport_model import pydantic_model as transport_model_pydantic_model |
57 |
| -from torax.transport_model import transport_model as transport_model_lib |
58 | 44 | import tqdm
|
59 |
| -import typing_extensions |
60 | 45 | import xarray as xr
|
61 | 46 |
|
62 | 47 |
|
@@ -102,347 +87,6 @@ def get_initial_state(
|
102 | 87 | )
|
103 | 88 |
|
104 | 89 |
|
105 |
| -class Sim: |
106 |
| - """A lightweight object holding all components of a simulation. |
107 |
| -
|
108 |
| - Use of this object is optional, it is also fine to hold these objects |
109 |
| - in local variables of a script and call `run_simulation` directly. |
110 |
| -
|
111 |
| - The main purpose of the Sim object is to enable configuration via |
112 |
| - constructor arguments. Components are reused in subsequent simulation runs, so |
113 |
| - if a component is compiled, it will be reused for the next `Sim.run()` call |
114 |
| - and will not be recompiled unless a static argument or shape changes. |
115 |
| - """ |
116 |
| - |
117 |
| - def __init__( |
118 |
| - self, |
119 |
| - static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice, |
120 |
| - dynamic_runtime_params_slice_provider: build_runtime_params.DynamicRuntimeParamsSliceProvider, |
121 |
| - geometry_provider: geometry_provider_lib.GeometryProvider, |
122 |
| - initial_state: state.ToraxSimState, |
123 |
| - step_fn: step_function.SimulationStepFn, |
124 |
| - file_restart: file_restart_pydantic_model.FileRestart | None = None, |
125 |
| - ): |
126 |
| - self._static_runtime_params_slice = static_runtime_params_slice |
127 |
| - self._dynamic_runtime_params_slice_provider = ( |
128 |
| - dynamic_runtime_params_slice_provider |
129 |
| - ) |
130 |
| - self._geometry_provider = geometry_provider |
131 |
| - self._initial_state = initial_state |
132 |
| - self._step_fn = step_fn |
133 |
| - self._file_restart = file_restart |
134 |
| - |
135 |
| - @property |
136 |
| - def file_restart(self) -> file_restart_pydantic_model.FileRestart | None: |
137 |
| - return self._file_restart |
138 |
| - |
139 |
| - @property |
140 |
| - def time_step_calculator(self) -> ts.TimeStepCalculator: |
141 |
| - return self._step_fn.time_step_calculator |
142 |
| - |
143 |
| - @property |
144 |
| - def initial_state(self) -> state.ToraxSimState: |
145 |
| - return self._initial_state |
146 |
| - |
147 |
| - @property |
148 |
| - def geometry_provider(self) -> geometry_provider_lib.GeometryProvider: |
149 |
| - return self._geometry_provider |
150 |
| - |
151 |
| - @property |
152 |
| - def dynamic_runtime_params_slice_provider( |
153 |
| - self, |
154 |
| - ) -> build_runtime_params.DynamicRuntimeParamsSliceProvider: |
155 |
| - return self._dynamic_runtime_params_slice_provider |
156 |
| - |
157 |
| - @property |
158 |
| - def static_runtime_params_slice( |
159 |
| - self, |
160 |
| - ) -> runtime_params_slice.StaticRuntimeParamsSlice: |
161 |
| - return self._static_runtime_params_slice |
162 |
| - |
163 |
| - @property |
164 |
| - def step_fn(self) -> step_function.SimulationStepFn: |
165 |
| - return self._step_fn |
166 |
| - |
167 |
| - @property |
168 |
| - def stepper(self) -> stepper_lib.Stepper: |
169 |
| - return self._step_fn.stepper |
170 |
| - |
171 |
| - @property |
172 |
| - def transport_model(self) -> transport_model_lib.TransportModel: |
173 |
| - return self.stepper.transport_model |
174 |
| - |
175 |
| - @property |
176 |
| - def pedestal_model(self) -> pedestal_model_lib.PedestalModel: |
177 |
| - return self.stepper.pedestal_model |
178 |
| - |
179 |
| - @property |
180 |
| - def source_models(self) -> source_models_lib.SourceModels: |
181 |
| - return self.stepper.source_models |
182 |
| - |
183 |
| - def update_base_components( |
184 |
| - self, |
185 |
| - *, |
186 |
| - allow_recompilation: bool = False, |
187 |
| - static_runtime_params_slice: ( |
188 |
| - runtime_params_slice.StaticRuntimeParamsSlice | None |
189 |
| - ) = None, |
190 |
| - dynamic_runtime_params_slice_provider: ( |
191 |
| - build_runtime_params.DynamicRuntimeParamsSliceProvider | None |
192 |
| - ) = None, |
193 |
| - geometry_provider: geometry_provider_lib.GeometryProvider | None = None, |
194 |
| - ): |
195 |
| - """Updates the Sim object with components that have already been updated. |
196 |
| -
|
197 |
| - Currently this only supports updating the geometry provider and the dynamic |
198 |
| - runtime params slice provider, both of which can be updated without |
199 |
| - recompilation. |
200 |
| -
|
201 |
| - Args: |
202 |
| - allow_recompilation: Whether recompilation is allowed. If True, the static |
203 |
| - runtime params slice can be updated. NOTE: recompilaton may still occur |
204 |
| - if the mesh is updated or if the shapes returned in the dynamic runtime |
205 |
| - params slice provider change even if this is False. |
206 |
| - static_runtime_params_slice: The new static runtime params slice. If None, |
207 |
| - the existing one is kept. |
208 |
| - dynamic_runtime_params_slice_provider: The new dynamic runtime params |
209 |
| - slice provider. This should already have been updated with modifications |
210 |
| - to the various components. If None, the existing one is kept. |
211 |
| - geometry_provider: The new geometry provider. If None, the existing one is |
212 |
| - kept. |
213 |
| -
|
214 |
| - Raises: |
215 |
| - ValueError: If the Sim object has a file restart or if the geometry |
216 |
| - provider has a different mesh than the existing one. |
217 |
| - """ |
218 |
| - if self._file_restart is not None: |
219 |
| - # TODO(b/384767453): Add support for updating a Sim object with a file |
220 |
| - # restart. |
221 |
| - raise ValueError('Cannot update a Sim object with a file restart.') |
222 |
| - if not allow_recompilation and static_runtime_params_slice is not None: |
223 |
| - raise ValueError( |
224 |
| - 'Cannot update a Sim object with a static runtime params slice if ' |
225 |
| - 'recompilation is not allowed.' |
226 |
| - ) |
227 |
| - |
228 |
| - if static_runtime_params_slice is not None: |
229 |
| - assert isinstance( # Avoid pytype error. |
230 |
| - self._static_runtime_params_slice, |
231 |
| - runtime_params_slice.StaticRuntimeParamsSlice, |
232 |
| - ) |
233 |
| - self._static_runtime_params_slice.validate_new( |
234 |
| - static_runtime_params_slice |
235 |
| - ) |
236 |
| - self._static_runtime_params_slice = static_runtime_params_slice |
237 |
| - if dynamic_runtime_params_slice_provider is not None: |
238 |
| - assert isinstance( # Avoid pytype error. |
239 |
| - self._dynamic_runtime_params_slice_provider, |
240 |
| - build_runtime_params.DynamicRuntimeParamsSliceProvider, |
241 |
| - ) |
242 |
| - self._dynamic_runtime_params_slice_provider.validate_new( |
243 |
| - dynamic_runtime_params_slice_provider |
244 |
| - ) |
245 |
| - self._dynamic_runtime_params_slice_provider = ( |
246 |
| - dynamic_runtime_params_slice_provider |
247 |
| - ) |
248 |
| - if geometry_provider is not None: |
249 |
| - if geometry_provider.torax_mesh != self._geometry_provider.torax_mesh: |
250 |
| - raise ValueError( |
251 |
| - 'Cannot update a Sim object with a geometry provider with a ' |
252 |
| - 'different mesh.' |
253 |
| - ) |
254 |
| - self._geometry_provider = geometry_provider |
255 |
| - |
256 |
| - dynamic_runtime_params_slice_for_init, geo_for_init = ( |
257 |
| - build_runtime_params.get_consistent_dynamic_runtime_params_slice_and_geometry( |
258 |
| - t=self._dynamic_runtime_params_slice_provider._runtime_params.numerics.t_initial, # pylint: disable=protected-access |
259 |
| - dynamic_runtime_params_slice_provider=self._dynamic_runtime_params_slice_provider, |
260 |
| - geometry_provider=self._geometry_provider, |
261 |
| - ) |
262 |
| - ) |
263 |
| - self._initial_state = get_initial_state( |
264 |
| - static_runtime_params_slice=self._static_runtime_params_slice, |
265 |
| - dynamic_runtime_params_slice=dynamic_runtime_params_slice_for_init, |
266 |
| - geo=geo_for_init, |
267 |
| - step_fn=self._step_fn, |
268 |
| - ) |
269 |
| - |
270 |
| - def run( |
271 |
| - self, |
272 |
| - log_timestep_info: bool = False, |
273 |
| - ) -> output.ToraxSimOutputs: |
274 |
| - """Runs the transport simulation over a prescribed time interval. |
275 |
| -
|
276 |
| - See `run_simulation` for details. |
277 |
| -
|
278 |
| - Args: |
279 |
| - log_timestep_info: See `run_simulation()`. |
280 |
| -
|
281 |
| - Returns: |
282 |
| - Tuple of all ToraxSimStates, one per time step and an additional one at |
283 |
| - the beginning for the starting state. |
284 |
| - """ |
285 |
| - return _run_simulation( |
286 |
| - static_runtime_params_slice=self.static_runtime_params_slice, |
287 |
| - dynamic_runtime_params_slice_provider=self.dynamic_runtime_params_slice_provider, |
288 |
| - geometry_provider=self.geometry_provider, |
289 |
| - initial_state=self.initial_state, |
290 |
| - step_fn=self.step_fn, |
291 |
| - log_timestep_info=log_timestep_info, |
292 |
| - ) |
293 |
| - |
294 |
| - @classmethod |
295 |
| - def create( |
296 |
| - cls, |
297 |
| - *, |
298 |
| - runtime_params: general_runtime_params.GeneralRuntimeParams, |
299 |
| - geometry_provider: geometry_provider_lib.GeometryProvider, |
300 |
| - stepper: stepper_pydantic_model.Stepper, |
301 |
| - transport_model: transport_model_pydantic_model.Transport, |
302 |
| - sources: source_pydantic_model.Sources, |
303 |
| - pedestal: pedestal_pydantic_model.Pedestal, |
304 |
| - time_step_calculator: Optional[ts.TimeStepCalculator] = None, |
305 |
| - file_restart: file_restart_pydantic_model.FileRestart | None = None, |
306 |
| - ) -> typing_extensions.Self: |
307 |
| - """Builds a Sim object from the input runtime params and sim components. |
308 |
| -
|
309 |
| - Args: |
310 |
| - runtime_params: The input runtime params used throughout the simulation |
311 |
| - run. |
312 |
| - geometry_provider: The geometry used throughout the simulation run. |
313 |
| - stepper: The stepper config that can be used to build the stepper. |
314 |
| - transport_model: The transport model config that can be used to build the |
315 |
| - transport model. |
316 |
| - sources: Builds the sources. |
317 |
| - pedestal: The pedestal config that can be used to build the pedestal. |
318 |
| - time_step_calculator: The time_step_calculator, if built, otherwise a |
319 |
| - ChiTimeStepCalculator will be built by default. |
320 |
| - file_restart: If provided we will reconstruct the initial state from the |
321 |
| - provided file at the given time step. This state from the file will only |
322 |
| - be used for constructing the initial state (as well as the config) and |
323 |
| - for all subsequent steps, the evolved state and runtime parameters from |
324 |
| - config are used. |
325 |
| -
|
326 |
| - Returns: |
327 |
| - sim: The built Sim instance. |
328 |
| - """ |
329 |
| - pedestal_model = pedestal.build_pedestal_model() |
330 |
| - |
331 |
| - # TODO(b/385788907): Document all changes that lead to recompilations. |
332 |
| - static_runtime_params_slice = ( |
333 |
| - build_runtime_params.build_static_runtime_params_slice( |
334 |
| - runtime_params=runtime_params, |
335 |
| - sources=sources, |
336 |
| - torax_mesh=geometry_provider.torax_mesh, |
337 |
| - stepper=stepper, |
338 |
| - ) |
339 |
| - ) |
340 |
| - dynamic_runtime_params_slice_provider = ( |
341 |
| - build_runtime_params.DynamicRuntimeParamsSliceProvider( |
342 |
| - runtime_params=runtime_params, |
343 |
| - transport=transport_model, |
344 |
| - sources=sources, |
345 |
| - stepper=stepper, |
346 |
| - torax_mesh=geometry_provider.torax_mesh, |
347 |
| - pedestal=pedestal, |
348 |
| - ) |
349 |
| - ) |
350 |
| - source_models = source_models_lib.SourceModels( |
351 |
| - sources=sources.source_model_config |
352 |
| - ) |
353 |
| - transport_model = transport_model.build_transport_model() |
354 |
| - stepper_model = stepper.build_stepper_model( |
355 |
| - transport_model=transport_model, |
356 |
| - source_models=source_models, |
357 |
| - pedestal_model=pedestal_model, |
358 |
| - ) |
359 |
| - |
360 |
| - if time_step_calculator is None: |
361 |
| - time_step_calculator = chi_time_step_calculator.ChiTimeStepCalculator() |
362 |
| - |
363 |
| - # Build dynamic_runtime_params_slice at t_initial for initial conditions. |
364 |
| - dynamic_runtime_params_slice_for_init, geo_for_init = ( |
365 |
| - build_runtime_params.get_consistent_dynamic_runtime_params_slice_and_geometry( |
366 |
| - t=runtime_params.numerics.t_initial, |
367 |
| - dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider, |
368 |
| - geometry_provider=geometry_provider, |
369 |
| - ) |
370 |
| - ) |
371 |
| - if file_restart is not None and file_restart.do_restart: |
372 |
| - data_tree = output.load_state_file(file_restart.filename) |
373 |
| - # Find the closest time in the given dataset. |
374 |
| - data_tree = data_tree.sel(time=file_restart.time, method='nearest') |
375 |
| - t_restart = data_tree.time.item() |
376 |
| - core_profiles_dataset = data_tree.children[output.CORE_PROFILES].dataset |
377 |
| - # Remap coordinates in saved file to be consistent with expectations of |
378 |
| - # how config_args parses xarrays. |
379 |
| - core_profiles_dataset = core_profiles_dataset.rename( |
380 |
| - {output.RHO_CELL_NORM: config_args.RHO_NORM} |
381 |
| - ) |
382 |
| - core_profiles_dataset = core_profiles_dataset.squeeze() |
383 |
| - if t_restart != runtime_params.numerics.t_initial: |
384 |
| - logging.warning( |
385 |
| - 'Requested restart time %f not exactly available in state file %s.' |
386 |
| - ' Restarting from closest available time %f instead.', |
387 |
| - file_restart.time, |
388 |
| - file_restart.filename, |
389 |
| - t_restart, |
390 |
| - ) |
391 |
| - # Override some of dynamic runtime params slice from t=t_initial. |
392 |
| - dynamic_runtime_params_slice_for_init, geo_for_init = ( |
393 |
| - _override_initial_runtime_params_from_file( |
394 |
| - dynamic_runtime_params_slice_for_init, |
395 |
| - geo_for_init, |
396 |
| - t_restart, |
397 |
| - core_profiles_dataset, |
398 |
| - ) |
399 |
| - ) |
400 |
| - post_processed_dataset = data_tree.children[ |
401 |
| - output.POST_PROCESSED_OUTPUTS |
402 |
| - ].dataset |
403 |
| - post_processed_dataset = post_processed_dataset.rename( |
404 |
| - {output.RHO_CELL_NORM: config_args.RHO_NORM} |
405 |
| - ) |
406 |
| - post_processed_dataset = post_processed_dataset.squeeze() |
407 |
| - post_processed_outputs = ( |
408 |
| - _override_initial_state_post_processed_outputs_from_file( |
409 |
| - geo_for_init, |
410 |
| - post_processed_dataset, |
411 |
| - ) |
412 |
| - ) |
413 |
| - |
414 |
| - step_fn = step_function.SimulationStepFn( |
415 |
| - stepper=stepper_model, |
416 |
| - time_step_calculator=time_step_calculator, |
417 |
| - transport_model=transport_model, |
418 |
| - pedestal_model=pedestal_model, |
419 |
| - ) |
420 |
| - |
421 |
| - initial_state = get_initial_state( |
422 |
| - static_runtime_params_slice=static_runtime_params_slice, |
423 |
| - dynamic_runtime_params_slice=dynamic_runtime_params_slice_for_init, |
424 |
| - geo=geo_for_init, |
425 |
| - step_fn=step_fn, |
426 |
| - ) |
427 |
| - |
428 |
| - # If we are restarting from a file, we need to override the initial state |
429 |
| - # post processed outputs such that cumulative outputs remain correct. |
430 |
| - if file_restart is not None and file_restart.do_restart: |
431 |
| - initial_state = dataclasses.replace( |
432 |
| - initial_state, |
433 |
| - post_processed_outputs=post_processed_outputs, # pylint: disable=undefined-variable |
434 |
| - ) |
435 |
| - |
436 |
| - return cls( |
437 |
| - static_runtime_params_slice=static_runtime_params_slice, |
438 |
| - dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider, |
439 |
| - geometry_provider=geometry_provider, |
440 |
| - initial_state=initial_state, |
441 |
| - step_fn=step_fn, |
442 |
| - file_restart=file_restart, |
443 |
| - ) |
444 |
| - |
445 |
| - |
446 | 90 | def _override_initial_runtime_params_from_file(
|
447 | 91 | dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
|
448 | 92 | geo: geometry.Geometry,
|
|
0 commit comments