-
Notifications
You must be signed in to change notification settings - Fork 2
/
w_postanalysis_matrix.py
174 lines (130 loc) · 7.62 KB
/
w_postanalysis_matrix.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
from __future__ import print_function, division; __metaclass__ = type
import sys, logging
import numpy as np
import scipy.sparse as sp
import westpa
from westpa import h5io
from west.data_manager import weight_dtype
from west.data_manager import seg_id_dtype
from westpa.binning import index_dtype
from westtools import (WESTTool, WESTDataReader, IterRangeSelection,
ProgressIndicatorComponent)
from postanalysis import stats_process
log = logging.getLogger('westtools.w_postanalysis_matrix')
class MatrixRw(WESTTool):
prog = 'w_postanalysis_matrix'
description = '''\
Generate a colored transition matrix from a WE assignment file. The subsequent
analysis requires that the assignments are calculated using only the initial and
final time points of each trajectory segment. This may require downsampling the
h5file generated by a WE simulation. In the future w_assign may be enhanced to optionally
generate the necessary assignment file from a h5file with intermediate time points.
Additionally, this analysis is currently only valid on simulations performed under
either equilibrium or steady-state conditions without recycling target states.
'''
def __init__(self):
super(MatrixRw, self).__init__()
self.progress = ProgressIndicatorComponent()
self.data_reader = WESTDataReader()
self.iter_range = IterRangeSelection()
self.output_file = None
self.assignments_file = None
self.default_output_file = 'flux_matrices.h5'
self.window_size = None
def add_args(self, parser):
self.data_reader.add_args(parser)
self.iter_range.add_args(parser)
iogroup = parser.add_argument_group('input/output options')
iogroup.add_argument('-a', '--assignments', default='assign.h5',
help='''Bin assignments and macrostate definitions are in ASSIGNMENTS
(default: %(default)s).''')
iogroup.add_argument('-o', '--output', dest='output', default=self.default_output_file,
help='''Store results in OUTPUT (default: %(default)s).''')
self.progress.add_args(parser)
def process_args(self, args):
self.progress.process_args(args)
self.assignments_file = h5io.WESTPAH5File(args.assignments, 'r')
self.data_reader.process_args(args)
with self.data_reader:
self.iter_range.process_args(args)
self.output_file = h5io.WESTPAH5File(args.output, 'w', creating_program=True)
h5io.stamp_creator_data(self.output_file)
if not self.iter_range.check_data_iter_range_least(self.assignments_file):
raise ValueError('assignments do not span the requested iterations')
def go(self):
pi = self.progress.indicator
pi.new_operation('Initializing')
with pi:
self.data_reader.open('r')
nbins = self.assignments_file.attrs['nbins']
state_labels = self.assignments_file['state_labels'][...]
state_map = self.assignments_file['state_map'][...]
nstates = len(state_labels)
start_iter, stop_iter = self.iter_range.iter_start, self.iter_range.iter_stop # h5io.get_iter_range(self.assignments_file)
iter_count = stop_iter - start_iter
nfbins = nbins * nstates
flux_shape = (iter_count, nfbins, nfbins)
pop_shape = (iter_count, nfbins)
h5io.stamp_iter_range(self.output_file, start_iter, stop_iter)
bin_populations_ds = self.output_file.create_dataset('bin_populations', shape=pop_shape, dtype=weight_dtype)
h5io.stamp_iter_range(bin_populations_ds, start_iter, stop_iter)
h5io.label_axes(bin_populations_ds, ['iteration', 'bin'])
flux_grp = self.output_file.create_group('iterations')
self.output_file.attrs['nrows'] = nfbins
self.output_file.attrs['ncols'] = nfbins
fluxes = np.empty(flux_shape[1:], weight_dtype)
populations = np.empty(pop_shape[1:], weight_dtype)
trans = np.empty(flux_shape[1:], np.int64)
# Check to make sure this isn't a data set with target states
tstates = self.data_reader.data_manager.get_target_states(0)
if len(tstates) > 0:
raise ValueError('Postanalysis reweighting analysis does not support WE simulation run under recycling conditions')
pi.new_operation('Calculating flux matrices', iter_count)
# Calculate instantaneous statistics
for iiter, n_iter in enumerate(xrange(start_iter, stop_iter)):
# Get data from the main HDF5 file
iter_group = self.data_reader.get_iter_group(n_iter)
seg_index = iter_group['seg_index']
nsegs, npts = iter_group['pcoord'].shape[0:2]
weights = seg_index['weight']
# Get bin and traj. ensemble assignments from the previously-generated assignments file
assignment_iiter = h5io.get_iteration_entry(self.assignments_file, n_iter)
bin_assignments = np.require(self.assignments_file['assignments'][assignment_iiter + np.s_[:nsegs,:npts]],
dtype=index_dtype)
mask_unknown = np.zeros_like(bin_assignments, dtype=np.uint16)
macrostate_iiter = h5io.get_iteration_entry(self.assignments_file, n_iter)
macrostate_assignments = np.require(self.assignments_file['trajlabels'][macrostate_iiter + np.s_[:nsegs,:npts]],
dtype=index_dtype)
# Transform bin_assignments to take macrostate membership into account
bin_assignments = nstates * bin_assignments + macrostate_assignments
mask_indx = np.where(macrostate_assignments == nstates)
mask_unknown[mask_indx] = 1
# Calculate bin-to-bin fluxes, bin populations and number of obs transitions
calc_stats(bin_assignments, weights, fluxes, populations, trans, mask_unknown)
# Store bin-based kinetics data
bin_populations_ds[iiter] = populations
# Setup sparse data structures for flux and obs
fluxes_sp = sp.coo_matrix(fluxes)
trans_sp = sp.coo_matrix(trans)
assert fluxes_sp.nnz == trans_sp.nnz
flux_iter_grp = flux_grp.create_group('iter_{:08d}'.format(n_iter))
flux_iter_grp.create_dataset('flux', data=fluxes_sp.data, dtype=weight_dtype)
flux_iter_grp.create_dataset('obs', data=trans_sp.data, dtype=np.int32)
flux_iter_grp.create_dataset('rows', data=fluxes_sp.row, dtype=np.int32)
flux_iter_grp.create_dataset('cols', data=fluxes_sp.col, dtype=np.int32)
flux_iter_grp.attrs['nrows'] = nfbins
flux_iter_grp.attrs['ncols'] = nfbins
# Do a little manual clean-up to prevent memory explosion
del iter_group, weights, bin_assignments
del macrostate_assignments
pi.progress += 1
# Check and save the number of intermediate time points; this will be used to normalize the
# flux and kinetics to tau in w_postanalysis_reweight.
self.output_file.attrs['npts'] = npts
def calc_stats(bin_assignments, weights, fluxes, populations, trans, mask):
fluxes.fill(0.0)
populations.fill(0.0)
trans.fill(0)
stats_process(bin_assignments, weights, fluxes, populations, trans, mask)
if __name__ == '__main__':
MatrixRw().main()