Skip to content

Commit

Permalink
Bug fix in split_index method (#5292)
Browse files Browse the repository at this point in the history
Bug description: on a dataset of 20 samples, when running 4 workers with
8 threads per worker, then the `split_dataset` would return for worker
id `1`:

```
self.worker_splits
[[0, 5], [5, 10], [10, 15], [15, 20]]


self.thread_splits
[[5, 6], [6, 7], [7, 8], [8, 9], [9, 10], [10, 10], [11, 10], [12, 10]]
```

`thread_splits` is wrong and causes a crash in the `DataAnalyzer`: the
end sample id is lower than the initial one on the last 2 threads.
This PR fixes that by fixing the behaviour of `split_index`

---------

Co-authored-by: Logan Adams <[email protected]>
  • Loading branch information
bm-synth and loadams authored Apr 18, 2024
1 parent 3194fe8 commit aaaf8bc
Showing 1 changed file with 2 additions and 5 deletions.
7 changes: 2 additions & 5 deletions deepspeed/runtime/data_pipeline/data_sampling/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

# DeepSpeed Team

import math
import numpy as np

from deepspeed.utils import logger
Expand Down Expand Up @@ -32,10 +31,8 @@ def find_fit_int_dtype(min_value, max_value):


def split_index(start_idx, end_idx, num_partitions):
partition_size = math.ceil((end_idx - start_idx) / num_partitions)
partitions = [[start_idx + x * partition_size,
min(end_idx, start_idx + (x + 1) * partition_size)] for x in range(num_partitions)]
return partitions
partition_boundaries = np.linspace(start_idx, end_idx, dtype=int, num=num_partitions + 1)
return [(partition_boundaries[i], partition_boundaries[i + 1]) for i in range(num_partitions)]


def split_dataset(dataset, num_workers, worker_id, num_threads):
Expand Down

0 comments on commit aaaf8bc

Please sign in to comment.