From 82611eb8ae85dd3d66a31a2a19ad289a1fa202f4 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 5 Feb 2024 12:01:33 -0800 Subject: [PATCH] document that under disable_jit, individual primitives are still compiled --- jax/_src/api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/jax/_src/api.py b/jax/_src/api.py index e9b502c84a24..7af35361d4ab 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -336,6 +336,8 @@ def disable_jit(disable: bool = True): `cond` functions passed to higher-level primitives like :func:`~jax.lax.scan` and :func:`~jax.lax.while_loop`, JIT used in implementations of :mod:`jax.numpy` functions, and any other case where :func:`jit` is used within an API's implementation. + Note however that even under `disable_jit`, individual primitive operations + will still be compiled by XLA as in normal eager op-by-op execution. Values that have a data dependence on the arguments to a jitted function are traced and abstracted. For example, an abstract value may be a