Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update __init__.py #528

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
196 changes: 149 additions & 47 deletions mmtbx/refinement/ensemble_refinement/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,42 +2,49 @@
from __future__ import absolute_import, division, print_function
import mmtbx.solvent.ensemble_ordered_solvent as ensemble_ordered_solvent
from mmtbx.refinement.ensemble_refinement import ensemble_utils
from mmtbx.refinement import refinement_flags
from mmtbx.dynamics import ensemble_cd
import mmtbx.tls.tools as tls_tools
import mmtbx.command_line
import mmtbx.utils
import mmtbx.model
import mmtbx.maps
from mmtbx import conformation_dependent_library as cdl
from mmtbx.conformation_dependent_library import cdl_setup
from mmtbx.command_line import validation_summary
from iotbx.option_parser import iotbx_option_parser
from iotbx import pdb
import iotbx.phil
import iotbx
from cctbx import geometry_restraints
from cctbx.array_family import flex
from cctbx import miller
from cctbx import adptbx
from cctbx import xray
import scitbx.math
from libtbx.utils import Sorry, user_plus_sys_time, multi_out, show_total_time
from libtbx import adopt_init_args, slots_getstate_setstate
from libtbx.str_utils import format_value, make_header
from libtbx import runtime_utils
from libtbx import easy_mp
import libtbx.load_env
from mmtbx.validation import rotalyze
from mmtbx.rotamer import sidechain_angles
from phenix import phenix_info
from six.moves import cStringIO as StringIO
from six.moves import cPickle as pickle
import random
import gzip
import math
import time, os
import os
import sys
from six.moves import range

# these supersede the defaults in included scopes
# Revert to defaults
# ensemble_refinement.mask.ignore_hydrogens = False
# ensemble_refinement.mask.n_radial_shells = 1
# ensemble_refinement.mask.radial_shell_width = 1.5

customization_params = iotbx.phil.parse("""
ensemble_refinement.mask.ignore_hydrogens = False
ensemble_refinement.mask.n_radial_shells = 1
ensemble_refinement.mask.radial_shell_width = 1.5
ensemble_refinement.den.kappa_burn_in_cycles = 0
ensemble_refinement.cartesian_dynamics.number_of_steps = 10
ensemble_refinement.ensemble_ordered_solvent.b_iso_min = 0.0
ensemble_refinement.ensemble_ordered_solvent.b_iso_max = 100.0
Expand Down Expand Up @@ -65,6 +72,13 @@
.type = bool
.help = Use protein atoms thermostat
}
den_restraints = True
.type = bool
.help = 'Use DEN restraints'
den
{
include scope mmtbx.den.den_params
}
update_sigmaa_rfree = 0.001
.type = float
.help = test function
Expand Down Expand Up @@ -133,6 +147,11 @@
.help = 'The fraction of atoms to include in TLS fitting'
.short_caption = Fraction of atoms to include in TLS fitting
.style = bold
import_tls_pdb = None
.type = str
.help = 'PDB path to import TLS from external structure'
.short_caption = PDB path to import TLS from external structure
.style = bold
max_ptls_cycles = 25
.type = int
.help = 'Maximum cycles to use in TLS fitting; TLS will stop prior to this if convergence is reached'
Expand Down Expand Up @@ -320,8 +339,6 @@ def __init__(self, fmodel,
ptls,
run_number=None):
adopt_init_args(self, locals())
# self.params = params.extract().ensemble_refinement

if self.params.target_name in ['ml', 'mlhl'] :
self.fix_scale = False
else:
Expand Down Expand Up @@ -369,6 +386,7 @@ def __init__(self, fmodel,
* self.cdp.number_of_steps * self.cdp.time_step, file=log)
#
print("\nAcquisition block", file=log)
print(" Number blocks : ", self.params.number_of_acquisition_periods, file=log)
print(" Number Tx periods : ", self.params.acquisition_block_n_tx, file=log)
print(" Number macro cycles : ", self.acquisition_block_macro_cycles, file=log)
print(" Time (ps) : ", self.acquisition_block_macro_cycles \
Expand Down Expand Up @@ -426,7 +444,7 @@ def __init__(self, fmodel,
max_number_of_bins = 999).show_rfactors_targets_in_bins(out = self.log)

if self.params.target_name in ['ml', 'mlhl'] :
#Must be called before reseting ADPs
# Must be called before reseting ADPs
if self.params.scale_wrt_n_calc_start:
make_header("Calculate Ncalc and restrain to scale kn", out = self.log)
self.fmodel_running.n_obs_n_calc(update_nobs_ncalc = True)
Expand All @@ -446,30 +464,40 @@ def __init__(self, fmodel,
self.target_k1 = self.fmodel_running.scale_k1()
self.update_normalisation_factors()
else:
make_header("Calculate and fix scale of Ncalc", out = self.log)
self.fmodel_running.n_obs_n_calc(update_nobs_ncalc = True)
make_header("Calculate and fix scale of Ncalc", out=self.log)
self.fmodel_running.n_obs_n_calc(update_nobs_ncalc=True)
print("Fix Ncalc scale : True", file=self.log)
print("Sum current Ncalc : {0:5.3f}".format(sum(self.fmodel_running.n_calc)), file=self.log)

#Set ADP model
# Set ADP model
self.tls_manager = er_tls_manager()
self.setup_tls_selections(tls_group_selection_strings = self.params.tls_group_selections)
self.fit_tls(input_model = self.model)
# Fit pTLS to starting atomic model
if self.params.import_tls_pdb is None:
self.setup_tls_selections(
tls_group_selection_strings=self.params.tls_group_selections)
self.fit_tls(input_model=self.model)
# Import TLS from reference model
else:
fit_tlsos, fit_tls_strings = self.import_tls_selections()
self.setup_tls_selections(tls_group_selection_strings=fit_tls_strings)
self.model.tls_groups.tlsos = fit_tlsos
self.tls_manager.tls_operators = fit_tlsos
# Assign solvent to TLS groups
self.assign_solvent_tls_groups()

#Set occupancies to 1.0
# Set occupancies to 1.0
if self.params.set_occupancies:
make_header("Set occupancies to 1.0", out = self.log)
self.model.get_xray_structure().set_occupancies(
value = 1.0)
self.model.show_occupancy_statistics(out = self.log)
#Initiates running average SFs
# Initiates running average SFs
self.er_data.f_calc_running = self.fmodel_running.f_calc().data().deep_copy()
#self.fc_running_ave = self.fmodel_running.f_calc()
self.fc_running_ave = self.fmodel_running.f_calc().deep_copy()

#Initial sigmaa array, required for ML target function
#Set eobs and ecalc normalization factors in Fmodel, required for ML
# Initial sigmaa array, required for ML target function
# Set eobs and ecalc normalization factors in Fmodel, required for ML
if self.params.target_name in ['ml', 'mlhl'] :
self.sigmaa_array = self.fmodel_running.sigmaa().sigmaa().data()
self.best_r_free = self.fmodel_running.r_free()
Expand Down Expand Up @@ -499,10 +527,9 @@ def __init__(self, fmodel,
self.update_normalisation_factors()

# Ordered Solvent Update
if self.params.ordered_solvent_update \
and (self.macro_cycle == 1\
or self.macro_cycle%self.params.ordered_solvent_update_cycle == 0):
self.ordered_solvent_update()
if self.params.ordered_solvent_update:
if self.macro_cycle == 1 or self.macro_cycle%self.params.ordered_solvent_update_cycle == 0:
self.ordered_solvent_update()

xrs_previous = self.model.get_xray_structure().deep_copy_scatterers()
assert self.fmodel_running.xray_structure is self.model.get_xray_structure()
Expand All @@ -515,6 +542,69 @@ def __init__(self, fmodel,
else:
cdp_verbose = -1

if self.params.den_restraints:
if self.macro_cycle == 1:
make_header("Create DEN restraints", out = self.log)
# Update den manager due to solvent chain changes from start model
pdb_hierarchy = self.model.get_hierarchy()
den_manager = mmtbx.den.den_restraints(
pdb_hierarchy = pdb_hierarchy,
pdb_hierarchy_ref = None,
params = self.params.den,
log = self.log)
self.model.restraints_manager.geometry.den_manager = den_manager
print(
"DEN weight : ",
self.model.restraints_manager.geometry.den_manager.weight,
file=self.log)
print(
"DEN gamma : ",
self.model.restraints_manager.geometry.den_manager.gamma,
file=self.log)
#
den_seed = self.params.random_seed
flex.set_random_seed(value=den_seed)
random.seed(den_seed)
self.model.restraints_manager.geometry.den_manager.build_den_proxies(
pdb_hierarchy=pdb_hierarchy)
self.model.restraints_manager.geometry.den_manager.build_den_restraints()
self.model.restraints_manager.geometry.den_manager.current_cycle = 1
sites_cart = self.model.get_xray_structure().sites_cart()
if self.params.verbose > 0:
print(
self.model.restraints_manager.geometry.den_manager.show_den_summary(
sites_cart=sites_cart),
file=self.log)

else:
# Reassign random pairs
if self.macro_cycle % 500 == 0:
make_header("Create DEN restraints", out = self.log)
den_seed += 1
flex.set_random_seed(value=den_seed)
pdb_hierarchy = self.model.get_hierarchy()
den_manager = mmtbx.den.den_restraints(
pdb_hierarchy = pdb_hierarchy,
pdb_hierarchy_ref = None,
params = self.params.den,
log = self.log)
self.model.restraints_manager.geometry.den_manager = den_manager
self.model.restraints_manager.geometry.den_manager.build_den_proxies(
pdb_hierarchy=pdb_hierarchy)
self.model.restraints_manager.geometry.den_manager.build_den_restraints()
self.model.restraints_manager.geometry.den_manager.current_cycle = 1
sites_cart = self.model.get_xray_structure().sites_cart()
if self.params.verbose > 0:
print(
self.model.restraints_manager.geometry.den_manager.show_den_summary(
sites_cart=sites_cart),
file=self.log)

# Update eq distances per macro cycle
self.model.restraints_manager.geometry.den_manager.current_cycle = 1
self.model.restraints_manager.geometry.den_manager.update_eq_distances(
sites_cart=self.model.get_xray_structure().sites_cart())

cd_manager = ensemble_cd.cartesian_dynamics(
structure = self.model.get_xray_structure(),
restraints_manager = self.model.restraints_manager,
Expand All @@ -540,6 +630,15 @@ def __init__(self, fmodel,
self.reset_velocities = False
self.cmremove = False

# Update CDL restraints
cdl_proxies = cdl_setup.setup_restraints(
self.model.restraints_manager.geometry,
verbose=True)
cdl.update_restraints(self.model.get_hierarchy(),
geometry=self.model.restraints_manager.geometry,
cdl_proxies=cdl_proxies,
verbose=True)

#Calc rolling average KE energy
self.kinetic_energy_running_average()
#Show KE stats
Expand Down Expand Up @@ -888,6 +987,16 @@ def update_sigmaa(self):
self.best_r_free = self.fmodel_running.r_free()
print("|"+"-"*77+"|\n", file=self.log)

def import_tls_selections(self):
make_header("Import External TLS", out = self.log)
print('External TLS model: ' + self.params.import_tls_pdb, file=self.log)
pdb_import_tls = self.params.import_tls_pdb
pdb_tls_inp = iotbx.pdb.hierarchy.input(file_name=pdb_import_tls)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please use iotbx.pdb.input() instead, preferably checking if the file exists.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still this, with the explanation in the previous message.

tls_params = pdb_tls_inp.input.extract_tls_params(pdb_tls_inp.hierarchy)
fit_tlsos = [tls_tools.tlso(t=o.t, l=o.l, s=o.s, origin=o.origin) for o in tls_params.tls_params]
tls_strings = [o.selection_string for o in tls_params.tls_params]
return fit_tlsos, tls_strings

def setup_tls_selections(self, tls_group_selection_strings):
make_header("Generating TLS selections from input parameters (not including solvent)", out = self.log)
model_no_solvent = self.model.deep_copy()
Expand Down Expand Up @@ -948,7 +1057,6 @@ def setup_tls_selections(self, tls_group_selection_strings):
tls_no_sol_no_hd_selections = mmtbx.utils.get_atom_selections(
model = model_no_solvent,
selection_strings = tls_no_hd_selection_strings)

#
assert self.tls_manager is not None
self.tls_manager.tls_selection_strings_no_sol = tls_group_selection_strings
Expand All @@ -964,7 +1072,7 @@ def setup_tls_selections(self, tls_group_selection_strings):
selection_strings = self.tls_manager.tls_selection_strings_no_sol,
tlsos = self.tls_manager.tls_operators)

def fit_tls(self, input_model, verbose = False):
def fit_tls(self, input_model, verbose=False):
make_header("Fit TLS from reference model", out = self.log)
model_copy = input_model.deep_copy()
model_copy = model_copy.remove_solvent()
Expand Down Expand Up @@ -1135,7 +1243,8 @@ def tls_parameters_update(self):
update_f_mask = False)

def assign_solvent_tls_groups(self):
self.model.get_xray_structure().convert_to_anisotropic(selection = self.model.solvent_selection())
self.model.get_xray_structure().convert_to_anisotropic(
selection=self.model.solvent_selection())
self.fmodel_running.update_xray_structure(
xray_structure = self.model.get_xray_structure(),
update_f_calc = False,
Expand Down Expand Up @@ -1183,6 +1292,7 @@ def kinetic_energy_running_average(self):
= (self.a_prime * self.er_data.ke_protein_running) + ( (1-self.a_prime) * ke)

def ordered_solvent_update(self):
make_header("Update ordered solvent", out = self.log)
ensemble_ordered_solvent_manager = ensemble_ordered_solvent.manager(
model = self.model,
fmodel = self.fmodel_running,
Expand Down Expand Up @@ -1432,7 +1542,6 @@ def write_ensemble_pdb(self, out):
pr = "REMARK 3"
print(pr, file=out)
print("REMARK 3 TIME-AVERAGED ENSEMBLE REFINEMENT.", file=out)
from phenix import phenix_info # FIXME ???
ver, tag = phenix_info.version_and_release_tag(f = out)
if(ver is None):
prog = " PROGRAM : PHENIX (phenix.ensemble_refinement)"
Expand Down Expand Up @@ -1686,6 +1795,9 @@ def run(args, command_name = "phenix.ensemble_refinement", out=None,
new_file_object=log_file)
timer = user_plus_sys_time()
mmtbx.utils.print_programs_start_header(log=log, text=command_name)
#
make_header("Ensemble Refinement 2020", out=log)
#
make_header("Ensemble refinement parameters", out = log)
working_phil.show(out = log)
make_header("Model and data statistics", out = log)
Expand Down Expand Up @@ -1739,22 +1851,16 @@ def run(args, command_name = "phenix.ensemble_refinement", out=None,
log = log)

# Refinement flags
# Worst hack I've ever seen! No wonder ensemble refinement is semi-broken!
class rf:
def __init__(self, size):
self.individual_sites = True
self.individual_adp = False
self.sites_individual = flex.bool(size, True)
self.sites_torsion_angles = None
self.torsion_angles = None
self.adp_individual_iso = None
self.adp_individual_aniso = None
def inflate(self, **keywords): pass
def select_detached(self, **keywords): pass

refinement_flags = rf(size = model.get_number_of_atoms())

model.set_refinement_flags(refinement_flags)
rf = refinement_flags.manager(
individual_sites = True,
individual_adp = False,
sites_individual = flex.bool(model.get_number_of_atoms(), True),
sites_torsion_angles = None,
torsion_angles = None,
# den = er_params.den_restraints,
adp_individual_iso = None,
adp_individual_aniso = None)
model.set_refinement_flags(rf)
model.get_restraints_manager()

# Geometry file
Expand Down Expand Up @@ -1946,7 +2052,6 @@ def __init__(self,
self.directory = os.getcwd()
self.validation = None
if (validate):
from mmtbx.command_line import validation_summary
self.validation = validation_summary.run(
args=[self.pdb_file],
out=log)
Expand Down Expand Up @@ -2004,9 +2109,6 @@ def validate_params(params):
return params

# =============================================================================
from mmtbx.validation import rotalyze
from mmtbx.rotamer import sidechain_angles

def calculate_chi_angles(model=None):
'''

Expand Down