From 357ae4c02049293e3a2d9781321fda0c5bce6526 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Mon, 30 Jan 2023 16:32:43 +0000 Subject: [PATCH 01/65] Move pairwise distance plot to db dir --- PopPUNK/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index f48eeaa3..443413bf 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -77,7 +77,7 @@ def plot_scatter(X, out_prefix, title, kde = True): plt.title(title) plt.xlabel('Core distance (' + r'$\pi$' + ')') plt.ylabel('Accessory distance (' + r'$a$' + ')') - plt.savefig(out_prefix + '.png') + plt.savefig(os.path.join(out_prefix, out_prefix + '.png')) plt.close() def plot_fit(klist, raw_matching, raw_fit, corrected_matching, corrected_fit, out_prefix, title): From d3843356d9107192919bd0d00053cf64da3a5da0 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Mon, 30 Jan 2023 16:58:16 +0000 Subject: [PATCH 02/65] Add db evaluation histograms --- PopPUNK/__main__.py | 4 ++- PopPUNK/plot.py | 59 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 61 insertions(+), 2 deletions(-) diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index 2fe144dc..6960dc11 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -240,6 +240,7 @@ def main(): from .plot import writeClusterCsv from .plot import plot_scatter + from .plot import plot_database_evaluations from .qc import prune_distance_matrix, qcDistMat, sketchlibAssemblyQC, remove_qc_fail @@ -363,8 +364,9 @@ def main(): # Plot results if not args.no_plot: plot_scatter(distMat, - f"{args.output}/{os.path.basename(args.output)}_distanceDistribution", + args.output, args.output + " distances") + plot_database_evaluations(args.output) #******************************# #* *# diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index 443413bf..8a832e0f 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -15,6 +15,7 @@ import itertools # for other outputs import pandas as pd +import h5py from collections import defaultdict from sklearn import utils try: # sklearn >= 0.22 @@ -77,7 +78,63 @@ def plot_scatter(X, out_prefix, title, kde = True): plt.title(title) plt.xlabel('Core distance (' + r'$\pi$' + ')') plt.ylabel('Accessory distance (' + r'$a$' + ')') - plt.savefig(os.path.join(out_prefix, out_prefix + '.png')) + plt.savefig(os.path.join(out_prefix, os.path.basename(out_prefix) + '_distanceDistribution.png')) + plt.close() + +def plot_database_evaluations(prefix): + """Plot histograms of sequence characteristics for database evaluation. + + Args: + prefix (str) + Prefix of database + """ + db_file = prefix + "/" + os.path.basename(prefix) + ".h5" + ref_db = h5py.File(db_file, 'r') + + genome_lengths = [] + ambiguous_bases = [] + for sample_name in list(ref_db['sketches'].keys()): + genome_lengths.append(ref_db['sketches/' + sample_name].attrs['length']) + ambiguous_bases.append(ref_db['sketches/' + sample_name].attrs['missing_bases']) + plot_evaluation_histogram(genome_lengths, + n_bins = 100, + prefix = prefix, + suffix = 'genome_lengths', + plt_title = 'Distribution of sequence lengths', + xlab = 'Sequence length (nt)') + plot_evaluation_histogram(ambiguous_bases, + n_bins = 100, + prefix = prefix, + suffix = 'ambiguous_base_counts', + plt_title = 'Distribution of ambiguous base counts', + xlab = 'Number of ambiguous bases') + +def plot_evaluation_histogram(input_data, n_bins = 100, prefix = 'hist', + suffix = '', plt_title = 'histogram', xlab = 'x'): + """Plot histograms of sequence characteristics for database evaluation. + + Args: + input_data (list) + Input data (list of numbers) + n_bins (int) + Number of bins to use for the histogram + prefix (str) + Prefix of database + suffix (str) + Suffix specifying plot type + plt_title (str) + Title for plot + xlab (str) + Title for the horizontal axis + """ + plt.figure(figsize=(8, 8), dpi=160, facecolor='w', edgecolor='k') + counts, bins = np.histogram(input_data, bins = n_bins) + plt.stairs(counts, bins, fill = True) + plt.title(plt_title) + plt.xlabel(xlab) + plt.ylabel('Frequency') + plt.savefig(os.path.join(out_prefix, os.path.basename(out_prefix) + suffix + '.png')) + plt.savefig(os.path.join(prefix,prefix + '.png')) plt.close() def plot_fit(klist, raw_matching, raw_fit, corrected_matching, corrected_fit, out_prefix, title): From 17cb94798be5a86e1b29ed12d99986a3a65fd678 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Mon, 30 Jan 2023 17:01:46 +0000 Subject: [PATCH 03/65] Correct variable name --- PopPUNK/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index 8a832e0f..b651e0e8 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -133,7 +133,7 @@ def plot_evaluation_histogram(input_data, n_bins = 100, prefix = 'hist', plt.title(plt_title) plt.xlabel(xlab) plt.ylabel('Frequency') - plt.savefig(os.path.join(out_prefix, os.path.basename(out_prefix) + suffix + '.png')) + plt.savefig(os.path.join(prefix, os.path.basename(prefix) + suffix + '.png')) plt.savefig(os.path.join(prefix,prefix + '.png')) plt.close() From 325171cdaf2f1ea6f5bd3a14dcbf78f9149ab79f Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 06:08:17 +0000 Subject: [PATCH 04/65] Inform user of QC process and output --- PopPUNK/__main__.py | 11 ++++++++--- PopPUNK/qc.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index 6960dc11..aa59b15f 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -405,7 +405,6 @@ def main(): fail_unconditionally[line.rstrip] = ["removed"] # assembly qc - sys.stderr.write("Running sequence QC\n") pass_assembly_qc, fail_assembly_qc = \ sketchlibAssemblyQC(args.ref_db, refList, @@ -413,7 +412,6 @@ def main(): sys.stderr.write(f"{len(fail_assembly_qc)} samples failed\n") # QC pairwise distances to identify long distances indicative of anomalous sequences in the collection - sys.stderr.write("Running distance QC\n") pass_dist_qc, fail_dist_qc = \ qcDistMat(distMat, refList, @@ -430,13 +428,20 @@ def main(): raise RuntimeError('Type isolate ' + qc_dict['type_isolate'] + \ ' not found in isolates after QC; check ' 'name of type isolate and QC options\n') - + + sys.stderr.write(f"{len(passed)} samples passed QC\n") if len(passed) < len(refList): remove_qc_fail(qc_dict, refList, passed, [fail_unconditionally, fail_assembly_qc, fail_dist_qc], args.ref_db, distMat, output, args.strand_preserved, args.threads) + # Plot results + if not args.no_plot: + plot_scatter(distMat, + args.output, + args.output + " distances") + plot_database_evaluations(args.output) #******************************# #* *# diff --git a/PopPUNK/qc.py b/PopPUNK/qc.py index 9ae8359f..5b5f8b49 100755 --- a/PopPUNK/qc.py +++ b/PopPUNK/qc.py @@ -153,7 +153,17 @@ def sketchlibAssemblyQC(prefix, names, qc_dict): import h5py from .sketchlib import removeFromDB + # Make user aware of all filters being used (including defaults) sys.stderr.write("Running QC on sketches\n") + if (qc_dict['upper_n'] is not None: + sys.stderr.write("Using count cutoff for ambiguous bases: " + str(qc_dict['upper_n']) + "\n") + else: + sys.stderr.write("Using proportion cutoff for ambiguous bases: " + str(qc_dict['prop_n']) + "\n") + if qc_dict['length_range'][0] is None: + sys.stderr.write("Using standard deviation for length cutoff: " + str(qc_dict['length_sigma']) + "\n") + else: + sys.stderr.write("Using range for length cutoffs: " + str(qc_dict['length_range'][0]) + " - " + \ + str(qc_dict['length_range'][1]) + "\n") # open databases db_name = prefix + '/' + os.path.basename(prefix) + '.h5' @@ -245,6 +255,12 @@ def qcDistMat(distMat, refList, queryList, ref_db, qc_dict): failed (dict) List of sequences failing, and reasons """ + # Make user aware of all filters being used (including defaults) + sys.stderr.write("Running QC on distances\n") + sys.stderr.write("Using cutoff for core distances: " + str(qc_dict['max_pi_dist']) + "\n") + sys.stderr.write("Using cutoff for accessory distances: " + str(qc_dict['max_a_dist']) + "\n") + sys.stderr.write("Using cutoff for proportion of zero distances: " + str(qc_dict['prop_zero']) + "\n") + # Create overall list of sequences if refList == queryList: names = refList From 316283a50afcb98df124c668faf7c5e659c6eae1 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 06:16:54 +0000 Subject: [PATCH 05/65] Fix bracket --- PopPUNK/qc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/qc.py b/PopPUNK/qc.py index 5b5f8b49..14edbdf7 100755 --- a/PopPUNK/qc.py +++ b/PopPUNK/qc.py @@ -155,7 +155,7 @@ def sketchlibAssemblyQC(prefix, names, qc_dict): # Make user aware of all filters being used (including defaults) sys.stderr.write("Running QC on sketches\n") - if (qc_dict['upper_n'] is not None: + if qc_dict['upper_n'] is not None: sys.stderr.write("Using count cutoff for ambiguous bases: " + str(qc_dict['upper_n']) + "\n") else: sys.stderr.write("Using proportion cutoff for ambiguous bases: " + str(qc_dict['prop_n']) + "\n") From 7041cf75355b5e82006f3447fd693af7158f6834 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 06:34:22 +0000 Subject: [PATCH 06/65] Fix plot names --- PopPUNK/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index b651e0e8..0b34243e 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -133,7 +133,7 @@ def plot_evaluation_histogram(input_data, n_bins = 100, prefix = 'hist', plt.title(plt_title) plt.xlabel(xlab) plt.ylabel('Frequency') - plt.savefig(os.path.join(prefix, os.path.basename(prefix) + suffix + '.png')) + plt.savefig(os.path.join(prefix, os.path.basename(prefix) + '_' + suffix + '.png')) plt.savefig(os.path.join(prefix,prefix + '.png')) plt.close() From 185c939765694984f9740dea3c906cb095d1744e Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 06:36:10 +0000 Subject: [PATCH 07/65] Fix plot prefixes --- PopPUNK/__main__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index aa59b15f..61dd4b47 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -439,9 +439,9 @@ def main(): # Plot results if not args.no_plot: plot_scatter(distMat, - args.output, - args.output + " distances") - plot_database_evaluations(args.output) + output, + output + " distances") + plot_database_evaluations(output) #******************************# #* *# From bda439852dcc7694d518fb07afec1e0b7e26519a Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 11:10:36 +0000 Subject: [PATCH 08/65] Validated update of https://github.com/rapidsai/cugraph/pull/2671 --- PopPUNK/network.py | 27 +-------------------------- 1 file changed, 1 insertion(+), 26 deletions(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 3c9c6a4e..cf40e0f9 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1105,27 +1105,6 @@ def construct_network_from_assignments(rlist, qlist, assignments, within_label = return G -def get_cugraph_triangles(G): - """Counts the number of triangles in a cugraph - network. Can be removed when the cugraph issue - https://github.com/rapidsai/cugraph/issues/1043 is fixed. - - Args: - G (cugraph network) - Network to be analysed - - Returns: - triangle_count (int) - Count of triangles in graph - """ - nlen = G.number_of_vertices() - df = G.view_edge_list() - A = cp.full((nlen, nlen), 0, dtype = cp.int32) - A[df.src.values, df.dst.values] = 1 - A = cp.maximum( A, A.transpose() ) - triangle_count = int(cp.around(cp.trace(cp.matmul(A, cp.matmul(A, A)))/6,0)) - return triangle_count - def networkSummary(G, calc_betweenness=True, betweenness_sample = betweenness_sample_default, use_gpu = False): """Provides summary values about the network @@ -1150,11 +1129,7 @@ def networkSummary(G, calc_betweenness=True, betweenness_sample = betweenness_sa component_nums = component_assignments['labels'].unique().astype(int) components = len(component_nums) density = G.number_of_edges()/(0.5 * G.number_of_vertices() * G.number_of_vertices() - 1) - # consistent with graph-tool for small graphs - triangle counts differ for large graphs - # could reflect issue https://github.com/rapidsai/cugraph/issues/1043 - # this command can be restored once the above issue is fixed - scheduled for cugraph 0.20 -# triangle_count = cugraph.community.triangle_count.triangles(G)/3 - triangle_count = 3*get_cugraph_triangles(G) + triangle_count = cugraph.community.triangle_count.triangles(G)/3 degree_df = G.in_degree() # consistent with graph-tool triad_count = 0.5 * sum([d * (d - 1) for d in degree_df[degree_df['degree'] > 1]['degree'].to_pandas()]) From d25cd5bced12815cef2f240566b6cd1573e8540e Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 13:55:40 +0000 Subject: [PATCH 09/65] Enable subsampling for graph analysis --- PopPUNK/__main__.py | 9 ++++++- PopPUNK/models.py | 6 ++++- PopPUNK/network.py | 64 ++++++++++++++++++++++++++++++++++++++------- PopPUNK/refine.py | 24 ++++++++++++++--- 4 files changed, 88 insertions(+), 15 deletions(-) diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index 61dd4b47..65f01eb2 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -121,7 +121,7 @@ def get_options(): modelGroup.add_argument('--threshold', help='Cutoff if using --fit-model threshold', type=float) # model refinement - refinementGroup = parser.add_argument_group('Refine model options') + refinementGroup = parser.add_argument_group('Network analysis and model refinement options') refinementGroup.add_argument('--pos-shift', help='Maximum amount to move the boundary right past between-strain mean', type=float, default = 0) refinementGroup.add_argument('--neg-shift', help='Maximum amount to move the boundary left past within-strain mean]', @@ -133,6 +133,9 @@ def get_options(): refinementGroup.add_argument('--score-idx', help='Index of score to use [default = 0]', type=int, default = 0, choices=[0, 1, 2]) + refinementGroup.add_argument('--summary-sample', + help='Number of sequences used to estimate graph properties [default = all]', + type=int, default = None) refinementGroup.add_argument('--betweenness-sample', help='Number of sequences used to estimate betweeness with a GPU [default = 100]', type = int, default = betweenness_sample_default) @@ -518,6 +521,7 @@ def main(): args.score_idx, args.no_local, args.betweenness_sample, + args.summary_sample, args.gpu_graph) model = new_model elif args.fit_model == "threshold": @@ -581,6 +585,7 @@ def main(): model.within_label, distMat = distMat, weights_type = weights_type, + sample_size = args.summary_sample, betweenness_sample = args.betweenness_sample, use_gpu = args.gpu_graph) else: @@ -596,6 +601,7 @@ def main(): refList, assignments[rank], weights = weights, + sample_size = args.summary_sample, betweenness_sample = args.betweenness_sample, use_gpu = args.gpu_graph, summarise = False @@ -653,6 +659,7 @@ def main(): queryList, indivAssignments, model.within_label, + sample_size = args.summary_sample, betweenness_sample = args.betweenness_sample, use_gpu = args.gpu_graph) isolateClustering[dist_type] = \ diff --git a/PopPUNK/models.py b/PopPUNK/models.py index d51f83f4..0b851c9b 100644 --- a/PopPUNK/models.py +++ b/PopPUNK/models.py @@ -702,7 +702,7 @@ def __init__(self, outPrefix): def fit(self, X, sample_names, model, max_move, min_move, startFile = None, indiv_refine = False, unconstrained = False, multi_boundary = 0, score_idx = 0, no_local = False, - betweenness_sample = betweenness_sample_default, use_gpu = False): + sample_size = None, betweenness_sample = betweenness_sample_default, use_gpu = False): '''Extends :func:`~ClusterFit.fit` Fits the distances by optimising network score, by calling @@ -744,6 +744,8 @@ def fit(self, X, sample_names, model, max_move, min_move, startFile = None, indi betweenness_sample (int) Number of sequences per component used to estimate betweenness using a GPU. Smaller numbers are faster but less precise [default = 100] + sample_size (int) + Number of nodes to subsample for graph statistic calculation use_gpu (bool) Whether to use cugraph for graph analyses @@ -791,6 +793,7 @@ def fit(self, X, sample_names, model, max_move, min_move, startFile = None, indi no_local = no_local, num_processes = self.threads, betweenness_sample = betweenness_sample, + sample_size = sample_size, use_gpu = use_gpu) self.fitted = True @@ -830,6 +833,7 @@ def fit(self, X, sample_names, model, max_move, min_move, startFile = None, indi no_local = no_local, num_processes = self.threads, betweenness_sample = betweenness_sample, + sample_size = sample_size, use_gpu = use_gpu) if dist_type == "core": self.core_boundary = core_boundary diff --git a/PopPUNK/network.py b/PopPUNK/network.py index cf40e0f9..d9fdef3c 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -548,12 +548,14 @@ def network_to_edges(prev_G_fn, rlist, adding_qq_dists = False, else: return source_ids, target_ids -def print_network_summary(G, betweenness_sample = betweenness_sample_default, use_gpu = False): +def print_network_summary(G, sample_size = None, betweenness_sample = betweenness_sample_default, use_gpu = False): """Wrapper function for printing network information Args: G (graph) List of reference sequence labels + sample_size (int) + Number of nodes to subsample for graph statistic calculation betweenness_sample (int) Number of sequences per component used to estimate betweenness using a GPU. Smaller numbers are faster but less precise [default = 100] @@ -561,7 +563,10 @@ def print_network_summary(G, betweenness_sample = betweenness_sample_default, us Whether to use GPUs for network construction """ # print some summaries - (metrics, scores) = networkSummary(G, betweenness_sample = betweenness_sample, use_gpu = use_gpu) + (metrics, scores) = networkSummary(G, + subsample = sample_size, + betweenness_sample = betweenness_sample, + use_gpu = use_gpu) sys.stderr.write("Network summary:\n" + "\n".join(["\tComponents\t\t\t\t" + str(metrics[0]), "\tDensity\t\t\t\t\t" + "{:.4f}".format(metrics[1]), "\tTransitivity\t\t\t\t" + "{:.4f}".format(metrics[2]), @@ -672,6 +677,7 @@ def construct_network_from_edge_list(rlist, previous_pkl = None, betweenness_sample = betweenness_sample_default, summarise = True, + sample_size = None, use_gpu = False): """Construct an undirected network using a data frame of edges. Nodes are samples and edges where samples are within the same cluster @@ -706,6 +712,8 @@ def construct_network_from_edge_list(rlist, summarise (bool) Whether to calculate and print network summaries with :func:`~networkSummary` (default = True) + sample_size (int) + Number of nodes to subsample for graph statistic calculation use_gpu (bool) Whether to use GPUs for network construction @@ -780,7 +788,10 @@ def construct_network_from_edge_list(rlist, else: G.add_edge_list(edge_list) if summarise: - print_network_summary(G, betweenness_sample = betweenness_sample, use_gpu = use_gpu) + print_network_summary(G, + sample_size = sample_size, + betweenness_sample = betweenness_sample, + use_gpu = use_gpu) return G @@ -795,6 +806,7 @@ def construct_network_from_df(rlist, previous_pkl = None, betweenness_sample = betweenness_sample_default, summarise = True, + sample_size = None, use_gpu = False): """Construct an undirected network using a data frame of edges. Nodes are samples and edges where samples are within the same cluster @@ -829,6 +841,8 @@ def construct_network_from_df(rlist, summarise (bool) Whether to calculate and print network summaries with :func:`~networkSummary` (default = True) + sample_size (int) + Number of nodes to subsample for graph statistic calculation use_gpu (bool) Whether to use GPUs for network construction @@ -894,7 +908,10 @@ def construct_network_from_df(rlist, summarise = False, use_gpu = use_gpu) if summarise: - print_network_summary(G, betweenness_sample = betweenness_sample, use_gpu = use_gpu) + print_network_summary(G, + sample_size = sample_size, + betweenness_sample = betweenness_sample, + use_gpu = use_gpu) return G def construct_network_from_sparse_matrix(rlist, @@ -905,6 +922,7 @@ def construct_network_from_sparse_matrix(rlist, previous_pkl = None, betweenness_sample = betweenness_sample_default, summarise = True, + sample_size = None, use_gpu = False): """Construct an undirected network using a sparse matrix. Nodes are samples and edges where samples are within the same cluster @@ -933,6 +951,8 @@ def construct_network_from_sparse_matrix(rlist, summarise (bool) Whether to calculate and print network summaries with :func:`~networkSummary` (default = True) + sample_size (int) + Number of nodes to subsample for graph statistic calculation use_gpu (bool) Whether to use GPUs for network construction @@ -956,7 +976,10 @@ def construct_network_from_sparse_matrix(rlist, summarise = False, use_gpu = use_gpu) if summarise: - print_network_summary(G, betweenness_sample = betweenness_sample, use_gpu = use_gpu) + print_network_summary(G, + sample_size = sample_size, + betweenness_sample = betweenness_sample, + use_gpu = use_gpu) return G def construct_dense_weighted_network(rlist, distMat, weights_type = None, use_gpu = False): @@ -1024,7 +1047,7 @@ def construct_dense_weighted_network(rlist, distMat, weights_type = None, use_gp def construct_network_from_assignments(rlist, qlist, assignments, within_label = 1, int_offset = 0, weights = None, distMat = None, weights_type = None, previous_network = None, old_ids = None, adding_qq_dists = False, previous_pkl = None, betweenness_sample = betweenness_sample_default, - summarise = True, use_gpu = False): + summarise = True, sample_size = None, use_gpu = False): """Construct an undirected network using sequence lists, assignments of pairwise distances to clusters, and the identifier of the cluster assigned to within-strain distances. Nodes are samples and edges where samples are within the same cluster @@ -1066,6 +1089,8 @@ def construct_network_from_assignments(rlist, qlist, assignments, within_label = summarise (bool) Whether to calculate and print network summaries with :func:`~networkSummary` (default = True) + sample_size (int) + Number of nodes to subsample for graph statistic calculation use_gpu (bool) Whether to use GPUs for network construction @@ -1101,17 +1126,22 @@ def construct_network_from_assignments(rlist, qlist, assignments, within_label = summarise = False, use_gpu = use_gpu) if summarise: - print_network_summary(G, betweenness_sample = betweenness_sample, use_gpu = use_gpu) + print_network_summary(G, + sample_size = sample_size, + betweenness_sample = betweenness_sample, + use_gpu = use_gpu) return G -def networkSummary(G, calc_betweenness=True, betweenness_sample = betweenness_sample_default, - use_gpu = False): +def networkSummary(G, subsample = None, calc_betweenness=True, + betweenness_sample = betweenness_sample_default, use_gpu = False): """Provides summary values about the network Args: G (graph) The network of strains + subsample (int) + Number of vertices to randomly subsample from graph calc_betweenness (bool) Whether to calculate betweenness stats use_gpu (bool) @@ -1125,6 +1155,13 @@ def networkSummary(G, calc_betweenness=True, betweenness_sample = betweenness_sa List of scores """ if use_gpu: + if subsample is None: + S = G + else: + vertex_subsample = cp.random.choice(cp.arange(0,G.number_of_vertices() - 1), + size = subsample, + replace = False) + S = cugraph.subgraph(G, vertex_subsample) component_assignments = cugraph.components.connectivity.connected_components(G) component_nums = component_assignments['labels'].unique().astype(int) components = len(component_nums) @@ -1138,6 +1175,15 @@ def networkSummary(G, calc_betweenness=True, betweenness_sample = betweenness_sa else: transitivity = 0.0 else: + if subsample is None: + S = G + else: + vfilt = g.new_vertex_property('bool', val = False) + vertex_subsample = np.random.choice(np.arange(0,len(list(G.vertices())) - 1), + size = subsample, + replace = False) + vfilt[vertex_subsample] = True + S = gt.GraphView(G, vfilt=vfilt) component_assignments, component_frequencies = gt.label_components(G) components = len(component_frequencies) density = len(list(G.edges()))/(0.5 * len(list(G.vertices())) * (len(list(G.vertices())) - 1)) diff --git a/PopPUNK/refine.py b/PopPUNK/refine.py index 3239bb0a..38c67f7f 100644 --- a/PopPUNK/refine.py +++ b/PopPUNK/refine.py @@ -88,6 +88,8 @@ def refineFit(distMat, sample_names, mean0, mean1, scale, betweenness_sample (int) Number of sequences per component used to estimate betweenness using a GPU. Smaller numbers are faster but less precise [default = 100] + sample_size (int) + Number of nodes to subsample for graph statistic calculation use_gpu (bool) Whether to use cugraph for graph analyses @@ -134,6 +136,7 @@ def refineFit(distMat, sample_names, mean0, mean1, scale, y_range = y_max, score_idx = score_idx, betweenness_sample = betweenness_sample, + sample_size = sample_size, use_gpu = True), range(global_grid_resolution)) else: @@ -154,6 +157,7 @@ def refineFit(distMat, sample_names, mean0, mean1, scale, y_range = y_max, score_idx = score_idx, betweenness_sample = betweenness_sample, + sample_size = sample_size, use_gpu = False), range(global_grid_resolution)) @@ -202,6 +206,7 @@ def refineFit(distMat, sample_names, mean0, mean1, scale, s_range, score_idx, betweenness_sample = betweenness_sample, + sample_size = sample_size, use_gpu = use_gpu)) global_s[np.isnan(global_s)] = 1 min_idx = np.argmin(np.array(global_s)) @@ -221,7 +226,7 @@ def refineFit(distMat, sample_names, mean0, mean1, scale, method = 'Bounded', options={'disp': True}, args = (sample_names, distMat, mean0, mean1, gradient, slope, score_idx, num_processes, - betweenness_sample, use_gpu) + betweenness_sample, sample_size, use_gpu) ) optimised_s = local_s.x @@ -295,6 +300,7 @@ def multi_refine(distMat, sample_names, mean0, mean1, scale, s_max, s_range, 0, write_clusters = output_prefix, + sample_size = sample_size, use_gpu = use_gpu) def check_search_range(scale, mean0, mean1, lower_s, upper_s): @@ -360,7 +366,7 @@ def expand_cugraph_network(G, G_extra_df): def growNetwork(sample_names, i_vec, j_vec, idx_vec, s_range, score_idx = 0, thread_idx = 0, betweenness_sample = betweenness_sample_default, - write_clusters = None, use_gpu = False): + write_clusters = None, sample_size = None, use_gpu = False): """Construct a network, then add edges to it iteratively. Input is from ``pp_sketchlib.iterateBoundary1D`` or``pp_sketchlib.iterateBoundary2D`` @@ -387,6 +393,8 @@ def growNetwork(sample_names, i_vec, j_vec, idx_vec, s_range, score_idx = 0, write_clusters (str) Set to a prefix to write the clusters from each position to files [default = None] + sample_size (int) + Number of nodes to subsample for graph statistic calculation use_gpu (bool) Whether to use cugraph for graph analyses @@ -435,6 +443,7 @@ def growNetwork(sample_names, i_vec, j_vec, idx_vec, s_range, score_idx = 0, # Add score into vector for any offsets passed (should usually just be one) G_summary = networkSummary(G, score_idx > 0, + subsample = sample_size, betweenness_sample = betweenness_sample, use_gpu = use_gpu) latest_score = -G_summary[1][score_idx] @@ -458,7 +467,7 @@ def growNetwork(sample_names, i_vec, j_vec, idx_vec, s_range, score_idx = 0, def newNetwork(s, sample_names, distMat, mean0, mean1, gradient, slope=2, score_idx=0, cpus=1, betweenness_sample = betweenness_sample_default, - use_gpu = False): + sample_size = None, use_gpu = False): """Wrapper function for :func:`~PopPUNK.network.construct_network_from_edge_list` which is called by optimisation functions moving a triangular decision boundary. @@ -490,6 +499,8 @@ def newNetwork(s, sample_names, distMat, mean0, mean1, gradient, betweenness_sample (int) Number of sequences per component used to estimate betweenness using a GPU. Smaller numbers are faster but less precise [default = 100] + sample_size (int) + Number of nodes to subsample for graph statistic calculation use_gpu (bool) Whether to use cugraph for graph analysis @@ -523,12 +534,14 @@ def newNetwork(s, sample_names, distMat, mean0, mean1, gradient, # Return score score = networkSummary(G, score_idx > 0, + subsample = sample_size, betweenness_sample = betweenness_sample, use_gpu = use_gpu)[1][score_idx] return(-score) def newNetwork2D(y_idx, sample_names, distMat, x_range, y_range, score_idx=0, - betweenness_sample = betweenness_sample_default, use_gpu = False): + betweenness_sample = betweenness_sample_default, sample_size = None, + use_gpu = False): """Wrapper function for thresholdIterate2D and :func:`growNetwork`. For a given y_max, constructs networks across x_range and returns a list @@ -551,6 +564,8 @@ def newNetwork2D(y_idx, sample_names, distMat, x_range, y_range, score_idx=0, betweenness_sample (int) Number of sequences per component used to estimate betweenness using a GPU. Smaller numbers are faster but less precise [default = 100] + sample_size (int) + Number of nodes to subsample for graph statistic calculation use_gpu (bool) Whether to use cugraph for graph analysis @@ -581,6 +596,7 @@ def newNetwork2D(y_idx, sample_names, distMat, x_range, y_range, score_idx=0, score_idx, y_idx, betweenness_sample, + sample_size = sample_size, use_gpu = use_gpu) return(scores) From 817f7ab01b270e371fc54c25ddb098d4da082623 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 13:59:57 +0000 Subject: [PATCH 10/65] Update triangle counting --- PopPUNK/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index d9fdef3c..2ce369d9 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1166,7 +1166,7 @@ def networkSummary(G, subsample = None, calc_betweenness=True, component_nums = component_assignments['labels'].unique().astype(int) components = len(component_nums) density = G.number_of_edges()/(0.5 * G.number_of_vertices() * G.number_of_vertices() - 1) - triangle_count = cugraph.community.triangle_count.triangles(G)/3 + triangle_count = cugraph.triangle_count(G)/3 degree_df = G.in_degree() # consistent with graph-tool triad_count = 0.5 * sum([d * (d - 1) for d in degree_df[degree_df['degree'] > 1]['degree'].to_pandas()]) From 3e583862ee455bee9c360867e8e1bf3f1895ab26 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 14:48:13 +0000 Subject: [PATCH 11/65] Fix argument parsing --- PopPUNK/models.py | 2 +- PopPUNK/network.py | 13 ++++++++----- PopPUNK/refine.py | 5 +++-- 3 files changed, 12 insertions(+), 8 deletions(-) diff --git a/PopPUNK/models.py b/PopPUNK/models.py index 0b851c9b..1309b11c 100644 --- a/PopPUNK/models.py +++ b/PopPUNK/models.py @@ -702,7 +702,7 @@ def __init__(self, outPrefix): def fit(self, X, sample_names, model, max_move, min_move, startFile = None, indiv_refine = False, unconstrained = False, multi_boundary = 0, score_idx = 0, no_local = False, - sample_size = None, betweenness_sample = betweenness_sample_default, use_gpu = False): + betweenness_sample = betweenness_sample_default, sample_size = None, use_gpu = False): '''Extends :func:`~ClusterFit.fit` Fits the distances by optimising network score, by calling diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 2ce369d9..cbc26214 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1133,17 +1133,20 @@ def construct_network_from_assignments(rlist, qlist, assignments, within_label = return G -def networkSummary(G, subsample = None, calc_betweenness=True, - betweenness_sample = betweenness_sample_default, use_gpu = False): +def networkSummary(G, calc_betweenness=True, betweenness_sample = betweenness_sample_default, + subsample = None, use_gpu = False): """Provides summary values about the network Args: G (graph) The network of strains - subsample (int) - Number of vertices to randomly subsample from graph calc_betweenness (bool) Whether to calculate betweenness stats + betweenness_sample (int) + Number of sequences per component used to estimate betweenness using + a GPU. Smaller numbers are faster but less precise [default = 100] + subsample (int) + Number of vertices to randomly subsample from graph use_gpu (bool) Whether to use cugraph for graph analysis @@ -1178,7 +1181,7 @@ def networkSummary(G, subsample = None, calc_betweenness=True, if subsample is None: S = G else: - vfilt = g.new_vertex_property('bool', val = False) + vfilt = G.new_vertex_property('bool', val = False) vertex_subsample = np.random.choice(np.arange(0,len(list(G.vertices())) - 1), size = subsample, replace = False) diff --git a/PopPUNK/refine.py b/PopPUNK/refine.py index 38c67f7f..350553ef 100644 --- a/PopPUNK/refine.py +++ b/PopPUNK/refine.py @@ -51,7 +51,8 @@ def refineFit(distMat, sample_names, mean0, mean1, scale, max_move, min_move, slope = 2, score_idx = 0, unconstrained = False, no_local = False, num_processes = 1, - betweenness_sample = betweenness_sample_default, use_gpu = False): + betweenness_sample = betweenness_sample_default, sample_size = None, + use_gpu = False): """Try to refine a fit by maximising a network score based on transitivity and density. Iteratively move the decision boundary to do this, using starting point from existing model. @@ -443,8 +444,8 @@ def growNetwork(sample_names, i_vec, j_vec, idx_vec, s_range, score_idx = 0, # Add score into vector for any offsets passed (should usually just be one) G_summary = networkSummary(G, score_idx > 0, - subsample = sample_size, betweenness_sample = betweenness_sample, + subsample = sample_size, use_gpu = use_gpu) latest_score = -G_summary[1][score_idx] for s in range(prev_idx, idx): From 17e3af7f8ee76fbbeffdd7416cd80819e9838b74 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 15:00:06 +0000 Subject: [PATCH 12/65] Use subgraph for statistics --- PopPUNK/network.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index cbc26214..1b4e937b 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1165,12 +1165,12 @@ def networkSummary(G, calc_betweenness=True, betweenness_sample = betweenness_sa size = subsample, replace = False) S = cugraph.subgraph(G, vertex_subsample) - component_assignments = cugraph.components.connectivity.connected_components(G) + component_assignments = cugraph.components.connectivity.connected_components(S) component_nums = component_assignments['labels'].unique().astype(int) components = len(component_nums) - density = G.number_of_edges()/(0.5 * G.number_of_vertices() * G.number_of_vertices() - 1) - triangle_count = cugraph.triangle_count(G)/3 - degree_df = G.in_degree() + density = S.number_of_edges()/(0.5 * S.number_of_vertices() * S.number_of_vertices() - 1) + triangle_count = cugraph.triangle_count(S)/3 + degree_df = S.in_degree() # consistent with graph-tool triad_count = 0.5 * sum([d * (d - 1) for d in degree_df[degree_df['degree'] > 1]['degree'].to_pandas()]) if triad_count > 0: @@ -1187,10 +1187,10 @@ def networkSummary(G, calc_betweenness=True, betweenness_sample = betweenness_sa replace = False) vfilt[vertex_subsample] = True S = gt.GraphView(G, vfilt=vfilt) - component_assignments, component_frequencies = gt.label_components(G) + component_assignments, component_frequencies = gt.label_components(S) components = len(component_frequencies) - density = len(list(G.edges()))/(0.5 * len(list(G.vertices())) * (len(list(G.vertices())) - 1)) - transitivity = gt.global_clustering(G)[0] + density = len(list(S.edges()))/(0.5 * len(list(S.vertices())) * (len(list(S.vertices())) - 1)) + transitivity = gt.global_clustering(S)[0] mean_bt = 0 weighted_mean_bt = 0 @@ -1204,7 +1204,7 @@ def networkSummary(G, calc_betweenness=True, betweenness_sample = betweenness_sa size = component_frequencies[component_frequencies.index == component].iloc[0].astype(int) if size > 3: component_vertices = component_assignments['vertex'][component_assignments['labels']==component] - subgraph = cugraph.subgraph(G, component_vertices) + subgraph = cugraph.subgraph(S, component_vertices) if len(component_vertices) >= betweenness_sample: component_betweenness = cugraph.betweenness_centrality(subgraph, k = betweenness_sample, @@ -1218,7 +1218,7 @@ def networkSummary(G, calc_betweenness=True, betweenness_sample = betweenness_sa for component, size in enumerate(component_frequencies): if size > 3: vfilt = component_assignments.a == component - subgraph = gt.GraphView(G, vfilt=vfilt) + subgraph = gt.GraphView(S, vfilt=vfilt) betweenness.append(max(vertex_betweenness(subgraph, norm=True))) sizes.append(size) From ae4958305cf97de0b8a8de1d64fce8041846318a Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 15:04:23 +0000 Subject: [PATCH 13/65] Update network count --- PopPUNK/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 1b4e937b..6ee8d5ad 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1169,7 +1169,7 @@ def networkSummary(G, calc_betweenness=True, betweenness_sample = betweenness_sa component_nums = component_assignments['labels'].unique().astype(int) components = len(component_nums) density = S.number_of_edges()/(0.5 * S.number_of_vertices() * S.number_of_vertices() - 1) - triangle_count = cugraph.triangle_count(S)/3 + triangle_count = cugraph.triangle_count(S)['counts'].sum()/3 degree_df = S.in_degree() # consistent with graph-tool triad_count = 0.5 * sum([d * (d - 1) for d in degree_df[degree_df['degree'] > 1]['degree'].to_pandas()]) From 934f8f5eb0369b84490bd997bb248ce69281c888 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 15:06:10 +0000 Subject: [PATCH 14/65] Remove correction factor --- PopPUNK/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 6ee8d5ad..ee679215 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1169,7 +1169,7 @@ def networkSummary(G, calc_betweenness=True, betweenness_sample = betweenness_sa component_nums = component_assignments['labels'].unique().astype(int) components = len(component_nums) density = S.number_of_edges()/(0.5 * S.number_of_vertices() * S.number_of_vertices() - 1) - triangle_count = cugraph.triangle_count(S)['counts'].sum()/3 + triangle_count = cugraph.triangle_count(S)['counts'].sum() degree_df = S.in_degree() # consistent with graph-tool triad_count = 0.5 * sum([d * (d - 1) for d in degree_df[degree_df['degree'] > 1]['degree'].to_pandas()]) From 5c29da10cf12d84097c75d40d96e71cf1976a6c1 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 15:26:35 +0000 Subject: [PATCH 15/65] Fix subsampling for graph-tool --- PopPUNK/network.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index ee679215..d01494ec 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1181,11 +1181,12 @@ def networkSummary(G, calc_betweenness=True, betweenness_sample = betweenness_sa if subsample is None: S = G else: - vfilt = G.new_vertex_property('bool', val = False) vertex_subsample = np.random.choice(np.arange(0,len(list(G.vertices())) - 1), size = subsample, replace = False) - vfilt[vertex_subsample] = True + vfilt_bool = np.full(len(list(G.vertices())) - 1, False) + vfilt_bool[vertex_subsample] = True + vfilt = G.new_vertex_property('bool', vals = vfilt_bool) S = gt.GraphView(G, vfilt=vfilt) component_assignments, component_frequencies = gt.label_components(S) components = len(component_frequencies) From 16c274a14eded49bec76f4ac6af216020f9b8570 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 15:29:50 +0000 Subject: [PATCH 16/65] Add tests for subsampling --- test/run_test.py | 1 + test/test-gpu.py | 1 + 2 files changed, 2 insertions(+) diff --git a/test/run_test.py b/test/run_test.py index 5f26a2f8..68f3d01f 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -43,6 +43,7 @@ subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_refine --neg-shift 0.15 --overwrite --score-idx 1", shell=True, check=True) subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_refine --neg-shift 0.15 --overwrite --score-idx 2", shell=True, check=True) subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model threshold --threshold 0.003 --ref-db example_db --output example_threshold", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_refine --neg-shift 0.15 --summary-sample 15 --overwrite", shell=True, check=True) sys.stderr.write("Running multi boundary refinement (--multi-boundary and poppunk_iterate.py)\n") subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_iterate --neg-shift -0.2 --overwrite --multi-boundary 10", shell=True, check=True) diff --git a/test/test-gpu.py b/test/test-gpu.py index 58af75c2..32fa4d6e 100755 --- a/test/test-gpu.py +++ b/test/test-gpu.py @@ -43,6 +43,7 @@ subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_refine --neg-shift 0.2 --overwrite --score-idx 1 --gpu-graph", shell=True, check=True) subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_refine --neg-shift 0.2 --overwrite --score-idx 2 --gpu-graph", shell=True, check=True) subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model threshold --threshold 0.003 --ref-db example_db --output example_threshold --gpu-graph", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --fit-model refine --ref-db example_db --output example_refine --neg-shift 0.15 --summary-sample 15 --overwrite --gpu-graph", shell=True, check=True) # lineage clustering sys.stderr.write("Running lineage clustering test (--fit-model lineage)\n") From 75ef0d52f3ba456f55ab1e120dc17de9707a606f Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 15:41:56 +0000 Subject: [PATCH 17/65] Enable sampling for multirefine --- PopPUNK/models.py | 2 ++ PopPUNK/refine.py | 11 +++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/PopPUNK/models.py b/PopPUNK/models.py index 1309b11c..62c9620b 100644 --- a/PopPUNK/models.py +++ b/PopPUNK/models.py @@ -809,6 +809,8 @@ def fit(self, X, sample_names, model, max_move, min_move, startFile = None, indi multi_boundary, self.outPrefix, num_processes = self.threads, + betweenness_sample = betweenness_sample, + sample_size = sample_size, use_gpu = use_gpu) # Try and do a 1D refinement for both core and accessory diff --git a/PopPUNK/refine.py b/PopPUNK/refine.py index 350553ef..aa0f9a00 100644 --- a/PopPUNK/refine.py +++ b/PopPUNK/refine.py @@ -247,8 +247,9 @@ def refineFit(distMat, sample_names, mean0, mean1, scale, return optimal_x, optimal_y, optimised_s def multi_refine(distMat, sample_names, mean0, mean1, scale, s_max, - n_boundary_points, output_prefix, - num_processes = 1, use_gpu = False): + n_boundary_points, output_prefix, num_processes = 1, + betweenness_sample = betweenness_sample_default, sample_size = None, + use_gpu = False): """Move the refinement boundary between the optimum and where it meets an axis. Discrete steps, output the clusers at each step @@ -270,6 +271,11 @@ def multi_refine(distMat, sample_names, mean0, mean1, scale, s_max, num_processes (int) Number of threads to use in the global optimisation step. (default = 1) + betweenness_sample (int) + Number of sequences per component used to estimate betweenness using + a GPU. Smaller numbers are faster but less precise [default = 100] + sample_size (int) + Number of nodes to subsample for graph statistic calculation use_gpu (bool) Whether to use cugraph for graph analyses """ @@ -301,6 +307,7 @@ def multi_refine(distMat, sample_names, mean0, mean1, scale, s_max, s_range, 0, write_clusters = output_prefix, + betweenness_sample = betweenness_sample, sample_size = sample_size, use_gpu = use_gpu) From bde0a3ec8981ccaf5e2097fcf35053978313ff2f Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 1 Feb 2023 16:17:29 +0000 Subject: [PATCH 18/65] Enable extraction of full graph statistics --- PopPUNK/info.py | 55 +++++++++++++++++++++++++++---------------------- 1 file changed, 30 insertions(+), 25 deletions(-) diff --git a/PopPUNK/info.py b/PopPUNK/info.py index 3e0607c7..8baf9599 100644 --- a/PopPUNK/info.py +++ b/PopPUNK/info.py @@ -52,9 +52,13 @@ def get_options(): # main code def main(): + # Import value + from .__main__ import betweenness_sample_default + # Import functions from .network import load_network_file from .network import sparse_mat_to_network + from .network import print_network_summary from .utils import check_and_set_gpu from .utils import setGtThreads @@ -103,6 +107,32 @@ def main(): use_rc = ref_db['sketches'].attrs['use_rc'] == 1 print("Uses canonical k-mers:\t" + str(use_rc)) + # Select network file name + if args.network_file is None: + if use_gpu: + network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.csv.gz') + else: + network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.gt') + else: + network_file = args.network_file + + # Open network file + if network_file.endswith('.gt'): + G = load_network_file(network_file, use_gpu = False) + elif network_file.endswith('.csv.gz'): + if use_gpu: + G = load_network_file(network_file, use_gpu = True) + else: + sys.stderr.write('Unable to load necessary GPU libraries\n') + sys.exit(1) + elif network_file.endswith('.npz'): + sparse_mat = sparse.load_npz(network_file) + G = sparse_mat_to_network(sparse_mat, sample_names, use_gpu = use_gpu) + else: + sys.stderr.write('Unrecognised suffix: expected ".gt", ".csv.gz" or ".npz"\n') + sys.exit(1) + print_network_summary(G, betweenness_sample = betweenness_sample_default, use_gpu = args.use_gpu) + # Print sample information if not args.simple: sample_names = list(ref_db['sketches'].keys()) @@ -115,31 +145,6 @@ def main(): sample_sequence_length[sample_name] = ref_db['sketches/' + sample_name].attrs['length'] sample_missing_bases[sample_name] = ref_db['sketches/' + sample_name].attrs['missing_bases'] - # Select network file name - if args.network_file is None: - if use_gpu: - network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.csv.gz') - else: - network_file = os.path.join(args.db, os.path.basename(args.db) + '_graph.gt') - else: - network_file = args.network_file - - # Open network file - if network_file.endswith('.gt'): - G = load_network_file(network_file, use_gpu = False) - elif network_file.endswith('.csv.gz'): - if use_gpu: - G = load_network_file(network_file, use_gpu = True) - else: - sys.stderr.write('Unable to load necessary GPU libraries\n') - sys.exit(1) - elif network_file.endswith('.npz'): - sparse_mat = sparse.load_npz(network_file) - G = sparse_mat_to_network(sparse_mat, sample_names, use_gpu = use_gpu) - else: - sys.stderr.write('Unrecognised suffix: expected ".gt", ".csv.gz" or ".npz"\n') - sys.exit(1) - # Analyse network if use_gpu: component_assignments_df = cugraph.components.connectivity.connected_components(G) From 757eeffeecfd51f8f792989e51fd024ea080852b Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Thu, 2 Feb 2023 06:48:02 +0000 Subject: [PATCH 19/65] Update function docstring --- PopPUNK/network.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index d01494ec..23b8ab1f 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -689,8 +689,8 @@ def construct_network_from_edge_list(rlist, List of reference sequence labels qlist (list) List of query sequence labels - G_df (cudf or pandas data frame) - Data frame in which the first two columns are the nodes linked by edges + edge_list (list of tuples) + List of connections in the network weights (list) List of edge weights distMat (2 column ndarray) From 7311f71611493ed1fd6fda5544dbd0347eddcbc1 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 3 Feb 2023 14:49:15 +0000 Subject: [PATCH 20/65] Allow for tmp file space --- PopPUNK/trees.py | 9 +++++++-- PopPUNK/visualise.py | 8 ++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/PopPUNK/trees.py b/PopPUNK/trees.py index 1891c0d0..a2e3bf72 100644 --- a/PopPUNK/trees.py +++ b/PopPUNK/trees.py @@ -147,7 +147,7 @@ def load_tree(prefix, type, distances = 'core'): return tree_string -def generate_nj_tree(coreMat, seqLabels, outPrefix, rapidnj, threads): +def generate_nj_tree(coreMat, seqLabels, outPrefix, tmp = None, rapidnj = None, threads = 1): """Generate phylogeny using dendropy or RapidNJ Writes a neighbour joining tree (.nwk) from core distances. @@ -159,6 +159,8 @@ def generate_nj_tree(coreMat, seqLabels, outPrefix, rapidnj, threads): Processed names of sequences being analysed. outPrefix (str) Output prefix for core distances file + tmp (str) + Directory in which to create large temporary pairwise distance file rapidnj (str) A string with the location of the rapidnj executable for tree-building. If None, will use dendropy by default @@ -174,7 +176,10 @@ def generate_nj_tree(coreMat, seqLabels, outPrefix, rapidnj, threads): # calculate phylogeny sys.stderr.write("Building phylogeny\n") if rapidnj is not None: - core_dist_file = outPrefix + "/" + os.path.basename(outPrefix) + "_core_dists.csv" + if tmp is None: + core_dist_file = outPrefix + "/" + os.path.basename(outPrefix) + "_core_dists.csv" + else: + core_dist_file = tmp + "/" + os.path.basename(outPrefix) + "_core_dists.csv" np.savetxt(core_dist_file, coreMat, delimiter=",", header = ",".join(seqLabels), comments="") tree = buildRapidNJ(rapidnj, seqLabels, coreMat, outPrefix, threads = threads) os.remove(core_dist_file) diff --git a/PopPUNK/visualise.py b/PopPUNK/visualise.py index 4669da1b..069f55c5 100644 --- a/PopPUNK/visualise.py +++ b/PopPUNK/visualise.py @@ -134,6 +134,7 @@ def get_options(): other.add_argument('--gpu-dist', default=False, action='store_true', help='Use a GPU when calculating distances [default = False]') other.add_argument('--gpu-graph', default=False, action='store_true', help='Use a GPU when calculating graphs [default = False]') other.add_argument('--deviceid', default=0, type=int, help='CUDA device ID, if using GPU [default = 0]') + other.add_argument('--tmp', default=None, type=str, help='Directory for large temporary files') other.add_argument('--strand-preserved', default=False, action='store_true', help='If distances being calculated, treat strand as known when calculating random ' 'match chances [default = False]') @@ -184,7 +185,8 @@ def generate_visualisations(query_db, tree, mst_distances, overwrite, - display_cluster): + display_cluster, + tmp): from .models import loadClusterFit @@ -529,6 +531,7 @@ def generate_visualisations(query_db, nj_tree = generate_nj_tree(core_distMat, combined_seq, output, + tmp, rapidnj, threads = threads) else: @@ -650,7 +653,8 @@ def main(): args.tree, args.mst_distances, args.overwrite, - args.display_cluster) + args.display_cluster, + args.tmp) if __name__ == '__main__': main() From 9e010692fc6586d59c10f111d694b0a08186f1bb Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 3 Feb 2023 16:33:44 +0000 Subject: [PATCH 21/65] Reduce precision and size of distance matrix --- PopPUNK/trees.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/PopPUNK/trees.py b/PopPUNK/trees.py index a2e3bf72..66b3374e 100644 --- a/PopPUNK/trees.py +++ b/PopPUNK/trees.py @@ -180,7 +180,13 @@ def generate_nj_tree(coreMat, seqLabels, outPrefix, tmp = None, rapidnj = None, core_dist_file = outPrefix + "/" + os.path.basename(outPrefix) + "_core_dists.csv" else: core_dist_file = tmp + "/" + os.path.basename(outPrefix) + "_core_dists.csv" - np.savetxt(core_dist_file, coreMat, delimiter=",", header = ",".join(seqLabels), comments="") + np.savetxt(core_dist_file, + coreMat, + fmt='%.4e', + delimiter=",", + header = ",".join(seqLabels), + comments="" + ) tree = buildRapidNJ(rapidnj, seqLabels, coreMat, outPrefix, threads = threads) os.remove(core_dist_file) else: From c36489ab3c73cd8b221691a8955a6ae6b6e4fbd6 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 3 Feb 2023 16:42:35 +0000 Subject: [PATCH 22/65] Make indentation consistent --- PopPUNK/lineages.py | 172 ++++++++++++++++++++++---------------------- 1 file changed, 86 insertions(+), 86 deletions(-) diff --git a/PopPUNK/lineages.py b/PopPUNK/lineages.py index 38f0300e..298ab105 100755 --- a/PopPUNK/lineages.py +++ b/PopPUNK/lineages.py @@ -189,92 +189,92 @@ def create_db(args): lineage_dbs = {} overall_lineage = {} for strain,isolates in strains: - # Make new database directory - sys.stderr.write("Making database for strain " + str(strain) + "\n") - strain_db_name = args.lineage_db_prefix + '_' + str(strain) + '_lineage_db' - isolate_list = isolates[isolates.columns.values[0]].to_list() - num_isolates = len(isolate_list) - if num_isolates >= args.min_count: - lineage_dbs[strain] = strain_db_name - if os.path.isdir(strain_db_name) and args.overwrite: - os.rmdir(strain_db_name) - if not os.path.isdir(strain_db_name): - try: - os.makedirs(strain_db_name) - except OSError: - sys.stderr.write("Cannot create output directory " + strain_db_name + "\n") - sys.exit(1) - # Make link to main database - src_db = os.path.join(args.create_db,os.path.basename(args.create_db) + '.h5') - dest_db = os.path.join(strain_db_name,os.path.basename(strain_db_name) + '.h5') - rel_path = os.path.relpath(src_db, os.path.dirname(dest_db)) - if os.path.exists(dest_db) and args.overwrite: - os.remove(dest_db) - elif not os.path.exists(dest_db): - os.symlink(rel_path,dest_db) - # Extract sparse distances - prune_distance_matrix(rlist, - list(set(rlist) - set(isolate_list)), - X, - os.path.join(strain_db_name,strain_db_name + '.dists')) - # Initialise model - model = LineageFit(strain_db_name, - rank_list, - max_search_depth, - args.reciprocal_only, - args.count_unique_distances, - use_gpu = args.gpu_graph) - model.set_threads(args.threads) - # Load pruned distance matrix - strain_rlist, strain_qlist, strain_self, strain_X = \ - readPickle(os.path.join(strain_db_name,strain_db_name + '.dists'), - enforce_self=False, - distances=True) - # Fit model - model.fit(strain_X, - args.use_accessory) - # Lineage fit requires some iteration - indivNetworks = {} - lineage_clusters = defaultdict(dict) - # Iterate over ranks - for rank in rank_list: - if rank <= num_isolates: - assignments = model.assign(rank) - # Generate networks - indivNetworks[rank] = construct_network_from_edge_list(strain_rlist, - strain_rlist, - assignments, - weights = None, - betweenness_sample = None, - use_gpu = args.gpu_graph, - summarise = False - ) - # Write networks - save_network(indivNetworks[rank], - prefix = strain_db_name, - suffix = '_rank_' + str(rank) + '_graph', - use_gpu = args.gpu_graph) - # Identify clusters from output - lineage_clusters[rank] = \ - printClusters(indivNetworks[rank], - strain_rlist, - printCSV = False, - use_gpu = args.gpu_graph) - n_clusters = max(lineage_clusters[rank].values()) - sys.stderr.write("Network for rank " + str(rank) + " has " + - str(n_clusters) + " lineages\n") - # For each strain, print output of each rank as CSV - overall_lineage[strain] = createOverallLineage(rank_list, lineage_clusters) - writeClusterCsv(os.path.join(strain_db_name,os.path.basename(strain_db_name) + '_lineages.csv'), - strain_rlist, - strain_rlist, - overall_lineage[strain], - output_format = 'phandango', - epiCsv = None, - suffix = '_Lineage') - genomeNetwork = indivNetworks[min(rank_list)] - # Save model - model.save() + # Make new database directory + sys.stderr.write("Making database for strain " + str(strain) + "\n") + strain_db_name = args.lineage_db_prefix + '_' + str(strain) + '_lineage_db' + isolate_list = isolates[isolates.columns.values[0]].to_list() + num_isolates = len(isolate_list) + if num_isolates >= args.min_count: + lineage_dbs[strain] = strain_db_name + if os.path.isdir(strain_db_name) and args.overwrite: + os.rmdir(strain_db_name) + if not os.path.isdir(strain_db_name): + try: + os.makedirs(strain_db_name) + except OSError: + sys.stderr.write("Cannot create output directory " + strain_db_name + "\n") + sys.exit(1) + # Make link to main database + src_db = os.path.join(args.create_db,os.path.basename(args.create_db) + '.h5') + dest_db = os.path.join(strain_db_name,os.path.basename(strain_db_name) + '.h5') + rel_path = os.path.relpath(src_db, os.path.dirname(dest_db)) + if os.path.exists(dest_db) and args.overwrite: + os.remove(dest_db) + elif not os.path.exists(dest_db): + os.symlink(rel_path,dest_db) + # Extract sparse distances + prune_distance_matrix(rlist, + list(set(rlist) - set(isolate_list)), + X, + os.path.join(strain_db_name,strain_db_name + '.dists')) + # Initialise model + model = LineageFit(strain_db_name, + rank_list, + max_search_depth, + args.reciprocal_only, + args.count_unique_distances, + use_gpu = args.gpu_graph) + model.set_threads(args.threads) + # Load pruned distance matrix + strain_rlist, strain_qlist, strain_self, strain_X = \ + readPickle(os.path.join(strain_db_name,strain_db_name + '.dists'), + enforce_self=False, + distances=True) + # Fit model + model.fit(strain_X, + args.use_accessory) + # Lineage fit requires some iteration + indivNetworks = {} + lineage_clusters = defaultdict(dict) + # Iterate over ranks + for rank in rank_list: + if rank <= num_isolates: + assignments = model.assign(rank) + # Generate networks + indivNetworks[rank] = construct_network_from_edge_list(strain_rlist, + strain_rlist, + assignments, + weights = None, + betweenness_sample = None, + use_gpu = args.gpu_graph, + summarise = False + ) + # Write networks + save_network(indivNetworks[rank], + prefix = strain_db_name, + suffix = '_rank_' + str(rank) + '_graph', + use_gpu = args.gpu_graph) + # Identify clusters from output + lineage_clusters[rank] = \ + printClusters(indivNetworks[rank], + strain_rlist, + printCSV = False, + use_gpu = args.gpu_graph) + n_clusters = max(lineage_clusters[rank].values()) + sys.stderr.write("Network for rank " + str(rank) + " has " + + str(n_clusters) + " lineages\n") + # For each strain, print output of each rank as CSV + overall_lineage[strain] = createOverallLineage(rank_list, lineage_clusters) + writeClusterCsv(os.path.join(strain_db_name,os.path.basename(strain_db_name) + '_lineages.csv'), + strain_rlist, + strain_rlist, + overall_lineage[strain], + output_format = 'phandango', + epiCsv = None, + suffix = '_Lineage') + genomeNetwork = indivNetworks[min(rank_list)] + # Save model + model.save() # Print combined strain and lineage clustering print_overall_clustering(overall_lineage,args.output + '.csv',rlist) From 628844699fa02a02f46edc41484b9eb47f6a8005 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 3 Feb 2023 17:03:05 +0000 Subject: [PATCH 23/65] Prune rank list for rare strains --- PopPUNK/lineages.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/PopPUNK/lineages.py b/PopPUNK/lineages.py index 298ab105..d3ed5b97 100755 --- a/PopPUNK/lineages.py +++ b/PopPUNK/lineages.py @@ -217,6 +217,8 @@ def create_db(args): list(set(rlist) - set(isolate_list)), X, os.path.join(strain_db_name,strain_db_name + '.dists')) + # Prune rank list + pruned_rank_list = [r for r in rank_list if r <= num_isolates] # Initialise model model = LineageFit(strain_db_name, rank_list, From 98e36a8c3f7d38f6abea2e72d2f47482a785b657 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 3 Feb 2023 17:14:40 +0000 Subject: [PATCH 24/65] Remove unnecessary file generation --- PopPUNK/trees.py | 25 +++++++++---------------- 1 file changed, 9 insertions(+), 16 deletions(-) diff --git a/PopPUNK/trees.py b/PopPUNK/trees.py index 66b3374e..5d501f4f 100644 --- a/PopPUNK/trees.py +++ b/PopPUNK/trees.py @@ -28,7 +28,7 @@ except ImportError: pass -def buildRapidNJ(rapidnj, refList, coreMat, outPrefix, threads = 1): +def buildRapidNJ(rapidnj, refList, coreMat, outPrefix, tmp = None, threads = 1): """Use rapidNJ for more rapid tree building Creates a phylip of core distances, system call to rapidnj executable, loads tree as @@ -43,6 +43,8 @@ def buildRapidNJ(rapidnj, refList, coreMat, outPrefix, threads = 1): NxN core distance matrix produced in :func:`~outputsForMicroreact` outPrefix (str) Prefix for all generated output files, which will be placed in `outPrefix` subdirectory + tmp (str) + Directory in which to create large temporary pairwise distance file threads (int) Number of threads to use @@ -51,12 +53,15 @@ def buildRapidNJ(rapidnj, refList, coreMat, outPrefix, threads = 1): Newick-formatted NJ tree from core distances """ # generate phylip matrix - phylip_name = outPrefix + "/" + os.path.basename(outPrefix) + "_core_distances.phylip" + if tmp is not None: + phylip_name = tmp + "/" + os.path.basename(outPrefix) + "_core_distances.phylip" + else: + phylip_name = outPrefix + "/" + os.path.basename(outPrefix) + "_core_distances.phylip" with open(phylip_name, 'w') as pFile: pFile.write(str(len(refList))+"\n") for coreDist, ref in zip(coreMat, refList): pFile.write(ref) - pFile.write(' '+' '.join(map(str, coreDist))) + pFile.write(' '+' '.join(map('{:.4f}'.format, coreDist))) pFile.write("\n") # construct tree @@ -176,19 +181,7 @@ def generate_nj_tree(coreMat, seqLabels, outPrefix, tmp = None, rapidnj = None, # calculate phylogeny sys.stderr.write("Building phylogeny\n") if rapidnj is not None: - if tmp is None: - core_dist_file = outPrefix + "/" + os.path.basename(outPrefix) + "_core_dists.csv" - else: - core_dist_file = tmp + "/" + os.path.basename(outPrefix) + "_core_dists.csv" - np.savetxt(core_dist_file, - coreMat, - fmt='%.4e', - delimiter=",", - header = ",".join(seqLabels), - comments="" - ) - tree = buildRapidNJ(rapidnj, seqLabels, coreMat, outPrefix, threads = threads) - os.remove(core_dist_file) + tree = buildRapidNJ(rapidnj, seqLabels, coreMat, outPrefix, tmp = tmp, threads = threads) else: matrix = [] for row, idx in enumerate(coreMat): From 9a19cdc570485664841def74fdc62dccf3876427 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Sat, 4 Feb 2023 06:58:58 +0000 Subject: [PATCH 25/65] Fix pandas error handling --- PopPUNK/plot.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index 0b34243e..a0196abe 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -15,6 +15,7 @@ import itertools # for other outputs import pandas as pd +from pandas.errors import DataError import h5py from collections import defaultdict from sklearn import utils @@ -728,17 +729,19 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, sys.stderr.write("Parsed data, now writing to CSV\n") try: pd.DataFrame(data=d).to_csv(outfile, columns = colnames, index = False) - except subprocess.CalledProcessError as e: - sys.stderr.write("Problem with epidemiological data CSV; returned code: " + str(e.returncode) + "\n") + except (ValueError,DataError) as e: + sys.stderr.write("Problem with epidemiological data CSV; returned code: " + str(e) + "\n") # check CSV prev_col_items = -1 prev_col_name = "unknown" for col in d: this_col_items = len(d[col]) + sys.stderr.write(col + ' has length ' + str(this_col_items) + '\n') if prev_col_items > -1 and prev_col_items != this_col_items: sys.stderr.write("Discrepant length between " + prev_col_name + \ - " (length of " + prev_col_items + ") and " + \ - col + "(length of " + this_col_items + ")\n") + " (length of " + str(prev_col_items) + ") and " + \ + col + "(length of " + str(this_col_items) + ")\n") + prev_col_items = this_col_items sys.exit(1) def outputsForMicroreact(combined_list, clustering, nj_tree, mst_tree, accMat, perplexity, From df9117f16030480a0f4aecf2e808218ded0a1da3 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Sun, 5 Feb 2023 07:12:43 +0000 Subject: [PATCH 26/65] Filter duplicate rows in epi csv --- PopPUNK/plot.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index a0196abe..d9dad098 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -713,13 +713,15 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, d['Status'].append("Reference") if epiCsv is not None: if label in epiData.index: - for col, value in zip(epiData.columns.values, epiData.loc[label].values): - if col not in columns_to_be_omitted: - d[col].append(str(value)) - else: - for col in epiData.columns.values: - if col not in columns_to_be_omitted: - d[col].append('nan') + info_added.append(label) + if label in epiData.index: + for col, value in zip(epiData.columns.values, epiData.loc[[label]].iloc[0].values): + if col not in columns_to_be_omitted: + d[col].append(str(value)) + else: + for col in epiData.columns.values: + if col not in columns_to_be_omitted: + d[col].append('nan') else: sys.stderr.write("Cannot find " + name + " in clustering\n") From a4ab20183544e55623a5512e139a4a0dc36a547f Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Sun, 5 Feb 2023 12:11:17 +0000 Subject: [PATCH 27/65] Remove obsolete variable --- PopPUNK/plot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index d9dad098..239b8768 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -713,7 +713,6 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, d['Status'].append("Reference") if epiCsv is not None: if label in epiData.index: - info_added.append(label) if label in epiData.index: for col, value in zip(epiData.columns.values, epiData.loc[[label]].iloc[0].values): if col not in columns_to_be_omitted: From aa637d7f6d94dc95774b085484475373241822d2 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 7 Feb 2023 06:11:44 +0000 Subject: [PATCH 28/65] Harmonise behaviour of BGMM and DBSCAN --- PopPUNK/bgmm.py | 29 +++++++++++++++++++++++++++-- PopPUNK/models.py | 3 ++- PopPUNK/plot.py | 3 ++- 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/PopPUNK/bgmm.py b/PopPUNK/bgmm.py index 92a001c6..b8db346d 100644 --- a/PopPUNK/bgmm.py +++ b/PopPUNK/bgmm.py @@ -45,10 +45,36 @@ def fit2dMultiGaussian(X, dpgmm_max_K = 2): return dpgmm +def findBetweenLabel_bgmm(means, assignments, rank = 0): + """Identify between-strain links + + Finds the component with the largest number of points + assigned to it + + Args: + means (numpy.array) + K x 2 array of mixture component means + assignments (numpy.array) + Sample cluster assignments + rank (int) + Which label to find, ordered by distance from origin. 0-indexed. + (default = 0) + + Returns: + between_label (int) + The cluster label with the most points assigned to it + """ + most_dists = {} + for mixture_component, distance in enumerate(np.apply_along_axis(np.linalg.norm, 1, means)): + most_dists[mixture_component] = np.count_nonzero(assignments == mixture_component) + + sorted_dists = sorted(most_dists.items(), key=operator.itemgetter(1), reverse=True) + return(sorted_dists[rank][0]) + def findWithinLabel(means, assignments, rank = 0): """Identify within-strain links - Finds the component with mean closest to the origin and also akes sure + Finds the component with mean closest to the origin and also makes sure some samples are assigned to it (in the case of small weighted components with a Dirichlet prior some components are unused) @@ -59,7 +85,6 @@ def findWithinLabel(means, assignments, rank = 0): Sample cluster assignments rank (int) Which label to find, ordered by distance from origin. 0-indexed. - (default = 0) Returns: diff --git a/PopPUNK/models.py b/PopPUNK/models.py index 62c9620b..934598e2 100644 --- a/PopPUNK/models.py +++ b/PopPUNK/models.py @@ -53,6 +53,7 @@ # BGMM from .bgmm import fit2dMultiGaussian from .bgmm import findWithinLabel +from .bgmm import findBetweenLabel_bgmm from .bgmm import log_likelihood from .plot import plot_results from .plot import plot_contours @@ -328,7 +329,7 @@ def fit(self, X, max_components): y = self.assign(X) self.within_label = findWithinLabel(self.means, y) - self.between_label = findWithinLabel(self.means, y, 1) + self.between_label = findBetweenLabel_bgmm(self.means, y) return y diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index 239b8768..8479f4e2 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -388,12 +388,13 @@ def plot_contours(model, assignments, title, out_prefix): # avoid recursive import from .bgmm import log_likelihood from .bgmm import findWithinLabel + from .bgmm import findBetweenLabel_bgmm xx, yy, xy = get_grid(0, 1, 100) # for likelihood boundary z = model.assign(xy, values=True, progress=False) - z_diff = z[:,findWithinLabel(model.means, assignments, 0)] - z[:,findWithinLabel(model.means, assignments, 1)] + z_diff = z[:,findWithinLabel(model.means, assignments, 0)] - z[:,findBetweenLabel_bgmm(model.means, assignments)] z = z_diff.reshape(xx.shape).T # For full likelihood surface From 3306f01fb7c1ca31b4f9b7b8b62b6b4303203434 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Thu, 9 Feb 2023 14:29:28 +0000 Subject: [PATCH 29/65] Enable alteration of unconstrained boundary search --- PopPUNK/refine.py | 4 ++-- PopPUNK/utils.py | 9 ++++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/PopPUNK/refine.py b/PopPUNK/refine.py index aa0f9a00..5508f6cc 100644 --- a/PopPUNK/refine.py +++ b/PopPUNK/refine.py @@ -114,8 +114,8 @@ def refineFit(distMat, sample_names, mean0, mean1, scale, raise RuntimeError("Unconstrained optimization and indiv-refine incompatible") global_grid_resolution = 20 - x_max_start, y_max_start = decisionBoundary(mean0, gradient) - x_max_end, y_max_end = decisionBoundary(mean1, gradient) + x_max_start, y_max_start = decisionBoundary(mean0, gradient, adj = -1*min_move) + x_max_end, y_max_end = decisionBoundary(mean1, gradient, adj = -max_move) if x_max_start < 0 or y_max_start < 0: raise RuntimeError("Boundary range below zero") diff --git a/PopPUNK/utils.py b/PopPUNK/utils.py index cae085cc..497d190d 100644 --- a/PopPUNK/utils.py +++ b/PopPUNK/utils.py @@ -519,7 +519,7 @@ def transformLine(s, mean0, mean1): return np.array([x, y]) -def decisionBoundary(intercept, gradient): +def decisionBoundary(intercept, gradient, adj = 0.0): """Returns the co-ordinates where the triangle the decision boundary forms meets the x- and y-axes. @@ -529,12 +529,19 @@ def decisionBoundary(intercept, gradient): which intercepts the boundary gradient (float) Gradient of the line + adj (float) + Distance by which to shift the interception point Returns: x (float) The x-axis intercept y (float) The y-axis intercept """ + if adj != 0.0: + original_hypotenuse = (intercept[0]**2 + intercept[1]**2)**0.5 + length_ratio = (original_hypotenuse + adj)/original_hypotenuse + intercept[0] = intercept[0] * length_ratio + intercept[1] = intercept[1] * length_ratio x = intercept[0] + intercept[1] * gradient y = intercept[1] + intercept[0] / gradient return(x, y) From bf4fcecbc4ade8e22be282e4ab20172d80f3f61d Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Thu, 9 Feb 2023 16:32:46 +0000 Subject: [PATCH 30/65] Correct sign --- PopPUNK/refine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/refine.py b/PopPUNK/refine.py index 5508f6cc..5bf2c402 100644 --- a/PopPUNK/refine.py +++ b/PopPUNK/refine.py @@ -115,7 +115,7 @@ def refineFit(distMat, sample_names, mean0, mean1, scale, global_grid_resolution = 20 x_max_start, y_max_start = decisionBoundary(mean0, gradient, adj = -1*min_move) - x_max_end, y_max_end = decisionBoundary(mean1, gradient, adj = -max_move) + x_max_end, y_max_end = decisionBoundary(mean1, gradient, adj = max_move) if x_max_start < 0 or y_max_start < 0: raise RuntimeError("Boundary range below zero") From f5ec728c1a8718ee766fd3bdad69a7e03df2f198 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Thu, 29 Feb 2024 17:39:38 +0000 Subject: [PATCH 31/65] Update function call --- PopPUNK/refine.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/refine.py b/PopPUNK/refine.py index 45963b20..f256e0a1 100644 --- a/PopPUNK/refine.py +++ b/PopPUNK/refine.py @@ -369,7 +369,7 @@ def expand_cugraph_network(G, G_extra_df): if 'src' in G_original_df.columns: G_original_df.columns = ['source','destination'] G_df = cudf.concat([G_original_df,G_extra_df]) - G = add_self_loop(G_df, G_vertex_count, weights = False, renumber = False) + G = generate_cugraph(G_df, G_vertex_count, weights = False, renumber = False) return G def growNetwork(sample_names, i_vec, j_vec, idx_vec, s_range, score_idx = 0, From e8cb660307febe6fd1c40a37372580b4618ca318 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Mon, 11 Mar 2024 14:17:05 +0000 Subject: [PATCH 32/65] Change network parsing --- PopPUNK/network.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 44519a58..64706e7c 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -327,7 +327,10 @@ def extractReferences(G, dbOrder, outPrefix, outSuffix = '', type_isolate = None # Make a graph of the component from the overall graph vertices_in_component = component_assignments[component_assignments['labels']==component]['vertex'] references_in_component = vertices_in_component[vertices_in_component.isin(reference_indices)].values - G_component_df = G_df[G_df['source'].isin(vertices_in_component) & G_df['destination'].isin(vertices_in_component)] + if 'src' in G_df.columns: + G_component_df = G_df[G_df['src'].isin(vertices_in_component) & G_df['dst'].isin(vertices_in_component)] + else: + G_component_df = G_df[G_df['source'].isin(vertices_in_component) & G_df['destination'].isin(vertices_in_component)] G_component = cugraph.Graph() G_component.from_cudf_edgelist(G_component_df) # Find single shortest path from a reference to all other nodes in the component From c4ba9d5c43fcfc4506f0234ec501aa500909fd96 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 12 Mar 2024 11:12:50 +0000 Subject: [PATCH 33/65] Update column names --- PopPUNK/network.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 64706e7c..e14ce1b1 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -327,10 +327,7 @@ def extractReferences(G, dbOrder, outPrefix, outSuffix = '', type_isolate = None # Make a graph of the component from the overall graph vertices_in_component = component_assignments[component_assignments['labels']==component]['vertex'] references_in_component = vertices_in_component[vertices_in_component.isin(reference_indices)].values - if 'src' in G_df.columns: - G_component_df = G_df[G_df['src'].isin(vertices_in_component) & G_df['dst'].isin(vertices_in_component)] - else: - G_component_df = G_df[G_df['source'].isin(vertices_in_component) & G_df['destination'].isin(vertices_in_component)] + G_component_df = G_df[G_df['old_source'].isin(vertices_in_component) & G_df['old_destination'].isin(vertices_in_component)] G_component = cugraph.Graph() G_component.from_cudf_edgelist(G_component_df) # Find single shortest path from a reference to all other nodes in the component From 3c4a9cb121462208b9d21c1894679ab4b5693a6d Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 12 Mar 2024 12:25:10 +0000 Subject: [PATCH 34/65] Rename columns --- PopPUNK/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index e14ce1b1..eb9d993d 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -329,7 +329,7 @@ def extractReferences(G, dbOrder, outPrefix, outSuffix = '', type_isolate = None references_in_component = vertices_in_component[vertices_in_component.isin(reference_indices)].values G_component_df = G_df[G_df['old_source'].isin(vertices_in_component) & G_df['old_destination'].isin(vertices_in_component)] G_component = cugraph.Graph() - G_component.from_cudf_edgelist(G_component_df) + G_component.from_cudf_edgelist(G_component_df.rename(columns={'old_source': 'src','old_destination': 'dst'}, inplace=True)) # Find single shortest path from a reference to all other nodes in the component traversal = cugraph.traversal.sssp(G_component,source = references_in_component[0]) reference_index_set = set(reference_indices) From ec4f8c10cb4cf12192aa00573005e4314a276073 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 13 Mar 2024 09:10:21 +0000 Subject: [PATCH 35/65] Fix component graph construction --- PopPUNK/network.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index eb9d993d..c0ae9a14 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -329,9 +329,10 @@ def extractReferences(G, dbOrder, outPrefix, outSuffix = '', type_isolate = None references_in_component = vertices_in_component[vertices_in_component.isin(reference_indices)].values G_component_df = G_df[G_df['old_source'].isin(vertices_in_component) & G_df['old_destination'].isin(vertices_in_component)] G_component = cugraph.Graph() - G_component.from_cudf_edgelist(G_component_df.rename(columns={'old_source': 'src','old_destination': 'dst'}, inplace=True)) + G_component_df.rename(columns={'old_source': 'source', 'old_destination': 'destination'}, inplace=True) + G_component.from_cudf_edgelist(G_component_df) # Find single shortest path from a reference to all other nodes in the component - traversal = cugraph.traversal.sssp(G_component,source = references_in_component[0]) + traversal = cugraph.traversal.bfs(G_component,source = references_in_component[0]) reference_index_set = set(reference_indices) # Add predecessors to reference sequences on the SSSPs predecessor_list = traversal[traversal['vertex'].isin(reference_indices)]['predecessor'].values From 986a43d44e2dc68e96ce2dc1471840b1abfb6dc7 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Wed, 13 Mar 2024 10:09:19 +0000 Subject: [PATCH 36/65] Fix BFS arg --- PopPUNK/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index c0ae9a14..34930a21 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -332,7 +332,7 @@ def extractReferences(G, dbOrder, outPrefix, outSuffix = '', type_isolate = None G_component_df.rename(columns={'old_source': 'source', 'old_destination': 'destination'}, inplace=True) G_component.from_cudf_edgelist(G_component_df) # Find single shortest path from a reference to all other nodes in the component - traversal = cugraph.traversal.bfs(G_component,source = references_in_component[0]) + traversal = cugraph.traversal.bfs(G_component,start = references_in_component[0]) reference_index_set = set(reference_indices) # Add predecessors to reference sequences on the SSSPs predecessor_list = traversal[traversal['vertex'].isin(reference_indices)]['predecessor'].values From 5e63f2e30a13c61d0e682356dd68f29724a5b1ff Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 24 May 2024 06:29:00 +0100 Subject: [PATCH 37/65] Enable compatibility with CPU-only mandrake --- PopPUNK/mandrake.py | 40 +++++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/PopPUNK/mandrake.py b/PopPUNK/mandrake.py index 2d14a350..fc9e81f5 100644 --- a/PopPUNK/mandrake.py +++ b/PopPUNK/mandrake.py @@ -70,23 +70,29 @@ def generate_embedding(seqLabels, accMat, perplexity, outPrefix, overwrite, kNN weights = np.ones((len(seqLabels))) random.Random() seed = random.randint(0, 2**32) - if use_gpu and gpu_fn_available: - sys.stderr.write("Running on GPU\n") - n_workers = 65536 - maxIter = round(maxIter / n_workers) - wtsne_call = partial(wtsne_gpu_fp64, - perplexity=perplexity, - maxIter=maxIter, - blockSize=128, - n_workers=n_workers, - nRepuSamp=5, - eta0=1, - bInit=0, - animated=False, - cpu_threads=n_threads, - device_id=device_id, - seed=seed) - else: + gpu_analysis_complete = False + try: + if use_gpu and gpu_fn_available: + sys.stderr.write("Running on GPU\n") + n_workers = 65536 + maxIter = round(maxIter / n_workers) + wtsne_call = partial(wtsne_gpu_fp64, + perplexity=perplexity, + maxIter=maxIter, + blockSize=128, + n_workers=n_workers, + nRepuSamp=5, + eta0=1, + bInit=0, + animated=False, + cpu_threads=n_threads, + device_id=device_id, + seed=seed) + gpu_analysis_complete = True + except: + # If installed through conda/mamba mandrake is not GPU-enabled by default + sys.stderr.write('Mandrake analysis with GPU failed; trying with CPU\n') + if not gpu_analysis_complete: sys.stderr.write("Running on CPU\n") maxIter = round(maxIter / n_threads) wtsne_call = partial(wtsne, From 137895fe6bf80ffd468348bd86e5fa18ed5ef484 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Thu, 19 Sep 2024 17:27:59 +0100 Subject: [PATCH 38/65] Fix processing of sample removal file --- PopPUNK/__main__.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index e0483983..3faa29e4 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -429,7 +429,9 @@ def main(): if args.remove_samples: with open(args.remove_samples, 'r') as f: for line in f: - fail_unconditionally[line.rstrip] = ["removed"] + sample_to_remove = line.rstrip() + if sample_to_remove in refList: + fail_unconditionally[sample_to_remove] = ["removed"] # assembly qc pass_assembly_qc, fail_assembly_qc = \ @@ -449,7 +451,7 @@ def main(): # Get list of passing samples pass_list = set(refList) - fail_unconditionally.keys() - fail_assembly_qc.keys() - fail_dist_qc.keys() - assert(pass_list == set(refList).intersection(set(pass_assembly_qc)).intersection(set(pass_dist_qc))) + assert(pass_list == (set(refList) - fail_unconditionally.keys()).intersection(set(pass_assembly_qc)).intersection(set(pass_dist_qc))) passed = [x for x in refList if x in pass_list] if qc_dict['type_isolate'] is not None and qc_dict['type_isolate'] not in pass_list: raise RuntimeError('Type isolate ' + qc_dict['type_isolate'] + \ From 99a1e839b1f3580df3a0dd9b749a583064f82cbd Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Thu, 19 Sep 2024 18:51:28 +0100 Subject: [PATCH 39/65] Add test for pruning database --- test/remove.txt | 10 ++++++++++ test/run_test.py | 1 + 2 files changed, 11 insertions(+) create mode 100644 test/remove.txt diff --git a/test/remove.txt b/test/remove.txt new file mode 100644 index 00000000..473107fd --- /dev/null +++ b/test/remove.txt @@ -0,0 +1,10 @@ +12754_4#89_partial 12754_4#89.partial.contigs_velvet.fa +12673_8#43 12673_8#43.contigs_velvet.fa +19183_4#69 19183_4#69.contigs_velvet.fa +12754_5#76 12754_5#76.contigs_velvet.fa +12754_5#57 12754_5#57.contigs_velvet.fa +19183_4#59 19183_4#59.contigs_velvet.fa +19183_4#63 19183_4#63.contigs_velvet.fa +12754_4#77 12754_4#77.contigs_velvet.fa +19183_4#55 19183_4#55.contigs_velvet.fa +12754_4#85 12754_4#85.contigs_velvet.fa diff --git a/test/run_test.py b/test/run_test.py index 471b86b8..f75bdc7d 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -24,6 +24,7 @@ sys.stderr.write("Running database QC test (--qc-db)\n") subprocess.run(python_cmd + " ../poppunk-runner.py --qc-db --ref-db example_db --type-isolate \"12754_4#79\" --overwrite", shell=True, check=True) subprocess.run(python_cmd + " ../poppunk-runner.py --qc-db --ref-db example_db --output example_qc --type-isolate \"12754_4#79\" --length-range 2000000 3000000 --overwrite", shell=True, check=True) +subprocess.run(python_cmd + " ../poppunk-runner.py --qc-db --ref-db example_db --output example_qc --type-isolate \"12754_4#79\" --remove-samples remove.txt --overwrite", shell=True, check=True) #fit GMM sys.stderr.write("Running GMM model fit (--fit-model gmm)\n") From 2b6d085a45f995d4976355eb916fc058052d566f Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Thu, 19 Sep 2024 19:15:05 +0100 Subject: [PATCH 40/65] Pass arguments needed for network pruning --- PopPUNK/__main__.py | 3 ++- PopPUNK/qc.py | 18 ++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index 3faa29e4..399be8e4 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -463,7 +463,8 @@ def main(): remove_qc_fail(qc_dict, refList, passed, [fail_unconditionally, fail_assembly_qc, fail_dist_qc], args.ref_db, distMat, output, - args.strand_preserved, args.threads) + args.strand_preserved, args.threads, + args.gpu_graph) # Plot results if not args.no_plot: diff --git a/PopPUNK/qc.py b/PopPUNK/qc.py index 14edbdf7..59775fe2 100755 --- a/PopPUNK/qc.py +++ b/PopPUNK/qc.py @@ -420,7 +420,7 @@ def prune_edges(long_edges, type_isolate, query_start, return failed def remove_qc_fail(qc_dict, names, passed, fail_dicts, ref_db, distMat, prefix, - strand_preserved=False, threads=1): + strand_preserved=False, threads=1, use_gpu=False): """Removes samples failing QC from the database and distances. Also recalculates random match chances. @@ -446,6 +446,8 @@ def remove_qc_fail(qc_dict, names, passed, fail_dicts, ref_db, distMat, prefix, threads (int) Number of CPU threads to use when recalculating random match chances [default = 1]. + use_gpu (bool) + Whether GPU libraries were used to generate the original network. """ from .sketchlib import removeFromDB, addRandom, readDBParams @@ -480,7 +482,19 @@ def remove_qc_fail(qc_dict, names, passed, fail_dicts, ref_db, distMat, prefix, failed, distMat, f"{prefix}/{os.path.basename(prefix)}.dists") - + + # Update the graph + if use_gpu: + graph_suffix = '.csv.gz' + else: + graph_suffix = '.gt' + network_file = f"{prefix}/{os.path.basename(prefix)}" + '_graph' + graph_suffix +# prune_graph(network_file, +# passed, +# output_db_name, +# threads, +# use_gpu) + #if any removed, recalculate random sys.stderr.write(f"Recalculating random matches with strand_preserved = {strand_preserved}\n") db_kmers = readDBParams(ref_db)[0] From 868fdfc9b0d20e351e239fdf6e9aa32f95e11b17 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Thu, 19 Sep 2024 21:32:07 +0100 Subject: [PATCH 41/65] Add and test prune_database function --- PopPUNK/network.py | 57 ++++++++++++++++++++++++++++++++++++++++++++++ PopPUNK/qc.py | 17 ++++++-------- test/clean_test.py | 3 ++- test/run_test.py | 3 +++ 4 files changed, 69 insertions(+), 11 deletions(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 34930a21..dd10a804 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1888,3 +1888,60 @@ def sparse_mat_to_network(sparse_mat, rlist, use_gpu = False): summarise=False) return G + + +def prune_graph(prefix, reflist, passed, output_db_name, threads, use_gpu): + """Keep only the specified sequences in a graph + + Args: + prefix (str) + Name of directory containing network + reflist (list) + Ordered list of sequences of database + passed (list) + The names of passing samples + output_db_name (str) + Name of output directory + threads (int) + Number of CPU threads to use when recalculating random match chances + [default = 1]. + use_gpu (bool) + Whether graph is a cugraph or not + [default = False] + + Returns: + vlist (list) + List of integers corresponding to nodes + """ + if use_gpu: + graph_suffix = '.csv.gz' + else: + graph_suffix = '.gt' + network_fn = f"{prefix}/{os.path.basename(prefix)}" + '_graph' + graph_suffix + print('Network: ' + network_fn) + if os.path.exists(network_fn): + sys.stderr.write("Loading network from " + network_fn + "\n") + passed_set = frozenset(passed) + G = load_network_file(network_fn, use_gpu = use_gpu) + if use_gpu: + G_df = G.view_edge_list() + if 'src' in G_df.columns: + G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True) + G_new_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)] + G_new = translate_network_indices(G_new_df, reference_indices) + else: + reference_vertex = G.new_vertex_property('bool') + for n, vertex in enumerate(G.vertices()): + if reflist[n] in passed_set: + reference_vertex[vertex] = True + else: + reference_vertex[vertex] = False + G_new = gt.GraphView(G, vfilt = reference_vertex) + G_new = gt.Graph(G_new, prune = True) + save_network(G_new, + prefix = output_db_name, + suffix = '_graph', + use_graphml = False, + use_gpu = use_gpu) + else: + sys.stderr.write('No network file found for pruning\n') diff --git a/PopPUNK/qc.py b/PopPUNK/qc.py index 59775fe2..b1e24056 100755 --- a/PopPUNK/qc.py +++ b/PopPUNK/qc.py @@ -11,6 +11,7 @@ import poppunk_refine +from .network import prune_graph from .utils import storePickle, iterDistRows, readIsolateTypeFromCsv def prune_distance_matrix(refList, remove_seqs_in, distMat, output): @@ -484,16 +485,12 @@ def remove_qc_fail(qc_dict, names, passed, fail_dicts, ref_db, distMat, prefix, f"{prefix}/{os.path.basename(prefix)}.dists") # Update the graph - if use_gpu: - graph_suffix = '.csv.gz' - else: - graph_suffix = '.gt' - network_file = f"{prefix}/{os.path.basename(prefix)}" + '_graph' + graph_suffix -# prune_graph(network_file, -# passed, -# output_db_name, -# threads, -# use_gpu) + prune_graph(ref_db, + names, + passed, + prefix, + threads, + use_gpu) #if any removed, recalculate random sys.stderr.write(f"Recalculating random matches with strand_preserved = {strand_preserved}\n") diff --git a/test/clean_test.py b/test/clean_test.py index 95690aba..7f277e50 100755 --- a/test/clean_test.py +++ b/test/clean_test.py @@ -47,7 +47,8 @@ def deleteDir(dirname): "batch12", "batch123", "strain_1_lineage_db", - "strain_2_lineage_db" + "strain_2_lineage_db", + "example_network_qc" ] for outDir in outputDirs: deleteDir(outDir) diff --git a/test/run_test.py b/test/run_test.py index f75bdc7d..2d0b0deb 100755 --- a/test/run_test.py +++ b/test/run_test.py @@ -59,6 +59,9 @@ sys.stderr.write("Running with an existing model (--use-model)\n") subprocess.run(python_cmd + " ../poppunk-runner.py --use-model --ref-db example_db --model-dir example_db --output example_use --overwrite", shell=True, check=True) +#test pruning a database with a graph +subprocess.run(python_cmd + " ../poppunk-runner.py --qc-db --ref-db example_db --output example_network_qc --type-isolate \"12754_4#79\" --remove-samples remove.txt --overwrite", shell=True, check=True) + # tests of other command line programs sys.stderr.write("Testing C++ extension\n") subprocess.run(python_cmd + " test-refine.py", shell=True, check=True) From 6e540344a57129ff346219523bbf54386a4affd1 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Thu, 19 Sep 2024 21:51:45 +0100 Subject: [PATCH 42/65] Fix file for graph pruning --- test/remove.txt | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/test/remove.txt b/test/remove.txt index 473107fd..b070180e 100644 --- a/test/remove.txt +++ b/test/remove.txt @@ -1,10 +1,10 @@ -12754_4#89_partial 12754_4#89.partial.contigs_velvet.fa -12673_8#43 12673_8#43.contigs_velvet.fa -19183_4#69 19183_4#69.contigs_velvet.fa -12754_5#76 12754_5#76.contigs_velvet.fa -12754_5#57 12754_5#57.contigs_velvet.fa -19183_4#59 19183_4#59.contigs_velvet.fa -19183_4#63 19183_4#63.contigs_velvet.fa -12754_4#77 12754_4#77.contigs_velvet.fa -19183_4#55 19183_4#55.contigs_velvet.fa -12754_4#85 12754_4#85.contigs_velvet.fa +12754_4#89_partial +12673_8#43 +19183_4#69 +12754_5#76 +12754_5#57 +19183_4#59 +19183_4#63 +12754_4#77 +19183_4#55 +12754_4#85 From f195985b8208a4a87f88e5a61466c7ec5bd1af6f Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 20 Sep 2024 06:26:34 +0100 Subject: [PATCH 43/65] Resolve conflicts with master --- PopPUNK/lineages.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/PopPUNK/lineages.py b/PopPUNK/lineages.py index d3ed5b97..36afe729 100755 --- a/PopPUNK/lineages.py +++ b/PopPUNK/lineages.py @@ -7,6 +7,7 @@ import argparse import subprocess import pickle +import shutil import pandas as pd from collections import defaultdict @@ -142,7 +143,7 @@ def main(): create_db(args) elif args.query_db is not None: query_db(args) - + def create_db(args): @@ -150,8 +151,10 @@ def create_db(args): if not args.overwrite: if os.path.exists(args.output + '.csv'): sys.stderr.write('Output file ' + args.output + '.csv exists; use --overwrite to replace it\n') + sys.exit(1) if os.path.exists(args.db_scheme): sys.stderr.write('Output file ' + args.db_scheme + ' exists; use --overwrite to replace it\n') + sys.exit(1) sys.stderr.write("Identifying strains in existing database\n") # Read in strain information @@ -197,7 +200,8 @@ def create_db(args): if num_isolates >= args.min_count: lineage_dbs[strain] = strain_db_name if os.path.isdir(strain_db_name) and args.overwrite: - os.rmdir(strain_db_name) + sys.stderr.write("--overwrite means {strain_db_name} will be deleted now\n") + shutil.rmtree(strain_db_name) if not os.path.isdir(strain_db_name): try: os.makedirs(strain_db_name) @@ -209,7 +213,8 @@ def create_db(args): dest_db = os.path.join(strain_db_name,os.path.basename(strain_db_name) + '.h5') rel_path = os.path.relpath(src_db, os.path.dirname(dest_db)) if os.path.exists(dest_db) and args.overwrite: - os.remove(dest_db) + sys.stderr.write("--overwrite means {dest_db} will be deleted now\n") + shutil.rmtree(dest_db) elif not os.path.exists(dest_db): os.symlink(rel_path,dest_db) # Extract sparse distances @@ -217,8 +222,6 @@ def create_db(args): list(set(rlist) - set(isolate_list)), X, os.path.join(strain_db_name,strain_db_name + '.dists')) - # Prune rank list - pruned_rank_list = [r for r in rank_list if r <= num_isolates] # Initialise model model = LineageFit(strain_db_name, rank_list, @@ -306,7 +309,7 @@ def create_db(args): def query_db(args): - + # Read querying scheme with open(args.db_scheme, 'rb') as pickle_file: ref_db, rlist, model_dir, clustering_file, args.clustering_col_name, distances, \ @@ -376,6 +379,7 @@ def query_db(args): False, # write references - need to consider whether to support ref-only databases for assignment distances, False, # serial - needs to be supported for web version? + None, # stable - not supported here args.threads, True, # overwrite - probably OK? False, # plot_fit - turn off for now @@ -422,6 +426,7 @@ def query_db(args): False, # write references - need to consider whether to support ref-only databases for assignment lineage_distances, False, # serial - needs to be supported for web version? + None, # stable - not supported here args.threads, True, # overwrite - probably OK? False, # plot_fit - turn off for now @@ -436,10 +441,10 @@ def query_db(args): args.gpu_graph, save_partial_query_graph = False) overall_lineage[strain] = createOverallLineage(rank_list, lineageClustering) - + # Print combined strain and lineage clustering print_overall_clustering(overall_lineage,args.output + '.csv',qNames) - + def print_overall_clustering(overall_lineage,output,include_list): @@ -457,7 +462,7 @@ def print_overall_clustering(overall_lineage,output,include_list): isolate_info[isolate].append(str(overall_lineage[strain][rank][isolate])) else: isolate_info[isolate] = [str(strain),str(overall_lineage[strain][rank][isolate])] - + # Print output with open(output,'w') as out: out.write('id,Cluster,') @@ -470,3 +475,4 @@ def print_overall_clustering(overall_lineage,output,include_list): if __name__ == '__main__': main() sys.exit(0) + From 95e12c11d604d07939691e9fb759d3750b702296 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 20 Sep 2024 06:27:55 +0100 Subject: [PATCH 44/65] Resolve conflicts with master --- test/clean_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/clean_test.py b/test/clean_test.py index 7f277e50..75466756 100755 --- a/test/clean_test.py +++ b/test/clean_test.py @@ -28,7 +28,9 @@ def deleteDir(dirname): "example_use", "example_query", "example_single_query", + "example_query_stable", "example_query_update", + "example_query_update_2", "example_lineage_query", "example_viz", "example_viz_subset", @@ -46,8 +48,10 @@ def deleteDir(dirname): "batch3", "batch12", "batch123", + "batch123_viz", "strain_1_lineage_db", "strain_2_lineage_db", + "lineage_querying_output", "example_network_qc" ] for outDir in outputDirs: From abd4cb6d43702d786011ab69da867c746e655726 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 20 Sep 2024 06:38:31 +0100 Subject: [PATCH 45/65] Add GPU compatibility --- PopPUNK/network.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index dd10a804..6de3c852 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1924,10 +1924,15 @@ def prune_graph(prefix, reflist, passed, output_db_name, threads, use_gpu): passed_set = frozenset(passed) G = load_network_file(network_fn, use_gpu = use_gpu) if use_gpu: + # Identify indices + reference_indices = [i for (i,name) in enumerate(reflist) if name in passed_set] + # Generate data frame G_df = G.view_edge_list() if 'src' in G_df.columns: G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True) + # Filter data frame G_new_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)] + # Translate network indices to match name order G_new = translate_network_indices(G_new_df, reference_indices) else: reference_vertex = G.new_vertex_property('bool') From c97c085b727705a1d1fd951e77b3a401b1adcbd1 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 20 Sep 2024 07:01:00 +0100 Subject: [PATCH 46/65] Generalise function to any network --- PopPUNK/network.py | 69 ++++++++++++++++++++++++---------------------- 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 6de3c852..244ce184 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1890,7 +1890,7 @@ def sparse_mat_to_network(sparse_mat, rlist, use_gpu = False): return G -def prune_graph(prefix, reflist, passed, output_db_name, threads, use_gpu): +def prune_graph(prefix, reflist, passed, output_db_name, threads, use_gpu): """Keep only the specified sequences in a graph Args: @@ -1917,36 +1917,39 @@ def prune_graph(prefix, reflist, passed, output_db_name, threads, use_gpu): graph_suffix = '.csv.gz' else: graph_suffix = '.gt' - network_fn = f"{prefix}/{os.path.basename(prefix)}" + '_graph' + graph_suffix - print('Network: ' + network_fn) - if os.path.exists(network_fn): - sys.stderr.write("Loading network from " + network_fn + "\n") - passed_set = frozenset(passed) - G = load_network_file(network_fn, use_gpu = use_gpu) - if use_gpu: - # Identify indices - reference_indices = [i for (i,name) in enumerate(reflist) if name in passed_set] - # Generate data frame - G_df = G.view_edge_list() - if 'src' in G_df.columns: - G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True) - # Filter data frame - G_new_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)] - # Translate network indices to match name order - G_new = translate_network_indices(G_new_df, reference_indices) - else: - reference_vertex = G.new_vertex_property('bool') - for n, vertex in enumerate(G.vertices()): - if reflist[n] in passed_set: - reference_vertex[vertex] = True - else: - reference_vertex[vertex] = False - G_new = gt.GraphView(G, vfilt = reference_vertex) - G_new = gt.Graph(G_new, prune = True) - save_network(G_new, - prefix = output_db_name, - suffix = '_graph', - use_graphml = False, - use_gpu = use_gpu) - else: + + network_found = False + for graph_name in ['_core.refs_graph','_core_graph','_accessory.refs_graph','_accessory_graph','.refs_graph','_graph'] + network_fn = f"{prefix}/{os.path.basename(prefix)}" + graph_name + graph_suffix + if os.path.exists(network_fn): + network_found = True + sys.stderr.write("Loading network from " + network_fn + "\n") + passed_set = frozenset(passed) + G = load_network_file(network_fn, use_gpu = use_gpu) + if use_gpu: + # Identify indices + reference_indices = [i for (i,name) in enumerate(reflist) if name in passed_set] + # Generate data frame + G_df = G.view_edge_list() + if 'src' in G_df.columns: + G_df.rename(columns={'src': 'source','dst': 'destination'}, inplace=True) + # Filter data frame + G_new_df = G_df[G_df['source'].isin(reference_indices) & G_df['destination'].isin(reference_indices)] + # Translate network indices to match name order + G_new = translate_network_indices(G_new_df, reference_indices) + else: + reference_vertex = G.new_vertex_property('bool') + for n, vertex in enumerate(G.vertices()): + if reflist[n] in passed_set: + reference_vertex[vertex] = True + else: + reference_vertex[vertex] = False + G_new = gt.GraphView(G, vfilt = reference_vertex) + G_new = gt.Graph(G_new, prune = True) + save_network(G_new, + prefix = output_db_name, + suffix = '_graph', + use_graphml = False, + use_gpu = use_gpu) + if not network_found: sys.stderr.write('No network file found for pruning\n') From 9453711953389a4881d0f407d0d0bd22dcd843de Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 20 Sep 2024 07:16:14 +0100 Subject: [PATCH 47/65] Add missing colon --- PopPUNK/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 244ce184..243f4d7e 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1919,7 +1919,7 @@ def prune_graph(prefix, reflist, passed, output_db_name, threads, use_gpu): graph_suffix = '.gt' network_found = False - for graph_name in ['_core.refs_graph','_core_graph','_accessory.refs_graph','_accessory_graph','.refs_graph','_graph'] + for graph_name in ['_core.refs_graph','_core_graph','_accessory.refs_graph','_accessory_graph','.refs_graph','_graph']: network_fn = f"{prefix}/{os.path.basename(prefix)}" + graph_name + graph_suffix if os.path.exists(network_fn): network_found = True From bea29b26a221439d4a7ed506d819144dbdc19807 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 20 Sep 2024 08:34:04 +0100 Subject: [PATCH 48/65] Bump version --- PopPUNK/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/__init__.py b/PopPUNK/__init__.py index cbaf1fcf..d7f9d184 100644 --- a/PopPUNK/__init__.py +++ b/PopPUNK/__init__.py @@ -3,7 +3,7 @@ '''PopPUNK (POPulation Partitioning Using Nucleotide Kmers)''' -__version__ = '2.7.1' +__version__ = '2.7.2' # Minimum sketchlib version SKETCHLIB_MAJOR = 2 From e811b7b67c6b330207b339886dd8b95cfc2be465 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 20 Sep 2024 08:44:31 +0100 Subject: [PATCH 49/65] Consistent data structures for updated cython compatibility --- PopPUNK/unwords.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/unwords.py b/PopPUNK/unwords.py index c64f5e1c..be36d492 100644 --- a/PopPUNK/unwords.py +++ b/PopPUNK/unwords.py @@ -13,7 +13,7 @@ def gen_unword(unique=True): vowels = ["a", "e", "i", "o", "u"] trouble = ["q", "x", "y"] - consonants = set(string.ascii_lowercase) - set(vowels) - set(trouble) + consonants = list(set(string.ascii_lowercase) - set(vowels) - set(trouble)) vowel = lambda: random.sample(vowels, 1) consonant = lambda: random.sample(consonants, 1) From 962b198cbbc9af9691da971b32644c96aa8bad9a Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Fri, 20 Sep 2024 09:11:40 +0100 Subject: [PATCH 50/65] Remove assumption of reference database when updating --- PopPUNK/assign.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/assign.py b/PopPUNK/assign.py index 0d487e01..ec162121 100644 --- a/PopPUNK/assign.py +++ b/PopPUNK/assign.py @@ -753,7 +753,7 @@ def assign_query_hdf5(dbFuncs, storePickle(combined_seq, combined_seq, True, None, dists_out) # Clique pruning - if model.type != 'lineage': + if model.type != 'lineage' and os.path.isfile(ref_file_name): existing_ref_list = [] with open(ref_file_name) as refFile: for reference in refFile: From 957d4fe649c3ec95286518969e7460fc01f8ec59 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 8 Oct 2024 09:49:33 +0100 Subject: [PATCH 51/65] Document network sampling arguments --- docs/model_fitting.rst | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/model_fitting.rst b/docs/model_fitting.rst index c49c7163..bea92c37 100644 --- a/docs/model_fitting.rst +++ b/docs/model_fitting.rst @@ -505,6 +505,10 @@ Which, looking at the `microreact output Date: Tue, 8 Oct 2024 11:17:58 +0100 Subject: [PATCH 52/65] Tidy up between strain cluster identification --- PopPUNK/bgmm.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/PopPUNK/bgmm.py b/PopPUNK/bgmm.py index b8db346d..7da7c49f 100644 --- a/PopPUNK/bgmm.py +++ b/PopPUNK/bgmm.py @@ -45,7 +45,7 @@ def fit2dMultiGaussian(X, dpgmm_max_K = 2): return dpgmm -def findBetweenLabel_bgmm(means, assignments, rank = 0): +def findBetweenLabel_bgmm(means, assignments): """Identify between-strain links Finds the component with the largest number of points @@ -56,9 +56,6 @@ def findBetweenLabel_bgmm(means, assignments, rank = 0): K x 2 array of mixture component means assignments (numpy.array) Sample cluster assignments - rank (int) - Which label to find, ordered by distance from origin. 0-indexed. - (default = 0) Returns: between_label (int) From a1691af36c4d1a75b01bb96b383837ffe32ddb32 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 8 Oct 2024 11:21:12 +0100 Subject: [PATCH 53/65] Remove obsolete line --- PopPUNK/plot.py | 1 - 1 file changed, 1 deletion(-) diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index bff863e1..abb8d441 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -746,7 +746,6 @@ def writeClusterCsv(outfile, nodeNames, nodeLabels, clustering, prev_col_name = "unknown" for col in d: this_col_items = len(d[col]) - sys.stderr.write(col + ' has length ' + str(this_col_items) + '\n') if prev_col_items > -1 and prev_col_items != this_col_items: sys.stderr.write("Discrepant length between " + prev_col_name + \ " (length of " + str(prev_col_items) + ") and " + \ From 7d9e17d5fb1a858c0abfe363098e5098aeae9fcc Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 8 Oct 2024 11:29:19 +0100 Subject: [PATCH 54/65] Move database processing function into sketchlib.py --- PopPUNK/__main__.py | 8 +++++--- PopPUNK/plot.py | 17 +++++------------ PopPUNK/sketchlib.py | 18 ++++++++++++++++++ 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index 399be8e4..58cfa9fe 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -256,7 +256,7 @@ def main(): # Imports are here because graph tool is very slow to load from .models import loadClusterFit, BGMMFit, DBSCANFit, RefineFit, LineageFit - from .sketchlib import checkSketchlibLibrary, removeFromDB + from .sketchlib import checkSketchlibLibrary, removeFromDB, get_database_statistics from .network import construct_network_from_edge_list from .network import construct_network_from_assignments @@ -393,7 +393,8 @@ def main(): plot_scatter(distMat, args.output, args.output + " distances") - plot_database_evaluations(args.output) + genome_lengths, ambiguous_bases = get_database_statistics(args.output) + plot_database_evaluations(genome_lengths, ambiguous_bases) #******************************# #* *# @@ -471,7 +472,8 @@ def main(): plot_scatter(distMat, output, output + " distances") - plot_database_evaluations(output) + genome_lengths, ambiguous_bases = get_database_statistics(args.output) + plot_database_evaluations(genome_lengths, ambiguous_bases) #******************************# #* *# diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index abb8d441..042e548c 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -16,7 +16,6 @@ # for other outputs import pandas as pd from pandas.errors import DataError -import h5py from collections import defaultdict from sklearn import utils try: # sklearn >= 0.22 @@ -82,21 +81,15 @@ def plot_scatter(X, out_prefix, title, kde = True): plt.savefig(os.path.join(out_prefix, os.path.basename(out_prefix) + '_distanceDistribution.png')) plt.close() -def plot_database_evaluations(prefix): +def plot_database_evaluations(genome_lengths, ambiguous_bases): """Plot histograms of sequence characteristics for database evaluation. Args: - prefix (str) - Prefix of database + genome_lengths (list) + Lengths of genomes in database + ambiguous_bases (list) + Counts of ambiguous bases in genomes in database """ - db_file = prefix + "/" + os.path.basename(prefix) + ".h5" - ref_db = h5py.File(db_file, 'r') - - genome_lengths = [] - ambiguous_bases = [] - for sample_name in list(ref_db['sketches'].keys()): - genome_lengths.append(ref_db['sketches/' + sample_name].attrs['length']) - ambiguous_bases.append(ref_db['sketches/' + sample_name].attrs['missing_bases']) plot_evaluation_histogram(genome_lengths, n_bins = 100, prefix = prefix, diff --git a/PopPUNK/sketchlib.py b/PopPUNK/sketchlib.py index c9629df5..7d1bb23a 100644 --- a/PopPUNK/sketchlib.py +++ b/PopPUNK/sketchlib.py @@ -659,3 +659,21 @@ def fitKmerCurve(pairwise, klist, jacobian): # Return core, accessory return(np.flipud(transformed_params)) + +def plot_database_evaluations(prefix): + """Extract statistics for evaluating databases. + + Args: + prefix (str) + Prefix of database + """ + db_file = prefix + "/" + os.path.basename(prefix) + ".h5" + ref_db = h5py.File(db_file, 'r') + + genome_lengths = [] + ambiguous_bases = [] + for sample_name in list(ref_db['sketches'].keys()): + genome_lengths.append(ref_db['sketches/' + sample_name].attrs['length']) + ambiguous_bases.append(ref_db['sketches/' + sample_name].attrs['missing_bases']) + + return genome_lengths, ambiguous_bases From 77537af9479716412c4b625dea388a43a0d550d1 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 8 Oct 2024 11:30:54 +0100 Subject: [PATCH 55/65] Set tmp as default --- PopPUNK/visualise.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/visualise.py b/PopPUNK/visualise.py index 6d42bbd1..40ce0fff 100644 --- a/PopPUNK/visualise.py +++ b/PopPUNK/visualise.py @@ -137,7 +137,7 @@ def get_options(): other.add_argument('--gpu-dist', default=False, action='store_true', help='Use a GPU when calculating distances [default = False]') other.add_argument('--gpu-graph', default=False, action='store_true', help='Use a GPU when calculating graphs [default = False]') other.add_argument('--deviceid', default=0, type=int, help='CUDA device ID, if using GPU [default = 0]') - other.add_argument('--tmp', default=None, type=str, help='Directory for large temporary files') + other.add_argument('--tmp', default='/tmp/', type=str, help='Directory for large temporary files') other.add_argument('--strand-preserved', default=False, action='store_true', help='If distances being calculated, treat strand as known when calculating random ' 'match chances [default = False]') From 2054e6cfdae3152cb218d1609db58b1dbed471c8 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 8 Oct 2024 11:54:06 +0100 Subject: [PATCH 56/65] Define adj --- PopPUNK/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/utils.py b/PopPUNK/utils.py index 31433ca2..d411ecbf 100644 --- a/PopPUNK/utils.py +++ b/PopPUNK/utils.py @@ -543,7 +543,7 @@ def decisionBoundary(intercept, gradient, adj = 0.0): gradient (float) Gradient of the line adj (float) - Distance by which to shift the interception point + Fraction by which to shift the intercept up the y axis Returns: x (float) The x-axis intercept From 4d8f1b27238193fc6bef65eba638622b4a1552fa Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 8 Oct 2024 11:55:32 +0100 Subject: [PATCH 57/65] Remove incorrect return description --- PopPUNK/network.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 6079a570..88c1adb1 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1908,10 +1908,6 @@ def prune_graph(prefix, reflist, passed, output_db_name, threads, use_gpu): use_gpu (bool) Whether graph is a cugraph or not [default = False] - - Returns: - vlist (list) - List of integers corresponding to nodes """ if use_gpu: graph_suffix = '.csv.gz' From f166cfa7c98c19ed11345fe3c4331705ef16d713 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 8 Oct 2024 12:00:16 +0100 Subject: [PATCH 58/65] Fix selection of between strain cluster --- PopPUNK/bgmm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/bgmm.py b/PopPUNK/bgmm.py index 7da7c49f..fac52eca 100644 --- a/PopPUNK/bgmm.py +++ b/PopPUNK/bgmm.py @@ -66,7 +66,7 @@ def findBetweenLabel_bgmm(means, assignments): most_dists[mixture_component] = np.count_nonzero(assignments == mixture_component) sorted_dists = sorted(most_dists.items(), key=operator.itemgetter(1), reverse=True) - return(sorted_dists[rank][0]) + return(sorted_dists[0][0]) def findWithinLabel(means, assignments, rank = 0): """Identify within-strain links From a8b58c04db73e86ca79ffd1f5e7e7d5040483205 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 8 Oct 2024 12:02:12 +0100 Subject: [PATCH 59/65] Improve variable name --- PopPUNK/network.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 88c1adb1..8df43e9b 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1898,8 +1898,8 @@ def prune_graph(prefix, reflist, passed, output_db_name, threads, use_gpu): Name of directory containing network reflist (list) Ordered list of sequences of database - passed (list) - The names of passing samples + samples_to_keep (list) + The names of samples to be retained in the graph output_db_name (str) Name of output directory threads (int) @@ -1920,11 +1920,11 @@ def prune_graph(prefix, reflist, passed, output_db_name, threads, use_gpu): if os.path.exists(network_fn): network_found = True sys.stderr.write("Loading network from " + network_fn + "\n") - passed_set = frozenset(passed) + samples_to_keep_set = frozenset(samples_to_keep) G = load_network_file(network_fn, use_gpu = use_gpu) if use_gpu: # Identify indices - reference_indices = [i for (i,name) in enumerate(reflist) if name in passed_set] + reference_indices = [i for (i,name) in enumerate(reflist) if name in samples_to_keep_set] # Generate data frame G_df = G.view_edge_list() if 'src' in G_df.columns: @@ -1936,7 +1936,7 @@ def prune_graph(prefix, reflist, passed, output_db_name, threads, use_gpu): else: reference_vertex = G.new_vertex_property('bool') for n, vertex in enumerate(G.vertices()): - if reflist[n] in passed_set: + if reflist[n] in samples_to_keep_set: reference_vertex[vertex] = True else: reference_vertex[vertex] = False From 96e5b5d92c2afafc56cba18d42350c4d48e10632 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 8 Oct 2024 12:04:00 +0100 Subject: [PATCH 60/65] Correct function name --- PopPUNK/sketchlib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/sketchlib.py b/PopPUNK/sketchlib.py index 7d1bb23a..df8b941d 100644 --- a/PopPUNK/sketchlib.py +++ b/PopPUNK/sketchlib.py @@ -660,7 +660,7 @@ def fitKmerCurve(pairwise, klist, jacobian): # Return core, accessory return(np.flipud(transformed_params)) -def plot_database_evaluations(prefix): +def get_database_statistics(prefix): """Extract statistics for evaluating databases. Args: From 769834b954dd68cfaa242ad4441a2242eb40263d Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 8 Oct 2024 12:07:09 +0100 Subject: [PATCH 61/65] Pass prefix to plotting function --- PopPUNK/plot.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/PopPUNK/plot.py b/PopPUNK/plot.py index 042e548c..6f7a995b 100644 --- a/PopPUNK/plot.py +++ b/PopPUNK/plot.py @@ -81,10 +81,12 @@ def plot_scatter(X, out_prefix, title, kde = True): plt.savefig(os.path.join(out_prefix, os.path.basename(out_prefix) + '_distanceDistribution.png')) plt.close() -def plot_database_evaluations(genome_lengths, ambiguous_bases): +def plot_database_evaluations(prefix, genome_lengths, ambiguous_bases): """Plot histograms of sequence characteristics for database evaluation. Args: + prefix (str) + Prefix for output files genome_lengths (list) Lengths of genomes in database ambiguous_bases (list) From e56a3e6875452071a5816e8706822a70ebbf181d Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 8 Oct 2024 12:09:06 +0100 Subject: [PATCH 62/65] Pass correct arguments to plotting function --- PopPUNK/__main__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index 58cfa9fe..de8f8360 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -394,7 +394,7 @@ def main(): args.output, args.output + " distances") genome_lengths, ambiguous_bases = get_database_statistics(args.output) - plot_database_evaluations(genome_lengths, ambiguous_bases) + plot_database_evaluations(args.output, genome_lengths, ambiguous_bases) #******************************# #* *# @@ -473,7 +473,7 @@ def main(): output, output + " distances") genome_lengths, ambiguous_bases = get_database_statistics(args.output) - plot_database_evaluations(genome_lengths, ambiguous_bases) + plot_database_evaluations(args.output, genome_lengths, ambiguous_bases) #******************************# #* *# From 2d311bf379dd15257a26b20c5a59e5dc136183c9 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 8 Oct 2024 12:13:36 +0100 Subject: [PATCH 63/65] Fix output name processing --- PopPUNK/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index de8f8360..616cc794 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -472,7 +472,7 @@ def main(): plot_scatter(distMat, output, output + " distances") - genome_lengths, ambiguous_bases = get_database_statistics(args.output) + genome_lengths, ambiguous_bases = get_database_statistics(output) plot_database_evaluations(args.output, genome_lengths, ambiguous_bases) #******************************# From fce1a21687940a004ed1d2b41a2861c7623589bf Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 8 Oct 2024 12:16:47 +0100 Subject: [PATCH 64/65] Fix output name processing again --- PopPUNK/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/__main__.py b/PopPUNK/__main__.py index 616cc794..b35804ff 100644 --- a/PopPUNK/__main__.py +++ b/PopPUNK/__main__.py @@ -473,7 +473,7 @@ def main(): output, output + " distances") genome_lengths, ambiguous_bases = get_database_statistics(output) - plot_database_evaluations(args.output, genome_lengths, ambiguous_bases) + plot_database_evaluations(output, genome_lengths, ambiguous_bases) #******************************# #* *# From 5b4eb67a8469cb4807e4b128dbd11ff22d928880 Mon Sep 17 00:00:00 2001 From: Nick Croucher Date: Tue, 8 Oct 2024 12:30:06 +0100 Subject: [PATCH 65/65] Update argument name --- PopPUNK/network.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PopPUNK/network.py b/PopPUNK/network.py index 8df43e9b..2568e934 100644 --- a/PopPUNK/network.py +++ b/PopPUNK/network.py @@ -1890,7 +1890,7 @@ def sparse_mat_to_network(sparse_mat, rlist, use_gpu = False): return G -def prune_graph(prefix, reflist, passed, output_db_name, threads, use_gpu): +def prune_graph(prefix, reflist, samples_to_keep, output_db_name, threads, use_gpu): """Keep only the specified sequences in a graph Args: