Skip to content

Commit

Permalink
[API][CPP] Add overloaded func of model.config. (#45)
Browse files Browse the repository at this point in the history
Signed-off-by: Duyi-Wang <[email protected]>
  • Loading branch information
Duyi-Wang authored Nov 16, 2023
1 parent 5aa13c1 commit 6f5e981
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 13 deletions.
4 changes: 2 additions & 2 deletions examples/cpp/example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,9 @@ int main(int argc, char **argv) {
args.add<std::string>("dtype", 'd', "weight data type", false, "fp16",
cmdline::oneof<std::string>("fp16", "bf16", "int8", "bf16_fp16", "bf16_int8"));
args.add<int>("input_len", 'l', "input token size", false, -1);
args.add<int>("output_len", '\0', "max tokens can generate excluded input.", false, 100, cmdline::range(1, 4096));
args.add<int>("output_len", '\0', "max tokens can generate excluded input.", false, 100, cmdline::range(1, 8192));
args.add<int>("num_beams", 'n', "number of beam size.", false, 1, cmdline::range(1, 32));
args.add<int>("batch_size", 'b', "batch size.", false, 1, cmdline::range(1, 32));
args.add<int>("batch_size", 'b', "batch size.", false, 1, cmdline::range(1, 512));
args.add<int>("loop", '\0', "number of loop.", false, 10);
args.add<int>("topK", '\0', "number of highest probability tokens to keep for top-k-filtering.", false, 50);
args.add<float>("temperature", '\0', "value used to modulate the next token probabilities.", false, 1.0);
Expand Down
2 changes: 2 additions & 0 deletions include/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ class Model {
bool doEarlyStopping_ = false, int eosTokenId_ = -1, int padTokenId_ = -1, bool doSample_ = false,
float temperature_ = 1.0, int topK_ = 50, float topP_ = 1.0);

void config(SearcherConfig &config_);

bool isDone();

std::vector<int32_t> generate();
Expand Down
48 changes: 37 additions & 11 deletions src/models/models.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,32 @@
#include <stdexcept>

#include "INIReader.h"
#include "baichuan.h"
#include "chatglm.h"
#include "chatglm2.h"
#include "hybrid_model.h"
#include "llama.h"
#include "baichuan.h"
#include "opt_decoder.h"
#include "searcher.h"

namespace xft {
enum class GenerationMode { GREEDY_SEARCH, BEAM_SEARCH, SAMPLE };

GenerationMode getGenerationMode(SearcherConfig &config_) {
if (config_.numBeams == 1) {
if (config_.doSample) {
return GenerationMode::SAMPLE;
} else {
return GenerationMode::GREEDY_SEARCH;
}
} else if (config_.numBeams > 1) {
return GenerationMode::BEAM_SEARCH;
} else {
printf("numBeams should greater than or equal to 1.\n");
exit(-1);
}
}

Model::~Model() {
exitSlaves();
if (decoder != nullptr) { delete decoder; }
Expand Down Expand Up @@ -84,6 +101,18 @@ void Model::config(int maxLen_, int numBeams_, int numBeamHypsToKeep_, float len
createSearcher(configuration);
}

void Model::config(SearcherConfig &config_) {
isNewInput = true;
if (decoder->getRank() == 0) { configuration = config_; }
Messenger &messenger = decoder->getMessenger();
messenger.broadcast((int *)&configuration, sizeof(SearcherConfig) / sizeof(int));

// Slaves get exit flags and exit directly
if (decoder->getRank() > 0 && configuration.numBeams == 0) { exit(0); }

createSearcher(configuration);
}

bool Model::isDone() {
if (searcher == nullptr || inputIds.empty()) {
printf("Please set input and config first.\n");
Expand Down Expand Up @@ -112,17 +141,14 @@ std::vector<int32_t> Model::generate() {

void Model::createSearcher(SearcherConfig &config_) {
if (searcher != nullptr) { delete searcher; }
if (config_.numBeams < 1) {
printf("numBeams should greater than or equal to 1.\n");
exit(-1);
} else if (config_.numBeams == 1) {
if (config_.doSample) {
searcher = new SampleSearch(*decoder, config_);
} else {
searcher = new GreedySearch(*decoder, config_);
}
} else {

GenerationMode genMode = getGenerationMode(config_);
if (genMode == GenerationMode::GREEDY_SEARCH) {
searcher = new GreedySearch(*decoder, config_);
} else if (genMode == GenerationMode::BEAM_SEARCH) {
searcher = new BeamSearch(*decoder, config_);
} else if (genMode == GenerationMode::SAMPLE) {
searcher = new SampleSearch(*decoder, config_);
}
}

Expand Down

0 comments on commit 6f5e981

Please sign in to comment.