From a7a5a4c39d12132846c00164dbd749e60382a13d Mon Sep 17 00:00:00 2001 From: Kyle Scott Date: Thu, 21 Nov 2024 16:01:47 -0500 Subject: [PATCH] bugfixes and optimization in pvar lazy reading * row indices create an issue with lazy execution. any function polars provides to add indices requires the lf be read into memory. * instead of reading the entire lf into memory, we track indices by using the variant ID. * we can determine a blank header from the number of columns in a pvar file - pvar spec allows only for 5/6 cols when header is blank. * the order of the control block for filtering variants provided an issue where the entire lf was read in the case that variant_idxs were not specified, even if variant_ids was. * fix the ordering logic into one if,elif,else block and copy the structure to the psam section. Signed-off-by: Kyle Scott --- snputils/snp/io/read/pgen.py | 87 ++++++++++++++++++++++++------------ 1 file changed, 59 insertions(+), 28 deletions(-) diff --git a/snputils/snp/io/read/pgen.py b/snputils/snp/io/read/pgen.py index 55c51e1..5803511 100644 --- a/snputils/snp/io/read/pgen.py +++ b/snputils/snp/io/read/pgen.py @@ -102,48 +102,83 @@ def open_textfile(filename): pvar_header_line_num = 0 with open_textfile(pvar_filename) as file: for line_num, line in enumerate(file): - if line.startswith("#CHROM"): + if line.startswith("##"): # Metadata + continue + elif line.startswith("#CHROM"): # Header pvar_header_line_num = line_num + header = line.strip().split() break - else: # if no break - pvar_has_header = False + elif not line.startswith("#"): # If no header, look at line 1 + pvar_has_header = False + cols_in_pvar = len(line.strip().split()) + if cols_in_pvar == 5: + header = ["#CHROM", "ID", "POS", "ALT", "REF"] + elif cols_in_pvar == 6: + header = ["#CHROM", "ID", "CM", "POS", "ALT", "REF"] + else: + raise ValueError( + f"{pvar_filename} is not a valid pvar file." + ) + break + + def lazy_read(filename: str, **kwargs) -> pl.LazyFrame: + """ + Simple reader function needed due to lack of support for scanning zstd files in polars. + + Args: + filename (str): pvar file, either .pvar or .pvar.zst + **kwargs: CSV arguments for polars + + Returns: + pl.LazyFrame + """ + if filename.endswith('.zst'): + return pl.read_csv(filename, **kwargs).lazy() + else: + return pl.scan_csv(filename, **kwargs) - pvar = pl.scan_csv( + pvar = lazy_read( pvar_filename, separator='\t', skip_rows=pvar_header_line_num, has_header=pvar_has_header, - new_columns=None if pvar_has_header else ["#CHROM", "ID", "CM", "POS", "REF", "ALT"], + new_columns=None if pvar_has_header else header, schema_overrides={ "#CHROM": pl.String, - "POS": pl.Int64, + "POS": pl.UInt32, "ID": pl.String, "REF": pl.String, "ALT": pl.String, }, - ).with_row_index() + ) + + # keeping track of indices provides a problem for lazy execution -> + # it can block predicate pushdown optimization and cause the whole lf to be read into memory + # instead we keep track of the index with "ID" and ensure that is the only columm always read into memory + idxs = pvar.select("ID").with_row_index().collect() - # since pvar is lazy, the skip_rows operation hasn't materialized - # pl.len() will return the length of the pvar + header - file_num_variants = pvar.select(pl.len()).collect().item() - pvar_header_line_num + file_num_variants = idxs.height if variant_ids is not None: + num_variants = np.size(variant_ids) + pvar = pvar.filter(pl.col("ID").is_in(variant_ids)).collect() variant_idxs = ( - pvar.filter(pl.col("ID").is_in(variant_ids)) + idxs.filter(pl.col("ID").is_in(variant_ids)) .select("index") - .collect() .to_series() .to_numpy() ) - - if variant_idxs is None: + elif variant_idxs is not None: + num_variants = np.size(variant_idxs) + variant_idxs = np.array(variant_idxs, dtype=np.uint32) + variant_ids = idxs.filter(pl.col("index").is_in(variant_idxs)).select( + "ID" + ) + pvar = pvar.filter(pl.col("ID").is_in(variant_ids)).collect() + else: num_variants = file_num_variants variant_idxs = np.arange(num_variants, dtype=np.uint32) pvar = pvar.collect() - else: - num_variants = np.size(variant_idxs) - variant_idxs = np.array(variant_idxs, dtype=np.uint32) - pvar = pvar.filter(pl.col("index").is_in(variant_idxs)).collect() log.info(f"Reading {filename_noext}.psam") @@ -165,19 +200,15 @@ def open_textfile(filename): file_num_samples = psam.height if sample_ids is not None: - sample_idxs = ( - psam.filter(pl.col("IID").is_in(sample_ids)) - .select("index") - .to_series() - .to_numpy() - ) - - if sample_idxs is None: - num_samples = file_num_samples - else: + num_samples = np.size(sample_ids) + psam = psam.filter(pl.col("IID").is_in(sample_ids)) + sample_idxs = psam.select("index").to_series().to_numpy() + elif sample_idxs is not None: num_samples = np.size(sample_idxs) sample_idxs = np.array(sample_idxs, dtype=np.uint32) psam = psam.filter(pl.col("index").is_in(sample_idxs)) + else: + num_samples = file_num_samples if "GT" in fields: log.info(f"Reading {filename_noext}.pgen")