Skip to content

Commit

Permalink
move logits to its own file
Browse files Browse the repository at this point in the history
  • Loading branch information
Hieu Hoang committed Mar 4, 2021
1 parent ca47eab commit f726688
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 68 deletions.
2 changes: 2 additions & 0 deletions src/layers/embedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

namespace marian {

class FactoredVocab;

// A regular embedding layer.
// Note that this also applies dropout if the option is passed (pass 0 when in inference mode).
// It is best to not use Embedding directly, but rather via getEmbeddingLayer() in
Expand Down
66 changes: 0 additions & 66 deletions src/layers/generic.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,72 +97,6 @@ class EncoderDecoderLayerBase : public LayerBase {
Ptr<IEmbeddingLayer> getEmbeddingLayer(bool ulr = false) const;
};

class FactoredVocab;

// To support factors, any output projection (that is followed by a softmax) must
// retain multiple outputs, one for each factor. Such layer returns not a single Expr,
// but a Logits object that contains multiple.
// This allows to compute softmax values in a factored manner, where we never create
// a fully expanded list of all factor combinations.
class RationalLoss;
class Logits {
public:
Logits() {}
explicit Logits(Ptr<RationalLoss> logits) { // single-output constructor
logits_.push_back(logits);
}
explicit Logits(Expr logits); // single-output constructor from Expr only (RationalLoss has no count)
Logits(std::vector<Ptr<RationalLoss>>&& logits, Ptr<FactoredVocab> embeddingFactorMapping) // factored-output constructor
: logits_(std::move(logits)), factoredVocab_(embeddingFactorMapping) {}
Expr getLogits() const; // assume it holds logits: get them, possibly aggregating over factors
Expr getFactoredLogits(size_t groupIndex, Ptr<data::Shortlist> shortlist = nullptr, const std::vector<IndexType>& hypIndices = {}, size_t beamSize = 0) const; // get logits for only one factor group, with optional reshuffle
//Ptr<RationalLoss> getRationalLoss() const; // assume it holds a loss: get that
Expr applyLossFunction(const Words& labels, const std::function<Expr(Expr/*logits*/,Expr/*indices*/)>& lossFn) const;
Logits applyUnaryFunction(const std::function<Expr(Expr)>& f) const; // clone this but apply f to all loss values
Logits applyUnaryFunctions(const std::function<Expr(Expr)>& f1, const std::function<Expr(Expr)>& fother) const; // clone this but apply f1 to first and fother to to all other values

struct MaskedFactorIndices {
std::vector<WordIndex> indices; // factor index, or 0 if masked
std::vector<float> masks;
void reserve(size_t n) { indices.reserve(n); masks.reserve(n); }
void push_back(size_t factorIndex); // push back into both arrays, setting mask and index to 0 for invalid entries
MaskedFactorIndices() {}
MaskedFactorIndices(const Words& words) { indices = toWordIndexVector(words); } // we can leave masks uninitialized for this special use case
};
std::vector<MaskedFactorIndices> factorizeWords(const Words& words) const; // breaks encoded Word into individual factor indices
Tensor getFactoredLogitsTensor(size_t factorGroup) const; // used for breakDown() only
size_t getNumFactorGroups() const { return logits_.size(); }
bool empty() const { return logits_.empty(); }
Logits withCounts(const Expr& count) const; // create new Logits with 'count' implanted into all logits_
private:
// helper functions
Ptr<ExpressionGraph> graph() const;
Expr constant(const Shape& shape, const std::vector<float>& data) const { return graph()->constant(shape, inits::fromVector(data)); }
Expr constant(const Shape& shape, const std::vector<uint32_t>& data) const { return graph()->constant(shape, inits::fromVector(data)); }
template<typename T> Expr constant(const std::vector<T>& data) const { return constant(Shape{(int)data.size()}, data); } // same as constant() but assuming vector
Expr indices(const std::vector<uint32_t>& data) const { return graph()->indices(data); } // actually the same as constant(data) for this data type
std::vector<float> getFactorMasks(size_t factorGroup, const std::vector<WordIndex>& indices) const;
private:
// members
// @TODO: we don't use the RationalLoss component anymore, can be removed again, and replaced just by the Expr
std::vector<Ptr<RationalLoss>> logits_; // [group id][B..., num factors in group]
Ptr<FactoredVocab> factoredVocab_;
};

// Unary function that returns a Logits object
// Also implements IUnaryLayer, since Logits can be cast to Expr.
// This interface is implemented by all layers that are of the form of a unary function
// that returns multiple logits, to support factors.
struct IUnaryLogitLayer : public IUnaryLayer {
virtual Logits applyAsLogits(Expr) = 0;
virtual Logits applyAsLogits(const std::vector<Expr>& es) {
ABORT_IF(es.size() > 1, "Not implemented"); // simple stub
return applyAsLogits(es.front());
}
virtual Expr apply(Expr e) override { return applyAsLogits(e).getLogits(); }
virtual Expr apply(const std::vector<Expr>& es) override { return applyAsLogits(es).getLogits(); }
};

namespace mlp {

class Dense : public LayerBase, public IUnaryLayer {
Expand Down
2 changes: 1 addition & 1 deletion src/layers/loss.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "graph/expression_operators.h"
#include "layers/generic.h" // for Logits (Frank's factor hack)
#include "layers/logits.h" // for Logits (Frank's factor hack)
#include "data/types.h"

namespace marian {
Expand Down
1 change: 1 addition & 0 deletions src/layers/output.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "marian.h"
#include "generic.h"
#include "logits.h"
#include "data/shortlist.h"
#include "layers/factory.h"

Expand Down
2 changes: 1 addition & 1 deletion src/models/states.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#pragma once

#include "marian.h"
#include "layers/generic.h" // @HACK: for factored embeddings only so far
#include "layers/logits.h" // @HACK: for factored embeddings only so far
#include "rnn/types.h"

namespace marian {
Expand Down

0 comments on commit f726688

Please sign in to comment.