Skip to content

Commit

Permalink
feat: flash attention (#264)
Browse files Browse the repository at this point in the history
* feat: flash attention
* feat: exclude GPU types from `gpu: "auto"`
  • Loading branch information
giladgd authored Jul 5, 2024
1 parent 81e0575 commit c2e322c
Show file tree
Hide file tree
Showing 20 changed files with 307 additions and 81 deletions.
4 changes: 4 additions & 0 deletions llama/addon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -987,6 +987,10 @@ class AddonContext : public Napi::ObjectWrap<AddonContext> {
context_params.embeddings = options.Get("embeddings").As<Napi::Boolean>().Value();
}

if (options.Has("flashAttention")) {
context_params.flash_attn = options.Get("flashAttention").As<Napi::Boolean>().Value();
}

if (options.Has("threads")) {
const auto n_threads = options.Get("threads").As<Napi::Number>().Uint32Value();
const auto resolved_n_threads = n_threads == 0 ? std::thread::hardware_concurrency() : n_threads;
Expand Down
1 change: 1 addition & 0 deletions src/bindings/AddonTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export type BindingModule = {
contextSize?: number,
batchSize?: number,
sequences?: number,
flashAttention?: boolean,
logitsAll?: boolean,
embeddings?: boolean,
threads?: number
Expand Down
9 changes: 6 additions & 3 deletions src/bindings/Llama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {GbnfJsonSchema} from "../utils/gbnfJson/types.js";
import {LlamaJsonSchemaGrammar} from "../evaluator/LlamaJsonSchemaGrammar.js";
import {LlamaGrammar, LlamaGrammarOptions} from "../evaluator/LlamaGrammar.js";
import {BindingModule} from "./AddonTypes.js";
import {BuildGpu, BuildMetadataFile, LlamaLocks, LlamaLogLevel} from "./types.js";
import {BuildGpu, BuildMetadataFile, LlamaGpuType, LlamaLocks, LlamaLogLevel} from "./types.js";
import {MemoryOrchestrator, MemoryReservation} from "./utils/MemoryOrchestrator.js";

const LlamaLogLevelToAddonLogLevel: ReadonlyMap<LlamaLogLevel, number> = new Map([
Expand All @@ -31,7 +31,7 @@ export class Llama {
/** @internal */ public readonly _vramOrchestrator: MemoryOrchestrator;
/** @internal */ public readonly _vramPadding: MemoryReservation;
/** @internal */ public readonly _debug: boolean;
/** @internal */ private readonly _gpu: BuildGpu;
/** @internal */ private readonly _gpu: LlamaGpuType;
/** @internal */ private readonly _buildType: "localBuild" | "prebuilt";
/** @internal */ private readonly _cmakeOptions: Readonly<Record<string, string>>;
/** @internal */ private readonly _supportsGpuOffloading: boolean;
Expand Down Expand Up @@ -244,7 +244,10 @@ export class Llama {
await this._bindings.init();
}

/** @internal */
/**
* Log messages related to the Llama instance
* @internal
*/
public _log(level: LlamaLogLevel, message: string) {
this._onAddonLog(LlamaLogLevelToAddonLogLevel.get(level) ?? defaultLogLevel, message + "\n");
}
Expand Down
10 changes: 8 additions & 2 deletions src/bindings/getLlama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {
} from "./utils/compileLLamaCpp.js";
import {getLastBuildInfo} from "./utils/lastBuildInfo.js";
import {getClonedLlamaCppRepoReleaseInfo, isLlamaCppRepoCloned} from "./utils/cloneLlamaCppRepo.js";
import {BuildGpu, BuildMetadataFile, BuildOptions, LlamaLogLevel} from "./types.js";
import {BuildGpu, BuildMetadataFile, BuildOptions, LlamaGpuType, LlamaLogLevel} from "./types.js";
import {BinaryPlatform, getPlatform} from "./utils/getPlatform.js";
import {getBuildFolderNameForBuildOptions} from "./utils/getBuildFolderNameForBuildOptions.js";
import {resolveCustomCmakeOptions} from "./utils/resolveCustomCmakeOptions.js";
Expand Down Expand Up @@ -46,7 +46,10 @@ export type LlamaOptions = {
*
* `"auto"` by default.
*/
gpu?: "auto" | "metal" | "cuda" | "vulkan" | false,
gpu?: "auto" | LlamaGpuType | {
type: "auto",
exclude?: LlamaGpuType[]
},

/**
* Set the minimum log level for llama.cpp.
Expand Down Expand Up @@ -298,6 +301,9 @@ export async function getLlamaForOptions({
}
}

if (buildGpusToTry.length === 0)
throw new Error("No GPU types available to try building with");

if (build === "auto" || build === "never") {
for (let i = 0; i < buildGpusToTry.length; i++) {
const gpu = buildGpusToTry[i];
Expand Down
1 change: 1 addition & 0 deletions src/bindings/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import {BinaryPlatform} from "./utils/getPlatform.js";
import {BinaryPlatformInfo} from "./utils/getPlatformInfo.js";

export const buildGpuOptions = ["metal", "cuda", "vulkan", false] as const;
export type LlamaGpuType = "metal" | "cuda" | "vulkan" | false;
export const nodeLlamaCppGpuOptions = [
"auto",
...buildGpuOptions
Expand Down
23 changes: 18 additions & 5 deletions src/bindings/utils/getGpuTypesToUseForOption.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,41 @@
import process from "process";
import {BuildGpu, buildGpuOptions} from "../types.js";
import {LlamaOptions} from "../getLlama.js";
import {BinaryPlatform, getPlatform} from "./getPlatform.js";
import {getBestComputeLayersAvailable} from "./getBestComputeLayersAvailable.js";

export async function getGpuTypesToUseForOption(gpu: BuildGpu | "auto", {
export async function getGpuTypesToUseForOption(gpu: Required<LlamaOptions>["gpu"], {
platform = getPlatform(),
arch = process.arch
}: {
platform?: BinaryPlatform,
arch?: typeof process.arch
} = {}): Promise<BuildGpu[]> {
const resolvedGpu = resolveValidGpuOptionForPlatform(gpu, {
const resolvedGpuOption = typeof gpu === "object"
? gpu.type
: gpu;

function withExcludedGpuTypesRemoved(gpuTypes: BuildGpu[]) {
const resolvedExcludeTypes = typeof gpu === "object"
? new Set(gpu.exclude ?? [])
: new Set();

return gpuTypes.filter(gpuType => !resolvedExcludeTypes.has(gpuType));
}

const resolvedGpu = resolveValidGpuOptionForPlatform(resolvedGpuOption, {
platform,
arch
});

if (resolvedGpu === "auto") {
if (arch === process.arch)
return await getBestComputeLayersAvailable();
return withExcludedGpuTypesRemoved(await getBestComputeLayersAvailable());

return [false];
return withExcludedGpuTypesRemoved([false]);
}

return [resolvedGpu];
return withExcludedGpuTypesRemoved([resolvedGpu]);
}

export function resolveValidGpuOptionForPlatform(gpu: BuildGpu | "auto", {
Expand Down
22 changes: 15 additions & 7 deletions src/cli/commands/ChatCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type ChatCommand = {
noJinja?: boolean,
contextSize?: number,
batchSize?: number,
flashAttention?: boolean,
noTrimWhitespace: boolean,
grammar: "text" | Parameters<typeof LlamaGrammar.getFor>[1],
jsonSchemaGrammarFile?: string,
Expand Down Expand Up @@ -149,6 +150,12 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
type: "number",
description: "Batch size to use for the model context. The default value is the context size"
})
.option("flashAttention", {
alias: "fa",
type: "boolean",
default: false,
description: "Enable flash attention"
})
.option("noTrimWhitespace", {
type: "boolean",
alias: ["noTrim"],
Expand Down Expand Up @@ -269,7 +276,7 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
},
async handler({
modelPath, header, gpu, systemInfo, systemPrompt, systemPromptFile, prompt,
promptFile, wrapper, noJinja, contextSize, batchSize,
promptFile, wrapper, noJinja, contextSize, batchSize, flashAttention,
noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK,
topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory,
Expand All @@ -278,9 +285,9 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {
try {
await RunChat({
modelPath, header, gpu, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja, contextSize,
batchSize, noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP, gpuLayers,
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
noHistory, environmentFunctions, debug, meter, printTimings
batchSize, flashAttention, noTrimWhitespace, grammar, jsonSchemaGrammarFile, threads, temperature, minP, topK, topP,
gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
maxTokens, noHistory, environmentFunctions, debug, meter, printTimings
});
} catch (err) {
await new Promise((accept) => setTimeout(accept, 0)); // wait for logs to finish printing
Expand All @@ -293,9 +300,9 @@ export const ChatCommand: CommandModule<object, ChatCommand> = {

async function RunChat({
modelPath: modelArg, header: headerArg, gpu, systemInfo, systemPrompt, systemPromptFile, prompt, promptFile, wrapper, noJinja,
contextSize, batchSize, noTrimWhitespace, grammar: grammarArg, jsonSchemaGrammarFile: jsonSchemaGrammarFilePath, threads, temperature,
minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty,
repeatPresencePenalty, maxTokens, noHistory, environmentFunctions, debug, meter, printTimings
contextSize, batchSize, flashAttention, noTrimWhitespace, grammar: grammarArg, jsonSchemaGrammarFile: jsonSchemaGrammarFilePath,
threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine,
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens, noHistory, environmentFunctions, debug, meter, printTimings
}: ChatCommand) {
if (contextSize === -1) contextSize = undefined;
if (gpuLayers === -1) gpuLayers = undefined;
Expand Down Expand Up @@ -360,6 +367,7 @@ async function RunChat({
: contextSize != null
? {fitContext: {contextSize}}
: undefined,
defaultContextFlashAttention: flashAttention,
ignoreMemorySafetyChecks: gpuLayers != null,
onLoadProgress(loadProgress: number) {
progressUpdater.setProgress(loadProgress);
Expand Down
14 changes: 11 additions & 3 deletions src/cli/commands/CompleteCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type CompleteCommand = {
textFile?: string,
contextSize?: number,
batchSize?: number,
flashAttention?: boolean,
threads: number,
temperature: number,
minP: number,
Expand Down Expand Up @@ -104,6 +105,12 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
type: "number",
description: "Batch size to use for the model context. The default value is the context size"
})
.option("flashAttention", {
alias: "fa",
type: "boolean",
default: false,
description: "Enable flash attention"
})
.option("threads", {
type: "number",
default: 6,
Expand Down Expand Up @@ -194,14 +201,14 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {
},
async handler({
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize,
threads, temperature, minP, topK,
flashAttention, threads, temperature, minP, topK,
topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
debug, meter, printTimings
}) {
try {
await RunCompletion({
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize,
modelPath, header, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention,
threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty,
repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
debug, meter, printTimings
Expand All @@ -216,7 +223,7 @@ export const CompleteCommand: CommandModule<object, CompleteCommand> = {


async function RunCompletion({
modelPath: modelArg, header: headerArg, gpu, systemInfo, text, textFile, contextSize, batchSize,
modelPath: modelArg, header: headerArg, gpu, systemInfo, text, textFile, contextSize, batchSize, flashAttention,
threads, temperature, minP, topK, topP, gpuLayers,
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
maxTokens, debug, meter, printTimings
Expand Down Expand Up @@ -276,6 +283,7 @@ async function RunCompletion({
: contextSize != null
? {fitContext: {contextSize}}
: undefined,
defaultContextFlashAttention: flashAttention,
ignoreMemorySafetyChecks: gpuLayers != null,
onLoadProgress(loadProgress: number) {
progressUpdater.setProgress(loadProgress);
Expand Down
14 changes: 11 additions & 3 deletions src/cli/commands/InfillCommand.ts
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ type InfillCommand = {
suffixFile?: string,
contextSize?: number,
batchSize?: number,
flashAttention?: boolean,
threads: number,
temperature: number,
minP: number,
Expand Down Expand Up @@ -114,6 +115,12 @@ export const InfillCommand: CommandModule<object, InfillCommand> = {
type: "number",
description: "Batch size to use for the model context. The default value is the context size"
})
.option("flashAttention", {
alias: "fa",
type: "boolean",
default: false,
description: "Enable flash attention"
})
.option("threads", {
type: "number",
default: 6,
Expand Down Expand Up @@ -204,14 +211,14 @@ export const InfillCommand: CommandModule<object, InfillCommand> = {
},
async handler({
modelPath, header, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize,
threads, temperature, minP, topK,
flashAttention, threads, temperature, minP, topK,
topP, gpuLayers, repeatPenalty, lastTokensRepeatPenalty, penalizeRepeatingNewLine,
repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
debug, meter, printTimings
}) {
try {
await RunInfill({
modelPath, header, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize,
modelPath, header, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize, flashAttention,
threads, temperature, minP, topK, topP, gpuLayers, lastTokensRepeatPenalty,
repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty, maxTokens,
debug, meter, printTimings
Expand All @@ -226,7 +233,7 @@ export const InfillCommand: CommandModule<object, InfillCommand> = {


async function RunInfill({
modelPath: modelArg, header: headerArg, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize,
modelPath: modelArg, header: headerArg, gpu, systemInfo, prefix, prefixFile, suffix, suffixFile, contextSize, batchSize, flashAttention,
threads, temperature, minP, topK, topP, gpuLayers,
lastTokensRepeatPenalty, repeatPenalty, penalizeRepeatingNewLine, repeatFrequencyPenalty, repeatPresencePenalty,
maxTokens, debug, meter, printTimings
Expand Down Expand Up @@ -300,6 +307,7 @@ async function RunInfill({
: contextSize != null
? {fitContext: {contextSize}}
: undefined,
defaultContextFlashAttention: flashAttention,
ignoreMemorySafetyChecks: gpuLayers != null,
onLoadProgress(loadProgress: number) {
progressUpdater.setProgress(loadProgress);
Expand Down
Loading

0 comments on commit c2e322c

Please sign in to comment.