-
In theory, EDIT: I went with sprinkling |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 3 replies
-
Right now, nested |
Beta Was this translation helpful? Give feedback.
-
For anyone coming back to this, For example, >>> f = lambda x: sum(i * x for i in range(256)) # common function
>>> jf = jax.jit(f)
>>> %timeit jax.jit(lambda x: sum(f(x) for _ in range(16)))(3.14)
2.79 s ± 133 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
>>> %timeit jax.jit(lambda x: sum(jf(x) for _ in range(16)))(3.14)
36.5 ms ± 9.6 ms per loop (mean ± std. dev. of 7 runs, 10 loops each) |
Beta Was this translation helpful? Give feedback.
Right now, nested
jit
calls will be preserved as function calls in the IR that JAX generates, but will be flattened by XLA. In the future, this may not be true any more! At some point, XLA (or another compiler) may inline less aggressively.