Skip to content

Commit

Permalink
Export jax.lax.sharding_constraint_p
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 534566582
  • Loading branch information
Jake VanderPlas authored and jax authors committed May 23, 2023
1 parent db87167 commit 7f7f995
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions jax/lax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@
from jax.lax import linalg as linalg

from jax._src.pjit import with_sharding_constraint as with_sharding_constraint
from jax._src.pjit import sharding_constraint_p as sharding_constraint_p
from jax._src.dispatch import device_put_p as device_put_p

from math import prod # TODO(phawkins): remove this accidental export

0 comments on commit 7f7f995

Please sign in to comment.