Skip to content

Commit

Permalink
Remove dependence on .mrc file extension
Browse files Browse the repository at this point in the history
  • Loading branch information
jamaliki committed Oct 18, 2023
1 parent 50cbf85 commit e05b495
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 38 deletions.
2 changes: 1 addition & 1 deletion model_angelo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
"""


__version__ = "1.0.6"
__version__ = "1.0.7"
5 changes: 4 additions & 1 deletion model_angelo/c_alpha/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import time
from itertools import product
from random import shuffle
import warnings

import numpy as np
import torch
Expand Down Expand Up @@ -156,7 +157,9 @@ def infer(args):
else:
cas = ds["cas"]

elif args.map_path.endswith("mrc"):
else:
if not args.map_path.endswith("mrc"):
warnings.warn(f"The file {args.map_path} does not end with '.mrc'\nPlease make sure it is an MRC file.")
grid_np, voxel_size, global_origin = load_mrc(args.map_path)
grid_np, voxel_size, global_origin = make_model_angelo_grid(
np.copy(grid_np), voxel_size, global_origin, target_voxel_size=1.5
Expand Down
34 changes: 16 additions & 18 deletions model_angelo/gnn/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import warnings

import numpy as np
import torch
Expand Down Expand Up @@ -93,24 +94,21 @@ def infer(args):
if protein is None:
raise RuntimeError(f"File {args.struct} is not a supported file format.")

grid_data = None
if args.map.endswith("mrc"):
grid_data = load_mrc(args.map, multiply_global_origin=False)
grid_data = make_model_angelo_grid(
grid_data.grid,
grid_data.voxel_size,
grid_data.global_origin,
target_voxel_size=voxel_size,
)
grid_data = MRCObject(
grid=grid_data.grid,
voxel_size=grid_data.voxel_size,
global_origin=np.zeros((3,), dtype=np.float32),
)
if grid_data is None:
raise RuntimeError(
f"Grid volume file {args.map} is not a supported file format."
)
if not args.map.endswith("mrc"):
warnings.warn(f"The file {args.map} does not end with '.mrc'\nPlease make sure it is an MRC file.")

grid_data = load_mrc(args.map, multiply_global_origin=False)
grid_data = make_model_angelo_grid(
grid_data.grid,
grid_data.voxel_size,
grid_data.global_origin,
target_voxel_size=voxel_size,
)
grid_data = MRCObject(
grid=grid_data.grid,
voxel_size=grid_data.voxel_size,
global_origin=np.zeros((3,), dtype=np.float32),
)
# Standardize the grid to have a mean of 0 and a standard deviation of 1
grid_data = standardize_mrc(grid_data)

Expand Down
35 changes: 17 additions & 18 deletions model_angelo/gnn/inference_no_seq.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import warnings

import numpy as np
import torch
Expand Down Expand Up @@ -55,24 +56,22 @@ def infer(args):
if protein is None:
raise RuntimeError(f"File {args.struct} is not a supported file format.")

grid_data = None
if args.map.endswith("mrc"):
grid_data = load_mrc(args.map, multiply_global_origin=False)
grid_data = make_model_angelo_grid(
grid_data.grid,
grid_data.voxel_size,
grid_data.global_origin,
target_voxel_size=voxel_size,
)
grid_data = MRCObject(
grid=grid_data.grid,
voxel_size=grid_data.voxel_size,
global_origin=np.zeros((3,), dtype=np.float32),
)
if grid_data is None:
raise RuntimeError(
f"Grid volume file {args.map} is not a supported file format."
)
if not args.map.endswith("mrc"):
warnings.warn(f"The file {args.map} does not end with '.mrc'\nPlease make sure it is an MRC file.")

grid_data = load_mrc(args.map, multiply_global_origin=False)
grid_data = make_model_angelo_grid(
grid_data.grid,
grid_data.voxel_size,
grid_data.global_origin,
target_voxel_size=voxel_size,
)
grid_data = MRCObject(
grid=grid_data.grid,
voxel_size=grid_data.voxel_size,
global_origin=np.zeros((3,), dtype=np.float32),
)

# Standardize the grid to have a mean of 0 and a standard deviation of 1
grid_data = standardize_mrc(grid_data)
num_res = len(protein.rigidgroups_gt_frames)
Expand Down

0 comments on commit e05b495

Please sign in to comment.