Skip to content

Commit

Permalink
llama : support Mamba Selective State Space Models (ggerganov#5328)
Browse files Browse the repository at this point in the history
* mamba : begin working on support for Mamba SSM

* mamba : begin figuring out how to (ab)use the kv cache for Mamba

* mamba : recurrent inference almost works, but incoherent

* mamba : recurrent inference WORKS!!!

* convert : optionally use d_conv and d_state from config.json for Mamba

* mamba : refactor recurrent conv, resulting in 20% perf increase

It's still slower than I'd like, but I did not really optimize `ggml_exp` yet.

I also refactored `ggml_exp` to work with tensors with more than 2 dimensions.

* ggml : parallelize ggml_exp

This results in 8% faster token generation for Mamba-130M.

* mamba : simplify the conv step with a self-overlapping view

Turns out the conv_state can be made smaller by one column.
Note that this breaks existing GGUFs of Mamba,
because the key_value_length field is tied to the conv_state size.

Convolution with a self-overlapping view is cool!
And it's much simpler than what I initially thought would be necessary
to make the convolution step work with more than 1 token at a time.

Next step is to make the SSM step work on batches of tokens too,
and thus I need to figure out a way to make a parallel selective scan
which will keep the ssm_state small and won't make it bigger
by a factor of (n_layer * batch_size).

* llama : fix Mamba KV self size wrongly displaying as f16 instead of f32

Relatedly, I also tried to see if other types than f32 worked for the states,
but they don't, because of the operators used.
It's probably better anyway to keep lots of precision there,
since the states are small anyway.

* mamba : fix self-overlapping view depth stride

* mamba : handle batches of more than 1 token

This means running Mamba no longer crashes when using the default settings!
And probably also slightly faster prompt processing.
Both batched and non-batched processing yield the same output.

Previously, the state was not cleared when starting a sequence.
Next step is to make the KV cache API work as expected for Mamba models.

* ggml: add ggml_ssm_scan to help with parallel selective scan

If the selective scan was implemented without a custom operator,
there would be waaay too many nodes in the graph. For example,
for Mamba-130M, with a batch size of 512 (the default),
a naive selective scan could add at least 24*512=12288 nodes,
which is more than LLAMA_MAX_NODES (8192),
and that's only for the smallest Mamba model.
So it's much cleaner with a custom operator.
Not sure about the name, though.

* ggml : in ggml_ssm_scan, merge multiple rows in the same vec operation

This will help with performance on CPU if ggml_vec_mul_f32
and ggml_vec_add_f32 are ever optimized with SIMD.

* mamba : very basic quantization support

Mostly works, but there is currently no difference
between the variants of a k-quant (e.g. Q4_K_S and Q4_K_M are the same).
Most of the SSM-specific weights can be kept in f32 without affecting
the size that much, since they are relatively small.
(the linear projection weights are responsible for most of Mamba's size)

Too much quantization seems to make the state degrade quite fast, and
the model begins to output gibberish.
It seems to affect bigger models to a lesser extent than small models,
but I'm not sure by how much.

Experimentation will be needed to figure out which weights are more important
for the _M (and _L?) variants of k-quants for Mamba.

* convert : fix wrong name for layer norm weight of offical Mamba models

I was using Q-bert/Mamba-* models before, which have a slighlty different
naming scheme for the weights.
(they start with "model.layers" instead of "backbone.layers")

* mamba : fuse more steps of the SSM scan in the ggml_ssm_scan operator

This increases performance on CPU by around 30% for prompt processing,
and by around 20% for text generation.

However, it also makes the ggml_exp and ggml_soft_plus operators unused.
Whether or not they should be kept will be decided later.

* convert : for Mamba, also consider the "MambaLMHeadModel" arch name

It's the name of the class of the official implementation,
though they don't use it (yet) in the "architectures" field of config.json

* mamba : fix vocab size problems with official models

The perplexity was waaaay to high for models with a non-round vocab size.
Not sure why, but it needed to be fixed in the metadata.

Note that this breaks existing GGUF-converted Mamba models,
but **only if** the vocab size was not already rounded.

* ggml : remove ggml_exp and ggml_soft_plus

They did not exist anyway outside of this branch,
and since ggml_ssm_scan fused operations together, they are unused.
It's always possible to bring them back if needed.

* mamba : remove some useless comments

No code change.

* convert : fix flake8 linter errors

* mamba : apply suggestions from code review

* mamba : remove unecessary branch for row-wise ssm_state and C multiplication

It was previously done to avoid permuting when only one token is processed
at a time (like when generating text), but permuting is cheap,
and dynamically changing the compute graph is not future-proof.

* ggml : in ggml_ssm_scan, use more appropriate asserts

* ggml : rename the destination pointer in ggml_compute_forward_ssm_scan_f32

* mamba : multiple sequences, but one at a time

This is a step towards making this Mamba implementation usable
with the server example (the way the system prompt is kept when clearing
the client slots will need to be changed before this can work, though).

The KV cache size for this kind of model is tied to the maximum number
of sequences kept at any single time.
For now, this number is obtained from n_parallel (plus one,
to have an extra sequence to dedicate to the system prompt),
but there might be a better way to do this which won't also
make the main example use 2 cells even if only 1 is really used.
(for this specific case, --parallel 0 helps)

Simultaneous sequence processing will probably require changes to
ggml_ssm_scan, and possibly a new operator for the conv step.

* mamba : support llama_kv_cache_seq_cp

This (mis)uses the logic around K shifts, because tokens in a state
can't be shifted anyway, and because inp_K_shift has the right shape and type.
Using ggml_get_rows is a nice way to do copies, but copy chains can't work.
Fortunately, copy chains don't really seem to be used in the examples.

Each KV cell is dedicated to the sequence ID corresponding to its own index.

* mamba : use a state mask

It's cleaner than the previous heuristic of
checking for the pos of the first token in the batch.

inp_KQ_mask could not be re-used for this, because it has the wrong shape
and because it seems more suited to the next step of
simultaneous sequence processing (helping with the problem of
remembering which token belongs to which sequence(s)/state(s)).

* llama : replace the usage of n_ctx with kv_self.size in many places

* mamba : use n_tokens directly instead of n_tok

* mamba : in comments, properly refer to KV cells instead of slots

* mamba : reduce memory usage of ggml_ssm_scan

From 290.37 MiB to 140.68 MiB of CPU compute buffer size
with Mamba 3B with a batch size of 512.

The result tensor of ggml_ssm_scan was previously a big part
of the CPU compute buffer size. To make it smaller,
it does not contain the intermediate ssm states anymore.
Both y and the last ssm state are combined in the result tensor,
because it seems only a single tensor can be returned by an operator
with the way the graph is built.

* mamba : simultaneous sequence processing

A batch can now contain tokens from multiple sequences.

This is necessary for at least the parallel example, the server example,
and the HellaSwag test in the perplexity example.

However, for this to be useful, uses of llama_kv_cache_seq_rm/cp
will need to be changed to work on whole sequences.

* ggml : add ggml_ssm_conv as a new operator for the conv step of Mamba

This operator makes it possible to use and update the correct states
for each token of the batch in the same way as ggml_ssm_scan.
Other solutions which use existing operators would need loops which would
add too many nodes to the graph (at least the ones I thought of).

Using this operator further reduces the size of the CPU compute buffer
from 140.68 MiB to 103.20 MiB with Mamba 3B with a batch size of 512.
And (at least on CPU), it's a bit faster than before.

Note that "ggml_ssm_conv" is probably not the most appropriate name,
and it could be changed if a better one is found.

* llama : add inp_s_seq as a new input tensor

The most convenient implementation to select the correct state (for Mamba)
for each token is to directly get the correct index from a tensor.
This is why inp_s_seq is storing int32_t and not floats.

The other, less convenient way to select the correct state would be
to have inp_KQ_mask contain 1.0f for each state used by a token
and 0.0f otherwise. This complicates quickly fetching the first used
state of a token, and is also less efficient because a whole row
of the mask would always need to be read for each token.

Using indexes makes it easy to stop searching when there are
no more sequences for a token, and the first sequence assigned
is always very quickly available (it's the first element of each row).

* mamba : support llama_kv_cache_seq_cp copy chains

* mamba : support shifting and dividing the kv cache pos

* mamba : make the server and parallel examples work with whole sequences

A seq_id is dedicated to the system prompt in both cases.

* llama : make llama_kv_cache_seq_rm return whether it succeeded or not

* mamba : dedicate an input tensor for state copy indices

This is cleaner and makes it easier to adapt when/if token positions
(and by extension, inp_K_shift) are no longer integers.

* mamba : adapt perplexity, batched, and batched-bench examples

* perplexity : limit the max number of sequences

This adapts to what the loaded model can provide.

* llama : add llama_n_max_seq to get the upper limit for seq_ids

Used by the perplexity example.

* batched : pass n_parallel to the model's context params

This should have been there already, but it wasn't.

* batched-bench : reserve sequences to support Mamba

* batched-bench : fix tokens being put in wrong sequences

Generation quality isn't what's measured in there anyway,
but at least using the correct sequences avoids using non-consecutive
token positions.

* mamba : stop abusing attention metadata

This breaks existing converted-to-GGUF Mamba models,
but will allow supporting mixed architectures like MambaFormer
without needing to break Mamba models.

This will also allow changing the size of Mamba's states
without having to reconvert models in the future.
(e.g. using something else than d_conv - 1 columns for the conv_states
 will not require breaking existing converted Mamba models again)

* gguf-py : add new KV metadata key-value pairs for Mamba

* llama : add new metadata key-value pairs for Mamba

* llama : guard against divisions by zero when n_head is 0

* mamba : rename "unlimited" KV cache property to "recurrent"

* mamba : more correctly update the "used" field of the KV cache

* ggml : in ggml_ssm_scan, use a threshold for soft_plus

This is how the official Mamba implementation does it,
and it's also what torch.nn.Softplus does.

* convert : for Mamba, fallback to internal NeoX tokenizer

The resulting models are exactly the same
as if the tokenizer.json and tokenizer_config.json of GPT-NeoX were there.

* mamba : support state saving and restoring

* ggml : implicitly pass src tensors through dst for Mamba-related ops

* mamba : clarify some comments

* server : fix cache_tokens not getting correctly resized

Otherwise, when the "we have to evaluate at least 1 token" special case
was triggered, an extra token was kept in cache_tokens even if it was
removed from the KV cache.

For Mamba, this caused useless prompt reprocessing when the previous
request triggered the above case.

* convert-hf : support new metadata keys for Mamba

For the models available at
https://huggingface.co/collections/state-spaces/transformers-compatible-mamba-65e7b40ab87e5297e45ae406

* mamba : rename metadata to be more similar to transformers library

This breaks existing converted-to-GGUF models,
but the metadata names are more "standard".

* mamba : support mamba-*-hf models

These models share their token_embd.weight with their output.weight

* mamba : add missing spaces

This is purely a formatting change.

* convert-hf : omit output.weight when identical with token_embd.weight

Only for Mamba for now, but it might be relevant for other models eventually.
Most Mamba models actually share these two tensors, albeit implicitly.

* readme : add Mamba to supported models, and add recent API changes

* mamba : move state_seq and state_mask views outside layer loop

A few tensors were also missing `struct` in front of `ggml_tensor`.
  • Loading branch information
compilade authored and jordankanter committed Mar 13, 2024
1 parent 48e39de commit 5c6dc3d
Show file tree
Hide file tree
Showing 15 changed files with 1,342 additions and 74 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)

### Recent API changes

- [2024 Mar 8] `llama_kv_cache_seq_rm()` returns a `bool` instead of `void`, and new `llama_n_max_seq()` returns the upper limit of acceptable `seq_id` in batches (relevant when dealing with multiple sequences) https://github.com/ggerganov/llama.cpp/pull/5328
- [2024 Mar 4] Embeddings API updated https://github.com/ggerganov/llama.cpp/pull/5796
- [2024 Mar 3] `struct llama_context_params` https://github.com/ggerganov/llama.cpp/pull/5849

Expand Down Expand Up @@ -110,6 +111,7 @@ Typically finetunes of the base models below are supported as well.
- [x] [InternLM2](https://huggingface.co/models?search=internlm2)
- [x] [CodeShell](https://github.com/WisdomShell/codeshell)
- [x] [Gemma](https://ai.google.dev/gemma)
- [x] [Mamba](https://github.com/state-spaces/mamba)

**Multimodal models:**

Expand Down
1 change: 1 addition & 0 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1288,6 +1288,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param

cparams.n_ctx = params.n_ctx;
cparams.n_batch = params.n_batch;
cparams.n_parallel = params.n_parallel;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;
cparams.seed = params.seed;
Expand Down
118 changes: 118 additions & 0 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,124 @@ class StarCoder2Model(Model):
model_arch = gguf.MODEL_ARCH.STARCODER2


@Model.register("MambaForCausalLM", "MambaLMHeadModel")
class MambaModel(Model):
model_arch = gguf.MODEL_ARCH.MAMBA

def set_vocab(self):
vocab_size = self.hparams["vocab_size"]
# Round vocab size to next multiple of 8
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 8)
# pad using ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
self.hparams["vocab_size"] = vocab_size

if (self.dir_model / "tokenizer.json").is_file():
self._set_vocab_gpt2()
else:
# Use the GPT-NeoX tokenizer when no tokenizer files are present
tokenizer_path = Path(sys.path[0]) / "models" / "ggml-vocab-gpt-neox.gguf"
print(f"Using tokenizer from '{os.path.relpath(tokenizer_path, os.getcwd())}'")
neox_reader = gguf.GGUFReader(tokenizer_path, "r")

field = neox_reader.get_field(gguf.Keys.Tokenizer.MODEL)
self.gguf_writer.add_tokenizer_model(bytes(field.parts[-1]))
field = neox_reader.get_field(gguf.Keys.Tokenizer.LIST)
self.gguf_writer.add_token_list([bytes(field.parts[i]) for i in field.data][:vocab_size])
field = neox_reader.get_field(gguf.Keys.Tokenizer.TOKEN_TYPE)
self.gguf_writer.add_token_types([field.parts[i].tolist()[0] for i in field.data][:vocab_size])
field = neox_reader.get_field(gguf.Keys.Tokenizer.MERGES)
self.gguf_writer.add_token_merges([bytes(field.parts[i]) for i in field.data])
field = neox_reader.get_field(gguf.Keys.Tokenizer.BOS_ID)
self.gguf_writer.add_bos_token_id(field.parts[-1].tolist()[0])
field = neox_reader.get_field(gguf.Keys.Tokenizer.EOS_ID)
self.gguf_writer.add_eos_token_id(field.parts[-1].tolist()[0])
field = neox_reader.get_field(gguf.Keys.Tokenizer.UNK_ID)
self.gguf_writer.add_unk_token_id(field.parts[-1].tolist()[0])

def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "d_model"])
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 16
# ceiling division
# ref: https://stackoverflow.com/a/17511341/22827863
# ref: https://github.com/state-spaces/mamba/blob/ce59daea3a090d011d6476c6e5b97f6d58ddad8b/mamba_ssm/modules/mamba_simple.py#L58
dt_rank = self.find_hparam(["time_step_rank", "dt_rank"], optional=True) or -(d_model // -16)
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5

# Fail early for models which don't have a block expansion factor of 2
assert d_inner == 2 * d_model

self.gguf_writer.add_name(self.dir_model.name)
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
self.gguf_writer.add_block_count(self.hparams["n_layer"])
self.gguf_writer.add_ssm_conv_kernel(d_conv)
self.gguf_writer.add_ssm_inner_size(d_inner)
self.gguf_writer.add_ssm_state_size(d_state)
self.gguf_writer.add_ssm_time_step_rank(dt_rank)
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
self.gguf_writer.add_file_type(self.ftype)

def write_tensors(self):
block_count = self.hparams["n_layer"]
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)

tok_embd = None
tok_embd_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.TOKEN_EMBD] + ".weight"
output_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.OUTPUT] + ".weight"

for name, data_torch in self.get_tensors():
old_dtype = data_torch.dtype

# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)

# map tensor names
new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()

if name.endswith(".A_log"):
print("A_log --> A ==> " + new_name)
data_torch = -torch.exp(data_torch)

# assuming token_embd.weight is seen before output.weight
if tok_embd is not None and new_name == output_name:
if torch.equal(tok_embd, data_torch):
print(f"{output_name} is equivalent to {tok_embd_name}, omitting")
continue
if new_name == tok_embd_name:
tok_embd = data_torch

data = data_torch.squeeze().numpy()

n_dims = len(data.shape)
data_dtype = data.dtype

# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)

# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)

# if f16 desired, convert big float32 2-dim weight tensors to float16
if self.ftype == 1 and data_dtype == np.float32 and new_name.removesuffix(".weight").endswith((".ssm_in", ".ssm_out", "token_embd", "output")) and n_dims == 2:
data = data.astype(np.float16)

print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")

self.gguf_writer.add_tensor(new_name, data)


###### CONVERSION LOGIC ######


Expand Down
13 changes: 8 additions & 5 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ int main(int argc, char ** argv) {
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;

// ensure enough sequences are available
ctx_params.n_parallel = *std::max_element(n_pl.begin(), n_pl.end());

llama_context * ctx = llama_new_context_with_model(model, ctx_params);

if (ctx == NULL) {
Expand Down Expand Up @@ -174,10 +177,10 @@ int main(int argc, char ** argv) {

llama_batch_clear(batch);

const int n_tokens = is_pp_shared ? pp : pl*pp;

for (int i = 0; i < n_tokens; ++i) {
llama_batch_add(batch, 0, i, { 0 }, false);
for (int i = 0; i < pp; ++i) {
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
llama_batch_add(batch, 0, i, { j }, false);
}
}
batch.logits[batch.n_tokens - 1] = true;

Expand All @@ -192,7 +195,7 @@ int main(int argc, char ** argv) {

if (is_pp_shared) {
for (int32_t i = 1; i < pl; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, 0, pp);
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}
}

Expand Down
3 changes: 2 additions & 1 deletion examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ int main(int argc, char ** argv) {
ctx_params.seed = 1234;
ctx_params.n_ctx = n_kv_req;
ctx_params.n_batch = std::max(n_len, n_parallel);
ctx_params.n_parallel = n_parallel;
ctx_params.n_threads = params.n_threads;
ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch;

Expand Down Expand Up @@ -132,7 +133,7 @@ int main(int argc, char ** argv) {
// assign the system KV cache to all parallel sequences
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (int32_t i = 1; i < n_parallel; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens);
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}

if (n_parallel > 1) {
Expand Down
20 changes: 13 additions & 7 deletions examples/parallel/parallel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ int main(int argc, char ** argv) {
// number of simultaneous "clients" to simulate
const int32_t n_clients = params.n_parallel;

// dedicate one sequence to the system prompt
params.n_parallel += 1;

// requests to simulate
const int32_t n_seq = params.n_sequences;

Expand Down Expand Up @@ -196,8 +199,8 @@ int main(int argc, char ** argv) {
}

// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i < n_clients; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
for (int32_t i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}

LOG_TEE("\n");
Expand All @@ -221,15 +224,17 @@ int main(int argc, char ** argv) {

client.i_batch = batch.n_tokens;

llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id }, true);
llama_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true);

client.n_decoded += 1;
}

if (batch.n_tokens == 0) {
// all sequences have ended - clear the entire KV cache
for (int i = 0; i < n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
for (int i = 1; i <= n_clients; ++i) {
llama_kv_cache_seq_rm(ctx, i, -1, -1);
// but keep the system prompt
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
}

LOG_TEE("%s: clearing the KV cache\n", __func__);
Expand All @@ -255,7 +260,7 @@ int main(int argc, char ** argv) {
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);

for (size_t i = 0; i < tokens_prompt.size(); ++i) {
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id }, false);
llama_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false);
}

// extract the logits only for the last token
Expand Down Expand Up @@ -366,7 +371,8 @@ int main(int argc, char ** argv) {
}

// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);

const auto t_main_end = ggml_time_us();

Expand Down
9 changes: 6 additions & 3 deletions examples/perplexity/perplexity.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
const int n_batch = params.n_batch;

const int max_tasks_per_batch = 32;
const int max_seq = 4*max_tasks_per_batch;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx));

llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);

Expand Down Expand Up @@ -1086,7 +1086,7 @@ static void winogrande_score(llama_context * ctx, const gpt_params & params) {
const int n_batch = params.n_batch;

const int max_tasks_per_batch = 128;
const int max_seq = 2*max_tasks_per_batch;
const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_max_seq(ctx));

llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);

Expand Down Expand Up @@ -1438,7 +1438,7 @@ static void multiple_choice_score(llama_context * ctx, const gpt_params & params
const int n_batch = params.n_batch;

const int max_tasks_per_batch = 32;
const int max_seq = 4*max_tasks_per_batch;
const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_max_seq(ctx));

llama_batch batch = llama_batch_init(n_ctx, 0, max_seq);

Expand Down Expand Up @@ -1815,6 +1815,9 @@ int main(int argc, char ** argv) {
llama_model * model;
llama_context * ctx;

// ensure there's at least enough seq_ids for HellaSwag
params.n_parallel = std::max(4, params.n_parallel);

// load the model and apply lora adapter, if any
std::tie(model, ctx) = llama_init_from_gpt_params(params);
if (model == NULL) {
Expand Down
Loading

0 comments on commit 5c6dc3d

Please sign in to comment.