diff --git a/src/scanpy/readwrite.py b/src/scanpy/readwrite.py index 3c958a1e5..85a19214a 100644 --- a/src/scanpy/readwrite.py +++ b/src/scanpy/readwrite.py @@ -5,7 +5,7 @@ import json from pathlib import Path, PurePath from typing import TYPE_CHECKING - +import anndata import anndata.utils import h5py import numpy as np @@ -33,6 +33,10 @@ read_text, ) from anndata import AnnData +import multiprocessing as mp +import time +import numba +import scipy from matplotlib.image import imread from . import logging as logg @@ -45,6 +49,12 @@ from ._utils import Empty +indices_type = np.int64 +indices_shm_type = "l" + +semDataLoaded = None # will be initialized later +semDataCopied = None # will be initialized later + # .gz and .bz2 suffixes are also allowed for text formats text_exts = { "csv", @@ -65,6 +75,114 @@ } | text_exts """Available file formats for reading data. """ +def _load_helper(fname, i, k, datalen, dataArray, indicesArray, startsArray, endsArray): + f = h5py.File(fname,'r') + dataA = np.frombuffer(dataArray,dtype=np.float32) + indicesA = np.frombuffer(indicesArray,dtype=indices_type) + startsA = np.frombuffer(startsArray,dtype=np.int64) + endsA = np.frombuffer(endsArray,dtype=np.int64) + for j in range(datalen//(k*1024*1024)+1): + # compute start, end + s = i*datalen//k + j*1024*1024 + e = min(s+1024*1024, (i+1)*datalen//k) + length = e-s + startsA[i]=s + endsA[i]=e + # read direct + f['X']['data'].read_direct(dataA, np.s_[s:e], np.s_[i*1024*1024:i*1024*1024+length]) + f['X']['indices'].read_direct(indicesA, np.s_[s:e], np.s_[i*1024*1024:i*1024*1024+length]) + + # coordinate with copy threads + semDataLoaded[i].release() # done data load + semDataCopied[i].acquire() # wait until data copied + +def _waitload(i): + semDataLoaded[i].acquire() + +def _signalcopy(i): + semDataCopied[i].release() + +@numba.njit(parallel=True) +def _fast_copy(data,dataA,indices,indicesA,starts,ends,k,m): + for i in numba.prange(k): + for _ in range(m): + with numba.objmode(): + _waitload(i) + length = ends[i]-starts[i] + data[starts[i]:ends[i]] = dataA[i*1024*1024:i*1024*1024+length] + indices[starts[i]:ends[i]] = indicesA[i*1024*1024:i*1024*1024+length] + with numba.objmode(): + _signalcopy(i) + +def fastload(fname, backed): #, firstn=1): + t0 = time.time() + f = h5py.File(fname,backed) + assert ('X' in f.keys() and 'var' in f.keys() and 'obs' in f.keys()) + + # get obs dataframe + rows = f['obs'][ list(f['obs'].keys())[0] ].size + # load index pointers, prepare shared arrays + indptr = f['X']['indptr'][0:rows+1] + datalen = int(indptr[-1]) + + + print(f"datalen {datalen} {1024*1024}") + if datalen<1024*1024: + f.close() + return read_h5ad(fname, backed=backed) + if '_index' in f['obs'].keys(): + dfobsind = pd.Series(f['obs']['_index'].asstr()[0:rows]) + dfobs = pd.DataFrame(index=dfobsind) + else: + dfobs = pd.DataFrame() + for k in f['obs'].keys(): + if k=='_index': continue + dfobs[k] = f['obs'][k].asstr()[...] + + # get var dataframe + if '_index' in f['var'].keys(): + dfvarind = pd.Series(f['var']['_index'].asstr()[...]) + dfvar = pd.DataFrame(index=dfvarind) + else: + dfvar = pd.DataFrame() + for k in f['var'].keys(): + if k=='_index': continue + dfvar[k] = f['var'][k].asstr()[...] + + f.close() + k = numba.get_num_threads() + dataArray = mp.Array('f',k*1024*1024,lock=False) # should be in shared memory + indicesArray = mp.Array(indices_shm_type,k*1024*1024,lock=False) # should be in shared memory + startsArray = mp.Array('l',k,lock=False) # start index of data read + endsArray = mp.Array('l',k,lock=False) # end index (noninclusive) of data read + global semDataLoaded + global semDataCopied + semDataLoaded = [mp.Semaphore(0) for _ in range(k)] + semDataCopied = [mp.Semaphore(0) for _ in range(k)] + dataA = np.frombuffer(dataArray,dtype=np.float32) + indicesA = np.frombuffer(indicesArray,dtype=indices_type) + startsA = np.frombuffer(startsArray, dtype=np.int64) + endsA = np.frombuffer(endsArray, dtype=np.int64) + data = np.empty(datalen, dtype=np.float32) + indices = np.empty(datalen, dtype=indices_type) + + procs = [mp.Process(target=_load_helper, args=(fname, i, k, datalen, dataArray, indicesArray, startsArray, endsArray)) for i in range(k)] + for p in procs: p.start() + + _fast_copy(data,dataA,indices,indicesA,startsA,endsA,k,datalen//(k*1024*1024)+1) + + for p in procs: p.join() + + X = scipy.sparse.csr_matrix((0,0)) + X.data = data + X.indices = indices + X.indptr = indptr + X._shape = ((rows, dfvar.shape[0])) + + # create AnnData + adata = anndata.AnnData(X, dfobs, dfvar) + return adata + # -------------------------------------------------------------------------------- # Reading and Writing data files and AnnData objects @@ -82,7 +200,7 @@ ) def read( filename: Path | str, - backed: Literal["r", "r+"] | None = None, + backed: Literal["r", "r+"] | None = 'r+', *, sheet: str | None = None, ext: str | None = None, @@ -162,7 +280,7 @@ def read( f"ending on one of the available extensions {avail_exts} " "or pass the parameter `ext`." ) - return read_h5ad(filename, backed=backed) + return fastload(filename, backed) @old_positionals("genome", "gex_only", "backup_url") @@ -774,7 +892,7 @@ def _read( # read hdf5 files if ext in {"h5", "h5ad"}: if sheet is None: - return read_h5ad(filename, backed=backed) + return fastload(filename, backed) else: logg.debug(f"reading sheet {sheet} from file {filename}") return read_hdf(filename, sheet) @@ -786,7 +904,7 @@ def _read( path_cache = path_cache.with_suffix("") if cache and path_cache.is_file(): logg.info(f"... reading from cache file {path_cache}") - return read_h5ad(path_cache) + return fastload(path_cache, backed) if not is_present: raise FileNotFoundError(f"Did not find file {filename}.")