Skip to content

Commit

Permalink
stablelm : StableLM support (#3586)
Browse files Browse the repository at this point in the history
* Add support for stablelm-3b-4e1t
* Supports GPU offloading of (n-1) layers
  • Loading branch information
Galunid authored Nov 14, 2023
1 parent b46d12f commit 36eed0c
Show file tree
Hide file tree
Showing 6 changed files with 322 additions and 12 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ as the main playground for developing new features for the [ggml](https://github
- [X] [Persimmon 8B](https://github.com/ggerganov/llama.cpp/pull/3410)
- [X] [MPT](https://github.com/ggerganov/llama.cpp/pull/3417)
- [X] [Bloom](https://github.com/ggerganov/llama.cpp/pull/3553)
- [X] [StableLM-3b-4e1t](https://github.com/ggerganov/llama.cpp/pull/3586)


**Bindings:**
Expand Down
30 changes: 19 additions & 11 deletions convert-hf-to-gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ def load_hparams(dir_model):

@staticmethod
def from_model_architecture(model_architecture):
if model_architecture == "StableLMEpochForCausalLM":
return StableLMModel
if model_architecture == "GPTNeoXForCausalLM":
return GPTNeoXModel
if model_architecture == "BloomForCausalLM":
Expand All @@ -168,6 +166,8 @@ def from_model_architecture(model_architecture):
return RefactModel
if model_architecture == "PersimmonForCausalLM":
return PersimmonModel
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return StableLMModel
return Model

def _is_model_safetensors(self) -> bool:
Expand Down Expand Up @@ -201,6 +201,8 @@ def _get_model_architecture(self) -> gguf.MODEL_ARCH:
return gguf.MODEL_ARCH.REFACT
if arch == "PersimmonForCausalLM":
return gguf.MODEL_ARCH.PERSIMMON
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return gguf.MODEL_ARCH.STABLELM

raise NotImplementedError(f'Architecture "{arch}" not supported!')

Expand Down Expand Up @@ -294,15 +296,6 @@ def _set_vocab_sentencepiece(self):
special_vocab.add_to_gguf(self.gguf_writer)


class StableLMModel(Model):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_rope_dimension_count(
int(self.hparams["rope_pct"] * (self.hparams["hidden_size"] // self.hparams["num_attention_heads"])),
)
self.gguf_writer.add_layer_norm_eps(1e-5)


class GPTNeoXModel(Model):
def set_gguf_parameters(self):
block_count = self.hparams["num_hidden_layers"]
Expand Down Expand Up @@ -824,6 +817,21 @@ def write_tensors(self):
self.gguf_writer.add_tensor(new_name, data)


class StableLMModel(Model):
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]

self.gguf_writer.add_name(dir_model.name)
self.gguf_writer.add_context_length(hparams["max_position_embeddings"])
self.gguf_writer.add_embedding_length(hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(int(hparams["rope_pct"]*(hparams["hidden_size"] // hparams["num_attention_heads"])))
self.gguf_writer.add_head_count(hparams["num_attention_heads"])
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
self.gguf_writer.add_layer_norm_eps(1e-5)

###### CONVERSION LOGIC ######

def parse_args() -> argparse.Namespace:
Expand Down
17 changes: 17 additions & 0 deletions gguf-py/gguf/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class MODEL_ARCH(IntEnum):
REFACT = auto()
BERT = auto()
BLOOM = auto()
STABLELM = auto()


class MODEL_TENSOR(IntEnum):
Expand Down Expand Up @@ -129,6 +130,7 @@ class MODEL_TENSOR(IntEnum):
MODEL_ARCH.REFACT: "refact",
MODEL_ARCH.BERT: "bert",
MODEL_ARCH.BLOOM: "bloom",
MODEL_ARCH.STABLELM: "stablelm",
}

TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
Expand Down Expand Up @@ -299,6 +301,21 @@ class MODEL_TENSOR(IntEnum):
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.STABLELM: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.GPT2: [
# TODO
],
Expand Down
Loading

0 comments on commit 36eed0c

Please sign in to comment.