Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
distributed merge of per-rank Megatron data files (#55)
* add parallel merge using mpi * handle case where some ranks might have 0 items * add inclusive scan prefix sum * report more timing info * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * rename total size variable for clarity * move translation to bin/idx file names a level deeper * parallel merge for cached dataset * add alltrue function * move collectives to new distdata class, add torch.distributed * drop unused prefix_sum function * allow ranks to pass a list of files to be merged * check that input dataset files exist * fix: using wrong doc_idx list for mmap * move init dist and collectives to distdata class * add --merge option, move parallel/serial to their own functions * Update megatron/data/distdata.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * Update megatron/data/indexed_dataset.py Co-authored-by: Thomas Wang <[email protected]> * drop extraneous numpy tolist calls * rename self.MPI to mpi4py * handle case where no ranks have elements in their file * rename tokenize_start to time_start * drop unrelated comment in distdata.min * add comment why pointers_shift is not None and add assert * note why pointers uses sizes count and offset values * can just rely on rank 0 for the leading 0 element * add write_list function * determine element size * add checks for consistent element_size values * check that at least one rank has a file to merge * assert that torch backend is gloo or mpi * add collectives for assert and raise * rename to allassert and allraise_if * check dtype instead of element_size * add uint32 to element_sizes table * infer dtype from files being merged * add write_header function to indexed dataset classes * call write_header internally from IndexedDataset classes * return number of bytes written from write calls * move scatterv to distdata class * add functions to format status and error messages * defer merge_files_dist to future PR * open files using with, refresh comments * rely on default torch datatypes * fix some status messages from preprocess script * fix: exclusive scan computing pointers list * fix: exclusive scan to compute mmap pointers list * note about seek * rename preprocess_dataset_mpi.py to preprocess_data_dist.py * update usage comments at top of script * restore commented print_rank_0 statements * restore status message in mmap merge_file_ * drop mpi4py, sad :( * add test case for parallel merge * add preprocess_data_dist test for serial merge * improve error handling * refactor get_pointers code * bug fix in exscan * further refactor get_pointers * move exscan collective for pointers outside of try block * clarify some comments * include string 1k in name of test files * use temporary file for index * fix: implement scatterv from torch.distributed.scatter * switch to pad method in torch.nn.functional * return data received in scatterv as new tensor * raise exception if conflicting scratch and merge options * use allraise method from distdata in preprocess_data_dist Co-authored-by: Thomas Wang <[email protected]>
- Loading branch information