Skip to content

Commit

Permalink
feat: minP support (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
giladgd authored Feb 11, 2024
1 parent 46235a2 commit 47b476f
Show file tree
Hide file tree
Showing 7 changed files with 56 additions and 9 deletions.
8 changes: 7 additions & 1 deletion llama/addon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,7 @@ class AddonContextSampleTokenWorker : Napi::AsyncWorker, Napi::Promise::Deferred
bool use_grammar = false;
llama_token result;
float temperature = 0.0f;
float min_p = 0;
int32_t top_k = 40;
float top_p = 0.95f;
float repeat_penalty = 1.10f; // 1.0 = disabled
Expand All @@ -732,6 +733,10 @@ class AddonContextSampleTokenWorker : Napi::AsyncWorker, Napi::Promise::Deferred
temperature = options.Get("temperature").As<Napi::Number>().FloatValue();
}

if (options.Has("minP")) {
min_p = options.Get("minP").As<Napi::Number>().FloatValue();
}

if (options.Has("topK")) {
top_k = options.Get("topK").As<Napi::Number>().Int32Value();
}
Expand Down Expand Up @@ -833,6 +838,7 @@ class AddonContextSampleTokenWorker : Napi::AsyncWorker, Napi::Promise::Deferred
llama_sample_tail_free(ctx->ctx, &candidates_p, tfs_z, min_keep);
llama_sample_typical(ctx->ctx, &candidates_p, typical_p, min_keep);
llama_sample_top_p(ctx->ctx, &candidates_p, resolved_top_p, min_keep);
llama_sample_min_p(ctx->ctx, &candidates_p, min_p, min_keep);
llama_sample_temp(ctx->ctx, &candidates_p, temperature);
new_token_id = llama_sample_token(ctx->ctx, &candidates_p);
}
Expand Down Expand Up @@ -879,7 +885,7 @@ void addonCallJsLogCallback(
) {
bool called = false;

if (env != nullptr && callback != nullptr) {
if (env != nullptr && callback != nullptr && addonJsLoggerCallbackSet) {
try {
callback.Call({
Napi::Number::New(env, data->logLevelNumber),
Expand Down
1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
"function-calling",
"embedding",
"temperature",
"minP",
"topK",
"topP",
"json-schema",
Expand Down
1 change: 1 addition & 0 deletions src/bindings/AddonTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ export type AddonContext = {
decodeBatch(): Promise<void>,
sampleToken(batchLogitIndex: BatchLogitIndex, options?: {
temperature?: number,
minP?: number,
topK?: number,
topP?: number,
repeatPenalty?: number,
Expand Down
17 changes: 13 additions & 4 deletions src/cli/commands/ChatCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ type ChatCommand = {
jsonSchemaGrammarFile?: string,
threads: number,
temperature: number,
minP: number,
topK: number,
topP: number,
gpuLayers?: number,
Expand Down Expand Up @@ -151,6 +152,13 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
description: "Temperature is a hyperparameter that controls the randomness of the generated text. It affects the probability distribution of the model's output tokens. A higher temperature (e.g., 1.5) makes the output more random and creative, while a lower temperature (e.g., 0.5) makes the output more focused, deterministic, and conservative. The suggested temperature is 0.8, which provides a balance between randomness and determinism. At the extreme, a temperature of 0 will always pick the most likely next token, leading to identical outputs in each run. Set to `0` to disable.",
group: "Optional:"
})
.option("minP", {
alias: "mp",
type: "number",
default: 0,
description: "From the next token candidates, discard the percentage of tokens with the lowest probability. For example, if set to `0.05`, 5% of the lowest probability tokens will be discarded. This is useful for generating more high-quality results when using a high temperature. Set to a value between `0` and `1` to enable. Only relevant when `temperature` is set to a value greater than `0`.",
group: "Optional:"
})
.option("topK", {
alias: "k",
type: "number",
Expand Down Expand Up @@ -243,15 +251,15 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
async handler({
model, systemInfo, systemPrompt, systemPromptFile, prompt,
promptFile, wrapper, contextSize, batchSize,
grammar, jsonSchemaGrammarFile, threads, temperature, topK, topP,
gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK,
topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory,
environmentFunctions, noInfoLog, printTimings
}) {
try {
await RunChat({
model, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, contextSize, batchSize,
grammar, jsonSchemaGrammarFile, threads, temperature, topK, topP, gpuLayers, lastTokensRepeatPenalty,
grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty,
repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
noHistory, environmentFunctions, noInfoLog, printTimings
});
Expand All @@ -265,7 +273,7 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {

async function RunChat({
model: modelArg, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, contextSize, batchSize,
grammar: grammarArg, jsonSchemaGrammarFile: jsonSchemaGrammarFilePath, threads, temperature, topK, topP, gpuLayers,
grammar: grammarArg, jsonSchemaGrammarFile: jsonSchemaGrammarFilePath, threads, temperature, minP, topK, topP, gpuLayers,
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
maxTokens, noHistory, environmentFunctions, noInfoLog, printTimings
}: ChatCommand) {
Expand Down Expand Up @@ -425,6 +433,7 @@ async function RunChat({
await session.prompt(input, {
grammar: grammar as undefined, // this is a workaround to allow passing both `functions` and `grammar`
temperature,
minP,
topK,
topP,
repeatPenalty: {
Expand Down
14 changes: 13 additions & 1 deletion src/evaluator/LlamaChat/LlamaChat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,17 @@ export type LLamaChatGenerateResponseOptions<Functions extends ChatModelFunction
*/
temperature?: number,

/**
* From the next token candidates, discard the percentage of tokens with the lowest probability.
* For example, if set to `0.05`, 5% of the lowest probability tokens will be discarded.
* This is useful for generating more high-quality results when using a high temperature.
* Set to a value between `0` and `1` to enable.
*
* Only relevant when `temperature` is set to a value greater than `0`.
* Disabled by default.
*/
minP?: number,

/**
* Limits the model to consider only the K most likely next tokens for sampling at each step of sequence generation.
* An integer number between `1` and the size of the vocabulary.
Expand Down Expand Up @@ -260,6 +271,7 @@ export class LlamaChat {
signal,
maxTokens,
temperature,
minP,
topK,
topP,
grammar,
Expand Down Expand Up @@ -535,7 +547,7 @@ export class LlamaChat {


const evaluationIterator = this._sequence.evaluate(tokens, removeNullFields({
temperature, topK, topP,
temperature, minP, topK, topP,
grammarEvaluationState: () => {
if (inFunctionEvaluationMode)
return functionsEvaluationState;
Expand Down
16 changes: 15 additions & 1 deletion src/evaluator/LlamaChatSession/LlamaChatSession.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,17 @@ export type LLamaChatPromptOptions<Functions extends ChatSessionModelFunctions |
*/
temperature?: number,

/**
* From the next token candidates, discard the percentage of tokens with the lowest probability.
* For example, if set to `0.05`, 5% of the lowest probability tokens will be discarded.
* This is useful for generating more high-quality results when using a high temperature.
* Set to a value between `0` and `1` to enable.
*
* Only relevant when `temperature` is set to a value greater than `0`.
* Disabled by default.
*/
minP?: number,

/**
* Limits the model to consider only the K most likely next tokens for sampling at each step of sequence generation.
* An integer number between `1` and the size of the vocabulary.
Expand Down Expand Up @@ -233,6 +244,7 @@ export class LlamaChatSession {
signal,
maxTokens,
temperature,
minP,
topK,
topP,
grammar,
Expand All @@ -244,7 +256,7 @@ export class LlamaChatSession {
functions: functions as undefined,
documentFunctionParams: documentFunctionParams as undefined,

onToken, signal, maxTokens, temperature, topK, topP, grammar, trimWhitespaceSuffix, repeatPenalty
onToken, signal, maxTokens, temperature, minP, topK, topP, grammar, trimWhitespaceSuffix, repeatPenalty
});

return responseText;
Expand All @@ -261,6 +273,7 @@ export class LlamaChatSession {
signal,
maxTokens,
temperature,
minP,
topK,
topP,
grammar,
Expand Down Expand Up @@ -309,6 +322,7 @@ export class LlamaChatSession {
onToken,
signal,
repeatPenalty,
minP,
topK,
topP,
maxTokens,
Expand Down
8 changes: 6 additions & 2 deletions src/evaluator/LlamaContext/LlamaContext.ts
Original file line number Diff line number Diff line change
Expand Up @@ -610,6 +610,7 @@ export class LlamaContextSequence {
*/
public evaluate(tokens: Token[], {
temperature = 0,
minP = 0,
topK = 40,
topP = 0.95,
grammarEvaluationState,
Expand All @@ -621,7 +622,7 @@ export class LlamaContextSequence {
} = {},
yieldEosToken = false
}: {
temperature?: number, topK?: number, topP?: number,
temperature?: number, minP?: number, topK?: number, topP?: number,
grammarEvaluationState?: LlamaGrammarEvaluationState | (() => LlamaGrammarEvaluationState | undefined),
repeatPenalty?: LlamaContextSequenceRepeatPenalty,

Expand All @@ -648,6 +649,7 @@ export class LlamaContextSequence {
} = {}): AsyncGenerator<Token, void> {
return this._evaluate(tokens, {
temperature,
minP,
topK,
topP,
grammarEvaluationState,
Expand Down Expand Up @@ -707,6 +709,7 @@ export class LlamaContextSequence {
/** @internal */
private async *_evaluate(tokens: Token[], {
temperature = 0,
minP = 0,
topK = 40,
topP = 0.95,
grammarEvaluationState,
Expand All @@ -716,7 +719,7 @@ export class LlamaContextSequence {
contextShiftOptions,
yieldEosToken = false
}: {
temperature?: number, topK?: number, topP?: number,
temperature?: number, minP?: number, topK?: number, topP?: number,
grammarEvaluationState?: LlamaGrammarEvaluationState | (() => LlamaGrammarEvaluationState | undefined),
repeatPenalty?: LlamaContextSequenceRepeatPenalty, evaluationPriority?: EvaluationPriority,
generateNewTokens?: boolean, contextShiftOptions: Required<ContextShiftOptions>, yieldEosToken?: boolean
Expand Down Expand Up @@ -752,6 +755,7 @@ export class LlamaContextSequence {

return this._context._ctx.sampleToken(batchLogitIndex, removeNullFields({
temperature,
minP,
topK,
topP,
repeatPenalty: repeatPenalty?.penalty,
Expand Down

0 comments on commit 47b476f

Please sign in to comment.