Skip to content

Commit

Permalink
Merge pull request #328 from bacpop/db_pruning_fix
Browse files Browse the repository at this point in the history
Small fixes to database pruning and updating
  • Loading branch information
nickjcroucher authored Oct 9, 2024
2 parents 4032bff + 5b4eb67 commit 2943e5f
Show file tree
Hide file tree
Showing 22 changed files with 538 additions and 203 deletions.
2 changes: 1 addition & 1 deletion PopPUNK/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

'''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)'''

__version__ = '2.7.1'
__version__ = '2.7.2'

# Minimum sketchlib version
SKETCHLIB_MAJOR = 2
Expand Down
37 changes: 28 additions & 9 deletions PopPUNK/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def get_options():
type=float)

# model refinement
refinementGroup = parser.add_argument_group('Refine model options')
refinementGroup = parser.add_argument_group('Network analysis and model refinement options')
refinementGroup.add_argument('--pos-shift', help='Maximum amount to move the boundary right past between-strain mean',
type=float, default = 0)
refinementGroup.add_argument('--neg-shift', help='Maximum amount to move the boundary left past within-strain mean]',
Expand All @@ -156,6 +156,9 @@ def get_options():
refinementGroup.add_argument('--score-idx',
help='Index of score to use [default = 0]',
type=int, default = 0, choices=[0, 1, 2])
refinementGroup.add_argument('--summary-sample',
help='Number of sequences used to estimate graph properties [default = all]',
type=int, default = None)
refinementGroup.add_argument('--betweenness-sample',
help='Number of sequences used to estimate betweeness with a GPU [default = 100]',
type = int, default = betweenness_sample_default)
Expand Down Expand Up @@ -253,7 +256,7 @@ def main():

# Imports are here because graph tool is very slow to load
from .models import loadClusterFit, BGMMFit, DBSCANFit, RefineFit, LineageFit
from .sketchlib import checkSketchlibLibrary, removeFromDB
from .sketchlib import checkSketchlibLibrary, removeFromDB, get_database_statistics

from .network import construct_network_from_edge_list
from .network import construct_network_from_assignments
Expand All @@ -264,6 +267,7 @@ def main():

from .plot import writeClusterCsv
from .plot import plot_scatter
from .plot import plot_database_evaluations

from .qc import prune_distance_matrix, qcDistMat, sketchlibAssemblyQC, remove_qc_fail

Expand Down Expand Up @@ -387,8 +391,10 @@ def main():
# Plot results
if not args.no_plot:
plot_scatter(distMat,
f"{args.output}/{os.path.basename(args.output)}_distanceDistribution",
args.output,
args.output + " distances")
genome_lengths, ambiguous_bases = get_database_statistics(args.output)
plot_database_evaluations(args.output, genome_lengths, ambiguous_bases)

#******************************#
#* *#
Expand Down Expand Up @@ -424,18 +430,18 @@ def main():
if args.remove_samples:
with open(args.remove_samples, 'r') as f:
for line in f:
fail_unconditionally[line.rstrip] = ["removed"]
sample_to_remove = line.rstrip()
if sample_to_remove in refList:
fail_unconditionally[sample_to_remove] = ["removed"]

# assembly qc
sys.stderr.write("Running sequence QC\n")
pass_assembly_qc, fail_assembly_qc = \
sketchlibAssemblyQC(args.ref_db,
refList,
qc_dict)
sys.stderr.write(f"{len(fail_assembly_qc)} samples failed\n")

# QC pairwise distances to identify long distances indicative of anomalous sequences in the collection
sys.stderr.write("Running distance QC\n")
pass_dist_qc, fail_dist_qc = \
qcDistMat(distMat,
refList,
Expand All @@ -446,19 +452,28 @@ def main():

# Get list of passing samples
pass_list = set(refList) - fail_unconditionally.keys() - fail_assembly_qc.keys() - fail_dist_qc.keys()
assert(pass_list == set(refList).intersection(set(pass_assembly_qc)).intersection(set(pass_dist_qc)))
assert(pass_list == (set(refList) - fail_unconditionally.keys()).intersection(set(pass_assembly_qc)).intersection(set(pass_dist_qc)))
passed = [x for x in refList if x in pass_list]
if qc_dict['type_isolate'] is not None and qc_dict['type_isolate'] not in pass_list:
raise RuntimeError('Type isolate ' + qc_dict['type_isolate'] + \
' not found in isolates after QC; check '
'name of type isolate and QC options\n')


sys.stderr.write(f"{len(passed)} samples passed QC\n")
if len(passed) < len(refList):
remove_qc_fail(qc_dict, refList, passed,
[fail_unconditionally, fail_assembly_qc, fail_dist_qc],
args.ref_db, distMat, output,
args.strand_preserved, args.threads)
args.strand_preserved, args.threads,
args.gpu_graph)

# Plot results
if not args.no_plot:
plot_scatter(distMat,
output,
output + " distances")
genome_lengths, ambiguous_bases = get_database_statistics(output)
plot_database_evaluations(output, genome_lengths, ambiguous_bases)

#******************************#
#* *#
Expand Down Expand Up @@ -545,6 +560,7 @@ def main():
args.score_idx,
args.no_local,
args.betweenness_sample,
args.summary_sample,
args.gpu_graph)
model = new_model
elif args.fit_model == "threshold":
Expand Down Expand Up @@ -613,6 +629,7 @@ def main():
model.within_label,
distMat = distMat,
weights_type = weights_type,
sample_size = args.summary_sample,
betweenness_sample = args.betweenness_sample,
use_gpu = args.gpu_graph)
else:
Expand All @@ -628,6 +645,7 @@ def main():
refList,
assignments[rank],
weights = weights,
sample_size = args.summary_sample,
betweenness_sample = args.betweenness_sample,
use_gpu = args.gpu_graph,
summarise = False
Expand Down Expand Up @@ -685,6 +703,7 @@ def main():
queryList,
indivAssignments,
model.within_label,
sample_size = args.summary_sample,
betweenness_sample = args.betweenness_sample,
use_gpu = args.gpu_graph)
isolateClustering[dist_type] = \
Expand Down
2 changes: 1 addition & 1 deletion PopPUNK/assign.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ def assign_query_hdf5(dbFuncs,
storePickle(combined_seq, combined_seq, True, None, dists_out)

# Clique pruning
if model.type != 'lineage':
if model.type != 'lineage' and os.path.isfile(ref_file_name):
existing_ref_list = []
with open(ref_file_name) as refFile:
for reference in refFile:
Expand Down
26 changes: 24 additions & 2 deletions PopPUNK/bgmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,33 @@ def fit2dMultiGaussian(X, dpgmm_max_K = 2):
return dpgmm


def findBetweenLabel_bgmm(means, assignments):
"""Identify between-strain links
Finds the component with the largest number of points
assigned to it
Args:
means (numpy.array)
K x 2 array of mixture component means
assignments (numpy.array)
Sample cluster assignments
Returns:
between_label (int)
The cluster label with the most points assigned to it
"""
most_dists = {}
for mixture_component, distance in enumerate(np.apply_along_axis(np.linalg.norm, 1, means)):
most_dists[mixture_component] = np.count_nonzero(assignments == mixture_component)

sorted_dists = sorted(most_dists.items(), key=operator.itemgetter(1), reverse=True)
return(sorted_dists[0][0])

def findWithinLabel(means, assignments, rank = 0):
"""Identify within-strain links
Finds the component with mean closest to the origin and also akes sure
Finds the component with mean closest to the origin and also makes sure
some samples are assigned to it (in the case of small weighted
components with a Dirichlet prior some components are unused)
Expand All @@ -59,7 +82,6 @@ def findWithinLabel(means, assignments, rank = 0):
Sample cluster assignments
rank (int)
Which label to find, ordered by distance from origin. 0-indexed.
(default = 0)
Returns:
Expand Down
55 changes: 30 additions & 25 deletions PopPUNK/info.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ def get_options():
# main code
def main():

# Import value
from .__main__ import betweenness_sample_default

# Import functions
from .network import load_network_file
from .network import sparse_mat_to_network
from .network import print_network_summary
from .utils import check_and_set_gpu
from .utils import setGtThreads

Expand Down Expand Up @@ -103,6 +107,32 @@ def main():
use_rc = ref_db['sketches'].attrs['use_rc'] == 1
print("Uses canonical k-mers:\t" + str(use_rc))

# Select network file name
if args.network_file is None:
if use_gpu:
network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.csv.gz')
else:
network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.gt')
else:
network_file = args.network_file

# Open network file
if network_file.endswith('.gt'):
G = load_network_file(network_file, use_gpu = False)
elif network_file.endswith('.csv.gz'):
if use_gpu:
G = load_network_file(network_file, use_gpu = True)
else:
sys.stderr.write('Unable to load necessary GPU libraries\n')
sys.exit(1)
elif network_file.endswith('.npz'):
sparse_mat = sparse.load_npz(network_file)
G = sparse_mat_to_network(sparse_mat, sample_names, use_gpu = use_gpu)
else:
sys.stderr.write('Unrecognised suffix: expected ".gt", ".csv.gz" or ".npz"\n')
sys.exit(1)
print_network_summary(G, betweenness_sample = betweenness_sample_default, use_gpu = args.use_gpu)

# Print sample information
if not args.simple:
sample_names = list(ref_db['sketches'].keys())
Expand All @@ -115,31 +145,6 @@ def main():
sample_sequence_length[sample_name] = ref_db['sketches/' + sample_name].attrs['length']
sample_missing_bases[sample_name] = ref_db['sketches/' + sample_name].attrs['missing_bases']

# Select network file name
if args.network_file is None:
if use_gpu:
network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.csv.gz')
else:
network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.gt')
else:
network_file = args.network_file

# Open network file
if network_file.endswith('.gt'):
G = load_network_file(network_file, use_gpu = False)
elif network_file.endswith('.csv.gz'):
if use_gpu:
G = load_network_file(network_file, use_gpu = True)
else:
sys.stderr.write('Unable to load necessary GPU libraries\n')
sys.exit(1)
elif network_file.endswith('.npz'):
sparse_mat = sparse.load_npz(network_file)
G = sparse_mat_to_network(sparse_mat, sample_names, use_gpu = use_gpu)
else:
sys.stderr.write('Unrecognised suffix: expected ".gt", ".csv.gz" or ".npz"\n')
sys.exit(1)

# Analyse network
if use_gpu:
component_assignments_df = cugraph.components.connectivity.connected_components(G)
Expand Down
Loading

0 comments on commit 2943e5f

Please sign in to comment.