Skip to content

Commit

Permalink
[JAX] Replace jax.core.call_p and jax.core.closed_call_p with `ja…
Browse files Browse the repository at this point in the history
…x.extend.core.primitives.call_p` and `jax.extend.core.primitives.closed_call_p`.

PiperOrigin-RevId: 705573507
  • Loading branch information
Jake VanderPlas authored and The oryx Authors committed Dec 13, 2024
1 parent 4ae0aaa commit 2492771
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 19 deletions.
5 changes: 3 additions & 2 deletions oryx/core/interpreters/log_prob_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from jax import random
from jax._src import api_util
from jax._src import core as jax_core
from jax.extend.core import primitives
from jax.extend import linear_util as lu
import jax.numpy as jnp

Expand Down Expand Up @@ -72,13 +73,13 @@ def wrapped(*args, **kwargs):
fun = lu.wrap_init(f, kwargs)
flat_args, in_tree = jax.tree_util.tree_flatten(args)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree)
ans = jax_core.call_p.bind(flat_fun, *flat_args)
ans = primitives.call_p.bind(flat_fun, *flat_args)
return jax.tree_util.tree_unflatten(out_tree(), ans)

return wrapped


jax_core.call_p.call_primitive = True
primitives.call_p.call_primitive = True


class LogProbTest(test_util.TestCase):
Expand Down
5 changes: 3 additions & 2 deletions oryx/core/interpreters/propagate.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from jax._src import core as jax_core
from jax._src import pjit
from jax._src import sharding_impls
from jax.extend.core import primitives
from jax.extend import linear_util as lu
from jax.interpreters import partial_eval as pe

Expand Down Expand Up @@ -356,8 +357,8 @@ def call_rule(prim, incells, outcells, **params):


default_call_rules = {}
default_call_rules[jax_core.call_p] = functools.partial(call_rule,
jax_core.call_p)
default_call_rules[primitives.call_p] = functools.partial(call_rule,
primitives.call_p)
default_call_rules[harvest.nest_p] = functools.partial(call_rule,
harvest.nest_p)

Expand Down
34 changes: 19 additions & 15 deletions oryx/experimental/matching/jax_rewrite_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
"""Tests for oryx.experimental.matching.jax_rewrite."""

from absl.testing import absltest

import jax
from jax import lax
from jax.extend.core import primitives
import jax.numpy as jnp

from oryx.experimental.matching import jax_rewrite as jr
from oryx.experimental.matching import matcher
from oryx.experimental.matching import rules
Expand Down Expand Up @@ -79,22 +78,24 @@ def test_primitive_should_infer_shape_dtype_correctly(self):

def test_call_primitive_should_include_call_in_trace(self):
exp_expr = Exp(jr.Literal(0.))
call_expr = jr.CallPrimitive(jax.core.call_p, (), (exp_expr,), jr.Params(),
[])
call_expr = jr.CallPrimitive(
primitives.call_p, (), (exp_expr,), jr.Params(), []
)
jaxpr = jax.make_jaxpr(lambda: jr.evaluate(call_expr, {}))()
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, jax.core.call_p)
self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, primitives.call_p)

def test_call_primitive_shape_and_dtype_are_multi_part(self):
exp_expr = Exp(jr.Literal(0.))
call_expr = jr.CallPrimitive(jax.core.call_p, (), (exp_expr,), jr.Params(),
[])
call_expr = jr.CallPrimitive(
primitives.call_p, (), (exp_expr,), jr.Params(), []
)
self.assertTupleEqual(call_expr.shape, ((),))
self.assertEqual(call_expr.dtype, (jnp.float32,))

def test_part_infers_correct_shape_dtype(self):
call_expr = jr.CallPrimitive(jax.core.call_p, (),
(jr.Literal(0.), jr.Literal(1)), jr.Params(),
[])
call_expr = jr.CallPrimitive(
primitives.call_p, (), (jr.Literal(0.0), jr.Literal(1)), jr.Params(), []
)
p0_expr = jr.Part(call_expr, 0)
p1_expr = jr.Part(call_expr, 1)
self.assertTupleEqual(p0_expr.shape, ())
Expand Down Expand Up @@ -151,16 +152,19 @@ def test_can_match_call_primitive_parts(self):
pattern = jr.CallPrimitive(
matcher.Var('prim'), matcher.Var('args'), matcher.Var('expression'),
matcher.Var('params'), matcher.Var('names'))
expr = jr.CallPrimitive(jax.core.call_p, (),
(jr.Literal(0.), jr.Literal(1)), jr.Params(), [])
expr = jr.CallPrimitive(
primitives.call_p, (), (jr.Literal(0.0), jr.Literal(1)), jr.Params(), []
)
self.assertDictEqual(
matcher.match(pattern, expr),
dict(
prim=jax.core.call_p,
prim=primitives.call_p,
args=(),
expression=(jr.Literal(0.), jr.Literal(1.)),
expression=(jr.Literal(0.0), jr.Literal(1.0)),
params=jr.Params(),
names=[]))
names=[],
),
)


class RewriteTest(test_util.TestCase):
Expand Down

0 comments on commit 2492771

Please sign in to comment.