Skip to content

Commit

Permalink
fix generics
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Apr 24, 2024
1 parent 5923bf7 commit 11af866
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 11 deletions.
69 changes: 69 additions & 0 deletions luisa_compute/examples/test_buffer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use luisa::prelude::*;
use luisa_compute as luisa;
use std::env::current_exe;

fn main() {
luisa::init_logger_verbose();
let args: Vec<String> = std::env::args().collect();
assert!(
args.len() <= 2,
"Usage: {} <backend>. <backend>: cpu, cuda, dx, metal, remote",
args[0]
);

let ctx = Context::new(current_exe().unwrap());
let device = ctx.create_device(if args.len() == 2 {
args[1].as_str()
} else {
"cpu"
});
let x = device.create_buffer::<f32>(1024);
let y = device.create_buffer::<f32>(1024);
let z = device.create_buffer::<f32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);

let kernel = device.create_kernel_with_options::<fn(Buffer<f32>)>(
KernelBuildOptions {
name: Some("vecadd".into()),
..Default::default()
},
&track!(|buf_z| {
// z is pass by arg
let buf_x = &x; // x and y are captured
let buf_y = &y;
let tid = dispatch_id().x;
let x = buf_x.read(tid);
let y = buf_y.read(tid);
buf_z.write(tid, x + y);
}),
);
let mut z_data = vec![123.0f32; 1024];

unsafe {
let s = device.default_stream().scope();
let z_data_ptr = z_data.as_mut_ptr();
s.submit([
z.copy_from_async(std::slice::from_raw_parts_mut(z_data_ptr, 1024)),
kernel.dispatch_async([1024, 1, 1], &z),
z.copy_to_async(std::slice::from_raw_parts_mut(z_data_ptr, 1024)),
z.copy_from_async(std::slice::from_raw_parts_mut(z_data_ptr, 1024)),
z.copy_to_buffer_async(&x)
]);
}

// this should produce the expected behavior
// unsafe {
// let z_data_ptr = z_data.as_mut_ptr();

// z.copy_from(std::slice::from_raw_parts_mut(z_data_ptr, 1024));
// kernel.dispatch([1024, 1, 1], &z);
// z.copy_to(std::slice::from_raw_parts_mut(z_data_ptr, 1024));
// z.copy_from(std::slice::from_raw_parts_mut(z_data_ptr, 1024));
// z.copy_to_buffer(&x);
// }

println!("{:?}", &z_data[0..16]);
let x_data = x.copy_to_vec();
println!("{:?}", &x_data[0..16]);
}
8 changes: 8 additions & 0 deletions luisa_compute/tests/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1472,6 +1472,14 @@ pub struct Bar {
a: [i32; 4],
f: Foo,
}
#[derive(Clone, Copy, Debug, Value)]
#[repr(C)]
pub struct Foo2<T: Value> {
i: T,
v: Float2,
a: [T; 4],
m: Mat2,
}
#[test]
fn soa() {
let device = get_device();
Expand Down
20 changes: 10 additions & 10 deletions luisa_compute_derive_impl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ impl Compiler {
#vis fn from_comps_expr(ctor: #ctor_proxy_name #ty_generics) -> #lang_path::types::Expr<#name #ty_generics> {
use #lang_path::*;
let node = #lang_path::__compose::<#name #ty_generics>(&[ #( #lang_path::ToNode::node(&ctor.#field_names.as_expr()).get() ),* ]);
<#lang_path::types::Expr::<#name> as #lang_path::FromNode>::from_node(node.into())
<#lang_path::types::Expr::<#name #ty_generics> as #lang_path::FromNode>::from_node(node.into())
}
}
)
Expand Down Expand Up @@ -575,7 +575,7 @@ impl Compiler {
#[allow(dead_code)]
#vis struct #expr_proxy_name #generics #where_clause{
_marker: std::marker::PhantomData<(#marker_args)>,
self_: #lang_path::types::Expr<#name>,
self_: #lang_path::types::Expr<#name #ty_generics>,
#(#field_vis #field_names: #lang_path::types::Expr<#field_types>),*

}
Expand All @@ -584,15 +584,15 @@ impl Compiler {
#[allow(dead_code)]
#vis struct #var_proxy_name #generics #where_clause{
_marker: std::marker::PhantomData<(#marker_args)>,
self_: #lang_path::types::Var<#name>,
self_: #lang_path::types::Var<#name #ty_generics>,
#(#field_vis #field_names: #lang_path::types::Var<#field_types>),*,
}
#[derive(Clone, Copy)]
#[allow(unused_parens)]
#[allow(dead_code)]
#vis struct #atomic_ref_proxy_name #generics #where_clause{
_marker: std::marker::PhantomData<(#marker_args)>,
self_: #lang_path::types::AtomicRef<#name>,
self_: #lang_path::types::AtomicRef<#name #ty_generics>,
#(#field_vis #field_names: #lang_path::types::AtomicRef<#field_types>),*,
}
#[allow(unused_parens)]
Expand All @@ -609,7 +609,7 @@ impl Compiler {

}
}
fn as_expr_from_proxy(&self) -> &#lang_path::types::Expr<#name> {
fn as_expr_from_proxy(&self) -> &#lang_path::types::Expr<#name #ty_generics> {
&self.self_
}
}
Expand All @@ -626,7 +626,7 @@ impl Compiler {
#(#field_names),*
}
}
fn as_var_from_proxy(&self) -> &#lang_path::types::Var<#name> {
fn as_var_from_proxy(&self) -> &#lang_path::types::Var<#name #ty_generics> {
&self.self_
}
}
Expand All @@ -643,13 +643,13 @@ impl Compiler {
#(#field_names),*
}
}
fn as_atomic_ref_from_proxy(&self) -> &#lang_path::types::AtomicRef<#name> {
fn as_atomic_ref_from_proxy(&self) -> &#lang_path::types::AtomicRef<#name #ty_generics> {
&self.self_
}
}
#[allow(unused_parens)]
impl #impl_generics std::ops::Deref for #var_proxy_name #ty_generics #where_clause {
type Target = #lang_path::types::Expr<#name> #ty_generics;
type Target = #lang_path::types::Expr<#name #ty_generics>;
fn deref(&self) -> &Self::Target {
#lang_path::types::_deref_proxy(self)
}
Expand Down Expand Up @@ -689,10 +689,10 @@ impl Compiler {
quote_spanned! {
span =>
impl #impl_generics #name #ty_generics #where_clause {
#vis fn new_expr(#(#field_names: impl #lang_path::types::AsExpr<Value = #field_types>),*) -> #lang_path::types::Expr::<#name> {
#vis fn new_expr(#(#field_names: impl #lang_path::types::AsExpr<Value = #field_types>),*) -> #lang_path::types::Expr::<#name #ty_generics> {
use #lang_path::*;
let node = #lang_path::__compose::<#name #ty_generics>(&[ #( #lang_path::ToNode::node(&#field_names.as_expr()).get() ),* ]);
<#lang_path::types::Expr::<#name> as #lang_path::FromNode>::from_node(node.into())
<#lang_path::types::Expr::<#name #ty_generics> as #lang_path::FromNode>::from_node(node.into())
}
}
}
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute_sys/LuisaCompute
Submodule LuisaCompute updated 88 files
+2 −1 .clangd
+0 −14 .devcontainer/Dockerfile
+0 −17 .devcontainer/devcontainer.json
+8 −9 .github/workflows/build-cmake.yml
+3 −3 .github/workflows/build-wheels.yml
+12 −9 .github/workflows/build-xmake.yml
+4 −0 .gitignore
+1 −4 .gitmodules
+0 −0 config/test/dev.json
+3 −0 include/luisa/ast/attribute.h
+2 −2 include/luisa/ast/type.h
+1 −1 include/luisa/backends/ext/dx_custom_cmd.h
+1 −1 include/luisa/core/dynamic_module.h
+2 −2 include/luisa/core/fiber.h
+2 −0 include/luisa/core/mathematics.h
+1 −0 include/luisa/core/stl/functional.h
+26 −0 include/luisa/core/stl/type_traits.h
+3 −2 include/luisa/core/stl/unordered_dense.h
+1 −0 include/luisa/core/stl/vector.h
+13 −11 include/luisa/dsl/operators.h
+0 −213 include/luisa/dsl/printer.h
+0 −1 include/luisa/luisa-compute.h
+6 −0 include/luisa/runtime/rtx/accel.h
+5 −4 include/luisa/vstl/arena_hash_map.h
+4 −3 include/luisa/vstl/hash_map.h
+3 −2 include/luisa/vstl/memory.h
+16 −18 include/luisa/vstl/meta_lib.h
+5 −4 include/luisa/vstl/pool.h
+2 −1 include/luisa/vstl/unique_ptr.h
+5 −4 include/luisa/vstl/vector.h
+2 −0 scripts/download_sdk.cmake
+111 −0 scripts/find_sdk.lua
+153 −136 scripts/lib.lua
+34 −0 scripts/packages.lua
+84 −25 scripts/xmake_func.lua
+18 −0 setup.lua
+55 −7 src/ast/type.cpp
+2 −0 src/backends/CMakeLists.txt
+33 −2 src/backends/common/CMakeLists.txt
+16 −1 src/backends/common/hlsl/hlsl_codegen_util.cpp
+10 −0 src/backends/cuda/cuda_builtin/cuda_device_math.h
+5,674 −5,615 src/backends/cuda/cuda_builtin_embedded.cpp
+1 −1 src/backends/cuda/cuda_builtin_embedded.h
+2 −0 src/backends/cuda/cuda_codegen_ast.cpp
+38 −3 src/backends/cuda/cuda_compiler.cpp
+9 −2 src/backends/cuda/cuda_nvrtc_compiler.cpp
+2 −0 src/backends/cuda/generate_device_library.py
+16 −16 src/backends/cuda/xmake.lua
+41 −5 src/backends/dx/CMakeLists.txt
+1 −1 src/backends/dx/DXRuntime/CommandAllocator.h
+1 −1 src/backends/dx/DXRuntime/CommandQueue.h
+1 −1 src/backends/dx/DXRuntime/DStorageCommandQueue.h
+1 −1 src/backends/dx/Resource/BottomAccel.h
+0 −1 src/backends/dx/dx_support
+1 −6 src/backends/dx/xmake.lua
+36 −20 src/backends/xmake.lua
+17 −7 src/clangcxx/src/llvm/ASTConsumer.cpp
+21 −8 src/clangcxx/src/llvm/TypeDatabase.cpp
+9 −2 src/core/platform.cpp
+0 −1 src/dsl/CMakeLists.txt
+0 −55 src/dsl/printer.cpp
+4 −4 src/ext/CMakeLists.txt
+1 −1 src/ext/EASTL
+1 −1 src/ext/glfw
+1 −1 src/ext/imgui
+1 −1 src/ext/marl
+1 −1 src/ext/reproc
+1 −1 src/ext/spdlog
+1 −1 src/ext/xxHash
+1 −1 src/gui/xmake.lua
+0 −18 src/ir/xmake.lua
+2 −2 src/py/export_gui.cpp
+6 −1 src/runtime/bindless_array.cpp
+5 −0 src/runtime/rtx/accel.cpp
+1 −1 src/rust/luisa_compute_backend_impl/src/cpu/codegen/cpp.rs
+1 −1 src/rust/luisa_compute_backend_impl/src/cpu/codegen/cpp_v2.rs
+2 −0 src/rust/luisa_compute_backend_impl/src/cpu/codegen/cpu_libm_def.h
+10 −0 src/rust/luisa_compute_backend_impl/src/cpu/codegen/device_math.h
+3 −3 src/rust/luisa_compute_backend_impl/src/lib.rs
+4 −4 src/tests/CMakeLists.txt
+2 −6 src/tests/next/example/use/use_printer.cpp
+4 −23 src/tests/test_ast.cpp
+2 −6 src/tests/test_buffer_io.cpp
+0 −1 src/tests/test_printer.cpp
+0 −1 src/tests/test_printer_custom_callback.cpp
+1 −1 src/tests/test_raster.cpp
+7 −3 src/tests/xmake.lua
+15 −17 xmake.lua

0 comments on commit 11af866

Please sign in to comment.