File tree 2 files changed +21
-8
lines changed
torax/time_step_calculator
2 files changed +21
-8
lines changed Original file line number Diff line number Diff line change @@ -30,14 +30,25 @@ class TimeStepCalculatorType(enum.Enum):
30
30
31
31
32
32
class TimeStepCalculator (torax_pydantic .BaseModelFrozen ):
33
- """Config for a time step calculator."""
33
+ """Config for a time step calculator.
34
+
35
+ Attributes:
36
+ calculator_type: The type of time step calculator to use.
37
+ tolerance: The tolerance within the final time for which the simulation
38
+ will be considered done.
39
+ """
34
40
35
41
calculator_type : TimeStepCalculatorType = TimeStepCalculatorType .CHI
42
+ tolerance : float = 1e-7
36
43
37
44
@property
38
45
def time_step_calculator (self ) -> time_step_calculator .TimeStepCalculator :
39
46
match self .calculator_type :
40
47
case TimeStepCalculatorType .CHI :
41
- return chi_time_step_calculator .ChiTimeStepCalculator ()
48
+ return chi_time_step_calculator .ChiTimeStepCalculator (
49
+ tolerance = self .tolerance
50
+ )
42
51
case TimeStepCalculatorType .FIXED :
43
- return fixed_time_step_calculator .FixedTimeStepCalculator ()
52
+ return fixed_time_step_calculator .FixedTimeStepCalculator (
53
+ tolerance = self .tolerance
54
+ )
Original file line number Diff line number Diff line change 18
18
"""
19
19
20
20
import abc
21
- from typing import Protocol , Union
22
21
23
22
import jax
24
23
from torax import state as state_module
25
24
from torax .config import runtime_params_slice
26
25
from torax .geometry import geometry
27
26
28
27
29
- class TimeStepCalculator (Protocol ):
28
+ class TimeStepCalculator (abc . ABC ):
30
29
"""Iterates over time during simulation.
31
30
32
31
Usage follows this pattern:
@@ -42,12 +41,15 @@ class TimeStepCalculator(Protocol):
42
41
sim_state = <update sim_state with step of size dt>
43
42
"""
44
43
44
+ def __init__ (self , tolerance : float = 1e-7 ):
45
+ self .tolerance = tolerance
46
+
45
47
def not_done (
46
48
self ,
47
- t : Union [ float , jax .Array ] ,
49
+ t : float | jax .Array ,
48
50
t_final : float ,
49
- ) -> Union [ bool , jax .Array ] :
50
- return t < t_final
51
+ ) -> bool | jax .Array :
52
+ return t < ( t_final - self . tolerance )
51
53
52
54
@abc .abstractmethod
53
55
def next_dt (
You can’t perform that action at this time.
0 commit comments