-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Guoye/merge embr final #3560
Open
guoli-ye
wants to merge
7
commits into
master
Choose a base branch
from
guoye/merge_embr_final
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Guoye/merge embr final #3560
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
86f9f65
cherry-pick commit 4018e1e, with manual edit
guoli-ye f6855ba
fix hard-tab in htkfeatio.h
guoli-ye 91456ba
add change until 5771cb8~ for guoye/merge_embr
guoli-ye 7f659f4
delete a redudant } in if statement
guoli-ye a64e306
fix removeextension with new CNTK usage
guoli-ye a5cd58d
add public for header_v1_v2 info
guoli-ye b568283
make the function interface of AssignSequenceError in NoGpu.cpp to be…
guoli-ye File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -23,7 +23,7 @@ | |
#include <algorithm> // for find() | ||
#include "simplesenonehmm.h" | ||
#include "Matrix.h" | ||
|
||
#include <set> | ||
namespace msra { namespace math { | ||
|
||
class ssematrixbase; | ||
|
@@ -67,7 +67,28 @@ enum mbrclassdefinition // used to identify definition of class in minimum bayes | |
// =========================================================================== | ||
class lattice | ||
{ | ||
public: | ||
public: | ||
// definie structure for nbest EMBR | ||
struct TokenInfo | ||
{ | ||
double score; // the score of the token | ||
size_t prev_edge_index; // edge ending with this token, edge start points to the previous node | ||
size_t prev_token_index; // the token index in the previous node | ||
}; | ||
struct PrevTokenInfo | ||
{ | ||
size_t prev_edge_index; | ||
size_t prev_token_index; | ||
double path_score; // use pure to indicatethe path score does not consider the WER of the path | ||
}; | ||
|
||
struct NBestToken | ||
{ | ||
// for sorting purpose | ||
// make sure the map is stored with keys in descending order | ||
std::map<double, std::vector<PrevTokenInfo>, std::greater <double>> mp_score_token_infos; // for sorting the tokens in map | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why not ordered_map? -- similarly in other places. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the map is for ordering purpose |
||
std::vector<TokenInfo> vt_nbest_tokens; // stores the nbest tokens in the node | ||
}; | ||
struct header_v1_v2 | ||
{ | ||
size_t numnodes : 32; | ||
|
@@ -90,12 +111,15 @@ class lattice | |
static const unsigned int NOEDGE = 0xffffff; // 24 bits | ||
// static_assert (sizeof (nodeinfo) == 8, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary | ||
// ensure type size as these are expected to be of this size in the files we read | ||
static_assert(sizeof(nodeinfo) == 2, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary | ||
static_assert(sizeof(nodeinfo) == 16, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary | ||
static_assert(sizeof(edgeinfowithscores) == 16, "unexpected size of edgeinfowithscores"); | ||
static_assert(sizeof(aligninfo) == 4, "unexpected size of aligninfo"); | ||
std::vector<nodeinfo> nodes; | ||
mutable std::vector<std::vector<uint64_t>> vt_node_out_edge_indices; // vt_node_out_edge_indices[i]: it stores the outgoing edge indices starting from node i | ||
std::vector<bool> is_special_words; // true if it is special words that do not count to WER computation, false if it is not | ||
std::vector<edgeinfowithscores> edges; | ||
std::vector<aligninfo> align; | ||
|
||
// V2 lattices --for a while, we will store both in RAM, until all code is updated | ||
static int fsgn(float f) | ||
{ | ||
|
@@ -217,6 +241,10 @@ class lattice | |
public: // TODO: make private again once | ||
// construct from edges/align | ||
// This is also used for merging, where the edges[] array is not correctly sorted. So don't assume this here. | ||
void erase_node_out_edges(size_t nodeidx, size_t edgeidx_start, size_t edgeidx_end) const | ||
{ | ||
vt_node_out_edge_indices[nodeidx].erase(vt_node_out_edge_indices[nodeidx].begin() + edgeidx_start, vt_node_out_edge_indices[nodeidx].begin() + edgeidx_end); | ||
} | ||
void builduniquealignments(size_t spunit = SIZE_MAX /*fix this later*/) | ||
{ | ||
// infer /sp/ unit if not given | ||
|
@@ -701,6 +729,7 @@ class lattice | |
const float lmf, const float wp, const float amf, const_array_ref<size_t>& uids, | ||
const edgealignments& thisedgealignments, std::vector<double>& Eframescorrect) const; | ||
|
||
|
||
void sMBRerrorsignal(parallelstate& parallelstate, | ||
msra::math::ssematrixbase& errorsignal, msra::math::ssematrixbase& errorsignalneg, | ||
const std::vector<double>& logpps, const float amf, double minlogpp, | ||
|
@@ -736,7 +765,8 @@ class lattice | |
const std::vector<double>& logpps, const float amf, | ||
const std::vector<double>& logEframescorrect, const double logEframescorrecttotal, | ||
msra::math::ssematrixbase& errorsignal, msra::math::ssematrixbase& errorsignalneg) const; | ||
|
||
void parallelEMBRerrorsignal(parallelstate& parallelstate, const edgealignments& thisedgealignments, | ||
const std::vector<double>& edgeweights, msra::math::ssematrixbase& errorsignal) const; | ||
void parallelmmierrorsignal(parallelstate& parallelstate, const edgealignments& thisedgealignments, | ||
const std::vector<double>& logpps, msra::math::ssematrixbase& errorsignal) const; | ||
|
||
|
@@ -747,6 +777,18 @@ class lattice | |
const_array_ref<size_t>& uids, std::vector<double>& logEframescorrect, | ||
std::vector<double>& Eframescorrectbuf, double& logEframescorrecttotal) const; | ||
|
||
double parallelbackwardlatticeEMBR(parallelstate& parallelstate, const std::vector<float>& edgeacscores, | ||
const float lmf, const float wp, | ||
const float amf, std::vector<double>& edgelogbetas, | ||
std::vector<double>& logbetas) const; | ||
|
||
void EMBRsamplepaths(const std::vector<double> &edgelogbetas, | ||
const std::vector<double> &logbetas, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const bool excludeSpecialWords, std::vector< std::vector<size_t> > & vt_paths) const; | ||
|
||
void EMBRnbestpaths(std::vector<NBestToken>& tokenlattice, std::vector<std::vector<size_t>> & vt_paths, std::vector<double>& path_posterior_probs) const; | ||
|
||
double get_edge_weights(std::vector<size_t>& wids, std::vector<std::vector<size_t>>& vt_paths, std::vector<double>& vt_edge_weights, std::vector<double>& vt_path_posterior_probs, std::string getPathMethodEMBR, double& onebestwer) const; | ||
|
||
static double scoregroundtruth(const_array_ref<size_t> uids, const_array_ref<htkmlfwordsequence::word> transcript, | ||
const std::vector<float>& transcriptunigrams, const msra::math::ssematrixbase& logLLs, | ||
const msra::asr::simplesenonehmm& hset, const float lmf, const float wp, const float amf); | ||
|
@@ -762,6 +804,14 @@ class lattice | |
std::vector<double>& logEframescorrect, std::vector<double>& Eframescorrectbuf, | ||
double& logEframescorrecttotal) const; | ||
|
||
double backwardlatticeEMBR(const std::vector<float>& edgeacscores, parallelstate& parallelstate, std::vector<double> &edgelogbetas, | ||
std::vector<double>& logbetas, | ||
const float lmf, const float wp, const float amf) const; | ||
|
||
void constructnodenbestoken(std::vector<NBestToken> &tokenlattice, const bool wordNbest, size_t numtokens2keep, size_t nidx) const; | ||
|
||
double nbestlatticeEMBR(const std::vector<float> &edgeacscores, parallelstate ¶llelstate, std::vector<NBestToken> &vt_nbesttokens, const size_t numtokens, const bool enforceValidPathEMBR, const bool excludeSpecialWords, | ||
const float lmf, const float wp, const float amf, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numPathsEMBR, std::vector<size_t> wids) const; | ||
public: | ||
// construct from a HTK lattice file | ||
void fromhtklattice(const std::wstring& path, const std::unordered_map<std::string, size_t>& unitmap); | ||
|
@@ -1003,15 +1053,18 @@ class lattice | |
// This will also map the aligninfo entries to the new symbol table, through idmap. | ||
// V1 lattices will be converted. 'spsenoneid' is used in that process. | ||
template <class IDMAP> | ||
void fread(FILE* f, const IDMAP& idmap, size_t spunit) | ||
void fread(FILE* f, const IDMAP& idmap, size_t spunit, std::set<int>& specialwordids) | ||
{ | ||
size_t version = freadtag(f, "LAT "); | ||
if (version == 1) | ||
{ | ||
freadOrDie(&info, sizeof(info), 1, f); | ||
freadvector(f, "NODE", nodes, info.numnodes); | ||
if (nodes.back().t != info.numframes) | ||
RuntimeError("fread: mismatch between info.numframes and last node's time"); | ||
{ | ||
// sometimes, the data is corrputed, let's try to live with it | ||
fprintf(stderr, "fread: mismatch between info.numframes and last node's time: nodes.back().t = %d vs. info.numframes = %d \n", int(nodes.back().t), int(info.numframes)); | ||
} | ||
freadvector(f, "EDGE", edges, info.numedges); | ||
freadvector(f, "ALIG", align); | ||
fcheckTag(f, "END "); | ||
|
@@ -1024,11 +1077,14 @@ class lattice | |
freadOrDie(&info, sizeof(info), 1, f); | ||
freadvector(f, "NODS", nodes, info.numnodes); | ||
if (nodes.back().t != info.numframes) | ||
RuntimeError("fread: mismatch between info.numframes and last node's time"); | ||
{ | ||
// sometimes, the data is corrputed, let's try to live with it | ||
fprintf(stderr, "fread: mismatch between info.numframes and last node's time: nodes.back().t = %d vs. info.numframes = %d \n", int(nodes.back().t), int(info.numframes)); | ||
} | ||
freadvector(f, "EDGS", edges2, info.numedges); // uniqued edges | ||
freadvector(f, "ALNS", uniquededgedatatokens); // uniqued alignments | ||
fcheckTag(f, "END "); | ||
ProcessV2Lattice(spunit, info, uniquededgedatatokens, idmap); | ||
ProcessV2EMBRLattice(spunit, info, uniquededgedatatokens, idmap, specialwordids); | ||
} | ||
else | ||
RuntimeError("fread: unsupported lattice format version"); | ||
|
@@ -1124,7 +1180,28 @@ class lattice | |
rebuildedges(info.impliedspunitid != spunit /*to be able to read somewhat broken V2 lattice archives*/); | ||
|
||
} | ||
|
||
|
||
template <class IDMAP> | ||
void ProcessV2EMBRLattice(size_t spunit, header_v1_v2& info, std::vector<aligninfo>& uniquededgedatatokens, const IDMAP& idmap, std::set<int>& specialwordids) | ||
{ | ||
vt_node_out_edge_indices.resize(info.numnodes); | ||
for (size_t j = 0; j < info.numedges; j++) | ||
{ | ||
// an edge with !NULL pointing to not <s> | ||
// this code make sure if you always start from <s> in the sampled path. | ||
// mask here: we delay the processing in EMBRsamplepaths controlled by flag: enforceValidPathEMBR | ||
// if (edges2[j].S == 0 && nodes[edges2[j].E].wid != 1) continue; | ||
vt_node_out_edge_indices[edges2[j].S].push_back(j); | ||
} | ||
is_special_words.resize(info.numnodes); | ||
for (size_t i = 0; i < info.numnodes; i++) | ||
{ | ||
if (specialwordids.find(int(nodes[i].wid)) != specialwordids.end()) is_special_words[i] = true; | ||
else is_special_words[i] = false; | ||
} | ||
ProcessV2Lattice(spunit, info, uniquededgedatatokens, idmap); | ||
} | ||
|
||
// parallel versions (defined in parallelforwardbackward.cpp) | ||
class parallelstate | ||
{ | ||
|
@@ -1152,6 +1229,10 @@ class lattice | |
const size_t getsilunitid(); | ||
void getedgeacscores(std::vector<float>& edgeacscores); | ||
void getedgealignments(std::vector<unsigned short>& edgealignments); | ||
void getlogbetas(std::vector<double>& logbetas); | ||
void getedgelogbetas(std::vector<double>& edgelogbetas); | ||
void getedgeweights(std::vector<double>& edgeweights); | ||
void setedgeweights(const std::vector<double>& edgeweights); | ||
// to work with CNTK's GPU memory | ||
void setdevice(size_t DeviceId); | ||
size_t getdevice(); | ||
|
@@ -1168,9 +1249,13 @@ class lattice | |
// Note: logLLs and posteriors may be the same matrix (aliased). | ||
double forwardbackward(parallelstate& parallelstate, const class msra::math::ssematrixbase& logLLs, const class msra::asr::simplesenonehmm& hmms, | ||
class msra::math::ssematrixbase& result, class msra::math::ssematrixbase& errorsignalbuf, | ||
const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode, array_ref<size_t> uids, const_array_ref<size_t> bounds = const_array_ref<size_t>(), | ||
const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode, const bool EMBR, const std::string EMBRUnit, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const std::string getPathMethodEMBR, const std::string showWERMode, | ||
const bool excludeSpecialWords, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numRawPathsEMBR, | ||
array_ref<size_t> uids, std::vector<size_t> wids, const_array_ref<size_t> bounds = const_array_ref<size_t>(), | ||
const_array_ref<htkmlfwordsequence::word> transcript = const_array_ref<htkmlfwordsequence::word>(), const std::vector<float>& transcriptunigrams = std::vector<float>()) const; | ||
|
||
|
||
void EMBRerrorsignal(parallelstate ¶llelstate, | ||
const edgealignments &thisedgealignments, std::vector<double>& edge_weights, msra::math::ssematrixbase &errorsignal) const; | ||
std::wstring key; // (keep our own name (key) so we can identify ourselves for diagnostics messages) | ||
const wchar_t* getkey() const | ||
{ | ||
|
@@ -1358,8 +1443,10 @@ class archive | |
if (sscanf(q, "[%" PRIu64 "]%c", &offset, &c) != 1) | ||
#endif | ||
RuntimeError("open: invalid TOC line (bad [] expression): %s", line); | ||
|
||
if (!toc.insert(make_pair(key, latticeref(offset, archiveindex))).second) | ||
RuntimeError("open: TOC entry leads to duplicate key: %s", line); | ||
// sometimes, the training will report this error. I believe it is due to some small data corruption, and fine to go on, so change the error to warning | ||
fprintf(stderr, " open: TOC entry leads to duplicate key: %s\n", line); | ||
} | ||
|
||
// initialize symmaps --alloc the array, but actually read the symmap on demand | ||
|
@@ -1390,7 +1477,7 @@ class archive | |
// Lattices will have unit ids updated according to the modelsymmap. | ||
// V1 lattices will be converted. 'spsenoneid' is used in the conversion for optimizing storing 0-frame /sp/ aligns. | ||
void getlattice(const std::wstring& key, lattice& L, | ||
size_t expectedframes = SIZE_MAX /*if unknown*/) const | ||
std::set<int>& specialwordids, size_t expectedframes = SIZE_MAX) const | ||
{ | ||
auto iter = toc.find(key); | ||
if (iter == toc.end()) | ||
|
@@ -1417,7 +1504,7 @@ class archive | |
// seek to start | ||
fsetpos(f, offset); | ||
// get it | ||
L.fread(f, idmap, spunit); | ||
L.fread(f, idmap, spunit, specialwordids); | ||
L.setverbosity(verbosity); | ||
#ifdef HACK_IN_SILENCE // hack to simulate DEL in the lattice | ||
const size_t silunit = getid(modelsymmap, "sil"); | ||
|
@@ -1451,7 +1538,8 @@ class archive | |
// - dump to stdout | ||
// - merge two lattices (for merging numer into denom lattices) | ||
static void convert(const std::wstring& intocpath, const std::wstring& intocpath2, const std::wstring& outpath, | ||
const msra::asr::simplesenonehmm& hset); | ||
const msra::asr::simplesenonehmm& hset, std::set<int>& specialwordids); | ||
}; | ||
}; | ||
}; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need them sorted? Otherwise we can use unordered_set
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change to unordered set