diff --git a/examples/getting_started/tf/nvflare_tf_getting_started.ipynb b/examples/getting_started/tf/nvflare_tf_getting_started.ipynb index b49836d9d5..25031497a5 100644 --- a/examples/getting_started/tf/nvflare_tf_getting_started.ipynb +++ b/examples/getting_started/tf/nvflare_tf_getting_started.ipynb @@ -314,9 +314,12 @@ "metadata": {}, "outputs": [], "source": [ + "from nvflare.client.config import ExchangeFormat\n", + "\n", "for i in range(n_clients):\n", " executor = ScriptExecutor(\n", " task_script_path=\"src/cifar10_tf_fl.py\", task_script_args=\"\" # f\"--batch_size 32 --data_path /tmp/data/site-{i}\"\n", + " params_exchange_format=ExchangeFormat.NUMPY,\n", " )\n", " job.to(executor, f\"site-{i}\", gpu=0)" ] diff --git a/examples/getting_started/tf/tf_fedavg_script_executor_cifar10.py b/examples/getting_started/tf/tf_fedavg_script_executor_cifar10.py index f6603b5f95..a7972d8a58 100644 --- a/examples/getting_started/tf/tf_fedavg_script_executor_cifar10.py +++ b/examples/getting_started/tf/tf_fedavg_script_executor_cifar10.py @@ -15,6 +15,7 @@ from src.tf_net import TFNet from nvflare import FedAvg, FedJob, ScriptExecutor +from nvflare.client.config import ExchangeFormat if __name__ == "__main__": n_clients = 2 @@ -36,7 +37,9 @@ # Add clients for i in range(n_clients): executor = ScriptExecutor( - task_script_path=train_script, task_script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" + task_script_path=train_script, + task_script_args="", # f"--batch_size 32 --data_path /tmp/data/site-{i}" + params_exchange_format=ExchangeFormat.NUMPY, ) job.to(executor, f"site-{i}", gpu=0)