Skip to content

Commit

Permalink
compare only within tlc of mutation
Browse files Browse the repository at this point in the history
  • Loading branch information
JohannesKarwou committed May 17, 2023
1 parent 10793e7 commit e7587ec
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 36 deletions.
55 changes: 20 additions & 35 deletions transformato/mutate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ def _performe_linear_charge_scaling(
intermediate_factory,
mutation,
):

for lambda_value in np.linspace(1, 0, nr_of_steps + 1)[1:]:
print("####################")
print(
Expand All @@ -50,7 +49,6 @@ def _performe_linear_cc_scaling(
intermediate_factory,
mutation,
) -> int:

for lambda_value in np.linspace(1, 0, nr_of_steps + 1)[1:]:
print("####################")
print(
Expand Down Expand Up @@ -192,7 +190,6 @@ def perform_mutations(
== list_of_heavy_atoms_to_be_mutated[-1]
and configuration["simulation"]["free-energy-type"] == "asfe"
):

for lambda_value in np.linspace(
0.75, 0, nr_of_mutation_steps_lj_of_heavy_atoms + 1
):
Expand Down Expand Up @@ -236,7 +233,6 @@ def perform_mutations(
# generate terminal LJ
######################################
if not configuration["simulation"]["free-energy-type"] == "asfe":

print("####################")
print(
f"Generate terminal LJ particle in step: {i.current_step} on atoms: {[v.vdw_atom_idx for v in mutation_list['default-lj']]}"
Expand Down Expand Up @@ -280,7 +276,6 @@ class DummyRegion:
lj_default: list

def return_connecting_real_atom(self, dummy_atoms: list):

for real_atom in self.match_termin_real_and_dummy_atoms:
for dummy_atom in self.match_termin_real_and_dummy_atoms[real_atom]:
if dummy_atom in dummy_atoms:
Expand All @@ -300,7 +295,6 @@ class MutationDefinition:
steric_mutation_to_default: bool = False

def print_details(self):

print("####################")
print(f"Atoms to be mutated: {self.atoms_to_be_mutated}")
print(f"Mutated on common core: {self.common_core}")
Expand Down Expand Up @@ -380,7 +374,6 @@ def __init__(
self.dummy_region_cc1: DummyRegion

def _check_cgenff_versions(self):

cgenff_sys1 = self.system["system1"].cgenff_version
cgenff_sys2 = self.system["system2"].cgenff_version
if cgenff_sys1 == cgenff_sys2:
Expand Down Expand Up @@ -493,7 +486,6 @@ def _match_terminal_dummy_atoms_between_common_cores(
match_terminal_atoms_cc1: dict,
match_terminal_atoms_cc2: dict,
) -> Tuple[list, list]:

cc1_idx = self._substructure_match["m1"]
cc2_idx = self._substructure_match["m2"]

Expand All @@ -502,13 +494,11 @@ def _match_terminal_dummy_atoms_between_common_cores(

# iterate through the common core substracter (the order represents the matched atoms)
for idx1, idx2 in zip(cc1_idx, cc2_idx):

# if both atoms are terminal atoms connected dummy regions can be identified
if (
idx1 in match_terminal_atoms_cc1.keys()
and idx2 in match_terminal_atoms_cc2.keys()
):

connected_dummy_cc1 = list(match_terminal_atoms_cc1[idx1])
connected_dummy_cc2 = list(match_terminal_atoms_cc2[idx2])

Expand Down Expand Up @@ -543,7 +533,6 @@ def _calculate_order_of_LJ_mutations(
match_terminal_atoms: dict,
G: nx.Graph,
) -> list:

try:
from tf_routes.routes import (
_calculate_order_of_LJ_mutations_new as _calculate_order_of_LJ_mutations_with_bfs,
Expand Down Expand Up @@ -685,7 +674,6 @@ def finish_common_core(
"""

if not self.asfe:

# set the teriminal real/dummy atom indices
self._set_common_core_parameters()
# match the real/dummy atoms
Expand Down Expand Up @@ -783,7 +771,6 @@ def finish_common_core(
self.charge_compensated_ligand2_psf = psf2

else:

# all atoms should become dummy atoms in the end
central_atoms = nx.center(self.graphs["m1"])

Expand Down Expand Up @@ -821,7 +808,6 @@ def finish_common_core(
)

def calculate_common_core(self):

self.propose_common_core()
self.finish_common_core()

Expand All @@ -840,7 +826,6 @@ def _prepare_cc_for_charge_transfer(self):
[self.get_common_core_idx_mol1(), self.get_common_core_idx_mol2()],
[self.dummy_region_cc1, self.dummy_region_cc2],
):

# set `initial_charge` parameter for Mutation
for atom in psf.view[f":{tlc}"].atoms:
# charge, epsilon and rmin are directly modiefied
Expand Down Expand Up @@ -939,7 +924,6 @@ def get_idx_not_in_common_core_for_mol2(self) -> list:
return self._get_idx_not_in_common_core_for_mol("m2")

def _get_idx_not_in_common_core_for_mol(self, mol_name: str) -> list:

dummy_list_mol = [
atom.GetIdx()
for atom in self.mols[mol_name].GetAtoms()
Expand Down Expand Up @@ -1036,7 +1020,6 @@ def _return_atom_idx_from_bond_idx(self, mol: Chem.Mol, bond_idx: int):
)

def _find_connected_dummy_regions(self, mol_name: str) -> List[set]:

sub = self._get_common_core(mol_name)
#############################
# start
Expand Down Expand Up @@ -1187,7 +1170,6 @@ def _transform_common_core(self) -> list:
self.get_common_core_idx_mol1() + self.dummy_region_cc1.lj_default,
self.get_common_core_idx_mol2() + self.dummy_region_cc2.lj_default,
):

# did atom type change? if not don't add BondedMutations
atom1 = self.psfs["m1"][cc1]
atom2 = self.psfs["m2"][cc2]
Expand Down Expand Up @@ -1220,6 +1202,17 @@ def _transform_common_core(self) -> list:
logger.warning(f"Bonded parameters mutation: {bonded_terms_mutation}.")
logger.warning(f"Charge parameters mutation: {charge_mutation}.")

# in point mutations all residues are in the tlc section -> we want only
# the one where the mutation happens (should be only necessary for s1_tlc)
# TODO: make an assert statement which checks that all atoms of dummy region
# belong to the same residue!
if len(self.s1_tlc) > 4:
self.s1_tlc = (
self.psf1["waterbox"]
.atoms[self.get_idx_not_in_common_core_for_mol1()[0]]
.residue.name
)

t = CommonCoreTransformation(
self.get_common_core_idx_mol1(),
self.get_common_core_idx_mol2(),
Expand Down Expand Up @@ -1509,7 +1502,6 @@ def _get_atom_mapping(self) -> dict:
return match_atom_names_cc1_to_cc2

def _mutate_charges(self, psf: pm.charmm.CharmmPsfFile, scale: float):

# common core of psf 1 is transformed to psf 2
for ligand1_atom in psf.view[f":{self.tlc_cc1}"]:
if ligand1_atom.name not in self.atom_names_mapping:
Expand Down Expand Up @@ -1644,13 +1636,11 @@ def _mutate_atoms(self, psf: pm.charmm.CharmmPsfFile, lambda_value: float):
raise RuntimeError("No corresponding atom in cc2 found")

def _mutate_bonds(self, psf: pm.charmm.CharmmPsfFile, lambda_value: float):

logger.debug("#######################")
logger.debug("mutate_bonds")

mod_type = namedtuple("Bond", "k, req")
for ligand1_bond in psf.view[f":{self.tlc_cc1}"].bonds:

ligand1_atom1_name = ligand1_bond.atom1.name
ligand1_atom2_name = ligand1_bond.atom2.name
# all atoms of the bond must be in cc
Expand Down Expand Up @@ -1717,14 +1707,17 @@ def _mutate_bonds(self, psf: pm.charmm.CharmmPsfFile, lambda_value: float):
ligand1_bond.mod_type = mod_type(modified_k, modified_req)
logger.debug(ligand1_bond.mod_type)

if not found:
# AND statement because in point mutations the atom_names_mapping check
# might mislead. There are situations where two atom names are both in
# the CC and in the dummy regions. E.g. in tlc1: C6-H62(is the dummy atom),
# in tlc2: C6-N7-H62. Now it would search for a C6-H62 bond in cc2!
if not found and ligand2_bond.atom1.residue.name == self.tlc_cc1:
logger.critical(ligand1_bond)
raise RuntimeError(
"No corresponding bond in cc2 found: {}".format(ligand1_bond)
)

def _mutate_angles(self, psf: pm.charmm.CharmmPsfFile, lambda_value: float):

mod_type = namedtuple("Angle", "k, theteq")
for cc1_angle in psf.view[f":{self.tlc_cc1}"].angles:
ligand1_atom1_name = cc1_angle.atom1.name
Expand Down Expand Up @@ -1801,18 +1794,16 @@ def _mutate_angles(self, psf: pm.charmm.CharmmPsfFile, lambda_value: float):
logging.debug(f"New k: {modified_theteq}")

cc1_angle.mod_type = mod_type(modified_k, modified_theteq)

if not found:
# AND statement is explained in bond section
if not found and cc2_angle.atom1.residue.name == self.tlc_cc1:
logger.critical(cc1_angle)
raise RuntimeError("No corresponding angle in cc2 found")

def _mutate_torsions(self, psf: pm.charmm.CharmmPsfFile, lambda_value: float):

mod_type = namedtuple("Torsion", "phi_k, per, phase, scee, scnb")

# get all torsions present in initial topology
for original_torsion in psf.view[f":{self.tlc_cc1}"].dihedrals:

found: bool = False
original_atom1_name = original_torsion.atom1.name
original_atom2_name = original_torsion.atom2.name
Expand Down Expand Up @@ -1925,8 +1916,8 @@ def _mutate_torsions(self, psf: pm.charmm.CharmmPsfFile, lambda_value: float):
)

original_torsion.mod_type = mod_types

if not found:
# AND section is explained in bond section anologon
if not found and new_torsion.atom1.residue.name == self.tlc_cc1:
logger.critical(original_torsion)
raise RuntimeError("No corresponding torsion in cc2 found")

Expand Down Expand Up @@ -1963,7 +1954,6 @@ def mutate(self, psf: pm.charmm.CharmmPsfFile, lambda_value: float):

@staticmethod
def _modify_type_in_cc(atom: pm.Atom, psf: pm.charmm.CharmmPsfFile):

if hasattr(atom, "initial_type"):
# only change parameters
pass
Expand All @@ -1979,7 +1969,6 @@ def _modify_type_in_cc(atom: pm.Atom, psf: pm.charmm.CharmmPsfFile):

class Mutation(object):
def __init__(self, atoms_to_be_mutated: list, dummy_region: DummyRegion):

assert type(atoms_to_be_mutated) == list
self.atoms_to_be_mutated = atoms_to_be_mutated
self.dummy_region = dummy_region
Expand All @@ -1988,7 +1977,6 @@ def __init__(self, atoms_to_be_mutated: list, dummy_region: DummyRegion):
def _mutate_charge(
self, psf: pm.charmm.CharmmPsfFile, lambda_value: float, offset: int
):

total_charge = int(
round(sum([atom.initial_charge for atom in psf.view[f":{self.tlc}"].atoms]))
)
Expand Down Expand Up @@ -2019,7 +2007,6 @@ def _mutate_vdw(
offset: int,
to_default: bool,
):

if not set(vdw_atom_idx).issubset(set(self.atoms_to_be_mutated)):
raise RuntimeError(
f"Specified atom {vdw_atom_idx} is not in atom_idx list {self.atoms_to_be_mutated}. Aborting."
Expand Down Expand Up @@ -2168,7 +2155,6 @@ def _scale_rmin(atom, lambda_value: float):

@staticmethod
def _modify_type(atom, psf, atom_type_suffix: str):

if hasattr(atom, "initial_type"):
# only change parameters
pass
Expand All @@ -2195,7 +2181,6 @@ def mutate_pure_tautomers(
single_state=False,
nr_of_bonded_windows: int = 4,
):

from transformato import (
IntermediateStateFactory,
)
Expand Down
36 changes: 35 additions & 1 deletion transformato/tests/test_point_mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,42 @@
os.getenv("CI") == "true",
reason="Skipping tests that cannot pass in github actions",
)
def test_setting_up_point_mutation():
def test_point_mutation_reduced_system():
configuration = load_config_yaml(
config=f"/site/raid3/johannes/h2u/data/config/H2U_1_cano-H2U_1_mod.yaml",
input_dir="/site/raid3/johannes/h2u/data",
output_dir=f"/site/raid3/johannes",
)

s1 = SystemStructure(configuration, "structure1")
s2 = SystemStructure(configuration, "structure2")
s1_to_s2 = ProposeMutationRoute(s1, s2)
s1_to_s2.propose_common_core()
s1_to_s2.finish_common_core()

mutation_list = s1_to_s2.generate_mutations_to_common_core_for_mol1()
i = IntermediateStateFactory(
system=s1,
configuration=configuration,
)

perform_mutations(
configuration=configuration,
nr_of_mutation_steps_charge=3,
nr_of_mutation_steps_cc=3,
i=i,
mutation_list=mutation_list,
)

assert s1_to_s2.get_idx_not_in_common_core_for_mol1() == [40, 43]


@pytest.mark.point_mutation
@pytest.mark.skipif(
os.getenv("CI") == "true",
reason="Skipping tests that cannot pass in github actions",
)
def test_setting_up_point_mutation():
configuration = load_config_yaml(
config=f"/site/raid3/johannes/bioinfo/data/config/cano10-psu10.yaml",
input_dir="/site/raid3/johannes/bioinfo/data/psul",
Expand Down

0 comments on commit e7587ec

Please sign in to comment.