Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace make_float4 with Float4::expr, etc. #7

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ As in the C++ EDSL, we additionally supports the following vector/matrix types.
Bool2 // bool2 in C++
Bool3 // bool3 in C++
Bool4 // bool4 in C++
Vec2 // float2 in C++
Vec3 // float3 in C++
Vec4 // float4 in C++
Float2 // float2 in C++
Float3 // float3 in C++
Float4 // float4 in C++
Int2 // int2 in C++
Int3 // int3 in C++
Int4 // int4 in C++
Expand All @@ -185,7 +185,7 @@ Array types `[T;N]` are also supported and their proxy types are `ArrayExpr<T, N

Most operators are already overloaded with the only exception is comparision. We cannot overload comparision operators as `PartialOrd` cannot return a DSL type. Instead, use `cmpxx` methods such as `cmpgt, cmpeq`, etc. To cast a primitive/vector into another type, use `v.type()`. For example:
```rust
let iv = make_int2(1,1,1);
let iv = Int2::expr(1, 1, 1);
let fv = iv.float(); //fv is Expr<Float2>
let bv = fv.bool(); // bv is Expr<Bool2>
```
Expand Down
42 changes: 21 additions & 21 deletions luisa_compute/examples/fluid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,14 @@ fn main() {

let index = |xy: Expr<Uint2>| -> Expr<u32> {
let p = xy.clamp(
make_uint2(0, 0),
make_uint2(N_GRID as u32 - 1, N_GRID as u32 - 1),
Uint2::expr(0, 0),
Uint2::expr(N_GRID as u32 - 1, N_GRID as u32 - 1),
);
p.x() + p.y() * N_GRID as u32
};

let lookup_float = |f: &BufferVar<f32>, x: Int, y: Int| -> Float {
return f.read(index(make_uint2(x.uint(), y.uint())));
return f.read(index(Uint2::expr(x.uint(), y.uint())));
};

let sample_float = |f: BufferVar<f32>, x: Float, y: Float| -> Float {
Expand All @@ -96,7 +96,7 @@ fn main() {
};

let lookup_vel = |f: &BufferVar<Float2>, x: Int, y: Int| -> Float2Expr {
return f.read(index(make_uint2(x.uint(), y.uint())));
return f.read(index(Uint2::expr(x.uint(), y.uint())));
};

let sample_vel = |f: BufferVar<Float2>, x: Float, y: Float| -> Float2Expr {
Expand All @@ -106,11 +106,11 @@ fn main() {
let tx = x - lx.float();
let ty = y - ly.float();

let s0 = lookup_vel(&f, lx, ly).lerp(lookup_vel(&f, lx + 1, ly), make_float2(tx, tx));
let s0 = lookup_vel(&f, lx, ly).lerp(lookup_vel(&f, lx + 1, ly), Float2::expr(tx, tx));
let s1 =
lookup_vel(&f, lx, ly + 1).lerp(lookup_vel(&f, lx + 1, ly + 1), make_float2(tx, tx));
lookup_vel(&f, lx, ly + 1).lerp(lookup_vel(&f, lx + 1, ly + 1), Float2::expr(tx, tx));

return s0.lerp(s1, make_float2(ty, ty));
return s0.lerp(s1, Float2::expr(ty, ty));
};

let advect = device
Expand All @@ -120,7 +120,7 @@ fn main() {
let u = u0.read(index(coord));

// trace backward
let mut p = make_float2(coord.x().float(), coord.y().float());
let mut p = Float2::expr(coord.x().float(), coord.y().float());
p = p - u * dt;

// advect
Expand All @@ -132,10 +132,10 @@ fn main() {
let divergence = device.create_kernel_async::<fn(Buffer<Float2>, Buffer<f32>)>(&|u, div| {
let coord = dispatch_id().xy();
if_!(coord.x().cmplt(N_GRID - 1) & coord.y().cmplt(N_GRID - 1), {
let dx = (u.read(index(make_uint2(coord.x() + 1, coord.y()))).x()
let dx = (u.read(index(Uint2::expr(coord.x() + 1, coord.y()))).x()
- u.read(index(coord)).x())
* 0.5;
let dy = (u.read(index(make_uint2(coord.x(), coord.y() + 1))).y()
let dy = (u.read(index(Uint2::expr(coord.x(), coord.y() + 1))).y()
- u.read(index(coord)).y())
* 0.5;
div.write(index(coord), dx + dy);
Expand Down Expand Up @@ -169,11 +169,11 @@ fn main() {
i.cmpgt(0) & i.cmplt(N_GRID - 1) & j.cmpgt(0) & j.cmplt(N_GRID - 1),
{
// pressure gradient
let f_p = make_float2(
p.read(index(make_uint2(i.uint() + 1, j.uint())))
- p.read(index(make_uint2(i.uint() - 1, j.uint()))),
p.read(index(make_uint2(i.uint(), j.uint() + 1)))
- p.read(index(make_uint2(i.uint(), j.uint() - 1))),
let f_p = Float2::expr(
p.read(index(Uint2::expr(i.uint() + 1, j.uint())))
- p.read(index(Uint2::expr(i.uint() - 1, j.uint()))),
p.read(index(Uint2::expr(i.uint(), j.uint() + 1)))
- p.read(index(Uint2::expr(i.uint(), j.uint() - 1))),
) * 0.5f32;

u.write(ij, u.read(ij) - f_p);
Expand All @@ -186,7 +186,7 @@ fn main() {
let ij = index(coord);

// gravity
let f_g = make_float2(-90.8f32, 0.0f32) * rho.read(ij);
let f_g = Float2::expr(-90.8f32, 0.0f32) * rho.read(ij);

// integrate
u.write(ij, u.read(ij) + dt * f_g);
Expand All @@ -201,7 +201,7 @@ fn main() {
let i = coord.x().int();
let j = coord.y().int();
let ij = index(coord);
let d = make_float2((i - N_GRID / 2).float(), (j - N_GRID / 2).float()).length();
let d = Float2::expr((i - N_GRID / 2).float(), (j - N_GRID / 2).float()).length();

let radius = 5.0f32;
if_!(d.cmplt(radius), {
Expand All @@ -212,8 +212,8 @@ fn main() {

let init_grid = device.create_kernel_async::<fn()>(&|| {
let idx = index(dispatch_id().xy());
u0.var().write(idx, make_float2(0.0f32, 0.0f32));
u1.var().write(idx, make_float2(0.0f32, 0.0f32));
u0.var().write(idx, Float2::expr(0.0f32, 0.0f32));
u1.var().write(idx, Float2::expr(0.0f32, 0.0f32));

rho0.var().write(idx, 0.0f32);
rho1.var().write(idx, 0.0f32);
Expand All @@ -234,8 +234,8 @@ fn main() {
let ij = index(coord);
let value = rho0.var().read(ij);
display.var().write(
make_uint2(coord.x(), (N_GRID - 1) as u32 - coord.y()),
make_float4(value, 0.0f32, 0.0f32, 1.0f32),
Uint2::expr(coord.x(), (N_GRID - 1) as u32 - coord.y()),
Float4::expr(value, 0.0f32, 0.0f32, 1.0f32),
);
});

Expand Down
26 changes: 13 additions & 13 deletions luisa_compute/examples/mpm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ fn main() {

let index = |xy: Expr<Uint2>| -> Expr<u32> {
let p = xy.clamp(
make_uint2(0, 0),
make_uint2(N_GRID as u32 - 1, N_GRID as u32 - 1),
Uint2::expr(0, 0),
Uint2::expr(N_GRID as u32 - 1, N_GRID as u32 - 1),
);
p.x() + p.y() * N_GRID as u32
};
Expand All @@ -113,11 +113,11 @@ fn main() {
];
let stress = -4.0f32 * DT * E * P_VOL * (J.var().read(p) - 1.0f32) / (DX * DX);
let affine =
Expr::<Mat2>::eye(make_float2(stress, stress)) + P_MASS as f32 * C.var().read(p);
Expr::<Mat2>::eye(Float2::expr(stress, stress)) + P_MASS as f32 * C.var().read(p);
let vp = v.var().read(p);
for ii in 0..9 {
let (i, j) = (ii % 3, ii / 3);
let offset = make_int2(i as i32, j as i32);
let offset = Int2::expr(i as i32, j as i32);
let dpos = (offset.float() - fx) * DX;
let weight = w[i].x() * w[j].y();
let vadd = weight * (P_MASS * vp + affine * dpos);
Expand All @@ -132,7 +132,7 @@ fn main() {
let coord = dispatch_id().xy();
let i = index(coord);
let v = var!(Float2);
v.store(make_float2(
v.store(Float2::expr(
grid_v.var().read(i * 2u32),
grid_v.var().read(i * 2u32 + 1u32),
));
Expand Down Expand Up @@ -170,15 +170,15 @@ fn main() {
];
let new_v = var!(Float2);
let new_C = var!(Mat2);
new_v.store(make_float2(0.0f32, 0.0f32));
new_C.store(make_float2x2(make_float2(0., 0.), make_float2(0., 0.)));
new_v.store(Float2::expr(0.0f32, 0.0f32));
new_C.store(Mat2::expr(Float2::expr(0., 0.), Float2::expr(0., 0.)));
for ii in 0..9 {
let (i, j) = (ii % 3, ii / 3);
let offset = make_int2(i as i32, j as i32);
let offset = Int2::expr(i as i32, j as i32);
let dpos = (offset.float() - fx) * DX;
let weight = w[i].x() * w[j].y();
let idx = index((base + offset).uint());
let g_v = make_float2(
let g_v = Float2::expr(
grid_v.var().read(idx * 2u32),
grid_v.var().read(idx * 2u32 + 1u32),
);
Expand All @@ -195,23 +195,23 @@ fn main() {
let clear_display = device.create_kernel_async::<fn()>(&|| {
display.var().write(
dispatch_id().xy(),
make_float4(0.1f32, 0.2f32, 0.3f32, 1.0f32),
Float4::expr(0.1f32, 0.2f32, 0.3f32, 1.0f32),
);
});
let draw_particles = device.create_kernel_async::<fn()>(&|| {
let p = dispatch_id().x();
for i in -1..=1 {
for j in -1..=1 {
let pos = (x.var().read(p) * RESOLUTION as f32).int() + make_int2(i, j);
let pos = (x.var().read(p) * RESOLUTION as f32).int() + Int2::expr(i, j);
if_!(
pos.x().cmpge(0i32)
& pos.x().cmplt(RESOLUTION as i32)
& pos.y().cmpge(0i32)
& pos.y().cmplt(RESOLUTION as i32),
{
display.var().write(
make_uint2(pos.x().uint(), RESOLUTION - 1u32 - pos.y().uint()),
make_float4(0.4f32, 0.6f32, 0.6f32, 1.0f32),
Uint2::expr(pos.x().uint(), RESOLUTION - 1u32 - pos.y().uint()),
Float4::expr(0.4f32, 0.6f32, 0.6f32, 1.0f32),
);
}
);
Expand Down
61 changes: 31 additions & 30 deletions luisa_compute/examples/path_tracer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,10 @@ fn main() {

let generate_ray = |p: Expr<Float2>| -> Expr<Ray> {
const FOV: f32 = 27.8f32 * std::f32::consts::PI / 180.0f32;
let origin = make_float3(-0.01f32, 0.995f32, 5.0f32);
let origin = Float3::expr(-0.01f32, 0.995f32, 5.0f32);

let pixel = origin
+ make_float3(
+ Float3::expr(
p.x() * f32::tan(0.5f32 * FOV),
p.y() * f32::tan(0.5f32 * FOV),
-1.0f32,
Expand All @@ -304,9 +304,9 @@ fn main() {
let make_onb = |normal: Expr<Float3>| -> Expr<Onb> {
let binormal = if_!(
normal.x().abs().cmpgt(normal.z().abs()), {
make_float3(-normal.y(), normal.x(), 0.0f32)
Float3::expr(-normal.y(), normal.x(), 0.0f32)
}, else {
make_float3(0.0f32, -normal.z(), normal.y())
Float3::expr(0.0f32, -normal.z(), normal.y())
}
);
let tangent = binormal.cross(normal).normalize();
Expand All @@ -316,7 +316,7 @@ fn main() {
let cosine_sample_hemisphere = |u: Expr<Float2>| {
let r = u.x().sqrt();
let phi = 2.0f32 * std::f32::consts::PI * u.y();
make_float3(r * phi.cos(), r * phi.sin(), (1.0f32 - u.x()).sqrt())
Float3::expr(r * phi.cos(), r * phi.sin(), (1.0f32 - u.x()).sqrt())
};

let coord = dispatch_id().xy();
Expand All @@ -327,24 +327,24 @@ fn main() {
let rx = lcg(state);
let ry = lcg(state);

let pixel = (coord.float() + make_float2(rx, ry)) / frame_size * 2.0f32 - 1.0f32;
let pixel = (coord.float() + Float2::expr(rx, ry)) / frame_size * 2.0f32 - 1.0f32;

let radiance = var!(Float3);
radiance.store(make_float3(0.0f32, 0.0f32, 0.0f32));
radiance.store(Float3::expr(0.0f32, 0.0f32, 0.0f32));
for_range(0..SPP_PER_DISPATCH as u32, |_| {
let init_ray = generate_ray(pixel * make_float2(1.0f32, -1.0f32));
let init_ray = generate_ray(pixel * Float2::expr(1.0f32, -1.0f32));
let ray = var!(Ray);
ray.store(init_ray);

let beta = var!(Float3);
beta.store(make_float3(1.0f32, 1.0f32, 1.0f32));
beta.store(Float3::expr(1.0f32, 1.0f32, 1.0f32));
let pdf_bsdf = var!(f32);
pdf_bsdf.store(0.0f32);

let light_position = make_float3(-0.24f32, 1.98f32, 0.16f32);
let light_u = make_float3(-0.24f32, 1.98f32, -0.22f32) - light_position;
let light_v = make_float3(0.23f32, 1.98f32, 0.16f32) - light_position;
let light_emission = make_float3(17.0f32, 12.0f32, 4.0f32);
let light_position = Float3::expr(-0.24f32, 1.98f32, 0.16f32);
let light_u = Float3::expr(-0.24f32, 1.98f32, -0.22f32) - light_position;
let light_v = Float3::expr(0.23f32, 1.98f32, 0.16f32) - light_position;
let light_emission = Float3::expr(17.0f32, 12.0f32, 4.0f32);
let light_area = light_u.cross(light_v).length();
let light_normal = light_u.cross(light_v).normalize();

Expand Down Expand Up @@ -414,13 +414,13 @@ fn main() {
let onb = make_onb(n);
let ux = lcg(state);
let uy = lcg(state);
let new_direction = onb.to_world(cosine_sample_hemisphere(make_float2(ux, uy)));
let new_direction = onb.to_world(cosine_sample_hemisphere(Float2::expr(ux, uy)));
*ray.get_mut() = make_ray(pp, new_direction, 0.0f32.into(), std::f32::MAX.into());
*beta.get_mut() *= albedo;
pdf_bsdf.store(cos_wi * std::f32::consts::FRAC_1_PI);

// russian roulette
let l = make_float3(0.212671f32, 0.715160f32, 0.072169f32).dot(*beta);
let l = Float3::expr(0.212671f32, 0.715160f32, 0.072169f32).dot(*beta);
if_!(l.cmpeq(0.0f32), { break_(); });
let q = l.max(0.05f32);
let r = lcg(state);
Expand All @@ -432,28 +432,29 @@ fn main() {
});
radiance.store(radiance.load() / SPP_PER_DISPATCH as f32);
seed_image.write(coord, *state);
if_!(radiance.load().is_nan().any(), { radiance.store(make_float3(0.0f32, 0.0f32, 0.0f32)); });
if_!(radiance.load().is_nan().any(), { radiance.store(Float3::expr(0.0f32, 0.0f32, 0.0f32)); });
let radiance = radiance.load().clamp(0.0f32, 30.0f32);
let old = image.read(dispatch_id().xy());
let spp = old.w();
let radiance = radiance + old.xyz();
image.write(dispatch_id().xy(), make_float4(radiance.x(), radiance.y(), radiance.z(), spp + 1.0f32));
image.write(dispatch_id().xy(), Float4::expr(radiance.x(), radiance.y(), radiance.z(), spp + 1.0f32));
},
)
;
let display = device.create_kernel_async::<fn(Tex2d<Float4>, Tex2d<Float4>)>(&|acc, display| {
set_block_size([16, 16, 1]);
let coord = dispatch_id().xy();
let radiance = acc.read(coord);
let spp = radiance.w();
let radiance = radiance.xyz() / spp;

// workaround a rust-analyzer bug
let r = 1.055f32 * radiance.powf(1.0 / 2.4) - 0.055;

let srgb = Float3Expr::select(radiance.cmplt(0.0031308), radiance * 12.92, r);
display.write(coord, make_float4(srgb.x(), srgb.y(), srgb.z(), 1.0f32));
});
let display =
device.create_kernel_async::<fn(Tex2d<Float4>, Tex2d<Float4>)>(&|acc, display| {
set_block_size([16, 16, 1]);
let coord = dispatch_id().xy();
let radiance = acc.read(coord);
let spp = radiance.w();
let radiance = radiance.xyz() / spp;

// workaround a rust-analyzer bug
let r = 1.055f32 * radiance.powf(1.0 / 2.4) - 0.055;

let srgb = Float3Expr::select(radiance.cmplt(0.0031308), radiance * 12.92, r);
display.write(coord, Float4::expr(srgb.x(), srgb.y(), srgb.z(), 1.0f32));
});
let img_w = 1024;
let img_h = 1024;
let acc_img = device.create_tex2d::<Float4>(PixelStorage::Float4, img_w, img_h, 1);
Expand Down
Loading