Skip to content

Commit

Permalink
Merge pull request #10 from 3dem/minimize-memory-usage
Browse files Browse the repository at this point in the history
Minimize memory usage
  • Loading branch information
jamaliki authored Nov 2, 2022
2 parents 37ec442 + d50b453 commit 1e5a6b3
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 90 deletions.
2 changes: 1 addition & 1 deletion model_angelo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
"""


__version__ = "0.1.0"
__version__ = "0.2"
21 changes: 17 additions & 4 deletions model_angelo/apps/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import argparse
import json
import os
import shutil
import sys

import torch
Expand Down Expand Up @@ -94,6 +95,11 @@ def add_args(parser):
default=None,
help="Inference model bundle path. If this is set, --model-bundle-name is not used."
)
advanced_args.add_argument(
"--keep-intermediate-results",
action="store_true",
help="Keep intermediate results, ie see_alpha_output and gnn_round_x_output"
)

# Below are RELION arguments, make sure to always add help=argparse.SUPPRESS

Expand Down Expand Up @@ -213,9 +219,7 @@ def main(parsed_args):
gnn_infer_args.output_dir = current_output_dir
gnn_infer_args.model_dir = gnn_model_logdir
gnn_infer_args.device = parsed_args.device

if i == total_gnn_rounds - 1:
gnn_infer_args.aggressive_pruning = True
gnn_infer_args.aggressive_pruning = True

logger.info(f"GNN model refinement round {i + 1} with args: {gnn_infer_args}")
gnn_output = gnn_infer(gnn_infer_args)
Expand All @@ -236,7 +240,16 @@ def main(parsed_args):
os.replace(raw_file_src, raw_file_dst)

os.remove(standarized_mrc_path)


if not parsed_args.keep_intermediate_output:
shutil.rmtree(ca_infer_args.output_path, ignore_errors=True)
for i in range(total_gnn_rounds):
shutil.rmtree(
os.path.join(
parsed_args.output_dir, f"gnn_output_round_{i + 1}"
)
)

print("-" * 70)
print("ModelAngelo build has been completed successfully!")
print("-" * 70)
Expand Down
14 changes: 14 additions & 0 deletions model_angelo/apps/build_no_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,11 @@ def add_args(parser):
default=None,
help="Inference model bundle path. If this is set, --model-bundle-name is not used."
)
advanced_args.add_argument(
"--keep-intermediate-results",
action="store_true",
help="Keep intermediate results, ie see_alpha_output and gnn_round_x_output"
)

# Below are RELION arguments, make sure to always add help=argparse.SUPPRESS

Expand Down Expand Up @@ -221,6 +226,15 @@ def main(parsed_args):
os.replace(hmm_profiles_src, hmm_profiles_dst)

os.remove(standarized_mrc_path)

if not parsed_args.keep_intermediate_output:
shutil.rmtree(ca_infer_args.output_path, ignore_errors=True)
for i in range(total_gnn_rounds):
shutil.rmtree(
os.path.join(
parsed_args.output_dir, f"gnn_output_round_{i + 1}"
)
)

print("-" * 70)
print("ModelAngelo build_no_seq has been completed successfully!")
Expand Down
7 changes: 7 additions & 0 deletions model_angelo/apps/eval_per_resid.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@ def get_residue_fit_report(
input_cas, target_cas, max_dist, verbose, two_rounds=two_rounds, get_unmatched=True
)

target_correspondence, input_correspondence, unmatched_target_idxs, unmatched_input_idxs = (
target_correspondence.astype(np.int32),
input_correspondence.astype(np.int32),
unmatched_target_idxs.astype(np.int32),
unmatched_input_idxs.astype(np.int32),
)

input_cas_cor = input_cas[input_correspondence]
target_cas_cor = target_cas[target_correspondence]

Expand Down
21 changes: 10 additions & 11 deletions model_angelo/gnn/flood_fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,25 +240,24 @@ def final_results_to_cif(
def flood_fill(atom14_positions, b_factors, n_c_distance_threshold=2.1):
n_positions = atom14_positions[:, 0]
c_positions = atom14_positions[:, 2]
n_c_distances = np.linalg.norm(n_positions[:, None] - c_positions[None], axis=-1)
kdtree = cKDTree(c_positions)
b_factors_copy = np.copy(b_factors)
idxs = np.arange(len(atom14_positions))

chains = []
chain_ends = {}
while np.any(b_factors_copy != -1):
idx = np.argmax(b_factors_copy)
possible_edges = (n_c_distances[idx] < n_c_distance_threshold) * (
n_c_distances[idx] > 0
possible_indices = np.array(
kdtree.query_ball_point(
n_positions[idx],
r=n_c_distance_threshold,
return_sorted=True
)
)
got_chain = False
if np.sum(possible_edges) > 0:
idx_n_c_distances = n_c_distances[idx][possible_edges]
possible_indices = idxs[possible_edges]

sorted_indices = np.argsort(idx_n_c_distances)
possible_indices = possible_indices[sorted_indices]
possible_indices = possible_indices[possible_indices != idx]

got_chain = False
if len(possible_indices) > 0:
for possible_prev_residue in possible_indices:
if possible_prev_residue == idx:
continue
Expand Down
33 changes: 24 additions & 9 deletions model_angelo/gnn/gnn_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,20 @@ def __init__(
init_affine: torch.Tensor = None,
):
self.result_dict = {}
self.keys = [
"pred_positions",
"pred_ncac",
"cryo_edges",
"cryo_edge_logits",
"cryo_aa_logits",
"local_confidence_score",
"pred_existence_mask",
"pred_affines",
"pred_torsions",
"seq_attention_scores",
"x",
]

self.refresh(
positions=positions,
hidden_features=hidden_features,
Expand All @@ -37,15 +51,9 @@ def refresh(
hidden_features: int = 256,
init_affine: torch.Tensor = None,
):
self.result_dict = {
"pred_positions": [],
"pred_ncac": [],
"cryo_edges": [],
"cryo_edge_logits": [],
"cryo_aa_logits": [],
"local_confidence_score": [],
"pred_existence_mask": [],
}
self.result_dict = {}
for key in self.keys:
self.result_dict[key] = []

if positions is not None:
self.result_dict["x"] = torch.zeros(
Expand All @@ -60,3 +68,10 @@ def refresh(
else init_affine
).requires_grad_()
]

def to(self, device: str):
for key in self.keys:
if torch.is_tensor(self.result_dict[key]):
self.result_dict[key] = self.result_dict[key].to(device)
else:
self.result_dict[key] = [x.to(device) for x in self.result_dict[key]]
39 changes: 7 additions & 32 deletions model_angelo/gnn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,17 @@ def run_inference_on_data(
run_iters=run_iters,
seq_attention_batch_size=seq_attention_batch_size,
)
result.to("cpu")
return result


def init_empty_collate_results(num_residues, unified_seq_len, device="cpu"):
result = {}
result["counts"] = torch.zeros(num_residues, device=device)
result["edge_counts"] = torch.zeros(num_residues, num_residues, device=device)

result["pred_positions"] = torch.zeros(num_residues, 3, device=device)
result["pred_affines"] = torch.zeros(num_residues, 3, 4, device=device)
result["pred_torsions"] = torch.zeros(num_residues, 83, 2, device=device)
result["aa_logits"] = torch.zeros(num_residues, 20, device=device)
result["edges"] = torch.zeros(num_residues, num_residues, device=device)
result["local_confidence"] = torch.zeros(num_residues, device=device)
result["existence_mask"] = torch.zeros(num_residues, device=device)
result["seq_attention_scores"] = torch.zeros(
Expand All @@ -150,7 +148,8 @@ def collate_nn_results(
/ collated_results["counts"][indices[:num_pred_residues]][..., None]
)
collated_results["pred_affines"][indices[:num_pred_residues]] = get_affine(
get_affine_rot(results["pred_affines"][-1][:num_pred_residues]), curr_pos_avg
get_affine_rot(results["pred_affines"][-1][:num_pred_residues]).cpu(),
curr_pos_avg
)
collated_results["aa_logits"][indices[:num_pred_residues]] += results[
"cryo_aa_logits"
Expand All @@ -165,31 +164,10 @@ def collate_nn_results(
"seq_attention_scores"
][:num_pred_residues][..., 0]

source_idx = (
indices[results["cryo_edges"][-1][1]]
.reshape(crop_length, 20)[:num_pred_residues]
.flatten()
)
target_idx = (
indices[results["cryo_edges"][-1][0]]
.reshape(crop_length, 20)[:num_pred_residues]
.flatten()
)

collated_results["edges"][source_idx, target_idx] += results["cryo_edge_logits"][
-1
][:num_pred_residues].flatten()
collated_results["edge_counts"][source_idx, target_idx] += 1

collated_results["edges"][target_idx, source_idx] += results["cryo_edge_logits"][
-1
][:num_pred_residues].flatten()
collated_results["edge_counts"][target_idx, source_idx] += 1

protein = update_protein_gt_frames(
protein,
indices[:num_pred_residues].cpu().numpy(),
collated_results["pred_affines"][indices[:num_pred_residues]].cpu().numpy(),
indices[:num_pred_residues].numpy(),
collated_results["pred_affines"][indices[:num_pred_residues]].numpy(),
)
return collated_results, protein

Expand All @@ -213,9 +191,6 @@ def get_final_nn_results(collated_results):
final_results["seq_attention_scores"] = (
collated_results["seq_attention_scores"] / collated_results["counts"][..., None]
)
final_results["edges"] = collated_results["edges"] / collated_results[
"edge_counts"
].clamp(min=1)
final_results["local_confidence"] = collated_results["local_confidence"]
final_results["existence_mask"] = collated_results["existence_mask"]

Expand All @@ -229,7 +204,7 @@ def get_final_nn_results(collated_results):
final_results["normalized_aa_entropy"].max()
)

return dict([(k, v.detach().cpu().numpy()) for (k, v) in final_results.items()])
return dict([(k, v.numpy()) for (k, v) in final_results.items()])


def infer(args):
Expand Down Expand Up @@ -293,7 +268,7 @@ def infer(args):
collated_results = init_empty_collate_results(
num_res,
protein.unified_seq_len,
device=get_module_device(module),
device="cpu",
)

residues_left = num_res
Expand Down
43 changes: 10 additions & 33 deletions model_angelo/gnn/inference_no_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,17 @@ def run_inference_on_data(
record_training=False,
run_iters=run_iters,
)
result.to("cpu")
return result


def init_empty_collate_results(num_residues, device="cpu"):
result = {}
result["counts"] = torch.zeros(num_residues, device=device)
result["edge_counts"] = torch.zeros(num_residues, num_residues, device=device)

result["pred_positions"] = torch.zeros(num_residues, 3, device=device)
result["pred_affines"] = torch.zeros(num_residues, 3, 4, device=device)
result["pred_torsions"] = torch.zeros(num_residues, 83, 2, device=device)
result["aa_logits"] = torch.zeros(num_residues, 20, device=device)
result["edges"] = torch.zeros(num_residues, num_residues, device=device)
result["local_confidence"] = torch.zeros(num_residues, device=device)
result["existence_mask"] = torch.zeros(num_residues, device=device)
return result
Expand All @@ -137,7 +135,8 @@ def collate_nn_results(
/ collated_results["counts"][indices[:num_pred_residues]][..., None]
)
collated_results["pred_affines"][indices[:num_pred_residues]] = get_affine(
get_affine_rot(results["pred_affines"][-1][:num_pred_residues]), curr_pos_avg
get_affine_rot(results["pred_affines"][-1][:num_pred_residues]),
curr_pos_avg
)
collated_results["aa_logits"][indices[:num_pred_residues]] += results[
"cryo_aa_logits"
Expand All @@ -149,31 +148,10 @@ def collate_nn_results(
"pred_existence_mask"
][-1][:num_pred_residues][..., 0]

source_idx = (
indices[results["cryo_edges"][-1][1]]
.reshape(crop_length, 20)[:num_pred_residues]
.flatten()
)
target_idx = (
indices[results["cryo_edges"][-1][0]]
.reshape(crop_length, 20)[:num_pred_residues]
.flatten()
)

collated_results["edges"][source_idx, target_idx] += results["cryo_edge_logits"][
-1
][:num_pred_residues].flatten()
collated_results["edge_counts"][source_idx, target_idx] += 1

collated_results["edges"][target_idx, source_idx] += results["cryo_edge_logits"][
-1
][:num_pred_residues].flatten()
collated_results["edge_counts"][target_idx, source_idx] += 1

protein = update_protein_gt_frames(
protein,
indices[:num_pred_residues].cpu().numpy(),
collated_results["pred_affines"][indices[:num_pred_residues]].cpu().numpy(),
indices[:num_pred_residues].numpy(),
collated_results["pred_affines"][indices[:num_pred_residues]].numpy(),
)
return collated_results, protein

Expand All @@ -194,9 +172,6 @@ def get_final_nn_results(collated_results):
final_results["aa_logits"] = (
collated_results["aa_logits"] / collated_results["counts"][..., None]
)
final_results["edges"] = collated_results["edges"] / collated_results[
"edge_counts"
].clamp(min=1)
final_results["local_confidence"] = collated_results["local_confidence"]
final_results["existence_mask"] = collated_results["existence_mask"]

Expand All @@ -210,7 +185,7 @@ def get_final_nn_results(collated_results):
final_results["normalized_aa_entropy"].max()
)

return dict([(k, v.detach().cpu().numpy()) for (k, v) in final_results.items()])
return dict([(k, v.numpy()) for (k, v) in final_results.items()])


def infer(args):
Expand Down Expand Up @@ -262,7 +237,7 @@ def infer(args):

collated_results = init_empty_collate_results(
num_res,
device=get_module_device(module),
device="cpu",
)

residues_left = num_res
Expand Down Expand Up @@ -300,13 +275,15 @@ def infer(args):

final_results = get_final_nn_results(collated_results)
output_path = os.path.join(args.output_dir, "output.cif")

# Aggressive pruning does not make sense here
final_results_to_cif(
final_results,
output_path,
sequences=None,
verbose=True,
print_fn=logger.info,
aggressive_pruning=args.aggressive_pruning,
aggressive_pruning=False,
)

return output_path
Expand Down

0 comments on commit 1e5a6b3

Please sign in to comment.