Skip to content

Commit

Permalink
fix mypy error
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Oct 22, 2024
1 parent 587832f commit 8498502
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,21 +1568,21 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding]
'Please see the jax.Array migration guide for more information '
'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. '
f'Got arg shape: {arg.shape}, arg value: {arg}')
if not is_unspecified(arg_s):
if not isinstance(arg_s, UnspecifiedValue):
# jax.jit does not allow resharding across different memory kinds even
# if the argument is uncommitted. Use jax.device_put for those cases,
# either outside or inside jax.jit.
if pjit_in_s.memory_kind != arg_s.memory_kind: # type: ignore
raise ValueError(
'Memory kinds passed to jax.jit does not match memory kind on the'
f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore
f'arg memory kind: {arg_s.memory_kind} for ' # pytype: disable=attribute-error
f'arg memory kind: {arg_s.memory_kind} for '
f'arg shape: {shaped_abstractify(arg).str_short()}')
if (committed and
not isinstance(arg_s, PmapSharding) and
not op_shardings.are_op_shardings_equal(
pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore
arg_s._to_xla_hlo_sharding(arg.ndim))): # type: ignore
arg_s._to_xla_hlo_sharding(arg.ndim))):
raise ValueError('Sharding passed to pjit does not match the sharding '
'on the respective arg. '
f'Got pjit sharding: {pjit_in_s},\n'
Expand Down

0 comments on commit 8498502

Please sign in to comment.