Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade static_runtime_params #750

Open
3 tasks
jcitrin opened this issue Feb 24, 2025 · 2 comments
Open
3 tasks

Upgrade static_runtime_params #750

jcitrin opened this issue Feb 24, 2025 · 2 comments
Assignees

Comments

@jcitrin
Copy link
Collaborator

jcitrin commented Feb 24, 2025

  • Make static_runtime_params a nested structure similar to dynamic_runtime_params OR investigate unifying static + dynamic runtime_params and use pydantic to provide relevant metadata for static variables
  • Make as many config booleans as possible static simplify control flow and speed up JAX compilation
  • Following the above jit various inter-step routines like get_ne and improve runtime performance
@msaadg
Copy link

msaadg commented Mar 5, 2025

Hi @jcitrin,

I've researched this issue and analyzed how static_runtime_params can be structured similarly to dynamic_runtime_params. I have a plan to:

  • Introduce a nested StaticRuntimeParamsSlice structure.
  • Group static parameters in a structured way (e.g., sources, stepper, torax_mesh, and boolean flags).
  • Refactor the code to build static_runtime_params in a way that simplifies control flow.
  • Ensure JAX treats these params as static to optimize compilation and improve performance.

I would like to take on this issue and implement the necessary changes. Let me know if there are any specific considerations I should keep in mind before proceeding

@jcitrin
Copy link
Collaborator Author

jcitrin commented Mar 6, 2025

HI @msaadg . Thank you, we appreciate your desire to contribute!

Recently we realized that it may be possible to actually do away with the dynamic vs static runtime_params split and leverage the upcoming pydantic config patterns to annotate variables as static for JAX compilation. We will thus assign this ticket internally while investigating this. Apologies for that!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants