From dbfc89d901d96cb3f50be58f4cc109e44f66b037 Mon Sep 17 00:00:00 2001 From: Yasset Perez-Riverol Date: Sat, 30 Nov 2024 14:29:20 +0000 Subject: [PATCH] black applied. --- quantmsutils/mzml/mzml_statistics.py | 178 ++++++++++++++------------- 1 file changed, 95 insertions(+), 83 deletions(-) diff --git a/quantmsutils/mzml/mzml_statistics.py b/quantmsutils/mzml/mzml_statistics.py index d582b0b..a5f6a33 100644 --- a/quantmsutils/mzml/mzml_statistics.py +++ b/quantmsutils/mzml/mzml_statistics.py @@ -17,7 +17,14 @@ class BatchWritingConsumer: pyopenms streaming. """ - def __init__(self, parquet_schema: pa.Schema, id_parquet_schema: pa.Schema, output_path, batch_size=10000, id_only=False): + def __init__( + self, + parquet_schema: pa.Schema, + id_parquet_schema: pa.Schema, + output_path, + batch_size=10000, + id_only=False, + ): self.parquet_schema = parquet_schema self.id_parquet_schema = id_parquet_schema self.output_path = output_path @@ -61,25 +68,31 @@ def consumeSpectrum(self, spectrum): if self.id_only: scan_id = self.scan_pattern.findall(spectrum.getNativeID())[0] - self.psm_parts.append([ - { - "scan": scan_id, - "ms_level": ms_level, - "mz": mz_array.tolist(), - "intensity": intensity_array.tolist() - } - ]) + self.psm_parts.append( + [ + { + "scan": scan_id, + "ms_level": ms_level, + "mz": mz_array.tolist(), + "intensity": intensity_array.tolist(), + } + ] + ) row_data = { "SpectrumID": spectrum.getNativeID(), "MSLevel": float(ms_level), "Charge": float(charge_state) if charge_state is not None else None, "MS_peaks": float(peak_per_ms), - "Base_Peak_Intensity": float(base_peak_intensity) if base_peak_intensity is not None else None, - "Summed_Peak_Intensities": float(total_intensity) if total_intensity is not None else None, + "Base_Peak_Intensity": ( + float(base_peak_intensity) if base_peak_intensity is not None else None + ), + "Summed_Peak_Intensities": ( + float(total_intensity) if total_intensity is not None else None + ), "Retention_Time": float(rt), "Exp_Mass_To_Charge": float(exp_mz) if exp_mz is not None else None, - "AcquisitionDateTime": str(self.acquisition_datetime) + "AcquisitionDateTime": str(self.acquisition_datetime), } elif ms_level == 1: row_data = { @@ -87,11 +100,15 @@ def consumeSpectrum(self, spectrum): "MSLevel": float(ms_level), "Charge": None, "MS_peaks": float(peak_per_ms), - "Base_Peak_Intensity": float(base_peak_intensity) if base_peak_intensity is not None else None, - "Summed_Peak_Intensities": float(total_intensity) if total_intensity is not None else None, + "Base_Peak_Intensity": ( + float(base_peak_intensity) if base_peak_intensity is not None else None + ), + "Summed_Peak_Intensities": ( + float(total_intensity) if total_intensity is not None else None + ), "Retention_Time": float(rt), "Exp_Mass_To_Charge": None, - "AcquisitionDateTime": str(self.acquisition_datetime) + "AcquisitionDateTime": str(self.acquisition_datetime), } else: return @@ -123,7 +140,7 @@ def _write_batch(self): self.id_parquet_writer = pq.ParquetWriter( where=f"{Path(self.output_path).stem}_spectrum_df.parquet", schema=self.id_parquet_schema, - compression="gzip" + compression="gzip", ) self.id_parquet_writer.write_table(spectrum_table) @@ -151,33 +168,23 @@ def finalize(self): if self.id_parquet_writer: self.id_parquet_writer.close() + def column_exists(conn, table_name: str) -> List[str]: """ Fetch the existing columns in the specified SQLite table. """ table_info = pd.read_sql_query(f"PRAGMA table_info({table_name});", conn) - return set(table_info['name'].tolist()) + return set(table_info["name"].tolist()) @click.command("mzmlstats") @click.option("--ms_path", type=click.Path(exists=True), required=True) +@click.option("--id_only", is_flag=True, help="Generate a csv with the spectrum id and the peaks") @click.option( - "--id_only", is_flag=True, - help="Generate a csv with the spectrum id and the peaks" -) -@click.option( - "--batch_size", - type=int, - default=10000, - help="Number of rows to write in each batch" + "--batch_size", type=int, default=10000, help="Number of rows to write in each batch" ) @click.pass_context -def mzml_statistics( - ctx, - ms_path: str, - id_only: bool = False, - batch_size: int = 10000 -) -> None: +def mzml_statistics(ctx, ms_path: str, id_only: bool = False, batch_size: int = 10000) -> None: """ The mzml_statistics function parses mass spectrometry data files, either in .mzML or Bruker .d formats, to extract and compile a set of statistics about the spectra contained within. @@ -194,28 +201,37 @@ def mzml_statistics( :param batch_size: An integer specifying the number of rows to write in each batch. """ - schema = pa.schema([ - pa.field("SpectrumID", pa.string(), nullable=True), - pa.field("MSLevel", pa.float64(), nullable=True), - pa.field("Charge", pa.float64(), nullable=True), - pa.field("MS_peaks", pa.float64(), nullable=True), - pa.field("Base_Peak_Intensity", pa.float64(), nullable=True), - pa.field("Summed_Peak_Intensities", pa.float64(), nullable=True), - pa.field("Retention_Time", pa.float64(), nullable=True), - pa.field("Exp_Mass_To_Charge", pa.float64(), nullable=True), - pa.field("AcquisitionDateTime", pa.string(), nullable=True), - ]) - - id_schema = pa.schema([ - ("scan", pa.string()), - ("ms_level", pa.int32()), - ("mz", pa.list_(pa.float64())), - ("intensity", pa.list_(pa.float64())) - ]) - - def batch_write_mzml_streaming(file_name: str, parquet_schema: pa.Schema, output_path: str, - id_parquet_schema: pa.Schema, id_only: bool = False, - batch_size: int = 10000) -> Optional[str]: + schema = pa.schema( + [ + pa.field("SpectrumID", pa.string(), nullable=True), + pa.field("MSLevel", pa.float64(), nullable=True), + pa.field("Charge", pa.float64(), nullable=True), + pa.field("MS_peaks", pa.float64(), nullable=True), + pa.field("Base_Peak_Intensity", pa.float64(), nullable=True), + pa.field("Summed_Peak_Intensities", pa.float64(), nullable=True), + pa.field("Retention_Time", pa.float64(), nullable=True), + pa.field("Exp_Mass_To_Charge", pa.float64(), nullable=True), + pa.field("AcquisitionDateTime", pa.string(), nullable=True), + ] + ) + + id_schema = pa.schema( + [ + ("scan", pa.string()), + ("ms_level", pa.int32()), + ("mz", pa.list_(pa.float64())), + ("intensity", pa.list_(pa.float64())), + ] + ) + + def batch_write_mzml_streaming( + file_name: str, + parquet_schema: pa.Schema, + output_path: str, + id_parquet_schema: pa.Schema, + id_only: bool = False, + batch_size: int = 10000, + ) -> Optional[str]: """ Parse mzML file in a streaming manner and write to Parquet. """ @@ -228,11 +244,7 @@ def batch_write_mzml_streaming(file_name: str, parquet_schema: pa.Schema, output print(f"Error during streaming: {e}") return None - def batch_write_bruker_d( - file_name: str, - output_path: str, - batch_size: int = 10000 - ) -> str: + def batch_write_bruker_d(file_name: str, output_path: str, batch_size: int = 10000) -> str: """ Batch processing and writing of Bruker .d files. """ @@ -246,24 +258,22 @@ def batch_write_bruker_d( columns = column_exists(conn, "frames") - schema = pa.schema([ - pa.field("Id", pa.int32(), nullable=False), - pa.field("MsMsType", pa.int32(), nullable=True), - pa.field("NumPeaks", pa.int32(), nullable=True), - pa.field("MaxIntensity", pa.float64(), nullable=True), - pa.field("SummedIntensities", pa.float64(), nullable=True), - pa.field("Time", pa.float64(), nullable=True), - pa.field("Charge", pa.int32(), nullable=True), - pa.field("MonoisotopicMz", pa.float64(), nullable=True), - pa.field("AcquisitionDateTime", pa.string(), nullable=True)] + schema = pa.schema( + [ + pa.field("Id", pa.int32(), nullable=False), + pa.field("MsMsType", pa.int32(), nullable=True), + pa.field("NumPeaks", pa.int32(), nullable=True), + pa.field("MaxIntensity", pa.float64(), nullable=True), + pa.field("SummedIntensities", pa.float64(), nullable=True), + pa.field("Time", pa.float64(), nullable=True), + pa.field("Charge", pa.int32(), nullable=True), + pa.field("MonoisotopicMz", pa.float64(), nullable=True), + pa.field("AcquisitionDateTime", pa.string(), nullable=True), + ] ) # Set up parquet writer - parquet_writer = pq.ParquetWriter( - output_path, - schema=schema, - compression='gzip' - ) + parquet_writer = pq.ParquetWriter(output_path, schema=schema, compression="gzip") base_columns = [ "Id", @@ -271,7 +281,7 @@ def batch_write_bruker_d( "NumPeaks", "MaxIntensity", "SummedIntensities", - "Time" + "Time", ] if "Charge" in columns: @@ -289,7 +299,7 @@ def batch_write_bruker_d( try: # Stream data in batches for chunk in pd.read_sql_query(query, conn, chunksize=batch_size): - chunk['AcquisitionDateTime'] = acquisition_date_time + chunk["AcquisitionDateTime"] = acquisition_date_time for col in schema.names: if col not in chunk.columns: chunk[col] = None @@ -309,8 +319,14 @@ def batch_write_bruker_d( if Path(ms_path).suffix == ".d": batch_write_bruker_d(file_name=ms_path, output_path=output_path, batch_size=batch_size) elif Path(ms_path).suffix.lower() in [".mzml"]: - batch_write_mzml_streaming(file_name=ms_path, parquet_schema=schema, id_parquet_schema=id_schema, output_path=output_path, id_only=id_only, - batch_size=batch_size) + batch_write_mzml_streaming( + file_name=ms_path, + parquet_schema=schema, + id_parquet_schema=id_schema, + output_path=output_path, + id_only=id_only, + batch_size=batch_size, + ) else: raise RuntimeError(f"Unsupported file type: {ms_path}") @@ -324,12 +340,8 @@ def _resolve_ms_path(ms_path: str) -> str: return str(path_obj) candidates = list(path_obj.parent.glob(f"{path_obj.stem}*")) - valid_extensions = {'.d', '.mzml', '.mzML'} - candidates = [ - str(c.resolve()) - for c in candidates - if c.suffix.lower() in valid_extensions - ] + valid_extensions = {".d", ".mzml", ".mzML"} + candidates = [str(c.resolve()) for c in candidates if c.suffix.lower() in valid_extensions] if len(candidates) == 1: return candidates[0]