Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

optimize lazy execution during pvar read #11

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 59 additions & 28 deletions snputils/snp/io/read/pgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,48 +102,83 @@ def open_textfile(filename):
pvar_header_line_num = 0
with open_textfile(pvar_filename) as file:
for line_num, line in enumerate(file):
if line.startswith("#CHROM"):
if line.startswith("##"): # Metadata
continue
elif line.startswith("#CHROM"): # Header
pvar_header_line_num = line_num
header = line.strip().split()
break
else: # if no break
pvar_has_header = False
elif not line.startswith("#"): # If no header, look at line 1
pvar_has_header = False
cols_in_pvar = len(line.strip().split())
if cols_in_pvar == 5:
header = ["#CHROM", "ID", "POS", "ALT", "REF"]
elif cols_in_pvar == 6:
header = ["#CHROM", "ID", "CM", "POS", "ALT", "REF"]
else:
raise ValueError(
f"{pvar_filename} is not a valid pvar file."
)
break

def lazy_read(filename: str, **kwargs) -> pl.LazyFrame:
"""
Simple reader function needed due to lack of support for scanning zstd files in polars.

Args:
filename (str): pvar file, either .pvar or .pvar.zst
**kwargs: CSV arguments for polars

Returns:
pl.LazyFrame
"""
if filename.endswith('.zst'):
return pl.read_csv(filename, **kwargs).lazy()
else:
return pl.scan_csv(filename, **kwargs)

pvar = pl.scan_csv(
pvar = lazy_read(
salcc marked this conversation as resolved.
Show resolved Hide resolved
pvar_filename,
separator='\t',
skip_rows=pvar_header_line_num,
has_header=pvar_has_header,
new_columns=None if pvar_has_header else ["#CHROM", "ID", "CM", "POS", "REF", "ALT"],
new_columns=None if pvar_has_header else header,
schema_overrides={
"#CHROM": pl.String,
"POS": pl.Int64,
"POS": pl.UInt32,
"ID": pl.String,
"REF": pl.String,
"ALT": pl.String,
},
).with_row_index()
)

# keeping track of indices provides a problem for lazy execution ->
# it can block predicate pushdown optimization and cause the whole lf to be read into memory
# instead we keep track of the index with "ID" and ensure that is the only columm always read into memory
idxs = pvar.select("ID").with_row_index().collect()

# since pvar is lazy, the skip_rows operation hasn't materialized
# pl.len() will return the length of the pvar + header
file_num_variants = pvar.select(pl.len()).collect().item() - pvar_header_line_num
file_num_variants = idxs.height

if variant_ids is not None:
num_variants = np.size(variant_ids)
pvar = pvar.filter(pl.col("ID").is_in(variant_ids)).collect()
variant_idxs = (
pvar.filter(pl.col("ID").is_in(variant_ids))
idxs.filter(pl.col("ID").is_in(variant_ids))
.select("index")
.collect()
.to_series()
.to_numpy()
)

if variant_idxs is None:
elif variant_idxs is not None:
num_variants = np.size(variant_idxs)
variant_idxs = np.array(variant_idxs, dtype=np.uint32)
variant_ids = idxs.filter(pl.col("index").is_in(variant_idxs)).select(
"ID"
)
pvar = pvar.filter(pl.col("ID").is_in(variant_ids)).collect()
else:
num_variants = file_num_variants
variant_idxs = np.arange(num_variants, dtype=np.uint32)
pvar = pvar.collect()
else:
num_variants = np.size(variant_idxs)
variant_idxs = np.array(variant_idxs, dtype=np.uint32)
pvar = pvar.filter(pl.col("index").is_in(variant_idxs)).collect()

log.info(f"Reading {filename_noext}.psam")

Expand All @@ -165,19 +200,15 @@ def open_textfile(filename):
file_num_samples = psam.height

if sample_ids is not None:
sample_idxs = (
psam.filter(pl.col("IID").is_in(sample_ids))
.select("index")
.to_series()
.to_numpy()
)

if sample_idxs is None:
num_samples = file_num_samples
else:
num_samples = np.size(sample_ids)
psam = psam.filter(pl.col("IID").is_in(sample_ids))
sample_idxs = psam.select("index").to_series().to_numpy()
elif sample_idxs is not None:
num_samples = np.size(sample_idxs)
sample_idxs = np.array(sample_idxs, dtype=np.uint32)
psam = psam.filter(pl.col("index").is_in(sample_idxs))
else:
num_samples = file_num_samples

if "GT" in fields:
log.info(f"Reading {filename_noext}.pgen")
Expand Down