Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Precompile metal kernels into .metallib files #2335

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions candle-metal-kernels/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ half = { version = "2.3.1", features = [
"rand_distr",
] }
rand = "0.8.5"

[build-dependencies]
anyhow = "1.0.44"
convert_case = "0.6.0"
115 changes: 115 additions & 0 deletions candle-metal-kernels/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
use anyhow::Result;
use std::io::Write;
use std::path::{Path, PathBuf};
use std::{env, fs};

fn main() -> Result<()> {
let kernel_files: Vec<PathBuf> = kernel_source_files()?;
let mut metallib_files: Vec<PathBuf> = Vec::with_capacity(kernel_files.len());

for kernel_file in kernel_files {
let ir_path = compile_kernel(kernel_file)?;
let metallib_path = link_kernel(ir_path)?;
metallib_files.push(metallib_path);
}

gen_metallibs_rs(metallib_files)?;

Ok(())
}

fn kernel_source_files() -> Result<Vec<PathBuf>> {
let manifest_dir = env::var("CARGO_MANIFEST_DIR")?;
let src_dir = Path::new(&manifest_dir).join("src");

let mut paths = Vec::new();
for entry in fs::read_dir(src_dir)? {
let entry = entry.unwrap();
let path = entry.path();
if path.extension().map(|ext| ext.to_str().unwrap()) == Some("metal") {
paths.push(path);
}
}

Ok(paths)
}

fn compile_kernel(kernel_path: impl AsRef<Path>) -> Result<PathBuf> {
let out_dir = std::env::var("OUT_DIR")?;

let file_stem = kernel_path.as_ref().file_stem().unwrap().to_str().unwrap();
let ir_file_name = format!("{}.ir", file_stem,);

println!("cargo:rerun-if-changed=src/{}.metal", file_stem);

let output_file = Path::new(&out_dir).join(ir_file_name);

let mut command = std::process::Command::new("xcrun");
command.arg("metal");
command.arg("-c");
command.arg(format!("{}", kernel_path.as_ref().display()));
command.arg("-o");
command.arg(format!("{}", output_file.display()));

let status = command.status()?;

if !status.success() {
return Err(anyhow::anyhow!(
"Failed to compile kernel file: {:?}",
kernel_path.as_ref()
));
}

Ok(output_file)
}

fn link_kernel(ir_path: impl AsRef<Path>) -> Result<PathBuf> {
let out_dir = std::env::var("OUT_DIR")?;

let metallib_file_name = format!(
"{}.metallib",
ir_path.as_ref().file_stem().unwrap().to_str().unwrap()
);

let output_file = Path::new(&out_dir).join(metallib_file_name);

let mut command = std::process::Command::new("xcrun");
command.arg("metallib");
command.arg(format!("{}", ir_path.as_ref().display()));
command.arg("-o");
command.arg(format!("{}", output_file.display()));

let status = command.status()?;

if !status.success() {
return Err(anyhow::anyhow!(
"Failed to link kernel file: {:?}",
ir_path.as_ref()
));
}

Ok(output_file)
}

fn gen_metallibs_rs(metallibs: Vec<PathBuf>) -> Result<()> {
use convert_case::{Case, Casing};

// generate a rust source file that contains an include_bytes constant
// for every metallib file
let out_dir = std::env::var("OUT_DIR")?;
let out_file = Path::new(&out_dir).join("candle_metallibs.rs");

let mut file = fs::File::create(&out_file)?;

for metallib in metallibs {
let name = metallib.file_stem().unwrap().to_str().unwrap();
writeln!(
file,
"pub const {}: &'static [u8] = include_bytes!(\"{}\");",
name.to_case(Case::ScreamingSnake),
metallib.display()
)?;
}

Ok(())
}
65 changes: 44 additions & 21 deletions candle-metal-kernels/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib");
const QUANTIZED: &str = include_str!("quantized.metal");
const SORT: &str = include_str!("sort.metal");

mod metallibs {
include!(concat!(env!("OUT_DIR"), "/candle_metallibs.rs"));
}

#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Source {
Affine,
Expand Down Expand Up @@ -187,21 +191,29 @@ impl Kernels {
}
}

fn get_library_source(&self, source: Source) -> &'static str {
match source {
Source::Affine => AFFINE,
Source::Unary => UNARY,
Source::Binary => BINARY,
Source::Ternary => TERNARY,
Source::Indexing => INDEXING,
Source::Cast => CAST,
Source::Reduce => REDUCE,
Source::Conv => CONV,
Source::Random => RANDOM,
Source::Quantized => QUANTIZED,
Source::Sort => SORT,
Source::Mfa => panic!("Invalid lib"),
}
// fn get_library_source(&self, source: Source) -> &'static str {
// match source {
// Source::Affine => AFFINE,
// Source::Unary => UNARY,
// Source::Binary => BINARY,
// Source::Ternary => TERNARY,
// Source::Indexing => INDEXING,
// Source::Cast => CAST,
// Source::Reduce => REDUCE,
// Source::Conv => CONV,
// Source::Random => RANDOM,
// Source::Quantized => QUANTIZED,
// Source::Sort => SORT,
// Source::Mfa => panic!("Invalid lib"),
// }
// }

fn load_metallib(device: &Device, data: &[u8]) -> Result<Library, MetalKernelError> {
device.new_library_with_data(data).map_err(|e| {
MetalKernelError::LoadLibraryError(format!(
"Candle metal requires macosx > 13.0 or higher, cannot load mfa: {e}"
))
})
}

/// Load the give library from its [`source`].
Expand All @@ -224,12 +236,23 @@ impl Kernels {
))
})?
}
source => {
let source_content = self.get_library_source(source);
device
.new_library_with_source(source_content, &CompileOptions::new())
.map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
}
Source::Affine => Self::load_metallib(device, metallibs::AFFINE)?,
Source::Indexing => Self::load_metallib(device, metallibs::INDEXING)?,
Source::Unary => Self::load_metallib(device, metallibs::UNARY)?,
Source::Binary => Self::load_metallib(device, metallibs::BINARY)?,
Source::Ternary => Self::load_metallib(device, metallibs::TERNARY)?,
Source::Cast => Self::load_metallib(device, metallibs::CAST)?,
Source::Reduce => Self::load_metallib(device, metallibs::REDUCE)?,
Source::Conv => Self::load_metallib(device, metallibs::CONV)?,
Source::Random => Self::load_metallib(device, metallibs::RANDOM)?,
Source::Quantized => Self::load_metallib(device, metallibs::QUANTIZED)?,
Source::Sort => Self::load_metallib(device, metallibs::SORT)?,
// source => {
// let source_content = self.get_library_source(source);
// device
// .new_library_with_source(source_content, &CompileOptions::new())
// .map_err(|e| MetalKernelError::LoadLibraryError(e.to_string()))?
// }
};
libraries.insert(source, lib.clone());
Ok(lib)
Expand Down