diff --git a/torax/time_step_calculator/pydantic_model.py b/torax/time_step_calculator/pydantic_model.py index 16c82f2de..c2c445f14 100644 --- a/torax/time_step_calculator/pydantic_model.py +++ b/torax/time_step_calculator/pydantic_model.py @@ -30,14 +30,25 @@ class TimeStepCalculatorType(enum.Enum): class TimeStepCalculator(torax_pydantic.BaseModelFrozen): - """Config for a time step calculator.""" + """Config for a time step calculator. + + Attributes: + calculator_type: The type of time step calculator to use. + tolerance: The tolerance within the final time for which the simulation + will be considered done. + """ calculator_type: TimeStepCalculatorType = TimeStepCalculatorType.CHI + tolerance: float = 1e-7 @property def time_step_calculator(self) -> time_step_calculator.TimeStepCalculator: match self.calculator_type: case TimeStepCalculatorType.CHI: - return chi_time_step_calculator.ChiTimeStepCalculator() + return chi_time_step_calculator.ChiTimeStepCalculator( + tolerance=self.tolerance + ) case TimeStepCalculatorType.FIXED: - return fixed_time_step_calculator.FixedTimeStepCalculator() + return fixed_time_step_calculator.FixedTimeStepCalculator( + tolerance=self.tolerance + ) diff --git a/torax/time_step_calculator/time_step_calculator.py b/torax/time_step_calculator/time_step_calculator.py index 24537a390..56aa12d72 100644 --- a/torax/time_step_calculator/time_step_calculator.py +++ b/torax/time_step_calculator/time_step_calculator.py @@ -18,7 +18,6 @@ """ import abc -from typing import Protocol, Union import jax from torax import state as state_module @@ -26,7 +25,7 @@ from torax.geometry import geometry -class TimeStepCalculator(Protocol): +class TimeStepCalculator(abc.ABC): """Iterates over time during simulation. Usage follows this pattern: @@ -42,12 +41,15 @@ class TimeStepCalculator(Protocol): sim_state = """ + def __init__(self, tolerance: float = 1e-7): + self.tolerance = tolerance + def not_done( self, - t: Union[float, jax.Array], + t: float | jax.Array, t_final: float, - ) -> Union[bool, jax.Array]: - return t < t_final + ) -> bool | jax.Array: + return t < (t_final - self.tolerance) @abc.abstractmethod def next_dt(