From cbab1f9c3b032cebb4d166d0d6a22ac339699f6d Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Sun, 3 Sep 2023 16:22:52 -0400 Subject: [PATCH] fixed ssa --- luisa_compute/tests/autodiff.rs | 91 ++++++++++++++++++++++++++++----- luisa_compute_sys/LuisaCompute | 2 +- 2 files changed, 80 insertions(+), 13 deletions(-) diff --git a/luisa_compute/tests/autodiff.rs b/luisa_compute/tests/autodiff.rs index 5dcc060..74d04fb 100644 --- a/luisa_compute/tests/autodiff.rs +++ b/luisa_compute/tests/autodiff.rs @@ -380,13 +380,13 @@ fn autodiff_vec3_cross_x() { let ax = inputs[0]; let ay = inputs[1]; let az = inputs[2]; - let a = make_float3(ax, ay, az); + let a = def(make_float3(ax, ay, az)); let bx = inputs[3]; let by = inputs[4]; let bz = inputs[5]; - let b = make_float3(bx, by, bz); - let v = a.cross(b); - v.x() + let b = def(make_float3(bx, by, bz)); + let v = def(a.cross(*b)); + *v.x() }); } #[test] @@ -395,13 +395,13 @@ fn autodiff_vec3_cross_y() { let ax = inputs[0]; let ay = inputs[1]; let az = inputs[2]; - let a = make_float3(ax, ay, az); + let a = def(make_float3(ax, ay, az)); let bx = inputs[3]; let by = inputs[4]; let bz = inputs[5]; - let b = make_float3(bx, by, bz); - let v = a.cross(b); - v.y() + let b = def(make_float3(bx, by, bz)); + let v = def(a.cross(*b)); + *v.x() }); } @@ -918,15 +918,82 @@ fn autodiff_if_phi3() { let tid = dispatch_id().x(); let x = buf_x.read(tid); let y = buf_y.read(tid); + let const_two = var!(f32, 2.0); + let const_three = var!(f32, 3.0); + let const_four = var!(f32); + + autodiff(|| { + requires_grad(x); + requires_grad(y); + const_four.store(4.0); + let c = x.cmpgt(*const_three).int(); + let z = if_!(x.cmpgt(y), { + switch::>(c) + .case(0, || x * *const_two) + .default(|| x * *const_four) + .finish() * *const_two + }, else { + y * 0.5 + }); + backward(z); + buf_dx.write(tid, gradient(x)); + buf_dy.write(tid, gradient(y)); + }); + }); + kernel.dispatch([1024, 1, 1]); + let dx = dx.view(..).copy_to_vec(); + let dy = dy.view(..).copy_to_vec(); + let x = x.view(..).copy_to_vec(); + let y = y.view(..).copy_to_vec(); + let cache_dir = kernel.cache_dir(); + for i in 0..1024 { + if x[i] > y[i] { + if x[i] > 3.0 { + assert_eq!(dx[i], 8.0, "{} cache_dir: {:?}", dx[i], cache_dir); + assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir); + } else { + assert_eq!(dx[i], 4.0, "{} cache_dir: {:?}", dx[i], cache_dir); + assert_eq!(dy[i], 0.0, "{} cache_dir: {:?}", dy[i], cache_dir); + } + } else { + assert_eq!(dx[i], 0.0, "{} cache_dir: {:?}", dx[i], cache_dir); + assert_eq!(dy[i], 0.5, "{} cache_dir: {:?}", dy[i], cache_dir); + } + } +} +#[test] +fn autodiff_if_phi4() { + let device = get_device(); + let x: Buffer = device.create_buffer(1024); + let y: Buffer = device.create_buffer(1024); + let dx: Buffer = device.create_buffer(1024); + let dy: Buffer = device.create_buffer(1024); + let mut rng = rand::thread_rng(); + x.view(..).fill_fn(|_| rng.gen()); + y.view(..).fill_fn(|_| rng.gen()); + let kernel = device.create_kernel::<()>(&|| { + let buf_x = x.var(); + let buf_y = y.var(); + let buf_dx = dx.var(); + let buf_dy = dy.var(); + let tid = dispatch_id().x(); + let x = buf_x.read(tid); + let y = buf_y.read(tid); + + let consts = var!(Float3); autodiff(|| { requires_grad(x); requires_grad(y); - let c = x.cmpgt(3.0).int(); + consts.store(make_float3(2.0,3.0,4.0)); + let const_two = consts.x(); + let const_three = consts.y(); + let const_four = consts.z(); + let c = x.cmpgt(*const_three).int(); let z = if_!(x.cmpgt(y), { switch::>(c) - .case(0, || x * 2.0) - .default(|| x * 4.0) - .finish() * 2.0 + .case(0, || x * *const_two) + .default(|| x * *const_four) + .finish() * *const_two }, else { y * 0.5 }); diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 967c96f..6fbf287 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 967c96f68d1d51dde2ec49660591d13ec6ad31ba +Subproject commit 6fbf28719f52db727be4a0ee43fd3c195e935046