From dd4a9d67799127a503d452a842934067e1ec630e Mon Sep 17 00:00:00 2001 From: Dibya Ghosh Date: Wed, 13 Dec 2023 21:48:04 -0800 Subject: [PATCH 1/2] Fixing finetune.py for arbitrary # GPUs --- scripts/finetune.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/scripts/finetune.py b/scripts/finetune.py index 463375ce..1d50c5c1 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -82,6 +82,13 @@ def main(_): # ######### + assert ( + FLAGS.config.batch_size % len(devices) == 0 + ), f"Batch size ({FLAGS.config.batch_size}) must be divisible by the number of devices ({len(devices)})" + assert ( + FLAGS.config.viz_kwargs.eval_batch_size % len(devices) == 0, + ), f"Eval batch size ({FLAGS.config.viz_kwargs.eval_batch_size}) must be divisible by the number of devices ({len(devices)})" + # create a 1D mesh with a single axis named "batch" mesh = Mesh(jax.devices(), axis_names="batch") # Our batches will be data-parallel sharded -- each device will get a slice of the batch From e595cd9baeb84ff42c828d960ebaeae4b404c3b8 Mon Sep 17 00:00:00 2001 From: Kevin Black <12429600+kvablack@users.noreply.github.com> Date: Wed, 13 Dec 2023 22:35:30 -0800 Subject: [PATCH 2/2] Fix typo --- scripts/finetune.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 1d50c5c1..fbe554a2 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -86,7 +86,7 @@ def main(_): FLAGS.config.batch_size % len(devices) == 0 ), f"Batch size ({FLAGS.config.batch_size}) must be divisible by the number of devices ({len(devices)})" assert ( - FLAGS.config.viz_kwargs.eval_batch_size % len(devices) == 0, + FLAGS.config.viz_kwargs.eval_batch_size % len(devices) == 0 ), f"Eval batch size ({FLAGS.config.viz_kwargs.eval_batch_size}) must be divisible by the number of devices ({len(devices)})" # create a 1D mesh with a single axis named "batch"