Skip to content

Commit

Permalink
Simplify GPU plotting logic
Browse files Browse the repository at this point in the history
  • Loading branch information
nazar-pc committed Nov 20, 2024
1 parent aff971e commit 21362c4
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,7 @@ where
cuda_devices
.into_iter()
.map(|cuda_device| CudaRecordsEncoder::new(cuda_device, Arc::clone(&global_mutex)))
.collect::<Result<_, _>>()
.map_err(|error| {
anyhow::anyhow!("Failed to create CUDA records encoder: {error}")
})?,
.collect(),
global_mutex,
kzg,
erasure_coding,
Expand Down Expand Up @@ -480,10 +477,7 @@ where
rocm_devices
.into_iter()
.map(|rocm_device| RocmRecordsEncoder::new(rocm_device, Arc::clone(&global_mutex)))
.collect::<Result<_, _>>()
.map_err(|error| {
anyhow::anyhow!("Failed to create ROCm records encoder: {error}")
})?,
.collect(),
global_mutex,
kzg,
erasure_coding,
Expand Down
10 changes: 2 additions & 8 deletions crates/subspace-farmer/src/bin/subspace-farmer/commands/farm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1072,10 +1072,7 @@ where
cuda_devices
.into_iter()
.map(|cuda_device| CudaRecordsEncoder::new(cuda_device, Arc::clone(&global_mutex)))
.collect::<Result<_, _>>()
.map_err(|error| {
anyhow::anyhow!("Failed to create CUDA records encoder: {error}")
})?,
.collect(),
global_mutex,
kzg,
erasure_coding,
Expand Down Expand Up @@ -1154,10 +1151,7 @@ where
rocm_devices
.into_iter()
.map(|rocm_device| RocmRecordsEncoder::new(rocm_device, Arc::clone(&global_mutex)))
.collect::<Result<_, _>>()
.map_err(|error| {
anyhow::anyhow!("Failed to create ROCm records encoder: {error}")
})?,
.collect(),
global_mutex,
kzg,
erasure_coding,
Expand Down
91 changes: 16 additions & 75 deletions crates/subspace-farmer/src/plotter/gpu/cuda.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
use crate::plotter::gpu::GpuRecordsEncoder;
use async_lock::Mutex as AsyncMutex;
use parking_lot::Mutex;
use rayon::{current_thread_index, ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
use std::process::exit;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use subspace_core_primitives::pieces::{PieceOffset, Record};
Expand All @@ -17,7 +14,6 @@ use subspace_proof_of_space_gpu::cuda::CudaDevice;
#[derive(Debug)]
pub struct CudaRecordsEncoder {
cuda_device: CudaDevice,
thread_pool: ThreadPool,
global_mutex: Arc<AsyncMutex<()>>,
}

Expand All @@ -38,89 +34,34 @@ impl RecordsEncoder for CudaRecordsEncoder {
.map_err(|error| anyhow::anyhow!("Failed to convert pieces in sector: {error}"))?;
let mut sector_contents_map = SectorContentsMap::new(pieces_in_sector);

self.thread_pool.install(|| {
let iter = Mutex::new(
(PieceOffset::ZERO..)
.zip(records.iter_mut())
.zip(sector_contents_map.iter_record_bitfields_mut()),
);
let plotting_error = Mutex::new(None::<String>);
for ((piece_offset, record), mut encoded_chunks_used) in (PieceOffset::ZERO..)
.zip(records.iter_mut())
.zip(sector_contents_map.iter_record_bitfields_mut())
{
// Take mutex briefly to make sure encoding is allowed right now
self.global_mutex.lock_blocking();

rayon::scope(|scope| {
scope.spawn_broadcast(|_scope, _ctx| loop {
// Take mutex briefly to make sure encoding is allowed right now
self.global_mutex.lock_blocking();
let pos_seed = sector_id.derive_evaluation_seed(piece_offset);

// This instead of `while` above because otherwise mutex will be held for the
// duration of the loop and will limit concurrency to 1 record
let Some(((piece_offset, record), mut encoded_chunks_used)) =
iter.lock().next()
else {
return;
};
let pos_seed = sector_id.derive_evaluation_seed(piece_offset);
self.cuda_device
.generate_and_encode_pospace(&pos_seed, record, encoded_chunks_used.iter_mut())
.map_err(anyhow::Error::msg)?;

if let Err(error) = self.cuda_device.generate_and_encode_pospace(
&pos_seed,
record,
encoded_chunks_used.iter_mut(),
) {
plotting_error.lock().replace(error);
return;
}

if abort_early.load(Ordering::Relaxed) {
return;
}
});
});

let plotting_error = plotting_error.lock().take();
if let Some(error) = plotting_error {
return Err(anyhow::Error::msg(error));
if abort_early.load(Ordering::Relaxed) {
break;
}

Ok(())
})?;
}

Ok(sector_contents_map)
}
}

impl CudaRecordsEncoder {
/// Create new instance
pub fn new(
cuda_device: CudaDevice,
global_mutex: Arc<AsyncMutex<()>>,
) -> Result<Self, ThreadPoolBuildError> {
let id = cuda_device.id();
let thread_name = move |thread_index| format!("cuda-{id}.{thread_index}");
// TODO: remove this panic handler when rayon logs panic_info
// https://github.com/rayon-rs/rayon/issues/1208
let panic_handler = move |panic_info| {
if let Some(index) = current_thread_index() {
eprintln!("panic on thread {}: {:?}", thread_name(index), panic_info);
} else {
// We want to guarantee exit, rather than panicking in a panic handler.
eprintln!(
"rayon panic handler called on non-rayon thread: {:?}",
panic_info
);
}
exit(1);
};

let thread_pool = ThreadPoolBuilder::new()
.thread_name(thread_name)
.panic_handler(panic_handler)
// Make sure there is overlap between records, so GPU is almost always busy
.num_threads(2)
.build()?;

Ok(Self {
pub fn new(cuda_device: CudaDevice, global_mutex: Arc<AsyncMutex<()>>) -> Self {
Self {
cuda_device,
thread_pool,
global_mutex,
})
}
}
}
73 changes: 16 additions & 57 deletions crates/subspace-farmer/src/plotter/gpu/rocm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
use crate::plotter::gpu::GpuRecordsEncoder;
use async_lock::Mutex as AsyncMutex;
use parking_lot::Mutex;
use rayon::{ThreadPool, ThreadPoolBuildError, ThreadPoolBuilder};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use subspace_core_primitives::pieces::{PieceOffset, Record};
Expand All @@ -16,7 +14,6 @@ use subspace_proof_of_space_gpu::rocm::RocmDevice;
#[derive(Debug)]
pub struct RocmRecordsEncoder {
rocm_device: RocmDevice,
thread_pool: ThreadPool,
global_mutex: Arc<AsyncMutex<()>>,
}

Expand All @@ -37,72 +34,34 @@ impl RecordsEncoder for RocmRecordsEncoder {
.map_err(|error| anyhow::anyhow!("Failed to convert pieces in sector: {error}"))?;
let mut sector_contents_map = SectorContentsMap::new(pieces_in_sector);

self.thread_pool.install(|| {
let iter = Mutex::new(
(PieceOffset::ZERO..)
.zip(records.iter_mut())
.zip(sector_contents_map.iter_record_bitfields_mut()),
);
let plotting_error = Mutex::new(None::<String>);
for ((piece_offset, record), mut encoded_chunks_used) in (PieceOffset::ZERO..)
.zip(records.iter_mut())
.zip(sector_contents_map.iter_record_bitfields_mut())
{
// Take mutex briefly to make sure encoding is allowed right now
self.global_mutex.lock_blocking();

rayon::scope(|scope| {
scope.spawn_broadcast(|_scope, _ctx| loop {
// Take mutex briefly to make sure encoding is allowed right now
self.global_mutex.lock_blocking();
let pos_seed = sector_id.derive_evaluation_seed(piece_offset);

// This instead of `while` above because otherwise mutex will be held for the
// duration of the loop and will limit concurrency to 1 record
let Some(((piece_offset, record), mut encoded_chunks_used)) =
iter.lock().next()
else {
return;
};
let pos_seed = sector_id.derive_evaluation_seed(piece_offset);
self.rocm_device
.generate_and_encode_pospace(&pos_seed, record, encoded_chunks_used.iter_mut())
.map_err(anyhow::Error::msg)?;

if let Err(error) = self.rocm_device.generate_and_encode_pospace(
&pos_seed,
record,
encoded_chunks_used.iter_mut(),
) {
plotting_error.lock().replace(error);
return;
}

if abort_early.load(Ordering::Relaxed) {
return;
}
});
});

let plotting_error = plotting_error.lock().take();
if let Some(error) = plotting_error {
return Err(anyhow::Error::msg(error));
if abort_early.load(Ordering::Relaxed) {
break;
}

Ok(())
})?;
}

Ok(sector_contents_map)
}
}

impl RocmRecordsEncoder {
/// Create new instance
pub fn new(
rocm_device: RocmDevice,
global_mutex: Arc<AsyncMutex<()>>,
) -> Result<Self, ThreadPoolBuildError> {
let id = rocm_device.id();
let thread_pool = ThreadPoolBuilder::new()
.thread_name(move |thread_index| format!("rocm-{id}.{thread_index}"))
// Make sure there is overlap between records, so GPU is almost always busy
.num_threads(2)
.build()?;

Ok(Self {
pub fn new(rocm_device: RocmDevice, global_mutex: Arc<AsyncMutex<()>>) -> Self {
Self {
rocm_device,
thread_pool,
global_mutex,
})
}
}
}

0 comments on commit 21362c4

Please sign in to comment.