-
Notifications
You must be signed in to change notification settings - Fork 20
/
expectation_step.d
116 lines (104 loc) · 4.21 KB
/
expectation_step.d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
/* Copyright (c) 2012,2013 Genome Research Ltd.
*
* Author: Stephan Schiffels <[email protected]>
*
* This file is part of msmc.
* msmc is free software: you can redistribute it and/or modify it under
* the terms of the GNU General Public License as published by the Free Software
* Foundation; either version 3 of the License, or (at your option) any later
* version.
*
* This program is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
* FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
* details.
*
* You should have received a copy of the GNU General Public License along with
* this program. If not, see <http://www.gnu.org/licenses/>.
*/
import std.typecons;
import std.stdio;
import std.string;
import std.algorithm;
import std.exception;
import std.conv;
import std.parallelism;
import std.math;
import core.memory;
import model.propagation_core;
import model.propagation_core_fastImpl;
import model.propagation_core_naiveImpl;
import model.msmc_model;
import model.msmc_hmm;
import model.data;
import logger;
alias Tuple!(double[], double[][], double) ExpectationResult_t;
ExpectationResult_t getExpectation(in SegSite_t[][] inputData, MSMCmodel msmc, size_t hmmStrideWidth,
size_t maxDistance, bool naiveImplementation=false)
{
PropagationCore propagationCore;
if(naiveImplementation)
propagationCore = new PropagationCoreNaive(msmc, maxDistance);
else
propagationCore = new PropagationCoreFast(msmc, maxDistance);
auto expectationResultVec = new double[msmc.nrMarginals];
auto expectationResultMat = new double[][](msmc.nrMarginals, msmc.nrMarginals);
foreach(au; 0 .. msmc.nrMarginals) {
expectationResultVec[au] = 0.0;
expectationResultMat[au][] = 0.0;
}
auto logLikelihood = 0.0;
auto cnt = 0;
foreach(data; taskPool.parallel(inputData)) {
logInfo(format("\r * [%s/%s] Expectation Step", ++cnt, inputData.length));
auto result = singleChromosomeExpectation(data, hmmStrideWidth, propagationCore);
foreach(au; 0 .. msmc.nrMarginals) {
expectationResultVec[au] += result[0][au];
expectationResultMat[au][] += result[1][au][];
}
logLikelihood += result[2];
}
logInfo(format(", log likelihood: %s", logLikelihood));
logInfo("\n");
return tuple(expectationResultVec, expectationResultMat, logLikelihood);
}
ExpectationResult_t singleChromosomeExpectation(in SegSite_t[] data, size_t hmmStrideWidth,
in PropagationCore propagationCore)
{
auto msmc_hmm = new MSMC_hmm(propagationCore, data);
msmc_hmm.runForward();
auto exp = msmc_hmm.runBackward(hmmStrideWidth);
auto logL = msmc_hmm.logLikelihood();
msmc_hmm.destroy();
return tuple(exp[0], exp[1], logL);
}
unittest {
writeln("test expectation step");
auto lambdaVec = new double[12];
lambdaVec[] = 1.0;
auto msmc = new MSMCmodel(0.01, 0.001, [0U, 0, 1, 1], lambdaVec, 4, 4, false);
auto fileName = "model/hmm_testData.txt";
auto data = readSegSites(fileName, false, [], false);
auto hmmStrideWidth = 100UL;
auto allData = [data, data, data];
std.parallelism.defaultPoolThreads(1U);
auto resultSingleThreaded = getExpectation(allData, msmc, hmmStrideWidth, 100UL, true);
std.parallelism.defaultPoolThreads(2U);
auto resultMultiThreaded = getExpectation(allData, msmc, hmmStrideWidth, 100UL);
auto lvl = 1.0e-8;
auto sumSingleThreaded = 0.0;
auto sumMultiThreaded = 0.0;
foreach(au; 0 .. msmc.nrMarginals) {
assert(resultSingleThreaded[0][au] >= 0.0);
assert(resultMultiThreaded[0][au] >= 0.0);
sumSingleThreaded += resultSingleThreaded[0][au];
sumMultiThreaded += resultMultiThreaded[0][au];
foreach(bv; 0 .. msmc.nrMarginals) {
assert(resultSingleThreaded[1][au][bv] >= 0.0, text(resultSingleThreaded[1][au][bv]));
assert(resultMultiThreaded[1][au][bv] >= 0.0, text(resultMultiThreaded[1][au][bv]));
sumSingleThreaded += resultSingleThreaded[1][au][bv];
sumMultiThreaded += resultMultiThreaded[1][au][bv];
}
}
assert(approxEqual(sumSingleThreaded, sumMultiThreaded, 1.0e-8, 0.0), text([sumSingleThreaded, sumMultiThreaded]));
}