diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index e41be76db0820d..5af497a3ce3214 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -276,6 +276,11 @@ jobs: run: | xcodebuild -scheme llama -destination "${{ matrix.destination }}" + - name: Build Swift Example + id: make_build_swift_example + run: | + make swift + windows-latest-cmake: runs-on: windows-latest diff --git a/Makefile b/Makefile index 40187c4a25e621..87e7bb604c0c8c 100644 --- a/Makefile +++ b/Makefile @@ -617,6 +617,11 @@ metal: examples/metal/metal.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) endif +ifeq ($(UNAME_S),Darwin) +swift: examples/batched.swift + (cd examples/batched.swift; make build) +endif + build-info.h: $(wildcard .git/index) scripts/build-info.sh @sh scripts/build-info.sh $(CC) > $@.tmp @if ! cmp -s $@.tmp $@; then \ @@ -637,7 +642,7 @@ benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o run-benchmark-matmult: benchmark-matmult ./$@ -.PHONY: run-benchmark-matmult +.PHONY: run-benchmark-matmult swift vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) diff --git a/Package.swift b/Package.swift index 1ea414cc149fde..4ab055b19da2e5 100644 --- a/Package.swift +++ b/Package.swift @@ -1,10 +1,10 @@ -// swift-tools-version:5.3 +// swift-tools-version:5.5 import PackageDescription #if arch(arm) || arch(arm64) let platforms: [SupportedPlatform]? = [ - .macOS(.v11), + .macOS(.v12), .iOS(.v14), .watchOS(.v4), .tvOS(.v14) @@ -41,12 +41,13 @@ let package = Package( "ggml.c", "llama.cpp", "ggml-alloc.c", + "ggml-backend.c", "k_quants.c", ] + additionalSources, resources: resources, publicHeadersPath: "spm-headers", cSettings: [ - .unsafeFlags(["-Wno-shorten-64-to-32"]), + .unsafeFlags(["-Wno-shorten-64-to-32", "-O3", "-DNDEBUG"]), .define("GGML_USE_K_QUANTS"), .define("GGML_USE_ACCELERATE") // NOTE: NEW_LAPACK will required iOS version 16.4+ diff --git a/README.md b/README.md index 0562795620e69d..0f1fd756569269 100644 --- a/README.md +++ b/README.md @@ -96,6 +96,8 @@ as the main playground for developing new features for the [ggml](https://github - [X] [Starcoder models](https://github.com/ggerganov/llama.cpp/pull/3187) - [X] [Mistral AI v0.1](https://huggingface.co/mistralai/Mistral-7B-v0.1) - [X] [Refact](https://huggingface.co/smallcloudai/Refact-1_6B-fim) +- [X] [Bloom](https://github.com/ggerganov/llama.cpp/pull/3553) +- [X] [MPT](https://github.com/ggerganov/llama.cpp/pull/3417) **Bindings:** diff --git a/convert-bloom-hf-to-gguf.py b/convert-bloom-hf-to-gguf.py new file mode 100755 index 00000000000000..7bfc95ec11daef --- /dev/null +++ b/convert-bloom-hf-to-gguf.py @@ -0,0 +1,238 @@ +#!/usr/bin/env python3 +# HF bloom --> gguf conversion + +from __future__ import annotations + +import argparse +import json +import os +import re +import struct +import sys +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from transformers import AutoTokenizer # type: ignore[import] + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf')) +import gguf + + +def count_model_parts(dir_model: Path) -> int: + num_parts = 0 + for filename in os.listdir(dir_model): + if filename.startswith("pytorch_model-"): + num_parts += 1 + + if num_parts > 0: + print("gguf: found " + str(num_parts) + " model parts") + return num_parts + + +# Supported Models: +# https://huggingface.co/bigscience/bloom-1b7 +# https://huggingface.co/bigscience/bloom-3b +# https://huggingface.co/bigscience/bloom-7b1 +# https://huggingface.co/Langboat/bloom-1b4-zh +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Convert a Bloom model to a GGML compatible file") + parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab") + parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") + parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.bin)") + parser.add_argument("ftype", type=int, help="output format - use 0 for float32, 1 for float16", choices=[0, 1], default = 1) + return parser.parse_args() + +args = parse_args() + +dir_model = args.model +ftype = args.ftype +if not dir_model.is_dir(): + print(f'Error: {args.model} is not a directory', file = sys.stderr) + sys.exit(1) + +# possible tensor data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 + +# map from ftype to string +ftype_str = ["f32", "f16"] + +if args.outfile is not None: + fname_out = args.outfile +else: + # output in the same directory as the model by default + fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf' + +print("gguf: loading model "+dir_model.name) + +with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + +if hparams["architectures"][0] != "BloomForCausalLM": + print("Model architecture not supported: " + hparams["architectures"][0]) + sys.exit(1) + +# get number of model parts +num_parts = count_model_parts(dir_model) + +ARCH=gguf.MODEL_ARCH.BLOOM +gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) + +print("gguf: get model metadata") + +block_count = hparams["n_layer"] + +gguf_writer.add_name("Bloom") +n_embed = hparams.get("hidden_size", hparams.get("n_embed")) +n_head = hparams.get("n_head", hparams.get("num_attention_heads")) +gguf_writer.add_context_length(hparams.get("seq_length", n_embed)) +gguf_writer.add_embedding_length(n_embed) +gguf_writer.add_feed_forward_length(4 * n_embed) +gguf_writer.add_block_count(block_count) +gguf_writer.add_head_count(n_head) +gguf_writer.add_head_count_kv(n_head) +gguf_writer.add_layer_norm_eps(hparams["layer_norm_epsilon"]) +gguf_writer.add_file_type(ftype) + +# TOKENIZATION + +print("gguf: get tokenizer metadata") + +tokens: list[bytearray] = [] +scores: list[float] = [] +toktypes: list[int] = [] + +# gpt2 tokenizer +gguf_writer.add_tokenizer_model("gpt2") + +print("gguf: get gpt2 tokenizer vocab") + +# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py +tokenizer = AutoTokenizer.from_pretrained(dir_model) + +# The number of tokens in tokenizer.json can differ from the expected vocab size. +# This causes downstream issues with mismatched tensor sizes when running the inference +vocab_size = hparams.get("vocab_size", len(tokenizer.vocab)) +assert max(tokenizer.vocab.values()) < vocab_size + +reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()} + +for i in range(vocab_size): + tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]") + scores.append(0.0) # dummy + toktypes.append(gguf.TokenType.NORMAL) + +gguf_writer.add_token_list(tokens) +gguf_writer.add_token_scores(scores) +gguf_writer.add_token_types(toktypes) + +special_vocab = gguf.SpecialVocab(dir_model, load_merges=True) +special_vocab.add_to_gguf(gguf_writer) + +# TENSORS + +tensor_map = gguf.get_tensor_name_map(ARCH, block_count) + +# params for qkv transform +n_head_kv = hparams.get("n_head_kv", n_head) +head_dim = n_embed // n_head + +# tensor info +print("gguf: get tensor metadata") + +if num_parts == 0: + part_names = iter(("pytorch_model.bin",)) +else: + part_names = ( + f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) + ) + +for part_name in part_names: + if args.vocab_only: + break + print("gguf: loading model part '" + part_name + "'") + model_part = torch.load(dir_model / part_name, map_location="cpu") + + has_lm_head = True + if "lm_head.weight" not in model_part.keys() and "output.weight" not in model_part.keys(): + has_lm_head = False + + for original_name in model_part.keys(): + data = model_part[original_name] + name = re.sub(r'transformer\.', '', original_name) + + old_dtype = data.dtype + + # convert any unsupported data types to float32 + if data.dtype != torch.float16 and data.dtype != torch.float32: + data = data.to(torch.float32) + + data = data.squeeze().numpy() + + if re.match(r"h\.\d+\.self_attention\.query_key_value\.weight", name): + # Map bloom-style qkv_linear to gpt-style qkv_linear + # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252 # noqa + # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312 # noqa + qkv_weights = data.reshape((n_head, 3, n_embed // n_head, n_embed)) + data = np.concatenate( + (qkv_weights[:, 0, :, :].reshape((-1, n_embed)), + qkv_weights[:, 1, :, :].reshape((-1, n_embed)), + qkv_weights[:, 2, :, :].reshape((-1, n_embed))), + axis=0 + ) + print("re-format attention.linear_qkv.weight") + elif re.match(r"h\.\d+\.self_attention\.query_key_value\.bias", name): + qkv_bias = data.reshape((n_head, 3, n_embed // n_head)) + data = np.concatenate( + (qkv_bias[:, 0, :].reshape((n_embed,)), + qkv_bias[:, 1, :].reshape((n_embed,)), + qkv_bias[:, 2, :].reshape((n_embed,))), + axis=0 + ) + print("re-format attention.linear_qkv.bias") + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias")) + if new_name is None: + print("Can not map tensor '" + name + "'") + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(name, "=>", new_name + ", shape = " + str(data.shape) + ", " + str(old_dtype) + " --> " + str(data.dtype)) + + gguf_writer.add_tensor(new_name, data) + + if not has_lm_head and name == "word_embeddings.weight": + gguf_writer.add_tensor("output.weight", data) + print(name, "=>", "output.weight" + ", shape = " + str(data.shape) + ", " + str(old_dtype) + " --> " + str(data.dtype)) # noqa + + +print("gguf: write header") +gguf_writer.write_header_to_file() +print("gguf: write metadata") +gguf_writer.write_kv_data_to_file() +if not args.vocab_only: + print("gguf: write tensors") + gguf_writer.write_tensors_to_file() + +gguf_writer.close() + +print(f"gguf: model successfully exported to '{fname_out}'") +print("") diff --git a/convert-mpt-hf-to-gguf.py b/convert-mpt-hf-to-gguf.py new file mode 100755 index 00000000000000..73a4932f7c831b --- /dev/null +++ b/convert-mpt-hf-to-gguf.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python3 +# HF mpt--> gguf conversion + +from __future__ import annotations + +import argparse +import json +import os +import struct +import sys +from pathlib import Path +from typing import Any + +import numpy as np +import torch +from transformers import AutoTokenizer # type: ignore[import] + +if 'NO_LOCAL_GGUF' not in os.environ: + sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf')) +import gguf + + +def count_model_parts(dir_model: Path) -> int: + num_parts = 0 + for filename in os.listdir(dir_model): + if filename.startswith("pytorch_model-"): + num_parts += 1 + + if num_parts > 0: + print("gguf: found " + str(num_parts) + " model parts") + return num_parts + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Convert an MPT model to a GGML compatible file") + parser.add_argument( + "--vocab-only", action="store_true", + help="extract only the vocab", + ) + parser.add_argument( + "--outfile", type=Path, + help="path to write to; default: based on input", + ) + parser.add_argument( + "model", type=Path, + help="directory containing model file, or model file itself (*.bin)", + ) + parser.add_argument( + "ftype", type=int, choices=[0, 1], default=1, nargs='?', + help="output format - use 0 for float32, 1 for float16", + ) + return parser.parse_args() + +args = parse_args() + +dir_model = args.model +ftype = args.ftype +if not dir_model.is_dir(): + print(f'Error: {args.model} is not a directory', file = sys.stderr) + sys.exit(1) + +# possible tensor data types +# ftype == 0 -> float32 +# ftype == 1 -> float16 + +# map from ftype to string +ftype_str = ["f32", "f16"] + +if args.outfile is not None: + fname_out = args.outfile +else: + # output in the same directory as the model by default + fname_out = dir_model / f'ggml-model-{ftype_str[ftype]}.gguf' + +print("gguf: loading model "+dir_model.name) + +with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + +if hparams["architectures"][0] != "MPTForCausalLM": + print("Model architecture not supported: " + hparams["architectures"][0]) + + sys.exit() + +# get number of model parts +num_parts = count_model_parts(dir_model) + +ARCH=gguf.MODEL_ARCH.MPT +gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH]) + +print("gguf: get model metadata") + +block_count = hparams["n_layers"] + +gguf_writer.add_name(dir_model.name) +gguf_writer.add_context_length(hparams["max_seq_len"]) +gguf_writer.add_embedding_length(hparams["d_model"]) +gguf_writer.add_block_count(block_count) +gguf_writer.add_feed_forward_length(4 * hparams["d_model"]) +gguf_writer.add_head_count(hparams["n_heads"]) +gguf_writer.add_layer_norm_eps(1e-05) +if hparams["attn_config"]["clip_qkv"] is not None: + gguf_writer.add_clamp_kqv(hparams["attn_config"]["clip_qkv"]) +gguf_writer.add_max_alibi_bias(hparams["attn_config"]["alibi_bias_max"]) + +# TOKENIZATION + +print("gguf: get tokenizer metadata") + +tokens: list[bytearray] = [] +scores: list[float] = [] +toktypes: list[int] = [] + +# gpt2 tokenizer +gguf_writer.add_tokenizer_model("gpt2") + +print("gguf: get gpt2 tokenizer vocab") + +# MPT token embedding tensors have dimension 50432 (hparams["vocab_size"]), but +# there are only 50254 (len(tokenizer.vocab)) tokens in the vocab, presumably to +# accomodate some "reserved" tokens; this is causing problems down the line in +# llama.cpp, so we pad the vocab with dummy tokens: + +vocab_size = hparams["vocab_size"] + +# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py +tokenizer = AutoTokenizer.from_pretrained(dir_model) + +reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()} + +for i in range(vocab_size): + tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]") + scores.append(0.0) # dummy + toktypes.append(gguf.TokenType.NORMAL) + +gguf_writer.add_token_list(tokens) +gguf_writer.add_token_scores(scores) +gguf_writer.add_token_types(toktypes) + +special_vocab = gguf.SpecialVocab(dir_model, load_merges = True) +special_vocab.add_to_gguf(gguf_writer) + +# TENSORS + +tensor_map = gguf.get_tensor_name_map(ARCH,block_count) + +# tensor info +print("gguf: get tensor metadata") + +if num_parts == 0: + part_names = iter(("pytorch_model.bin",)) +else: + part_names = ( + f"pytorch_model-{n:05}-of-{num_parts:05}.bin" for n in range(1, num_parts + 1) + ) + +for part_name in part_names: + if args.vocab_only: + break + print("gguf: loading model part '" + part_name + "'") + model_part = torch.load(f"{dir_model}/{part_name}", map_location="cpu") + + for name in model_part.keys(): + data = model_part[name] + + old_dtype = data.dtype + + # convert any unsupported data types to float32 + if data.dtype != torch.float16 and data.dtype != torch.float32: + data = data.to(torch.float32) + + data = data.squeeze().numpy() + + # map tensor names + new_name = tensor_map.get_name(name, try_suffixes = (".weight", ".bias")) + if new_name is None: + print("Cannot map tensor '" + name + "'") + continue # for the sake of compatibility with some old published models, don't quit + sys.exit() + + n_dims = len(data.shape) + data_dtype = data.dtype + + # if f32 desired, convert any float16 to float32 + if ftype == 0 and data_dtype == np.float16: + data = data.astype(np.float32) + + # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32 + if ftype == 1 and data_dtype == np.float16 and n_dims == 1: + data = data.astype(np.float32) + + # if f16 desired, convert any float32 2-dim weight tensors to float16 + if ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2: + data = data.astype(np.float16) + + print(new_name + ", n_dims = " + str(n_dims) + ", " + str(old_dtype) + " --> " + str(data.dtype)) + + gguf_writer.add_tensor(new_name, data) + + # note: MPT output is tied to (same as) wte in original model; + # for easier implementation in llama.cpp it's duplicated in GGUF, though :/ + if new_name == "token_embd.weight": + gguf_writer.add_tensor("output.weight", data) + +print("gguf: write header") +gguf_writer.write_header_to_file() +print("gguf: write metadata") +gguf_writer.write_kv_data_to_file() +if not args.vocab_only: + print("gguf: write tensors") + gguf_writer.write_tensors_to_file() + +gguf_writer.close() + +print(f"gguf: model successfully exported to '{fname_out}'") +print("") diff --git a/examples/batched.swift/.gitignore b/examples/batched.swift/.gitignore new file mode 100644 index 00000000000000..e1e863bec6d5de --- /dev/null +++ b/examples/batched.swift/.gitignore @@ -0,0 +1,9 @@ +.DS_Store +/.build +/Packages +xcuserdata/ +DerivedData/ +.swiftpm/configuration/registries.json +.swiftpm/xcode/package.xcworkspace/contents.xcworkspacedata +.netrc +batched_swift diff --git a/examples/batched.swift/Makefile b/examples/batched.swift/Makefile new file mode 100755 index 00000000000000..2afb24fb85a1a3 --- /dev/null +++ b/examples/batched.swift/Makefile @@ -0,0 +1,6 @@ +.PHONY: build + +build: + xcodebuild -scheme batched_swift -destination "generic/platform=macOS" -derivedDataPath build + rm -f ./batched_swift + ln -s ./build/Build/Products/Debug/batched_swift ./batched_swift diff --git a/examples/batched.swift/Package.swift b/examples/batched.swift/Package.swift new file mode 100644 index 00000000000000..826491defd8631 --- /dev/null +++ b/examples/batched.swift/Package.swift @@ -0,0 +1,22 @@ +// swift-tools-version: 5.5 +// The swift-tools-version declares the minimum version of Swift required to build this package. + +import PackageDescription + +let package = Package( + name: "batched_swift", + platforms: [.macOS(.v12)], + dependencies: [ + .package(name: "llama", path: "../../"), + ], + targets: [ + // Targets are the basic building blocks of a package, defining a module or a test suite. + // Targets can depend on other targets in this package and products from dependencies. + .executableTarget( + name: "batched_swift", + dependencies: ["llama"], + path: "Sources", + linkerSettings: [.linkedFramework("Foundation"), .linkedFramework("AppKit")] + ), + ] +) diff --git a/examples/batched.swift/README.md b/examples/batched.swift/README.md new file mode 100644 index 00000000000000..464c9079c46608 --- /dev/null +++ b/examples/batched.swift/README.md @@ -0,0 +1,4 @@ +This is a swift clone of `examples/batched`. + +$ `make` +$ `./swift MODEL_PATH [PROMPT] [PARALLEL]` diff --git a/examples/batched.swift/Sources/main.swift b/examples/batched.swift/Sources/main.swift new file mode 100644 index 00000000000000..938f30512ca6a8 --- /dev/null +++ b/examples/batched.swift/Sources/main.swift @@ -0,0 +1,255 @@ +import Foundation +import llama + +let arguments = CommandLine.arguments + +// Check that we have at least one argument (the model path) +guard arguments.count > 1 else { + print("Usage: swift MODEL_PATH [PROMPT] [PARALLEL]") + exit(1) +} + +let modelPath: String = arguments[1] +let prompt: String = arguments.count > 2 ? arguments[2] : "Hello my name is" +let n_parallel: Int = arguments.count > 3 && Int(arguments[3]) != nil ? Int(arguments[3])! : 1 + +// total length of the sequences including the prompt +let n_len: Int = 32 + +// init LLM +llama_backend_init(false) +defer { + llama_backend_free() +} + +let model_params = llama_model_default_params() +guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), model_params) else { + print("Failed to load model") + exit(1) +} + +defer { + llama_free_model(model) +} + +var tokens = tokenize(text: prompt, add_bos: true) + +let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel) + +var context_params = llama_context_default_params() +context_params.seed = 1234 +context_params.n_ctx = n_kv_req +context_params.n_batch = UInt32(max(n_len, n_parallel)) +context_params.n_threads = 8 +context_params.n_threads_batch = 8 + +let context = llama_new_context_with_model(model, context_params) +guard context != nil else { + print("Failed to initialize context") + exit(1) +} + +defer { + llama_free(context) +} + +let n_ctx = llama_n_ctx(context) + +print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n") + +if n_kv_req > n_ctx { + print("error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", n_kv_req) + exit(1) +} + +var buffer: [CChar] = [] +for id: llama_token in tokens { + print(token_to_piece(token: id, buffer: &buffer) ?? "", terminator: "") +} + +print("\n") + +var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0) +defer { + llama_batch_free(batch) +} + +// evaluate the initial prompt +batch.n_tokens = Int32(tokens.count) + +for (i, token) in tokens.enumerated() { + batch.token[i] = token + batch.pos[i] = Int32(i) + batch.seq_id[i] = 0 + batch.logits[i] = 0 +} + +// llama_decode will output logits only for the last token of the prompt +batch.logits[Int(batch.n_tokens) - 1] = 1 + +if llama_decode(context, batch) != 0 { + print("llama_decode() failed") + exit(1) +} + +for i in 1 ..< n_parallel { + llama_kv_cache_seq_cp(context, 0, Int32(i), 0, batch.n_tokens) +} + +if n_parallel > 1 { + print("generating \(n_parallel) sequences ...\n") +} + +var streams: [String] = .init(repeating: "", count: n_parallel) +var streamBuffers: [[CChar]] = .init(repeating: [], count: n_parallel) +var i_batch = [Int32](repeating: batch.n_tokens - 1, count: n_parallel) + +var n_cur = batch.n_tokens +var n_decode = 0 + +let t_main_start = ggml_time_us() + +while n_cur <= n_len { + // prepare the next batch + batch.n_tokens = 0 + + // sample the next token for each parallel sequence / stream + for i in 0 ..< n_parallel { + if i_batch[i] < 0 { + // the stream has already finished + continue + } + + var n_vocab = llama_n_vocab(model) + var logits = llama_get_logits_ith(context, i_batch[i]) + + var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab)) + + for token_id in 0 ..< n_vocab { + candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0)) + } + + var candidates_p: llama_token_data_array = .init( + data: &candidates, + size: candidates.count, + sorted: false + ) + + let top_k: Int32 = 40 + let top_p: Float = 0.9 + let temp: Float = 0.4 + + llama_sample_top_k(context, &candidates_p, top_k, 1) + llama_sample_top_p(context, &candidates_p, top_p, 1) + llama_sample_temp(context, &candidates_p, temp) + + let new_token_id = llama_sample_token(context, &candidates_p) + + // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); + + // is it an end of stream? -> mark the stream as finished + if new_token_id == llama_token_eos(context) || n_cur == n_len { + i_batch[i] = -1 + // print("") + if n_parallel > 1 { + print("stream \(i) finished at n_cur = \(n_cur)") + } + + continue + } + + let nextStringPiece = token_to_piece(token: new_token_id, buffer: &streamBuffers[i]) ?? "" + + // if there is only one stream, we print immediately to stdout + if n_parallel == 1 { + print(nextStringPiece, terminator: "") + } + streams[i] += nextStringPiece + + // push this new token for next evaluation + batch.token[Int(batch.n_tokens)] = new_token_id + batch.pos[Int(batch.n_tokens)] = n_cur + batch.seq_id[Int(batch.n_tokens)] = Int32(i) + batch.logits[Int(batch.n_tokens)] = 1 + + i_batch[i] = batch.n_tokens + + batch.n_tokens += 1 + + n_decode += 1 + } + + // all streams are finished + if batch.n_tokens == 0 { + break + } + + n_cur += 1 + + // evaluate the current batch with the transformer model + if llama_decode(context, batch) != 0 { + print("llama_decode() failed") + exit(1) + } +} + +if n_parallel > 1 { + print("\n") + for (i, stream) in streams.enumerated() { + print("sequence \(i):\n\n\(prompt)\(stream)\n") + } +} + +let t_main_end = ggml_time_us() + +print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n") + +llama_print_timings(context) + +private func tokenize(text: String, add_bos: Bool) -> [llama_token] { + let n_tokens = text.count + (add_bos ? 1 : 0) + let tokens = UnsafeMutablePointer.allocate(capacity: n_tokens) + let tokenCount = llama_tokenize(model, text, Int32(text.count), tokens, Int32(n_tokens), add_bos) + var swiftTokens: [llama_token] = [] + for i in 0 ..< tokenCount { + swiftTokens.append(tokens[Int(i)]) + } + tokens.deallocate() + return swiftTokens +} + +private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? { + var result = [CChar](repeating: 0, count: 8) + let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count)) + if nTokens < 0 { + if result.count >= -Int(nTokens) { + result.removeLast(-Int(nTokens)) + } else { + result.removeAll() + } + let check = llama_token_to_piece( + model, + token, + &result, + Int32(result.count) + ) + assert(check == nTokens) + } else { + result.removeLast(result.count - Int(nTokens)) + } + if buffer.isEmpty, let utfString = String(cString: result + [0], encoding: .utf8) { + return utfString + } else { + buffer.append(contentsOf: result) + let data = Data(buffer.map { UInt8(bitPattern: $0) }) + if buffer.count >= 4 { // 4 bytes is the max length of a utf8 character so if we're here we need to reset the buffer + buffer = [] + } + guard let bufferString = String(data: data, encoding: .utf8) else { + return nil + } + buffer = [] + return bufferString + } + return nil +} diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 9ec75ce425b2a6..d994de5e850c3e 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -233,10 +233,22 @@ int main(int argc, char ** argv) { const bool add_bos = llama_vocab_type(model) == LLAMA_VOCAB_TYPE_SPM; LOG("add_bos: %d\n", add_bos); + bool suff_rm_leading_spc = params.escape; + if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } std::vector embd_inp; - std::vector inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos); - std::vector inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos); + std::vector inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false); + std::vector inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false); + const int space_token = 29871; + if (suff_rm_leading_spc && inp_sfx[0] == space_token) { + inp_sfx.erase(inp_sfx.begin()); + } inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx)); + if (add_bos) { + inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx)); + } inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx)); embd_inp = inp_pfx; embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); @@ -627,10 +639,27 @@ int main(int argc, char ** argv) { buffer.clear(); // done taking input, reset color console::set_display(console::reset); + + if (params.escape) { + //process escape sequences, for the initial prompt this is done in common.cpp when we load the params, but for the interactive mode we need to do it here + process_escapes(params.input_prefix); + process_escapes(params.input_suffix); + } + suff_rm_leading_spc = params.escape; + if (suff_rm_leading_spc && params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } // tokenize new prefix and suffix - std::vector inp_pfx = ::llama_tokenize(ctx, params.input_prefix, add_bos); - std::vector inp_sfx = ::llama_tokenize(ctx, params.input_suffix, add_bos); + std::vector inp_pfx = ::llama_tokenize(ctx, params.input_prefix, false); + std::vector inp_sfx = ::llama_tokenize(ctx, params.input_suffix, false); + if (suff_rm_leading_spc && inp_sfx[0] == space_token) { + inp_sfx.erase(inp_sfx.begin()); + } inp_pfx.insert(inp_pfx.begin(), llama_token_prefix(ctx)); + if (add_bos) { + inp_pfx.insert(inp_pfx.begin(), llama_token_bos(ctx)); + } inp_sfx.insert(inp_sfx.begin(), llama_token_suffix(ctx)); embd_inp = inp_pfx; embd_inp.insert(embd_inp.end(), inp_sfx.begin(), inp_sfx.end()); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index c53a64867336f9..8c5318c650ae8e 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -344,9 +344,20 @@ struct llama_server_context void loadInfill() { - auto prefix_tokens = tokenize(params.input_prefix, true); // always add BOS - auto suffix_tokens = tokenize(params.input_suffix, true); // always add BOS + bool suff_rm_leading_spc = true; + if (params.input_suffix.find_first_of(" ") == 0 && params.input_suffix.size() > 1) { + params.input_suffix.erase(0, 1); + suff_rm_leading_spc = false; + } + + auto prefix_tokens = tokenize(params.input_prefix, false); + auto suffix_tokens = tokenize(params.input_suffix, false); + const int space_token = 29871; + if (suff_rm_leading_spc && suffix_tokens[0] == space_token) { + suffix_tokens.erase(suffix_tokens.begin()); + } prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(ctx)); + prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(ctx)); // always add BOS prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(ctx)); prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end()); prefix_tokens.push_back(llama_token_middle(ctx)); diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 7e92c519741b97..654d3632fc179a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -415,6 +415,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_SCALE_BLOCK_SIZE 256 +#define CUDA_CLAMP_BLOCK_SIZE 256 #define CUDA_ROPE_BLOCK_SIZE 256 #define CUDA_ALIBI_BLOCK_SIZE 32 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 @@ -4585,6 +4586,15 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale dst[i] = scale * x[i]; } +static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); +} template static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) { @@ -5475,6 +5485,11 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons scale_f32<<>>(x, dst, scale, k); } +static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE; + clamp_f32<<>>(x, dst, min, max, k); +} + template static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { @@ -6419,12 +6434,12 @@ inline void ggml_cuda_op_alibi( const int64_t ne02 = src0->ne[2]; const int64_t nrows = ggml_nrows(src0); - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - GGML_ASSERT(ne01 + n_past == ne00); + //GGML_ASSERT(ne01 + n_past == ne00); GGML_ASSERT(n_head == ne02); const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); @@ -6500,6 +6515,24 @@ inline void ggml_cuda_op_scale( (void) src1_dd; } +inline void ggml_cuda_op_clamp( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const float min = ((float *) dst->op_params)[0]; + const float max = ((float *) dst->op_params)[1]; + + clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream); + CUDA_CHECK(cudaGetLastError()); + + (void) src1; + (void) dst; + (void) src1_dd; +} + static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) { const int64_t nrows0 = ggml_nrows(src0); @@ -7061,6 +7094,10 @@ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } +static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp); +} + static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -7470,6 +7507,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_SCALE: func = ggml_cuda_scale; break; + case GGML_OP_CLAMP: + if (!any_on_device) { + return false; + } + func = ggml_cuda_clamp; + break; case GGML_OP_CPY: func = ggml_cuda_cpy; break; diff --git a/ggml-metal.m b/ggml-metal.m index 5a23144d0c8913..87fa172161405a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1299,7 +1299,7 @@ void ggml_metal_graph_compute( const int nth = MIN(1024, ne00); - const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); diff --git a/ggml.c b/ggml.c index 5bb1da31ba624d..1f5598fa6af8f9 100644 --- a/ggml.c +++ b/ggml.c @@ -13059,13 +13059,11 @@ static void ggml_compute_forward_alibi_f32( return; } - const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - assert(n_past >= 0); - const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int64_t ne1 = src0->ne[1]; // seq_len_without_past const int64_t ne2 = src0->ne[2]; // n_head -> this is k diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index 8eb7e6372553de..79815bc39f54d2 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -88,30 +88,32 @@ class MODEL_ARCH(IntEnum): PERSIMMON : int = auto() REFACT : int = auto() BERT : int = auto() + BLOOM : int = auto() PLAMO : int = auto() class MODEL_TENSOR(IntEnum): - TOKEN_EMBD : int = auto() - TOKEN_TYPES : int = auto() - POS_EMBD : int = auto() - OUTPUT : int = auto() - OUTPUT_NORM : int = auto() - ROPE_FREQS : int = auto() - ATTN_Q : int = auto() - ATTN_K : int = auto() - ATTN_V : int = auto() - ATTN_QKV : int = auto() - ATTN_OUT : int = auto() - ATTN_NORM : int = auto() - ATTN_NORM_2 : int = auto() - ATTN_ROT_EMBD: int = auto() - FFN_GATE : int = auto() - FFN_DOWN : int = auto() - FFN_UP : int = auto() - FFN_NORM : int = auto() - ATTN_Q_NORM : int = auto() - ATTN_K_NORM : int = auto() + TOKEN_EMBD : int = auto() + TOKEN_EMBD_NORM : int = auto() + TOKEN_TYPES : int = auto() + POS_EMBD : int = auto() + OUTPUT : int = auto() + OUTPUT_NORM : int = auto() + ROPE_FREQS : int = auto() + ATTN_Q : int = auto() + ATTN_K : int = auto() + ATTN_V : int = auto() + ATTN_QKV : int = auto() + ATTN_OUT : int = auto() + ATTN_NORM : int = auto() + ATTN_NORM_2 : int = auto() + ATTN_ROT_EMBD : int = auto() + FFN_GATE : int = auto() + FFN_DOWN : int = auto() + FFN_UP : int = auto() + FFN_NORM : int = auto() + ATTN_Q_NORM : int = auto() + ATTN_K_NORM : int = auto() MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { @@ -126,30 +128,32 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.PERSIMMON: "persimmon", MODEL_ARCH.REFACT: "refact", MODEL_ARCH.BERT: "bert", + MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.PLAMO: "plamo", } TENSOR_NAMES: dict[MODEL_TENSOR, str] = { - MODEL_TENSOR.TOKEN_EMBD: "token_embd", - MODEL_TENSOR.TOKEN_TYPES: "token_types", - MODEL_TENSOR.POS_EMBD: "position_embd", - MODEL_TENSOR.OUTPUT_NORM: "output_norm", - MODEL_TENSOR.OUTPUT: "output", - MODEL_TENSOR.ROPE_FREQS: "rope_freqs", - MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", - MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", - MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", - MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", - MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", - MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", - MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", - MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", - MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", - MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", - MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", - MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", - MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", - MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", + MODEL_TENSOR.TOKEN_EMBD: "token_embd", + MODEL_TENSOR.TOKEN_EMBD_NORM: "token_embd_norm", + MODEL_TENSOR.TOKEN_TYPES: "token_types", + MODEL_TENSOR.POS_EMBD: "position_embd", + MODEL_TENSOR.OUTPUT_NORM: "output_norm", + MODEL_TENSOR.OUTPUT: "output", + MODEL_TENSOR.ROPE_FREQS: "rope_freqs", + MODEL_TENSOR.ATTN_NORM: "blk.{bid}.attn_norm", + MODEL_TENSOR.ATTN_NORM_2: "blk.{bid}.attn_norm_2", + MODEL_TENSOR.ATTN_QKV: "blk.{bid}.attn_qkv", + MODEL_TENSOR.ATTN_Q: "blk.{bid}.attn_q", + MODEL_TENSOR.ATTN_K: "blk.{bid}.attn_k", + MODEL_TENSOR.ATTN_V: "blk.{bid}.attn_v", + MODEL_TENSOR.ATTN_OUT: "blk.{bid}.attn_output", + MODEL_TENSOR.ATTN_ROT_EMBD: "blk.{bid}.attn_rot_embd", + MODEL_TENSOR.ATTN_Q_NORM: "blk.{bid}.attn_q_norm", + MODEL_TENSOR.ATTN_K_NORM: "blk.{bid}.attn_k_norm", + MODEL_TENSOR.FFN_NORM: "blk.{bid}.ffn_norm", + MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", + MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", + MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", } MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { @@ -284,6 +288,18 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.BLOOM: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.TOKEN_EMBD_NORM, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -328,6 +344,7 @@ class TensorNameMap: "gpt_neox.embed_in", # gptneox "transformer.wte", # gpt2 gpt-j mpt refact "transformer.word_embeddings", # falcon + "word_embeddings", # bloom "model.embed_tokens", # llama-hf "tok_embeddings", # llama-pth "embeddings.word_embeddings", # bert @@ -339,6 +356,11 @@ class TensorNameMap: "embeddings.token_type_embeddings", # bert ), + # Normalization of token embeddings + MODEL_TENSOR.TOKEN_EMBD_NORM: ( + "word_embeddings_layernorm", # bloom + ), + # Position embeddings MODEL_TENSOR.POS_EMBD: ( "transformer.wpe", # gpt2 @@ -349,7 +371,7 @@ class TensorNameMap: MODEL_TENSOR.OUTPUT: ( "embed_out", # gptneox "lm_head", # gpt2 mpt falcon llama-hf baichuan - "output", # llama-pth + "output", # llama-pth bloom "word_embeddings_for_head", # persimmon ), @@ -361,7 +383,7 @@ class TensorNameMap: "norm", # llama-pth "embeddings.LayerNorm", # bert "transformer.norm_f", # mpt - "ln_f", # refact + "ln_f", # refact bloom "language_model.encoder.final_layernorm", # persimmon ), @@ -378,6 +400,7 @@ class TensorNameMap: "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact "transformer.blocks.{bid}.norm_1", # mpt "transformer.h.{bid}.input_layernorm", # falcon7b + "h.{bid}.input_layernorm", # bloom "transformer.h.{bid}.ln_mlp", # falcon40b "model.layers.{bid}.input_layernorm", # llama-hf "layers.{bid}.attention_norm", # llama-pth @@ -397,6 +420,7 @@ class TensorNameMap: "transformer.h.{bid}.attn.c_attn", # gpt2 "transformer.blocks.{bid}.attn.Wqkv", # mpt "transformer.h.{bid}.self_attention.query_key_value", # falcon + "h.{bid}.self_attention.query_key_value", # bloom "language_model.encoder.layers.{bid}.self_attention.query_key_value", # persimmon ), @@ -433,6 +457,7 @@ class TensorNameMap: "transformer.h.{bid}.attn.c_proj", # gpt2 refact "transformer.blocks.{bid}.attn.out_proj", # mpt "transformer.h.{bid}.self_attention.dense", # falcon + "h.{bid}.self_attention.dense", # bloom "model.layers.{bid}.self_attn.o_proj", # llama-hf "layers.{bid}.attention.wo", # llama-pth "encoder.layer.{bid}.attention.output.dense", # bert @@ -452,6 +477,7 @@ class TensorNameMap: MODEL_TENSOR.FFN_NORM: ( "gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox "transformer.h.{bid}.ln_2", # gpt2 refact + "h.{bid}.post_attention_layernorm", # bloom "transformer.blocks.{bid}.norm_2", # mpt "model.layers.{bid}.post_attention_layernorm", # llama-hf "layers.{bid}.ffn_norm", # llama-pth @@ -465,6 +491,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.c_fc", # gpt2 "transformer.blocks.{bid}.ffn.up_proj", # mpt "transformer.h.{bid}.mlp.dense_h_to_4h", # falcon + "h.{bid}.mlp.dense_h_to_4h", # bloom "model.layers.{bid}.mlp.up_proj", # llama-hf refact "layers.{bid}.feed_forward.w3", # llama-pth "encoder.layer.{bid}.intermediate.dense", # bert @@ -486,6 +513,7 @@ class TensorNameMap: "transformer.h.{bid}.mlp.c_proj", # gpt2 refact "transformer.blocks.{bid}.ffn.down_proj", # mpt "transformer.h.{bid}.mlp.dense_4h_to_h", # falcon + "h.{bid}.mlp.dense_4h_to_h", # bloom "model.layers.{bid}.mlp.down_proj", # llama-hf "layers.{bid}.feed_forward.w2", # llama-pth "encoder.layer.{bid}.output.dense", # bert diff --git a/llama.cpp b/llama.cpp index 097cabd8e1deea..9868c8184da3e0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -188,6 +188,7 @@ enum llm_arch { LLM_ARCH_STARCODER, LLM_ARCH_PERSIMMON, LLM_ARCH_REFACT, + LLM_ARCH_BLOOM, LLM_ARCH_PLAMO, LLM_ARCH_UNKNOWN, }; @@ -202,8 +203,9 @@ static std::map LLM_ARCH_NAMES = { { LLM_ARCH_BAICHUAN, "baichuan" }, { LLM_ARCH_STARCODER, "starcoder" }, { LLM_ARCH_PERSIMMON, "persimmon" }, - { LLM_ARCH_REFACT, "refact" }, - { LLM_ARCH_PLAMO, "plamo" }, + { LLM_ARCH_REFACT, "refact" }, + { LLM_ARCH_BLOOM, "bloom" }, + { LLM_ARCH_PLAMO, "plamo" }, }; enum llm_kv { @@ -306,6 +308,7 @@ struct LLM_KV { enum llm_tensor { LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_TOKEN_EMBD_NORM, LLM_TENSOR_POS_EMBD, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, @@ -426,6 +429,14 @@ static std::map> LLM_TENSOR_NAMES = LLM_ARCH_MPT, { { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, { @@ -460,6 +471,21 @@ static std::map> LLM_TENSOR_NAMES = { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, }, }, + { + LLM_ARCH_BLOOM, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + }, + }, { LLM_ARCH_PLAMO, { @@ -1031,6 +1057,9 @@ struct llama_hparams { float rope_freq_base_train; float rope_freq_scale_train; + float f_clamp_kqv; + float f_max_alibi_bias; + bool operator!=(const llama_hparams & other) const { if (this->vocab_only != other.vocab_only) return true; if (this->n_vocab != other.n_vocab) return true; @@ -1216,6 +1245,8 @@ struct llama_model { struct ggml_tensor * tok_embeddings; struct ggml_tensor * pos_embeddings; + struct ggml_tensor * tok_norm; + struct ggml_tensor * tok_norm_b; struct ggml_tensor * output_norm; struct ggml_tensor * output_norm_b; @@ -2065,13 +2096,13 @@ static void llm_load_hparams( } } break; case LLM_ARCH_PERSIMMON: - { - GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); - switch (hparams.n_layer) { - case 36: model.type = e_model::MODEL_8B; break; - default: model.type = e_model::MODEL_UNKNOWN; - } - } break; + { + GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); + switch (hparams.n_layer) { + case 36: model.type = e_model::MODEL_8B; break; + default: model.type = e_model::MODEL_UNKNOWN; + } + } break; case LLM_ARCH_REFACT: { GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); @@ -2080,11 +2111,30 @@ static void llm_load_hparams( default: model.type = e_model::MODEL_UNKNOWN; } } break; - case LLM_ARCH_PLAMO: + case LLM_ARCH_BLOOM: { - GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); + GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); + switch (hparams.n_layer) { - case 40: model.type = e_model::MODEL_13B; break; + case 24: model.type = e_model::MODEL_1B; break; + case 30: + switch (hparams.n_embd) { + case 2560: model.type = e_model::MODEL_3B; break; + case 4096: model.type = e_model::MODEL_7B; break; + } break; + } + } break; + case LLM_ARCH_MPT: + { + hparams.f_clamp_kqv = 0.0f; + + GGUF_GET_KEY(ctx, hparams.f_norm_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_EPS)); + GGUF_GET_KEY(ctx, hparams.f_clamp_kqv, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_CLAMP_KQV)); + GGUF_GET_KEY(ctx, hparams.f_max_alibi_bias, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_MAX_ALIBI_BIAS)); + + switch (hparams.n_layer) { + case 32: model.type = e_model::MODEL_7B; break; + case 48: model.type = e_model::MODEL_30B; break; default: model.type = e_model::MODEL_UNKNOWN; } } break; @@ -2232,6 +2282,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) { LLAMA_LOG_INFO("%s: n_gqa = %u\n", __func__, hparams.n_gqa()); LLAMA_LOG_INFO("%s: f_norm_eps = %.1e\n", __func__, hparams.f_norm_eps); LLAMA_LOG_INFO("%s: f_norm_rms_eps = %.1e\n", __func__, hparams.f_norm_rms_eps); + LLAMA_LOG_INFO("%s: f_clamp_kqv = %.1e\n", __func__, hparams.f_clamp_kqv); + LLAMA_LOG_INFO("%s: f_max_alibi_bias = %.1e\n", __func__, hparams.f_max_alibi_bias); LLAMA_LOG_INFO("%s: n_ff = %u\n", __func__, hparams.n_ff); LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); @@ -2677,6 +2729,155 @@ static void llm_load_tensors( layer.attn_k_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), {64}, backend); } } break; + case LLM_ARCH_BLOOM: + { + // TODO: CPU-only for now + + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + model.tok_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, GGML_BACKEND_CPU); + model.tok_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD_NORM, "bias"), {n_embd}, GGML_BACKEND_CPU); + + // output + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = LLAMA_BACKEND_OFFLOAD; +#else + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#endif // _WIN32 + + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + vram_weights += ggml_nbytes(model.output_norm_b); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.attn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, backend); + + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); + layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, backend_split); + + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + layer.bo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, backend_split); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + layer.ffn_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, backend); + + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, backend_split); + layer.b2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, backend_split); + + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + layer.b3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) + + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) + + ggml_nbytes(layer.wo) + ggml_nbytes(layer.bo) + + ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_norm_b) + + ggml_nbytes(layer.w3) + ggml_nbytes(layer.b3) + + ggml_nbytes(layer.w2) + ggml_nbytes(layer.b2); + } + } + } break; + case LLM_ARCH_MPT: + { + model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); + + // output + { + ggml_backend_type backend_norm; + ggml_backend_type backend_output; + + if (n_gpu_layers > int(n_layer)) { + // norm is not performance relevant on its own but keeping it in VRAM reduces data copying + // on Windows however this is detrimental unless everything is on the GPU +#ifndef _WIN32 + backend_norm = LLAMA_BACKEND_OFFLOAD; +#else + backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; +#endif // _WIN32 + + backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT; + } else { + backend_norm = GGML_BACKEND_CPU; + backend_output = GGML_BACKEND_CPU; + } + + model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); + model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); + + if (backend_norm == GGML_BACKEND_GPU) { + vram_weights += ggml_nbytes(model.output_norm); + } + if (backend_output == GGML_BACKEND_GPU_SPLIT) { + vram_weights += ggml_nbytes(model.output); + } + } + + const uint32_t n_ff = hparams.n_ff; + + const int i_gpu_start = n_layer - n_gpu_layers; + + model.layers.resize(n_layer); + + for (uint32_t i = 0; i < n_layer; ++i) { + const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT + const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT + + auto & layer = model.layers[i]; + + layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend); + layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, 3*n_embd}, backend_split); + layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split); + + layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend); + + layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); + layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); + + if (backend == GGML_BACKEND_GPU) { + vram_weights += + ggml_nbytes(layer.attn_norm) + + ggml_nbytes(layer.wqkv) + + ggml_nbytes(layer.wo) + + ggml_nbytes(layer.ffn_norm) + + ggml_nbytes(layer.w2) + + ggml_nbytes(layer.w3); + } + } + } break; case LLM_ARCH_PLAMO: { model.tok_embeddings = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU); @@ -4599,7 +4800,6 @@ static struct ggml_cgraph * llm_build_starcoder( return gf; } - static struct ggml_cgraph * llm_build_persimmon( llama_context & lctx, const llama_batch & batch) { @@ -4997,9 +5197,9 @@ static struct ggml_cgraph * llm_build_persimmon( return gf; } -static struct ggml_cgraph * llm_build_plamo( - llama_context & lctx, - const llama_batch & batch) { +static struct ggml_cgraph * llm_build_bloom( + llama_context & lctx, + const llama_batch & batch) { const auto & model = lctx.model; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; @@ -5018,33 +5218,28 @@ static struct ggml_cgraph * llm_build_plamo( GGML_ASSERT(n_embd_head == hparams.n_rot); - const float freq_base = cparams.rope_freq_base; - const float freq_scale = cparams.rope_freq_scale; - const float norm_rms_eps = hparams.f_norm_rms_eps; - - const int n_gpu_layers = model.n_gpu_layers; + const float norm_eps = hparams.f_norm_eps; const int32_t n_tokens = batch.n_tokens; const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; - const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; - - //printf("n_kv = %d\n", n_kv); - auto & buf_compute = lctx.buf_compute; struct ggml_init_params params = { /*.mem_size =*/ buf_compute.size, /*.mem_buffer =*/ buf_compute.data, - /*.no_alloc =*/ true, + /*.no_alloc =*/ false, }; + params.no_alloc = true; + struct ggml_context * ctx0 = ggml_init(params); ggml_cgraph * gf = ggml_new_graph(ctx0); struct ggml_tensor * cur; + struct ggml_tensor * token; struct ggml_tensor * inpL; if (batch.token) { @@ -5056,53 +5251,30 @@ static struct ggml_cgraph * llm_build_plamo( } ggml_set_name(inp_tokens, "inp_tokens"); - inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + token = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); } else { #ifdef GGML_USE_MPI GGML_ASSERT(false && "not implemented"); #endif - inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + token = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); - ggml_allocr_alloc(lctx.alloc, inpL); + ggml_allocr_alloc(lctx.alloc, token); if (!ggml_allocr_is_measure(lctx.alloc)) { - memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); + memcpy(token->data, batch.embd, n_tokens * n_embd * ggml_element_size(token)); } } - - const int i_gpu_start = n_layer - n_gpu_layers; - (void) i_gpu_start; - - // offload functions set the tensor output backend to GPU - // tensors are GPU-accelerated if any input or the output has been offloaded - offload_func_t offload_func_nr = llama_nop; // nr = non-repeating - offload_func_t offload_func_kq = llama_nop; - offload_func_t offload_func_v = llama_nop; - -#ifdef GGML_USE_CUBLAS - if (n_gpu_layers > n_layer) { - offload_func_nr = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 1) { - offload_func_v = ggml_cuda_assign_buffers_no_alloc; - } - if (n_gpu_layers > n_layer + 2) { - offload_func_kq = ggml_cuda_assign_buffers_no_alloc; - } -#endif // GGML_USE_CUBLAS - // KQ_scale struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); ggml_allocr_alloc(lctx.alloc, KQ_scale); if (!ggml_allocr_is_measure(lctx.alloc)) { - ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); } // KQ_mask (mask for 1 head, it will be broadcasted to all heads) struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); - offload_func_kq(KQ_mask); ggml_set_name(KQ_mask, "KQ_mask"); ggml_allocr_alloc(lctx.alloc, KQ_mask); if (!ggml_allocr_is_measure(lctx.alloc)) { @@ -5123,32 +5295,619 @@ static struct ggml_cgraph * llm_build_plamo( } } - // KQ_pos - contains the positions - struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); - offload_func_kq(KQ_pos); - ggml_set_name(KQ_pos, "KQ_pos"); - ggml_allocr_alloc(lctx.alloc, KQ_pos); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) KQ_pos->data; - for (int i = 0; i < n_tokens; ++i) { - data[i] = batch.pos[i]; - } + // norm + { + inpL = ggml_norm(ctx0, token, norm_eps); + inpL = ggml_add(ctx0, ggml_mul(ctx0, inpL, model.tok_norm), model.tok_norm_b); } - // shift the entire K-cache if needed - if (do_rope_shift) { - struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); - offload_func_kq(K_shift); - ggml_set_name(K_shift, "K_shift"); - ggml_allocr_alloc(lctx.alloc, K_shift); - if (!ggml_allocr_is_measure(lctx.alloc)) { - int * data = (int *) K_shift->data; - for (int i = 0; i < n_ctx; ++i) { - data[i] = kv_self.cells[i].delta; - } - } + ggml_set_name(inpL, "inpL"); - for (int il = 0; il < n_layer; ++il) { + for (int il = 0; il < n_layer; ++il) { + { + // Norm + cur = ggml_norm(ctx0, inpL, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].attn_norm), model.layers[il].attn_norm_b); + } + + { + // Self Attention + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wqkv, cur), model.layers[il].bqkv); + + struct ggml_tensor * tmpq = ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*n_embd); + struct ggml_tensor * tmpk = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*n_embd); + struct ggml_tensor * tmpv = ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], sizeof(float)*(n_embd + n_embd_gqa)); + + struct ggml_tensor * Qcur = tmpq; + struct ggml_tensor * Kcur = tmpk; + + // store key and value to memory + { + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * Q = + ggml_permute(ctx0, + ggml_cpy(ctx0, + Qcur, + ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_embd_head, n_head, n_tokens)), + 0, 2, 1, 3); + ggml_set_name(Q, "Q"); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_kv, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + ggml_set_name(K, "K"); + + // K * Q + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + ggml_set_name(KQ, "KQ"); + + // KQ_scaled = KQ / sqrt(n_embd_head) + // KQ_scaled shape [n_past + n_tokens, n_tokens, n_head, 1] + struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); + ggml_set_name(KQ_scaled, "KQ_scaled"); + + struct ggml_tensor * KQ_scaled_alibi = ggml_alibi(ctx0, KQ_scaled, /*n_past*/ kv_head, n_head, 8); + ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); + + // KQ_masked = mask_past(KQ_scaled) + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); + ggml_set_name(KQ_masked, "KQ_masked"); + + // KQ = soft_max(KQ_masked) + struct ggml_tensor * KQ_soft_max = ggml_soft_max_inplace(ctx0, KQ_masked); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + // split cached V into n_head heads + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + ggml_set_name(V, "V"); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + ggml_set_name(KQV, "KQV"); + + // KQV_merged = KQV.permute(0, 2, 1, 3) + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + ggml_set_name(KQV_merged, "KQV_merged"); + + // cur = KQV_merged.contiguous().view(n_embd, n_tokens) + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); + ggml_set_name(cur, "KQV_merged_contiguous"); + } + + // Projection + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].wo, cur), model.layers[il].bo); + + // Add the input + cur = ggml_add(ctx0, cur, inpL); + + struct ggml_tensor * inpFF = cur; + + // FF + { + // Norm + { + cur = ggml_norm(ctx0, inpFF, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.layers[il].ffn_norm), model.layers[il].ffn_norm_b); + } + + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w3, cur), model.layers[il].b3); + + // GELU activation + cur = ggml_gelu(ctx0, cur); + + // Projection + cur = ggml_add(ctx0, ggml_mul_mat(ctx0, model.layers[il].w2, cur), model.layers[il].b2); + } + + inpL = ggml_add(ctx0, cur, inpFF); + } + + // Output Norm + { + cur = ggml_norm(ctx0, inpL, norm_eps); + cur = ggml_add(ctx0, ggml_mul(ctx0, cur, model.output_norm), model.output_norm_b); + } + ggml_set_name(cur, "result_norm"); + + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * llm_build_mpt( + llama_context & lctx, + const llama_batch & batch) { + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; // == n_head for MPT, as there's no MQA/GQA + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + const float norm_eps = hparams.f_norm_eps; + const float clamp_kqv = hparams.f_clamp_kqv; + const float max_alibi_bias = hparams.f_max_alibi_bias; + + const int n_gpu_layers = model.n_gpu_layers; + + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ false, + }; + + params.no_alloc = true; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + //int warmup = 0; + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); + //warmup = ((uint32_t*) inp_tokens->data)[0] == 0; + } + + ggml_set_name(inp_tokens, "inp_tokens"); + + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inpL); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); + } + } + + const int i_gpu_start = n_layer - n_gpu_layers; + (void) i_gpu_start; + + // offload functions set the tensor output backend to GPU + // tensors are GPU-accelerated if any input or the output has been offloaded + offload_func_t offload_func_nr = llama_nop; // nr = non-repeating + offload_func_t offload_func_kq = llama_nop; + offload_func_t offload_func_v = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd)/n_head)); + } + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + for (int il = 0; il < n_layer; ++il) { + struct ggml_tensor * attn_norm; + + offload_func_t offload_func = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (il >= i_gpu_start) { + offload_func = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + // self-attention + // TODO: refactor into common function (shared with LLaMA) + { + attn_norm = ggml_norm(ctx0, inpL, norm_eps); + offload_func(attn_norm); + + attn_norm = ggml_mul(ctx0, attn_norm, model.layers[il].attn_norm); + offload_func(attn_norm); + + if (1) { + cur = attn_norm; + } + + // compute QKV + + cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur); + offload_func_kq(cur); + + if (clamp_kqv > 0.0f) { + cur = ggml_clamp(ctx0, cur, -clamp_kqv, clamp_kqv); + offload_func_kq(cur); + } + + const size_t wsize = ggml_type_size(cur->type); + + struct ggml_tensor * Qcur = ggml_view_3d( + ctx0, cur, n_embd_head, n_head, n_tokens, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + 0); + offload_func_kq(Qcur); + + struct ggml_tensor * Kcur = ggml_view_3d( + ctx0, cur, n_embd_head, n_head_kv, n_tokens, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + wsize * n_embd_head * n_head); + offload_func_kq(Kcur); + + struct ggml_tensor * tmpv = ggml_view_3d( + ctx0, cur, n_embd_head, n_head_kv, n_tokens, + wsize * n_embd_head, + wsize * n_embd_head * (n_head + 2 * n_head_kv), + wsize * n_embd_head * (n_head + n_head_kv)); + offload_func_kq(Kcur); + + ggml_set_name(Qcur, "Qcur"); + ggml_set_name(Kcur, "Kcur"); + + { + struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_cont(ctx0, tmpv), n_embd_gqa, n_tokens)); + offload_func_v(Vcur); + offload_func_v(Vcur->src[0]->src[0]); + ggml_set_name(Vcur, "Vcur"); + + struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, n_tokens*n_embd_gqa, (ggml_element_size(kv_self.k)*n_embd_gqa)*(il*n_ctx + kv_head)); + offload_func_kq(k); + ggml_set_name(k, "k"); + + struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, n_tokens, n_embd_gqa, + ( n_ctx)*ggml_element_size(kv_self.v), + (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd_gqa + kv_head*ggml_element_size(kv_self.v)); + offload_func_v(v); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); + } + + struct ggml_tensor * Q = ggml_permute(ctx0, Qcur, 0, 2, 1, 3); + offload_func_kq(Q); + ggml_set_name(Q, "Q"); + + struct ggml_tensor * K = + ggml_view_3d(ctx0, kv_self.k, + n_embd_head, n_kv, n_head_kv, + ggml_element_size(kv_self.k)*n_embd_gqa, + ggml_element_size(kv_self.k)*n_embd_head, + ggml_element_size(kv_self.k)*n_embd_gqa*n_ctx*il); + offload_func_kq(K); + ggml_set_name(K, "K"); + + struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); + offload_func_kq(KQ); + ggml_set_name(KQ, "KQ"); + + struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, KQ, KQ_scale); + offload_func_kq(KQ_scaled); + ggml_set_name(KQ_scaled, "KQ_scaled"); + + // TODO: replace with ggml_add() + struct ggml_tensor * KQ_scaled_alibi = + ggml_alibi(ctx0, KQ_scaled, 0, n_head, max_alibi_bias); + offload_func_kq(KQ_scaled_alibi); + ggml_set_name(KQ_scaled_alibi, "KQ_scaled_alibi"); + + struct ggml_tensor * KQ_masked = ggml_add(ctx0, KQ_scaled_alibi, KQ_mask); + offload_func_kq(KQ_masked); + ggml_set_name(KQ_masked, "KQ_masked"); + + struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); + offload_func_v(KQ_soft_max); + ggml_set_name(KQ_soft_max, "KQ_soft_max"); + + struct ggml_tensor * V = + ggml_view_3d(ctx0, kv_self.v, + n_kv, n_embd_head, n_head_kv, + ggml_element_size(kv_self.v)*n_ctx, + ggml_element_size(kv_self.v)*n_ctx*n_embd_head, + ggml_element_size(kv_self.v)*n_ctx*n_embd_gqa*il); + offload_func_v(V); + ggml_set_name(V, "V"); + + struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); + offload_func_v(KQV); + ggml_set_name(KQV, "KQV"); + + struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); + offload_func_v(KQV_merged); + ggml_set_name(KQV_merged, "KQV_merged"); + + cur = ggml_cont_2d(ctx0, KQV_merged, n_embd, n_tokens); + offload_func_v(cur); + ggml_set_name(cur, "KQV_merged_contiguous"); + + cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); + offload_func(cur); + ggml_set_name(cur, "result_wo"); + } + + // Add the input + cur = ggml_add(ctx0, cur, inpL); + offload_func(cur); + + struct ggml_tensor * attn_out = cur; + + // feed forward + { + // Norm + { + cur = ggml_norm(ctx0, attn_out, norm_eps); + offload_func(cur); + + cur = ggml_mul(ctx0, cur, model.layers[il].ffn_norm); + offload_func(cur); + } + + cur = ggml_mul_mat(ctx0, model.layers[il].w3, cur); + offload_func(cur); + + cur = ggml_gelu(ctx0, cur); + offload_func(cur); + cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur); + offload_func(cur); + } + + cur = ggml_add(ctx0, cur, attn_out); + offload_func(cur); + // input for next layer + inpL = cur; + } + + cur = inpL; + + // norm + { + cur = ggml_norm(ctx0, cur, norm_eps); + offload_func_nr(cur); + + cur = ggml_mul(ctx0, cur, model.output_norm); + ggml_set_name(cur, "result_norm"); + } + + cur = ggml_mul_mat(ctx0, model.output, cur); + ggml_set_name(cur, "result_output"); + + ggml_build_forward_expand(gf, cur); + + ggml_free(ctx0); + + return gf; +} + +static struct ggml_cgraph * llm_build_plamo( + llama_context & lctx, + const llama_batch & batch) { + const auto & model = lctx.model; + const auto & hparams = model.hparams; + const auto & cparams = lctx.cparams; + + const auto & kv_self = lctx.kv_self; + + GGML_ASSERT(!!kv_self.ctx); + + const int64_t n_embd = hparams.n_embd; + const int64_t n_layer = hparams.n_layer; + const int64_t n_ctx = cparams.n_ctx; + const int64_t n_head = hparams.n_head; + const int64_t n_head_kv = hparams.n_head_kv; + const int64_t n_embd_head = hparams.n_embd_head(); + const int64_t n_embd_gqa = hparams.n_embd_gqa(); + + GGML_ASSERT(n_embd_head == hparams.n_rot); + + const float freq_base = cparams.rope_freq_base; + const float freq_scale = cparams.rope_freq_scale; + const float norm_rms_eps = hparams.f_norm_rms_eps; + + const int n_gpu_layers = model.n_gpu_layers; + + const int32_t n_tokens = batch.n_tokens; + const int32_t n_kv = ggml_allocr_is_measure(lctx.alloc) ? n_ctx : kv_self.n; + const int32_t kv_head = ggml_allocr_is_measure(lctx.alloc) ? n_ctx - n_tokens : kv_self.head; + + const bool do_rope_shift = ggml_allocr_is_measure(lctx.alloc) || kv_self.has_shift; + + //printf("n_kv = %d\n", n_kv); + + auto & buf_compute = lctx.buf_compute; + + struct ggml_init_params params = { + /*.mem_size =*/ buf_compute.size, + /*.mem_buffer =*/ buf_compute.data, + /*.no_alloc =*/ true, + }; + + struct ggml_context * ctx0 = ggml_init(params); + + ggml_cgraph * gf = ggml_new_graph(ctx0); + + struct ggml_tensor * cur; + struct ggml_tensor * inpL; + + if (batch.token) { + struct ggml_tensor * inp_tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inp_tokens); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inp_tokens->data, batch.token, n_tokens*ggml_element_size(inp_tokens)); + } + ggml_set_name(inp_tokens, "inp_tokens"); + + inpL = ggml_get_rows(ctx0, model.tok_embeddings, inp_tokens); + } else { +#ifdef GGML_USE_MPI + GGML_ASSERT(false && "not implemented"); +#endif + + inpL = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, n_tokens); + + ggml_allocr_alloc(lctx.alloc, inpL); + if (!ggml_allocr_is_measure(lctx.alloc)) { + memcpy(inpL->data, batch.embd, n_tokens * n_embd * ggml_element_size(inpL)); + } + } + + + const int i_gpu_start = n_layer - n_gpu_layers; + (void) i_gpu_start; + + // offload functions set the tensor output backend to GPU + // tensors are GPU-accelerated if any input or the output has been offloaded + offload_func_t offload_func_nr = llama_nop; // nr = non-repeating + offload_func_t offload_func_kq = llama_nop; + offload_func_t offload_func_v = llama_nop; + +#ifdef GGML_USE_CUBLAS + if (n_gpu_layers > n_layer) { + offload_func_nr = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 1) { + offload_func_v = ggml_cuda_assign_buffers_no_alloc; + } + if (n_gpu_layers > n_layer + 2) { + offload_func_kq = ggml_cuda_assign_buffers_no_alloc; + } +#endif // GGML_USE_CUBLAS + + // KQ_scale + struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); + ggml_set_name(KQ_scale, "1/sqrt(n_embd_head)"); + ggml_allocr_alloc(lctx.alloc, KQ_scale); + if (!ggml_allocr_is_measure(lctx.alloc)) { + ggml_set_f32(KQ_scale, 1.0f/sqrtf(float(n_embd_head))); + } + + // KQ_mask (mask for 1 head, it will be broadcasted to all heads) + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1); + offload_func_kq(KQ_mask); + ggml_set_name(KQ_mask, "KQ_mask"); + ggml_allocr_alloc(lctx.alloc, KQ_mask); + if (!ggml_allocr_is_measure(lctx.alloc)) { + float * data = (float *) KQ_mask->data; + memset(data, 0, ggml_nbytes(KQ_mask)); + + for (int h = 0; h < 1; ++h) { + for (int j = 0; j < n_tokens; ++j) { + const llama_pos pos = batch.pos[j]; + const llama_seq_id seq_id = batch.seq_id[j]; + + for (int i = 0; i < n_kv; ++i) { + if (!kv_self.cells[i].has_seq_id(seq_id) || kv_self.cells[i].pos > pos) { + data[h*(n_kv*n_tokens) + j*n_kv + i] = -INFINITY; + } + } + } + } + } + + // KQ_pos - contains the positions + struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens); + offload_func_kq(KQ_pos); + ggml_set_name(KQ_pos, "KQ_pos"); + ggml_allocr_alloc(lctx.alloc, KQ_pos); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) KQ_pos->data; + for (int i = 0; i < n_tokens; ++i) { + data[i] = batch.pos[i]; + } + } + + // shift the entire K-cache if needed + if (do_rope_shift) { + struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_ctx); + offload_func_kq(K_shift); + ggml_set_name(K_shift, "K_shift"); + ggml_allocr_alloc(lctx.alloc, K_shift); + if (!ggml_allocr_is_measure(lctx.alloc)) { + int * data = (int *) K_shift->data; + for (int i = 0; i < n_ctx; ++i) { + data[i] = kv_self.cells[i].delta; + } + } + + for (int il = 0; il < n_layer; ++il) { struct ggml_tensor * tmp = ggml_rope_custom_inplace(ctx0, ggml_view_3d(ctx0, kv_self.k, @@ -5411,6 +6170,14 @@ static struct ggml_cgraph * llama_build_graph( { result = llm_build_refact(lctx, batch); } break; + case LLM_ARCH_BLOOM: + { + result = llm_build_bloom(lctx, batch); + } break; + case LLM_ARCH_MPT: + { + result = llm_build_mpt(lctx, batch); + } break; case LLM_ARCH_PLAMO: { result = llm_build_plamo(lctx, batch); @@ -5545,7 +6312,8 @@ static int llama_decode_internal( const bool full_offload_supported = model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_BAICHUAN || model.arch == LLM_ARCH_FALCON || - model.arch == LLM_ARCH_REFACT; + model.arch == LLM_ARCH_REFACT || + model.arch == LLM_ARCH_MPT; const bool fully_offloaded = model.n_gpu_layers >= (int) hparams.n_layer + 3; if (ggml_cpu_has_cublas() && full_offload_supported && fully_offloaded) { n_threads = 1; @@ -6046,7 +6814,6 @@ struct llm_tokenizer_bpe { for (int i = 0; i < (int)text_utf.size(); i++) { const std::string & utf_char = text_utf[i]; bool split_condition = false; - // const char* text_pos = raw_text_p + utf_char.seq_offset_bytes; int bytes_remain = text_utf.size() - i; // forward backward lookups const std::string & utf_char_next = (i + 1 < (int)text_utf.size()) ? text_utf[i + 1] : ""; @@ -6072,9 +6839,9 @@ struct llm_tokenizer_bpe { if (!split_condition && bytes_remain >= 3) { // 're|'ve|'ll if (utf_char == "\'" && ( - (utf_char_next == "r" || utf_char_next_next == "e") || - (utf_char_next == "v" || utf_char_next_next == "e") || - (utf_char_next == "l" || utf_char_next_next == "l")) + (utf_char_next == "r" && utf_char_next_next == "e") || + (utf_char_next == "v" && utf_char_next_next == "e") || + (utf_char_next == "l" && utf_char_next_next == "l")) ) { split_condition = true; } @@ -6125,7 +6892,7 @@ struct llm_tokenizer_bpe { else if (collecting_special && (codepoint_type(utf_char) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char) == CODEPOINT_TYPE_DIGIT || codepoint_type(utf_char) == CODEPOINT_TYPE_WHITESPACE)) { split_condition = true; } - else if (collecting_whitespace_lookahead && codepoint_type(utf_char_next) != CODEPOINT_TYPE_WHITESPACE) { + else if (collecting_whitespace_lookahead && (codepoint_type(utf_char_next) == CODEPOINT_TYPE_LETTER || codepoint_type(utf_char_next) == CODEPOINT_TYPE_DIGIT)) { split_condition = true; } } @@ -7641,7 +8408,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s const std::string name = ggml_get_name(meta); // TODO: avoid hardcoded tensor names - use the TN_* constants - if (name.find("attn_v.weight") != std::string::npos) { + if (name.find("attn_v.weight") != std::string::npos || name.find("attn_qkv.weight") != std::string::npos) { ++n_attention_wv; } else if (name.find("ffn_down.weight") != std::string::npos) { diff --git a/tests/test-tokenizer-0-falcon.cpp b/tests/test-tokenizer-0-falcon.cpp index 0f3c50bce8ae9d..a4e9d2b9127287 100644 --- a/tests/test-tokenizer-0-falcon.cpp +++ b/tests/test-tokenizer-0-falcon.cpp @@ -36,6 +36,8 @@ static const std::map> & k_tests() { { " Hello" , { 258, 23090, }, }, { " Hello" , { 466, 23090, }, }, { " Hello\n Hello" , { 466, 23090, 742, 23090, }, }, + { "\n =" , { 1212, 40, }, }, + { "' era" , { 18, 4932, }, }, }; return _k_tests; @@ -155,7 +157,7 @@ int main(int argc, char **argv) { fprintf(stderr, "%s : text size: %zu\n", __func__, text.size()); - const std::vector res = llama_tokenize(ctx, text, true); + const std::vector res = llama_tokenize(ctx, text, false); fprintf(stderr, "%s : tokens: %zu\n", __func__, res.size()); @@ -169,10 +171,8 @@ int main(int argc, char **argv) { } for (const auto & tok : res) { - ofs << tok << " "; + ofs << tok << " '" << llama_detokenize_bpe(ctx, std::vector{tok}) << "'" << std::endl; } - - ofs << "\n"; } fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str()); diff --git a/tests/test-tokenizer-0-falcon.py b/tests/test-tokenizer-0-falcon.py index 9c8c1c7d1d3ca4..cf65a3f65d72cc 100644 --- a/tests/test-tokenizer-0-falcon.py +++ b/tests/test-tokenizer-0-falcon.py @@ -41,6 +41,8 @@ " Hello", " Hello", " Hello\n Hello", + "\n =", + "' era", ] for text in tests: @@ -69,15 +71,14 @@ if fname_tok: print('tokenizing file: ', fname_tok) fname_out = fname_tok + '.tok' - with open(fname_tok, 'r') as f: + with open(fname_tok, 'r', encoding='utf-8') as f: lines = f.readlines() s = ''.join(lines) res = tokenizer.encode(s) # write to file - with open(fname_out, 'w') as f: + with open(fname_out, 'w', encoding='utf-8') as f: for x in res: - f.write(str(x) + ' ') - f.write('\n') + f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n') print('len(res): ', len(res)) print('len(lines): ', len(lines)) print('results written to: ', fname_out) diff --git a/tests/test-tokenizer-0-llama.cpp b/tests/test-tokenizer-0-llama.cpp index 91c841f7bba8f6..39c8d188c90861 100644 --- a/tests/test-tokenizer-0-llama.cpp +++ b/tests/test-tokenizer-0-llama.cpp @@ -174,10 +174,8 @@ int main(int argc, char **argv) { } for (const auto & tok : res) { - ofs << tok << " "; + ofs << tok << " '" << llama_detokenize_spm(ctx, std::vector{tok}) << "'" << std::endl; } - - ofs << "\n"; } fprintf(stderr, "%s : tokens written to '%s'\n", __func__, (fname_text + ".tokcpp").c_str()); diff --git a/tests/test-tokenizer-0-llama.py b/tests/test-tokenizer-0-llama.py index bc164ee296cb1d..078f680b165ca1 100644 --- a/tests/test-tokenizer-0-llama.py +++ b/tests/test-tokenizer-0-llama.py @@ -81,15 +81,14 @@ if fname_tok: print('tokenizing file: ', fname_tok) fname_out = fname_tok + '.tok' - with open(fname_tok, 'r') as f: + with open(fname_tok, 'r', encoding='utf-8') as f: lines = f.readlines() s = ''.join(lines) res = tokenizer.encode(s, add_bos=True) # write to file - with open(fname_out, 'w') as f: + with open(fname_out, 'w', encoding='utf-8') as f: for x in res: - f.write(str(x) + ' ') - f.write('\n') + f.write(str(x) + ' \'' + tokenizer.decode(x) + '\'\n') print('len(res): ', len(res)) print('len(lines): ', len(lines)) print('results written to: ', fname_out)