From 11af866832a73f0992492b88aba702704a7b0d7b Mon Sep 17 00:00:00 2001 From: Xiaochun Tong Date: Wed, 24 Apr 2024 16:23:34 -0400 Subject: [PATCH] fix generics --- luisa_compute/examples/test_buffer.rs | 69 +++++++++++++++++++++++++++ luisa_compute/tests/misc.rs | 8 ++++ luisa_compute_derive_impl/src/lib.rs | 20 ++++---- luisa_compute_sys/LuisaCompute | 2 +- 4 files changed, 88 insertions(+), 11 deletions(-) create mode 100644 luisa_compute/examples/test_buffer.rs diff --git a/luisa_compute/examples/test_buffer.rs b/luisa_compute/examples/test_buffer.rs new file mode 100644 index 0000000..3accebd --- /dev/null +++ b/luisa_compute/examples/test_buffer.rs @@ -0,0 +1,69 @@ +use luisa::prelude::*; +use luisa_compute as luisa; +use std::env::current_exe; + +fn main() { + luisa::init_logger_verbose(); + let args: Vec = std::env::args().collect(); + assert!( + args.len() <= 2, + "Usage: {} . : cpu, cuda, dx, metal, remote", + args[0] + ); + + let ctx = Context::new(current_exe().unwrap()); + let device = ctx.create_device(if args.len() == 2 { + args[1].as_str() + } else { + "cpu" + }); + let x = device.create_buffer::(1024); + let y = device.create_buffer::(1024); + let z = device.create_buffer::(1024); + x.view(..).fill_fn(|i| i as f32); + y.view(..).fill_fn(|i| 1000.0 * i as f32); + + let kernel = device.create_kernel_with_options::)>( + KernelBuildOptions { + name: Some("vecadd".into()), + ..Default::default() + }, + &track!(|buf_z| { + // z is pass by arg + let buf_x = &x; // x and y are captured + let buf_y = &y; + let tid = dispatch_id().x; + let x = buf_x.read(tid); + let y = buf_y.read(tid); + buf_z.write(tid, x + y); + }), + ); + let mut z_data = vec![123.0f32; 1024]; + + unsafe { + let s = device.default_stream().scope(); + let z_data_ptr = z_data.as_mut_ptr(); + s.submit([ + z.copy_from_async(std::slice::from_raw_parts_mut(z_data_ptr, 1024)), + kernel.dispatch_async([1024, 1, 1], &z), + z.copy_to_async(std::slice::from_raw_parts_mut(z_data_ptr, 1024)), + z.copy_from_async(std::slice::from_raw_parts_mut(z_data_ptr, 1024)), + z.copy_to_buffer_async(&x) + ]); + } + + // this should produce the expected behavior + // unsafe { + // let z_data_ptr = z_data.as_mut_ptr(); + + // z.copy_from(std::slice::from_raw_parts_mut(z_data_ptr, 1024)); + // kernel.dispatch([1024, 1, 1], &z); + // z.copy_to(std::slice::from_raw_parts_mut(z_data_ptr, 1024)); + // z.copy_from(std::slice::from_raw_parts_mut(z_data_ptr, 1024)); + // z.copy_to_buffer(&x); + // } + + println!("{:?}", &z_data[0..16]); + let x_data = x.copy_to_vec(); + println!("{:?}", &x_data[0..16]); +} diff --git a/luisa_compute/tests/misc.rs b/luisa_compute/tests/misc.rs index a0d36d2..0cd3d3a 100644 --- a/luisa_compute/tests/misc.rs +++ b/luisa_compute/tests/misc.rs @@ -1472,6 +1472,14 @@ pub struct Bar { a: [i32; 4], f: Foo, } +#[derive(Clone, Copy, Debug, Value)] +#[repr(C)] +pub struct Foo2 { + i: T, + v: Float2, + a: [T; 4], + m: Mat2, +} #[test] fn soa() { let device = get_device(); diff --git a/luisa_compute_derive_impl/src/lib.rs b/luisa_compute_derive_impl/src/lib.rs index d6b9c17..3e34dbf 100644 --- a/luisa_compute_derive_impl/src/lib.rs +++ b/luisa_compute_derive_impl/src/lib.rs @@ -545,7 +545,7 @@ impl Compiler { #vis fn from_comps_expr(ctor: #ctor_proxy_name #ty_generics) -> #lang_path::types::Expr<#name #ty_generics> { use #lang_path::*; let node = #lang_path::__compose::<#name #ty_generics>(&[ #( #lang_path::ToNode::node(&ctor.#field_names.as_expr()).get() ),* ]); - <#lang_path::types::Expr::<#name> as #lang_path::FromNode>::from_node(node.into()) + <#lang_path::types::Expr::<#name #ty_generics> as #lang_path::FromNode>::from_node(node.into()) } } ) @@ -575,7 +575,7 @@ impl Compiler { #[allow(dead_code)] #vis struct #expr_proxy_name #generics #where_clause{ _marker: std::marker::PhantomData<(#marker_args)>, - self_: #lang_path::types::Expr<#name>, + self_: #lang_path::types::Expr<#name #ty_generics>, #(#field_vis #field_names: #lang_path::types::Expr<#field_types>),* } @@ -584,7 +584,7 @@ impl Compiler { #[allow(dead_code)] #vis struct #var_proxy_name #generics #where_clause{ _marker: std::marker::PhantomData<(#marker_args)>, - self_: #lang_path::types::Var<#name>, + self_: #lang_path::types::Var<#name #ty_generics>, #(#field_vis #field_names: #lang_path::types::Var<#field_types>),*, } #[derive(Clone, Copy)] @@ -592,7 +592,7 @@ impl Compiler { #[allow(dead_code)] #vis struct #atomic_ref_proxy_name #generics #where_clause{ _marker: std::marker::PhantomData<(#marker_args)>, - self_: #lang_path::types::AtomicRef<#name>, + self_: #lang_path::types::AtomicRef<#name #ty_generics>, #(#field_vis #field_names: #lang_path::types::AtomicRef<#field_types>),*, } #[allow(unused_parens)] @@ -609,7 +609,7 @@ impl Compiler { } } - fn as_expr_from_proxy(&self) -> &#lang_path::types::Expr<#name> { + fn as_expr_from_proxy(&self) -> &#lang_path::types::Expr<#name #ty_generics> { &self.self_ } } @@ -626,7 +626,7 @@ impl Compiler { #(#field_names),* } } - fn as_var_from_proxy(&self) -> &#lang_path::types::Var<#name> { + fn as_var_from_proxy(&self) -> &#lang_path::types::Var<#name #ty_generics> { &self.self_ } } @@ -643,13 +643,13 @@ impl Compiler { #(#field_names),* } } - fn as_atomic_ref_from_proxy(&self) -> &#lang_path::types::AtomicRef<#name> { + fn as_atomic_ref_from_proxy(&self) -> &#lang_path::types::AtomicRef<#name #ty_generics> { &self.self_ } } #[allow(unused_parens)] impl #impl_generics std::ops::Deref for #var_proxy_name #ty_generics #where_clause { - type Target = #lang_path::types::Expr<#name> #ty_generics; + type Target = #lang_path::types::Expr<#name #ty_generics>; fn deref(&self) -> &Self::Target { #lang_path::types::_deref_proxy(self) } @@ -689,10 +689,10 @@ impl Compiler { quote_spanned! { span => impl #impl_generics #name #ty_generics #where_clause { - #vis fn new_expr(#(#field_names: impl #lang_path::types::AsExpr),*) -> #lang_path::types::Expr::<#name> { + #vis fn new_expr(#(#field_names: impl #lang_path::types::AsExpr),*) -> #lang_path::types::Expr::<#name #ty_generics> { use #lang_path::*; let node = #lang_path::__compose::<#name #ty_generics>(&[ #( #lang_path::ToNode::node(&#field_names.as_expr()).get() ),* ]); - <#lang_path::types::Expr::<#name> as #lang_path::FromNode>::from_node(node.into()) + <#lang_path::types::Expr::<#name #ty_generics> as #lang_path::FromNode>::from_node(node.into()) } } } diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index 815126c..4375ca6 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit 815126c5c401b2f2c3094b8a839f2df8ea9b0c4b +Subproject commit 4375ca6d4fb193a2a1ce203604b232b3c193379d