Skip to content

Commit 35abf3d

Browse files
Nush395Torax team
authored and
Torax team
committed
Add helper function for counting number of times a function has been JIT compiled.
Creating a standalone function allows us to: - have single point of responsibility - test whether the assumptions we make on JAX internals are valid. PiperOrigin-RevId: 743897133
1 parent 8f95c3d commit 35abf3d

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

torax/jax_utils.py

+26
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import functools
1919
import os
2020
from typing import Any, Callable, Optional, TypeVar
21+
2122
import chex
2223
import equinox as eqx
2324
import jax
@@ -257,4 +258,29 @@ def py_cond(
257258
return false_fun()
258259

259260

261+
def get_number_of_compiles(
262+
jitted_function: Callable[..., Any],
263+
) -> int:
264+
"""Helper function for debugging JAX compilation.
265+
266+
This counts the number of times the function has been JIT compiled. This does
267+
not include any uses of the AOT compile workflow.
268+
269+
Args:
270+
jitted_function: A function that has been wrapped with `jax.jit`.
271+
Returns:
272+
The number of times the function has been compiled.
273+
Raises:
274+
RuntimeError: If the function does not have a _cache_size attribute.
275+
"""
276+
# pylint: disable=protected-access
277+
if not hasattr(jitted_function, '_cache_size'):
278+
raise RuntimeError(
279+
'The function does not have a _cache_size attribute. Possibly because'
280+
' the function was not jitted.'
281+
)
282+
return jitted_function._cache_size()
283+
# pylint: enable=protected-access
284+
285+
260286
# pylint: enable=g-bare-generic

torax/tests/jax_utils_test.py

+23
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from unittest import mock
1717
from absl.testing import absltest
1818
from absl.testing import parameterized
19+
import jax
1920
from jax import numpy as jnp
2021
from torax import jax_utils
2122

@@ -105,6 +106,28 @@ def test_f32_int_dtype(self):
105106
"""Test that the dtype is int32 when JAX_PRECISION is set to 'f32'."""
106107
self.assertEqual(jax_utils.get_int_dtype(), jnp.int32)
107108

109+
def test_get_number_of_compiles(self):
110+
"""Check assumptions on JAX internals are valid."""
111+
112+
def f(x: jax.Array):
113+
return x
114+
115+
jit_f = jax.jit(f)
116+
self.assertTrue(hasattr(jit_f, '_cache_size'))
117+
# Should be 0 before any calls.
118+
self.assertEqual(jax_utils.get_number_of_compiles(jit_f), 0)
119+
120+
# Should be 1 after one call.
121+
jit_f(jnp.array(0))
122+
self.assertEqual(jax_utils.get_number_of_compiles(jit_f), 1)
123+
# Should be 1 after another call with same shape.
124+
jit_f(jnp.array(1))
125+
self.assertEqual(jax_utils.get_number_of_compiles(jit_f), 1)
126+
127+
# Should be 2 after another call with different shape.
128+
jit_f(jnp.array([1]))
129+
self.assertEqual(jax_utils.get_number_of_compiles(jit_f), 2)
130+
108131

109132
if __name__ == '__main__':
110133
absltest.main()

0 commit comments

Comments
 (0)