Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 17, 2023
1 parent 84daf64 commit dd6b25a
Show file tree
Hide file tree
Showing 9 changed files with 78 additions and 82 deletions.
2 changes: 1 addition & 1 deletion luisa_compute/examples/path_tracer_cutout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ fn main() {
on_triangle_hit: |c: TriangleCandidate| {
if_!(filter(&c), { c.commit(); });
},
on_procedural_hit: |c| {}
on_procedural_hit: |_c| {}
});
let occluded = !occluded.miss();
let cos_wi_light = wi_light.dot(n);
Expand Down
8 changes: 6 additions & 2 deletions luisa_compute/src/lang/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,8 @@ impl<T: Value> CpuFn<T> {
r.device
.as_ref()
.unwrap()
.upgrade()
.unwrap()
.inner
.query("device_name")
.unwrap(),
Expand Down Expand Up @@ -600,7 +602,7 @@ pub(crate) struct Recorder {
pub(crate) cpu_custom_ops: IndexMap<u64, (usize, CArc<CpuCustomOp>)>,
pub(crate) callables: IndexMap<u64, CallableModuleRef>,
pub(crate) shared: Vec<NodeRef>,
pub(crate) device: Option<Device>,
pub(crate) device: Option<WeakDevice>,
pub(crate) block_size: Option<[u32; 3]>,
pub(crate) building_kernel: bool,
pub(crate) pools: Option<CArc<ModulePools>>,
Expand Down Expand Up @@ -1855,7 +1857,7 @@ impl KernelBuilder {
"Cannot record multiple kernels at the same time"
);
r.lock = true;
r.device = device.clone();
r.device = device.as_ref().map(|d| WeakDevice::new(d));
r.pools = Some(CArc::new(ModulePools::new()));
r.scopes.clear();
r.building_kernel = is_kernel;
Expand Down Expand Up @@ -3006,6 +3008,8 @@ pub fn is_cpu_backend() -> bool {
r.device
.as_ref()
.unwrap()
.upgrade()
.unwrap()
.inner
.query("device_name")
.map(|s| s == "cpu")
Expand Down
2 changes: 0 additions & 2 deletions luisa_compute/src/lang/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use std::sync::atomic::AtomicBool;

pub type LogFn = Box<dyn Fn(&[*const u32]) + Send + Sync>;
struct PrinterItem {
level: log::Level,
log_fn: LogFn,
count: usize,
count_per_arg: Vec<usize>,
Expand Down Expand Up @@ -136,7 +135,6 @@ impl Printer {
);

items.push(PrinterItem {
level,
log_fn,
count: args.count + 1,
count_per_arg: args.count_per_arg,
Expand Down
9 changes: 5 additions & 4 deletions luisa_compute/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub mod prelude {
pub use crate::lang::traits::{CommonVarOp, FloatVarTrait, IntVarTrait, VarCmp, VarCmpEq};
pub use crate::lang::{
Aggregate, ExprProxy, FromNode, IndexRead, IndexWrite, KernelBuildFn, KernelParameter,
KernelSignature, Value, VarProxy, Mask,
KernelSignature, Mask, Value, VarProxy,
};
pub use crate::lang::{
__compose, __cpu_dbg, __current_scope, __env_need_backtrace, __extract, __insert,
Expand All @@ -34,14 +34,14 @@ pub use api::{
AccelBuildModificationFlags, AccelBuildRequest, AccelOption, AccelUsageHint, MeshType,
PixelFormat, PixelStorage,
};
pub use log;
pub use glam;
pub use lang::math;
pub use lang::math::*;
pub use lang::poly;
pub use lang::poly::*;
pub use lang::traits::*;
pub use lang::*;
pub use log;
pub use luisa_compute_derive as derive;
pub use luisa_compute_derive::*;
pub use luisa_compute_ir::ir::UserNodeData;
Expand Down Expand Up @@ -84,8 +84,8 @@ lazy_static! {
Mutex::new(HashMap::new());
}
impl Context {
// path to libluisa-*
// if the current_exe() is in the same directory as libluisa-*, then passing current_exe() is enough
/// path to libluisa-*
/// if the current_exe() is in the same directory as libluisa-*, then passing current_exe() is enough
pub fn new(lib_path: impl AsRef<Path>) -> Self {
let mut lib_path = lib_path.as_ref().to_path_buf();
lib_path = lib_path.canonicalize().unwrap();
Expand Down Expand Up @@ -124,6 +124,7 @@ impl Context {
device: weak.clone(),
mutex: RawMutex::INIT,
})),
ctx: self.inner.clone(),
}),
}
}
Expand Down
26 changes: 22 additions & 4 deletions luisa_compute/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,20 @@ use winit::window::Window;
pub struct Device {
pub(crate) inner: Arc<DeviceHandle>,
}

#[derive(Clone)]
pub struct WeakDevice {
pub(crate) inner: Weak<DeviceHandle>,
}
impl WeakDevice {
pub fn new(device: &Device) -> Self {
Self {
inner: Arc::downgrade(&device.inner),
}
}
pub fn upgrade(&self) -> Option<Device> {
self.inner.upgrade().map(|inner| Device { inner })
}
}
impl Hash for Device {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
let ptr = Arc::as_ptr(&self.inner);
Expand All @@ -46,6 +59,8 @@ impl Eq for Device {}
pub(crate) struct DeviceHandle {
pub(crate) backend: ProxyBackend,
pub(crate) default_stream: Option<Arc<StreamHandle>>,
#[allow(dead_code)]
pub(crate) ctx: Arc<backend::Context>,
}

unsafe impl Send for DeviceHandle {}
Expand Down Expand Up @@ -367,7 +382,7 @@ impl Device {
modifications: RwLock::new(HashMap::new()),
}
}
pub fn create_callable<'a, S: CallableSignature<'a>>(&self, f:S::Fn) -> S::Callable {
pub fn create_callable<'a, S: CallableSignature<'a>>(&self, f: S::Fn) -> S::Callable {
let mut builder = KernelBuilder::new(Some(self.clone()), false);
let raw_callable = CallableBuildFn::build_callable(&f, None, &mut builder);
S::wrap_raw_callable(raw_callable)
Expand Down Expand Up @@ -1192,7 +1207,7 @@ impl<S: CallableSignature<'static>> DynCallable<S> {
RECORDER.with(|r| {
if let Some(device) = r.borrow().device.as_ref() {
assert!(
Arc::ptr_eq(&device.inner, &self.device.inner),
Arc::ptr_eq(&device.inner.upgrade().unwrap(), &self.device.inner),
"Callable created on a different device than the one it is called on"
);
}
Expand All @@ -1217,7 +1232,10 @@ impl<S: CallableSignature<'static>> DynCallable<S> {
let (r_backup, device) = RECORDER.with(|r| {
let mut r = r.borrow_mut();
let device = r.device.clone().unwrap();
(std::mem::replace(&mut *r, Recorder::new()), device)
(
std::mem::replace(&mut *r, Recorder::new()),
device.upgrade().unwrap(),
)
});
let mut builder = KernelBuilder::new(Some(device), false);
let new_callable = (inner.builder)(args, &mut builder);
Expand Down
34 changes: 5 additions & 29 deletions luisa_compute/tests/autodiff.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{env::current_exe, ops::Range};
use std::ops::Range;

use luisa::prelude::*;
use luisa::*;
Expand All @@ -8,34 +8,10 @@ use rayon::{
prelude::{IntoParallelIterator, ParallelIterator},
slice::ParallelSliceMut,
};
fn _signal_handler(signal: libc::c_int) {
if signal == libc::SIGSEGV {
panic!("segfault detected");
}
}
static ONCE: std::sync::Once = std::sync::Once::new();
fn get_device() -> Device {
let show_log = match std::env::var("LUISA_TEST_LOG") {
Ok(log) => log == "1",
Err(_) => false,
};
ONCE.call_once(|| {
if show_log {
init_logger_verbose();
}
unsafe {
libc::signal(libc::SIGSEGV, _signal_handler as usize);
}
});
let curr_exe = current_exe().unwrap();
let runtime_dir = curr_exe.parent().unwrap().parent().unwrap();
let ctx = Context::new(runtime_dir);
let device = match std::env::var("LUISA_TEST_DEVICE") {
Ok(device) => device,
Err(_) => "cpu".to_string(),
};
ctx.create_device(&device)
}
#[path = "common.rs"]
mod common;
use common::*;

fn finite_difference(inputs: &[Float], f: impl Fn(&[Float]) -> Float) -> Vec<Float> {
let eps = 1e-4;

Expand Down
34 changes: 34 additions & 0 deletions luisa_compute/tests/common.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
use std::env::current_exe;
use luisa::*;
use luisa_compute as luisa;
fn _signal_handler(signal: libc::c_int) {
if signal == libc::SIGSEGV {
panic!("segfault detected");
}
}
static ONCE: std::sync::Once = std::sync::Once::new();
pub fn device_name() -> String {
match std::env::var("LUISA_TEST_DEVICE") {
Ok(device) => device,
Err(_) => "cpu".to_string(),
}
}
pub fn get_device() -> Device {
let show_log = match std::env::var("LUISA_TEST_LOG") {
Ok(log) => log == "1",
Err(_) => false,
};
ONCE.call_once(|| unsafe {
if show_log {
init_logger_verbose();
}
libc::signal(libc::SIGSEGV, _signal_handler as usize);
});
let curr_exe = current_exe().unwrap();
let runtime_dir = curr_exe.parent().unwrap().parent().unwrap();
let ctx = Context::new(runtime_dir);
let device = device_name();
let device = ctx.create_device(&device);
device.create_buffer_from_slice(&[1.0f32]);
device
}
43 changes: 4 additions & 39 deletions luisa_compute/tests/misc.rs
Original file line number Diff line number Diff line change
@@ -1,41 +1,12 @@
use std::env::current_exe;

use luisa::prelude::*;
use luisa::*;
use luisa_compute as luisa;
use luisa_compute_api_types::StreamTag;
use rand::prelude::*;
#[path = "common.rs"]
mod common;
use common::*;

fn _signal_handler(signal: libc::c_int) {
if signal == libc::SIGSEGV {
panic!("segfault detected");
}
}
static ONCE: std::sync::Once = std::sync::Once::new();
fn device_name() -> String {
match std::env::var("LUISA_TEST_DEVICE") {
Ok(device) => device,
Err(_) => "cpu".to_string(),
}
}
fn get_device() -> Device {
let show_log = match std::env::var("LUISA_TEST_LOG") {
Ok(log) => log == "1",
Err(_) => false,
};
ONCE.call_once(|| unsafe {
if show_log {
init_logger_verbose();
}
libc::signal(libc::SIGSEGV, _signal_handler as usize);
});
let curr_exe = current_exe().unwrap();
let runtime_dir = curr_exe.parent().unwrap().parent().unwrap();
let ctx = Context::new(runtime_dir);
let device = device_name();
let device = ctx.create_device(&device);
device
}
#[test]
fn event() {
let device = get_device();
Expand Down Expand Up @@ -89,10 +60,6 @@ fn event() {
#[test]
#[should_panic]
fn callable_return_mismatch() {
// Cpp backends cannot recover from panic
if device_name() != "cpu" {
panic!();
}
let device = get_device();
let _abs = device.create_callable::<fn(Expr<f32>) -> Expr<f32>>(&|x| {
if_!(x.cmpgt(0.0), {
Expand All @@ -101,12 +68,10 @@ fn callable_return_mismatch() {
-x
});
}

#[test]
#[should_panic]
fn callable_return_void_mismatch() {
if device_name() != "cpu" {
panic!();
}
let device = get_device();
let _abs = device.create_callable::<fn(Var<f32>)>(&|x| {
if_!(x.cmpgt(0.0), {
Expand Down

0 comments on commit dd6b25a

Please sign in to comment.