You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Make static_runtime_params a nested structure similar to dynamic_runtime_params OR investigate unifying static + dynamic runtime_params and use pydantic to provide relevant metadata for static variables
Make as many config booleans as possible static simplify control flow and speed up JAX compilation
Following the above jit various inter-step routines like get_ne and improve runtime performance
The text was updated successfully, but these errors were encountered:
I've researched this issue and analyzed how static_runtime_params can be structured similarly to dynamic_runtime_params. I have a plan to:
Introduce a nested StaticRuntimeParamsSlice structure.
Group static parameters in a structured way (e.g., sources, stepper, torax_mesh, and boolean flags).
Refactor the code to build static_runtime_params in a way that simplifies control flow.
Ensure JAX treats these params as static to optimize compilation and improve performance.
I would like to take on this issue and implement the necessary changes. Let me know if there are any specific considerations I should keep in mind before proceeding
HI @msaadg . Thank you, we appreciate your desire to contribute!
Recently we realized that it may be possible to actually do away with the dynamic vs static runtime_params split and leverage the upcoming pydantic config patterns to annotate variables as static for JAX compilation. We will thus assign this ticket internally while investigating this. Apologies for that!
The text was updated successfully, but these errors were encountered: