From 62944df1425016018b5a137ebd0af5495eaa4894 Mon Sep 17 00:00:00 2001 From: Brandon Roberts Date: Mon, 18 Dec 2023 12:27:11 -0700 Subject: [PATCH] Bugfix: Remove f16_kv, add offload_kqv field (#1019) F16_KV appears to have been removed here: https://github.com/ggerganov/llama.cpp/pull/4312/commits/af99c6fbfc815df7dad94d8c1f20d55927b2203a This addresses two issues: - #995 which just requests to add the KV cache offloading param - #1006 a NULL ptr exception when using the embeddings (introduced by leaving f16_kv in the fields struct) --- llama_cpp/llama.py | 5 ----- llama_cpp/llama_cpp.py | 6 +++--- llama_cpp/server/app.py | 2 -- 3 files changed, 3 insertions(+), 10 deletions(-) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 15307ab3b..491d7cbef 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -751,7 +751,6 @@ def __init__( yarn_beta_slow: float = 1.0, yarn_orig_ctx: int = 0, mul_mat_q: bool = True, - f16_kv: bool = True, logits_all: bool = False, embedding: bool = False, # Sampling Params @@ -817,7 +816,6 @@ def __init__( yarn_beta_fast: YaRN low correction dim yarn_beta_slow: YaRN high correction dim yarn_orig_ctx: YaRN original context size - f16_kv: Use fp16 for KV cache, fp32 otherwise logits_all: Return logits for all tokens, not just the last token. Must be True for completion to return logprobs. embedding: Embedding mode only. last_n_tokens_size: Maximum number of tokens to keep in the last_n_tokens deque. @@ -904,7 +902,6 @@ def __init__( ) self.context_params.yarn_orig_ctx = yarn_orig_ctx if yarn_orig_ctx != 0 else 0 self.context_params.mul_mat_q = mul_mat_q - # self.context_params.f16_kv = f16_kv self.context_params.logits_all = logits_all self.context_params.embedding = embedding @@ -2155,7 +2152,6 @@ def __getstate__(self): yarn_beta_slow=self.context_params.yarn_beta_slow, yarn_orig_ctx=self.context_params.yarn_orig_ctx, mul_mat_q=self.context_params.mul_mat_q, - f16_kv=self.context_params.f16_kv, logits_all=self.context_params.logits_all, embedding=self.context_params.embedding, # Sampling Params @@ -2198,7 +2194,6 @@ def __setstate__(self, state): yarn_beta_slow=state["yarn_beta_slow"], yarn_orig_ctx=state["yarn_orig_ctx"], mul_mat_q=state["mul_mat_q"], - f16_kv=state["f16_kv"], logits_all=state["logits_all"], embedding=state["embedding"], # Sampling Params diff --git a/llama_cpp/llama_cpp.py b/llama_cpp/llama_cpp.py index 82c7187e6..538e3ff16 100644 --- a/llama_cpp/llama_cpp.py +++ b/llama_cpp/llama_cpp.py @@ -432,9 +432,9 @@ class llama_context_params(Structure): type_k (int): data type for K cache type_v (int): data type for V cache mul_mat_q (bool): if true, use experimental mul_mat_q kernels (DEPRECATED - always true) - f16_kv (bool): use fp16 for KV cache, fp32 otherwise logits_all (bool): the llama_eval() call computes all logits, not just the last one (DEPRECATED - set llama_batch.logits instead) - embedding (bool): embedding mode only""" + embedding (bool): embedding mode only + offload_kqv (bool): whether to offload the KQV ops (including the KV cache) to GPU""" _fields_ = [ ("seed", c_uint32), ("n_ctx", c_uint32), @@ -452,9 +452,9 @@ class llama_context_params(Structure): ("type_k", c_int), ("type_v", c_int), ("mul_mat_q", c_bool), - ("f16_kv", c_bool), ("logits_all", c_bool), ("embedding", c_bool), + ("offload_kqv", c_bool), ] diff --git a/llama_cpp/server/app.py b/llama_cpp/server/app.py index 7138cf403..9e76ebdf4 100644 --- a/llama_cpp/server/app.py +++ b/llama_cpp/server/app.py @@ -98,7 +98,6 @@ class Settings(BaseSettings): mul_mat_q: bool = Field( default=True, description="if true, use experimental mul_mat_q kernels" ) - f16_kv: bool = Field(default=True, description="Whether to use f16 key/value.") logits_all: bool = Field(default=True, description="Whether to return logits.") embedding: bool = Field(default=True, description="Whether to use embeddings.") # Sampling Params @@ -408,7 +407,6 @@ def create_app(settings: Optional[Settings] = None): yarn_beta_slow=settings.yarn_beta_slow, yarn_orig_ctx=settings.yarn_orig_ctx, mul_mat_q=settings.mul_mat_q, - f16_kv=settings.f16_kv, logits_all=settings.logits_all, embedding=settings.embedding, # Sampling Params