Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expr Structs Refactor #12

Merged
merged 15 commits into from
Sep 21, 2023
Merged
25 changes: 15 additions & 10 deletions luisa_compute/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,36 @@ version = "0.1.1-alpha.1"
base64ct = { version = "1.5.0", features = ["alloc"] }
bumpalo = "3.12.0"
env_logger = "0.10.0"
glam = "0.24.0"
half = "2.2.1"

lazy_static = "1.4.0"
libc = "0.2"
libloading = "0.8"
log = "0.4"
luisa_compute_api_types = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_api_types", version = "0.1.1-alpha.1" }
luisa_compute_backend = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_backend", version = "0.1.1-alpha.1" }
luisa_compute_derive = { path = "../luisa_compute_derive", version = "0.1.1-alpha.1" }
luisa_compute_derive_impl = { path = "../luisa_compute_derive_impl", version = "0.1.1-alpha.1" }
luisa_compute_track = { path = "../luisa_compute_track", version = "0.1.1-alpha.1" }
luisa_compute_ir = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_ir", version = "0.1.1-alpha.1" }
luisa_compute_sys = { path = "../luisa_compute_sys", version = "0.1.1-alpha.1" }
parking_lot = "0.12.1"
rayon = "1.6.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
sha2 = "0.10"
winit = "0.28.3"
raw-window-handle = "0.5.1"
indexmap = "1.9.3"
indexmap = "2.0.0"

luisa_compute_api_types = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_api_types", version = "0.1.1-alpha.1" }
luisa_compute_backend = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_backend", version = "0.1.1-alpha.1" }
luisa_compute_derive = { path = "../luisa_compute_derive", version = "0.1.1-alpha.1" }
luisa_compute_derive_impl = { path = "../luisa_compute_derive_impl", version = "0.1.1-alpha.1" }
luisa_compute_track = { path = "../luisa_compute_track", version = "0.1.1-alpha.1" }
luisa_compute_ir = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_ir", version = "0.1.1-alpha.1" }
luisa_compute_sys = { path = "../luisa_compute_sys", version = "0.1.1-alpha.1" }

glam = { version = "0.24.0", optional = true }
nalgebra = { version = "0.32.3", optional = true }

[dev-dependencies]
rand = "0.8.5"
image = "0.24.5"
tobj = "3.2.5"
tobj = "4.0.0"

[features]
default = ["remote", "cuda", "cpu", "metal", "dx"]
Expand All @@ -43,3 +46,5 @@ dx = ["luisa_compute_sys/dx"]
strict = ["luisa_compute_sys/strict"]
remote = ["luisa_compute_sys/remote"]
cpu = ["luisa_compute_sys/cpu"]
glam = ["dep:glam"]
nalgebra = ["dep:nalgebra"]
73 changes: 67 additions & 6 deletions luisa_compute/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,65 @@ pub mod debug;
pub mod diff;
pub mod functions;
pub mod index;
pub mod maybe_expr;
pub mod ops;
pub mod poly;
pub mod swizzle;
pub mod types;

pub(crate) trait CallFuncTrait {
fn call<T: Value, S: Value>(self, x: Expr<T>) -> Expr<S>;
fn call2<T: Value, S: Value, U: Value>(self, x: Expr<T>, y: Expr<S>) -> Expr<U>;
fn call3<T: Value, S: Value, U: Value, V: Value>(
self,
x: Expr<T>,
y: Expr<S>,
z: Expr<U>,
) -> Expr<V>;
fn call_void<T: Value>(self, x: Expr<T>);
fn call2_void<T: Value, S: Value>(self, x: Expr<T>, y: Expr<S>);
fn call3_void<T: Value, S: Value, U: Value>(self, x: Expr<T>, y: Expr<S>, z: Expr<U>);
}
impl CallFuncTrait for Func {
fn call<T: Value, S: Value>(self, x: Expr<T>) -> Expr<S> {
Expr::<S>::from_node(__current_scope(|b| {
b.call(self, &[x.node()], <S as TypeOf>::type_())
}))
}
fn call2<T: Value, S: Value, U: Value>(self, x: Expr<T>, y: Expr<S>) -> Expr<U> {
Expr::<U>::from_node(__current_scope(|b| {
b.call(self, &[x.node(), y.node()], <U as TypeOf>::type_())
}))
}
fn call3<T: Value, S: Value, U: Value, V: Value>(
self,
x: Expr<T>,
y: Expr<S>,
z: Expr<U>,
) -> Expr<V> {
Expr::<V>::from_node(__current_scope(|b| {
b.call(
self,
&[x.node(), y.node(), z.node()],
<V as TypeOf>::type_(),
)
}))
}
fn call_void<T: Value>(self, x: Expr<T>) {
__current_scope(|b| {
b.call(self, &[x.node()], Type::void());
});
}
fn call2_void<T: Value, S: Value>(self, x: Expr<T>, y: Expr<S>) {
__current_scope(|b| {
b.call(self, &[x.node(), y.node()], Type::void());
});
}
fn call3_void<T: Value, S: Value, U: Value>(self, x: Expr<T>, y: Expr<S>, z: Expr<U>) {
__current_scope(|b| {
b.call(self, &[x.node(), y.node(), z.node()], Type::void());
});
}
}

#[allow(dead_code)]
pub(crate) static KERNEL_ID: AtomicUsize = AtomicUsize::new(0);
// prevent node being shared across kernels
Expand Down Expand Up @@ -135,10 +188,19 @@ pub trait ToNode {
fn node(&self) -> NodeRef;
}

pub trait FromNode: ToNode {
pub trait NodeLike: FromNode + ToNode {}
impl<T> NodeLike for T where T: FromNode + ToNode {}

pub trait FromNode {
fn from_node(node: NodeRef) -> Self;
}

impl<T: Default> FromNode for T {
fn from_node(_: NodeRef) -> Self {
Default::default()
}
}

fn _store<T1: Aggregate, T2: Aggregate>(var: &T1, value: &T2) {
let value_nodes = value.to_vec_nodes();
let self_nodes = var.to_vec_nodes();
Expand Down Expand Up @@ -407,12 +469,11 @@ pub const fn packed_size<T: Value>() -> usize {
(std::mem::size_of::<T>() + 3) / 4
}

pub fn pack_to<E, B>(expr: E, buffer: &B, index: impl Into<Expr<u32>>)
pub fn pack_to<V: Value, B>(expr: Expr<V>, buffer: &B, index: impl AsExpr<Value = u32>)
where
E: ExprProxy,
B: IndexWrite<Element = u32>,
{
let index = index.into();
let index = index.as_expr();
__current_scope(|b| {
b.call(
Func::Pack,
Expand Down
37 changes: 12 additions & 25 deletions luisa_compute/src/lang/control_flow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,30 @@ use crate::internal_prelude::*;
use ir::SwitchCase;

/**
* If you want rustfmt to format your code, use if_!(cond, { .. }, { .. }) or if_!(cond, { .. }, else, {...})
* instead of if_!(cond, { .. }, else {...}).
* If you want rustfmt to format your code, use if_!(cond, { .. }, { .. })
* or if_!(cond, { .. }, else, {...}) instead of if_!(cond, { .. }, else
* {...}).
*
*/
#[macro_export]
macro_rules! if_ {
($cond:expr, $then:block, else $else_:block) => {
<_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else(
$cond,
|| $then,
|| $else_,
)
<_ as $crate::lang::ops::SelectMaybeExpr<_>>::select($cond, || $then, || $else_)
};
($cond:expr, $then:block, else, $else_:block) => {
<_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else(
$cond,
|| $then,
|| $else_,
)
<_ as $crate::lang::ops::SelectMaybeExpr<_>>::select($cond, || $then, || $else_)
};
($cond:expr, $then:block, $else_:block) => {
<_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else(
$cond,
|| $then,
|| $else_,
)
<_ as $crate::lang::ops::SelectMaybeExpr<_>>::select($cond, || $then, || $else_)
};
($cond:expr, $then:block) => {
<_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else(
$cond,
|| $then,
|| {},
)
<_ as $crate::lang::ops::ActivateMaybeExpr>::activate($cond, || $then)
};
}
#[macro_export]
macro_rules! while_ {
($cond:expr,$body:block) => {
<_ as $crate::lang::maybe_expr::BoolWhileMaybeExpr>::while_loop(|| $cond, || $body)
<_ as $crate::lang::ops::LoopMaybeExpr>::while_loop(|| $cond, || $body)
};
}
#[macro_export]
Expand All @@ -66,7 +51,7 @@ pub fn continue_() {
});
}

pub fn return_v<T: FromNode>(v: T) {
pub fn return_v<T: NodeLike>(v: T) {
RECORDER.with(|r| {
let mut r = r.borrow_mut();
if r.callable_ret_type.is_none() {
Expand Down Expand Up @@ -296,7 +281,9 @@ impl_range!(u32);
impl_range!(u64);

pub fn loop_(body: impl Fn()) {
while_!(true.expr(), { body(); });
while_!(true.expr(), {
body();
});
}

pub fn for_range<R: ForLoopRange>(r: R, body: impl Fn(Expr<R::Element>)) {
Expand Down
11 changes: 4 additions & 7 deletions luisa_compute/src/lang/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ impl<T: Value> CpuFn<T> {
_marker: PhantomData,
}
}
pub fn call(&self, arg: impl ExprProxy<Value = T>) -> Expr<T> {
pub fn call(&self, arg: impl AsExpr<Value = T>) -> Expr<T> {
RECORDER.with(|r| {
let mut r = r.borrow_mut();
assert!(r.lock);
Expand All @@ -58,7 +58,7 @@ impl<T: Value> CpuFn<T> {
Expr::<T>::from_node(__current_scope(|b| {
b.call(
Func::CpuCustomOp(self.op.clone()),
&[arg.node()],
&[arg.as_expr().node()],
T::type_(),
)
}))
Expand Down Expand Up @@ -107,14 +107,11 @@ macro_rules! lc_assert {
$crate::lang::debug::__assert($arg, $msg, file!(), line!(), column!())
};
}
pub fn __cpu_dbg<T: ExprProxy>(arg: T, file: &'static str, line: u32)
where
T::Value: Debug,
{
pub fn __cpu_dbg<V: Value + Debug>(arg: Expr<V>, file: &'static str, line: u32) {
if !is_cpu_backend() {
return;
}
let f = CpuFn::new(move |x: &mut T::Value| {
let f = CpuFn::new(move |x: &mut V| {
println!("[{}:{}] {:?}", file, line, x);
});
let _ = f.call(arg);
Expand Down
31 changes: 16 additions & 15 deletions luisa_compute/src/lang/diff.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ impl AdContext {
thread_local! {
static AD_CONTEXT:RefCell<AdContext> = RefCell::new(AdContext::new_rev());
}
pub fn requires_grad(var: impl ExprProxy) {
pub fn requires_grad<V: Value>(var: Expr<V>) {
AD_CONTEXT.with(|c| {
let c = c.borrow();
assert!(c.started, "autodiff section is not started");
Expand All @@ -49,15 +49,15 @@ pub fn requires_grad(var: impl ExprProxy) {
});
}

pub fn backward<T: ExprProxy>(out: T) {
pub fn backward<V: Value>(out: Expr<V>) {
backward_with_grad(
out,
FromNode::from_node(__current_scope(|b| {
let one = new_node(
b.pools(),
Node::new(
CArc::new(Instruction::Const(Const::One(<T::Value>::type_()))),
<T::Value>::type_(),
CArc::new(Instruction::Const(Const::One(V::type_()))),
V::type_(),
),
);
b.append(one);
Expand All @@ -66,7 +66,7 @@ pub fn backward<T: ExprProxy>(out: T) {
);
}

pub fn backward_with_grad<T: ExprProxy>(out: T, grad: T) {
pub fn backward_with_grad<V: Value>(out: Expr<V>, grad: Expr<V>) {
AD_CONTEXT.with(|c| {
let mut c = c.borrow_mut();
assert!(c.started, "autodiff section is not started");
Expand All @@ -83,19 +83,19 @@ pub fn backward_with_grad<T: ExprProxy>(out: T, grad: T) {
}

/// Gradient of a value in *Reverse mode* AD
pub fn gradient<T: ExprProxy>(var: T) -> T {
pub fn gradient<V: Value>(var: Expr<V>) -> Expr<V> {
AD_CONTEXT.with(|c| {
let c = c.borrow();
assert!(c.started, "autodiff section is not started");
assert!(!c.is_forward_mode, "gradient() is called in forward mode");
assert!(c.backward_called, "backward is not called");
});
T::from_node(__current_scope(|b| {
Expr::<V>::from_node(__current_scope(|b| {
b.call(Func::Gradient, &[var.node()], var.node().type_().clone())
}))
}
/// Gradient of a value in *Reverse mode* AD
pub fn grad<T: ExprProxy>(var: T) -> T {
pub fn grad<V: Value>(var: Expr<V>) -> Expr<V> {
gradient(var)
}

Expand All @@ -108,8 +108,8 @@ pub fn grad<T: ExprProxy>(var: T) -> T {
// let ret = body();
// let fwd = pop_scope();
// __current_scope(|b| {
// let node = new_node(Node::new(CArc::new(Instruction::AdDetach(fwd)), Type::void()));
// b.append(node);
// let node = new_node(Node::new(CArc::new(Instruction::AdDetach(fwd)),
// Type::void())); b.append(node);
// });
// let nodes = ret.to_vec_nodes();
// let nodes: Vec<_> = nodes
Expand All @@ -118,13 +118,14 @@ pub fn grad<T: ExprProxy>(var: T) -> T {
// .collect();
// R::from_vec_nodes(nodes)
// }
pub fn detach<T: FromNode>(v: T) -> T {
pub fn detach<T: NodeLike>(v: T) -> T {
let v = v.node();
let node = __current_scope(|b| b.call(Func::Detach, &[v], v.type_().clone()));
T::from_node(node)
}

/// Start a *Forward mode* AD section that propagates N gradients w.r.t to input variable
/// Start a *Forward mode* AD section that propagates N gradients w.r.t to input
/// variable
pub fn forward_autodiff(n_grads: usize, body: impl Fn()) {
AD_CONTEXT.with(|c| {
let mut c = c.borrow_mut();
Expand Down Expand Up @@ -152,7 +153,7 @@ pub fn forward_autodiff(n_grads: usize, body: impl Fn()) {
}

/// Propagate N gradients w.r.t to input variable using *Forward mode* AD
pub fn propagate_gradient<T: ExprProxy>(v: T, grads: &[T]) {
pub fn propagate_gradient<V: Value>(v: Expr<V>, grads: &[Expr<V>]) {
AD_CONTEXT.with(|c| {
let c = c.borrow();
assert_eq!(grads.len(), c.n_forward_grads);
Expand All @@ -169,7 +170,7 @@ pub fn propagate_gradient<T: ExprProxy>(v: T, grads: &[T]) {
});
}

pub fn output_gradients<T: ExprProxy>(v: T) -> Vec<T> {
pub fn output_gradients<V: Value>(v: Expr<V>) -> Vec<Expr<V>> {
let n = AD_CONTEXT.with(|c| {
let c = c.borrow();
assert!(c.started, "autodiff section is not started");
Expand All @@ -183,7 +184,7 @@ pub fn output_gradients<T: ExprProxy>(v: T) -> Vec<T> {
let mut grads = vec![];
for i in 0..n {
let idx = b.const_(Const::Int32(i as i32));
grads.push(T::from_node(b.call(
grads.push(Expr::<V>::from_node(b.call(
Func::OutputGrad,
&[v.node(), idx],
v.node().type_().clone(),
Expand Down
Loading