Skip to content

Commit

Permalink
black applied.
Browse files Browse the repository at this point in the history
  • Loading branch information
ypriverol committed Nov 30, 2024
1 parent 3e4613c commit dbfc89d
Showing 1 changed file with 95 additions and 83 deletions.
178 changes: 95 additions & 83 deletions quantmsutils/mzml/mzml_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@ class BatchWritingConsumer:
pyopenms streaming.
"""

def __init__(self, parquet_schema: pa.Schema, id_parquet_schema: pa.Schema, output_path, batch_size=10000, id_only=False):
def __init__(
self,
parquet_schema: pa.Schema,
id_parquet_schema: pa.Schema,
output_path,
batch_size=10000,
id_only=False,
):
self.parquet_schema = parquet_schema
self.id_parquet_schema = id_parquet_schema
self.output_path = output_path
Expand Down Expand Up @@ -61,37 +68,47 @@ def consumeSpectrum(self, spectrum):

if self.id_only:
scan_id = self.scan_pattern.findall(spectrum.getNativeID())[0]
self.psm_parts.append([
{
"scan": scan_id,
"ms_level": ms_level,
"mz": mz_array.tolist(),
"intensity": intensity_array.tolist()
}
])
self.psm_parts.append(
[
{
"scan": scan_id,
"ms_level": ms_level,
"mz": mz_array.tolist(),
"intensity": intensity_array.tolist(),
}
]
)

row_data = {
"SpectrumID": spectrum.getNativeID(),
"MSLevel": float(ms_level),
"Charge": float(charge_state) if charge_state is not None else None,
"MS_peaks": float(peak_per_ms),
"Base_Peak_Intensity": float(base_peak_intensity) if base_peak_intensity is not None else None,
"Summed_Peak_Intensities": float(total_intensity) if total_intensity is not None else None,
"Base_Peak_Intensity": (
float(base_peak_intensity) if base_peak_intensity is not None else None
),
"Summed_Peak_Intensities": (
float(total_intensity) if total_intensity is not None else None
),
"Retention_Time": float(rt),
"Exp_Mass_To_Charge": float(exp_mz) if exp_mz is not None else None,
"AcquisitionDateTime": str(self.acquisition_datetime)
"AcquisitionDateTime": str(self.acquisition_datetime),
}
elif ms_level == 1:
row_data = {
"SpectrumID": spectrum.getNativeID(),
"MSLevel": float(ms_level),
"Charge": None,
"MS_peaks": float(peak_per_ms),
"Base_Peak_Intensity": float(base_peak_intensity) if base_peak_intensity is not None else None,
"Summed_Peak_Intensities": float(total_intensity) if total_intensity is not None else None,
"Base_Peak_Intensity": (
float(base_peak_intensity) if base_peak_intensity is not None else None
),
"Summed_Peak_Intensities": (
float(total_intensity) if total_intensity is not None else None
),
"Retention_Time": float(rt),
"Exp_Mass_To_Charge": None,
"AcquisitionDateTime": str(self.acquisition_datetime)
"AcquisitionDateTime": str(self.acquisition_datetime),
}
else:
return
Expand Down Expand Up @@ -123,7 +140,7 @@ def _write_batch(self):
self.id_parquet_writer = pq.ParquetWriter(
where=f"{Path(self.output_path).stem}_spectrum_df.parquet",
schema=self.id_parquet_schema,
compression="gzip"
compression="gzip",
)

self.id_parquet_writer.write_table(spectrum_table)
Expand Down Expand Up @@ -151,33 +168,23 @@ def finalize(self):
if self.id_parquet_writer:
self.id_parquet_writer.close()


def column_exists(conn, table_name: str) -> List[str]:
"""
Fetch the existing columns in the specified SQLite table.
"""
table_info = pd.read_sql_query(f"PRAGMA table_info({table_name});", conn)
return set(table_info['name'].tolist())
return set(table_info["name"].tolist())


@click.command("mzmlstats")
@click.option("--ms_path", type=click.Path(exists=True), required=True)
@click.option("--id_only", is_flag=True, help="Generate a csv with the spectrum id and the peaks")
@click.option(
"--id_only", is_flag=True,
help="Generate a csv with the spectrum id and the peaks"
)
@click.option(
"--batch_size",
type=int,
default=10000,
help="Number of rows to write in each batch"
"--batch_size", type=int, default=10000, help="Number of rows to write in each batch"
)
@click.pass_context
def mzml_statistics(
ctx,
ms_path: str,
id_only: bool = False,
batch_size: int = 10000
) -> None:
def mzml_statistics(ctx, ms_path: str, id_only: bool = False, batch_size: int = 10000) -> None:
"""
The mzml_statistics function parses mass spectrometry data files, either in
.mzML or Bruker .d formats, to extract and compile a set of statistics about the spectra contained within.
Expand All @@ -194,28 +201,37 @@ def mzml_statistics(
:param batch_size: An integer specifying the number of rows to write in each batch.
"""
schema = pa.schema([
pa.field("SpectrumID", pa.string(), nullable=True),
pa.field("MSLevel", pa.float64(), nullable=True),
pa.field("Charge", pa.float64(), nullable=True),
pa.field("MS_peaks", pa.float64(), nullable=True),
pa.field("Base_Peak_Intensity", pa.float64(), nullable=True),
pa.field("Summed_Peak_Intensities", pa.float64(), nullable=True),
pa.field("Retention_Time", pa.float64(), nullable=True),
pa.field("Exp_Mass_To_Charge", pa.float64(), nullable=True),
pa.field("AcquisitionDateTime", pa.string(), nullable=True),
])

id_schema = pa.schema([
("scan", pa.string()),
("ms_level", pa.int32()),
("mz", pa.list_(pa.float64())),
("intensity", pa.list_(pa.float64()))
])

def batch_write_mzml_streaming(file_name: str, parquet_schema: pa.Schema, output_path: str,
id_parquet_schema: pa.Schema, id_only: bool = False,
batch_size: int = 10000) -> Optional[str]:
schema = pa.schema(
[
pa.field("SpectrumID", pa.string(), nullable=True),
pa.field("MSLevel", pa.float64(), nullable=True),
pa.field("Charge", pa.float64(), nullable=True),
pa.field("MS_peaks", pa.float64(), nullable=True),
pa.field("Base_Peak_Intensity", pa.float64(), nullable=True),
pa.field("Summed_Peak_Intensities", pa.float64(), nullable=True),
pa.field("Retention_Time", pa.float64(), nullable=True),
pa.field("Exp_Mass_To_Charge", pa.float64(), nullable=True),
pa.field("AcquisitionDateTime", pa.string(), nullable=True),
]
)

id_schema = pa.schema(
[
("scan", pa.string()),
("ms_level", pa.int32()),
("mz", pa.list_(pa.float64())),
("intensity", pa.list_(pa.float64())),
]
)

def batch_write_mzml_streaming(
file_name: str,
parquet_schema: pa.Schema,
output_path: str,
id_parquet_schema: pa.Schema,
id_only: bool = False,
batch_size: int = 10000,
) -> Optional[str]:
"""
Parse mzML file in a streaming manner and write to Parquet.
"""
Expand All @@ -228,11 +244,7 @@ def batch_write_mzml_streaming(file_name: str, parquet_schema: pa.Schema, output
print(f"Error during streaming: {e}")
return None

def batch_write_bruker_d(
file_name: str,
output_path: str,
batch_size: int = 10000
) -> str:
def batch_write_bruker_d(file_name: str, output_path: str, batch_size: int = 10000) -> str:
"""
Batch processing and writing of Bruker .d files.
"""
Expand All @@ -246,32 +258,30 @@ def batch_write_bruker_d(

columns = column_exists(conn, "frames")

schema = pa.schema([
pa.field("Id", pa.int32(), nullable=False),
pa.field("MsMsType", pa.int32(), nullable=True),
pa.field("NumPeaks", pa.int32(), nullable=True),
pa.field("MaxIntensity", pa.float64(), nullable=True),
pa.field("SummedIntensities", pa.float64(), nullable=True),
pa.field("Time", pa.float64(), nullable=True),
pa.field("Charge", pa.int32(), nullable=True),
pa.field("MonoisotopicMz", pa.float64(), nullable=True),
pa.field("AcquisitionDateTime", pa.string(), nullable=True)]
schema = pa.schema(
[
pa.field("Id", pa.int32(), nullable=False),
pa.field("MsMsType", pa.int32(), nullable=True),
pa.field("NumPeaks", pa.int32(), nullable=True),
pa.field("MaxIntensity", pa.float64(), nullable=True),
pa.field("SummedIntensities", pa.float64(), nullable=True),
pa.field("Time", pa.float64(), nullable=True),
pa.field("Charge", pa.int32(), nullable=True),
pa.field("MonoisotopicMz", pa.float64(), nullable=True),
pa.field("AcquisitionDateTime", pa.string(), nullable=True),
]
)

# Set up parquet writer
parquet_writer = pq.ParquetWriter(
output_path,
schema=schema,
compression='gzip'
)
parquet_writer = pq.ParquetWriter(output_path, schema=schema, compression="gzip")

base_columns = [
"Id",
"CASE WHEN MsMsType IN (8, 9) THEN 2 WHEN MsMsType = 0 THEN 1 ELSE NULL END as MsMsType",
"NumPeaks",
"MaxIntensity",
"SummedIntensities",
"Time"
"Time",
]

if "Charge" in columns:
Expand All @@ -289,7 +299,7 @@ def batch_write_bruker_d(
try:
# Stream data in batches
for chunk in pd.read_sql_query(query, conn, chunksize=batch_size):
chunk['AcquisitionDateTime'] = acquisition_date_time
chunk["AcquisitionDateTime"] = acquisition_date_time
for col in schema.names:
if col not in chunk.columns:
chunk[col] = None
Expand All @@ -309,8 +319,14 @@ def batch_write_bruker_d(
if Path(ms_path).suffix == ".d":
batch_write_bruker_d(file_name=ms_path, output_path=output_path, batch_size=batch_size)
elif Path(ms_path).suffix.lower() in [".mzml"]:
batch_write_mzml_streaming(file_name=ms_path, parquet_schema=schema, id_parquet_schema=id_schema, output_path=output_path, id_only=id_only,
batch_size=batch_size)
batch_write_mzml_streaming(
file_name=ms_path,
parquet_schema=schema,
id_parquet_schema=id_schema,
output_path=output_path,
id_only=id_only,
batch_size=batch_size,
)
else:
raise RuntimeError(f"Unsupported file type: {ms_path}")

Expand All @@ -324,12 +340,8 @@ def _resolve_ms_path(ms_path: str) -> str:
return str(path_obj)

candidates = list(path_obj.parent.glob(f"{path_obj.stem}*"))
valid_extensions = {'.d', '.mzml', '.mzML'}
candidates = [
str(c.resolve())
for c in candidates
if c.suffix.lower() in valid_extensions
]
valid_extensions = {".d", ".mzml", ".mzML"}
candidates = [str(c.resolve()) for c in candidates if c.suffix.lower() in valid_extensions]

if len(candidates) == 1:
return candidates[0]
Expand Down

0 comments on commit dbfc89d

Please sign in to comment.