Skip to content

Commit 17ca3aa

Browse files
tamaranormanTorax team
authored and
Torax team
committed
Delete sim.Sim which has been replaced by run_simulation
PiperOrigin-RevId: 738807490
1 parent e795fb7 commit 17ca3aa

File tree

1 file changed

+0
-356
lines changed

1 file changed

+0
-356
lines changed

Diff for: torax/sim.py

-356
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
import dataclasses
2828
import time
29-
from typing import Optional
3029

3130
from absl import logging
3231
import jax
@@ -36,27 +35,13 @@
3635
from torax import post_processing
3736
from torax import state
3837
from torax.config import build_runtime_params
39-
from torax.config import config_args
40-
from torax.config import runtime_params as general_runtime_params
4138
from torax.config import runtime_params_slice
4239
from torax.core_profiles import initialization
4340
from torax.geometry import geometry
4441
from torax.geometry import geometry_provider as geometry_provider_lib
4542
from torax.orchestration import step_function
46-
from torax.pedestal_model import pedestal_model as pedestal_model_lib
47-
from torax.pedestal_model import pydantic_model as pedestal_pydantic_model
48-
from torax.sources import pydantic_model as source_pydantic_model
49-
from torax.sources import source_models as source_models_lib
5043
from torax.sources import source_profile_builders
51-
from torax.stepper import pydantic_model as stepper_pydantic_model
52-
from torax.stepper import stepper as stepper_lib
53-
from torax.time_step_calculator import chi_time_step_calculator
54-
from torax.time_step_calculator import time_step_calculator as ts
55-
from torax.torax_pydantic import file_restart as file_restart_pydantic_model
56-
from torax.transport_model import pydantic_model as transport_model_pydantic_model
57-
from torax.transport_model import transport_model as transport_model_lib
5844
import tqdm
59-
import typing_extensions
6045
import xarray as xr
6146

6247

@@ -102,347 +87,6 @@ def get_initial_state(
10287
)
10388

10489

105-
class Sim:
106-
"""A lightweight object holding all components of a simulation.
107-
108-
Use of this object is optional, it is also fine to hold these objects
109-
in local variables of a script and call `run_simulation` directly.
110-
111-
The main purpose of the Sim object is to enable configuration via
112-
constructor arguments. Components are reused in subsequent simulation runs, so
113-
if a component is compiled, it will be reused for the next `Sim.run()` call
114-
and will not be recompiled unless a static argument or shape changes.
115-
"""
116-
117-
def __init__(
118-
self,
119-
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
120-
dynamic_runtime_params_slice_provider: build_runtime_params.DynamicRuntimeParamsSliceProvider,
121-
geometry_provider: geometry_provider_lib.GeometryProvider,
122-
initial_state: state.ToraxSimState,
123-
step_fn: step_function.SimulationStepFn,
124-
file_restart: file_restart_pydantic_model.FileRestart | None = None,
125-
):
126-
self._static_runtime_params_slice = static_runtime_params_slice
127-
self._dynamic_runtime_params_slice_provider = (
128-
dynamic_runtime_params_slice_provider
129-
)
130-
self._geometry_provider = geometry_provider
131-
self._initial_state = initial_state
132-
self._step_fn = step_fn
133-
self._file_restart = file_restart
134-
135-
@property
136-
def file_restart(self) -> file_restart_pydantic_model.FileRestart | None:
137-
return self._file_restart
138-
139-
@property
140-
def time_step_calculator(self) -> ts.TimeStepCalculator:
141-
return self._step_fn.time_step_calculator
142-
143-
@property
144-
def initial_state(self) -> state.ToraxSimState:
145-
return self._initial_state
146-
147-
@property
148-
def geometry_provider(self) -> geometry_provider_lib.GeometryProvider:
149-
return self._geometry_provider
150-
151-
@property
152-
def dynamic_runtime_params_slice_provider(
153-
self,
154-
) -> build_runtime_params.DynamicRuntimeParamsSliceProvider:
155-
return self._dynamic_runtime_params_slice_provider
156-
157-
@property
158-
def static_runtime_params_slice(
159-
self,
160-
) -> runtime_params_slice.StaticRuntimeParamsSlice:
161-
return self._static_runtime_params_slice
162-
163-
@property
164-
def step_fn(self) -> step_function.SimulationStepFn:
165-
return self._step_fn
166-
167-
@property
168-
def stepper(self) -> stepper_lib.Stepper:
169-
return self._step_fn.stepper
170-
171-
@property
172-
def transport_model(self) -> transport_model_lib.TransportModel:
173-
return self.stepper.transport_model
174-
175-
@property
176-
def pedestal_model(self) -> pedestal_model_lib.PedestalModel:
177-
return self.stepper.pedestal_model
178-
179-
@property
180-
def source_models(self) -> source_models_lib.SourceModels:
181-
return self.stepper.source_models
182-
183-
def update_base_components(
184-
self,
185-
*,
186-
allow_recompilation: bool = False,
187-
static_runtime_params_slice: (
188-
runtime_params_slice.StaticRuntimeParamsSlice | None
189-
) = None,
190-
dynamic_runtime_params_slice_provider: (
191-
build_runtime_params.DynamicRuntimeParamsSliceProvider | None
192-
) = None,
193-
geometry_provider: geometry_provider_lib.GeometryProvider | None = None,
194-
):
195-
"""Updates the Sim object with components that have already been updated.
196-
197-
Currently this only supports updating the geometry provider and the dynamic
198-
runtime params slice provider, both of which can be updated without
199-
recompilation.
200-
201-
Args:
202-
allow_recompilation: Whether recompilation is allowed. If True, the static
203-
runtime params slice can be updated. NOTE: recompilaton may still occur
204-
if the mesh is updated or if the shapes returned in the dynamic runtime
205-
params slice provider change even if this is False.
206-
static_runtime_params_slice: The new static runtime params slice. If None,
207-
the existing one is kept.
208-
dynamic_runtime_params_slice_provider: The new dynamic runtime params
209-
slice provider. This should already have been updated with modifications
210-
to the various components. If None, the existing one is kept.
211-
geometry_provider: The new geometry provider. If None, the existing one is
212-
kept.
213-
214-
Raises:
215-
ValueError: If the Sim object has a file restart or if the geometry
216-
provider has a different mesh than the existing one.
217-
"""
218-
if self._file_restart is not None:
219-
# TODO(b/384767453): Add support for updating a Sim object with a file
220-
# restart.
221-
raise ValueError('Cannot update a Sim object with a file restart.')
222-
if not allow_recompilation and static_runtime_params_slice is not None:
223-
raise ValueError(
224-
'Cannot update a Sim object with a static runtime params slice if '
225-
'recompilation is not allowed.'
226-
)
227-
228-
if static_runtime_params_slice is not None:
229-
assert isinstance( # Avoid pytype error.
230-
self._static_runtime_params_slice,
231-
runtime_params_slice.StaticRuntimeParamsSlice,
232-
)
233-
self._static_runtime_params_slice.validate_new(
234-
static_runtime_params_slice
235-
)
236-
self._static_runtime_params_slice = static_runtime_params_slice
237-
if dynamic_runtime_params_slice_provider is not None:
238-
assert isinstance( # Avoid pytype error.
239-
self._dynamic_runtime_params_slice_provider,
240-
build_runtime_params.DynamicRuntimeParamsSliceProvider,
241-
)
242-
self._dynamic_runtime_params_slice_provider.validate_new(
243-
dynamic_runtime_params_slice_provider
244-
)
245-
self._dynamic_runtime_params_slice_provider = (
246-
dynamic_runtime_params_slice_provider
247-
)
248-
if geometry_provider is not None:
249-
if geometry_provider.torax_mesh != self._geometry_provider.torax_mesh:
250-
raise ValueError(
251-
'Cannot update a Sim object with a geometry provider with a '
252-
'different mesh.'
253-
)
254-
self._geometry_provider = geometry_provider
255-
256-
dynamic_runtime_params_slice_for_init, geo_for_init = (
257-
build_runtime_params.get_consistent_dynamic_runtime_params_slice_and_geometry(
258-
t=self._dynamic_runtime_params_slice_provider._runtime_params.numerics.t_initial, # pylint: disable=protected-access
259-
dynamic_runtime_params_slice_provider=self._dynamic_runtime_params_slice_provider,
260-
geometry_provider=self._geometry_provider,
261-
)
262-
)
263-
self._initial_state = get_initial_state(
264-
static_runtime_params_slice=self._static_runtime_params_slice,
265-
dynamic_runtime_params_slice=dynamic_runtime_params_slice_for_init,
266-
geo=geo_for_init,
267-
step_fn=self._step_fn,
268-
)
269-
270-
def run(
271-
self,
272-
log_timestep_info: bool = False,
273-
) -> output.ToraxSimOutputs:
274-
"""Runs the transport simulation over a prescribed time interval.
275-
276-
See `run_simulation` for details.
277-
278-
Args:
279-
log_timestep_info: See `run_simulation()`.
280-
281-
Returns:
282-
Tuple of all ToraxSimStates, one per time step and an additional one at
283-
the beginning for the starting state.
284-
"""
285-
return _run_simulation(
286-
static_runtime_params_slice=self.static_runtime_params_slice,
287-
dynamic_runtime_params_slice_provider=self.dynamic_runtime_params_slice_provider,
288-
geometry_provider=self.geometry_provider,
289-
initial_state=self.initial_state,
290-
step_fn=self.step_fn,
291-
log_timestep_info=log_timestep_info,
292-
)
293-
294-
@classmethod
295-
def create(
296-
cls,
297-
*,
298-
runtime_params: general_runtime_params.GeneralRuntimeParams,
299-
geometry_provider: geometry_provider_lib.GeometryProvider,
300-
stepper: stepper_pydantic_model.Stepper,
301-
transport_model: transport_model_pydantic_model.Transport,
302-
sources: source_pydantic_model.Sources,
303-
pedestal: pedestal_pydantic_model.Pedestal,
304-
time_step_calculator: Optional[ts.TimeStepCalculator] = None,
305-
file_restart: file_restart_pydantic_model.FileRestart | None = None,
306-
) -> typing_extensions.Self:
307-
"""Builds a Sim object from the input runtime params and sim components.
308-
309-
Args:
310-
runtime_params: The input runtime params used throughout the simulation
311-
run.
312-
geometry_provider: The geometry used throughout the simulation run.
313-
stepper: The stepper config that can be used to build the stepper.
314-
transport_model: The transport model config that can be used to build the
315-
transport model.
316-
sources: Builds the sources.
317-
pedestal: The pedestal config that can be used to build the pedestal.
318-
time_step_calculator: The time_step_calculator, if built, otherwise a
319-
ChiTimeStepCalculator will be built by default.
320-
file_restart: If provided we will reconstruct the initial state from the
321-
provided file at the given time step. This state from the file will only
322-
be used for constructing the initial state (as well as the config) and
323-
for all subsequent steps, the evolved state and runtime parameters from
324-
config are used.
325-
326-
Returns:
327-
sim: The built Sim instance.
328-
"""
329-
pedestal_model = pedestal.build_pedestal_model()
330-
331-
# TODO(b/385788907): Document all changes that lead to recompilations.
332-
static_runtime_params_slice = (
333-
build_runtime_params.build_static_runtime_params_slice(
334-
runtime_params=runtime_params,
335-
sources=sources,
336-
torax_mesh=geometry_provider.torax_mesh,
337-
stepper=stepper,
338-
)
339-
)
340-
dynamic_runtime_params_slice_provider = (
341-
build_runtime_params.DynamicRuntimeParamsSliceProvider(
342-
runtime_params=runtime_params,
343-
transport=transport_model,
344-
sources=sources,
345-
stepper=stepper,
346-
torax_mesh=geometry_provider.torax_mesh,
347-
pedestal=pedestal,
348-
)
349-
)
350-
source_models = source_models_lib.SourceModels(
351-
sources=sources.source_model_config
352-
)
353-
transport_model = transport_model.build_transport_model()
354-
stepper_model = stepper.build_stepper_model(
355-
transport_model=transport_model,
356-
source_models=source_models,
357-
pedestal_model=pedestal_model,
358-
)
359-
360-
if time_step_calculator is None:
361-
time_step_calculator = chi_time_step_calculator.ChiTimeStepCalculator()
362-
363-
# Build dynamic_runtime_params_slice at t_initial for initial conditions.
364-
dynamic_runtime_params_slice_for_init, geo_for_init = (
365-
build_runtime_params.get_consistent_dynamic_runtime_params_slice_and_geometry(
366-
t=runtime_params.numerics.t_initial,
367-
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
368-
geometry_provider=geometry_provider,
369-
)
370-
)
371-
if file_restart is not None and file_restart.do_restart:
372-
data_tree = output.load_state_file(file_restart.filename)
373-
# Find the closest time in the given dataset.
374-
data_tree = data_tree.sel(time=file_restart.time, method='nearest')
375-
t_restart = data_tree.time.item()
376-
core_profiles_dataset = data_tree.children[output.CORE_PROFILES].dataset
377-
# Remap coordinates in saved file to be consistent with expectations of
378-
# how config_args parses xarrays.
379-
core_profiles_dataset = core_profiles_dataset.rename(
380-
{output.RHO_CELL_NORM: config_args.RHO_NORM}
381-
)
382-
core_profiles_dataset = core_profiles_dataset.squeeze()
383-
if t_restart != runtime_params.numerics.t_initial:
384-
logging.warning(
385-
'Requested restart time %f not exactly available in state file %s.'
386-
' Restarting from closest available time %f instead.',
387-
file_restart.time,
388-
file_restart.filename,
389-
t_restart,
390-
)
391-
# Override some of dynamic runtime params slice from t=t_initial.
392-
dynamic_runtime_params_slice_for_init, geo_for_init = (
393-
_override_initial_runtime_params_from_file(
394-
dynamic_runtime_params_slice_for_init,
395-
geo_for_init,
396-
t_restart,
397-
core_profiles_dataset,
398-
)
399-
)
400-
post_processed_dataset = data_tree.children[
401-
output.POST_PROCESSED_OUTPUTS
402-
].dataset
403-
post_processed_dataset = post_processed_dataset.rename(
404-
{output.RHO_CELL_NORM: config_args.RHO_NORM}
405-
)
406-
post_processed_dataset = post_processed_dataset.squeeze()
407-
post_processed_outputs = (
408-
_override_initial_state_post_processed_outputs_from_file(
409-
geo_for_init,
410-
post_processed_dataset,
411-
)
412-
)
413-
414-
step_fn = step_function.SimulationStepFn(
415-
stepper=stepper_model,
416-
time_step_calculator=time_step_calculator,
417-
transport_model=transport_model,
418-
pedestal_model=pedestal_model,
419-
)
420-
421-
initial_state = get_initial_state(
422-
static_runtime_params_slice=static_runtime_params_slice,
423-
dynamic_runtime_params_slice=dynamic_runtime_params_slice_for_init,
424-
geo=geo_for_init,
425-
step_fn=step_fn,
426-
)
427-
428-
# If we are restarting from a file, we need to override the initial state
429-
# post processed outputs such that cumulative outputs remain correct.
430-
if file_restart is not None and file_restart.do_restart:
431-
initial_state = dataclasses.replace(
432-
initial_state,
433-
post_processed_outputs=post_processed_outputs, # pylint: disable=undefined-variable
434-
)
435-
436-
return cls(
437-
static_runtime_params_slice=static_runtime_params_slice,
438-
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
439-
geometry_provider=geometry_provider,
440-
initial_state=initial_state,
441-
step_fn=step_fn,
442-
file_restart=file_restart,
443-
)
444-
445-
44690
def _override_initial_runtime_params_from_file(
44791
dynamic_runtime_params_slice: runtime_params_slice.DynamicRuntimeParamsSlice,
44892
geo: geometry.Geometry,

0 commit comments

Comments
 (0)