Skip to content

Jit individual parts of the transport model instead of the whole thing #898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions torax/orchestration/step_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ def __init__(
self._time_step_calculator = time_step_calculator
self._transport_model = transport_model
self._pedestal_model = pedestal_model
self._jitted_transport_model = jax_utils.jit(
transport_model.__call__,
)

@property
def pedestal_model(self) -> pedestal_model_lib.PedestalModel:
Expand Down Expand Up @@ -236,7 +233,7 @@ def init_time_step_calculator(
pedestal_model_output = self._pedestal_model(
dynamic_runtime_params_slice_t, geo_t, input_state.core_profiles
)
transport_coeffs = self._jitted_transport_model(
transport_coeffs = self._transport_model(
dynamic_runtime_params_slice_t,
geo_t,
input_state.core_profiles,
Expand Down
13 changes: 13 additions & 0 deletions torax/pedestal_model/pedestal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,3 +86,16 @@ def _call_implementation(
core_profiles: state.CoreProfiles,
) -> PedestalModelOutput:
"""Calculate the pedestal values."""

@abc.abstractmethod
def __hash__(self) -> int:
"""Hash function for the pedestal model.

Needed for jax.jit caching to work.
"""
...

@abc.abstractmethod
def __eq__(self, other) -> bool:
"""Equality function for the pedestal model."""
...
6 changes: 6 additions & 0 deletions torax/pedestal_model/set_pped_tpedratio_nped.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,3 +114,9 @@ def _call_implementation(
Teped=Teped,
rho_norm_ped_top=dynamic_runtime_params_slice.pedestal.rho_norm_ped_top,
)

def __hash__(self) -> int:
return hash(('SetPressureTemperatureRatioAndDensityPedestalModel'))

def __eq__(self, other) -> bool:
return isinstance(other, SetPressureTemperatureRatioAndDensityPedestalModel)
6 changes: 6 additions & 0 deletions torax/pedestal_model/set_tped_nped.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,9 @@ def _call_implementation(
Teped=dynamic_runtime_params_slice.pedestal.Teped,
rho_norm_ped_top=dynamic_runtime_params_slice.pedestal.rho_norm_ped_top,
)

def __hash__(self) -> int:
return hash(('SetTemperatureDensityPedestalModel'))

def __eq__(self, other) -> bool:
return isinstance(other, SetTemperatureDensityPedestalModel)
10 changes: 10 additions & 0 deletions torax/sources/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,13 @@ def get_value(
return (zeros,) * len(self.affected_core_profiles)
case _:
raise ValueError(f'Unknown mode: {mode}')

def __hash__(self) -> int:
return hash((self.SOURCE_NAME, self.model_func))

def __eq__(self, other) -> bool:
return (
isinstance(other, type(self))
and self.SOURCE_NAME == other.SOURCE_NAME
and self.model_func == other.model_func
)
23 changes: 13 additions & 10 deletions torax/sources/source_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,8 @@ def __init__(
# The rest of the sources are "standard".
self._standard_sources = {}

# Divide up the sources based on which core profiles they affect.
# Pull out the psi sources as these are calculated first.
self._psi_sources: dict[str, source_lib.Source] = {}
self._ne_sources: dict[str, source_lib.Source] = {}
self._temp_ion_sources: dict[str, source_lib.Source] = {}
self._temp_el_sources: dict[str, source_lib.Source] = {}

# First set the "special" sources.
for source in sources.values():
Expand Down Expand Up @@ -180,12 +177,6 @@ def _add_standard_source(
self._standard_sources[source_name] = source
if source_lib.AffectedCoreProfile.PSI in source.affected_core_profiles:
self._psi_sources[source_name] = source
if source_lib.AffectedCoreProfile.NE in source.affected_core_profiles:
self._ne_sources[source_name] = source
if source_lib.AffectedCoreProfile.TEMP_ION in source.affected_core_profiles:
self._temp_ion_sources[source_name] = source
if source_lib.AffectedCoreProfile.TEMP_EL in source.affected_core_profiles:
self._temp_el_sources[source_name] = source

# Some sources require direct access, so this class defines properties for
# those sources.
Expand Down Expand Up @@ -229,3 +220,15 @@ def sources(self) -> dict[str, source_lib.Source]:
self.j_bootstrap_name: self.j_bootstrap,
self.qei_source_name: self.qei_source,
}

def __hash__(self) -> int:
hashes = [hash(source) for source in self.sources.values()]
return hash(tuple(hashes))

def __eq__(self, other) -> bool:
if set(self.sources.keys()) == set(other.sources.keys()):
return all(
self.sources[name] == other.sources[name]
for name in self.sources.keys()
)
return False
6 changes: 6 additions & 0 deletions torax/tests/sim_time_dependence_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ def _call_implementation(
) -> state.CoreTransport:
return state.CoreTransport.zeros(geo)

def __hash__(self) -> int:
return hash(('FakeTransportModel'))

def __eq__(self, other) -> bool:
return isinstance(other, FakeTransportModel)


class FakeTransportConfig(transport_pydantic_model_base.TransportBase):
"""Fake transport config for a model that always returns zeros."""
Expand Down
9 changes: 9 additions & 0 deletions torax/transport_model/qualikiz_transport_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,15 @@ def _extract_run_data(
gyrobohm_flux_reference_length=geo.Rmin,
)

def __hash__(self) -> int:
return hash(('QualikizTransportModel' + self._runpath))

def __eq__(self, other) -> bool:
return (
isinstance(other, QualikizTransportModel)
and self._runpath == other._runpath
)


def _extract_qualikiz_plan(
qualikiz_inputs: qualikiz_based_transport_model.QualikizInputs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,12 @@ def _call_implementation(
gyrobohm_flux_reference_length=geo.Rmin,
)

def __hash__(self) -> int:
return hash(('FakeQualikizBasedTransportModel'))

def __eq__(self, other) -> bool:
return isinstance(other, FakeQualikizBasedTransportModel)


# pylint: disable=invalid-name
class QualikizBasedTransportModelConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,12 @@ def _call_implementation(
gyrobohm_flux_reference_length=1.0,
)

def __hash__(self):
return hash('FakeQuasilinearTransportModel')

def __eq__(self, other):
return isinstance(other, FakeQuasilinearTransportModel)


def _get_dummy_core_profiles(value, right_face_constraint):
"""Returns dummy core profiles for testing."""
Expand Down
6 changes: 6 additions & 0 deletions torax/transport_model/tests/transport_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,12 @@ def _call_implementation(
v_face_el=v_face_el,
)

def __hash__(self) -> int:
return hash(('FakeTransportModel'))

def __eq__(self, other) -> bool:
return isinstance(other, FakeTransportModel)


class FakeTransportConfig(transport_pydantic_model_base.TransportBase):
"""Fake transport config for a model that always returns zeros."""
Expand Down
Loading