Skip to content

Commit

Permalink
Merge pull request #10 from DeepRank/unit_test
Browse files Browse the repository at this point in the history
Add unit test for transform
  • Loading branch information
CunliangGeng authored Oct 25, 2019
2 parents fdfde48 + c96fdd9 commit 2bebbfb
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 4,921 deletions.
4,834 changes: 0 additions & 4,834 deletions pdb2sql/5hvd.pdb

This file was deleted.

11 changes: 6 additions & 5 deletions pdb2sql/StructureSimilarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(self,decoy,ref,verbose=False):
self.decoy = decoy
self.ref = ref
self.verbose = verbose
self.origin = [0., 0., 0.]


################################################################################################
Expand Down Expand Up @@ -155,7 +156,7 @@ def compute_lrmsd_fast(self,lzone=None,method='svd',check=True):
U = self.get_rotation_matrix(xyz_decoy_long,xyz_ref_long,method=method)

# rotate the entire fragment
xyz_decoy_short = transform.rotation_matrix(xyz_decoy_short,U,center=False)
xyz_decoy_short = transform.rotate(xyz_decoy_short,U, center=self.origin)

# compute the RMSD
return self.get_rmsd(xyz_decoy_short,xyz_ref_short)
Expand Down Expand Up @@ -293,7 +294,7 @@ def compute_irmsd_fast(self,izone=None,method='svd',cutoff=10,check=True):
U = self.get_rotation_matrix(xyz_contact_decoy,xyz_contact_ref,method=method)

# rotate the entire fragment
xyz_contact_decoy = transform.rotation_matrix(xyz_contact_decoy,U,center=False)
xyz_contact_decoy = transform.rotate(xyz_contact_decoy,U,center=self.origin)

# return the RMSD
return self.get_rmsd(xyz_contact_decoy,xyz_contact_ref)
Expand Down Expand Up @@ -551,7 +552,7 @@ def compute_lrmsd_pdb2sql(self,exportpath=None,method='svd'):
U = self.get_rotation_matrix(xyz_decoy_long,xyz_ref_long,method=method)

# rotate the entire fragment
xyz_decoy_short = transform.rotation_matrix(xyz_decoy_short,U,center=False)
xyz_decoy_short = transform.rotate(xyz_decoy_short, U, center=self.origin)


# compute the RMSD
Expand All @@ -569,7 +570,7 @@ def compute_lrmsd_pdb2sql(self,exportpath=None,method='svd'):
xyz_decoy += tr_decoy

# rotate decoy
xyz_decoy = transform.rotation_matrix(xyz_decoy,U,center=False)
xyz_decoy = transform.rotate(xyz_decoy, U, center=self.origin)

# update the sql database
sql_decoy.update_column('x',xyz_decoy[:,0])
Expand Down Expand Up @@ -726,7 +727,7 @@ def compute_irmsd_pdb2sql(self,cutoff=10,method='svd',izone=None,exportpath=None
U = self.get_rotation_matrix(xyz_contact_decoy,xyz_contact_ref,method=method)

# rotate the entire fragment
xyz_contact_decoy = transform.rotation_matrix(xyz_contact_decoy,U,center=False)
xyz_contact_decoy = transform.rotate(xyz_contact_decoy, U, center=self.origin)

# compute the RMSD
irmsd = self.get_rmsd(xyz_contact_decoy,xyz_contact_ref)
Expand Down
2 changes: 1 addition & 1 deletion pdb2sql/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.2.0'
__version__ = '0.2.1'
174 changes: 93 additions & 81 deletions pdb2sql/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,73 +8,51 @@
definition of the data set.
'''

def get_rot_axis_angle(seed=None):
"""Get the rotation angle/axis.
Args:
seed(int): random seed for numpy
Returns:
list(float): axis of rotation
float: angle of rotation
"""
# define the axis
# uniform distribution on a sphere
# http://mathworld.wolfram.com/SpherePointPicking.html
if seed != None:
np.random.seed(seed)

u1, u2 = np.random.rand(), np.random.rand()
teta, phi = np.arccos(2 * u1 - 1), 2 * np.pi * u2
axis = [np.sin(teta) * np.cos(phi),
np.sin(teta) * np.sin(phi),
np.cos(teta)]

# and the rotation angle
angle = -np.pi + np.pi * np.random.rand()

return axis, angle


########################################################################
# Translation
########################################################################
def translation(db, vect, **kwargs):
xyz = _get_xyz(db, **kwargs)
xyz += vect
_update(db, xyz, **kwargs)


########################################################################
# Rotation using axis–angle presentation
# see https://en.wikipedia.org/wiki/Rotation_matrix#Rotation_matrix_from_axis_and_angle
########################################################################
def rot_axis(db, axis, angle, **kwargs):
xyz = _get_xyz(db, **kwargs)
xyz = rot_xyz_around_axis(xyz, axis, angle)
_update(db, xyz, **kwargs)


def rot_euler(db, alpha, beta, gamma, **kwargs):
"""Rotate molecule from Euler rotation axis.
def get_rot_axis_angle(seed=None):
"""Get the rotation angle and axis.
Args:
alpha (float): angle of rotation around the x axis
beta (float): angle of rotation around the y axis
gamma (float): angle of rotation around the z axis
**kwargs: keyword argument to select the atoms.
See pdb2sql.get()
"""
xyz = _get_xyz(db, **kwargs)
xyz = _rotation_euler(xyz, alpha, beta, gamma)
_update(db, xyz, **kwargs)
seed(int): random seed for numpy
Returns:
list(float): axis of rotation
float: angle of rotation
"""
if seed is not None:
np.random.seed(seed)

def rot_mat(db, mat, **kwargs):
"""Rotate molecule from a rotation matrix.
# define the rotation axis
# uniform distribution on a sphere
# eq1,2 in http://mathworld.wolfram.com/SpherePointPicking.html
u1, u2 = np.random.rand(), np.random.rand()
theta = 2 * np.pi * u1 # [0, 2*pi)
phi = np.arccos(2 * u2 - 1) # [0, pi]
# eq19 in http://mathworld.wolfram.com/SphericalCoordinates.html
axis = [np.sin(phi) * np.cos(theta),
np.sin(phi) * np.sin(theta),
np.cos(phi)]

Args:
mat (np.array): 3x3 rotation matrix
**kwargs: keyword argument to select the atoms.
See pdb2sql.get()
"""
xyz = _get_xyz(db, **kwargs)
xyz = _rotation_matrix(xyz, mat)
_update(db, xyz, **kwargs)
# define the rotation angle
angle = 2 * np.pi * np.random.rand()

return axis, angle

def rot_xyz_around_axis(xyz, axis, angle, center=None):
"""Get the rotated xyz.
Expand All @@ -89,17 +67,11 @@ def rot_xyz_around_axis(xyz, axis, angle, center=None):
Returns:
np.array: rotated xyz coordinates
"""

# check center
if center is None:
center = np.mean(xyz, 0)

# get the data
ct, st = np.cos(angle), np.sin(angle)
ux, uy, uz = axis

# definition of the rotation matrix
# see https://en.wikipedia.org/wiki/Rotation_matrix
rot_mat = np.array([[ct + ux**2 * (1 - ct),
ux * uy * (1 - ct) - uz * st,
ux * uz * (1 - ct) + uy * st],
Expand All @@ -111,51 +83,91 @@ def rot_xyz_around_axis(xyz, axis, angle, center=None):
ct + uz**2 * (1 - ct)]])

# apply the rotation
return np.dot(rot_mat, (xyz - center).T).T + center
return rotate(xyz, rot_mat, center)

########################################################################
# Rotation using Euler anlges
# see https://en.wikipedia.org/wiki/Rotation_matrix#General_rotations
########################################################################

def _rotation_euler(xyz, alpha, beta, gamma):
def rot_euler(db, alpha, beta, gamma, **kwargs):
"""Rotate molecule from Euler rotation axis.
Args:
alpha (float): angle of rotation around the x axis
beta (float): angle of rotation around the y axis
gamma (float): angle of rotation around the z axis
**kwargs: keyword argument to select the atoms.
See pdb2sql.get()
"""
xyz = _get_xyz(db, **kwargs)
xyz = rotation_euler(xyz, alpha, beta, gamma)
_update(db, xyz, **kwargs)

def rotation_euler(xyz, alpha, beta, gamma, center=None):

# precomte the trig
ca, sa = np.cos(alpha), np.sin(alpha)
cb, sb = np.cos(beta), np.sin(beta)
cg, sg = np.cos(gamma), np.sin(gamma)

# get the center of the molecule
xyz0 = np.mean(xyz, 0)

# rotation matrices
rx = np.array([[1, 0, 0], [0, ca, -sa], [0, sa, ca]])
ry = np.array([[cb, 0, sb], [0, 1, 0], [-sb, 0, cb]])
rz = np.array([[cg, -sg, 0], [sg, cs, 0], [0, 0, 1]])
rz = np.array([[cg, -sg, 0], [sg, cg, 0], [0, 0, 1]])

rot_mat = np.dot(rz, np.dot(ry, rz))
# get rotation matrix
rot_mat = np.dot(rz, np.dot(ry, rx))

# apply the rotation
return np.dot(rot_mat, (xyz - xyz0).T).T + xyz0
return rotate(xyz, rot_mat, center)

########################################################################
# Rotation using provided rotation matrix
########################################################################

def rotation_matrix(xyz, rot_mat, center=True):
if center:
xyz0 = np.mean(xyz)
return np.dot(rot_mat, (xyz - xyz0).T).T + xyz0
else:
return np.dot(rot_mat, (xyz).T).T
def rot_mat(db, mat, **kwargs):
"""Rotate molecule from a rotation matrix.
Args:
mat (np.array): 3x3 rotation matrix
**kwargs: keyword argument to select the atoms.
See pdb2sql.get()
"""
xyz = _get_xyz(db, **kwargs)
xyz = rotate(xyz, mat)
_update(db, xyz, **kwargs)

def _get_xyz(db, **kwargs):
return np.array(db.get('x,y,z', **kwargs))
def rotate(xyz, rot_mat, center=None):
"""[summary]
Args:
xyz(np.ndarray): x,y,z coordinates
rot_mat(np.ndarray): rotation matrix
center (list or np.ndarray, optional): rotation center.
Defaults to None, i.e. using molecule center as rotation
center.
def _update(db, xyz, **kwargs):
db.update('x,y,z', xyz, **kwargs)
Raises:
TypeError: Rotation center must be list or 1D np.ndarray.
Returns:
np.ndarray: x,y,z coordinates after rotation
"""
# the default rotation center is the center of molecule itself.
if center is None:
center = np.mean(xyz, 0)

if __name__ == "__main__":
if not isinstance(center, (list, np.ndarray)):
raise TypeError("Rotation center must be list or 1D np.ndarray")

t0 = time()
db = pdb2sql('5hvd.pdb')
print('SQL %f' % (time() - t0))
return np.dot(rot_mat, (xyz - center).T).T + center

tr = np.array([1, 2, 3])
translation(db, tr, chainID='A')
########################################################################
# helper functions
########################################################################
def _get_xyz(db, **kwargs):
return np.array(db.get('x,y,z', **kwargs))

def _update(db, xyz, **kwargs):
db.update('x,y,z', xyz, **kwargs)
6 changes: 6 additions & 0 deletions test/pdb/dummy_transform.pdb
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
ATOM 1 N THR A 1 1.000 0.000 0.000 1.00 0.69 N
ATOM 2 CA THR A 1 -1.000 0.000 0.000 1.00 0.50 C
ATOM 3 C THR A 1 0.000 1.000 0.000 1.00 0.45 C
ATOM 4 O THR A 1 0.000 -1.000 0.000 1.00 0.69 O
ATOM 5 CB THR A 1 0.000 0.000 1.000 1.00 0.50 C
ATOM 6 H1 THR A 1 0.000 0.000 -1.000 1.00 0.45 H
Loading

0 comments on commit 2bebbfb

Please sign in to comment.