Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Guoye/merge embr final #3560

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Source/Common/DataReader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,11 @@ bool DataReader::GetMinibatch(StreamMinibatchInputs& matrices)
// uids - lables stored in size_t vector instead of ElemType matrix
// boundary - phone boundaries
// returns - true if there are more minibatches, false if no more minibatches remain
bool DataReader::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
bool DataReader::GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& wids, vector<short>& nws, vector<size_t>& boundaries, vector<size_t>& extrauttmap)
{
bool bRet = true;
for (size_t i = 0; i < m_ioNames.size(); i++)
bRet &= m_dataReaders[m_ioNames[i]]->GetMinibatch4SE(latticeinput, uids, boundaries, extrauttmap);
bRet &= m_dataReaders[m_ioNames[i]]->GetMinibatch4SE(latticeinput, uids, wids, nws, boundaries, extrauttmap);
return bRet;
}

Expand Down
4 changes: 2 additions & 2 deletions Source/Common/Include/DataReader.h
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ class DATAREADER_API IDataReader
}

virtual bool GetMinibatch(StreamMinibatchInputs& matrices) = 0;
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& /*latticeinput*/, vector<size_t>& /*uids*/, vector<size_t>& /*boundaries*/, vector<size_t>& /*extrauttmap*/)
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& /*latticeinput*/, vector<size_t>& /*uids*/, vector<size_t>& /*wids*/, vector<short>& /*nws*/, vector<size_t>& /*boundaries*/, vector<size_t>& /*extrauttmap*/)
{
NOT_IMPLEMENTED;
};
Expand Down Expand Up @@ -444,7 +444,7 @@ class DataReader : public IDataReader, protected Plugin, public ScriptableObject
// [out] each matrix resized if necessary containing data.
// returns - true if there are more minibatches, false if no more minibatches remain
virtual bool GetMinibatch(StreamMinibatchInputs& matrices);
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
virtual bool GetMinibatch4SE(std::vector<shared_ptr<const msra::dbn::latticepair>>& latticeinput, vector<size_t>& uids, vector<size_t>& wids, vector<short>& nws, vector<size_t>& boundaries, vector<size_t>& extrauttmap);
virtual bool GetHmmData(msra::asr::simplesenonehmm* hmm);

size_t GetNumParallelSequencesForFixingBPTTMode();
Expand Down
118 changes: 103 additions & 15 deletions Source/Common/Include/latticearchive.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include <algorithm> // for find()
#include "simplesenonehmm.h"
#include "Matrix.h"

#include <set>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need them sorted? Otherwise we can use unordered_set

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change to unordered set

namespace msra { namespace math {

class ssematrixbase;
Expand Down Expand Up @@ -67,7 +67,28 @@ enum mbrclassdefinition // used to identify definition of class in minimum bayes
// ===========================================================================
class lattice
{
public:
public:
// definie structure for nbest EMBR
struct TokenInfo
{
double score; // the score of the token
size_t prev_edge_index; // edge ending with this token, edge start points to the previous node
size_t prev_token_index; // the token index in the previous node
};
struct PrevTokenInfo
{
size_t prev_edge_index;
size_t prev_token_index;
double path_score; // use pure to indicatethe path score does not consider the WER of the path
};

struct NBestToken
{
// for sorting purpose
// make sure the map is stored with keys in descending order
std::map<double, std::vector<PrevTokenInfo>, std::greater <double>> mp_score_token_infos; // for sorting the tokens in map
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not ordered_map? -- similarly in other places.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the map is for ordering purpose

std::vector<TokenInfo> vt_nbest_tokens; // stores the nbest tokens in the node
};
struct header_v1_v2
{
size_t numnodes : 32;
Expand All @@ -90,12 +111,15 @@ class lattice
static const unsigned int NOEDGE = 0xffffff; // 24 bits
// static_assert (sizeof (nodeinfo) == 8, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary
// ensure type size as these are expected to be of this size in the files we read
static_assert(sizeof(nodeinfo) == 2, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary
static_assert(sizeof(nodeinfo) == 16, "unexpected size of nodeeinfo"); // note: int64_t required to allow going across 32-bit boundary
static_assert(sizeof(edgeinfowithscores) == 16, "unexpected size of edgeinfowithscores");
static_assert(sizeof(aligninfo) == 4, "unexpected size of aligninfo");
std::vector<nodeinfo> nodes;
mutable std::vector<std::vector<uint64_t>> vt_node_out_edge_indices; // vt_node_out_edge_indices[i]: it stores the outgoing edge indices starting from node i
std::vector<bool> is_special_words; // true if it is special words that do not count to WER computation, false if it is not
std::vector<edgeinfowithscores> edges;
std::vector<aligninfo> align;

// V2 lattices --for a while, we will store both in RAM, until all code is updated
static int fsgn(float f)
{
Expand Down Expand Up @@ -217,6 +241,10 @@ class lattice
public: // TODO: make private again once
// construct from edges/align
// This is also used for merging, where the edges[] array is not correctly sorted. So don't assume this here.
void erase_node_out_edges(size_t nodeidx, size_t edgeidx_start, size_t edgeidx_end) const
{
vt_node_out_edge_indices[nodeidx].erase(vt_node_out_edge_indices[nodeidx].begin() + edgeidx_start, vt_node_out_edge_indices[nodeidx].begin() + edgeidx_end);
}
void builduniquealignments(size_t spunit = SIZE_MAX /*fix this later*/)
{
// infer /sp/ unit if not given
Expand Down Expand Up @@ -701,6 +729,7 @@ class lattice
const float lmf, const float wp, const float amf, const_array_ref<size_t>& uids,
const edgealignments& thisedgealignments, std::vector<double>& Eframescorrect) const;


void sMBRerrorsignal(parallelstate& parallelstate,
msra::math::ssematrixbase& errorsignal, msra::math::ssematrixbase& errorsignalneg,
const std::vector<double>& logpps, const float amf, double minlogpp,
Expand Down Expand Up @@ -736,7 +765,8 @@ class lattice
const std::vector<double>& logpps, const float amf,
const std::vector<double>& logEframescorrect, const double logEframescorrecttotal,
msra::math::ssematrixbase& errorsignal, msra::math::ssematrixbase& errorsignalneg) const;

void parallelEMBRerrorsignal(parallelstate& parallelstate, const edgealignments& thisedgealignments,
const std::vector<double>& edgeweights, msra::math::ssematrixbase& errorsignal) const;
void parallelmmierrorsignal(parallelstate& parallelstate, const edgealignments& thisedgealignments,
const std::vector<double>& logpps, msra::math::ssematrixbase& errorsignal) const;

Expand All @@ -747,6 +777,18 @@ class lattice
const_array_ref<size_t>& uids, std::vector<double>& logEframescorrect,
std::vector<double>& Eframescorrectbuf, double& logEframescorrecttotal) const;

double parallelbackwardlatticeEMBR(parallelstate& parallelstate, const std::vector<float>& edgeacscores,
const float lmf, const float wp,
const float amf, std::vector<double>& edgelogbetas,
std::vector<double>& logbetas) const;

void EMBRsamplepaths(const std::vector<double> &edgelogbetas,
const std::vector<double> &logbetas, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const bool excludeSpecialWords, std::vector< std::vector<size_t> > & vt_paths) const;

void EMBRnbestpaths(std::vector<NBestToken>& tokenlattice, std::vector<std::vector<size_t>> & vt_paths, std::vector<double>& path_posterior_probs) const;

double get_edge_weights(std::vector<size_t>& wids, std::vector<std::vector<size_t>>& vt_paths, std::vector<double>& vt_edge_weights, std::vector<double>& vt_path_posterior_probs, std::string getPathMethodEMBR, double& onebestwer) const;

static double scoregroundtruth(const_array_ref<size_t> uids, const_array_ref<htkmlfwordsequence::word> transcript,
const std::vector<float>& transcriptunigrams, const msra::math::ssematrixbase& logLLs,
const msra::asr::simplesenonehmm& hset, const float lmf, const float wp, const float amf);
Expand All @@ -762,6 +804,14 @@ class lattice
std::vector<double>& logEframescorrect, std::vector<double>& Eframescorrectbuf,
double& logEframescorrecttotal) const;

double backwardlatticeEMBR(const std::vector<float>& edgeacscores, parallelstate& parallelstate, std::vector<double> &edgelogbetas,
std::vector<double>& logbetas,
const float lmf, const float wp, const float amf) const;

void constructnodenbestoken(std::vector<NBestToken> &tokenlattice, const bool wordNbest, size_t numtokens2keep, size_t nidx) const;

double nbestlatticeEMBR(const std::vector<float> &edgeacscores, parallelstate &parallelstate, std::vector<NBestToken> &vt_nbesttokens, const size_t numtokens, const bool enforceValidPathEMBR, const bool excludeSpecialWords,
const float lmf, const float wp, const float amf, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numPathsEMBR, std::vector<size_t> wids) const;
public:
// construct from a HTK lattice file
void fromhtklattice(const std::wstring& path, const std::unordered_map<std::string, size_t>& unitmap);
Expand Down Expand Up @@ -1003,15 +1053,18 @@ class lattice
// This will also map the aligninfo entries to the new symbol table, through idmap.
// V1 lattices will be converted. 'spsenoneid' is used in that process.
template <class IDMAP>
void fread(FILE* f, const IDMAP& idmap, size_t spunit)
void fread(FILE* f, const IDMAP& idmap, size_t spunit, std::set<int>& specialwordids)
{
size_t version = freadtag(f, "LAT ");
if (version == 1)
{
freadOrDie(&info, sizeof(info), 1, f);
freadvector(f, "NODE", nodes, info.numnodes);
if (nodes.back().t != info.numframes)
RuntimeError("fread: mismatch between info.numframes and last node's time");
{
// sometimes, the data is corrputed, let's try to live with it
fprintf(stderr, "fread: mismatch between info.numframes and last node's time: nodes.back().t = %d vs. info.numframes = %d \n", int(nodes.back().t), int(info.numframes));
}
freadvector(f, "EDGE", edges, info.numedges);
freadvector(f, "ALIG", align);
fcheckTag(f, "END ");
Expand All @@ -1024,11 +1077,14 @@ class lattice
freadOrDie(&info, sizeof(info), 1, f);
freadvector(f, "NODS", nodes, info.numnodes);
if (nodes.back().t != info.numframes)
RuntimeError("fread: mismatch between info.numframes and last node's time");
{
// sometimes, the data is corrputed, let's try to live with it
fprintf(stderr, "fread: mismatch between info.numframes and last node's time: nodes.back().t = %d vs. info.numframes = %d \n", int(nodes.back().t), int(info.numframes));
}
freadvector(f, "EDGS", edges2, info.numedges); // uniqued edges
freadvector(f, "ALNS", uniquededgedatatokens); // uniqued alignments
fcheckTag(f, "END ");
ProcessV2Lattice(spunit, info, uniquededgedatatokens, idmap);
ProcessV2EMBRLattice(spunit, info, uniquededgedatatokens, idmap, specialwordids);
}
else
RuntimeError("fread: unsupported lattice format version");
Expand Down Expand Up @@ -1124,7 +1180,28 @@ class lattice
rebuildedges(info.impliedspunitid != spunit /*to be able to read somewhat broken V2 lattice archives*/);

}


template <class IDMAP>
void ProcessV2EMBRLattice(size_t spunit, header_v1_v2& info, std::vector<aligninfo>& uniquededgedatatokens, const IDMAP& idmap, std::set<int>& specialwordids)
{
vt_node_out_edge_indices.resize(info.numnodes);
for (size_t j = 0; j < info.numedges; j++)
{
// an edge with !NULL pointing to not <s>
// this code make sure if you always start from <s> in the sampled path.
// mask here: we delay the processing in EMBRsamplepaths controlled by flag: enforceValidPathEMBR
// if (edges2[j].S == 0 && nodes[edges2[j].E].wid != 1) continue;
vt_node_out_edge_indices[edges2[j].S].push_back(j);
}
is_special_words.resize(info.numnodes);
for (size_t i = 0; i < info.numnodes; i++)
{
if (specialwordids.find(int(nodes[i].wid)) != specialwordids.end()) is_special_words[i] = true;
else is_special_words[i] = false;
}
ProcessV2Lattice(spunit, info, uniquededgedatatokens, idmap);
}

// parallel versions (defined in parallelforwardbackward.cpp)
class parallelstate
{
Expand Down Expand Up @@ -1152,6 +1229,10 @@ class lattice
const size_t getsilunitid();
void getedgeacscores(std::vector<float>& edgeacscores);
void getedgealignments(std::vector<unsigned short>& edgealignments);
void getlogbetas(std::vector<double>& logbetas);
void getedgelogbetas(std::vector<double>& edgelogbetas);
void getedgeweights(std::vector<double>& edgeweights);
void setedgeweights(const std::vector<double>& edgeweights);
// to work with CNTK's GPU memory
void setdevice(size_t DeviceId);
size_t getdevice();
Expand All @@ -1168,9 +1249,13 @@ class lattice
// Note: logLLs and posteriors may be the same matrix (aliased).
double forwardbackward(parallelstate& parallelstate, const class msra::math::ssematrixbase& logLLs, const class msra::asr::simplesenonehmm& hmms,
class msra::math::ssematrixbase& result, class msra::math::ssematrixbase& errorsignalbuf,
const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode, array_ref<size_t> uids, const_array_ref<size_t> bounds = const_array_ref<size_t>(),
const float lmf, const float wp, const float amf, const float boostingfactor, const bool sMBRmode, const bool EMBR, const std::string EMBRUnit, const size_t numPathsEMBR, const bool enforceValidPathEMBR, const std::string getPathMethodEMBR, const std::string showWERMode,
const bool excludeSpecialWords, const bool wordNbest, const bool useAccInNbest, const float accWeightInNbest, const size_t numRawPathsEMBR,
array_ref<size_t> uids, std::vector<size_t> wids, const_array_ref<size_t> bounds = const_array_ref<size_t>(),
const_array_ref<htkmlfwordsequence::word> transcript = const_array_ref<htkmlfwordsequence::word>(), const std::vector<float>& transcriptunigrams = std::vector<float>()) const;


void EMBRerrorsignal(parallelstate &parallelstate,
const edgealignments &thisedgealignments, std::vector<double>& edge_weights, msra::math::ssematrixbase &errorsignal) const;
std::wstring key; // (keep our own name (key) so we can identify ourselves for diagnostics messages)
const wchar_t* getkey() const
{
Expand Down Expand Up @@ -1358,8 +1443,10 @@ class archive
if (sscanf(q, "[%" PRIu64 "]%c", &offset, &c) != 1)
#endif
RuntimeError("open: invalid TOC line (bad [] expression): %s", line);

if (!toc.insert(make_pair(key, latticeref(offset, archiveindex))).second)
RuntimeError("open: TOC entry leads to duplicate key: %s", line);
// sometimes, the training will report this error. I believe it is due to some small data corruption, and fine to go on, so change the error to warning
fprintf(stderr, " open: TOC entry leads to duplicate key: %s\n", line);
}

// initialize symmaps --alloc the array, but actually read the symmap on demand
Expand Down Expand Up @@ -1390,7 +1477,7 @@ class archive
// Lattices will have unit ids updated according to the modelsymmap.
// V1 lattices will be converted. 'spsenoneid' is used in the conversion for optimizing storing 0-frame /sp/ aligns.
void getlattice(const std::wstring& key, lattice& L,
size_t expectedframes = SIZE_MAX /*if unknown*/) const
std::set<int>& specialwordids, size_t expectedframes = SIZE_MAX) const
{
auto iter = toc.find(key);
if (iter == toc.end())
Expand All @@ -1417,7 +1504,7 @@ class archive
// seek to start
fsetpos(f, offset);
// get it
L.fread(f, idmap, spunit);
L.fread(f, idmap, spunit, specialwordids);
L.setverbosity(verbosity);
#ifdef HACK_IN_SILENCE // hack to simulate DEL in the lattice
const size_t silunit = getid(modelsymmap, "sil");
Expand Down Expand Up @@ -1451,7 +1538,8 @@ class archive
// - dump to stdout
// - merge two lattices (for merging numer into denom lattices)
static void convert(const std::wstring& intocpath, const std::wstring& intocpath2, const std::wstring& outpath,
const msra::asr::simplesenonehmm& hset);
const msra::asr::simplesenonehmm& hset, std::set<int>& specialwordids);
};
};
};

4 changes: 2 additions & 2 deletions Source/Common/Include/latticesource.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,10 @@ class latticesource
#endif
}

void getlattices(const std::wstring& key, std::shared_ptr<const latticepair>& L, size_t expectedframes) const
void getlattices(const std::wstring& key, std::shared_ptr<const latticepair>& L, size_t expectedframes, std::set<int>& specialwordids) const
{
std::shared_ptr<latticepair> LP(new latticepair);
denlattices.getlattice(key, LP->second, expectedframes); // this loads the lattice from disk, using the existing L.second object
denlattices.getlattice(key, LP->second, specialwordids, expectedframes); // this loads the lattice from disk, using the existing L.second object
L = LP;
}

Expand Down
9 changes: 7 additions & 2 deletions Source/Common/Include/latticestorage.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <stdexcept>
#include <stdint.h>
#include <cstdio>
#include <vector>

#undef INITIAL_STRANGE // [v-hansu] initialize structs to strange values
#define PARALLEL_SIL // [v-hansu] process sil on CUDA, used in other files, please search this
Expand All @@ -30,11 +31,15 @@ struct nodeinfo
// uint64_t firstinedge : 24; // index of first incoming edge
// uint64_t firstoutedge : 24; // index of first outgoing edge
// uint64_t t : 16; // time associated with this

uint64_t wid; // word ID associated with the node
unsigned short t; // time associated with this
nodeinfo(size_t pt)
: t((unsigned short) pt) // , firstinedge (NOEDGE), firstoutedge (NOEDGE)

nodeinfo(size_t pt, size_t pwid)
: t((unsigned short)pt), wid(pwid)
{
checkoverflow(t, pt, "nodeinfo::t");
checkoverflow(wid, pwid, "nodeinfo::wid");
// checkoverflow (firstinedge, NOEDGE, "nodeinfo::firstinedge");
// checkoverflow (firstoutedge, NOEDGE, "nodeinfo::firstoutedge");
}
Expand Down
Loading