You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Getting the following error while re-running the workload from the saved checkpoint,
Note that the error doesnt come when I run the the command for the first time or if I rm -rf the my_first_experiment folder
Traceback (most recent call last):
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 714, in <module>
app.run(main)
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 308, in run
_run_main(main, args)
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/absl/app.py", line 254, in _run_main
sys.exit(main(argv))
^^^^^^^^^^
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 682, in main
score = score_submission_on_workload(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 587, in score_submission_on_workload
timing, metrics = train_once(workload, workload_name,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/submission_runner.py", line 302, in train_once
preemption_count) = checkpoint_utils.maybe_restore_checkpoint(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/Desktop/opensource/isaac_algorithmic_efficiency/algorithmic-efficiency/algorithmic_efficiency/checkpoint_utils.py", line 82, in maybe_restore_checkpoint
latest_ckpt = flax_checkpoints.restore_checkpoint(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/flax/training/checkpoints.py", line 1128, in restore_checkpoint
restored = orbax_checkpointer.restore(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/orbax/checkpoint/checkpointer.py", line 239, in restore
restored = self._handler.restore(directory, args=ckpt_args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 820, in restore
restored_item = asyncio_utils.run_sync(
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/orbax/checkpoint/_src/asyncio_utils.py", line 50, in run_sync
return asyncio.run(coro)
^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/asyncio/runners.py", line 190, in run
return runner.run(main)
^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/asyncio/runners.py", line 118, in run
return self._loop.run_until_complete(task)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/asyncio/base_events.py", line 654, in run_until_complete
return future.result()
^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler.py", line 617, in _maybe_deserialize
deserialized_batches += await asyncio.gather(*deserialized_batches_ops)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/orbax/checkpoint/_src/serialization/type_handlers.py", line 1442, in deserialize
ret = await asyncio.gather(*deserialize_ops)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/orbax/checkpoint/_src/serialization/serialization.py", line 615, in async_deserialize
new_shard_shape = in_sharding.shard_shape(tuple(shape))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/isaac/miniconda3/envs/main/lib/python3.11/site-packages/jax/_src/sharding_impls.py", line 648, in shard_shape
raise ValueError(
ValueError: The sharded dimension must be equal to the number of devices passed to PmapSharding. Got sharded dimension 0 with value 128 in shape (128,) and the number of devices=1
The text was updated successfully, but these errors were encountered:
priyakasimbeg
changed the title
ValueError: The sharded dimension must be equal to the number of devices passed to PmapSharding. Got sharded dimension 0 with value 128 in shape (128,) and the number of devices=1
Resume from Checkpoint bug: ValueError: The sharded dimension must be equal to the number of devices passed to PmapSharding.
Nov 25, 2024
System Info:
Ubuntu 20.04,
Python 3.11,
Nvidia3080ti
Jax Versions:
jax==0.4.35
jax-cuda12-pjrt==0.4.35
jax-cuda12-plugin==0.4.35
jaxlib==0.4.35
Getting the following error while re-running the workload from the saved checkpoint,
Note that the error doesnt come when I run the the command for the first time or if I rm -rf the my_first_experiment folder
Here is the command used:
Here is the full traceback:
The text was updated successfully, but these errors were encountered: