diff --git a/model_angelo/apps/build.py b/model_angelo/apps/build.py index 26bd042..ece30b4 100644 --- a/model_angelo/apps/build.py +++ b/model_angelo/apps/build.py @@ -234,7 +234,15 @@ def main(parsed_args): gnn_infer_args.write_hmm_profiles = False gnn_infer_args.refine = False - gnn_infer_args.aggressive_pruning = True + if i == total_gnn_rounds - 1: + if parsed_args.config_path is None: + gnn_infer_args.aggressive_pruning = True + else: + gnn_infer_args.aggressive_pruning = config["gnn_infer_args"][ + "aggressive_pruning" + ] + else: + gnn_infer_args.aggressive_pruning = False logger.info( f"GNN model refinement round {i + 1} with args: {gnn_infer_args}" diff --git a/model_angelo/apps/build_no_seq.py b/model_angelo/apps/build_no_seq.py index ece968e..a43de18 100644 --- a/model_angelo/apps/build_no_seq.py +++ b/model_angelo/apps/build_no_seq.py @@ -196,7 +196,14 @@ def main(parsed_args): gnn_infer_args.refine = False if i == total_gnn_rounds - 1: - gnn_infer_args.aggressive_pruning = True + if parsed_args.config_path is None: + gnn_infer_args.aggressive_pruning = True + else: + gnn_infer_args.aggressive_pruning = config["gnn_infer_args"][ + "aggressive_pruning" + ] + else: + gnn_infer_args.aggressive_pruning = False logger.info( f"GNN model refinement round {i + 1} with args: {gnn_infer_args}" @@ -225,7 +232,7 @@ def main(parsed_args): print( f"The HMM profiles are available in the directory: {hmm_profiles_dst}\n" f"They are named according to the chains found in {raw_file_dst}\n" - f"For example, chain A's profile is in {os.path.join(hmm_profiles_dst, 'A.hmm')}" + f"For example, chain A's profile is in {os.path.join(hmm_profiles_dst, 'A_u.hmm')}" ) print( f"You can use model_angelo hmm_search to search these HMM profiles against a database" diff --git a/model_angelo/apps/hmm_search.py b/model_angelo/apps/hmm_search.py index 6ec3d52..399fd1e 100644 --- a/model_angelo/apps/hmm_search.py +++ b/model_angelo/apps/hmm_search.py @@ -150,6 +150,7 @@ def main(parsed_args): os.makedirs(parsed_args.output_dir, exist_ok=True) + hmms = [(k.split("_")[0], v) for k, v in hmms] pruned_hmms = [k for k in hmms if k[1].alphabet == alphabet] try: diff --git a/model_angelo/apps/refine.py b/model_angelo/apps/refine.py index d4fd2b0..6111ef6 100644 --- a/model_angelo/apps/refine.py +++ b/model_angelo/apps/refine.py @@ -186,7 +186,12 @@ def main(parsed_args): gnn_infer_args.write_hmm_profiles = parsed_args.write_hmm_profiles gnn_infer_args.refine = True - gnn_infer_args.aggressive_pruning = True + if parsed_args.config_path is None: + gnn_infer_args.aggressive_pruning = True + else: + gnn_infer_args.aggressive_pruning = config["gnn_infer_args"][ + "aggressive_pruning" + ] logger.info(f"GNN model refinement round with args: {gnn_infer_args}") gnn_output = gnn_infer(gnn_infer_args) @@ -217,7 +222,7 @@ def main(parsed_args): print( f"The HMM profiles are available in the directory: {hmm_profiles_dst}\n" f"They are named according to the chains found in {file_dst}\n" - f"For example, chain A's profile is in {os.path.join(hmm_profiles_dst, 'A.hmm')}" + f"For example, chain A's profile is in {os.path.join(hmm_profiles_dst, 'A_u.hmm')}" ) print( f"You can use model_angelo hmm_search to search these HMM profiles against a database"