From f7266886f0d478d802a88f2ce82b71f27c37bf07 Mon Sep 17 00:00:00 2001 From: Hieu Hoang Date: Thu, 4 Mar 2021 04:18:19 +0000 Subject: [PATCH] move logits to its own file --- src/layers/embedding.h | 2 ++ src/layers/generic.h | 66 ------------------------------------------ src/layers/loss.h | 2 +- src/layers/output.h | 1 + src/models/states.h | 2 +- 5 files changed, 5 insertions(+), 68 deletions(-) diff --git a/src/layers/embedding.h b/src/layers/embedding.h index 91fd0b9d7..b7898c76e 100644 --- a/src/layers/embedding.h +++ b/src/layers/embedding.h @@ -4,6 +4,8 @@ namespace marian { +class FactoredVocab; + // A regular embedding layer. // Note that this also applies dropout if the option is passed (pass 0 when in inference mode). // It is best to not use Embedding directly, but rather via getEmbeddingLayer() in diff --git a/src/layers/generic.h b/src/layers/generic.h index 6d953fd8c..eddd597e8 100644 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -97,72 +97,6 @@ class EncoderDecoderLayerBase : public LayerBase { Ptr getEmbeddingLayer(bool ulr = false) const; }; -class FactoredVocab; - -// To support factors, any output projection (that is followed by a softmax) must -// retain multiple outputs, one for each factor. Such layer returns not a single Expr, -// but a Logits object that contains multiple. -// This allows to compute softmax values in a factored manner, where we never create -// a fully expanded list of all factor combinations. -class RationalLoss; -class Logits { -public: - Logits() {} - explicit Logits(Ptr logits) { // single-output constructor - logits_.push_back(logits); - } - explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count) - Logits(std::vector>&& logits, Ptr embeddingFactorMapping) // factored-output constructor - : logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {} - Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors - Expr getFactoredLogits(size_t groupIndex, Ptr shortlist = nullptr, const std::vector& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle - //Ptr getRationalLoss() const; // assume it holds a loss: get that - Expr applyLossFunction(const Words& labels, const std::function& lossFn) const; - Logits applyUnaryFunction(const std::function& f) const; // clone this but apply f to all loss values - Logits applyUnaryFunctions(const std::function& f1, const std::function& fother) const; // clone this but apply f1 to first and fother to to all other values - - struct MaskedFactorIndices { - std::vector indices; // factor index, or 0 if masked - std::vector masks; - void reserve(size_t n) { indices.reserve(n); masks.reserve(n); } - void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries - MaskedFactorIndices() {} - MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case - }; - std::vector factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices - Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only - size_t getNumFactorGroups() const { return logits_.size(); } - bool empty() const { return logits_.empty(); } - Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_ -private: - // helper functions - Ptr graph() const; - Expr constant(const Shape& shape, const std::vector& data) const { return graph()->constant(shape, inits::fromVector(data)); } - Expr constant(const Shape& shape, const std::vector& data) const { return graph()->constant(shape, inits::fromVector(data)); } - template Expr constant(const std::vector& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector - Expr indices(const std::vector& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type - std::vector getFactorMasks(size_t factorGroup, const std::vector& indices) const; -private: - // members - // @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr - std::vector> logits_; // [group id][B..., num factors in group] - Ptr factoredVocab_; -}; - -// Unary function that returns a Logits object -// Also implements IUnaryLayer, since Logits can be cast to Expr. -// This interface is implemented by all layers that are of the form of a unary function -// that returns multiple logits, to support factors. -struct IUnaryLogitLayer : public IUnaryLayer { - virtual Logits applyAsLogits(Expr) = 0; - virtual Logits applyAsLogits(const std::vector& es) { - ABORT_IF(es.size() > 1, "Not implemented"); // simple stub - return applyAsLogits(es.front()); - } - virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); } - virtual Expr apply(const std::vector& es) override { return applyAsLogits(es).getLogits(); } -}; - namespace mlp { class Dense : public LayerBase, public IUnaryLayer { diff --git a/src/layers/loss.h b/src/layers/loss.h index d7bc19e4a..ba93cdac7 100644 --- a/src/layers/loss.h +++ b/src/layers/loss.h @@ -1,7 +1,7 @@ #pragma once #include "graph/expression_operators.h" -#include "layers/generic.h" // for Logits (Frank's factor hack) +#include "layers/logits.h" // for Logits (Frank's factor hack) #include "data/types.h" namespace marian { diff --git a/src/layers/output.h b/src/layers/output.h index d091556a4..92e7eb25e 100644 --- a/src/layers/output.h +++ b/src/layers/output.h @@ -2,6 +2,7 @@ #include "marian.h" #include "generic.h" +#include "logits.h" #include "data/shortlist.h" #include "layers/factory.h" diff --git a/src/models/states.h b/src/models/states.h index c2f9ee05a..cfb6fd1b8 100644 --- a/src/models/states.h +++ b/src/models/states.h @@ -1,7 +1,7 @@ #pragma once #include "marian.h" -#include "layers/generic.h" // @HACK: for factored embeddings only so far +#include "layers/logits.h" // @HACK: for factored embeddings only so far #include "rnn/types.h" namespace marian {