From 9076dee4328923c95a58610cf3341f1ed70aa50b Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 3 Oct 2024 08:43:00 +0200 Subject: [PATCH 1/3] Cuda graph experiments. --- candle-core/Cargo.toml | 4 ++ candle-core/examples/cuda_basics.rs | 57 ++++++++++++++++++++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/candle-core/Cargo.toml b/candle-core/Cargo.toml index cbf8f2007f..a05f966a3f 100644 --- a/candle-core/Cargo.toml +++ b/candle-core/Cargo.toml @@ -52,3 +52,7 @@ harness = false [[example]] name = "metal_basics" required-features = ["metal"] + +[[example]] +name = "cuda_basics" +required-features = ["cuda"] diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index 9af1b006e3..31db9e8112 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -7,8 +7,63 @@ extern crate intel_mkl_src; use anyhow::Result; use candle_core::{Device, Tensor}; +fn cuda_graph() -> Result<()> { + let device = Device::new_cuda_with_stream(0)?; + let cu_device = match &device { + Device::Cuda(dev) => dev, + _ => unreachable!(), + }; + let cu_stream = cu_device.cu_stream(); + { + // load_ptx cannot be called while capturing the stream so we need this to happen + // beforehand. + let x = Tensor::zeros(16, candle_core::DType::F32, &device)?; + let y = Tensor::zeros(16, candle_core::DType::F32, &device)?; + y.slice_set(&x, 0, 0)?; + device.synchronize()?; + } + unsafe { + cudarc::driver::sys::lib() + .cuStreamBeginCapture_v2( + *cu_stream, + cudarc::driver::sys::CUstreamCaptureMode_enum::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL, + ) + .result()? + }; + { + let x = Tensor::zeros(16, candle_core::DType::F32, &device)?; + let y = Tensor::zeros(16, candle_core::DType::F32, &device)?; + y.slice_set(&x, 0, 0)?; + // let y = x.affine(2., 1.)?; + } + let cu_graph = unsafe { + let mut cu_graph = std::mem::MaybeUninit::uninit(); + cudarc::driver::sys::lib() + .cuStreamEndCapture(*cu_stream, cu_graph.as_mut_ptr()) + .result()?; + cu_graph.assume_init() + }; + let cu_graph_e = unsafe { + let mut cu_graph_e = std::mem::MaybeUninit::uninit(); + cudarc::driver::sys::lib() + .cuGraphInstantiateWithFlags(cu_graph_e.as_mut_ptr(), cu_graph, 0) + .result()?; + cu_graph_e.assume_init() + }; + for _i in 0..100 { + unsafe { + cudarc::driver::sys::lib() + .cuGraphLaunch(cu_graph_e, *cu_stream) + .result()? + } + } + Ok(()) +} + fn main() -> Result<()> { - let device = Device::new_cuda(0)?; + cuda_graph()?; + return Ok(()); + let device = Device::new_cuda_with_stream(0)?; let x = Tensor::randn(0f32, 1.0, (8 * 4096, 8 * 4096), &device)? .to_dtype(candle_core::DType::BF16)?; candle_core::cuda::set_gemm_reduced_precision_f32(false); From b2956857efcd7aecc6e53f53d761503ef118d3be Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 3 Oct 2024 12:43:08 +0200 Subject: [PATCH 2/3] More cuda graph attempts. --- candle-core/examples/cuda_basics.rs | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index 31db9e8112..0cd933b597 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -17,9 +17,13 @@ fn cuda_graph() -> Result<()> { { // load_ptx cannot be called while capturing the stream so we need this to happen // beforehand. - let x = Tensor::zeros(16, candle_core::DType::F32, &device)?; - let y = Tensor::zeros(16, candle_core::DType::F32, &device)?; - y.slice_set(&x, 0, 0)?; + let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)? + .to_dtype(candle_core::DType::BF16)?; + let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)? + .to_dtype(candle_core::DType::BF16)?; + let v = Tensor::zeros(4096, candle_core::DType::F32, &device)? + .to_dtype(candle_core::DType::BF16)?; + let _x = x.mul(&u)?.broadcast_add(&v)?; device.synchronize()?; } unsafe { @@ -31,10 +35,15 @@ fn cuda_graph() -> Result<()> { .result()? }; { - let x = Tensor::zeros(16, candle_core::DType::F32, &device)?; - let y = Tensor::zeros(16, candle_core::DType::F32, &device)?; - y.slice_set(&x, 0, 0)?; - // let y = x.affine(2., 1.)?; + let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)? + .to_dtype(candle_core::DType::BF16)?; + let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)? + .to_dtype(candle_core::DType::BF16)?; + let v = Tensor::zeros(4096, candle_core::DType::F32, &device)? + .to_dtype(candle_core::DType::BF16)?; + for _i in 0..1 { + x = x.mul(&u)?.broadcast_add(&v)?; + } } let cu_graph = unsafe { let mut cu_graph = std::mem::MaybeUninit::uninit(); From 1bb68854d3a37c05a452312b633174b7b9c6d633 Mon Sep 17 00:00:00 2001 From: laurent Date: Thu, 3 Oct 2024 17:12:52 +0200 Subject: [PATCH 3/3] Tweaks to the graph experiment. --- candle-core/examples/cuda_basics.rs | 68 +++++++++++++++++++---------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/candle-core/examples/cuda_basics.rs b/candle-core/examples/cuda_basics.rs index 0cd933b597..315fe0a282 100644 --- a/candle-core/examples/cuda_basics.rs +++ b/candle-core/examples/cuda_basics.rs @@ -7,6 +7,8 @@ extern crate intel_mkl_src; use anyhow::Result; use candle_core::{Device, Tensor}; +const USE_CUDA_GRAPH: bool = true; + fn cuda_graph() -> Result<()> { let device = Device::new_cuda_with_stream(0)?; let cu_device = match &device { @@ -24,47 +26,65 @@ fn cuda_graph() -> Result<()> { let v = Tensor::zeros(4096, candle_core::DType::F32, &device)? .to_dtype(candle_core::DType::BF16)?; let _x = x.mul(&u)?.broadcast_add(&v)?; + let _x = x.affine(1., 0.5)?; + x.slice_set(&u, 0, 0)?; device.synchronize()?; } - unsafe { - cudarc::driver::sys::lib() + if USE_CUDA_GRAPH { + unsafe { + cudarc::driver::sys::lib() .cuStreamBeginCapture_v2( *cu_stream, cudarc::driver::sys::CUstreamCaptureMode_enum::CU_STREAM_CAPTURE_MODE_THREAD_LOCAL, ) .result()? - }; + } + } { let u = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)? .to_dtype(candle_core::DType::BF16)?; let mut x = Tensor::zeros((4096, 4096), candle_core::DType::F32, &device)? .to_dtype(candle_core::DType::BF16)?; - let v = Tensor::zeros(4096, candle_core::DType::F32, &device)? + let v = Tensor::zeros((4096, 1), candle_core::DType::F32, &device)? .to_dtype(candle_core::DType::BF16)?; - for _i in 0..1 { - x = x.mul(&u)?.broadcast_add(&v)?; + for _i in 0..100 { + // x.slice_set(&u, 0, 0)?; + // x.broadcast_add(&v)?; + x = x.affine(1., 0.5)?; + // x = (&u + &x)?; } } - let cu_graph = unsafe { - let mut cu_graph = std::mem::MaybeUninit::uninit(); - cudarc::driver::sys::lib() - .cuStreamEndCapture(*cu_stream, cu_graph.as_mut_ptr()) - .result()?; - cu_graph.assume_init() - }; - let cu_graph_e = unsafe { - let mut cu_graph_e = std::mem::MaybeUninit::uninit(); - cudarc::driver::sys::lib() - .cuGraphInstantiateWithFlags(cu_graph_e.as_mut_ptr(), cu_graph, 0) - .result()?; - cu_graph_e.assume_init() - }; - for _i in 0..100 { - unsafe { + if USE_CUDA_GRAPH { + let cu_graph: cudarc::driver::sys::CUgraph = unsafe { + let mut cu_graph = std::mem::MaybeUninit::uninit(); + cudarc::driver::sys::lib() + .cuStreamEndCapture(*cu_stream, cu_graph.as_mut_ptr()) + .result()?; + cu_graph.assume_init() + }; + let cu_graph_e: cudarc::driver::sys::CUgraphExec = unsafe { + let mut cu_graph_e = std::mem::MaybeUninit::uninit(); cudarc::driver::sys::lib() - .cuGraphLaunch(cu_graph_e, *cu_stream) - .result()? + .cuGraphInstantiateWithFlags(cu_graph_e.as_mut_ptr(), cu_graph, 0) + .result()?; + cu_graph_e.assume_init() + }; + println!("graph captured!"); + for i in 1..100 { + println!("graph exec {i}"); + unsafe { + cudarc::driver::sys::lib() + .cuGraphLaunch(cu_graph_e, *cu_stream) + .result()? + } + println!("sync"); + if let Err(err) = device.synchronize() { + println!("err: {err:?}") + } + println!("done syncing"); } + } else { + device.synchronize()?; } Ok(()) }