From 849850216d4e0df9205f6b269f6c2ac855ac5658 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 22 Oct 2024 11:10:10 -0700 Subject: [PATCH] fix mypy error --- jax/_src/pjit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 12e310904645..a69e8987b2d8 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1568,7 +1568,7 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] 'Please see the jax.Array migration guide for more information ' 'https://jax.readthedocs.io/en/latest/jax_array_migration.html#handling-of-host-local-inputs-to-pjit-like-batch-etc. ' f'Got arg shape: {arg.shape}, arg value: {arg}') - if not is_unspecified(arg_s): + if not isinstance(arg_s, UnspecifiedValue): # jax.jit does not allow resharding across different memory kinds even # if the argument is uncommitted. Use jax.device_put for those cases, # either outside or inside jax.jit. @@ -1576,13 +1576,13 @@ def _resolve_in_shardings(args, pjit_in_shardings: Sequence[PjitSharding] raise ValueError( 'Memory kinds passed to jax.jit does not match memory kind on the' f' respective arg. Got pjit memory kind: {pjit_in_s.memory_kind}, ' # type: ignore - f'arg memory kind: {arg_s.memory_kind} for ' # pytype: disable=attribute-error + f'arg memory kind: {arg_s.memory_kind} for ' f'arg shape: {shaped_abstractify(arg).str_short()}') if (committed and not isinstance(arg_s, PmapSharding) and not op_shardings.are_op_shardings_equal( pjit_in_s._to_xla_hlo_sharding(arg.ndim), # type: ignore - arg_s._to_xla_hlo_sharding(arg.ndim))): # type: ignore + arg_s._to_xla_hlo_sharding(arg.ndim))): raise ValueError('Sharding passed to pjit does not match the sharding ' 'on the respective arg. ' f'Got pjit sharding: {pjit_in_s},\n'