When does recreating a model trigger nnx.jit
recompilation?
#4474
Unanswered
NiklasKappel
asked this question in
Q&A
Replies: 2 comments 2 replies
-
Further testing reveals the culprit is the use of class CNN(nnx.Module):
def __init__(self, *, rngs: nnx.Rngs):
self.conv1 = nnx.Conv(1, 32, kernel_size=(3, 3), rngs=rngs)
self.conv2 = nnx.Conv(32, 64, kernel_size=(3, 3), rngs=rngs)
# self.avg_pool = partial(nnx.avg_pool, window_shape=(2, 2), strides=(2, 2))
self.avg_pool = nnx.avg_pool
self.linear1 = nnx.Linear(3136, 256, rngs=rngs)
self.linear2 = nnx.Linear(256, 10, rngs=rngs)
def __call__(self, x):
x = self.avg_pool(nnx.relu(self.conv1(x)), window_shape=(2, 2), strides=(2, 2))
x = self.avg_pool(nnx.relu(self.conv2(x)), window_shape=(2, 2), strides=(2, 2))
x = x.reshape(x.shape[0], -1) # flatten
x = nnx.relu(self.linear1(x))
x = self.linear2(x)
return x My guess is that |
Beta Was this translation helpful? Give feedback.
0 replies
-
Maybe the behavior is different with jax.tree_util.Partial? Have you tried? |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Consider this MWE that uses the model from the MNIST tutorial:
Note that, when running
try_CNN
, thefoo_step
function is JIT compiled twice (i.e. theprint
side effect is triggered twice), once with the first CNN instance and once with the second, even though they are essentially the same. When runningtry_Foo
though, thefoo_step
function is compiled only once. The latter is the behavior I would expect fromjax.jit
, considering that recreating a model does not change anything about the shape of parameter arrays etc.What is the reason for the extra compilation happening in
try_CNN
?Beta Was this translation helpful? Give feedback.
All reactions