From e851405560f4170c70e07343da78ec56e96f8630 Mon Sep 17 00:00:00 2001 From: TL Gao Date: Mon, 30 Sep 2024 02:33:31 +0800 Subject: [PATCH] [Fix] Fix conflict of downloading dataset --- intermediate_source/FSDP_tutorial.rst | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/intermediate_source/FSDP_tutorial.rst b/intermediate_source/FSDP_tutorial.rst index 9b9845667f..10a141675d 100644 --- a/intermediate_source/FSDP_tutorial.rst +++ b/intermediate_source/FSDP_tutorial.rst @@ -208,10 +208,12 @@ We add the following code snippets to a python script “FSDP_mnist.py”. transforms.Normalize((0.1307,), (0.3081,)) ]) - dataset1 = datasets.MNIST('../data', train=True, download=True, - transform=transform) - dataset2 = datasets.MNIST('../data', train=False, - transform=transform) + #if dataset not exists, download it on rank 0 + dataset_dir = '../data' + if rank == 0: + print(f"Preparing MNIST dataset on {rank=} ...") + datasets.MNIST(dataset_dir, train=True, download=True) + dist.barrier() sampler1 = DistributedSampler(dataset1, rank=rank, num_replicas=world_size, shuffle=True) sampler2 = DistributedSampler(dataset2, rank=rank, num_replicas=world_size)