From b198cff1408fdb19d6ea642498e4f9e77db118e9 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Fri, 8 Dec 2023 14:42:15 +0000 Subject: [PATCH] add test_scan_negative_axes --- tests/linen/linen_transforms_test.py | 29 ++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index e22ef97a28..e7cf13326e 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -373,6 +373,35 @@ def __call__(self, c, b, xs): np.testing.assert_allclose(c[0], c2[0], atol=1e-7) np.testing.assert_allclose(c[1], c2[1], atol=1e-7) + def test_scan_negative_axes(self): + class Foo(nn.Module): + @nn.compact + def __call__(self, _, x): + x = nn.Dense(4)(x) + return None, x + + class Bar(nn.Module): + @nn.compact + def __call__(self, x): + _, x = nn.scan( + Foo, + variable_broadcast='params', + split_rngs=dict(params=False), + in_axes=1, + out_axes=-1, + )()(None, x) + return x + + y, variables = Bar().init_with_output( + {'params': jax.random.PRNGKey(0)}, + jax.random.normal(jax.random.PRNGKey(1), shape=[1, 2, 3]), + ) + params = variables['params'] + + self.assertEqual(y.shape, (1, 4, 2)) + self.assertEqual(params['ScanFoo_0']['Dense_0']['kernel'].shape, (3, 4)) + self.assertEqual(params['ScanFoo_0']['Dense_0']['bias'].shape, (4,)) + def test_multiscope_lifting_simple(self): class Counter(nn.Module): @nn.compact