Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

faster reading of h5ad file (~18X faster) #3365

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 123 additions & 5 deletions src/scanpy/readwrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
from pathlib import Path, PurePath
from typing import TYPE_CHECKING

import anndata
import anndata.utils
import h5py
import numpy as np
Expand Down Expand Up @@ -33,6 +33,10 @@
read_text,
)
from anndata import AnnData
import multiprocessing as mp
import time
import numba
import scipy
from matplotlib.image import imread

from . import logging as logg
Expand All @@ -45,6 +49,12 @@

from ._utils import Empty

indices_type = np.int64
indices_shm_type = "l"

semDataLoaded = None # will be initialized later
semDataCopied = None # will be initialized later

# .gz and .bz2 suffixes are also allowed for text formats
text_exts = {
"csv",
Expand All @@ -65,6 +75,114 @@
} | text_exts
"""Available file formats for reading data. """

def _load_helper(fname, i, k, datalen, dataArray, indicesArray, startsArray, endsArray):
f = h5py.File(fname,'r')
dataA = np.frombuffer(dataArray,dtype=np.float32)
indicesA = np.frombuffer(indicesArray,dtype=indices_type)
startsA = np.frombuffer(startsArray,dtype=np.int64)
endsA = np.frombuffer(endsArray,dtype=np.int64)
for j in range(datalen//(k*1024*1024)+1):

Check warning on line 84 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L79-L84

Added lines #L79 - L84 were not covered by tests
# compute start, end
s = i*datalen//k + j*1024*1024
e = min(s+1024*1024, (i+1)*datalen//k)
length = e-s
startsA[i]=s
endsA[i]=e

Check warning on line 90 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L86-L90

Added lines #L86 - L90 were not covered by tests
# read direct
f['X']['data'].read_direct(dataA, np.s_[s:e], np.s_[i*1024*1024:i*1024*1024+length])
f['X']['indices'].read_direct(indicesA, np.s_[s:e], np.s_[i*1024*1024:i*1024*1024+length])

Check warning on line 93 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L92-L93

Added lines #L92 - L93 were not covered by tests

# coordinate with copy threads
semDataLoaded[i].release() # done data load
semDataCopied[i].acquire() # wait until data copied

Check warning on line 97 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L96-L97

Added lines #L96 - L97 were not covered by tests

def _waitload(i):
semDataLoaded[i].acquire()

Check warning on line 100 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L100

Added line #L100 was not covered by tests

def _signalcopy(i):
semDataCopied[i].release()

Check warning on line 103 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L103

Added line #L103 was not covered by tests

@numba.njit(parallel=True)
def _fast_copy(data,dataA,indices,indicesA,starts,ends,k,m):
for i in numba.prange(k):
for _ in range(m):
with numba.objmode():
_waitload(i)
length = ends[i]-starts[i]
data[starts[i]:ends[i]] = dataA[i*1024*1024:i*1024*1024+length]
indices[starts[i]:ends[i]] = indicesA[i*1024*1024:i*1024*1024+length]
with numba.objmode():
_signalcopy(i)

def fastload(fname, backed): #, firstn=1):
t0 = time.time()
f = h5py.File(fname,backed)
assert ('X' in f.keys() and 'var' in f.keys() and 'obs' in f.keys())

# get obs dataframe
rows = f['obs'][ list(f['obs'].keys())[0] ].size
# load index pointers, prepare shared arrays
indptr = f['X']['indptr'][0:rows+1]
datalen = int(indptr[-1])

Check warning on line 126 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L125-L126

Added lines #L125 - L126 were not covered by tests


print(f"datalen {datalen} {1024*1024}")
if datalen<1024*1024:
f.close()
return read_h5ad(fname, backed=backed)
if '_index' in f['obs'].keys():
dfobsind = pd.Series(f['obs']['_index'].asstr()[0:rows])
dfobs = pd.DataFrame(index=dfobsind)

Check warning on line 135 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L129-L135

Added lines #L129 - L135 were not covered by tests
else:
dfobs = pd.DataFrame()
for k in f['obs'].keys():
if k=='_index': continue
dfobs[k] = f['obs'][k].asstr()[...]

Check warning on line 140 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L137-L140

Added lines #L137 - L140 were not covered by tests

# get var dataframe
if '_index' in f['var'].keys():
dfvarind = pd.Series(f['var']['_index'].asstr()[...])
dfvar = pd.DataFrame(index=dfvarind)

Check warning on line 145 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L143-L145

Added lines #L143 - L145 were not covered by tests
else:
dfvar = pd.DataFrame()
for k in f['var'].keys():
if k=='_index': continue
dfvar[k] = f['var'][k].asstr()[...]

Check warning on line 150 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L147-L150

Added lines #L147 - L150 were not covered by tests

f.close()
k = numba.get_num_threads()
dataArray = mp.Array('f',k*1024*1024,lock=False) # should be in shared memory
indicesArray = mp.Array(indices_shm_type,k*1024*1024,lock=False) # should be in shared memory
startsArray = mp.Array('l',k,lock=False) # start index of data read
endsArray = mp.Array('l',k,lock=False) # end index (noninclusive) of data read

Check warning on line 157 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L152-L157

Added lines #L152 - L157 were not covered by tests
global semDataLoaded
global semDataCopied
semDataLoaded = [mp.Semaphore(0) for _ in range(k)]
semDataCopied = [mp.Semaphore(0) for _ in range(k)]
dataA = np.frombuffer(dataArray,dtype=np.float32)
indicesA = np.frombuffer(indicesArray,dtype=indices_type)
startsA = np.frombuffer(startsArray, dtype=np.int64)
endsA = np.frombuffer(endsArray, dtype=np.int64)
data = np.empty(datalen, dtype=np.float32)
indices = np.empty(datalen, dtype=indices_type)

Check warning on line 167 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L160-L167

Added lines #L160 - L167 were not covered by tests

procs = [mp.Process(target=_load_helper, args=(fname, i, k, datalen, dataArray, indicesArray, startsArray, endsArray)) for i in range(k)]
for p in procs: p.start()

Check warning on line 170 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L169-L170

Added lines #L169 - L170 were not covered by tests

_fast_copy(data,dataA,indices,indicesA,startsA,endsA,k,datalen//(k*1024*1024)+1)

Check warning on line 172 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L172

Added line #L172 was not covered by tests

for p in procs: p.join()

Check warning on line 174 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L174

Added line #L174 was not covered by tests

X = scipy.sparse.csr_matrix((0,0))
X.data = data
X.indices = indices
X.indptr = indptr
X._shape = ((rows, dfvar.shape[0]))

Check warning on line 180 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L176-L180

Added lines #L176 - L180 were not covered by tests

# create AnnData
adata = anndata.AnnData(X, dfobs, dfvar)
return adata

Check warning on line 184 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L183-L184

Added lines #L183 - L184 were not covered by tests


# --------------------------------------------------------------------------------
# Reading and Writing data files and AnnData objects
Expand All @@ -82,7 +200,7 @@
)
def read(
filename: Path | str,
backed: Literal["r", "r+"] | None = None,
backed: Literal["r", "r+"] | None = 'r+',
*,
sheet: str | None = None,
ext: str | None = None,
Expand Down Expand Up @@ -162,7 +280,7 @@
f"ending on one of the available extensions {avail_exts} "
"or pass the parameter `ext`."
)
return read_h5ad(filename, backed=backed)
return fastload(filename, backed)

Check warning on line 283 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L283

Added line #L283 was not covered by tests


@old_positionals("genome", "gex_only", "backup_url")
Expand Down Expand Up @@ -774,7 +892,7 @@
# read hdf5 files
if ext in {"h5", "h5ad"}:
if sheet is None:
return read_h5ad(filename, backed=backed)
return fastload(filename, backed)
else:
logg.debug(f"reading sheet {sheet} from file {filename}")
return read_hdf(filename, sheet)
Expand All @@ -786,7 +904,7 @@
path_cache = path_cache.with_suffix("")
if cache and path_cache.is_file():
logg.info(f"... reading from cache file {path_cache}")
return read_h5ad(path_cache)
return fastload(path_cache, backed)

Check warning on line 907 in src/scanpy/readwrite.py

View check run for this annotation

Codecov / codecov/patch

src/scanpy/readwrite.py#L907

Added line #L907 was not covered by tests

if not is_present:
raise FileNotFoundError(f"Did not find file {filename}.")
Expand Down
Loading