Skip to content

Commit

Permalink
Make +addclust back up the original files, to avoid data loss in case…
Browse files Browse the repository at this point in the history
… adding more clusters fails
  • Loading branch information
matthewfallan committed Jan 14, 2024
1 parent bcf6e13 commit 8e30c46
Show file tree
Hide file tree
Showing 6 changed files with 360 additions and 106 deletions.
162 changes: 97 additions & 65 deletions src/seismicrna/cluster/update.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime
from logging import getLogger
from pathlib import Path
from shutil import rmtree

from click import command

Expand All @@ -17,11 +18,13 @@
docdef,
arg_input_path,
opt_max_clusters,
opt_temp_dir,
opt_keep_temp,
opt_brotli_level,
opt_parallel,
opt_max_procs)
from ..core.io import recast_file_path
from ..core.parallel import as_list_of_tuples, dispatch
from ..core.io import make_temp_backup, recast_file_path, restore_temp_backup
from ..core.parallel import as_list_of_tuples, dispatch, lock_temp_dir
from ..core.report import (calc_dt_minutes,
Field,
TimeBeganF,
Expand Down Expand Up @@ -60,6 +63,8 @@ def update_field(report: ClustReport,

def add_orders(cluster_report_file: Path,
max_order: int, *,
temp_dir: Path,
keep_temp: bool,
brotli_level: int,
n_procs: int):
""" Add orders to an existing report and dataset. """
Expand All @@ -84,83 +89,105 @@ def add_orders(cluster_report_file: Path,
ClustReport,
MaskReport))
uniq_reads = UniqReads.from_dataset(mask_dataset)
# Run clustering for every new order.
orders = list(run_orders(uniq_reads,
original_max_order + 1,
max_order,
n_runs,
prev_bic=prev_bic,
min_iter=min_iter,
max_iter=max_iter,
conv_thresh=conv_thresh,
n_procs=n_procs,
top=mask_dataset.top))
if orders:
best_order = find_best_order(orders)
# Update the observed and expected counts for each best run.
try:
# Make a temporary backup of the original results.
backup_dir = make_temp_backup(cluster_report_file.parent,
mask_dataset.top,
temp_dir)
try:
# Run clustering for every new order.
orders = list(run_orders(uniq_reads,
original_max_order + 1,
max_order,
n_runs,
prev_bic=prev_bic,
min_iter=min_iter,
max_iter=max_iter,
conv_thresh=conv_thresh,
n_procs=n_procs,
top=mask_dataset.top))
if orders:
best_order = find_best_order(orders)
# Update the expected counts for each best run.
update_log_counts(orders,
top=mask_dataset.top,
sample=mask_dataset.sample,
ref=mask_dataset.ref,
sect=mask_dataset.sect)
except Exception as error:
logger.error(f"Failed to update counts: {error}")
# Output the cluster memberships in batches of reads.
cluster_dataset = load_cluster_dataset(cluster_report_file)
checksums = update_batches(cluster_dataset, orders, brotli_level)
ended = datetime.now()
new_report = ClustReport(
sample=cluster_dataset.sample,
ref=cluster_dataset.ref,
sect=cluster_dataset.sect,
n_uniq_reads=uniq_reads.num_uniq,
max_order=max_order,
num_runs=n_runs,
min_iter=min_iter,
max_iter=max_iter,
conv_thresh=conv_thresh,
checksums={ClustBatchIO.btype(): checksums},
n_batches=len(checksums),
converged=update_field(original_report,
ClustsConvF,
orders,
"converged"),
log_likes=update_field(original_report,
ClustsLogLikesF,
orders,
"log_likes"),
clusts_rmsds=update_field(original_report,
ClustsRMSDsF,
orders,
"rmsds"),
clusts_meanr=update_field(original_report,
ClustsMeanRsF,
orders,
"meanr"),
bic=update_field(original_report,
ClustsBicF,
orders,
"bic"),
best_order=best_order,
began=original_began,
ended=ended,
taken=taken + calc_dt_minutes(new_began, ended),
# Output the cluster memberships in batches of reads.
cluster_dataset = load_cluster_dataset(cluster_report_file)
checksums = update_batches(cluster_dataset,
orders,
brotli_level)
ended = datetime.now()
new_report = ClustReport(
sample=cluster_dataset.sample,
ref=cluster_dataset.ref,
sect=cluster_dataset.sect,
n_uniq_reads=uniq_reads.num_uniq,
max_order=max_order,
num_runs=n_runs,
min_iter=min_iter,
max_iter=max_iter,
conv_thresh=conv_thresh,
checksums={ClustBatchIO.btype(): checksums},
n_batches=len(checksums),
converged=update_field(original_report,
ClustsConvF,
orders,
"converged"),
log_likes=update_field(original_report,
ClustsLogLikesF,
orders,
"log_likes"),
clusts_rmsds=update_field(original_report,
ClustsRMSDsF,
orders,
"rmsds"),
clusts_meanr=update_field(original_report,
ClustsMeanRsF,
orders,
"meanr"),
bic=update_field(original_report,
ClustsBicF,
orders,
"bic"),
best_order=best_order,
began=original_began,
ended=ended,
taken=taken + calc_dt_minutes(new_began, ended),
)
new_report.save(cluster_dataset.top, force=True)
else:
best_order = original_best_order
n_new = best_order - original_best_order
logger.info(
f"Ended adding {n_new} cluster(s) to {cluster_report_file}"
)
new_report.save(cluster_dataset.top, force=True)
else:
best_order = original_best_order
n_new = best_order - original_best_order
logger.info(f"Ended adding {n_new} cluster(s) to {cluster_report_file}")
except Exception:
# If any error happens, then restore the original results
# (as if this function never ran) and re-raise the error.
restore_temp_backup(cluster_report_file.parent,
mask_dataset.top,
temp_dir)
raise
finally:
# Always delete the backup unless keep_temp is True.
if not keep_temp:
rmtree(backup_dir, ignore_errors=True)
logger.info(f"Deleted backup of {cluster_report_file.parent} "
f"in {backup_dir}")
else:
logger.warning(f"New maximum order ({max_order}) is not greater than "
f"original ({original_max_order}): nothing to update")
return cluster_report_file


@lock_temp_dir
@docdef.auto()
def run(input_path: tuple[str, ...], *,
max_clusters: int,
temp_dir: str,
keep_temp: bool,
brotli_level: int,
max_procs: int,
parallel: bool) -> list[Path]:
Expand All @@ -179,14 +206,19 @@ def run(input_path: tuple[str, ...], *,
pass_n_procs=True,
args=as_list_of_tuples(report_files),
kwargs=dict(max_order=max_clusters,
brotli_level=brotli_level))
brotli_level=brotli_level,
temp_dir=Path(temp_dir),
keep_temp=keep_temp))


params = [
# Input files
arg_input_path,
# Clustering options
opt_max_clusters,
# Backup
opt_temp_dir,
opt_keep_temp,
# Compression
opt_brotli_level,
# Parallelization
Expand Down
43 changes: 43 additions & 0 deletions src/seismicrna/core/io/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import cached_property
from logging import getLogger
from pathlib import Path
from shutil import copy2, copytree, move, rmtree
from typing import Any, Iterable

from .brickle import load_brickle, save_brickle
Expand Down Expand Up @@ -167,6 +168,48 @@ def recast_file_path(input_path: Path,
output_type.seg_types(),
**(override | output_type.auto_fields()))


def make_temp_backup(source_path: str | Path,
out_dir: str | Path,
temp_dir: str | Path):
""" Make a temporary backup of `source_path` in `temp_dir`. """
# Determine the path of the backup.
backup_path = path.transpath(temp_dir, out_dir, source_path)
# Copy the source files to the backup.
if source_path.is_dir():
if backup_path.exists():
rmtree(backup_path)
logger.debug(f"Deleted existing backup in {backup_path}")
copytree(source_path, backup_path)
logger.debug(f"Copied directory {source_path} to {backup_path}")
else:
backup_path.parent.mkdir(parents=True, exist_ok=True)
copy2(source_path, backup_path)
logger.debug(f"Copied file {source_path} to {backup_path}")
logger.info(f"Backed up {source_path} to {backup_path}")
return backup_path


def restore_temp_backup(source_path: str | Path,
out_dir: str | Path,
temp_dir: str | Path):
""" Restore the original files from a temporary backup. """
# Determine the path of the backup.
backup_path = path.transpath(temp_dir, out_dir, source_path)
# Replace the source files with the backup.
if backup_path.is_dir():
if source_path.exists():
rmtree(source_path)
logger.debug(f"Deleted original source in {source_path}")
move(backup_path, source_path.parent)
logger.debug(f"Moved directory {backup_path} to {source_path}")
else:
source_path.parent.mkdir(parents=True, exist_ok=True)
move(backup_path, source_path)
logger.debug(f"Moved file {backup_path} to {source_path}")
logger.info(f"Restored {source_path} from backup in {backup_path}")
return backup_path

########################################################################
# #
# © Copyright 2024, the Rouskin Lab. #
Expand Down
22 changes: 22 additions & 0 deletions src/seismicrna/core/io/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@


########################################################################
# #
# © Copyright 2024, the Rouskin Lab. #
# #
# This file is part of SEISMIC-RNA. #
# #
# SEISMIC-RNA is free software; you can redistribute it and/or modify #
# it under the terms of the GNU General Public License as published by #
# the Free Software Foundation; either version 3 of the License, or #
# (at your option) any later version. #
# #
# SEISMIC-RNA is distributed in the hope that it will be useful, but #
# WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANT- #
# ABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General #
# Public License for more details. #
# #
# You should have received a copy of the GNU General Public License #
# along with SEISMIC-RNA; if not, see <https://www.gnu.org/licenses>. #
# #
########################################################################
Loading

0 comments on commit 8e30c46

Please sign in to comment.