Skip to content

Commit

Permalink
move logits to its own file
Browse files Browse the repository at this point in the history
  • Loading branch information
Hieu Hoang committed Mar 4, 2021
1 parent f726688 commit 42406cc
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 71 deletions.
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ set(MARIAN_SOURCES
layers/lsh.cpp
layers/embedding.cpp
layers/output.cpp
layers/logits.cpp

rnn/cells.cpp
rnn/attention.cpp
Expand Down
71 changes: 0 additions & 71 deletions src/layers/generic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,76 +143,5 @@ namespace marian {
#endif
}

void Logits::MaskedFactorIndices::push_back(size_t factorIndex) {
bool isValid = FactoredVocab::isFactorValid(factorIndex);
indices.push_back(isValid ? (WordIndex)factorIndex : 0);
masks.push_back((float)isValid);
}

std::vector<Logits::MaskedFactorIndices> Logits::factorizeWords(const Words& words) const { // [numGroups][words.size()] -> breaks encoded Word into individual factor indices
if (!factoredVocab_) {
ABORT_IF(logits_.size() != 1, "Factors without factor mappings??");
return {MaskedFactorIndices(words)};
}
auto numGroups = factoredVocab_->getNumGroups();
std::vector<MaskedFactorIndices> res(numGroups);
for (size_t g = 0; g < numGroups; g++) {
auto& resg = res[g];
resg.reserve(words.size());
for (const auto& word : words)
resg.push_back(factoredVocab_->getFactor(word, g));
}
return res;
}

//// use first factor of each word to determine whether it has a specific factor
//std::vector<float> Logits::getFactorMasks(const Words& words, size_t factorGroup) const { // 1.0 for words that do have this factor; else 0
// std::vector<float> res;
// res.reserve(words.size());
// for (const auto& word : words) {
// auto lemma = factoredVocab_->getFactor(word, 0);
// res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
// }
// return res;
//}

// return a vector of 1 or 0 indicating for each lemma whether it has a specific factor
// If 'indices' is given, then return the masks for the indices; otherwise for all lemmas
std::vector<float> Logits::getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const { // [lemmaIndex] -> 1.0 for words that do have this factor; else 0
size_t n = indices.empty() ? (factoredVocab_->getGroupRange(0).second - factoredVocab_->getGroupRange(0).first) : indices.size();
std::vector<float> res;
res.reserve(n);
// @TODO: we should rearrange lemmaHasFactorGroup as vector[groups[i] of float; then move this into FactoredVocab
for (size_t i = 0; i < n; i++) {
auto lemma = indices.empty() ? i : (indices[i] - factoredVocab_->getGroupRange(0).first);
res.push_back((float)factoredVocab_->lemmaHasFactorGroup(lemma, factorGroup));
}
return res;
}

Logits Logits::applyUnaryFunction(const std::function<Expr(Expr)>& f) const { // clone this but apply f to all loss values
std::vector<Ptr<RationalLoss>> newLogits;
for (const auto& l : logits_)
newLogits.emplace_back(New<RationalLoss>(f(l->loss()), l->count()));
return Logits(std::move(newLogits), factoredVocab_);
}

Logits Logits::applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const {
std::vector<Ptr<RationalLoss>> newLogits;
bool first = true;
for (const auto& l : logits_) {
newLogits.emplace_back(New<RationalLoss>((first?f1:fother)(l->loss()), l->count())); // f1 for first, fother for all others
first = false;
}
return Logits(std::move(newLogits), factoredVocab_);
}

// @TODO: code dup with above; we can merge it into applyToRationalLoss()
Logits Logits::withCounts(const Expr& count) const { // create new Logits with 'count' implanted into all logits_
std::vector<Ptr<RationalLoss>> newLogits;
for (const auto& l : logits_)
newLogits.emplace_back(New<RationalLoss>(l->loss(), count));
return Logits(std::move(newLogits), factoredVocab_);
}

} // namespace marian

0 comments on commit 42406cc

Please sign in to comment.