Token samplers for large language models, written in Rust!
Extremely early in development, poorly tested. You can look at src/tests.rs
for some examples of use.
Also a fairly simple example of using Mirostat with my RWKV project here: https://github.com/KerfuffleV2/smolrsrwkv/blob/60b8e8bfe64f157f1800445128af3b4adbbc64c1/smolrwkv-cli/src/main.rs#L139-L164
For notes on migrating from 0.0.6 to 0.0.7, see below.
Using the term "sampler" here loosely, perhaps it should be renamed in the future. Right now a "sampler" could be something that manipulates the list of logits (for example, a top-k sampler might prune the list to the top K entries), it might actually pick a token or both!
- Flat bias - biases tokens by the specified amount
- Frequency / presence - Applies frequency and presence penalties
- Greedy - picks the token ID with the highest probability
- Locally typical
- Mirostat V1
- Mirostat V2
- Random distribution - picks a token ID based on weighted probabilities
- Repetition - applies a repetition penalty
- Tail free
- Temperature
- Top-K
- Top-P
- Min-P
- Top-A
Real descriptions may (or may not happen) eventually. For now, you can check out the llama.cpp main
example README for a brief overview of some of the types of sampler: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/README.md#generation-flags
You probably won't usually want to use individual Sampler
s. The most typical
use case is going to be chaining a number of samplers together.
A simple example of constructing a [SamplerChain]:
use anyhow::Result;
use llm_samplers::prelude::*;
pub fn test_chain1() -> Result<()> {
let mut logits = Logits::try_from_iter([0.1f32, 0.2, 0.3, 0.4].into_iter())?;
// Demonstrating the different ways you can build a SamplerChain.
// These are all equivalent.
let mut sc = SamplerChain::new()
+ SampleFlatBias::new([(3, f32::NEG_INFINITY)]);
sc += SampleTemperature::new(0.8);
sc.push_sampler(SampleGreedy::new());
assert_eq!(
sc.sample_token(
// These samplers don't actually need any resources.
&mut NilSamplerResources::default(),
&mut logits)?,
Some(1)
);
// () also implements HasSamplerResources
// so you could use &mut () here.
assert_eq!(sc.sample_token(&mut (), &mut logits)?, Some(1));
Ok(())
}
The previous example is simple but not very realistic: the greedy sampler doesn't even care about temperature. Now let's look at something a bit more complicated:
use anyhow::Result;
use rand::{SeedableRng, rngs::StdRng};
use llm_samplers::prelude::*;
fn test_chain2() -> Result<()> {
let example_logits = vec![0.1f32, 0.2, 0.3, 0.4];
let mut res = SimpleSamplerResources::new(
// Optionally include an RNG resource.
Some(Box::new(StdRng::seed_from_u64(123))),
// Optionally include a last tokens resource.
Some(vec![]),
);
let mut logits = Logits::try_from_iter(example_logits.into_iter())?;
let mut logits2 = logits.clone();
let mut sc = SamplerChain::new()
// Bias logits (this example sets bias for token id 3 to -inf)
+ SampleFlatBias::new([(3, f32::NEG_INFINITY)])
// Apply a repetition penalty.
+ SampleRepetition::new(1.1, 64)
// Apply frequency and presence penalties.
+ SampleFreqPresence::new(0.05, 0.1, 64)
// Apply temperature to logits.
+ SampleTemperature::new(0.8)
// Sample a token using Mirostat1
+ SampleMirostat1::new(4, 5.0, 0.1);
// Put a value into `last_tokens`, this simulates us having already picked
// that token (3) previously.
res.with_last_tokens_mut(&mut |tokens| tokens.push(3u32))?;
assert_eq!(sc.sample_token(&mut res, &mut logits)?, Some(2));
// Now add the last selected token to the list.
res.with_last_tokens_mut(&mut |tokens| tokens.push(2u32))?;
// And pick the next one. *Important*: Note that we don't reuse `logits`.
// This is because `logits` already has all the filtering/sorting/permutation
// from the previous sample call applied to it.
assert_eq!(sc.sample_token(&mut res, &mut logits2)?, Some(1));
Ok(())
}
Unfortunately, this involved some breaking changes. Basically, the samplers and chains no
longer take token id and logits type variables anymore. You can have your token ids in any
color you like, as long as it's u32
. Same for logits: they're always f32
now.
For example, where previously you would have done SampleRandDistrib::<u32>::new
or SampleMirostat2::<u32, f32>::new
,
you only need SampleRandDistrib::new
, SampleMirostat2::new
. Same for creating chains: SamplerChain::<u32, f32>::new
will
only need SamplerChain::new
.
Note: Crate/docs version likely won't match this repo.
Initial version closely referenced from the samplers in the llama.cpp project (although not a line-by-line port). Thanks!