Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gguf: better type usage #655

Merged
merged 11 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 46 additions & 24 deletions packages/gguf/scripts/generate-llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,27 +8,53 @@ import { writeFileSync } from "node:fs";
const SOURCE_CPP_URL = "https://raw.githubusercontent.com/ggerganov/llama.cpp/master/llama.cpp";
const DEST_FILE_PATH = "./src/transformer-llm.ts";
const DEST_COMMON_SOURCE = `
type Attention<TArchitecture extends string> =
& { [K in \`\${TArchitecture}.attention.head_count\`]: number }
& { [K in \`\${TArchitecture}.attention.head_count_kv\`]: number }
& { [K in \`\${TArchitecture}.attention.layer_norm_epsilon\`]: number }
& { [K in \`\${TArchitecture}.attention.layer_norm_rms_epsilon\`]: number }
& { [K in \`\${TArchitecture}.attention.alibi_bias_max\`]: number }
& { [K in \`\${TArchitecture}.attention.clip_kqv\`]: number }
& { [K in \`\${TArchitecture}.attention.use_norm\`]: number };

type Rope<TArchitecture extends LLMArchitecture> =
& { [K in \`\${TArchitecture}.rope.dimension_count\`]: number }
& { [K in \`\${TArchitecture}.rope.freq_base\`]: number }
& { [K in \`\${TArchitecture}.rope.scale\`]: number }
& { [K in \`\${TArchitecture}.rope.scale_linear\`]: number };

type MOE<TArchitecture extends LLMArchitecture> =
& { [K in \`\${TArchitecture}.expert_count\`]: number }
& { [K in \`\${TArchitecture}.expert_used_count\`]: number };
/** This file is auto-generated by generate-llm.ts */

import type { ModelBase, GGUFGeneralInfo } from "./types";

type LLMBase<TArchitecture extends string> = Partial<Record<
\`\${TArchitecture}.vocab_size\`
| \`\${TArchitecture}.use_parallel_residual\`
| \`\${TArchitecture}.tensor_data_layout\`,
number
>>;

type Attention<TArchitecture extends string> = Record<
\`\${TArchitecture}.attention.head_count\`,
number
> & Partial<Record<
\`\${TArchitecture}.attention.head_count_kv\`
| \`\${TArchitecture}.attention.key_length\`
| \`\${TArchitecture}.attention.value_length\`,
number
>>;

type RopeScalingType = "none" | "linear" | "yarn";
type Rope<TArchitecture extends LLMArchitecture> = Partial<
Record<
\`\${TArchitecture}.rope.dimension_count\`
| \`\${TArchitecture}.rope.freq_base\`
| \`\${TArchitecture}.rope.scale_linear\`
| \`\${TArchitecture}.rope.scaling.factor\`
| \`\${TArchitecture}.rope.scaling.original_context_length\`,
number
>
& Record<\`\${TArchitecture}.rope.scaling.type\`, RopeScalingType>
& Record<\`\${TArchitecture}.rope.finetuned\`, boolean>
>;

type MOE<TArchitecture extends LLMArchitecture> = Partial<
Record<
\`\${TArchitecture}.expert_count\`
| \`\${TArchitecture}.expert_used_count\`,
number
>
>;

export type TransformerLLMArchitecture = LLMArchitecture; // type alias
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = ModelBase<TArchitecture>
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = GGUFGeneralInfo<TArchitecture>
& LLMBase<TArchitecture>
& ModelBase<TArchitecture>
& MOE<TArchitecture>
& Attention<TArchitecture>
& Rope<TArchitecture>;
Expand Down Expand Up @@ -163,15 +189,11 @@ async function main() {
/////////////////////////////////////
// write result to file
const content = [
"/** This file is auto-generated by generate-llm.ts */",
"",
'import type { ModelBase } from "./types";',
"",
DEST_COMMON_SOURCE,
"export const LLM_ARCHITECTURES = [",
...archList.map((a) => `\t${JSON.stringify(a.name)},`),
"] as const;",
"type LLMArchitecture = (typeof LLM_ARCHITECTURES)[number];",
DEST_COMMON_SOURCE,
...archList.map((a) => {
let code = `export type ${a.tsName} = TransformerLLMBase<${JSON.stringify(a.name)}>`;
if (a.hparams.length) {
Expand Down
33 changes: 18 additions & 15 deletions packages/gguf/src/gguf.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,25 @@ describe("gguf", () => {
"llama.rope.dimension_count": 128,
});

const tokens = metadata["tokenizer.ggml.tokens"];
if (!Array.isArray(tokens)) {
throw new Error();
expect(metadata["tokenizer.ggml.model"]);
if (metadata["tokenizer.ggml.model"]) {
const tokens = metadata["tokenizer.ggml.tokens"];
if (!Array.isArray(tokens)) {
throw new Error();
}
expect(tokens.slice(0, 10)).toEqual([
"<unk>",
"<s>",
"</s>",
"<0x00>",
"<0x01>",
"<0x02>",
"<0x03>",
"<0x04>",
"<0x05>",
"<0x06>",
]);
}
expect(tokens.slice(0, 10)).toEqual([
"<unk>",
"<s>",
"</s>",
"<0x00>",
"<0x01>",
"<0x02>",
"<0x03>",
"<0x04>",
"<0x05>",
"<0x06>",
]);

/// Tensor infos
/// By convention we test the first and last tensor.
Expand Down
3 changes: 3 additions & 0 deletions packages/gguf/src/gguf.ts
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,9 @@ export async function gguf(
}
}
offset += valueResult.length;
/// TODO(fix typing)
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use // @ts-expect-error instead of // @ts-ignore in general

(no need for eslint-disable this way)

Here I think you can change const metadata: GGUFMetadata to const metadata: GGUFMetadata<GGUFType.NON_STRICT> to remove the error (not sure if it's the best fix)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 31bac8b

metadata[keyResult.value] = valueResult.value;
}

Expand Down
82 changes: 51 additions & 31 deletions packages/gguf/src/transformer-llm.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,56 @@
/** This file is auto-generated by generate-llm.ts */

import type { ModelBase } from "./types";
import type { ModelBase, GGUFGeneralInfo } from "./types";

type LLMBase<TArchitecture extends string> = Partial<
Record<
`${TArchitecture}.vocab_size` | `${TArchitecture}.use_parallel_residual` | `${TArchitecture}.tensor_data_layout`,
number
>
>;

type Attention<TArchitecture extends string> = Record<`${TArchitecture}.attention.head_count`, number> &
Partial<
Record<
| `${TArchitecture}.attention.head_count_kv`
| `${TArchitecture}.attention.key_length`
| `${TArchitecture}.attention.value_length`,
number
>
>;

type RopeScalingType = "none" | "linear" | "yarn";
type Rope<TArchitecture extends LLMArchitecture> = Partial<
Record<
| `${TArchitecture}.rope.dimension_count`
| `${TArchitecture}.rope.freq_base`
| `${TArchitecture}.rope.scale_linear`
| `${TArchitecture}.rope.scaling.factor`
| `${TArchitecture}.rope.scaling.original_context_length`,
number
> &
Record<`${TArchitecture}.rope.scaling.type`, RopeScalingType> &
Record<`${TArchitecture}.rope.finetuned`, boolean>
>;

type MOE<TArchitecture extends LLMArchitecture> = Partial<
Record<`${TArchitecture}.expert_count` | `${TArchitecture}.expert_used_count`, number>
>;

export type TransformerLLMArchitecture = LLMArchitecture; // type alias
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = GGUFGeneralInfo<TArchitecture> &
LLMBase<TArchitecture> &
ModelBase<TArchitecture> &
MOE<TArchitecture> &
Attention<TArchitecture> &
Rope<TArchitecture>;

export enum TransformerLLMPoolingType {
UNSPECIFIED = -1,
NONE = 0,
MEAN = 1,
CLS = 2,
}

export const LLM_ARCHITECTURES = [
"llama",
Expand Down Expand Up @@ -37,36 +87,6 @@ export const LLM_ARCHITECTURES = [
"olmo",
] as const;
type LLMArchitecture = (typeof LLM_ARCHITECTURES)[number];

type Attention<TArchitecture extends string> = { [K in `${TArchitecture}.attention.head_count`]: number } & {
[K in `${TArchitecture}.attention.head_count_kv`]: number;
} & { [K in `${TArchitecture}.attention.layer_norm_epsilon`]: number } & {
[K in `${TArchitecture}.attention.layer_norm_rms_epsilon`]: number;
} & { [K in `${TArchitecture}.attention.alibi_bias_max`]: number } & {
[K in `${TArchitecture}.attention.clip_kqv`]: number;
} & { [K in `${TArchitecture}.attention.use_norm`]: number };

type Rope<TArchitecture extends LLMArchitecture> = { [K in `${TArchitecture}.rope.dimension_count`]: number } & {
[K in `${TArchitecture}.rope.freq_base`]: number;
} & { [K in `${TArchitecture}.rope.scale`]: number } & { [K in `${TArchitecture}.rope.scale_linear`]: number };

type MOE<TArchitecture extends LLMArchitecture> = { [K in `${TArchitecture}.expert_count`]: number } & {
[K in `${TArchitecture}.expert_used_count`]: number;
};

export type TransformerLLMArchitecture = LLMArchitecture; // type alias
export type TransformerLLMBase<TArchitecture extends LLMArchitecture> = ModelBase<TArchitecture> &
MOE<TArchitecture> &
Attention<TArchitecture> &
Rope<TArchitecture>;

export enum TransformerLLMPoolingType {
UNSPECIFIED = -1,
NONE = 0,
MEAN = 1,
CLS = 2,
}

export type ArchLlama = TransformerLLMBase<"llama"> & {
"llama.attention.layer_norm_rms_epsilon": number;
};
Expand Down
42 changes: 42 additions & 0 deletions packages/gguf/src/types.spec.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import { describe, it } from "vitest";
import type { GGUFStrictType, GGUFMetadata, GGUFNonStrictType } from "./types";

describe("gguf-types", () => {
it("GGUFNonStrictType should be correct (at compile time)", async () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const model: GGUFMetadata<GGUFNonStrictType> = null as any;
model.kv_count = 123n;
model.abc = 456; // PASS, because it can be anything
});

it("GGUFStrictType should be correct (at compile time)", async () => {
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const model: GGUFMetadata<GGUFStrictType> = null as any;

if (model["general.architecture"] === "whisper") {
model["encoder.whisper.block_count"] = 0;
// @ts-expect-error because it must be a number
model["encoder.whisper.block_count"] = "abc";
}

if (model["tokenizer.ggml.model"] === undefined) {
// @ts-expect-error because it's undefined
model["tokenizer.ggml.eos_token_id"] = 1;
}
if (model["tokenizer.ggml.model"] === "gpt2") {
// @ts-expect-error because it must be a number
model["tokenizer.ggml.eos_token_id"] = undefined;
model["tokenizer.ggml.eos_token_id"] = 1;
}

if (model["general.architecture"] === "mamba") {
model["mamba.ssm.conv_kernel"] = 0;
// @ts-expect-error because it must be a number
model["mamba.ssm.conv_kernel"] = "abc";
}
if (model["general.architecture"] === "llama") {
// @ts-expect-error llama does not have ssm.* keys
model["mamba.ssm.conv_kernel"] = 0;
}
});
});
68 changes: 49 additions & 19 deletions packages/gguf/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,21 +50,32 @@ export enum GGUFValueType {
const ARCHITECTURES = [...LLM_ARCHITECTURES, "rwkv", "whisper"] as const;
export type Architecture = (typeof ARCHITECTURES)[number];

interface General {
"general.architecture": Architecture;
"general.name": string;
"general.file_type": number;
"general.quantization_version": number;
export interface GGUFGeneralInfo<TArchitecture extends Architecture> {
"general.architecture": TArchitecture;
"general.name"?: string;
"general.file_type"?: number;
"general.quantization_version"?: number;
}

type ModelMetadata = Whisper | RWKV | TransformerLLM;
interface NoModelMetadata {
"general.architecture"?: undefined;
}

export type ModelBase<
TArchitecture extends
| Architecture
| `encoder.${Extract<Architecture, "whisper">}`
| `decoder.${Extract<Architecture, "whisper">}`,
> = { [K in `${TArchitecture}.layer_count`]: number } & { [K in `${TArchitecture}.feed_forward_length`]: number } & {
[K in `${TArchitecture}.context_length`]: number;
} & { [K in `${TArchitecture}.embedding_length`]: number } & { [K in `${TArchitecture}.block_count`]: number };
> = Record<
| `${TArchitecture}.context_length`
| `${TArchitecture}.block_count`
| `${TArchitecture}.embedding_length`
| `${TArchitecture}.feed_forward_length`,
number
>;

/// Tokenizer

type TokenizerModel = "no_vocab" | "llama" | "gpt2" | "bert";
interface Tokenizer {
Expand All @@ -75,21 +86,40 @@ interface Tokenizer {
"tokenizer.ggml.bos_token_id": number;
"tokenizer.ggml.eos_token_id": number;
"tokenizer.ggml.add_bos_token": boolean;
"tokenizer.chat_template": string;
"tokenizer.chat_template"?: string;
}
interface NoTokenizer {
"tokenizer.ggml.model"?: undefined;
}

/// Models outside of llama.cpp: "rwkv" and "whisper"

export type RWKV = ModelBase<"rwkv"> & { "rwkv.architecture_version": number };
export type LLM = TransformerLLM | RWKV;
export type Whisper = ModelBase<"encoder.whisper"> & ModelBase<"decoder.whisper">;
export type Model = (LLM | Whisper) & Partial<Tokenizer>;
export type RWKV = GGUFGeneralInfo<"rwkv"> &
ModelBase<"rwkv"> & {
"rwkv.architecture_version": number;
};

export type GGUFMetadata = {
// TODO: whisper.cpp doesn't yet support gguf. This maybe changed in the future.
export type Whisper = GGUFGeneralInfo<"whisper"> &
ModelBase<"encoder.whisper"> &
ModelBase<"decoder.whisper"> & {
"whisper.encoder.mels_count": number;
"whisper.encoder.attention.head_count": number;
"whisper.decoder.attention.head_count": number;
};

/// Types for parse output

export type GGUFStrictType = true;
export type GGUFNonStrictType = false;
mishig25 marked this conversation as resolved.
Show resolved Hide resolved

export type GGUFMetadata<T extends GGUFStrictType | GGUFNonStrictType = GGUFStrictType> = {
mishig25 marked this conversation as resolved.
Show resolved Hide resolved
version: Version;
tensor_count: bigint;
kv_count: bigint;
} & Partial<General> &
Partial<Model> &
Record<string, MetadataValue>;
} & (T extends GGUFStrictType ? GGUFModelKV : Record<string, MetadataValue>);

export type GGUFModelKV = (NoModelMetadata | ModelMetadata) & (NoTokenizer | Tokenizer);

export interface GGUFTensorInfo {
name: string;
Expand All @@ -99,7 +129,7 @@ export interface GGUFTensorInfo {
offset: bigint;
}

export interface GGUFParseOutput {
metadata: GGUFMetadata;
export interface GGUFParseOutput<T extends GGUFStrictType | GGUFNonStrictType = GGUFStrictType> {
metadata: GGUFMetadata<T>;
tensorInfos: GGUFTensorInfo[];
}
Loading