diff --git a/PopPUNK/visualise.py b/PopPUNK/visualise.py index c6ff5509..8a01a0ca 100644 --- a/PopPUNK/visualise.py +++ b/PopPUNK/visualise.py @@ -314,16 +314,17 @@ def generate_visualisations(query_db, #******************************# #* *# - #* Process dense or sparse *# - #* distances *# + #* Determine type of distance *# + #* to use *# #* *# #******************************# # Determine whether to use sparse distances use_sparse = False use_dense = False - - if (tree == "mst" or tree == "both") and rank_fit is not None: + if (tree == "nj" or tree == "both") or rank_fit == None: + use_dense = True + elif (tree == "mst" or tree == "both") and rank_fit is not None: # Set flag use_sparse = True # Read list of sequence names and sparse distance matrix @@ -336,9 +337,91 @@ def generate_visualisations(query_db, elif previous_mst is not None: sys.stderr.write('The prefix of the distance files used to create the previous MST' ' is needed to use the network') - + + #**********************************# + #* *# + #* Process clustering information *# + #* *# + #**********************************# + + # identify existing model and cluster files + if model_dir is not None: + model_prefix = model_dir + else: + model_prefix = ref_db + try: + model_file = os.path.join(model_prefix, os.path.basename(model_prefix)) + model = loadClusterFit(model_file + '_fit.pkl', + model_file + '_fit.npz') + model.set_threads(threads) + except FileNotFoundError: + sys.stderr.write('Unable to locate previous model fit in ' + model_prefix + '\n') + sys.exit(1) + + # Either use strain definitions, lineage assignments or external clustering + isolateClustering = {} + # Use external clustering if specified + if external_clustering: + mode = 'external' + cluster_file = external_clustering + if cluster_file.endswith('_lineages.csv'): + suffix = "_lineages.csv" + else: + suffix = "_clusters.csv" + else: + # Load previous clusters + if previous_clustering is not None: + cluster_file = previous_clustering + mode = "clusters" + suffix = "_clusters.csv" + if cluster_file.endswith('_lineages.csv'): + mode = "lineages" + suffix = "_lineages.csv" + else: + # Identify type of clustering based on model + mode = "clusters" + suffix = "_clusters.csv" + if model.type == "lineage": + mode = "lineages" + suffix = "_lineages.csv" + cluster_file = os.path.join(model_prefix, os.path.basename(model_prefix) + suffix) + + isolateClustering = readIsolateTypeFromCsv(cluster_file, + mode = mode, + return_dict = True) + + # Add individual refinement clusters if they exist + if model.indiv_fitted: + for type, indiv_suffix in zip(['Core','Accessory'],['_core_clusters.csv','_accessory_clusters.csv']): + indiv_clustering = os.path.join(model_prefix, os.path.basename(model_prefix) + indiv_suffix) + if os.path.isfile(indiv_clustering): + indiv_isolateClustering = readIsolateTypeFromCsv(indiv_clustering, + mode = mode, + return_dict = True) + isolateClustering[type] = indiv_isolateClustering['Cluster'] + + # Join clusters with query clusters if required + if use_dense: + if query_db is not None: + if previous_query_clustering is not None: + prev_query_clustering = previous_query_clustering + else: + prev_query_clustering = os.path.join(query_db, os.path.basename(query_db) + suffix) + + queryIsolateClustering = readIsolateTypeFromCsv( + prev_query_clustering, + mode = mode, + return_dict = True) + isolateClustering = joinClusterDicts(isolateClustering, queryIsolateClustering) + + #******************************# + #* *# + #* Process dense or sparse *# + #* distances *# + #* *# + #******************************# + if (tree == "nj" or tree == "both") or rank_fit == None: - use_dense = True # Either calculate or read distances if recalculate_distances: @@ -456,81 +539,6 @@ def generate_visualisations(query_db, core_distMat = core_distMat[np.ix_(row_slice, row_slice)] acc_distMat = acc_distMat[np.ix_(row_slice, row_slice)] - #**********************************# - #* *# - #* Process clustering information *# - #* *# - #**********************************# - - # identify existing model and cluster files - if model_dir is not None: - model_prefix = model_dir - else: - model_prefix = ref_db - try: - model_file = os.path.join(model_prefix, os.path.basename(model_prefix)) - model = loadClusterFit(model_file + '_fit.pkl', - model_file + '_fit.npz') - model.set_threads(threads) - except FileNotFoundError: - sys.stderr.write('Unable to locate previous model fit in ' + model_prefix + '\n') - sys.exit(1) - - # Either use strain definitions, lineage assignments or external clustering - isolateClustering = {} - # Use external clustering if specified - if external_clustering: - mode = 'external' - cluster_file = external_clustering - if cluster_file.endswith('_lineages.csv'): - suffix = "_lineages.csv" - else: - suffix = "_clusters.csv" - else: - # Load previous clusters - if previous_clustering is not None: - cluster_file = previous_clustering - mode = "clusters" - suffix = "_clusters.csv" - if cluster_file.endswith('_lineages.csv'): - mode = "lineages" - suffix = "_lineages.csv" - else: - # Identify type of clustering based on model - mode = "clusters" - suffix = "_clusters.csv" - if model.type == "lineage": - mode = "lineages" - suffix = "_lineages.csv" - cluster_file = os.path.join(model_prefix, os.path.basename(model_prefix) + suffix) - - isolateClustering = readIsolateTypeFromCsv(cluster_file, - mode = mode, - return_dict = True) - - # Add individual refinement clusters if they exist - if model.indiv_fitted: - for type, indiv_suffix in zip(['Core','Accessory'],['_core_clusters.csv','_accessory_clusters.csv']): - indiv_clustering = os.path.join(model_prefix, os.path.basename(model_prefix) + indiv_suffix) - if os.path.isfile(indiv_clustering): - indiv_isolateClustering = readIsolateTypeFromCsv(indiv_clustering, - mode = mode, - return_dict = True) - isolateClustering[type] = indiv_isolateClustering['Cluster'] - - # Join clusters with query clusters if required - if use_dense: - if query_db is not None: - if previous_query_clustering is not None: - prev_query_clustering = previous_query_clustering - else: - prev_query_clustering = os.path.join(query_db, os.path.basename(query_db) + suffix) - - queryIsolateClustering = readIsolateTypeFromCsv( - prev_query_clustering, - mode = mode, - return_dict = True) - isolateClustering = joinClusterDicts(isolateClustering, queryIsolateClustering) #*******************# #* *#