Skip to content

Commit

Permalink
Add back const correctness
Browse files Browse the repository at this point in the history
Pull Request: #22
  • Loading branch information
georg3tom authored Oct 21, 2023
1 parent 32a589b commit d9ffd96
Show file tree
Hide file tree
Showing 15 changed files with 172 additions and 148 deletions.
10 changes: 5 additions & 5 deletions slimt/Frontend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ std::vector<Response> Blocking::translate(const Ptr<Model> &model,
promise.set_value(std::move(response));
};

auto &processor = model->processor();
const auto &processor = model->processor();
auto [annotated, segments] =
processor.process(std::move(source), config_.wrap_length);
auto request = make_request(id_, model, cache_, std::move(annotated),
Expand Down Expand Up @@ -172,7 +172,7 @@ std::vector<Response> Blocking::pivot(const Ptr<Model> &first,
response = std::move(combined);
};

TextProcessor &processor = second->processor();
const TextProcessor &processor = second->processor();
auto [annotated, segments] = processor.process(source_to_pivot.target);
auto request = make_request(id_, second, cache_, std::move(annotated),
std::move(segments), continuation);
Expand Down Expand Up @@ -229,7 +229,7 @@ std::future<Response> Async::translate(const Ptr<Model> &model,
promise->set_value(std::move(response));
};

TextProcessor &processor = model->processor();
const TextProcessor &processor = model->processor();
auto [annotated, segments] =
processor.process(std::move(source), config_.wrap_length);
auto request = make_request(id_, model, cache_, std::move(annotated),
Expand Down Expand Up @@ -270,7 +270,7 @@ std::future<Response> Async::pivot(const Ptr<Model> &first,
promise.set_value(std::move(response));
};

TextProcessor &processor = second->processor();
const TextProcessor &processor = second->processor();
auto [annotated, segments] = processor.process(source_to_pivot.target);

auto request =
Expand All @@ -280,7 +280,7 @@ std::future<Response> Async::pivot(const Ptr<Model> &first,
batcher_.enqueue(second, request);
};

TextProcessor &processor = first->processor();
const TextProcessor &processor = first->processor();
auto [annotated, segments] =
processor.process(std::move(source), config_.wrap_length);
auto request = make_request(id_, first, cache_, std::move(annotated),
Expand Down
2 changes: 1 addition & 1 deletion slimt/Input.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Input::Input(size_t batch_size, size_t sequence_length, uint32_t pad_id,
pad_id_(pad_id),
limit_factor_(limit_factor) {}

void Input::add(std::vector<uint32_t> &words) {
void Input::add(const std::vector<uint32_t> &words) {
size_t sequence_length = batch_.dim(-1);
size_t batch_size = batch_.dim(-2);

Expand Down
8 changes: 4 additions & 4 deletions slimt/Input.hh
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ class Input {
Input(size_t batch_size, size_t sequence_length, uint32_t pad_id,
size_t limit_factor);

void add(std::vector<uint32_t> &words);
Tensor &indices() { return batch_; }
void add(const std::vector<uint32_t> &words);
const Tensor &indices() const { return batch_; }
Tensor &mask() { return mask_; }
std::vector<uint32_t> &words() { return words_; }
std::vector<size_t> &lengths() { return lengths_; }
const std::vector<uint32_t> &words() const { return words_; }
const std::vector<size_t> &lengths() const { return lengths_; }
size_t index() const { return index_; }
float occupancy();
float limit_factor() const;
Expand Down
8 changes: 4 additions & 4 deletions slimt/Model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ void update_alignment(const std::vector<size_t> &lengths,
}
} // namespace

Histories Model::decode(Tensor &encoder_out, Input &input) {
Histories Model::decode(Tensor &encoder_out, Input &input) const {
// Prepare a shortlist for the entire input.
size_t batch_size = encoder_out.dim(-3);
size_t source_sequence_length = encoder_out.dim(-2);
Expand Down Expand Up @@ -113,7 +113,7 @@ Histories Model::decode(Tensor &encoder_out, Input &input) {
Sentences sentences(batch_size);
Alignments alignments(sentences.size());

Decoder &decoder = transformer_.decoder();
const Decoder &decoder = transformer_.decoder();
Words previous_slice = {};
std::vector<Tensor> states = decoder.start_states(batch_size);
auto [logits, attn] =
Expand Down Expand Up @@ -146,8 +146,8 @@ Histories Model::decode(Tensor &encoder_out, Input &input) {
return histories;
}

Histories Model::forward(Input &input) {
Tensor &indices = input.indices();
Histories Model::forward(Input &input) const {
const Tensor &indices = input.indices();
Tensor &mask = input.mask();

// uint64_t batch_size = indices.dim(-2);
Expand Down
16 changes: 9 additions & 7 deletions slimt/Model.hh
Original file line number Diff line number Diff line change
Expand Up @@ -49,17 +49,19 @@ class Model {
explicit Model(const Config &config, const Package<std::string> &package);
explicit Model(const Config &config, const Package<View> &package);

Histories forward(Input &input);
Histories forward(Input &input) const;

Config &config() { return config_; }
Vocabulary &vocabulary() { return vocabulary_; }
TextProcessor &processor() { return processor_; }
Transformer &transformer() { return transformer_; }
const Config &config() const { return config_; }
const Vocabulary &vocabulary() const { return vocabulary_; }
const TextProcessor &processor() const { return processor_; }
const Transformer &transformer() const { return transformer_; }
size_t id() const { return id_; } // NOLINT
ShortlistGenerator &shortlist_generator() { return shortlist_generator_; }
const ShortlistGenerator &shortlist_generator() const {
return shortlist_generator_;
}

private:
Histories decode(Tensor &encoder_out, Input &input);
Histories decode(Tensor &encoder_out, Input &input) const;

size_t id_;
Config config_;
Expand Down
29 changes: 16 additions & 13 deletions slimt/Modules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

namespace slimt {

float retrieve_quantization_multiplier(Tensor &W) {
auto *b_end = W.end<int8_t>();
float b_quant = *(reinterpret_cast<float *>(b_end));
float retrieve_quantization_multiplier(const Tensor &W) {
const auto *b_end = W.end<int8_t>();
float b_quant = *(reinterpret_cast<const float *>(b_end));
return b_quant;
}

Expand Down Expand Up @@ -135,7 +135,8 @@ Tensor join_heads(Tensor &x) {
return y;
}

Tensor affine(Affine &parameters, Tensor &x, const std::string &name = "") {
Tensor affine(const Affine &parameters, Tensor &x,
const std::string &name = "") {
Tensor y = qmm::affine( //
x, //
parameters.W, parameters.b, //
Expand All @@ -146,7 +147,7 @@ Tensor affine(Affine &parameters, Tensor &x, const std::string &name = "") {
return y;
}

Tensor affine_with_select(Affine &parameters, Tensor &x,
Tensor affine_with_select(const Affine &parameters, Tensor &x,
const std::vector<uint32_t> &indices,
const std::string &name /*= ""*/) {
Tensor y = qmm::affine_with_select( //
Expand All @@ -160,7 +161,8 @@ Tensor affine_with_select(Affine &parameters, Tensor &x,
return y;
}

Tensor linear(Linear &parameters, Tensor &x, const std::string &name = "") {
Tensor linear(const Linear &parameters, Tensor &x,
const std::string &name = "") {
Tensor y = qmm::dot( //
x, parameters.W, //
parameters.quant.item<float>(), //
Expand All @@ -170,15 +172,15 @@ Tensor linear(Linear &parameters, Tensor &x, const std::string &name = "") {
return y;
}

Tensor SSRU::start_state(size_t batch_size) {
Tensor SSRU::start_state(size_t batch_size) const {
// auto start = graph->constant({1, 1, dimBatch, dim}, inits::zeros());
size_t feature_dim = O_.W.dim(-1);
Tensor start(Type::f32, Shape({batch_size, feature_dim}), "start");
start.fill_in_place(0.0F);
return start;
}

Tensor SSRU::forward(Tensor &state, Tensor &x) {
Tensor SSRU::forward(Tensor &state, Tensor &x) const {
// From Research to Production and Back: Ludicrously Fast Neural Machine
// Translation (https://aclanthology.org/D19-5632.pdf) Section 3.1 describes
// SSRU. SSRU is described by the following recurrent equations - which
Expand Down Expand Up @@ -235,7 +237,7 @@ Tensor SSRU::forward(Tensor &state, Tensor &x) {

std::tuple<Tensor, Tensor> DecoderLayer::forward(Tensor &encoder_out,
Tensor &mask, Tensor &state,
Tensor &x) {
Tensor &x) const {
Tensor decoder_out = rnn_.forward(state, x);

// Assign query, key, value for cross-attention.
Expand Down Expand Up @@ -272,12 +274,12 @@ DecoderLayer::DecoderLayer(size_t depth, size_t ffn_count, size_t num_heads)

FFN::FFN(size_t depth) : depth_(depth) {}

Tensor FFN::forward(Tensor &x) {
Tensor FFN::forward(Tensor &x) const {
Tensor y = affine(O_, x, "ffn" + std::to_string(depth_));
return y;
}

Tensor LayerNorm::forward(Tensor &x) {
Tensor LayerNorm::forward(Tensor &x) const {
Tensor y = x.like("ln_out");
size_t cols = x.dim(-1);
size_t rows = x.size() / cols;
Expand All @@ -299,7 +301,7 @@ Tensor LayerNorm::forward(Tensor &x) {
}

std::tuple<Tensor, Tensor> Attention::forward(Tensor &q, Tensor &k, Tensor &v,
Tensor &mask) {
Tensor &mask) const {
// We have a B x T x H sequence comoing in, for q, k and v.
Tensor yq = affine(Q_, q, "q");
Tensor yk = affine(K_, k, "k");
Expand Down Expand Up @@ -331,7 +333,8 @@ std::tuple<Tensor, Tensor> Attention::forward(Tensor &q, Tensor &k, Tensor &v,
return std::make_tuple(std::move(y), std::move(attn));
}

std::tuple<Tensor, Tensor> EncoderLayer::forward(Tensor &x, Tensor &mask) {
std::tuple<Tensor, Tensor> EncoderLayer::forward(Tensor &x,
Tensor &mask) const {
// TODO(fill code):
auto [out, attention] = attention_.forward(x, x, x, mask);

Expand Down
20 changes: 11 additions & 9 deletions slimt/Modules.hh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class LayerNorm {
public:
explicit LayerNorm() = default;
void register_parameters(const std::string &prefix, ParameterMap &parameters);
Tensor forward(Tensor &x);
Tensor forward(Tensor &x) const;

private:
Tensor bias_;
Expand All @@ -37,7 +37,7 @@ class Attention {
explicit Attention(std::string name, size_t num_heads);
void register_parameters(const std::string &prefix, ParameterMap &parameters);
std::tuple<Tensor, Tensor> forward(Tensor &q, Tensor &k, Tensor &v,
Tensor &mask);
Tensor &mask) const;

private:
std::string name_;
Expand All @@ -50,8 +50,8 @@ class SSRU {
public:
explicit SSRU() = default;
void register_parameters(const std::string &prefix, ParameterMap &parameters);
Tensor forward(Tensor &state, Tensor &x);
Tensor start_state(size_t batch_size);
Tensor forward(Tensor &state, Tensor &x) const;
Tensor start_state(size_t batch_size) const;

private:
Affine F_;
Expand All @@ -63,7 +63,7 @@ class FFN {
public:
explicit FFN(size_t depth);
void register_parameters(const std::string &prefix, ParameterMap &parameters);
Tensor forward(Tensor &x);
Tensor forward(Tensor &x) const;

private:
Affine O_;
Expand All @@ -74,7 +74,7 @@ class EncoderLayer {
public:
EncoderLayer(size_t depth, size_t ffn_count, size_t num_heads);
void register_parameters(const std::string &prefix, ParameterMap &parameters);
std::tuple<Tensor, Tensor> forward(Tensor &x, Tensor &mask);
std::tuple<Tensor, Tensor> forward(Tensor &x, Tensor &mask) const;

private:
size_t depth_;
Expand All @@ -88,8 +88,10 @@ class DecoderLayer {
explicit DecoderLayer(size_t depth, size_t ffn_count, size_t num_heads);
void register_parameters(const std::string &prefix, ParameterMap &parameters);
std::tuple<Tensor, Tensor> forward(Tensor &encoder_out, Tensor &mask,
Tensor &state, Tensor &x);
Tensor start_state(size_t batch_size) { return rnn_.start_state(batch_size); }
Tensor &state, Tensor &x) const;
Tensor start_state(size_t batch_size) const {
return rnn_.start_state(batch_size);
}

private:
size_t depth_;
Expand All @@ -99,7 +101,7 @@ class DecoderLayer {
LayerNorm ffn_ffn_;
};

Tensor affine_with_select(Affine &parameters, Tensor &x,
Tensor affine_with_select(const Affine &parameters, Tensor &x,
const std::vector<uint32_t> &indices,
const std::string &name = "");

Expand Down
Loading

0 comments on commit d9ffd96

Please sign in to comment.