Skip to content

Commit

Permalink
Added io_commons.read_csv to address issues with formatting of sample…
Browse files Browse the repository at this point in the history
… names in gCNV. (#5811)

* Added io_commons.read_csv to address issues with formatting of sample names in gCNV.

* Cleaned up PEP8 violations.

* Added some minor edits to the dtype dictionaries.

* Removed unordered contig set from ploidy-model config JSON output, enforced sort in all config JSON output, and removed some dead code.

* Fixed Number entry of CNLP field in gCNV intervals VCF.
  • Loading branch information
samuelklee authored Mar 22, 2019
1 parent fb2b5a2 commit 022800c
Show file tree
Hide file tree
Showing 17 changed files with 172 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public void composeVariantContextHeader(final Set<VCFHeaderLine> vcfDefaultToolH
VCFHeaderLineType.Integer, "Genotype"));
result.addMetaDataLine(new VCFFormatHeaderLine(CN, 1,
VCFHeaderLineType.Integer, "Copy number maximum a posteriori value"));
result.addMetaDataLine(new VCFFormatHeaderLine(CNLP, VCFHeaderLineCount.A,
result.addMetaDataLine(new VCFFormatHeaderLine(CNLP, VCFHeaderLineCount.UNBOUNDED,
VCFHeaderLineType.Integer, "Copy number log posterior (in Phred-scale) rounded down"));
result.addMetaDataLine(new VCFFormatHeaderLine(CNQ, 1,
VCFHeaderLineType.Integer, "Genotype call quality as the difference between" +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,4 +249,4 @@ def initialize_state_from_instance(self, instance: 'FancyAdamax'):
self.get_rho_m().set_value(instance.get_rho_m().get_value())
self.get_rho_u().set_value(instance.get_rho_u().get_value())
if not self.disable_bias_correction:
self.get_res_tensor().set_value(instance.get_res_tensor().get_value())
self.get_res_tensor().set_value(instance.get_res_tensor().get_value())
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import re
import pandas as pd
from ast import literal_eval as make_tuple
from typing import List, Optional, Tuple, Set, Dict

Expand All @@ -16,23 +17,71 @@
_logger = logging.getLogger(__name__)


def read_csv(input_file: str,
dtypes_dict: Dict[str, object]=None,
mandatory_columns_set: Set[str]=None,
comment=io_consts.default_comment_char,
delimiter=io_consts.default_delimiter_char) -> pd.DataFrame:
"""Opens a file and seeks to the first line that does not start with the comment character,
checks for mandatory columns in this column-header line, and returns a pandas DataFrame.
Prefer using this method rather than pandas read_csv, because the comment character will only have an effect
when it is present at the beginning of a line at the beginning of the file (pandas can otherwise strip
characters that follow a comment character that appears in the middle of a line, which can corrupt sample names).
Dtypes for columns can be provided, but those for column names that are not known ahead of time will be inferred.
Args:
input_file: input file
dtypes_dict: dictionary of column headers to dtypes; keys will be taken as mandatory columns unless
mandatory_columns_set is also provided
mandatory_columns_set: set of mandatory header columns; must be subset of dtypes_dict keys if provided
comment: comment character
delimiter: delimiter character
Returns:
pandas DataFrame
"""
with open(input_file, 'r') as fh:
while True:
pos = fh.tell()
line = fh.readline()
if not line.startswith(comment):
fh.seek(pos)
break
input_pd = pd.read_csv(fh, delimiter=delimiter, dtype=dtypes_dict) # dtypes_dict keys may not be present
found_columns_set = {str(column) for column in input_pd.columns.values}
assert dtypes_dict is not None or mandatory_columns_set is None, \
"Cannot specify mandatory_columns_set if dtypes_dict is not specified."
if dtypes_dict is not None:
dtypes_dict_keys_set = set(dtypes_dict.keys())
if mandatory_columns_set is None:
assert_mandatory_columns(dtypes_dict_keys_set, found_columns_set, input_file)
else:
assert mandatory_columns_set.issubset(dtypes_dict_keys_set), \
"The mandatory_columns_set must be a subset of the dtypes_dict keys."
assert_mandatory_columns(mandatory_columns_set, found_columns_set, input_file)

return input_pd


def extract_sample_name_from_header(input_file: str,
max_scan_lines: int = 10000,
comment=io_consts.default_comment_char,
sample_name_header_regexp: str = io_consts.sample_name_header_regexp) -> str:
"""Extracts sample name from header.
"""Extracts sample name from header (all lines up to the first line that does not start with the comment character).
Args:
input_file: any readable text file
max_scan_lines: maximum number of lines to scan from the top of the file
comment: comment character
sample_name_header_regexp: the regular expression for identifying the header line that contains
the sample name
Returns:
Sample name
"""
with open(input_file, 'r') as f:
for _ in range(max_scan_lines):
while True:
line = f.readline()
if not line.startswith(comment):
break
match = re.search(sample_name_header_regexp, line, re.M)
if match is None:
continue
Expand Down Expand Up @@ -186,7 +235,7 @@ def _get_value(key: str, _line: str):
if shape is None:
shape = _get_value('shape', stripped_line)
else:
assert dtype is not None and shape is not None,\
assert dtype is not None and shape is not None, \
"Shape and dtype information could not be found in the header of " \
"\"{0}\"".format(input())
row = np.asarray(stripped_line.split(delimiter), dtype=dtype)
Expand All @@ -212,7 +261,7 @@ def get_var_map_list_from_mean_field_approx(approx: pm.MeanField) -> List[pm.blo
raise Exception("Unsupported PyMC3 version")


def extract_mean_field_posterior_parameters(approx: pm.MeanField)\
def extract_mean_field_posterior_parameters(approx: pm.MeanField) \
-> Tuple[Set[str], Dict[str, np.ndarray], Dict[str, np.ndarray]]:
"""Extracts mean-field posterior parameters in the right shape and dtype from an instance
of PyMC3 mean-field approximation.
Expand Down Expand Up @@ -247,7 +296,7 @@ def write_dict_to_json_file(output_file: str,
dict_to_write: dictionary to write to file
ignored_keys: a set of keys to ignore
"""
filtered_dict = {k: v for k, v in dict_to_write.items() if k not in ignored_keys}
filtered_dict = {k: v for k, v in sorted(dict_to_write.items()) if k not in ignored_keys}
with open(output_file, 'w') as fp:
json.dump(filtered_dict, fp, indent=1)

Expand Down Expand Up @@ -293,12 +342,12 @@ def _get_singleton_slice_along_axis(array: np.ndarray, axis: int, index: int):


def write_mean_field_sample_specific_params(sample_index: int,
sample_posterior_path: str,
approx_var_name_set: Set[str],
approx_mu_map: Dict[str, np.ndarray],
approx_std_map: Dict[str, np.ndarray],
model: GeneralizedContinuousModel,
extra_comment_lines: Optional[List[str]] = None):
sample_posterior_path: str,
approx_var_name_set: Set[str],
approx_mu_map: Dict[str, np.ndarray],
approx_std_map: Dict[str, np.ndarray],
model: GeneralizedContinuousModel,
extra_comment_lines: Optional[List[str]] = None):
"""Writes sample-specific parameters contained in an instance of PyMC3 mean-field approximation
to disk.
Expand Down Expand Up @@ -329,8 +378,8 @@ def write_mean_field_sample_specific_params(sample_index: int,


def write_mean_field_global_params(output_path: str,
approx: pm.MeanField,
model: GeneralizedContinuousModel):
approx: pm.MeanField,
model: GeneralizedContinuousModel):
"""Writes global parameters contained in an instance of PyMC3 mean-field approximation to disk.
Args:
Expand All @@ -356,8 +405,8 @@ def write_mean_field_global_params(output_path: str,


def read_mean_field_global_params(input_model_path: str,
approx: pm.MeanField,
model: GeneralizedContinuousModel) -> None:
approx: pm.MeanField,
model: GeneralizedContinuousModel) -> None:
"""Reads global parameters of a given model from saved mean-field posteriors and injects them
into a provided mean-field instance.
Expand Down Expand Up @@ -392,7 +441,7 @@ def _update_param_inplace(param, slc, dtype, new_value):

for vmap in vmap_list:
if vmap.var == var_name:
assert var_mu.shape == vmap.shp,\
assert var_mu.shape == vmap.shp, \
"Loaded mean for \"{0}\" has an unexpected shape; loaded: {1}, " \
"expected: {2}".format(var_name, var_mu.shape, vmap.shp)
assert var_rho.shape == vmap.shp, \
Expand All @@ -405,10 +454,10 @@ def _update_param_inplace(param, slc, dtype, new_value):


def read_mean_field_sample_specific_params(input_sample_calls_path: str,
sample_index: int,
sample_name: str,
approx: pm.MeanField,
model: GeneralizedContinuousModel):
sample_index: int,
sample_name: str,
approx: pm.MeanField,
model: GeneralizedContinuousModel):
"""Reads sample-specific parameters of a given sample from saved mean-field posteriors and injects them
into a provided mean-field instance.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from .. import types

# interval list .tsv file column names
contig_column_name = "CONTIG"
start_column_name = "START"
Expand All @@ -22,8 +24,8 @@
ploidy_gq_column_name = "PLOIDY_GQ"

# column names for copy-number segments file
call_copy_number_column_name = "CALL_COPY_NUMBER"
num_points_column_name = "NUM_POINTS"
call_copy_number_column_name = "CALL_COPY_NUMBER"
quality_some_called_column_name = "QUALITY_SOME_CALLED"
quality_all_called_column_name = "QUALITY_ALL_CALLED"
quality_start_column_name = "QUALITY_START"
Expand All @@ -45,6 +47,55 @@
default_comment_char = "@"
default_delimiter_char = "\t"

# dtype dictionaries giving types of mandatory columns whose names are known ahead of time
# (some of these dictionaries are not currently used, but we define their formats for future reference)
interval_dtypes_dict = {
contig_column_name: str,
start_column_name: types.med_uint,
end_column_name: types.med_uint
}

read_count_dtypes_dict = {
**interval_dtypes_dict,
count_column_name: types.med_uint
}

ploidy_prior_dtypes_dict = {
ploidy_prior_contig_name_column: str
}

sample_coverage_metadata_dtypes_dict = {
sample_name_column_name: str
}

sample_ploidy_metadata_dtypes_dict = {
contig_column_name: str,
ploidy_column_name: types.small_uint,
ploidy_gq_column_name: types.floatX
}

sample_read_depth_metadata_dtypes_dict = {
global_read_depth_column_name: types.floatX,
average_ploidy_column_name: types.floatX
}

copy_number_segment_dtypes_dict = {
**interval_dtypes_dict,
num_points_column_name: types.med_uint,
call_copy_number_column_name: types.small_uint,
baseline_copy_number_column_name: types.small_uint,
quality_some_called_column_name: types.floatX,
quality_all_called_column_name: types.floatX,
quality_start_column_name: types.floatX,
quality_end_column_name: types.floatX
}

denoised_copy_ratio_dtypes_dict = {
**interval_dtypes_dict,
denoised_copy_ratio_mean_column_name: types.floatX,
denoised_copy_ratio_std_column_name: types.floatX
}

# default file names for loading and saving models, posteriors, and configurations
default_sample_read_depth_tsv_filename = "global_read_depth.tsv"
default_sample_name_txt_filename = "sample_name.txt"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import List, Optional

import numpy as np
import pandas as pd
import pymc3 as pm

from . import io_commons
Expand Down Expand Up @@ -240,7 +239,7 @@ def read_ndarray_tc_with_copy_number_header(sample_posterior_path: str,
delimiter=io_consts.default_delimiter_char) -> np.ndarray:
"""Reads a TSV-formatted dim-2 (intervals x copy-number) ndarray from a sample posterior path."""
ndarray_tc_tsv_file = os.path.join(sample_posterior_path, input_file_name)
ndarray_tc_pd = pd.read_csv(ndarray_tc_tsv_file, delimiter=delimiter, comment=comment)
ndarray_tc_pd = io_commons.read_csv(ndarray_tc_tsv_file, comment=comment, delimiter=delimiter)
read_columns = [str(column_name) for column_name in ndarray_tc_pd.columns.values]
num_read_columns = len(read_columns)
expected_copy_number_header_columns =\
Expand All @@ -260,7 +259,7 @@ def _read_sample_copy_number_log_posterior(self,
delimiter=delimiter,
comment=comment)
assert read_log_q_c_tc.shape == (self.denoising_calling_workspace.num_intervals,
self.denoising_calling_workspace.calling_config.num_copy_number_states)
self.denoising_calling_workspace.calling_config.num_copy_number_states)
return read_log_q_c_tc

def _read_sample_copy_number_log_emission(self,
Expand All @@ -272,9 +271,8 @@ def _read_sample_copy_number_log_emission(self,
io_consts.default_copy_number_log_emission_tsv_filename,
delimiter=delimiter,
comment=comment)
assert read_log_emission_tc.shape ==\
(self.denoising_calling_workspace.num_intervals,
self.denoising_calling_workspace.calling_config.num_copy_number_states)
assert read_log_emission_tc.shape == (self.denoising_calling_workspace.num_intervals,
self.denoising_calling_workspace.calling_config.num_copy_number_states)
return read_log_emission_tc

def __call__(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,11 @@

_logger = logging.getLogger(__name__)

interval_dtypes_dict = {
io_consts.contig_column_name: np.str,
io_consts.start_column_name: types.med_uint,
io_consts.end_column_name: types.med_uint
}

read_count_dtypes_dict = {
**interval_dtypes_dict,
io_consts.count_column_name: types.med_uint
}
interval_dtypes_dict = io_consts.interval_dtypes_dict
read_count_dtypes_dict = io_consts.read_count_dtypes_dict


def load_read_counts_tsv_file(read_counts_tsv_file: str,
max_rows: Optional[int] = None,
return_interval_list: bool = False,
comment=io_consts.default_comment_char,
delimiter=io_consts.default_delimiter_char) \
Expand All @@ -32,7 +23,6 @@ def load_read_counts_tsv_file(read_counts_tsv_file: str,
Args:
read_counts_tsv_file: input read counts .tsv file
max_rows: (optional) maximum number of rows to process
return_interval_list: if true, an interval list will also be generated and returned
delimiter: delimiter character
comment: comment character
Expand All @@ -41,8 +31,10 @@ def load_read_counts_tsv_file(read_counts_tsv_file: str,
sample name, counts, (and optionally a list of intervals if `return_interval_list` == True)
"""
sample_name = io_commons.extract_sample_name_from_header(read_counts_tsv_file)
counts_pd = pd.read_csv(read_counts_tsv_file, delimiter=delimiter, comment=comment, nrows=max_rows,
dtype={**read_count_dtypes_dict})
counts_pd = io_commons.read_csv(read_counts_tsv_file,
dtypes_dict=read_count_dtypes_dict,
comment=comment,
delimiter=delimiter)
if return_interval_list:
interval_list_pd = counts_pd[list(interval_dtypes_dict.keys())]
interval_list = _convert_interval_list_pandas_to_gcnv_interval_list(interval_list_pd, read_counts_tsv_file)
Expand All @@ -57,14 +49,17 @@ def load_interval_list_tsv_file(interval_list_tsv_file: str,
"""Loads an interval list .tsv file.
Args:
interval_list_tsv_file: input interval list .tsv file
delimiter: delimiter character
comment: comment character
delimiter: delimiter character
Returns:
interval list
"""
interval_list_pd = pd.read_csv(interval_list_tsv_file, delimiter=delimiter, comment=comment,
dtype={**interval_dtypes_dict, **interval_annotations_dtypes})
interval_list_pd = io_commons.read_csv(interval_list_tsv_file,
dtypes_dict={**interval_dtypes_dict, **interval_annotations_dtypes},
mandatory_columns_set=set(interval_dtypes_dict.keys()),
comment=comment,
delimiter=delimiter)
return _convert_interval_list_pandas_to_gcnv_interval_list(interval_list_pd, interval_list_tsv_file)


Expand Down
Loading

0 comments on commit 022800c

Please sign in to comment.