diff --git a/oryx/core/interpreters/log_prob_test.py b/oryx/core/interpreters/log_prob_test.py index aa18fb9..1d34211 100644 --- a/oryx/core/interpreters/log_prob_test.py +++ b/oryx/core/interpreters/log_prob_test.py @@ -18,6 +18,7 @@ from jax import random from jax._src import api_util from jax._src import core as jax_core +from jax.extend.core import primitives from jax.extend import linear_util as lu import jax.numpy as jnp @@ -72,13 +73,13 @@ def wrapped(*args, **kwargs): fun = lu.wrap_init(f, kwargs) flat_args, in_tree = jax.tree_util.tree_flatten(args) flat_fun, out_tree = api_util.flatten_fun_nokwargs(fun, in_tree) - ans = jax_core.call_p.bind(flat_fun, *flat_args) + ans = primitives.call_p.bind(flat_fun, *flat_args) return jax.tree_util.tree_unflatten(out_tree(), ans) return wrapped -jax_core.call_p.call_primitive = True +primitives.call_p.call_primitive = True class LogProbTest(test_util.TestCase): diff --git a/oryx/core/interpreters/propagate.py b/oryx/core/interpreters/propagate.py index 7b0a84b..f1893f4 100644 --- a/oryx/core/interpreters/propagate.py +++ b/oryx/core/interpreters/propagate.py @@ -36,6 +36,7 @@ from jax._src import core as jax_core from jax._src import pjit from jax._src import sharding_impls +from jax.extend.core import primitives from jax.extend import linear_util as lu from jax.interpreters import partial_eval as pe @@ -356,8 +357,8 @@ def call_rule(prim, incells, outcells, **params): default_call_rules = {} -default_call_rules[jax_core.call_p] = functools.partial(call_rule, - jax_core.call_p) +default_call_rules[primitives.call_p] = functools.partial(call_rule, + primitives.call_p) default_call_rules[harvest.nest_p] = functools.partial(call_rule, harvest.nest_p) diff --git a/oryx/experimental/matching/jax_rewrite_test.py b/oryx/experimental/matching/jax_rewrite_test.py index bebe5a0..e6af893 100644 --- a/oryx/experimental/matching/jax_rewrite_test.py +++ b/oryx/experimental/matching/jax_rewrite_test.py @@ -15,11 +15,10 @@ """Tests for oryx.experimental.matching.jax_rewrite.""" from absl.testing import absltest - import jax from jax import lax +from jax.extend.core import primitives import jax.numpy as jnp - from oryx.experimental.matching import jax_rewrite as jr from oryx.experimental.matching import matcher from oryx.experimental.matching import rules @@ -79,22 +78,24 @@ def test_primitive_should_infer_shape_dtype_correctly(self): def test_call_primitive_should_include_call_in_trace(self): exp_expr = Exp(jr.Literal(0.)) - call_expr = jr.CallPrimitive(jax.core.call_p, (), (exp_expr,), jr.Params(), - []) + call_expr = jr.CallPrimitive( + primitives.call_p, (), (exp_expr,), jr.Params(), [] + ) jaxpr = jax.make_jaxpr(lambda: jr.evaluate(call_expr, {}))() - self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, jax.core.call_p) + self.assertEqual(jaxpr.jaxpr.eqns[0].primitive, primitives.call_p) def test_call_primitive_shape_and_dtype_are_multi_part(self): exp_expr = Exp(jr.Literal(0.)) - call_expr = jr.CallPrimitive(jax.core.call_p, (), (exp_expr,), jr.Params(), - []) + call_expr = jr.CallPrimitive( + primitives.call_p, (), (exp_expr,), jr.Params(), [] + ) self.assertTupleEqual(call_expr.shape, ((),)) self.assertEqual(call_expr.dtype, (jnp.float32,)) def test_part_infers_correct_shape_dtype(self): - call_expr = jr.CallPrimitive(jax.core.call_p, (), - (jr.Literal(0.), jr.Literal(1)), jr.Params(), - []) + call_expr = jr.CallPrimitive( + primitives.call_p, (), (jr.Literal(0.0), jr.Literal(1)), jr.Params(), [] + ) p0_expr = jr.Part(call_expr, 0) p1_expr = jr.Part(call_expr, 1) self.assertTupleEqual(p0_expr.shape, ()) @@ -151,16 +152,19 @@ def test_can_match_call_primitive_parts(self): pattern = jr.CallPrimitive( matcher.Var('prim'), matcher.Var('args'), matcher.Var('expression'), matcher.Var('params'), matcher.Var('names')) - expr = jr.CallPrimitive(jax.core.call_p, (), - (jr.Literal(0.), jr.Literal(1)), jr.Params(), []) + expr = jr.CallPrimitive( + primitives.call_p, (), (jr.Literal(0.0), jr.Literal(1)), jr.Params(), [] + ) self.assertDictEqual( matcher.match(pattern, expr), dict( - prim=jax.core.call_p, + prim=primitives.call_p, args=(), - expression=(jr.Literal(0.), jr.Literal(1.)), + expression=(jr.Literal(0.0), jr.Literal(1.0)), params=jr.Params(), - names=[])) + names=[], + ), + ) class RewriteTest(test_util.TestCase):