Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ramanishsingh committed Aug 22, 2024
1 parent 4d014e5 commit cdc5d31
Showing 1 changed file with 83 additions and 67 deletions.
150 changes: 83 additions & 67 deletions test/stateful_dataloader/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,10 @@ def setUp(self):

def test_initialization_StatefulDistributedSampler(self):

dataset = self.dataset
sampler = StatefulDistributedSampler(dataset, num_replicas=10, rank=0, shuffle=False, seed=42, drop_last=False)
self.assertEqual(sampler.dataset, dataset)
sampler = StatefulDistributedSampler(
self.dataset, num_replicas=10, rank=0, shuffle=False, seed=42, drop_last=False
)
self.assertEqual(sampler.dataset, self.dataset)
self.assertEqual(sampler.num_replicas, 10)
self.assertEqual(sampler.rank, 0)
self.assertFalse(sampler.shuffle)
Expand All @@ -169,22 +170,45 @@ def test_initialization_StatefulDistributedSampler(self):
self.assertEqual(sampler.yielded, 0)
self.assertIsNone(sampler.next_yielded)

def test_state_dict(self):
def test_dataloader_state_dict(self):
sampler = StatefulDistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=False)
dataloader = StatefulDataLoader(self.dataset, batch_size=10, sampler=sampler)
# Partial iteration over the DataLoader
iter_count = 5
for i, data in enumerate(dataloader):
if i == iter_count - 1:
break
state_dict = dataloader.state_dict()
new_sampler = StatefulDistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=False)

new_dataloader = StatefulDataLoader(self.dataset, batch_size=10, sampler=new_sampler)
new_dataloader.load_state_dict(state_dict)
resumed_data = []
for data in new_dataloader:
resumed_data.append(data.tolist())
expected_data = []
full_dataloader = DataLoader(self.dataset, batch_size=10, sampler=sampler)
for data in full_dataloader:
expected_data.append(data.tolist())

self.assertEqual(resumed_data, expected_data[iter_count:])

def test_sampler_state_dict(self):

sampler = StatefulDistributedSampler(self.dataset, num_replicas=10, rank=0)
sampler.yielded = 5
state_dict = sampler.state_dict()
self.assertEqual(state_dict["yielded"], 5)

def test_load_state_dict(self):
def test_sampler_load_state_dict(self):

sampler = StatefulDistributedSampler(self.dataset, num_replicas=10, rank=0)
sampler.load_state_dict({"yielded": 3})
self.assertEqual(sampler.next_yielded, 3)
with self.assertRaises(ValueError):
sampler.load_state_dict({"yielded": -1})

def test_next_yielded(self):
def test_sampler_next_yielded(self):

sampler = StatefulDistributedSampler(self.dataset, num_replicas=2, rank=0, shuffle=True, seed=42)
iterator = iter(sampler)
Expand All @@ -197,91 +221,83 @@ def test_next_yielded(self):
self.assertEqual(sampler.yielded, 6)

def test_drop_last_effect(self):
num_replicas = 3
total_samples = len(self.dataset)
expected_length_with_drop = total_samples // num_replicas
expected_length_without_drop = math.ceil(total_samples / num_replicas)

sampler_with_drop = StatefulDistributedSampler(
self.dataset, num_replicas=3, rank=0, drop_last=True, shuffle=False
)
dataloader_with_drop = StatefulDataLoader(self.dataset, sampler=sampler_with_drop)

sampler_without_drop = StatefulDistributedSampler(
self.dataset, num_replicas=3, rank=0, drop_last=False, shuffle=False
)
dataloader_without_drop = StatefulDataLoader(self.dataset, sampler=sampler_without_drop)

# Collect all indices from dataloaders
indices_with_drop = [data for batch in dataloader_with_drop for data in batch]
indices_without_drop = [data for batch in dataloader_without_drop for data in batch]

# Check the lengths of the outputs
self.assertEqual(
len(indices_with_drop),
expected_length_with_drop,
"Length with drop_last=True should match expected truncated length",
)
self.assertEqual(
len(indices_without_drop),
expected_length_without_drop,
"Length with drop_last=False should match total dataset size",
)

sampler_with_drop = StatefulDistributedSampler(self.dataset, num_replicas=3, rank=0, drop_last=True)
sampler_without_drop = StatefulDistributedSampler(self.dataset, num_replicas=3, rank=0, drop_last=False)
indices_with_drop = list(iter(sampler_with_drop))
indices_without_drop = list(iter(sampler_without_drop))
self.assertTrue(
len(indices_with_drop) <= len(indices_without_drop), "Drop last should result in fewer or equal indices"
)

def test_data_order_with_shuffle(self):

sampler = StatefulDistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=True, seed=42)
sampler = StatefulDistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=True)
indices = list(iter(sampler))
data_sampled = [self.dataset[i] for i in indices]
self.assertNotEqual(data_sampled, list(range(100)), "Data should be shuffled")

def test_data_order_without_shuffle(self):
dataloader = StatefulDataLoader(self.dataset, sampler=sampler)
data_loaded = []
for batch in dataloader:
data_loaded.extend(batch)
self.assertEqual(len(data_loaded), len(self.dataset), "All data should be loaded")
self.assertEqual(data_loaded, data_sampled, "Data loaded by DataLoader should match data sampled by sampler")

def test_data_order_without_shuffle(self):
sampler = StatefulDistributedSampler(self.dataset, num_replicas=1, rank=0, shuffle=False)
indices = list(iter(sampler))
data_sampled = [self.dataset[i] for i in indices]
self.assertEqual(data_sampled, list(range(100)), "Data should be in sequential order when shuffle is False")
self.assertEqual(data_sampled, list(range(100)), "Data should not be shuffled")

def test_data_distribution_across_replicas(self):
batch_size = 32
dataloader = torch.utils.data.DataLoader(self.dataset, batch_size=batch_size, sampler=sampler)
data_loaded = []
for batch in dataloader:
data_loaded.extend(batch)
self.assertEqual(len(data_loaded), len(self.dataset), "All data should be loaded")
self.assertEqual(data_loaded, data_sampled, "Data loaded by DataLoader should match data sampled by sampler")
self.assertEqual(data_loaded, list(range(100)), "Data loaded by DataLoader should be in original order")

def test_data_distribution_across_replicas(self):
num_replicas = 5
all_data = []
for rank in range(num_replicas):
sampler = StatefulDistributedSampler(self.dataset, num_replicas=num_replicas, rank=rank, shuffle=False)
indices = list(iter(sampler))
data_sampled = [int(self.dataset[i].item()) for i in indices]
all_data.extend(data_sampled)
dataloader = torch.utils.data.DataLoader(self.dataset, sampler=sampler)
data_loaded = []
for batch in dataloader:
data_loaded.extend([int(x.item()) for x in batch])
all_data.extend(data_loaded)
self.assertEqual(
sorted(all_data), list(range(100)), "All data points should be covered exactly once across all replicas"
)

def test_consistency_across_epochs(self):

num_replicas = 3
rank = 1
sampler = StatefulDistributedSampler(self.dataset, num_replicas=num_replicas, rank=rank, shuffle=True, seed=42)
indices_epoch1 = list(iter(sampler))
data_epoch1 = [self.dataset[i] for i in indices_epoch1]
sampler.set_epoch(1) # Move to the next epoch
indices_epoch2 = list(iter(sampler))
data_epoch2 = [self.dataset[i] for i in indices_epoch2]
self.assertNotEqual(data_epoch1, data_epoch2, "Data order should change with different epochs due to shuffling")

def test_no_data_loss_with_drop_last(self):

sampler = StatefulDistributedSampler(self.dataset, num_replicas=3, rank=0, drop_last=True)
indices = list(iter(sampler))
expected_length = (len(self.dataset) // 3) * 3 // 3
self.assertEqual(
len(indices), expected_length, "Length of indices should match expected length with drop_last=True"
)

def test_state_dict_end_to_end(self):
# Setup
dataset = MockDataset(100)
sampler = StatefulDistributedSampler(dataset, num_replicas=1, rank=0, shuffle=False)
dataloader = StatefulDataLoader(dataset, batch_size=10, sampler=sampler)
# Simulate partial iteration over the DataLoader
iter_count = 5
for i, data in enumerate(dataloader):
if i == iter_count - 1:
break
# Save the state of the sampler
state_dict = sampler.state_dict()
# Create a new sampler and DataLoader, load the saved state
new_sampler = StatefulDistributedSampler(dataset, num_replicas=1, rank=0, shuffle=False)
new_sampler.load_state_dict(state_dict)
new_dataloader = StatefulDataLoader(dataset, batch_size=10, sampler=new_sampler)
# Collect data from the new DataLoader
resumed_data = []
for data in new_dataloader:
resumed_data.append(data.tolist())
# Expected data if there was no interruption
expected_data = []
full_dataloader = DataLoader(dataset, batch_size=10, sampler=sampler)
for data in full_dataloader:
expected_data.append(data.tolist())
# Compare resumed data with expected data starting from the interruption point
self.assertEqual(resumed_data, expected_data[iter_count:])


if __name__ == "__main__":
run_tests()

0 comments on commit cdc5d31

Please sign in to comment.