File tree 2 files changed +49
-0
lines changed
2 files changed +49
-0
lines changed Original file line number Diff line number Diff line change 18
18
import functools
19
19
import os
20
20
from typing import Any , Callable , Optional , TypeVar
21
+
21
22
import chex
22
23
import equinox as eqx
23
24
import jax
@@ -257,4 +258,29 @@ def py_cond(
257
258
return false_fun ()
258
259
259
260
261
+ def get_number_of_compiles (
262
+ jitted_function : Callable [..., Any ],
263
+ ) -> int :
264
+ """Helper function for debugging JAX compilation.
265
+
266
+ This counts the number of times the function has been JIT compiled. This does
267
+ not include any uses of the AOT compile workflow.
268
+
269
+ Args:
270
+ jitted_function: A function that has been wrapped with `jax.jit`.
271
+ Returns:
272
+ The number of times the function has been compiled.
273
+ Raises:
274
+ RuntimeError: If the function does not have a _cache_size attribute.
275
+ """
276
+ # pylint: disable=protected-access
277
+ if not hasattr (jitted_function , '_cache_size' ):
278
+ raise RuntimeError (
279
+ 'The function does not have a _cache_size attribute. Possibly because'
280
+ ' the function was not jitted.'
281
+ )
282
+ return jitted_function ._cache_size ()
283
+ # pylint: enable=protected-access
284
+
285
+
260
286
# pylint: enable=g-bare-generic
Original file line number Diff line number Diff line change 16
16
from unittest import mock
17
17
from absl .testing import absltest
18
18
from absl .testing import parameterized
19
+ import jax
19
20
from jax import numpy as jnp
20
21
from torax import jax_utils
21
22
@@ -105,6 +106,28 @@ def test_f32_int_dtype(self):
105
106
"""Test that the dtype is int32 when JAX_PRECISION is set to 'f32'."""
106
107
self .assertEqual (jax_utils .get_int_dtype (), jnp .int32 )
107
108
109
+ def test_get_number_of_compiles (self ):
110
+ """Check assumptions on JAX internals are valid."""
111
+
112
+ def f (x : jax .Array ):
113
+ return x
114
+
115
+ jit_f = jax .jit (f )
116
+ self .assertTrue (hasattr (jit_f , '_cache_size' ))
117
+ # Should be 0 before any calls.
118
+ self .assertEqual (jax_utils .get_number_of_compiles (jit_f ), 0 )
119
+
120
+ # Should be 1 after one call.
121
+ jit_f (jnp .array (0 ))
122
+ self .assertEqual (jax_utils .get_number_of_compiles (jit_f ), 1 )
123
+ # Should be 1 after another call with same shape.
124
+ jit_f (jnp .array (1 ))
125
+ self .assertEqual (jax_utils .get_number_of_compiles (jit_f ), 1 )
126
+
127
+ # Should be 2 after another call with different shape.
128
+ jit_f (jnp .array ([1 ]))
129
+ self .assertEqual (jax_utils .get_number_of_compiles (jit_f ), 2 )
130
+
108
131
109
132
if __name__ == '__main__' :
110
133
absltest .main ()
You can’t perform that action at this time.
0 commit comments