Skip to content

Commit

Permalink
Replay changes from #3871
Browse files Browse the repository at this point in the history
Credit to @cebtenzzre for that pull
  • Loading branch information
KerfuffleV2 committed Nov 7, 2023
1 parent b8c80df commit 8047aa1
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 31 deletions.
75 changes: 49 additions & 26 deletions gguf-py/gguf/gguf_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import struct
import tempfile
from io import BufferedWriter
from typing import Any, BinaryIO, Sequence
from enum import Enum, auto
from typing import Any, IO, Sequence

import numpy as np

Expand All @@ -21,18 +22,16 @@
TokenType,
)

class WriterState(Enum):
EMPTY = auto()
HEADER = auto()
KV_DATA = auto()
TI_DATA = auto()

class GGUFWriter:
fout: BufferedWriter
arch: str
offset_tensor = 0
data_alignment = GGUF_DEFAULT_ALIGNMENT
kv_data = b""
kv_data_count = 0
ti_data = b""
ti_data_count = 0
use_temp_file: bool
temp_file: tempfile.SpooledTemporaryFile[bytes] | None = None
tensors: list[tuple[np.ndarray[Any, Any], int]]
temp_file: tempfile.SpooledTemporaryFile[bytes] | None
tensors: list[np.ndarray[Any, Any]]
_simple_value_packing = {
GGUFValueType.UINT8: "B",
GGUFValueType.INT8: "b",
Expand Down Expand Up @@ -60,27 +59,47 @@ def __init__(self, path: os.PathLike[str] | str, arch: str, use_temp_file: bool
self.fout = open(path, "wb")
self.arch = arch
self.endianess = endianess
self.add_architecture()
self.offset_tensor = 0
self.data_alignment = GGUF_DEFAULT_ALIGNMENT
self.kv_data = b""
self.kv_data_count = 0
self.ti_data = b""
self.ti_data_count = 0
self.use_temp_file = use_temp_file
self.temp_file = None
self.tensors = []
print("gguf: This GGUF file is for {0} Endian only"
.format("Big" if self.endianess == GGUFEndian.BIG else "Little"))
self.state = WriterState.EMPTY

self.add_architecture()

def write_header_to_file(self) -> None:
if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}')

self._write_packed("<I", GGUF_MAGIC, skip_pack_prefix = True)
self._write_packed("I", GGUF_VERSION)
self._write_packed("Q", self.ti_data_count)
self._write_packed("Q", self.kv_data_count)
self.flush()
# print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
self.state = WriterState.HEADER

def write_kv_data_to_file(self) -> None:
if self.state is not WriterState.HEADER:
raise ValueError(f'Expected output file to contain the header, got {self.state}')

self.fout.write(self.kv_data)
self.flush()
self.state = WriterState.KV_DATA

def write_ti_data_to_file(self) -> None:
if self.state is not WriterState.KV_DATA:
raise ValueError(f'Expected output file to contain KV data, got {self.state}')

self.fout.write(self.ti_data)
self.flush()
self.state = WriterState.TI_DATA

def add_key(self, key: str) -> None:
self.add_val(key, GGUFValueType.STRING, add_vtype=False)
Expand Down Expand Up @@ -173,6 +192,9 @@ def ggml_pad(x: int, n: int) -> int:
return ((x + n - 1) // n) * n

def add_tensor_info(self, name: str, tensor_shape: Sequence[int], tensor_dtype: np.dtype[np.float16] | np.dtype[np.float32], tensor_nbytes: int, raw_dtype: GGMLQuantizationType | None = None) -> None:
if self.state is not WriterState.EMPTY:
raise ValueError(f'Expected output file to be empty, got {self.state}')

if raw_dtype is None and tensor_dtype not in (np.float32, np.float16):
raise ValueError("Only F32 and F16 tensors are supported for now")

Expand Down Expand Up @@ -203,23 +225,21 @@ def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequenc
shape: Sequence[int] = raw_shape if raw_shape is not None else tensor.shape
self.add_tensor_info(name, shape, tensor.dtype, tensor.nbytes, raw_dtype = raw_dtype)

pad = GGUFWriter.ggml_pad(tensor.nbytes, self.data_alignment) - tensor.nbytes

if self.temp_file is None:
self.tensors.append((tensor, pad))
return
if self.temp_file is None:
self.tensors.append(tensor)

tensor.tofile(self.temp_file)
self.write_padding(self.temp_file, tensor.nbytes)

if pad != 0:
self.temp_file.write(bytes([0] * pad))

def write_padding(self, fp: BinaryIO, n: int, align: int | None = None) -> None:
def write_padding(self, fp: IO[bytes], n: int, align: int | None = None):
pad = GGUFWriter.ggml_pad(n, align if align is not None else self.data_alignment) - n
if pad != 0:
fp.write(bytes([0] * pad))

def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
if self.state is not WriterState.TI_DATA:
raise ValueError(f'Expected output file to contain tensor info, got {self.state}')

if self.endianess==GGUFEndian.BIG:
tensor.byteswap(inplace=True)
self.write_padding(self.fout, self.fout.tell())
Expand All @@ -232,10 +252,13 @@ def write_tensors_to_file(self) -> None:
self.write_padding(self.fout, self.fout.tell())

if self.temp_file is None:
for (currtensor, currpad) in self.tensors:
currtensor.tofile(self.fout)
if currpad != 0:
self.fout.write(bytes([0] * currpad))
while True:
try:
tensor = self.tensors.pop(0)
except IndexError:
break
tensor.tofile(self.fout)
self.write_padding(self.fout, tensor.nbytes)
return

self.temp_file.seek(0)
Expand Down
10 changes: 5 additions & 5 deletions gguf-py/gguf/vocab.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,8 @@
from .gguf_writer import GGUFWriter

class SpecialVocab:
load_merges: bool = False
merges: list[str] = []
special_token_types: tuple[str, ...] = ('bos', 'eos', 'unk', 'sep', 'pad')
special_token_ids: dict[str, int] = {}
n_vocab: int | None = None
merges: list[str]
special_token_ids: dict[str, int]

def __init__(
self, path: str | os.PathLike[str], load_merges: bool = False,
Expand All @@ -23,8 +20,11 @@ def __init__(
self.special_token_ids = {}
self.n_vocab = n_vocab
self.load_merges = load_merges
self.merges = []
if special_token_types is not None:
self.special_token_types = special_token_types
else:
self.special_token_types = ('bos', 'eos', 'unk', 'sep', 'pad')
self._load(Path(path))

def _load(self, path: Path) -> None:
Expand Down

0 comments on commit 8047aa1

Please sign in to comment.