Skip to content

Commit 7401b1c

Browse files
jcitrinTorax team
authored and
Torax team
committed
Implement hash and eq for sawtooth models
PiperOrigin-RevId: 745079788
1 parent 3ac544c commit 7401b1c

File tree

4 files changed

+45
-1
lines changed

4 files changed

+45
-1
lines changed

Diff for: torax/mhd/sawtooth/sawtooth_model.py

+32
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,17 @@ def __call__(
3737
) -> tuple[array_typing.ScalarBool, array_typing.ScalarFloat]:
3838
"""Indicates if a crash is triggered and the radius of the q=1 surface."""
3939

40+
@abc.abstractmethod
41+
def __hash__(self) -> int:
42+
"""Returns a hash of the trigger model.
43+
44+
Should be implemented to support jax.jit caching.
45+
"""
46+
47+
@abc.abstractmethod
48+
def __eq__(self, other: object) -> bool:
49+
"""Equality method to be implemented to support jax.jit caching."""
50+
4051

4152
class RedistributionModel(abc.ABC):
4253
"""Abstract base class for sawtooth redistribution models."""
@@ -52,6 +63,17 @@ def __call__(
5263
) -> state.CoreProfiles:
5364
"""Returns a redistributed core_profiles if sawtooth has been triggered."""
5465

66+
@abc.abstractmethod
67+
def __hash__(self) -> int:
68+
"""Returns a hash of the redistribution model.
69+
70+
Should be implemented to support jax.jit caching.
71+
"""
72+
73+
@abc.abstractmethod
74+
def __eq__(self, other) -> bool:
75+
"""Equality method to be implemented to support jax.jit caching."""
76+
5577

5678
class SawtoothModel:
5779
"""Sawtooth trigger and redistribution, and carries out sawtooth step."""
@@ -112,3 +134,13 @@ def __call__(
112134
# modify output state with new time, dt, and core_profiles if triggered.
113135

114136
return input_state
137+
138+
def __hash__(self) -> int:
139+
return hash((self.trigger_model, self.redistribution_model))
140+
141+
def __eq__(self, other: object) -> bool:
142+
return (
143+
isinstance(other, SawtoothModel)
144+
and self.trigger_model == other.trigger_model
145+
and self.redistribution_model == other.redistribution_model
146+
)

Diff for: torax/mhd/sawtooth/simple_redistribution.py

+6
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ def __call__(
5454

5555
return core_profiles
5656

57+
def __hash__(self) -> int:
58+
return hash(self.__class__.__name__)
59+
60+
def __eq__(self, other: object) -> bool:
61+
return isinstance(other, SimpleRedistribution)
62+
5763

5864
@chex.dataclass(frozen=True)
5965
class DynamicRuntimeParams(runtime_params.RedistributionDynamicRuntimeParams):

Diff for: torax/mhd/sawtooth/simple_trigger.py

+6
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ def __call__(
103103
rho_norm_q1,
104104
)
105105

106+
def __hash__(self) -> int:
107+
return hash(self.__class__.__name__)
108+
109+
def __eq__(self, other: object) -> bool:
110+
return isinstance(other, SimpleTrigger)
111+
106112

107113
@chex.dataclass(frozen=True)
108114
class DynamicRuntimeParams(runtime_params.TriggerDynamicRuntimeParams):

Diff for: torax/tests/sim_test.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ class SimTest(sim_test_case.SimTestCase):
405405
'test_iterhybrid_newton',
406406
'test_iterhybrid_newton.py',
407407
_ALL_PROFILES,
408-
5e-7,
408+
1e-6,
409409
),
410410
# Tests current and density rampup for for ITER-hybrid-like-config
411411
# using Newton-Raphson. Only case which reverts to coarse_tol for several

0 commit comments

Comments
 (0)