Skip to content

Commit

Permalink
update submod and readme
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Sep 12, 2023
1 parent e5e92e9 commit f0ec52d
Show file tree
Hide file tree
Showing 27 changed files with 256 additions and 159 deletions.
23 changes: 11 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,14 @@ let result = my_add.call(args);

```
### Callable
Users can define device-only functions using Callables. Callables have similar type signature to kernels: `Callable<ArgsTuple, Ret>`.
Users can define device-only functions using Callables. Callables have similar type signature to kernels: `Callable<fn(Args)->Ret>`.
The difference is that Callables are not dispatchable and can only be called from other Callables or Kernels. Callables can be created using `Device::create_callable`. To invoke a Callable, use `Callable::call(args...)`. Callables accepts arguments such as resources (`BufferVar<T>`, .etc), expressions and references (pass a `Var<T>` to the callable). For example:
```rust
let add = device.create_callable::<(Expr<f32>, Expr<f32>), Expr<f32>>(&|a, b| {
let add = device.create_callable::<fn(Expr<f32>, Expr<f32>)-> Expr<f32>>(&|a, b| {
a + b
});
let z = add.call(x, y);
let pass_by_ref = device.create_callable::<(Var<f32>,), ()>(&|a| {
let pass_by_ref = device.create_callable::<fn(Var<f32>)>(&|a| {
*a.get_mut() += 1.0;
});
let a = var!(f32, 1.0);
Expand All @@ -314,9 +314,9 @@ cpu_dbg!(*a); // prints 2.0
```
***Note***: You cannot record a callable when recording another kernel or callables. This is because a callable can capture outer variables such as buffers. However, capturing local variables define in another callable is undefined behavior. To avoid this, we disallow recording a callable when recording another callable or kernel.
```rust
let add = device.create_callable::<(Expr<f32>, Expr<f32>), Expr<f32>>(&|a, b| {
let add = device.create_callable::<fn(Expr<f32>, Expr<f32>)-> Expr<f32>>(&|a, b| {
// runtime error!
let another_add = device.create_callable::<(Expr<f32>, Expr<f32>), Expr<f32>>(&|a, b| {
let another_add = device.create_callable::<fn(Expr<f32>, Expr<f32>)-> Expr<f32>>(&|a, b| {
a + b
});
a + b
Expand All @@ -327,7 +327,7 @@ let add = device.create_callable::<(Expr<f32>, Expr<f32>), Expr<f32>>(&|a, b| {
1. Use static callables. A static callable does not capture any resources and thus can be safely recorded inside any callable/kernel. To create a static callable, use `create_static_callable(fn)`. For example,
```rust
lazy_static! {
static ref ADD:Callable<(Expr<f32>, Expr<f32>), Expr<f32>> = create_static_callable::<(Expr<f32>, Expr<f32>), Expr<f32>>(|a, b| {
static ref ADD:Callable<fn(Expr<f32>, Expr<f32>)->Expr<f32>> = create_static_callable::<fn(Expr<f32>, Expr<f32>)->Expr<f32>>(|a, b| {
a + b
});
}
Expand All @@ -337,9 +337,9 @@ ADD.call(x, y);
2. Use `DynCallable`. These are callables that defer recording until being called. As a result, it requires you to pass a `'static` closure, avoiding the capture issue. To create a `DynCallable`, use `Device::create_dyn_callable(Box::new(fn))`. The syntax is the same as `create_callable`. Furthermore, `DynCallable` supports `DynExpr` and `DynVar`, which provides some capablitiy of implementing template/overloading inside EDSL.

```rust
let add = device.create_callable::<(Expr<f32>, Expr<f32>), Expr<f32>>(&|a, b| {
let add = device.create_callable::<fn(Expr<f32>, Expr<f32>)->Expr<f32>>(&|a, b| {
// no error!
let another_add = device.create_dyn_callable::<(Expr<f32>, Expr<f32>), Expr<f32>>(Box::new(|a, b| {
let another_add = device.create_dyn_callable::<fn(Expr<f32>, Expr<f32>)->Expr<f32>>(Box::new(|a, b| {
a + b
}));
a + b
Expand All @@ -349,18 +349,17 @@ let add = device.create_callable::<(Expr<f32>, Expr<f32>), Expr<f32>>(&|a, b| {
### Kernel
A kernel can be written in a closure or a function. The closure/function should have a `Fn(/*args*/)->()` signature, where the args are taking the `Var` type of resources, such as `BufferVar<T>`, `Tex2D<T>`, etc.

Note: `Device::create_kernel` takes a tuple of types as its generic parameter. If the kernel takes a single argument, it is required to use `create_kernel::<(Type,)>` instead of `create_kernel::<Type>`.

```rust
let kernel = device.create_kernel::<(Arg0, Arg1, ...)>(&|/*args*/| {
let kernel = device.create_kernel::<fn(Arg0, Arg1, ...)>(&|/*args*/| {
/*body*/
});
kernel.dispatch([/*dispatch size*/], &arg0, &arg1, ...);
```
There are two ways to pass arguments to a kernel: by arguments or by capture.
```rust
let captured:Buffer<f32> = device.create_buffer(...);
let kernel = device.create_kernel::<(BufferVar<f32>, )>(arg| {
let kernel = device.create_kernel::<fn(BufferVar<f32>>(arg| {
let v = arg.read(..);
let u = captured.var().read(..);
}));
Expand All @@ -372,7 +371,7 @@ pub struct BufferPair {
a:Buffer<f32>,
b:Buffer<f32>
}
let kernel = device.create_kernel::<(BufferPair, )>(&|| {
let kernel = device.create_kernel::<fn(BufferPair)>(&|| {
// ...
});
let a = device.create_buffer(...);
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/atomic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ fn main() {
let sum = device.create_buffer::<f32>(1);
x.view(..).fill_fn(|i| i as f32);
sum.view(..).fill(0.0);
let shader = device.create_kernel::<()>(&|| {
let shader = device.create_kernel::<fn()>(&|| {
let buf_x = x.var();
let buf_sum = sum.var();
let tid = luisa::dispatch_id().x();
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/autodiff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ fn main() {
let dy = device.create_buffer::<f32>(1024);
x.fill_fn(|i| i as f32);
y.fill_fn(|i| 1.0 + i as f32);
let shader = device.create_kernel::<(Buffer<f32>, Buffer<f32>, Buffer<f32>, Buffer<f32>)>(
let shader = device.create_kernel::<fn(Buffer<f32>, Buffer<f32>, Buffer<f32>, Buffer<f32>)>(
&|buf_x: BufferVar<f32>,
buf_y: BufferVar<f32>,
buf_dx: BufferVar<f32>,
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/backtrace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ fn main() {
let z = device.create_buffer::<f32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = device.create_kernel::<(Buffer<f32>,)>(&|buf_z| {
let kernel = device.create_kernel::<fn(Buffer<f32>)>(&|buf_z| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/bindgroup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,6 @@ fn main() {
y,
exclude: 42.0,
};
let shader = device.create_kernel::<(MyArgStruct<f32>,)>(&|_args| {});
let shader = device.create_kernel::<fn(MyArgStruct<f32>)>(&|_args| {});
shader.dispatch([1024, 1, 1], &my_args);
}
2 changes: 1 addition & 1 deletion luisa_compute/examples/bindless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ fn main() {
bindless.emplace_buffer_async(1, &y);
bindless.emplace_tex2d_async(0, &img, Sampler::default());
bindless.update();
let kernel = device.create_kernel::<(BufferView<f32>,)>(&|buf_z| {
let kernel = device.create_kernel::<fn(Buffer<f32>)>(&|buf_z| {
let bindless = bindless.var();
let tid = dispatch_id().x();
let buf_x = bindless.buffer::<f32>(Uint::from(0));
Expand Down
4 changes: 2 additions & 2 deletions luisa_compute/examples/callable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ fn main() {
init_logger();
let ctx = Context::new(current_exe().unwrap());
let device = ctx.create_device("cpu");
let add = device.create_callable::<(Expr<f32>, Expr<f32>), Expr<f32>>(&|a, b| a + b);
let add = device.create_callable::<fn(Expr<f32>, Expr<f32>)->Expr<f32>>(&|a, b| a + b);
let x = device.create_buffer::<f32>(1024);
let y = device.create_buffer::<f32>(1024);
let z = device.create_buffer::<f32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = device.create_kernel::<(Buffer<f32>,)>(&|buf_z| {
let kernel = device.create_kernel::<fn(Buffer<f32>)>(&|buf_z| {
let buf_x = x.var();
let buf_y = y.var();
let tid = dispatch_id().x();
Expand Down
4 changes: 2 additions & 2 deletions luisa_compute/examples/callable_advanced.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ fn main() {
init_logger();
let ctx = Context::new(current_exe().unwrap());
let device = ctx.create_device("cpu");
let add = device.create_dyn_callable::<(DynExpr, DynExpr), DynExpr>(Box::new(
let add = device.create_dyn_callable::<fn(DynExpr, DynExpr) -> DynExpr>(Box::new(
|a: DynExpr, b: DynExpr| -> DynExpr {
if let Some(a) = a.downcast::<f32>() {
let b = b.downcast::<f32>().unwrap();
Expand All @@ -25,7 +25,7 @@ fn main() {
let w = device.create_buffer::<i32>(1024);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = device.create_kernel::<(Buffer<f32>,)>(&|buf_z| {
let kernel = device.create_kernel::<fn(Buffer<f32>)>(&|buf_z| {
let buf_x = x.var();
let buf_y = y.var();
let tid = dispatch_id().x();
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/custom_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ fn main() {
}
});
let shader = device
.create_kernel::<(Buffer<f32>,)>(&|buf_z: BufferVar<f32>| {
.create_kernel::<fn(Buffer<f32>)>(&|buf_z: BufferVar<f32>| {
// z is pass by arg
let buf_x = x.var(); // x and y are captured
let buf_y = y.var();
Expand Down
4 changes: 2 additions & 2 deletions luisa_compute/examples/find_leak.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ fn main() {
let z = device.create_buffer::<f32>(count);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = device.create_kernel::<(Buffer<f32>, Buffer<f32>, Buffer<f32>)>(
let kernel = device.create_kernel::<fn(Buffer<f32>, Buffer<f32>, Buffer<f32>)>(
&|buf_x, buf_y, buf_z| {
let tid = dispatch_id().x();
let x = buf_x.read(tid);
Expand All @@ -47,7 +47,7 @@ fn main() {
let z = device.create_buffer::<f32>(count);
x.view(..).fill_fn(|i| i as f32);
y.view(..).fill_fn(|i| 1000.0 * i as f32);
let kernel = device.create_kernel::<(Buffer<f32>,)>(&|buf_z| {
let kernel = device.create_kernel::<fn(Buffer<f32>)>(&|buf_z| {
let buf_x = x.var();
let buf_y = y.var();
let tid = dispatch_id().x();
Expand Down
18 changes: 9 additions & 9 deletions luisa_compute/examples/fluid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ fn main() {
};

let advect = device
.create_kernel_async::<(Buffer<Float2>, Buffer<Float2>, Buffer<f32>, Buffer<f32>)>(
.create_kernel_async::<fn(Buffer<Float2>, Buffer<Float2>, Buffer<f32>, Buffer<f32>)>(
&|u0, u1, rho0, rho1| {
let coord = dispatch_id().xy();
let u = u0.read(index(coord));
Expand All @@ -129,7 +129,7 @@ fn main() {
},
);

let divergence = device.create_kernel_async::<(Buffer<Float2>, Buffer<f32>)>(&|u, div| {
let divergence = device.create_kernel_async::<fn(Buffer<Float2>, Buffer<f32>)>(&|u, div| {
let coord = dispatch_id().xy();
if_!(coord.x().cmplt(N_GRID - 1) & coord.y().cmplt(N_GRID - 1), {
let dx = (u.read(index(make_uint2(coord.x() + 1, coord.y()))).x()
Expand All @@ -143,7 +143,7 @@ fn main() {
});

let pressure_solve =
device.create_kernel_async::<(Buffer<f32>, Buffer<f32>, Buffer<f32>)>(&|p0, p1, div| {
device.create_kernel_async::<fn(Buffer<f32>, Buffer<f32>, Buffer<f32>)>(&|p0, p1, div| {
let coord = dispatch_id().xy();
let i = coord.x().int();
let j = coord.y().int();
Expand All @@ -159,7 +159,7 @@ fn main() {
p1.write(ij, err * 0.25f32);
});

let pressure_apply = device.create_kernel_async::<(Buffer<f32>, Buffer<Float2>)>(&|p, u| {
let pressure_apply = device.create_kernel_async::<fn(Buffer<f32>, Buffer<Float2>)>(&|p, u| {
let coord = dispatch_id().xy();
let i = coord.x().int();
let j = coord.y().int();
Expand All @@ -181,7 +181,7 @@ fn main() {
);
});

let integrate = device.create_kernel_async::<(Buffer<Float2>, Buffer<f32>)>(&|u, rho| {
let integrate = device.create_kernel_async::<fn(Buffer<Float2>, Buffer<f32>)>(&|u, rho| {
let coord = dispatch_id().xy();
let ij = index(coord);

Expand All @@ -196,7 +196,7 @@ fn main() {
});

let init =
device.create_kernel_async::<(Buffer<f32>, Buffer<Float2>, Float2)>(&|rho, u, dir| {
device.create_kernel_async::<fn(Buffer<f32>, Buffer<Float2>, Float2)>(&|rho, u, dir| {
let coord = dispatch_id().xy();
let i = coord.x().int();
let j = coord.y().int();
Expand All @@ -210,7 +210,7 @@ fn main() {
});
});

let init_grid = device.create_kernel_async::<()>(&|| {
let init_grid = device.create_kernel_async::<fn()>(&|| {
let idx = index(dispatch_id().xy());
u0.var().write(idx, make_float2(0.0f32, 0.0f32));
u1.var().write(idx, make_float2(0.0f32, 0.0f32));
Expand All @@ -223,13 +223,13 @@ fn main() {
div.var().write(idx, 0.0f32);
});

let clear_pressure = device.create_kernel_async::<()>(&|| {
let clear_pressure = device.create_kernel_async::<fn()>(&|| {
let idx = index(dispatch_id().xy());
p0.var().write(idx, 0.0f32);
p1.var().write(idx, 0.0f32);
});

let draw_rho = device.create_kernel_async::<()>(&|| {
let draw_rho = device.create_kernel_async::<fn()>(&|| {
let coord = dispatch_id().xy();
let ij = index(coord);
let value = rho0.var().read(ij);
Expand Down
12 changes: 6 additions & 6 deletions luisa_compute/examples/mpm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,14 +93,14 @@ fn main() {
p.x() + p.y() * N_GRID as u32
};

let clear_grid = device.create_kernel_async::<()>(&|| {
let clear_grid = device.create_kernel_async::<fn()>(&|| {
let idx = index(dispatch_id().xy());
grid_v.var().write(idx * 2, 0.0f32);
grid_v.var().write(idx * 2 + 1, 0.0f32);
grid_m.var().write(idx, 0.0f32);
});

let point_to_grid = device.create_kernel_async::<()>(&|| {
let point_to_grid = device.create_kernel_async::<fn()>(&|| {
let p = dispatch_id().x();
let xp = x.var().read(p) / DX;
let base = (xp - 0.5f32).int();
Expand Down Expand Up @@ -128,7 +128,7 @@ fn main() {
}
});

let simulate_grid = device.create_kernel_async::<()>(&|| {
let simulate_grid = device.create_kernel_async::<fn()>(&|| {
let coord = dispatch_id().xy();
let i = index(coord);
let v = var!(Float2);
Expand Down Expand Up @@ -157,7 +157,7 @@ fn main() {
grid_v.var().write(i * 2 + 1, vy);
});

let grid_to_point = device.create_kernel_async::<()>(&|| {
let grid_to_point = device.create_kernel_async::<fn()>(&|| {
let p = dispatch_id().x();
let xp = x.var().read(p) / DX;
let base = (xp - 0.5f32).int();
Expand Down Expand Up @@ -192,13 +192,13 @@ fn main() {
C.var().write(p, new_C);
});

let clear_display = device.create_kernel_async::<()>(&|| {
let clear_display = device.create_kernel_async::<fn()>(&|| {
display.var().write(
dispatch_id().xy(),
make_float4(0.1f32, 0.2f32, 0.3f32, 1.0f32),
);
});
let draw_particles = device.create_kernel_async::<()>(&|| {
let draw_particles = device.create_kernel_async::<fn()>(&|| {
let p = dispatch_id().x();
for i in -1..=1 {
for j in -1..=1 {
Expand Down
6 changes: 3 additions & 3 deletions luisa_compute/examples/path_tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ fn main() {

// use create_kernel_async to compile multiple kernels in parallel
let path_tracer = device
.create_kernel_async::<(Tex2d<Float4>, Tex2d<u32>, Accel, Uint2)>(
.create_kernel_async::<fn(Tex2d<Float4>, Tex2d<u32>, Accel, Uint2)>(
&|image: Tex2dVar<Float4>,
seed_image: Tex2dVar<u32>,
accel: AccelVar,
Expand All @@ -265,7 +265,7 @@ fn main() {
]);

let lcg = |state: Var<u32>| -> Expr<f32> {
let lcg = create_static_callable::<(Var<u32>,), Expr<f32>>(|state:Var<u32>|{
let lcg = create_static_callable::<fn(Var<u32>)-> Expr<f32>>(|state:Var<u32>|{
const LCG_A: u32 = 1664525u32;
const LCG_C: u32 = 1013904223u32;
*state.get_mut() = LCG_A * *state + LCG_C;
Expand Down Expand Up @@ -441,7 +441,7 @@ fn main() {
},
)
;
let display = device.create_kernel_async::<(Tex2d<Float4>, Tex2d<Float4>)>(&|acc, display| {
let display = device.create_kernel_async::<fn(Tex2d<Float4>, Tex2d<Float4>)>(&|acc, display| {
set_block_size([16, 16, 1]);
let coord = dispatch_id().xy();
let radiance = acc.read(coord);
Expand Down
6 changes: 3 additions & 3 deletions luisa_compute/examples/path_tracer_cutout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ fn main() {

// use create_kernel_async to compile multiple kernels in parallel
let path_tracer = device
.create_kernel_async::<(Tex2d<Float4>, Tex2d<u32>, Accel, Uint2)>(
.create_kernel_async::<fn(Tex2d<Float4>, Tex2d<u32>, Accel, Uint2)>(
&|image: Tex2dVar<Float4>,
seed_image: Tex2dVar<u32>,
accel: AccelVar,
Expand All @@ -271,7 +271,7 @@ fn main() {
]);

let lcg = |state: Var<u32>| -> Expr<f32> {
let lcg = create_static_callable::<(Var<u32>, ), Expr<f32>>(|state: Var<u32>| {
let lcg = create_static_callable::<fn(Var<u32>)->Expr<f32>>(|state: Var<u32>| {
const LCG_A: u32 = 1664525u32;
const LCG_C: u32 = 1013904223u32;
*state.get_mut() = LCG_A * *state + LCG_C;
Expand Down Expand Up @@ -470,7 +470,7 @@ fn main() {
},
)
;
let display = device.create_kernel_async::<(Tex2d<Float4>, Tex2d<Float4>)>(&|acc, display| {
let display = device.create_kernel_async::<fn(Tex2d<Float4>, Tex2d<Float4>)>(&|acc, display| {
set_block_size([16, 16, 1]);
let coord = dispatch_id().xy();
let radiance = acc.read(coord);
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/polymorphism.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ fn main() {
poly_area.register((), &circles);
poly_area.register((), &squares);
let areas = device.create_buffer::<f32>(4);
let shader = device.create_kernel::<()>(&|| {
let shader = device.create_kernel::<fn()>(&|| {
let tid = dispatch_id().x();
let tag = tid / 2;
let index = tid % 2;
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/polymorphism_advanced.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ fn main() {
);
let poly_shader = builder.build();
let result = device.create_buffer::<f32>(100);
let kernel = device.create_kernel::<()>(&|| {
let kernel = device.create_kernel::<fn()>(&|| {
let i = dispatch_id().x();
let x = i.float() / 100.0 * PI;
let ctx = ShaderEvalContext {
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/printer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ fn main() {
"cpu"
});
let printer = Printer::new(&device, 65536);
let kernel = device.create_kernel::<()>(&|| {
let kernel = device.create_kernel::<fn()>(&|| {
let id = dispatch_id().xy();
if_!(id.x().cmpeq(id.y()), {
lc_info!(printer, "id = {:?}", id);
Expand Down
Loading

0 comments on commit f0ec52d

Please sign in to comment.