diff --git a/common/sampling.cpp b/common/sampling.cpp index 1317024c2c11c..8a4b1f4384f6b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -166,7 +166,7 @@ llama_token llama_sampling_sample( } if (ctx_sampling->grammar != NULL) { - llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar); + llama_sample_grammar(ctx_main, &cur_p, ctx_sampling->grammar, nullptr); } if (temp < 0.0) { diff --git a/llama.cpp b/llama.cpp index c5f4053f2fffb..0e87da50a4363 100644 --- a/llama.cpp +++ b/llama.cpp @@ -6420,10 +6420,13 @@ struct llama_grammar_candidate { // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. static std::pair, llama_partial_utf8> decode_utf8( const char * src, + size_t n_src, llama_partial_utf8 partial_start) { static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; const char * pos = src; std::vector code_points; + // common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0. + code_points.reserve(n_src + 1); uint32_t value = partial_start.value; int n_remain = partial_start.n_remain; @@ -6474,6 +6477,13 @@ static std::pair, llama_partial_utf8> decode_utf8( return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); } +static std::pair, llama_partial_utf8> decode_utf8( + std::string src, + llama_partial_utf8 partial_start +) { + return decode_utf8(src.c_str(), src.size(), partial_start); +} + // returns true iff pos points to the end of one of the definitions of a rule static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { switch (pos->type) { @@ -7096,7 +7106,11 @@ void llama_sample_repetition_penalties( } } -void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * candidates, const struct llama_grammar * grammar) { +void llama_sample_grammar( + struct llama_context * ctx, + llama_token_data_array * candidates, + const struct llama_grammar * grammar, + char const * const * pieces) { GGML_ASSERT(ctx); const int64_t t_start_sample_us = ggml_time_us(); @@ -7115,7 +7129,14 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c for (size_t i = 0; i < candidates->size; ++i) { const llama_token id = candidates->data[i].id; - const std::string piece = llama_token_to_piece(ctx, id); + std::string piece; + + if (pieces != nullptr && pieces[id] != nullptr) { + piece = std::string(pieces[id]); + } else { + piece = llama_token_to_piece(ctx, id); + } + if (id == eos) { if (!allow_eos) { candidates->data[i].logit = -INFINITY; @@ -7123,7 +7144,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c } else if (piece.empty() || piece[0] == 0) { candidates->data[i].logit = -INFINITY; } else { - candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8)); + candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8)); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); } } @@ -7330,7 +7351,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar const std::string piece = llama_token_to_piece(ctx, token); // Note terminating 0 in decoded string - const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8); + const auto decoded = decode_utf8(piece, grammar->partial_utf8); const auto & code_points = decoded.first; for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); diff --git a/llama.h b/llama.h index 1a62058d1406b..9d6e1599f5988 100644 --- a/llama.h +++ b/llama.h @@ -719,10 +719,13 @@ extern "C" { "use llama_sample_temp instead"); /// @details Apply constraints from grammar + /// @param pieces an array of all null terminated strings obtained from calling llama_token_to_piece for the whole vocab. Can be nullptr in which case they will be computed. LLAMA_API void llama_sample_grammar( struct llama_context * ctx, llama_token_data_array * candidates, - const struct llama_grammar * grammar); + const struct llama_grammar * grammar, + char const * const * pieces); + /// @details Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. /// @param candidates A vector of `llama_token_data` containing the candidate tokens, their probabilities (p), and log-odds (logit) for the current position in the generated text.