Skip to content

Commit

Permalink
Merge pull request #11 from kscott-1/dev
Browse files Browse the repository at this point in the history
Optimize lazy execution during pvar read
  • Loading branch information
salcc authored Nov 21, 2024
2 parents ce39cd7 + a7a5a4c commit 5dd876b
Showing 1 changed file with 59 additions and 28 deletions.
87 changes: 59 additions & 28 deletions snputils/snp/io/read/pgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,48 +104,83 @@ def open_textfile(filename):
pvar_header_line_num = 0
with open_textfile(pvar_filename) as file:
for line_num, line in enumerate(file):
if line.startswith("#CHROM"):
if line.startswith("##"): # Metadata
continue
elif line.startswith("#CHROM"): # Header
pvar_header_line_num = line_num
header = line.strip().split()
break
else: # if no break
pvar_has_header = False
elif not line.startswith("#"): # If no header, look at line 1
pvar_has_header = False
cols_in_pvar = len(line.strip().split())
if cols_in_pvar == 5:
header = ["#CHROM", "ID", "POS", "ALT", "REF"]
elif cols_in_pvar == 6:
header = ["#CHROM", "ID", "CM", "POS", "ALT", "REF"]
else:
raise ValueError(
f"{pvar_filename} is not a valid pvar file."
)
break

def lazy_read(filename: str, **kwargs) -> pl.LazyFrame:
"""
Simple reader function needed due to lack of support for scanning zstd files in polars.
Args:
filename (str): pvar file, either .pvar or .pvar.zst
**kwargs: CSV arguments for polars
Returns:
pl.LazyFrame
"""
if filename.endswith('.zst'):
return pl.read_csv(filename, **kwargs).lazy()
else:
return pl.scan_csv(filename, **kwargs)

pvar = pl.scan_csv(
pvar = lazy_read(
pvar_filename,
separator='\t',
skip_rows=pvar_header_line_num,
has_header=pvar_has_header,
new_columns=None if pvar_has_header else ["#CHROM", "ID", "CM", "POS", "REF", "ALT"],
new_columns=None if pvar_has_header else header,
schema_overrides={
"#CHROM": pl.String,
"POS": pl.Int64,
"POS": pl.UInt32,
"ID": pl.String,
"REF": pl.String,
"ALT": pl.String,
},
).with_row_index()
)

# keeping track of indices provides a problem for lazy execution ->
# it can block predicate pushdown optimization and cause the whole lf to be read into memory
# instead we keep track of the index with "ID" and ensure that is the only columm always read into memory
idxs = pvar.select("ID").with_row_index().collect()

# since pvar is lazy, the skip_rows operation hasn't materialized
# pl.len() will return the length of the pvar + header
file_num_variants = pvar.select(pl.len()).collect().item() - pvar_header_line_num
file_num_variants = idxs.height

if variant_ids is not None:
num_variants = np.size(variant_ids)
pvar = pvar.filter(pl.col("ID").is_in(variant_ids)).collect()
variant_idxs = (
pvar.filter(pl.col("ID").is_in(variant_ids))
idxs.filter(pl.col("ID").is_in(variant_ids))
.select("index")
.collect()
.to_series()
.to_numpy()
)

if variant_idxs is None:
elif variant_idxs is not None:
num_variants = np.size(variant_idxs)
variant_idxs = np.array(variant_idxs, dtype=np.uint32)
variant_ids = idxs.filter(pl.col("index").is_in(variant_idxs)).select(
"ID"
)
pvar = pvar.filter(pl.col("ID").is_in(variant_ids)).collect()
else:
num_variants = file_num_variants
variant_idxs = np.arange(num_variants, dtype=np.uint32)
pvar = pvar.collect()
else:
num_variants = np.size(variant_idxs)
variant_idxs = np.array(variant_idxs, dtype=np.uint32)
pvar = pvar.filter(pl.col("index").is_in(variant_idxs)).collect()

log.info(f"Reading {filename_noext}.psam")

Expand All @@ -167,19 +202,15 @@ def open_textfile(filename):
file_num_samples = psam.height

if sample_ids is not None:
sample_idxs = (
psam.filter(pl.col("IID").is_in(sample_ids))
.select("index")
.to_series()
.to_numpy()
)

if sample_idxs is None:
num_samples = file_num_samples
else:
num_samples = np.size(sample_ids)
psam = psam.filter(pl.col("IID").is_in(sample_ids))
sample_idxs = psam.select("index").to_series().to_numpy()
elif sample_idxs is not None:
num_samples = np.size(sample_idxs)
sample_idxs = np.array(sample_idxs, dtype=np.uint32)
psam = psam.filter(pl.col("index").is_in(sample_idxs))
else:
num_samples = file_num_samples

if "GT" in fields:
log.info(f"Reading {filename_noext}.pgen")
Expand Down

0 comments on commit 5dd876b

Please sign in to comment.