Skip to content

Commit

Permalink
Merge pull request #10 from iMplode-nZ/refactor
Browse files Browse the repository at this point in the history
Complete Refactor, also add `track` macro and `::expr`, `.var` commands.
  • Loading branch information
shiinamiyuki authored Sep 18, 2023
2 parents dd6b25a + d60f4d4 commit dc6338c
Show file tree
Hide file tree
Showing 58 changed files with 4,942 additions and 4,505 deletions.
4 changes: 3 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ members = [
"luisa_compute_sys",
"luisa_compute_derive_impl",
"luisa_compute_derive",
"luisa_compute_track",
]
resolver = "2"
# exclude = [
# "luisa_compute_sys/LuisaCompute/src/api/luisa_compute_api_types",
# ]
# ]
144 changes: 96 additions & 48 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,33 @@ Rust frontend to LuisaCompute and more! Unified API and embedded DSL for high pe

To see the use of `luisa-compute-rs` in a high performance offline rendering system, checkout [our research renderer](https://github.com/shiinamiyuki/akari_render)
## Table of Contents
* [Overview](#overview)
+ [Embedded Domain-Specific Language](#embedded-domain-specific-language)
+ [Automatic Differentiation](#automatic-differentiation)
+ [A CPU backend](#cpu-backend)
+ [IR Module for EDSL](#ir-module)
+ [Debuggability](#debuggability)
* [Usage](#usage)
+ [Building](#building)
+ [Variables and Expressions](#variables-and-expressions)
+ [Builtin Functions](#builtin-functions)
+ [Control Flow](#control-flow)
+ [Custom Data Types](#custom-data-types)
+ [Polymorphism](#polymorphism)
+ [Autodiff](#autodiff)
+ [Custom Operators](#custom-operators)
+ [Callable](#callable)
+ [Kernel](#kernel)
* [Advanced Usage](#advanced-usage)
* [Safety](#safety)
* [Citation](#citation)
- [luisa-compute-rs](#luisa-compute-rs)
- [Table of Contents](#table-of-contents)
- [Example](#example)
- [Vecadd](#vecadd)
- [Overview](#overview)
- [Embedded Domain-Specific Language](#embedded-domain-specific-language)
- [Automatic Differentiation](#automatic-differentiation)
- [CPU Backend](#cpu-backend)
- [IR Module](#ir-module)
- [Debuggability](#debuggability)
- [Usage](#usage)
- [Building](#building)
- [Variables and Expressions](#variables-and-expressions)
- [Builtin Functions](#builtin-functions)
- [Control Flow](#control-flow)
- [`track!` Mcro](#track-mcro)
- [Custom Data Types](#custom-data-types)
- [Polymorphism](#polymorphism)
- [Autodiff](#autodiff)
- [Custom Operators](#custom-operators)
- [Callable](#callable)
- [Kernel](#kernel)
- [Advanced Usage](#advanced-usage)
- [Safety](#safety)
- [API](#api)
- [Backend](#backend)
- [Citation](#citation)

## Example
Try `cargo run --release --example path_tracer -- [cpu|cuda|dx|metal]`!
Expand Down Expand Up @@ -60,7 +67,7 @@ fn main() {
let tid = dispatch_id().x();
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let vx = var!(f32); // create a local mutable variable
let vx = 0.0f32.var(); // create a local mutable variable
*vx.get_mut() += x;
buf_z.write(tid, vx.load() + y);
});
Expand Down Expand Up @@ -125,52 +132,44 @@ For each type, there are two EDSL proxy objects `Expr<T>` and `Var<T>`. `Expr<T>
*Note*: Every DSL object in host code **must** be immutable due to Rust unable to overload `operator =`. For example:
```rust
// **no good**
let mut v = const_(0.0f32);
let mut v = 0.0f32.expr();
if_!(cond, {
v += 1.0;
});

// also **not good**
let v = Cell::new(const_(0.0f32));
let v = Cell::new(0.0f32.expr());
if_!(cond, {
v.set(v.get() + 1.0);
});

// **good**
let v = var!(f32);
let v = 0.0f32.var();
if_!(cond, {
*v.get_mut() += 1.0;
});
```
*Note*: You should not store the referene obtained by `v.get_mut()` for repeated use, as the assigned value is only updated when `v.get_mut()` is dropped. For example,:
```rust
let v = var!(f32);
let v = 0.0f32.var();
let bad = v.get_mut();
*bad = 1.0;
let u = *v;
drop(bad);
cpu_dbg!(u); // prints 0.0
cpu_dbg!(*v); // prints now 1.0
```
All operations except load/store should be performed on `Expr<T>`. `Var<T>` can only be used to load/store values. While `Expr<T>` and `Var<T>` are sufficent in most cases, it cannot be placed in an `impl` block. To do so, the exact name of these proxies are needed.
```rust
Expr<Bool> == Bool, Var<Bool> == BoolVar
Expr<f32> == Float32, Var<f32> == Float32Var
Expr<i32> == Int32, Var<i32> == Int32Var
Expr<u32> == UInt32, Var<u32> == UInt32Var
Expr<i64> == Int64, Var<i64> == Int64Var
Expr<u64> == UInt64, Var<u64> == UInt64Var
```
All operations except load/store should be performed on `Expr<T>`. `Var<T>` can only be used to load/store values.

As in the C++ EDSL, we additionally supports the following vector/matrix types. Their proxy types are `XXXExpr` and `XXXVar`:

```rust
Bool2 // bool2 in C++
Bool3 // bool3 in C++
Bool4 // bool4 in C++
Vec2 // float2 in C++
Vec3 // float3 in C++
Vec4 // float4 in C++
Float2 // float2 in C++
Float3 // float3 in C++
Float4 // float4 in C++
Int2 // int2 in C++
Int3 // int3 in C++
Int4 // int4 in C++
Expand All @@ -181,20 +180,19 @@ Mat2 // float2x2 in C++
Mat3 // float3x3 in C++
Mat4 // float4x4 in C++
```
Array types `[T;N]` are also supported and their proxy types are `ArrayExpr<T, N>` and `ArrayVar<T, N>`. Call `arr.read(i)` and `arr.write(i, value)` on `ArrayVar<T, N>` for element access. `ArrayExpr<T,N>` can be stored to and loaded from `ArrayVar<T, N>`. The limitation is however the array length must be determined during host compile time. If runtime length is required, use `VLArrayVar<T>`. `VLArrayVar<T>::zero(length: usize` would create a zero initialized array. Similarly you can use `read` and `write` methods as well. To query the length of a `VLArrayVar<T>` in host, use ``VLArrayVar<T>::static_len()->usize`. To query the length in kernel, use ``VLArrayVar<T>::len()->Expr<u32>`
Array types `[T;N]` are also supported and their proxy types are `ArrayExpr<T, N>` and `ArrayVar<T, N>`. Call `arr.read(i)` and `arr.write(i, value)` on `ArrayVar<T, N>` for element access. `ArrayExpr<T,N>` can be stored to and loaded from `ArrayVar<T, N>`. The limitation is however the array length must be determined during host compile time. If runtime length is required, use `VLArrayVar<T>`. `VLArrayVar<T>::zero(length: usize)` would create a zero initialized array. Similarly you can use `read` and `write` methods as well. To query the length of a `VLArrayVar<T>` in host, use `VLArrayVar<T>::static_len()->usize`. To query the length in kernel, use `VLArrayVar<T>::len()->Expr<u32>`

Most operators are already overloaded with the only exception is comparision. We cannot overload comparision operators as `PartialOrd` cannot return a DSL type. Instead, use `cmpxx` methods such as `cmpgt, cmpeq`, etc. To cast a primitive/vector into another type, use `v.type()`. For example:
```rust
let iv = make_int2(1,1,1);
let iv = Int2::expr(1, 1, 1);
let fv = iv.float(); //fv is Expr<Float2>
let bv = fv.bool(); // bv is Expr<Bool2>
```
To perform a bitwise cast, use the `bitcast` function. `let fv:Expr<f32> = bitcast::<u32, f32>(const_(0u32));`
To perform a bitwise cast, use the `bitcast` function. `let fv:Expr<f32> = bitcast::<u32, f32>(0u32);`

### Builtin Functions

We have extentded primitive types with methods similar to their host counterpart: `v.sin(), v.max(u)`, etc. Most methods accepts both a `Expr<T>` or a literal like `0.0`. However, the `select` function is slightly different as it do not accept literals. You need to use `select(cond, f_var, const_(1.0f32))`.

We have extentded primitive types with methods similar to their host counterpart: `v.sin(), v.max(u)`, etc. Most methods accepts both a `Expr<T>` or a literal like `0.0`. However, the `select` function is slightly different as it does not accept literals. You need to use `select(cond, f_var, 1.0f32.expr())`.

### Control Flow
*Note*, you cannot modify outer scope variables inside a control flow block by declaring the variable as `mut`. To modify outer scope variables, use `Var<T>` instead and call *var.get_mut() = value` to store the value back to the outer scope.
Expand Down Expand Up @@ -223,8 +221,60 @@ let (x,y) = switch::<(Expr<i32>, Expr<f32>)>(value)
.finish();
```

### `track!` Mcro

We also offer a `track!` macro that automatically rewrites control flow primitves and comparison operators. For example (from [`examples/mpm.rs`](luisa_compute/examples/mpm.rs)):

```rust
track!(|| {
// ...
let vx = select(
coord.x() < BOUND && (vx < 0.0f32)
|| coord.x() + BOUND > N_GRID as u32 && (vx > 0.0f32),
0.0f32.into(),
vx,
);
let vy = select(
coord.y() < BOUND && (vy < 0.0f32)
|| coord.y() + BOUND > N_GRID as u32 && (vy > 0.0f32),
0.0f32.into(),
vy,
);
// ...
})
```
is equivalent to:
```rust
|| {
// ...
let vx = select(
(coord.x().cmplt(BOUND) & vx.cmplt(0.0f32))
| (coord.x() + BOUND).cmpgt(N_GRID as u32) & vx.cmpgt(0.0f32),
0.0f32.into(),
vx,
);
let vy = select(
(coord.y().cmplt(BOUND) & vy.cmplt(0.0f32))
| (coord.y() + BOUND).cmpgt(N_GRID as u32) & vy.cmpgt(0.0f32),
0.0f32.into(),
vy,
);
// ...
}
```
Similarily,
```rust
track!(if cond { foo } else if bar { baz } else { qux })
```
will be converted to
```rust
if_!(cond, { foo }, { if_!(bar, { baz }, { qux }) })
```

Note that this macro will rewrite `while`, `for _ in x..y`, and `loop` expressions to versions using functions, which will then break the `break` and `continue` expressions. In order to avoid this, it's possible to use the `escape!` macro within a `track!` context to disable rewriting for an expression.

### Custom Data Types
To add custom data types to the EDSL, simply derive from `luisa::Value` macro. Note that `#[repr(C)]` is required for the struct to be compatible with C ABI. The proxy types are `XXXExpr` and `XXXVar`:
To add custom data types to the EDSL, simply derive from `Value` macro. Note that `#[repr(C)]` is required for the struct to be compatible with C ABI. The proxy types are `XXXExpr` and `XXXVar`:

```rust
#[derive(Copy, Clone, Default, Debug, Value)]
Expand All @@ -234,7 +284,7 @@ pub struct MyVec2 {
pub y: f32,
}

let v = var!(MyVec2);
let v = MyVec2.var();
let sum = *v.x() + *v.y();
*v.x().get_mut() += 1.0;
```
Expand Down Expand Up @@ -282,8 +332,6 @@ autodiff(||{
buf_dv.write(.., dv);
buf_dm.write(.., dm);
});


```

### Custom Operators
Expand All @@ -304,8 +352,8 @@ let my_add = CpuFn::new(|args: &mut MyAddArgs| {

let args = MyAddArgsExpr::new(x, y, Float32::zero());
let result = my_add.call(args);

```

### Callable
Users can define device-only functions using Callables. Callables have similar type signature to kernels: `Callable<fn(Args)->Ret>`.
The difference is that Callables are not dispatchable and can only be called from other Callables or Kernels. Callables can be created using `Device::create_callable`. To invoke a Callable, use `Callable::call(args...)`. Callables accepts arguments such as resources (`BufferVar<T>`, .etc), expressions and references (pass a `Var<T>` to the callable). For example:
Expand All @@ -317,7 +365,7 @@ let z = add.call(x, y);
let pass_by_ref = device.create_callable::<fn(Var<f32>)>(&|a| {
*a.get_mut() += 1.0;
});
let a = var!(f32, 1.0);
let a = 1.0f32.var();
pass_by_ref.call(a);
cpu_dbg!(*a); // prints 2.0
```
Expand Down
17 changes: 9 additions & 8 deletions luisa_compute/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ name = "luisa_compute"
version = "0.1.1-alpha.1"

[dependencies]
base64ct = {version = "1.5.0", features = ["alloc"]}
base64ct = { version = "1.5.0", features = ["alloc"] }
bumpalo = "3.12.0"
env_logger = "0.10.0"
glam = "0.24.0"
Expand All @@ -14,15 +14,16 @@ lazy_static = "1.4.0"
libc = "0.2"
libloading = "0.8"
log = "0.4"
luisa_compute_api_types = {path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_api_types", version="0.1.1-alpha.1"}
luisa_compute_backend = {path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_backend", version="0.1.1-alpha.1"}
luisa_compute_derive = {path = "../luisa_compute_derive", version="0.1.1-alpha.1"}
luisa_compute_derive_impl = {path = "../luisa_compute_derive_impl", version="0.1.1-alpha.1"}
luisa_compute_ir = {path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_ir", version="0.1.1-alpha.1"}
luisa_compute_sys = {path = "../luisa_compute_sys", version="0.1.1-alpha.1"}
luisa_compute_api_types = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_api_types", version = "0.1.1-alpha.1" }
luisa_compute_backend = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_backend", version = "0.1.1-alpha.1" }
luisa_compute_derive = { path = "../luisa_compute_derive", version = "0.1.1-alpha.1" }
luisa_compute_derive_impl = { path = "../luisa_compute_derive_impl", version = "0.1.1-alpha.1" }
luisa_compute_track = { path = "../luisa_compute_track", version = "0.1.1-alpha.1" }
luisa_compute_ir = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_ir", version = "0.1.1-alpha.1" }
luisa_compute_sys = { path = "../luisa_compute_sys", version = "0.1.1-alpha.1" }
parking_lot = "0.12.1"
rayon = "1.6.0"
serde = {version = "1.0", features = ["derive"]}
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
sha2 = "0.10"
winit = "0.28.3"
Expand Down
7 changes: 3 additions & 4 deletions luisa_compute/examples/atomic.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use std::env::current_exe;

use luisa::prelude::*;
use luisa::Context;
use luisa_compute as luisa;

fn main() {
Expand All @@ -11,12 +10,12 @@ fn main() {
let sum = device.create_buffer::<f32>(1);
x.view(..).fill_fn(|i| i as f32);
sum.view(..).fill(0.0);
let shader = device.create_kernel::<fn()>(&|| {
let shader = device.create_kernel::<fn()>(&track!(|| {
let buf_x = x.var();
let buf_sum = sum.var();
let tid = luisa::dispatch_id().x();
let tid = dispatch_id().x();
buf_sum.atomic_fetch_add(0, buf_x.read(tid));
});
}));
shader.dispatch([x.len() as u32, 1, 1]);
let mut sum_data = vec![0.0];
sum.view(..).copy_to(&mut sum_data);
Expand Down
20 changes: 12 additions & 8 deletions luisa_compute/examples/autodiff.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::{env::current_exe, f32::consts::PI};
use std::env::current_exe;
use std::f32::consts::PI;

use luisa::*;
use luisa::lang::diff::*;
use luisa::prelude::*;
use luisa_compute as luisa;
fn main() {
luisa::init_logger_verbose();
Expand Down Expand Up @@ -31,11 +33,13 @@ fn main() {
let buf_y = y.var();
let x = buf_x.read(tid);
let y = buf_y.read(tid);
let f = |x: Expr<f32>, y: Expr<f32>| {
if_!(x.cmpgt(y), { x * y }, else, {
let f = track!(|x: Expr<f32>, y: Expr<f32>| {
if x > y {
x * y
} else {
y * x + (x / 32.0 * PI).sin()
})
};
}
});
autodiff(|| {
requires_grad(x);
requires_grad(y);
Expand All @@ -45,8 +49,8 @@ fn main() {
dy_rev.write(tid, gradient(y));
});
forward_autodiff(2, || {
propagate_gradient(x, &[const_(1.0f32), const_(0.0f32)]);
propagate_gradient(y, &[const_(0.0f32), const_(1.0f32)]);
propagate_gradient(x, &[1.0f32.expr(), 0.0f32.expr()]);
propagate_gradient(y, &[0.0f32.expr(), 1.0f32.expr()]);
let z = f(x, y);
let dx = output_gradients(z)[0];
let dy = output_gradients(z)[1];
Expand Down
8 changes: 4 additions & 4 deletions luisa_compute/examples/backtrace.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::env::{current_exe, self};
use std::env::{self, current_exe};

use luisa::prelude::*;
use luisa_compute as luisa;

fn main() {
use luisa::*;
init_logger();
luisa::init_logger();
let ctx = Context::new(current_exe().unwrap());
env::set_var("LUISA_DEBUG", "1");
let device = ctx.create_device("cpu");
Expand All @@ -20,7 +20,7 @@ fn main() {
let tid = dispatch_id().x();
let x = buf_x.read(tid + 123);
let y = buf_y.read(tid);
let vx = var!(f32); // create a local mutable variable
let vx = Var::<f32>::zeroed(); // create a local mutable variable
vx.store(x);
buf_z.write(tid, vx.load() + y);
});
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/bindgroup.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::env::current_exe;

use luisa::*;
use luisa::prelude::*;
use luisa_compute as luisa;
#[derive(BindGroup)]
struct MyArgStruct<T: Value> {
Expand Down
Loading

0 comments on commit dc6338c

Please sign in to comment.