Skip to content

Commit bb95d96

Browse files
Nush395Torax team
authored and
Torax team
committed
Add tolerance to not_done condition of time step calculator.
This avoids running an apparent extra step for the final step when we are a very small amount below exact_t_final, the rounding of logged values makes it appear like t_final is reached. PiperOrigin-RevId: 744871073
1 parent 2943ce7 commit bb95d96

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

Diff for: torax/time_step_calculator/pydantic_model.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -30,14 +30,25 @@ class TimeStepCalculatorType(enum.Enum):
3030

3131

3232
class TimeStepCalculator(torax_pydantic.BaseModelFrozen):
33-
"""Config for a time step calculator."""
33+
"""Config for a time step calculator.
34+
35+
Attributes:
36+
calculator_type: The type of time step calculator to use.
37+
tolerance: The tolerance within the final time for which the simulation
38+
will be considered done.
39+
"""
3440

3541
calculator_type: TimeStepCalculatorType = TimeStepCalculatorType.CHI
42+
tolerance: float = 1e-7
3643

3744
@property
3845
def time_step_calculator(self) -> time_step_calculator.TimeStepCalculator:
3946
match self.calculator_type:
4047
case TimeStepCalculatorType.CHI:
41-
return chi_time_step_calculator.ChiTimeStepCalculator()
48+
return chi_time_step_calculator.ChiTimeStepCalculator(
49+
tolerance=self.tolerance
50+
)
4251
case TimeStepCalculatorType.FIXED:
43-
return fixed_time_step_calculator.FixedTimeStepCalculator()
52+
return fixed_time_step_calculator.FixedTimeStepCalculator(
53+
tolerance=self.tolerance
54+
)

Diff for: torax/time_step_calculator/time_step_calculator.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,14 @@
1818
"""
1919

2020
import abc
21-
from typing import Protocol, Union
2221

2322
import jax
2423
from torax import state as state_module
2524
from torax.config import runtime_params_slice
2625
from torax.geometry import geometry
2726

2827

29-
class TimeStepCalculator(Protocol):
28+
class TimeStepCalculator(abc.ABC):
3029
"""Iterates over time during simulation.
3130
3231
Usage follows this pattern:
@@ -42,12 +41,15 @@ class TimeStepCalculator(Protocol):
4241
sim_state = <update sim_state with step of size dt>
4342
"""
4443

44+
def __init__(self, tolerance: float = 1e-7):
45+
self.tolerance = tolerance
46+
4547
def not_done(
4648
self,
47-
t: Union[float, jax.Array],
49+
t: float | jax.Array,
4850
t_final: float,
49-
) -> Union[bool, jax.Array]:
50-
return t < t_final
51+
) -> bool | jax.Array:
52+
return t < (t_final - self.tolerance)
5153

5254
@abc.abstractmethod
5355
def next_dt(

0 commit comments

Comments
 (0)