Skip to content

Commit

Permalink
+ Update io.read_csv to return a list[dict[str, str]], not `list[…
Browse files Browse the repository at this point in the history
…str]` to align the output format with `read_xlsx`.
  • Loading branch information
akikuno committed Apr 30, 2024
1 parent 33e955f commit d406d34
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 43 deletions.
69 changes: 36 additions & 33 deletions src/DAJIN2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import shutil
import argparse
from pathlib import Path
from copy import deepcopy
from itertools import groupby

from DAJIN2 import gui, view
Expand Down Expand Up @@ -55,42 +56,43 @@ def execute_single_mode(arguments: dict[str]):
################################################################################


def validate_headers_of_batch_file(headers: list, filepath: str) -> None:
def validate_headers_of_batch_file(headers: set[str], filepath: str) -> None:
"""Validate the headers of a batch file."""
required_headers = ["sample", "control", "allele", "name"]
accepted_headers = ["sample", "control", "allele", "name", "genome"]
required_headers = {"sample", "control", "allele", "name"}
accepted_headers = {"sample", "control", "allele", "name", "genome"}

if not set(required_headers).issubset(set(headers)):
if not required_headers.issubset(headers):
raise ValueError(f"{filepath} must contain {', '.join(required_headers)} in the header")

if not set(headers).issubset(accepted_headers):
if not headers.issubset(accepted_headers):
raise ValueError(f"Accepted header names of {filepath} are {', '.join(accepted_headers)}.")


def create_argument_dict(headers: list, group: list, cache_urls_genome: dict, is_control: bool) -> dict:
def create_argument_dict(args: dict, cache_urls_genome: dict, is_control: bool) -> dict[str, str]:
"""Create a dictionary of arguments from the given headers and group."""
args = dict(zip(headers, group))
args["threads"] = 1 # Set the number of threads to 1 for batch mode
args_update = deepcopy(args)

args_update["threads"] = 1 # Set the number of threads to 1 for batch mode

# Assign the "sample" field depending on whether it's a control or not
if is_control:
args["sample"] = args["control"]
args_update["sample"] = args_update["control"]
else:
if args["sample"] == args["control"]:
if args_update["sample"] == args_update["control"]:
return {} # Return an empty dict to indicate a skipped group

if args.get("genome"):
args.update(cache_urls_genome[args["genome"]])
if args_update.get("genome"):
args_update.update(cache_urls_genome[args_update["genome"]])

return args
return args_update


def run_DAJIN2(
groups: list, headers: list, cache_urls_genome: dict, is_control: bool = True, num_workers: int = 1
groups: list[dict[str, str]], cache_urls_genome: dict, is_control: bool = True, num_workers: int = 1
) -> None:
contents = []
for group in groups:
args = create_argument_dict(headers, group, cache_urls_genome, is_control)
for args in groups:
args = create_argument_dict(args, cache_urls_genome, is_control)
if args: # Add args to contents only if it's not an empty dict
contents.append(args)

Expand All @@ -111,36 +113,37 @@ def execute_batch_mode(arguments: dict[str]):
if not Path(path_batchfile).exists():
raise FileNotFoundError(f"'{path_batchfile}' does not exist.")

inputs = io.load_batchfile(path_batchfile)
records = io.load_batchfile(path_batchfile)

# Validate Column of the batch file
headers = inputs[0]
headers = set(records[0].keys())
validate_headers_of_batch_file(headers, path_batchfile)

# Validate contents and fetch genome urls
contents = inputs[1:]
cache_urls_genome = dict()
index_of_name = headers.index("name")
contents.sort(key=lambda x: x[index_of_name])
for _, groups in groupby(contents, key=lambda x: x[index_of_name]):
for group in groups:
args = dict(zip(headers, group))
# validate contents in the batch file
records.sort(key=lambda x: x["name"])
for _, groups in groupby(records, key=lambda x: x["name"]):
for args in groups:
# Validate contents in the batch file
input_validator.validate_files(args["sample"], args["control"], args["allele"])
# validate genome and fetch urls
# Validate genome and fetch urls
if args.get("genome") and args["genome"] not in cache_urls_genome:
urls_genome = input_validator.validate_genome_and_fetch_urls(args["genome"])
cache_urls_genome[args["genome"]] = urls_genome
for name, groups in groupby(contents, key=lambda x: x[index_of_name]):
# set logging to export log to stderr and file

# Run DAJIN2
for name, groups in groupby(records, key=lambda x: x["name"]):
groups: list[dict[str, str]] = list(groups)
# Set logging to export log to stderr and file
config.reset_logging()
path_logfile = config.get_logfile()
config.set_logging(path_logfile)
groups = list(groups)
# Run DAJIN2
run_DAJIN2(groups, headers, cache_urls_genome, is_control=True, num_workers=arguments["threads"])
run_DAJIN2(groups, headers, cache_urls_genome, is_control=False, num_workers=arguments["threads"])
# Finish

# Start DAJIN2
run_DAJIN2(groups, cache_urls_genome, is_control=True, num_workers=arguments["threads"])
run_DAJIN2(groups, cache_urls_genome, is_control=False, num_workers=arguments["threads"])

# Finish call
generate_report(name)
shutil.move(path_logfile, Path("DAJIN_Results", name))
if not arguments["debug"]:
Expand Down
29 changes: 19 additions & 10 deletions src/DAJIN2/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,30 +95,39 @@ def determine_file_type(file_path: str) -> str | None:


def read_xlsx(file_path: str | Path) -> list[dict[str, str]]:
"""Load data from an Excel file and return as a list."""
"""Load data from an Excel file."""
wb = load_workbook(filename=file_path)
ws = wb.active

headers = [cell for cell in next(ws.iter_rows(min_row=1, max_row=1, values_only=True))]

data = []
records = []
for row in ws.iter_rows(min_row=2, values_only=True):
if all(element is None for element in row): # Skip rows with all None values
continue
row_data = {headers[i]: (row[i] if i < len(row) else None) for i in range(len(headers))}
data.append(row_data)
records.append(row_data)

return data
return records


def read_csv(file_path: str) -> list[dict[str, str]]:
"""Load data from a CSV file and return as a list."""
def read_csv(file_path: str | Path) -> list[dict[str, str]]:
"""Load data from a CSV file."""
with open(file_path, "r") as csvfile:
inputs = []

header = [field.strip() for field in next(csv.reader(csvfile))]

records = []
for row in csv.reader(csvfile):
if not row: # Skip empty rows
continue
trimmed_row = [field.strip() for field in row]
inputs.append(trimmed_row)
return inputs
if all(element is None for element in row): # Skip rows with all None values
continue
row_trimmed = [field.strip() for field in row]
row_data = {h: v for h, v in zip(header, row_trimmed)}
records.append(row_data)

return records


def load_batchfile(batchfile_path: str) -> list[dict[str, str]]:
Expand Down

0 comments on commit d406d34

Please sign in to comment.