diff --git a/luisa_compute/Cargo.toml b/luisa_compute/Cargo.toml index dbf69d2..1a0ca36 100644 --- a/luisa_compute/Cargo.toml +++ b/luisa_compute/Cargo.toml @@ -7,20 +7,12 @@ version = "0.1.1-alpha.1" base64ct = { version = "1.5.0", features = ["alloc"] } bumpalo = "3.12.0" env_logger = "0.10.0" -glam = "0.24.0" half = "2.2.1" lazy_static = "1.4.0" libc = "0.2" libloading = "0.8" log = "0.4" -luisa_compute_api_types = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_api_types", version = "0.1.1-alpha.1" } -luisa_compute_backend = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_backend", version = "0.1.1-alpha.1" } -luisa_compute_derive = { path = "../luisa_compute_derive", version = "0.1.1-alpha.1" } -luisa_compute_derive_impl = { path = "../luisa_compute_derive_impl", version = "0.1.1-alpha.1" } -luisa_compute_track = { path = "../luisa_compute_track", version = "0.1.1-alpha.1" } -luisa_compute_ir = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_ir", version = "0.1.1-alpha.1" } -luisa_compute_sys = { path = "../luisa_compute_sys", version = "0.1.1-alpha.1" } parking_lot = "0.12.1" rayon = "1.6.0" serde = { version = "1.0", features = ["derive"] } @@ -28,12 +20,23 @@ serde_json = "1.0" sha2 = "0.10" winit = "0.28.3" raw-window-handle = "0.5.1" -indexmap = "1.9.3" +indexmap = "2.0.0" + +luisa_compute_api_types = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_api_types", version = "0.1.1-alpha.1" } +luisa_compute_backend = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_backend", version = "0.1.1-alpha.1" } +luisa_compute_derive = { path = "../luisa_compute_derive", version = "0.1.1-alpha.1" } +luisa_compute_derive_impl = { path = "../luisa_compute_derive_impl", version = "0.1.1-alpha.1" } +luisa_compute_track = { path = "../luisa_compute_track", version = "0.1.1-alpha.1" } +luisa_compute_ir = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_ir", version = "0.1.1-alpha.1" } +luisa_compute_sys = { path = "../luisa_compute_sys", version = "0.1.1-alpha.1" } + +glam = { version = "0.24.0", optional = true } +nalgebra = { version = "0.32.3", optional = true } [dev-dependencies] rand = "0.8.5" image = "0.24.5" -tobj = "3.2.5" +tobj = "4.0.0" [features] default = ["remote", "cuda", "cpu", "metal", "dx"] @@ -43,3 +46,5 @@ dx = ["luisa_compute_sys/dx"] strict = ["luisa_compute_sys/strict"] remote = ["luisa_compute_sys/remote"] cpu = ["luisa_compute_sys/cpu"] +glam = ["dep:glam"] +nalgebra = ["dep:nalgebra"] diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index dae2d7e..4fdf9cb 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -29,12 +29,65 @@ pub mod debug; pub mod diff; pub mod functions; pub mod index; -pub mod maybe_expr; pub mod ops; pub mod poly; -pub mod swizzle; pub mod types; +pub(crate) trait CallFuncTrait { + fn call(self, x: Expr) -> Expr; + fn call2(self, x: Expr, y: Expr) -> Expr; + fn call3( + self, + x: Expr, + y: Expr, + z: Expr, + ) -> Expr; + fn call_void(self, x: Expr); + fn call2_void(self, x: Expr, y: Expr); + fn call3_void(self, x: Expr, y: Expr, z: Expr); +} +impl CallFuncTrait for Func { + fn call(self, x: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(self, &[x.node()], ::type_()) + })) + } + fn call2(self, x: Expr, y: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(self, &[x.node(), y.node()], ::type_()) + })) + } + fn call3( + self, + x: Expr, + y: Expr, + z: Expr, + ) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call( + self, + &[x.node(), y.node(), z.node()], + ::type_(), + ) + })) + } + fn call_void(self, x: Expr) { + __current_scope(|b| { + b.call(self, &[x.node()], Type::void()); + }); + } + fn call2_void(self, x: Expr, y: Expr) { + __current_scope(|b| { + b.call(self, &[x.node(), y.node()], Type::void()); + }); + } + fn call3_void(self, x: Expr, y: Expr, z: Expr) { + __current_scope(|b| { + b.call(self, &[x.node(), y.node(), z.node()], Type::void()); + }); + } +} + #[allow(dead_code)] pub(crate) static KERNEL_ID: AtomicUsize = AtomicUsize::new(0); // prevent node being shared across kernels @@ -135,10 +188,19 @@ pub trait ToNode { fn node(&self) -> NodeRef; } -pub trait FromNode: ToNode { +pub trait NodeLike: FromNode + ToNode {} +impl NodeLike for T where T: FromNode + ToNode {} + +pub trait FromNode { fn from_node(node: NodeRef) -> Self; } +impl FromNode for T { + fn from_node(_: NodeRef) -> Self { + Default::default() + } +} + fn _store(var: &T1, value: &T2) { let value_nodes = value.to_vec_nodes(); let self_nodes = var.to_vec_nodes(); @@ -407,12 +469,11 @@ pub const fn packed_size() -> usize { (std::mem::size_of::() + 3) / 4 } -pub fn pack_to(expr: E, buffer: &B, index: impl Into>) +pub fn pack_to(expr: Expr, buffer: &B, index: impl AsExpr) where - E: ExprProxy, B: IndexWrite, { - let index = index.into(); + let index = index.as_expr(); __current_scope(|b| { b.call( Func::Pack, diff --git a/luisa_compute/src/lang/control_flow.rs b/luisa_compute/src/lang/control_flow.rs index 9324695..06ca73f 100644 --- a/luisa_compute/src/lang/control_flow.rs +++ b/luisa_compute/src/lang/control_flow.rs @@ -4,45 +4,30 @@ use crate::internal_prelude::*; use ir::SwitchCase; /** - * If you want rustfmt to format your code, use if_!(cond, { .. }, { .. }) or if_!(cond, { .. }, else, {...}) - * instead of if_!(cond, { .. }, else {...}). + * If you want rustfmt to format your code, use if_!(cond, { .. }, { .. }) + * or if_!(cond, { .. }, else, {...}) instead of if_!(cond, { .. }, else + * {...}). * */ #[macro_export] macro_rules! if_ { ($cond:expr, $then:block, else $else_:block) => { - <_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else( - $cond, - || $then, - || $else_, - ) + <_ as $crate::lang::ops::SelectMaybeExpr<_>>::select($cond, || $then, || $else_) }; ($cond:expr, $then:block, else, $else_:block) => { - <_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else( - $cond, - || $then, - || $else_, - ) + <_ as $crate::lang::ops::SelectMaybeExpr<_>>::select($cond, || $then, || $else_) }; ($cond:expr, $then:block, $else_:block) => { - <_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else( - $cond, - || $then, - || $else_, - ) + <_ as $crate::lang::ops::SelectMaybeExpr<_>>::select($cond, || $then, || $else_) }; ($cond:expr, $then:block) => { - <_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else( - $cond, - || $then, - || {}, - ) + <_ as $crate::lang::ops::ActivateMaybeExpr>::activate($cond, || $then) }; } #[macro_export] macro_rules! while_ { ($cond:expr,$body:block) => { - <_ as $crate::lang::maybe_expr::BoolWhileMaybeExpr>::while_loop(|| $cond, || $body) + <_ as $crate::lang::ops::LoopMaybeExpr>::while_loop(|| $cond, || $body) }; } #[macro_export] @@ -66,7 +51,7 @@ pub fn continue_() { }); } -pub fn return_v(v: T) { +pub fn return_v(v: T) { RECORDER.with(|r| { let mut r = r.borrow_mut(); if r.callable_ret_type.is_none() { @@ -296,7 +281,9 @@ impl_range!(u32); impl_range!(u64); pub fn loop_(body: impl Fn()) { - while_!(true.expr(), { body(); }); + while_!(true.expr(), { + body(); + }); } pub fn for_range(r: R, body: impl Fn(Expr)) { diff --git a/luisa_compute/src/lang/debug.rs b/luisa_compute/src/lang/debug.rs index 1f2c362..d2213e0 100644 --- a/luisa_compute/src/lang/debug.rs +++ b/luisa_compute/src/lang/debug.rs @@ -31,7 +31,7 @@ impl CpuFn { _marker: PhantomData, } } - pub fn call(&self, arg: impl ExprProxy) -> Expr { + pub fn call(&self, arg: impl AsExpr) -> Expr { RECORDER.with(|r| { let mut r = r.borrow_mut(); assert!(r.lock); @@ -58,7 +58,7 @@ impl CpuFn { Expr::::from_node(__current_scope(|b| { b.call( Func::CpuCustomOp(self.op.clone()), - &[arg.node()], + &[arg.as_expr().node()], T::type_(), ) })) @@ -107,14 +107,11 @@ macro_rules! lc_assert { $crate::lang::debug::__assert($arg, $msg, file!(), line!(), column!()) }; } -pub fn __cpu_dbg(arg: T, file: &'static str, line: u32) -where - T::Value: Debug, -{ +pub fn __cpu_dbg(arg: Expr, file: &'static str, line: u32) { if !is_cpu_backend() { return; } - let f = CpuFn::new(move |x: &mut T::Value| { + let f = CpuFn::new(move |x: &mut V| { println!("[{}:{}] {:?}", file, line, x); }); let _ = f.call(arg); diff --git a/luisa_compute/src/lang/diff.rs b/luisa_compute/src/lang/diff.rs index cb37fbb..ca454d9 100644 --- a/luisa_compute/src/lang/diff.rs +++ b/luisa_compute/src/lang/diff.rs @@ -34,7 +34,7 @@ impl AdContext { thread_local! { static AD_CONTEXT:RefCell = RefCell::new(AdContext::new_rev()); } -pub fn requires_grad(var: impl ExprProxy) { +pub fn requires_grad(var: Expr) { AD_CONTEXT.with(|c| { let c = c.borrow(); assert!(c.started, "autodiff section is not started"); @@ -49,15 +49,15 @@ pub fn requires_grad(var: impl ExprProxy) { }); } -pub fn backward(out: T) { +pub fn backward(out: Expr) { backward_with_grad( out, FromNode::from_node(__current_scope(|b| { let one = new_node( b.pools(), Node::new( - CArc::new(Instruction::Const(Const::One(::type_()))), - ::type_(), + CArc::new(Instruction::Const(Const::One(V::type_()))), + V::type_(), ), ); b.append(one); @@ -66,7 +66,7 @@ pub fn backward(out: T) { ); } -pub fn backward_with_grad(out: T, grad: T) { +pub fn backward_with_grad(out: Expr, grad: Expr) { AD_CONTEXT.with(|c| { let mut c = c.borrow_mut(); assert!(c.started, "autodiff section is not started"); @@ -83,19 +83,19 @@ pub fn backward_with_grad(out: T, grad: T) { } /// Gradient of a value in *Reverse mode* AD -pub fn gradient(var: T) -> T { +pub fn gradient(var: Expr) -> Expr { AD_CONTEXT.with(|c| { let c = c.borrow(); assert!(c.started, "autodiff section is not started"); assert!(!c.is_forward_mode, "gradient() is called in forward mode"); assert!(c.backward_called, "backward is not called"); }); - T::from_node(__current_scope(|b| { + Expr::::from_node(__current_scope(|b| { b.call(Func::Gradient, &[var.node()], var.node().type_().clone()) })) } /// Gradient of a value in *Reverse mode* AD -pub fn grad(var: T) -> T { +pub fn grad(var: Expr) -> Expr { gradient(var) } @@ -108,8 +108,8 @@ pub fn grad(var: T) -> T { // let ret = body(); // let fwd = pop_scope(); // __current_scope(|b| { -// let node = new_node(Node::new(CArc::new(Instruction::AdDetach(fwd)), Type::void())); -// b.append(node); +// let node = new_node(Node::new(CArc::new(Instruction::AdDetach(fwd)), +// Type::void())); b.append(node); // }); // let nodes = ret.to_vec_nodes(); // let nodes: Vec<_> = nodes @@ -118,13 +118,14 @@ pub fn grad(var: T) -> T { // .collect(); // R::from_vec_nodes(nodes) // } -pub fn detach(v: T) -> T { +pub fn detach(v: T) -> T { let v = v.node(); let node = __current_scope(|b| b.call(Func::Detach, &[v], v.type_().clone())); T::from_node(node) } -/// Start a *Forward mode* AD section that propagates N gradients w.r.t to input variable +/// Start a *Forward mode* AD section that propagates N gradients w.r.t to input +/// variable pub fn forward_autodiff(n_grads: usize, body: impl Fn()) { AD_CONTEXT.with(|c| { let mut c = c.borrow_mut(); @@ -152,7 +153,7 @@ pub fn forward_autodiff(n_grads: usize, body: impl Fn()) { } /// Propagate N gradients w.r.t to input variable using *Forward mode* AD -pub fn propagate_gradient(v: T, grads: &[T]) { +pub fn propagate_gradient(v: Expr, grads: &[Expr]) { AD_CONTEXT.with(|c| { let c = c.borrow(); assert_eq!(grads.len(), c.n_forward_grads); @@ -169,7 +170,7 @@ pub fn propagate_gradient(v: T, grads: &[T]) { }); } -pub fn output_gradients(v: T) -> Vec { +pub fn output_gradients(v: Expr) -> Vec> { let n = AD_CONTEXT.with(|c| { let c = c.borrow(); assert!(c.started, "autodiff section is not started"); @@ -183,7 +184,7 @@ pub fn output_gradients(v: T) -> Vec { let mut grads = vec![]; for i in 0..n { let idx = b.const_(Const::Int32(i as i32)); - grads.push(T::from_node(b.call( + grads.push(Expr::::from_node(b.call( Func::OutputGrad, &[v.node(), idx], v.node().type_().clone(), diff --git a/luisa_compute/src/lang/functions.rs b/luisa_compute/src/lang/functions.rs index 23be6fa..a5d057d 100644 --- a/luisa_compute/src/lang/functions.rs +++ b/luisa_compute/src/lang/functions.rs @@ -1,5 +1,7 @@ use crate::internal_prelude::*; +use super::types::core::{Integral, Numeric}; + pub fn thread_id() -> Expr { Expr::::from_node(__current_scope(|b| { b.call(Func::ThreadId, &[], Uint3::type_()) @@ -51,10 +53,10 @@ pub fn sync_block() { pub fn warp_is_first_active_lane() -> Expr { Expr::::from_node(__current_scope(|b| { - b.call(Func::WarpIsFirstActiveLane, &[], Expr::::type_()) + b.call(Func::WarpIsFirstActiveLane, &[], bool::type_()) })) } -pub fn warp_active_all_equal(v: impl ScalarOrVector) -> Expr { +pub fn warp_active_all_equal(v: Expr) -> Expr { Expr::::from_node(__current_scope(|b| { b.call( Func::WarpActiveAllEqual, @@ -63,8 +65,11 @@ pub fn warp_active_all_equal(v: impl ScalarOrVector) -> Expr { ) })) } -pub fn warp_active_bit_and, E: IntVarTrait>(v: T) -> T { - T::from_node(__current_scope(|b| { +pub fn warp_active_bit_and(v: Expr) -> Expr +where + T::Scalar: Integral + Numeric, +{ + Expr::::from_node(__current_scope(|b| { b.call( Func::WarpActiveBitAnd, &[v.node()], @@ -73,8 +78,11 @@ pub fn warp_active_bit_and, E: IntVarTrait>(v: T) })) } -pub fn warp_active_bit_or, E: IntVarTrait>(v: T) -> T { - T::from_node(__current_scope(|b| { +pub fn warp_active_bit_or(v: Expr) -> Expr +where + T::Scalar: Integral + Numeric, +{ + Expr::::from_node(__current_scope(|b| { b.call( Func::WarpActiveBitOr, &[v.node()], @@ -83,8 +91,11 @@ pub fn warp_active_bit_or, E: IntVarTrait>(v: T) })) } -pub fn warp_active_bit_xor, E: IntVarTrait>(v: T) -> T { - T::from_node(__current_scope(|b| { +pub fn warp_active_bit_xor(v: Expr) -> Expr +where + T::Scalar: Integral + Numeric, +{ + Expr::::from_node(__current_scope(|b| { b.call( Func::WarpActiveBitXor, &[v.node()], @@ -93,37 +104,33 @@ pub fn warp_active_bit_xor, E: IntVarTrait>(v: T) })) } -pub fn warp_active_count_bits(v: impl Into>) -> Expr { +pub fn warp_active_count_bits(v: impl AsExpr) -> Expr { Expr::::from_node(__current_scope(|b| { b.call( Func::WarpActiveCountBits, - &[v.into().node()], + &[v.as_expr().node()], ::type_(), ) })) } -pub fn warp_active_max(v: T) -> T::Element { - ::from_node(__current_scope(|b| { - b.call(Func::WarpActiveMax, &[v.node()], ::type_()) +pub fn warp_active_max(v: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(Func::WarpActiveMax, &[v.node()], ::type_()) })) } -pub fn warp_active_min(v: T) -> T::Element { - ::from_node(__current_scope(|b| { - b.call(Func::WarpActiveMin, &[v.node()], ::type_()) +pub fn warp_active_min(v: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(Func::WarpActiveMin, &[v.node()], ::type_()) })) } -pub fn warp_active_product(v: T) -> T::Element { - ::from_node(__current_scope(|b| { - b.call( - Func::WarpActiveProduct, - &[v.node()], - ::type_(), - ) +pub fn warp_active_product(v: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(Func::WarpActiveProduct, &[v.node()], ::type_()) })) } -pub fn warp_active_sum(v: T) -> T::Element { - ::from_node(__current_scope(|b| { - b.call(Func::WarpActiveSum, &[v.node()], ::type_()) +pub fn warp_active_sum(v: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(Func::WarpActiveSum, &[v.node()], ::type_()) })) } pub fn warp_active_all(v: Expr) -> Expr { @@ -150,13 +157,13 @@ pub fn warp_prefix_count_bits(v: Expr) -> Expr { ) })) } -pub fn warp_prefix_sum_exclusive(v: T) -> T { - T::from_node(__current_scope(|b| { +pub fn warp_prefix_sum_exclusive(v: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { b.call(Func::WarpPrefixSum, &[v.node()], v.node().type_().clone()) })) } -pub fn warp_prefix_product_exclusive(v: T) -> T { - T::from_node(__current_scope(|b| { +pub fn warp_prefix_product_exclusive(v: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { b.call( Func::WarpPrefixProduct, &[v.node()], @@ -164,9 +171,10 @@ pub fn warp_prefix_product_exclusive(v: T) -> T { ) })) } -pub fn warp_read_lane_at(v: T, index: impl Into>) -> T { - let index = index.into(); - T::from_node(__current_scope(|b| { +// TODO: Difference between `Linear` and BuiltinVarTrait? +pub fn warp_read_lane_at(v: Expr, index: impl AsExpr) -> Expr { + let index = index.as_expr(); + Expr::::from_node(__current_scope(|b| { b.call( Func::WarpReadLaneAt, &[v.node(), index.node()], @@ -174,8 +182,8 @@ pub fn warp_read_lane_at(v: T, index: impl Into>) ) })) } -pub fn warp_read_first_active_lane(v: T) -> T { - T::from_node(__current_scope(|b| { +pub fn warp_read_first_active_lane(v: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { b.call( Func::WarpReadFirstLane, &[v.node()], diff --git a/luisa_compute/src/lang/index.rs b/luisa_compute/src/lang/index.rs index 31d84d3..d83a0c2 100644 --- a/luisa_compute/src/lang/index.rs +++ b/luisa_compute/src/lang/index.rs @@ -25,7 +25,7 @@ impl IntoIndex for u64 { } impl IntoIndex for Expr { fn to_u64(&self) -> Expr { - self.ulong() + self.cast::() } } impl IntoIndex for Expr { @@ -40,5 +40,5 @@ pub trait IndexRead: ToNode { } pub trait IndexWrite: IndexRead { - fn write>>(&self, i: I, value: V); + fn write>(&self, i: I, value: V); } diff --git a/luisa_compute/src/lang/maybe_expr.rs b/luisa_compute/src/lang/maybe_expr.rs deleted file mode 100644 index 2654889..0000000 --- a/luisa_compute/src/lang/maybe_expr.rs +++ /dev/null @@ -1,256 +0,0 @@ -//! The purpose of this module is to provide traits to represent things that may -//! either be an expression or a normal value. This is necessary for making the -//! trace macro work for both types of value. - -use super::control_flow::{generic_loop, if_then_else}; -use super::types::core::*; -use crate::internal_prelude::*; - -pub trait BoolIfElseMaybeExpr { - fn if_then_else(self, then: impl FnOnce() -> R, else_: impl FnOnce() -> R) -> R; -} -impl BoolIfElseMaybeExpr for bool { - fn if_then_else(self, then: impl FnOnce() -> R, else_: impl FnOnce() -> R) -> R { - if self { - then() - } else { - else_() - } - } -} -impl BoolIfElseMaybeExpr for Bool { - fn if_then_else(self, then: impl FnOnce() -> R, else_: impl FnOnce() -> R) -> R { - if_then_else(self, then, else_) - } -} - -pub trait BoolIfMaybeExpr { - fn if_then(self, then: impl FnOnce()); -} -impl BoolIfMaybeExpr for bool { - fn if_then(self, then: impl FnOnce()) { - if self { - then() - } - } -} -impl BoolIfMaybeExpr for Bool { - fn if_then(self, then: impl FnOnce()) { - if_then_else(self, then, || {}) - } -} - -pub trait BoolWhileMaybeExpr { - fn while_loop(this: impl FnMut() -> Self, body: impl FnMut()); -} -impl BoolWhileMaybeExpr for bool { - fn while_loop(mut this: impl FnMut() -> Self, mut body: impl FnMut()) { - while this() { - body() - } - } -} -impl BoolWhileMaybeExpr for Bool { - fn while_loop(this: impl FnMut() -> Self, body: impl FnMut()) { - generic_loop(this, body, || {}); - } -} - -// TODO: Support lazy expressions if that isn't done already? -pub trait BoolLazyOpsMaybeExpr { - type Ret; - fn and(self, other: impl FnOnce() -> R) -> Self::Ret; - fn or(self, other: impl FnOnce() -> R) -> Self::Ret; -} -impl BoolLazyOpsMaybeExpr for bool { - type Ret = bool; - fn and(self, other: impl FnOnce() -> bool) -> Self::Ret { - self && other() - } - fn or(self, other: impl FnOnce() -> bool) -> Self::Ret { - self || other() - } -} -impl BoolLazyOpsMaybeExpr for bool { - type Ret = Bool; - fn and(self, other: impl FnOnce() -> Bool) -> Self::Ret { - self & other() - } - fn or(self, other: impl FnOnce() -> Bool) -> Self::Ret { - self | other() - } -} -impl BoolLazyOpsMaybeExpr for Bool { - type Ret = Bool; - fn and(self, other: impl FnOnce() -> bool) -> Self::Ret { - self & other() - } - fn or(self, other: impl FnOnce() -> bool) -> Self::Ret { - self | other() - } -} -impl BoolLazyOpsMaybeExpr for Bool { - type Ret = Bool; - fn and(self, other: impl FnOnce() -> Bool) -> Self::Ret { - self & other() - } - fn or(self, other: impl FnOnce() -> Bool) -> Self::Ret { - self | other() - } -} - -pub trait EqMaybeExpr { - type Bool; - fn eq(self, other: X) -> Self::Bool; - fn ne(self, other: X) -> Self::Bool; -} -impl> EqMaybeExpr for A { - type Bool = bool; - fn eq(self, other: R) -> Self::Bool { - self == other - } - fn ne(self, other: R) -> Self::Bool { - self != other - } -} -macro_rules! impl_eme { - ($t: ty, $s: ty) => { - impl EqMaybeExpr<$s> for $t { - type Bool = <$t as VarTrait>::Bool; - fn eq(self, other: $s) -> Self::Bool { - self.cmpeq(other) - } - fn ne(self, other: $s) -> Self::Bool { - self.cmpne(other) - } - } - }; -} -macro_rules! impl_mem { - ($t: ty, $s: ty) => { - impl EqMaybeExpr<$s> for $t { - type Bool = <$s as VarTrait>::Bool; - fn eq(self, other: $s) -> Self::Bool { - other.cmpeq(self) - } - fn ne(self, other: $s) -> Self::Bool { - other.cmpne(self) - } - } - }; -} -macro_rules! emes { - ($x: ty $(, $y: ty)*) => { - impl_eme!(Expr<$x>, Expr<$x>); - impl_eme!(Expr<$x>, $x); - impl_mem!($x, Expr<$x>); - $(impl_eme!(Expr<$x>, $y); - impl_mem!($y, Expr<$x>);)* - }; -} -emes!(bool); -emes!(Bool2); -emes!(Bool3); -emes!(Bool4); - -pub trait PartialOrdMaybeExpr { - type Bool; - fn lt(self, other: R) -> Self::Bool; - fn le(self, other: R) -> Self::Bool; - fn ge(self, other: R) -> Self::Bool; - fn gt(self, other: R) -> Self::Bool; -} -impl> PartialOrdMaybeExpr for A { - type Bool = bool; - fn lt(self, other: R) -> Self::Bool { - self < other - } - fn le(self, other: R) -> Self::Bool { - self <= other - } - fn ge(self, other: R) -> Self::Bool { - self >= other - } - fn gt(self, other: R) -> Self::Bool { - self > other - } -} -macro_rules! impl_pome { - ($t: ty, $s: ty) => { - impl_eme!($t, $s); - impl PartialOrdMaybeExpr<$s> for $t { - type Bool = <$t as VarTrait>::Bool; - fn lt(self, other: $s) -> Self::Bool { - self.cmplt(other) - } - fn le(self, other: $s) -> Self::Bool { - self.cmple(other) - } - fn ge(self, other: $s) -> Self::Bool { - self.cmpge(other) - } - fn gt(self, other: $s) -> Self::Bool { - self.cmpgt(other) - } - } - }; -} -macro_rules! impl_emop { - ($t: ty, $s: ty) => { - impl_mem!($t, $s); - impl PartialOrdMaybeExpr<$s> for $t { - type Bool = <$s as VarTrait>::Bool; - fn lt(self, other: $s) -> Self::Bool { - other.cmpgt(self) - } - fn le(self, other: $s) -> Self::Bool { - other.cmpge(self) - } - fn ge(self, other: $s) -> Self::Bool { - other.cmplt(self) - } - fn gt(self, other: $s) -> Self::Bool { - other.cmplt(self) - } - } - }; -} -macro_rules! pomes { - ($x: ty $(, $y:ty)*) => { - impl_pome!(Expr<$x>, Expr<$x>); - impl_pome!(Expr<$x>, $x); - impl_emop!($x, Expr<$x>); - impl_pome!(Expr<$x>, Var<$x>); - impl_emop!(Var<$x>, Expr<$x>); - $(impl_pome!(Expr<$x>, $y); - impl_emop!($y, Expr<$x>);)* - }; -} -pomes!(f16); -pomes!(f32); -pomes!(f64); -pomes!(i16); -pomes!(i32); -pomes!(i64); -pomes!(u16); -pomes!(u32); -pomes!(u64); - -pomes!(Float2, Expr, f32); -pomes!(Float3, Expr, f32); -pomes!(Float4, Expr, f32); -pomes!(Double2); -pomes!(Double3); -pomes!(Double4); -pomes!(Int2, Expr); -pomes!(Int3, Expr); -pomes!(Int4, Expr); -pomes!(Uint2, Expr); -pomes!(Uint3, Expr); -pomes!(Uint4, Expr); - -#[allow(dead_code)] -fn tests() { - <_ as BoolWhileMaybeExpr>::while_loop(|| true, || {}); - <_ as BoolWhileMaybeExpr>::while_loop(|| Bool::from(true), || {}); -} diff --git a/luisa_compute/src/lang/ops.rs b/luisa_compute/src/lang/ops.rs index 4920bf7..3758647 100644 --- a/luisa_compute/src/lang/ops.rs +++ b/luisa_compute/src/lang/ops.rs @@ -1,490 +1,48 @@ use crate::internal_prelude::*; use std::ops::*; -pub mod impls; - -pub trait VarTrait: Copy + Clone + 'static + FromNode { - type Value: Value; - type Short: VarTrait; - type Ushort: VarTrait; - type Int: VarTrait; - type Uint: VarTrait; - type Long: VarTrait; - type Ulong: VarTrait; - type Half: VarTrait; - type Float: VarTrait; - type Double: VarTrait; - type Bool: VarTrait + Not + BitAnd; - fn type_() -> CArc { - ::type_() - } -} +use super::types::core::{Floating, Integral, Numeric, Primitive, Signed}; +use super::types::vector::{VectorAlign, VectorElement}; -fn _cast(expr: T) -> U { - let node = expr.node(); - __current_scope(|s| { - let ret = s.call(Func::Cast, &[node], U::type_()); - U::from_node(ret) - }) -} -pub trait CommonVarOp: VarTrait { - fn max>(&self, other: A) -> Self { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Max, &[lhs, rhs], Self::type_()); - Self::from_node(ret) - }) - } - fn min>(&self, other: A) -> Self { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Min, &[lhs, rhs], Self::type_()); - Self::from_node(ret) - }) - } - fn clamp, B: Into>(&self, min: A, max: B) -> Self { - let min = min.into().node(); - let max = max.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Clamp, &[self.node(), min, max], Self::type_()); - Self::from_node(ret) - }) - } - fn abs(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Abs, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn bitcast(&self) -> Expr { - assert_eq!(std::mem::size_of::(), std::mem::size_of::()); - let ty = ::type_(); - let node = __current_scope(|s| s.bitcast(self.node(), ty)); - Expr::::from_node(node) - } - fn uint(&self) -> Self::Uint { - _cast(*self) - } - fn int(&self) -> Self::Int { - _cast(*self) - } - fn ulong(&self) -> Self::Ulong { - _cast(*self) - } - fn long(&self) -> Self::Long { - _cast(*self) - } - fn float(&self) -> Self::Float { - _cast(*self) - } - fn short(&self) -> Self::Short { - _cast(*self) - } - fn ushort(&self) -> Self::Ushort { - _cast(*self) - } - fn half(&self) -> Self::Half { - _cast(*self) - } - fn double(&self) -> Self::Double { - _cast(*self) - } - fn bool_(&self) -> Self::Bool { - _cast(*self) - } -} -pub trait VarCmpEq: VarTrait { - fn cmpeq>(&self, other: A) -> Self::Bool { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Eq, &[lhs, rhs], Self::Bool::type_()); - FromNode::from_node(ret) - }) - } - fn cmpne>(&self, other: A) -> Self::Bool { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Ne, &[lhs, rhs], Self::Bool::type_()); - FromNode::from_node(ret) - }) - } -} -pub trait VarCmp: VarTrait + VarCmpEq { - fn cmplt>(&self, other: A) -> Self::Bool { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Lt, &[lhs, rhs], Self::Bool::type_()); - FromNode::from_node(ret) - }) - } - fn cmple>(&self, other: A) -> Self::Bool { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Le, &[lhs, rhs], Self::Bool::type_()); - FromNode::from_node(ret) - }) - } - fn cmpgt>(&self, other: A) -> Self::Bool { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Gt, &[lhs, rhs], Self::Bool::type_()); - FromNode::from_node(ret) - }) - } - fn cmpge>(&self, other: A) -> Self::Bool { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Ge, &[lhs, rhs], Self::Bool::type_()); - FromNode::from_node(ret) - }) - } -} -pub trait IntVarTrait: - VarTrait - + CommonVarOp - + VarCmp - + Add - + Sub - + Mul - + Div - + Rem - + Shl - + Shr - + BitAnd - + BitOr - + BitXor - + AddAssign - + SubAssign - + MulAssign - + DivAssign - + RemAssign - + ShlAssign - + ShrAssign - + BitAndAssign - + BitOrAssign - + BitXorAssign - + Neg - + Clone - + Not - + From - + From -{ - fn one() -> Self { - Self::from(1i64) - } - fn zero() -> Self { - Self::from(0i64) - } - fn rotate_right(&self, n: Expr) -> Self { - let lhs = self.node(); - let rhs = Expr::::node(&n); - __current_scope(|s| { - let ret = s.call(Func::RotRight, &[lhs, rhs], Self::type_()); - Self::from_node(ret) - }) - } - fn rotate_left(&self, n: Expr) -> Self { - let lhs = self.node(); - let rhs = Expr::::node(&n); - __current_scope(|s| { - let ret = s.call(Func::RotLeft, &[lhs, rhs], Self::type_()); - Self::from_node(ret) - }) - } -} -pub trait FloatVarTrait: - VarTrait - + CommonVarOp - + VarCmp - + Add - + Sub - + Mul - + Div - + Rem - + Neg - + AddAssign - + SubAssign - + MulAssign - + DivAssign - + RemAssign - + Clone - + From - + From -{ - fn one() -> Self { - Self::from(1.0f32) - } - fn zero() -> Self { - Self::from(0.0f32) - } - fn mul_add, B: Into>(&self, a: A, b: B) -> Self { - let a: Self = a.into(); - let b: Self = b.into(); - let node = __current_scope(|s| { - s.call(Func::Fma, &[self.node(), a.node(), b.node()], Self::type_()) - }); - Self::from_node(node) - } - fn ceil(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Ceil, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn floor(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Floor, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn round(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Round, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn trunc(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Trunc, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn copysign>(&self, other: A) -> Self { - __current_scope(|s| { - let ret = s.call( - Func::Copysign, - &[self.node(), other.into().node()], - Self::type_(), - ); - Self::from_node(ret) - }) - } - fn sqrt(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Sqrt, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn rsqrt(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Rsqrt, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn fract(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Fract, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - - // x.step(edge) - fn step(&self, edge: Self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Step, &[edge.node(), self.node()], Self::type_()); - Self::from_node(ret) - }) - } - - fn smooth_step(&self, edge0: Self, edge1: Self) -> Self { - __current_scope(|s| { - let ret = s.call( - Func::SmoothStep, - &[edge0.node(), edge1.node(), self.node()], - Self::type_(), - ); - Self::from_node(ret) - }) - } - - fn saturate(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Saturate, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - - fn sin(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Sin, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - // crate::math::approx_sin_cos(self.clone(), true, false).0 - } - fn cos(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Cos, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - // crate::math::approx_sin_cos(self.clone(), false, true).1 - } - fn tan(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Tan, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn asin(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Asin, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn acos(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Acos, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn atan(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Atan, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn atan2(&self, other: Self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Atan2, &[self.node(), other.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn sinh(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Sinh, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn cosh(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Cosh, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn tanh(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Tanh, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn asinh(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Asinh, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn acosh(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Acosh, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn atanh(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Atanh, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn exp(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Exp, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn exp2(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Exp2, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn is_finite(&self) -> Self::Bool { - !self.is_infinite() & !self.is_nan() - } - fn is_infinite(&self) -> Self::Bool { - __current_scope(|s| { - let ret = s.call(Func::IsInf, &[self.node()], ::type_()); - FromNode::from_node(ret) - }) - } - fn is_nan(&self) -> Self::Bool { - __current_scope(|s| { - let ret = s.call(Func::IsNan, &[self.node()], ::type_()); - FromNode::from_node(ret) - }) - } - fn ln(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Log, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn log(&self, base: impl Into) -> Self { - self.ln() / base.into().ln() - } - fn log2(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Log2, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn log10(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Log10, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn powf(&self, exp: impl Into) -> Self { - let exp = exp.into(); - __current_scope(|s| { - let ret = s.call(Func::Powf, &[self.node(), exp.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn sqr(&self) -> Self { - *self * *self - } - fn cube(&self) -> Self { - *self * *self * *self - } - fn powi(&self, exp: impl Into) -> Self { - let exp = exp.into(); - __current_scope(|s| { - let ret = s.call(Func::Powi, &[self.node(), exp.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn lerp(&self, other: impl Into, frac: impl Into) -> Self { - let other = other.into(); - let frac = frac.into(); - __current_scope(|s| { - let ret = s.call( - Func::Lerp, - &[self.node(), other.node(), frac.node()], - Self::type_(), - ); - Self::from_node(ret) - }) - } - fn recip(&self) -> Self { - Self::one() / self.clone() - } - fn sin_cos(&self) -> (Self, Self) { - (self.sin(), self.cos()) - } -} - -pub trait ScalarVarTrait: ToNode + FromNode {} -pub trait VectorVarTrait: ToNode + FromNode {} -pub trait MatrixVarTrait: ToNode + FromNode {} -pub trait ScalarOrVector: ToNode + FromNode { - type Element: ScalarVarTrait; - type ElementHost: Value; +pub mod impls; +pub mod spread; +pub mod traits; + +pub use traits::*; + +trait CastFrom: Primitive {} +impl CastFrom for T {} +impl CastFrom for T {} + +pub trait Linear: Value { + // Note that without #![feature(generic_const_exprs)], I can't use this within + // the WithScalar restriction. As such, we can't support higher dimensional + // vector operations. If that ever becomes necessary, check commit + // 9e6eacf6b0c2b59a2646f45a727e4d82e84a46cd. + const N: usize; + type Scalar: VectorElement; + type WithScalar: Linear; + // We don't actually know that the vector has equivalent vectors of every + // primitive type. +} +impl Linear for T { + const N: usize = 1; + type Scalar = T; + type WithScalar = S; +} + +impl Linear for Vector { + const N: usize = 2; + type Scalar = T; + type WithScalar = Vector; +} +impl Linear for Vector { + const N: usize = 3; + type Scalar = T; + type WithScalar = Vector; +} +impl Linear for Vector { + const N: usize = 4; + type Scalar = T; + type WithScalar = Vector; } -pub trait BuiltinVarTrait: ToNode + FromNode {} diff --git a/luisa_compute/src/lang/ops/impls.rs b/luisa_compute/src/lang/ops/impls.rs index db3af67..0cf2fe4 100644 --- a/luisa_compute/src/lang/ops/impls.rs +++ b/luisa_compute/src/lang/ops/impls.rs @@ -1,338 +1,517 @@ use super::*; -use crate::lang::types::core::*; -use crate::lang::types::VarDerefProxy; - -macro_rules! impl_var_trait { - ($t:ty) => { - impl VarTrait for prim::Expr<$t> { - type Value = $t; - type Short = prim::Expr; - type Ushort = prim::Expr; - type Int = prim::Expr; - type Uint = prim::Expr; - type Long = prim::Expr; - type Ulong = prim::Expr; - type Half = prim::Expr; - type Float = prim::Expr; - type Double = prim::Expr; - type Bool = prim::Expr; + +impl Expr { + pub fn as_(self) -> Expr + where + Y::Scalar: CastFrom, + { + assert_eq!(X::N, Y::N); + Func::Cast.call(self) + } + pub fn cast(self) -> Expr> + where + S: CastFrom, + { + self.as_::<::WithScalar>() + } +} + +macro_rules! impl_ops_trait { + ( + [$($bounds:tt)*] $TraitExpr:ident [$TraitThis:ident] for $T:ty where [$($where:tt)*] { + $( + fn $fn:ident [$fn_this:ident] ($sl:ident, $($arg:ident),*) { $body:expr } + )* } - impl ScalarVarTrait for prim::Expr<$t> {} - impl ScalarOrVector for prim::Expr<$t> { - type Element = prim::Expr<$t>; - type ElementHost = $t; + ) => { + impl<$($bounds)*> $TraitThis for $T where $($where)* { + $( + fn $fn_this($sl, $($arg: Self),*) -> Self { + $body + } + )* } - impl BuiltinVarTrait for prim::Expr<$t> {} - }; -} -impl_var_trait!(f16); -impl_var_trait!(f32); -impl_var_trait!(f64); -impl_var_trait!(i16); -impl_var_trait!(u16); -impl_var_trait!(i32); -impl_var_trait!(u32); -impl_var_trait!(i64); -impl_var_trait!(u64); -impl_var_trait!(bool); - -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} - -impl VarCmpEq for prim::Expr {} - -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} - -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} - -impl CommonVarOp for prim::Expr {} - -impl FloatVarTrait for prim::Expr {} -impl FloatVarTrait for prim::Expr {} -impl FloatVarTrait for prim::Expr {} - -impl IntVarTrait for prim::Expr {} -impl IntVarTrait for prim::Expr {} -impl IntVarTrait for prim::Expr {} -impl IntVarTrait for prim::Expr {} -impl IntVarTrait for prim::Expr {} -impl IntVarTrait for prim::Expr {} - -macro_rules! impl_from { - ($from:ty, $to:ty) => { - impl From<$from> for prim::Expr<$to> { - fn from(x: $from) -> Self { - let y: $to = (x.try_into().unwrap()); - y.expr() - } + impl<$($bounds)*> $TraitExpr for $T where $($where)* { + type Output = Self; + + $( + fn $fn($sl, $($arg: Self),*) -> Self { + <$T as $TraitThis>::$fn_this($sl, $($arg),*) + } + )* } }; + ( + [$($bounds:tt)*] $TraitExpr:ident [$TraitThis:ident] for $T:ty where [$($where:tt)*] { + type Output = $Output:ty; + $( + fn $fn:ident [$fn_this:ident] ($sl:ident, $($arg:ident),*) { $body:expr } + )* + } + ) => { + impl<$($bounds)*> $TraitThis for $T where $($where)* { + type Output = $Output; + $( + fn $fn_this($sl, $($arg: Self),*) -> Self::Output { + $body + } + )* + } + impl<$($bounds)*> $TraitExpr for $T where $($where)* { + type Output = $Output; + + $( + fn $fn($sl, $($arg: Self),*) -> Self::Output { + <$T as $TraitThis>::$fn_this($sl, $($arg),*) + } + )* + } + } } +impl_ops_trait!([X: Linear] MinMaxExpr[MinMaxThis] for Expr where [X::Scalar: Numeric] { + fn max[_max](self, other) { Func::Max.call2(self, other) } + fn min[_min](self, other) { Func::Min.call2(self, other) } +}); + +impl_ops_trait!([X: Linear] ClampExpr[ClampThis] for Expr where [X::Scalar: Numeric] { + fn clamp[_clamp](self, min, max) { Func::Clamp.call3(self, min, max) } +}); -impl_from!(i16, u16); -impl_from!(i16, i32); -impl_from!(i16, u32); -impl_from!(i16, i64); -impl_from!(i16, u64); +impl AbsExpr for Expr +where + X::Scalar: Signed, +{ + fn abs(&self) -> Self { + Func::Abs.call(self.clone()) + } +} -impl_from!(u16, i16); -impl_from!(u16, i32); -impl_from!(u16, u32); -impl_from!(u16, i64); -impl_from!(u16, u64); +impl_ops_trait!([X: Linear] EqExpr[EqThis] for Expr where [X::Scalar: VectorElement] { + type Output = Expr>; -impl_from!(i32, u16); -impl_from!(i32, i16); -impl_from!(i32, u32); -impl_from!(i32, i64); -impl_from!(i32, u64); + fn eq[_eq](self, other) { Func::Eq.call2(self, other) } + fn ne[_ne](self, other) { Func::Ne.call2(self, other) } +}); -impl_from!(i64, u16); -impl_from!(i64, i16); -impl_from!(i64, u64); -impl_from!(i64, i32); -impl_from!(i64, u32); +impl_ops_trait!([X: Linear] CmpExpr[CmpThis] for Expr where [X::Scalar: Numeric] { + type Output = Expr>; -impl_from!(u32, u16); -impl_from!(u32, i16); -impl_from!(u32, i32); -impl_from!(u32, i64); -impl_from!(u32, u64); + fn lt[_lt](self, other) { Func::Lt.call2(self, other) } + fn le[_le](self, other) { Func::Le.call2(self, other) } + fn gt[_gt](self, other) { Func::Gt.call2(self, other) } + fn ge[_ge](self, other) { Func::Ge.call2(self, other) } +}); -impl_from!(u64, u16); -impl_from!(u64, i16); -impl_from!(u64, i64); -impl_from!(u64, i32); -impl_from!(u64, u32); +impl Add for Expr +where + X::Scalar: Numeric, +{ + type Output = Self; + fn add(self, other: Self) -> Self { + Func::Add.call2(self, other) + } +} +impl Sub for Expr +where + X::Scalar: Numeric, +{ + type Output = Self; + fn sub(self, other: Self) -> Self { + Func::Sub.call2(self, other) + } +} +impl Mul for Expr +where + X::Scalar: Numeric, +{ + type Output = Self; + fn mul(self, other: Self) -> Self { + Func::Mul.call2(self, other) + } +} +impl Div for Expr +where + X::Scalar: Numeric, +{ + type Output = Self; + fn div(self, other: Self) -> Self { + Func::Div.call2(self, other) + } +} +impl Rem for Expr +where + X::Scalar: Numeric, +{ + type Output = Self; + fn rem(self, other: Self) -> Self { + Func::Rem.call2(self, other) + } +} -impl From for prim::Expr { - fn from(x: f64) -> Self { - (x as f32).into() +impl BitAnd for Expr +where + X::Scalar: Integral, +{ + type Output = Self; + fn bitand(self, other: Self) -> Self { + Func::BitAnd.call2(self, other) } } -impl From for prim::Expr { - fn from(x: f32) -> Self { - (x as f64).into() +impl BitOr for Expr +where + X::Scalar: Integral, +{ + type Output = Self; + fn bitor(self, other: Self) -> Self { + Func::BitOr.call2(self, other) } } -impl From for prim::Expr { - fn from(x: f64) -> Self { - f16::from_f64(x).into() +impl BitXor for Expr +where + X::Scalar: Integral, +{ + type Output = Self; + fn bitxor(self, other: Self) -> Self { + Func::BitXor.call2(self, other) } } -impl From for prim::Expr { - fn from(x: f32) -> Self { - f16::from_f32(x).into() +impl Shl for Expr +where + X::Scalar: Integral, +{ + type Output = Self; + fn shl(self, other: Self) -> Self { + Func::Shl.call2(self, other) + } +} +impl Shr for Expr +where + X::Scalar: Integral, +{ + type Output = Self; + fn shr(self, other: Self) -> Self { + Func::Shr.call2(self, other) } } -macro_rules! impl_binop { - ($t:ty, $proxy:ty, $tr_assign:ident, $method_assign:ident, $tr:ident, $method:ident) => { - impl $tr_assign> for $proxy { - fn $method_assign(&mut self, rhs: prim::Expr<$t>) { - *self = self.clone().$method(rhs); - } - } - impl $tr_assign<$t> for $proxy { - fn $method_assign(&mut self, rhs: $t) { - *self = self.clone().$method(rhs); - } - } - impl $tr> for $proxy { - type Output = prim::Expr<$t>; - fn $method(self, rhs: prim::Expr<$t>) -> Self::Output { - __current_scope(|s| { - let lhs = ToNode::node(&self); - let rhs = ToNode::node(&rhs); - let ret = s.call(Func::$tr, &[lhs, rhs], Self::Output::type_()); - Expr::<$t>::from_node(ret) - }) - } +impl Neg for Expr +where + X::Scalar: Signed, +{ + type Output = Self; + fn neg(self) -> Self { + Func::Neg.call(self) + } +} +impl Not for Expr +where + X::Scalar: Integral, +{ + type Output = Self; + fn not(self) -> Self { + Func::BitNot.call(self) + } +} + +impl IntExpr for Expr +where + X::Scalar: Integral + Numeric, +{ + fn rotate_left(&self, n: Expr) -> Self { + Func::RotRight.call2(self.clone(), n) + } + fn rotate_right(&self, n: Expr) -> Self { + Func::RotLeft.call2(self.clone(), n) + } +} + +macro_rules! impl_simple_fns { + ($($fname:ident => $func:ident),+) => {$( + fn $fname(&self) -> Self { + Func::$func.call(self.clone()) } + )+}; +} + +impl FloatExpr for Expr +where + X::Scalar: Floating, +{ + type Bool = Expr>; + impl_simple_fns! { + ceil => Ceil, + floor => Floor, + round => Round, + trunc => Trunc, + sqrt => Sqrt, + rsqrt => Rsqrt, + fract => Fract, + saturate => Saturate, + sin => Sin, + cos => Cos, + tan => Tan, + asin => Asin, + acos => Acos, + atan => Atan, + sinh => Sinh, + cosh => Cosh, + tanh => Tanh, + asinh => Asinh, + acosh => Acosh, + atanh => Atanh, + exp => Exp, + exp2 => Exp2, + ln => Log, + log2 => Log2, + log10 => Log10 + } + fn is_finite(&self) -> Self::Bool { + !self.is_infinite() & !self.is_nan() + } + fn is_infinite(&self) -> Self::Bool { + Func::IsInf.call(self.clone()) + } + fn is_nan(&self) -> Self::Bool { + Func::IsNan.call(self.clone()) + } + fn sqr(&self) -> Self { + self.clone() * self.clone() + } + fn cube(&self) -> Self { + self.clone() * self.clone() * self.clone() + } + fn recip(&self) -> Self { + todo!() + // 1.0 / self.clone() + } + fn sin_cos(&self) -> (Self, Self) { + (self.sin(), self.cos()) + } +} +impl_ops_trait!([X: Linear] FloatMulAddExpr[FloatMulAddThis] for Expr where [X::Scalar: Floating] { + fn mul_add[_mul_add](self, a, b) { Func::Fma.call3(self, a, b) } +}); + +impl_ops_trait!([X: Linear] FloatCopySignExpr[FloatCopySignThis] for Expr where [X::Scalar: Floating] { + fn copy_sign[_copy_sign](self, sign) { Func::Copysign.call2(self, sign) } +}); + +impl_ops_trait!([X: Linear] FloatStepExpr[FloatStepThis] for Expr where [X::Scalar: Floating] { + fn step[_step](self, edge) { Func::Step.call2(edge, self) } +}); + +impl_ops_trait!([X: Linear] FloatSmoothStepExpr[FloatSmoothStepThis] for Expr where [X::Scalar: Floating] { + fn smooth_step[_smooth_step](self, edge0, edge1) { Func::SmoothStep.call3(edge0, edge1, self) } +}); + +impl_ops_trait!([X: Linear] FloatArcTan2Expr[FloatArcTan2This] for Expr where [X::Scalar: Floating] { + fn atan2[_atan2](self, other) { Func::Atan2.call2(self, other) } +}); + +impl_ops_trait!([X: Linear] FloatLogExpr[FloatLogThis] for Expr where [X::Scalar: Floating] { + fn log[_log](self, base) { self.ln() / base.ln()} +}); + +impl_ops_trait!([X: Linear] FloatPowfExpr[FloatPowfThis] for Expr where [X::Scalar: Floating] { + fn powf[_powf](self, exponent) { Func::Powf.call2(self, exponent) } +}); + +impl> FloatPowiExpr> for Expr +where + X::Scalar: Floating, +{ + type Output = Self; + + fn powi(self, exponent: Expr) -> Self::Output { + Func::Powi.call2(self, exponent) + } +} - impl $tr<$t> for $proxy { - type Output = prim::Expr<$t>; - fn $method(self, rhs: $t) -> Self::Output { - $tr::$method(self, rhs.expr()) - } +impl_ops_trait!([X: Linear] FloatLerpExpr[FloatLerpThis] for Expr where [X::Scalar: Floating] { + fn lerp[_lerp](self, other, frac) { Func::Lerp.call3(self, other, frac) } +}); + +// Traits for `track!`. + +impl StoreMaybeExpr for &mut T { + fn store(self, value: T) { + *self = value; + } +} +impl> StoreMaybeExpr for &Var { + fn store(self, value: E) { + crate::lang::_store(self, &value.as_expr()); + } +} +impl> StoreMaybeExpr for Var { + fn store(self, value: E) { + crate::lang::_store(&self, &value.as_expr()); + } +} + +impl SelectMaybeExpr for bool { + fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R { + if self { + on() + } else { + off() } - impl $tr<$proxy> for $t { - type Output = prim::Expr<$t>; - fn $method(self, rhs: $proxy) -> Self::Output { - $tr::$method(self.expr(), rhs) - } + } + fn select(self, on: R, off: R) -> R { + if self { + on + } else { + off } - }; -} -macro_rules! impl_common_binop { - ($t:ty,$proxy:ty) => { - impl_binop!($t, $proxy, AddAssign, add_assign, Add, add); - impl_binop!($t, $proxy, SubAssign, sub_assign, Sub, sub); - impl_binop!($t, $proxy, MulAssign, mul_assign, Mul, mul); - impl_binop!($t, $proxy, DivAssign, div_assign, Div, div); - impl_binop!($t, $proxy, RemAssign, rem_assign, Rem, rem); - }; + } } -macro_rules! impl_int_binop { - ($t:ty,$proxy:ty) => { - impl_binop!($t, $proxy, ShlAssign, shl_assign, Shl, shl); - impl_binop!($t, $proxy, ShrAssign, shr_assign, Shr, shr); - impl_binop!($t, $proxy, BitAndAssign, bitand_assign, BitAnd, bitand); - impl_binop!($t, $proxy, BitOrAssign, bitor_assign, BitOr, bitor); - impl_binop!($t, $proxy, BitXorAssign, bitxor_assign, BitXor, bitxor); - }; +impl SelectMaybeExpr for Expr { + fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R { + crate::lang::control_flow::if_then_else(self, on, off) + } + fn select(self, on: R, off: R) -> R { + crate::lang::control_flow::select(self, on, off) + } } -macro_rules! impl_not { - ($t:ty,$proxy:ty) => { - impl Not for $proxy { - type Output = prim::Expr<$t>; - fn not(self) -> Self::Output { - __current_scope(|s| { - let ret = s.call(Func::BitNot, &[ToNode::node(&self)], Self::Output::type_()); - Expr::<$t>::from_node(ret) - }) - } +impl ActivateMaybeExpr for bool { + fn activate(self, then: impl FnOnce()) { + if self { + then() } - }; + } } -macro_rules! impl_neg { - ($t:ty,$proxy:ty) => { - impl Neg for $proxy { - type Output = prim::Expr<$t>; - fn neg(self) -> Self::Output { - __current_scope(|s| { - let ret = s.call(Func::Neg, &[ToNode::node(&self)], Self::Output::type_()); - Expr::<$t>::from_node(ret) - }) - } +impl ActivateMaybeExpr for Expr { + fn activate(self, then: impl FnOnce()) { + crate::lang::control_flow::if_then_else(self, then, || {}) + } +} + +impl LoopMaybeExpr for bool { + fn while_loop(mut cond: impl FnMut() -> Self, mut body: impl FnMut()) { + while cond() { + body() } - }; + } +} + +impl LoopMaybeExpr for Expr { + fn while_loop(cond: impl FnMut() -> Self, body: impl FnMut()) { + crate::lang::control_flow::generic_loop(cond, body, || {}) + } +} + +impl LazyBoolMaybeExpr for bool { + type Bool = bool; + fn and(self, other: impl FnOnce() -> bool) -> bool { + self && other() + } + fn or(self, other: impl FnOnce() -> bool) -> bool { + self || other() + } } -macro_rules! impl_fneg { - ($t:ty, $proxy:ty) => { - impl Neg for $proxy { - type Output = prim::Expr<$t>; - fn neg(self) -> Self::Output { - __current_scope(|s| { - let ret = s.call(Func::Neg, &[ToNode::node(&self)], Self::Output::type_()); - Expr::<$t>::from_node(ret) - }) - } +impl LazyBoolMaybeExpr> for bool { + type Bool = Expr; + fn and(self, other: impl FnOnce() -> Expr) -> Self::Bool { + if self { + other() + } else { + false.expr() } - }; + } + fn or(self, other: impl FnOnce() -> Expr) -> Self::Bool { + if self { + true.expr() + } else { + other() + } + } } -impl Not for prim::Expr { - type Output = prim::Expr; - fn not(self) -> Self::Output { - __current_scope(|s| { - let ret = s.call(Func::BitNot, &[ToNode::node(&self)], Self::Output::type_()); - FromNode::from_node(ret) - }) - } -} -impl_common_binop!(f16, prim::Expr); -impl_common_binop!(f32, prim::Expr); -impl_common_binop!(f64, prim::Expr); -impl_common_binop!(i16, prim::Expr); -impl_common_binop!(i32, prim::Expr); -impl_common_binop!(i64, prim::Expr); -impl_common_binop!(u16, prim::Expr); -impl_common_binop!(u32, prim::Expr); -impl_common_binop!(u64, prim::Expr); - -impl_binop!( - bool, - prim::Expr, - BitAndAssign, - bitand_assign, - BitAnd, - bitand -); -impl_binop!( - bool, - prim::Expr, - BitOrAssign, - bitor_assign, - BitOr, - bitor -); -impl_binop!( - bool, - prim::Expr, - BitXorAssign, - bitxor_assign, - BitXor, - bitxor -); -impl_int_binop!(i16, prim::Expr); -impl_int_binop!(i32, prim::Expr); -impl_int_binop!(i64, prim::Expr); -impl_int_binop!(u16, prim::Expr); -impl_int_binop!(u32, prim::Expr); -impl_int_binop!(u64, prim::Expr); - -impl_not!(i16, prim::Expr); -impl_not!(i32, prim::Expr); -impl_not!(i64, prim::Expr); -impl_not!(u16, prim::Expr); -impl_not!(u32, prim::Expr); -impl_not!(u64, prim::Expr); - -impl_neg!(i16, prim::Expr); -impl_neg!(i32, prim::Expr); -impl_neg!(i64, prim::Expr); -impl_neg!(u16, prim::Expr); -impl_neg!(u32, prim::Expr); -impl_neg!(u64, prim::Expr); - -impl_fneg!(f16, prim::Expr); -impl_fneg!(f32, prim::Expr); -impl_fneg!(f64, prim::Expr); - -macro_rules! impl_assign_ops { - ($ass:ident, $ass_m:ident, $o:ident, $o_m:ident) => { - impl std::ops::$ass for VarDerefProxy - where - P: VarProxy, - Expr: std::ops::$o>, - { - fn $ass_m(&mut self, rhs: Rhs) { - *self.deref_mut() = std::ops::$o::$o_m(**self, rhs); - } +impl LazyBoolMaybeExpr for Expr { + type Bool = Expr; + fn and(self, other: impl FnOnce() -> bool) -> Self::Bool { + if other() { + self + } else { + false.expr() } - }; + } + fn or(self, other: impl FnOnce() -> bool) -> Self::Bool { + if other() { + true.expr() + } else { + self + } + } +} +impl LazyBoolMaybeExpr for Expr { + type Bool = Expr; + fn and(self, other: impl FnOnce() -> Expr) -> Self::Bool { + crate::lang::control_flow::if_then_else(self, other, || false.expr()) + } + fn or(self, other: impl FnOnce() -> Expr) -> Self::Bool { + crate::lang::control_flow::if_then_else(self, || true.expr(), other) + } +} + +impl EqMaybeExpr for T +where + T: EqExpr, +{ + type Bool = >::Output; + fn __eq(self, other: S) -> Self::Bool { + self.eq(other) + } + fn __ne(self, other: S) -> Self::Bool { + self.ne(other) + } +} +impl EqMaybeExpr for T +where + T: PartialEq, +{ + type Bool = bool; + fn __eq(self, other: S) -> Self::Bool { + self == other + } + fn __ne(self, other: S) -> Self::Bool { + self != other + } +} + +impl CmpMaybeExpr for T +where + T: CmpExpr, +{ + type Bool = >::Output; + fn __lt(self, other: S) -> Self::Bool { + self.lt(other) + } + fn __le(self, other: S) -> Self::Bool { + self.le(other) + } + fn __gt(self, other: S) -> Self::Bool { + self.gt(other) + } + fn __ge(self, other: S) -> Self::Bool { + self.ge(other) + } +} +impl CmpMaybeExpr for T +where + T: PartialOrd, +{ + type Bool = bool; + fn __lt(self, other: S) -> Self::Bool { + self < other + } + fn __le(self, other: S) -> Self::Bool { + self <= other + } + fn __gt(self, other: S) -> Self::Bool { + self > other + } + fn __ge(self, other: S) -> Self::Bool { + self >= other + } } -impl_assign_ops!(AddAssign, add_assign, Add, add); -impl_assign_ops!(SubAssign, sub_assign, Sub, sub); -impl_assign_ops!(MulAssign, mul_assign, Mul, mul); -impl_assign_ops!(DivAssign, div_assign, Div, div); -impl_assign_ops!(RemAssign, rem_assign, Rem, rem); -impl_assign_ops!(BitAndAssign, bitand_assign, BitAnd, bitand); -impl_assign_ops!(BitOrAssign, bitor_assign, BitOr, bitor); -impl_assign_ops!(BitXorAssign, bitxor_assign, BitXor, bitxor); -impl_assign_ops!(ShlAssign, shl_assign, Shl, shl); -impl_assign_ops!(ShrAssign, shr_assign, Shr, shr); diff --git a/luisa_compute/src/lang/ops/spread.rs b/luisa_compute/src/lang/ops/spread.rs new file mode 100644 index 0000000..2abab65 --- /dev/null +++ b/luisa_compute/src/lang/ops/spread.rs @@ -0,0 +1,356 @@ +use super::*; +use traits::*; + +pub trait SpreadOps { + type Join: Value; + fn lift_self(x: Self) -> Expr; + fn lift_other(x: Other) -> Expr; +} + +macro_rules! impl_spread { + (@sym [$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { + impl<$($bounds)*> SpreadOps<$S> for $T { + type Join = $J; + fn lift_self($x: $T) -> Expr { + $f + } + fn lift_other($y: $S) -> Expr { + $g + } + } + }; + ([$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { + impl<$($bounds)*> SpreadOps<$S> for $T { + type Join = $J; + fn lift_self($x: $T) -> Expr { + $f + } + fn lift_other($x: $S) -> Expr { + $g + } + } + impl<$($bounds)*> SpreadOps<$T> for $S { + type Join = $J; + fn lift_self($y: $S) -> Expr { + $g + } + fn lift_other($y: $T) -> Expr { + $f + } + } + }; +} + +macro_rules! call_linear_fn_spread { + ($f:ident [$($bounds:tt)*]($T:ty)) => { + $f!([$($bounds)*] $T: |x| x.expr(), Expr<$T>: |x| x => Expr<$T>); + $f!(['a, $($bounds)*] &'a $T: |x| x.expr(), Expr<$T>: |x| x => Expr<$T>); + $f!(['b, $($bounds)*] $T: |x| x.expr(), &'b Expr<$T>: |x| x.clone() => Expr<$T>); + $f!(['a, 'b, $($bounds)*] &'a $T: |x| x.expr(), &'b Expr<$T>: |x| x.clone() => Expr<$T>); + + $f!(['b, $($bounds)*] Expr<$T>: |x| x, &'b Expr<$T>: |x| x.clone() => Expr<$T>); + $f!(['b, $($bounds)*] Var<$T>: |x| x.load(), &'b Var<$T>: |x| x.load() => Expr<$T>); + $f!(@sym ['a, 'b, $($bounds)*] &'a Expr<$T>: |x| x.clone(), &'b Expr<$T>: |x| x.clone() => Expr<$T>); + $f!(@sym [$($bounds)*] Var<$T>: |x| x.load(), Var<$T>: |x| x.load() => Expr<$T>); + $f!(@sym ['a, 'b, $($bounds)*] &'a Var<$T>: |x| x.load(), &'b Var<$T>: |x| x.load() => Expr<$T>); + + $f!([$($bounds)*] $T: |x| x.expr(), Var<$T>: |x| x.load() => Expr<$T>); + $f!(['a, $($bounds)*] &'a $T: |x| x.expr(), Var<$T>: |x| x.load() => Expr<$T>); + $f!(['b, $($bounds)*] $T: |x| x.expr(), &'b Var<$T>: |x| x.load() => Expr<$T>); + $f!(['a, 'b, $($bounds)*] &'a $T: |x| x.expr(), &'b Var<$T>: |x| x.load() => Expr<$T>); + + $f!(['a, $($bounds)*] &'a Expr<$T>: |x| x.clone(), Var<$T>: |x| x.load() => Expr<$T>); + $f!(['a, 'b, $($bounds)*] &'a Expr<$T>: |x| x.clone(), &'b Var<$T>: |x| x.load() => Expr<$T>); + $f!([$($bounds)*] Expr<$T>: |x| x, Var<$T>: |x| x.load() => Expr<$T>); + $f!(['b, $($bounds)*] Expr<$T>: |x| x, &'b Var<$T>: |x| x.load() => Expr<$T>); + }; + ($f:ident [$T:ident]) => { + call_linear_fn_spread!($f [$T: Linear]($T)); + } +} + +call_linear_fn_spread!(impl_spread[T]); + +macro_rules! call_vector_fn_spread { + ($f:ident [$($bounds:tt)*]($N:tt, $T:ty) $Vt:ty, $Vsplat:path) => { + $f!([$($bounds)*] $T: |x| $Vsplat(x), Expr<$Vt>: |x| x => Expr<$Vt>); + $f!(['a, $($bounds)*] &'a $T: |x| $Vsplat(*x), Expr<$Vt>: |x| x => Expr<$Vt>); + $f!([$($bounds)*] Expr<$T>: |x| $Vsplat(x), Expr<$Vt>: |x| x => Expr<$Vt>); + $f!(['a, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), Expr<$Vt>: |x| x => Expr<$Vt>); + $f!(['b, $($bounds)*] $T: |x| $Vsplat(x), &'b Expr<$Vt>: |x| x.clone() => Expr<$Vt>); + $f!(['a, 'b, $($bounds)*] &'a $T: |x| $Vsplat(*x), &'b Expr<$Vt>: |x| x.clone() => Expr<$Vt>); + $f!(['b, $($bounds)*] Expr<$T>: |x| $Vsplat(x), &'b Expr<$Vt>: |x| x.clone() => Expr<$Vt>); + $f!(['a, 'b, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), &'b Expr<$Vt>: |x| x.clone() => Expr<$Vt>); + + $f!([$($bounds)*] Expr<$T>: |x| $Vsplat(x), $Vt: |x| x.expr() => Expr<$Vt>); + $f!(['a, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), $Vt: |x| x.expr() => Expr<$Vt>); + $f!(['b, $($bounds)*] Expr<$T>: |x| $Vsplat(x), &'b $Vt: |x| x.expr() => Expr<$Vt>); + $f!(['a, 'b, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), &'b $Vt: |x| x.expr() => Expr<$Vt>); + + $f!([$($bounds)*] $T: |x| $Vsplat(x), Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['a, $($bounds)*] &'a $T: |x| $Vsplat(*x), Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!([$($bounds)*] Expr<$T>: |x| $Vsplat(x), Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['a, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['b, $($bounds)*] $T: |x| $Vsplat(x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['a, 'b, $($bounds)*] &'a $T: |x| $Vsplat(*x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['b, $($bounds)*] Expr<$T>: |x| $Vsplat(x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['a, 'b, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); + + $f!([$($bounds)*] Var<$T>: |x| $Vsplat(x), Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['a, $($bounds)*] &'a Var<$T>: |x| $Vsplat(x), Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['b, $($bounds)*] Var<$T>: |x| $Vsplat(x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['a, 'b, $($bounds)*] &'a Var<$T>: |x| $Vsplat(x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); + }; + ($f:ident [$($bounds:tt)*]($N:tt, $T:ty)) => { + call_vector_fn_spread!($f[$($bounds)*]($N, $T) Vector<$T, $N>, Vector::<$T, $N>::splat_expr); + }; + ($f:ident[$N:ident, $T:ident]) => { + call_vector_fn_spread!($f[const $N: usize, $T: VectorAlign<$N>]($N, $T)); + } +} + +call_vector_fn_spread!(impl_spread[N, T]); + +mod trait_impls { + use super::*; + impl MinMaxExpr for T + where + T: SpreadOps, + Expr: MinMaxThis, + { + type Output = Expr; + fn max(self, other: S) -> Self::Output { + Expr::::_max(Self::lift_self(self), Self::lift_other(other)) + } + fn min(self, other: S) -> Self::Output { + Expr::::_min(Self::lift_self(self), Self::lift_other(other)) + } + } + impl ClampExpr for T + where + S: SpreadOps, + T: SpreadOps>, + Expr: ClampThis, + { + /// T::Join + /// / \ + /// / \ + /// / \ + /// / \ + /// / S::Join + /// / / \ + /// / / \ + /// / / \ + /// / / \ + /// T S U + + type Output = Expr; + fn clamp(self, min: S, max: U) -> Self::Output { + Expr::::_clamp( + Self::lift_self(self), + Self::lift_other(S::lift_self(min)), + Self::lift_other(S::lift_other(max)), + ) + } + } + impl EqExpr for T + where + T: SpreadOps, + Expr: EqThis, + { + type Output = as EqThis>::Output; + fn eq(self, other: S) -> Self::Output { + Expr::::_eq(Self::lift_self(self), Self::lift_other(other)) + } + fn ne(self, other: S) -> Self::Output { + Expr::::_ne(Self::lift_self(self), Self::lift_other(other)) + } + } + impl CmpExpr for T + where + T: SpreadOps, + Expr: CmpThis, + { + type Output = as CmpThis>::Output; + fn lt(self, other: S) -> Self::Output { + Expr::::_lt(Self::lift_self(self), Self::lift_other(other)) + } + fn le(self, other: S) -> Self::Output { + Expr::::_le(Self::lift_self(self), Self::lift_other(other)) + } + fn gt(self, other: S) -> Self::Output { + Expr::::_gt(Self::lift_self(self), Self::lift_other(other)) + } + fn ge(self, other: S) -> Self::Output { + Expr::::_ge(Self::lift_self(self), Self::lift_other(other)) + } + } + impl FloatMulAddExpr for T + where + S: SpreadOps, + T: SpreadOps>, + Expr: FloatMulAddThis, + { + type Output = Expr; + fn mul_add(self, mul: S, add: U) -> Self::Output { + Expr::::_mul_add( + Self::lift_self(self), + Self::lift_other(S::lift_self(mul)), + Self::lift_other(S::lift_other(add)), + ) + } + } + impl FloatCopySignExpr for T + where + T: SpreadOps, + Expr: FloatCopySignThis, + { + type Output = Expr; + fn copy_sign(self, sign: S) -> Self::Output { + Expr::::_copy_sign(Self::lift_self(self), Self::lift_other(sign)) + } + } + impl FloatStepExpr for T + where + T: SpreadOps, + Expr: FloatStepThis, + { + type Output = Expr; + fn step(self, edge: S) -> Self::Output { + Expr::::_step(Self::lift_self(self), Self::lift_other(edge)) + } + } + impl FloatSmoothStepExpr for T + where + S: SpreadOps, + T: SpreadOps>, + Expr: FloatSmoothStepThis, + { + type Output = Expr; + fn smooth_step(self, edge0: S, edge1: U) -> Self::Output { + Expr::::_smooth_step( + Self::lift_self(self), + Self::lift_other(S::lift_self(edge0)), + Self::lift_other(S::lift_other(edge1)), + ) + } + } + impl FloatArcTan2Expr for T + where + T: SpreadOps, + Expr: FloatArcTan2This, + { + type Output = Expr; + fn atan2(self, other: S) -> Self::Output { + Expr::::_atan2(Self::lift_self(self), Self::lift_other(other)) + } + } + impl FloatLogExpr for T + where + T: SpreadOps, + Expr: FloatLogThis, + { + type Output = Expr; + fn log(self, base: S) -> Self::Output { + Expr::::_log(Self::lift_self(self), Self::lift_other(base)) + } + } + impl FloatPowfExpr for T + where + T: SpreadOps, + Expr: FloatPowfThis, + { + type Output = Expr; + fn powf(self, exponent: S) -> Self::Output { + Expr::::_powf(Self::lift_self(self), Self::lift_other(exponent)) + } + } + impl FloatLerpExpr for T + where + S: SpreadOps, + T: SpreadOps>, + Expr: FloatLerpThis, + { + type Output = Expr; + fn lerp(self, other: S, frac: U) -> Self::Output { + Expr::::_lerp( + Self::lift_self(self), + Self::lift_other(S::lift_self(other)), + Self::lift_other(S::lift_other(frac)), + ) + } + } +} +macro_rules! impl_spread_op { + ([ $($bounds:tt)* ]: $Op:ident::$op_fn:ident for $T:ty, $S:ty) => { + impl<$($bounds)*> $Op <$S> for $T where $T: SpreadOps<$S>, Expr<<$T as SpreadOps<$S>>::Join>: $Op { + type Output = >::Join> as $Op>::Output; + fn $op_fn (self, other: $S) -> Self::Output { + >::Join> as $Op>::$op_fn (<$T as SpreadOps<$S>>::lift_self(self), <$T as SpreadOps<$S>>::lift_other(other)) + } + } + } +} + +macro_rules! impl_num_spread_single { + ([ $($bounds:tt)* ] $T:ty, $S:ty) => { + impl_spread_op!( [ $($bounds)* ]: Add::add for $T, $S); + impl_spread_op!( [ $($bounds)* ]: Sub::sub for $T, $S); + impl_spread_op!( [ $($bounds)* ]: Mul::mul for $T, $S); + impl_spread_op!( [ $($bounds)* ]: Div::div for $T, $S); + impl_spread_op!( [ $($bounds)* ]: Rem::rem for $T, $S); + } +} +macro_rules! impl_int_spread_single { + ([ $($bounds:tt)* ] $T:ty, $S:ty) => { + impl_spread_op!([ $($bounds)* ]: BitAnd::bitand for $T, $S); + impl_spread_op!([ $($bounds)* ]: BitOr::bitor for $T, $S); + impl_spread_op!([ $($bounds)* ]: BitXor::bitxor for $T, $S); + impl_spread_op!([ $($bounds)* ]: Shl::shl for $T, $S); + impl_spread_op!([ $($bounds)* ]: Shr::shr for $T, $S); + } +} + +macro_rules! impl_num_spread { + (@sym [$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { + impl_num_spread_single!([$($bounds)*] $T, $S); + }; + ([$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { + impl_num_spread_single!([$($bounds)*] $T, $S); + impl_num_spread_single!([$($bounds)*] $S, $T); + } +} +macro_rules! impl_int_spread { + (@sym [$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { + impl_int_spread_single!([$($bounds)*] $T, $S); + }; + ([$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { + impl_int_spread_single!([$($bounds)*] $T, $S); + impl_int_spread_single!([$($bounds)*] $S, $T); + } +} +macro_rules! call_spreads { + ($f:ident: $($T:ty),+) => { + $( + call_linear_fn_spread!($f []($T)); + call_vector_fn_spread!($f [](2, $T)); + call_vector_fn_spread!($f [](3, $T)); + call_vector_fn_spread!($f [](4, $T)); + )+ + }; +} +call_spreads!(impl_num_spread: f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64); +call_spreads!(impl_int_spread: bool, i8, i16, i32, i64, u8, u16, u32, u64); + +#[allow(dead_code)] +mod tests { + use super::*; + fn test() { + let x = 10.0f32; + let y = Vector::<_, 2>::splat(20.0f32); + let x = x.expr(); + + let w = (&x.var()).min(&0.0_f32.expr()); + println!("{:?}", w); + } +} diff --git a/luisa_compute/src/lang/ops/traits.rs b/luisa_compute/src/lang/ops/traits.rs new file mode 100644 index 0000000..a4561e1 --- /dev/null +++ b/luisa_compute/src/lang/ops/traits.rs @@ -0,0 +1,194 @@ +use super::*; + +// The double trait implementation is necessary as the compiler infinite loops +// when trying to resolve the Expr: SpreadOps>> bound. +macro_rules! ops_trait { + ( + $TraitExpr:ident<$($T:ident),*> [ $TraitThis:ident] { + $( + fn $fn:ident [$fn_this:ident] (self, $($arg:ident: $S:ident),*); + )+ + } + ) => { + pub trait $TraitThis { + $( + fn $fn_this(self, $($arg: Self),*) -> Self; + )* + } + pub trait $TraitExpr<$($T = Self),*> { + type Output; + + $( + fn $fn(self, $($arg: $S),*) -> Self::Output; + )* + } + }; + ( + $TraitExpr:ident<$($T:ident),*> [ $TraitThis:ident] { + type Output; + $( + fn $fn:ident [$fn_this:ident] (self, $($arg:ident: $S:ident),*); + )+ + } + ) => { + pub trait $TraitThis { + type Output; + $( + fn $fn_this(self, $($arg: Self),*) -> Self::Output; + )* + } + pub trait $TraitExpr<$($T = Self),*> { + type Output; + + $( + fn $fn(self, $($arg: $S),*) -> Self::Output; + )* + } + } +} + +ops_trait!(MinMaxExpr[MinMaxThis] { + fn max[_max](self, other: T); + fn min[_min](self, other: T); +}); + +ops_trait!(ClampExpr[ClampThis] { + fn clamp[_clamp](self, min: A, max: B); +}); + +pub trait AbsExpr { + fn abs(&self) -> Self; +} + +ops_trait!(EqExpr[EqThis] { + type Output; + + fn eq[_eq](self, other: T); + fn ne[_ne](self, other: T); +}); + +ops_trait!(CmpExpr[CmpThis] { + type Output; + + fn lt[_lt](self, other: T); + fn le[_le](self, other: T); + fn gt[_gt](self, other: T); + fn ge[_ge](self, other: T); +}); + +pub trait IntExpr { + fn rotate_right(&self, n: Expr) -> Self; + fn rotate_left(&self, n: Expr) -> Self; +} + +pub trait FloatExpr: Sized { + type Bool; + + fn ceil(&self) -> Self; + fn floor(&self) -> Self; + fn round(&self) -> Self; + fn trunc(&self) -> Self; + fn sqrt(&self) -> Self; + fn rsqrt(&self) -> Self; + fn fract(&self) -> Self; + fn saturate(&self) -> Self; + fn sin(&self) -> Self; + fn cos(&self) -> Self; + fn tan(&self) -> Self; + fn asin(&self) -> Self; + fn acos(&self) -> Self; + fn atan(&self) -> Self; + fn sinh(&self) -> Self; + fn cosh(&self) -> Self; + fn tanh(&self) -> Self; + fn asinh(&self) -> Self; + fn acosh(&self) -> Self; + fn atanh(&self) -> Self; + fn exp(&self) -> Self; + fn exp2(&self) -> Self; + fn is_finite(&self) -> Self::Bool; + fn is_infinite(&self) -> Self::Bool; + fn is_nan(&self) -> Self::Bool; + fn ln(&self) -> Self; + fn log2(&self) -> Self; + fn log10(&self) -> Self; + fn sqr(&self) -> Self; + fn cube(&self) -> Self; + fn recip(&self) -> Self; + fn sin_cos(&self) -> (Self, Self); +} + +ops_trait!(FloatMulAddExpr[FloatMulAddThis] { + fn mul_add[_mul_add](self, a: A, b: B); +}); + +ops_trait!(FloatCopySignExpr[FloatCopySignThis] { + fn copy_sign[_copy_sign](self, sign: T); +}); + +ops_trait!(FloatStepExpr[FloatStepThis] { + fn step[_step](self, edge: T); +}); + +ops_trait!(FloatSmoothStepExpr[FloatSmoothStepThis] { + fn smooth_step[_smooth_step](self, edge0: T, edge1: S); +}); + +ops_trait!(FloatArcTan2Expr[FloatArcTan2This] { + fn atan2[_atan2](self, other: T); +}); + +ops_trait!(FloatLogExpr[FloatLogThis] { + fn log[_log](self, base: T); +}); + +ops_trait!(FloatPowfExpr[FloatPowfThis] { + fn powf[_powf](self, exponent: T); +}); + +pub trait FloatPowiExpr { + type Output; + + fn powi(self, exponent: T) -> Self::Output; +} + +ops_trait!(FloatLerpExpr[FloatLerpThis] { + fn lerp[_lerp](self, other: A, frac: B); +}); + +pub trait StoreMaybeExpr { + fn store(self, value: V); +} + +pub trait SelectMaybeExpr { + fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R; + fn select(self, on: R, off: R) -> R; +} + +pub trait ActivateMaybeExpr { + fn activate(self, then: impl FnOnce()); +} + +pub trait LoopMaybeExpr { + fn while_loop(cond: impl FnMut() -> Self, body: impl FnMut()); +} + +pub trait LazyBoolMaybeExpr { + type Bool; + fn and(self, other: impl FnOnce() -> T) -> Self::Bool; + fn or(self, other: impl FnOnce() -> T) -> Self::Bool; +} + +pub trait EqMaybeExpr { + type Bool; + fn __eq(self, other: T) -> Self::Bool; + fn __ne(self, other: T) -> Self::Bool; +} + +pub trait CmpMaybeExpr { + type Bool; + fn __lt(self, other: T) -> Self::Bool; + fn __le(self, other: T) -> Self::Bool; + fn __gt(self, other: T) -> Self::Bool; + fn __ge(self, other: T) -> Self::Bool; +} diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs index dc8f3f1..27b11af 100644 --- a/luisa_compute/src/lang/types.rs +++ b/luisa_compute/src/lang/types.rs @@ -1,168 +1,330 @@ -use std::any::Any; -use std::cell::Cell; -use std::ops::{Deref, DerefMut}; +use std::ops::Deref; use crate::internal_prelude::*; +pub mod alignment; pub mod array; pub mod core; pub mod dynamic; pub mod shared; pub mod vector; -pub type Expr = ::Expr; -pub type Var = ::Var; +// TODO: Check up on comments. -pub trait Value: Copy + ir::TypeOf + 'static { +/// A value that can be used in a [`Kernel`] or [`Callable`]. Call [`expr`] or +/// [`var`] to convert into a kernel-trackable type. +pub trait Value: Copy + TypeOf + 'static { + /// A proxy for additional impls on [`Expr`]. type Expr: ExprProxy; + /// A proxy for additional impls on [`Var`]. type Var: VarProxy; - fn fields() -> Vec; - fn expr(self) -> Self::Expr { - const_(self) + /// The type of the custom data within an [`Expr`]. + type ExprData: Clone + FromNode + 'static; + /// The type of the custom data within an [`Var`]. + type VarData: Clone + FromNode + 'static; + + fn expr(self) -> Expr { + let node = __current_scope(|s| -> NodeRef { + let mut buf = vec![0u8; std::mem::size_of::()]; + unsafe { + std::ptr::copy_nonoverlapping( + &self as *const Self as *const u8, + buf.as_mut_ptr(), + buf.len(), + ); + } + s.const_(Const::Generic(CBoxedSlice::new(buf), Self::type_())) + }); + Expr::::from_node(node) } - fn var(self) -> Self::Var { - local::(self.expr()) + fn var(self) -> Var { + self.expr().var() } } -pub trait ExprProxy: Copy + Aggregate + FromNode { +/// A trait for implementing remote impls on top of an [`Expr`] using [`Deref`]. +/// +/// For example, `Expr<[f32; 4]>` dereferences to `ArrayExpr`, which +/// exposes an [`Index`](std::ops::Index) impl. +pub trait ExprProxy: Clone + HasExprLayout<::ExprData> + 'static { type Value: Value; +} - fn var(self) -> Var { - def(self) - } +/// A trait for implementing remote impls on top of an [`Var`] using [`Deref`]. +/// +/// For example, `Var<[f32; 4]>` dereferences to `ArrayVar`, which +/// exposes [`Index`](std::ops::Index) and [`IndexMut`](std::ops::IndexMut) +/// impls. +pub trait VarProxy: + Clone + HasVarLayout<::VarData> + Deref> + 'static +{ + type Value: Value; +} +/// This marker trait states that `Self` has the same layout as an [`Expr`] +/// with `T::ExprData = X`. +pub unsafe trait HasExprLayout {} +/// This marker trait states that `Self` has the same layout as an [`Var`] +/// with `T::VarData = X`. +pub unsafe trait HasVarLayout {} - fn zeroed() -> Self { - zeroed::() +/// An expression within a [`Kernel`] or [`Callable`]. Created from a raw value +/// using [`Value::expr`]. +/// +/// Note that this does not store the value, and in order to get the result of a +/// function returning an `Expr`, you must call [`Kernel::dispatch`]. +#[derive(Debug, Clone)] +#[repr(C)] +pub struct Expr { + pub(crate) node: NodeRef, + _marker: PhantomData, + /// Custom data stored within the expression. + pub data: T::ExprData, +} + +/// A variable within a [`Kernel`] or [`Callable`]. Created using [`Expr::var`] +/// and [`Value::var`]. +/// +/// Note that setting a `Var` using direct assignment will not work. Instead, +/// either use the [`store`](Var::store) method or the `track!` macro and `*var +/// = expr` syntax. +#[derive(Debug, Clone)] +#[repr(C)] +pub struct Var { + pub(crate) node: NodeRef, + _marker: PhantomData, + /// Custom data stored within the variable. + pub data: T::VarData, +} + +impl Copy for Expr where T::ExprData: Copy {} +impl Copy for Var where T::VarData: Copy {} + +impl Aggregate for Expr { + fn to_nodes(&self, nodes: &mut Vec) { + nodes.push(self.node); + } + fn from_nodes>(iter: &mut I) -> Self { + let node = iter.next().unwrap(); + Self::from_node(node) + } +} +impl FromNode for Expr { + fn from_node(node: NodeRef) -> Self { + Self { + node, + _marker: PhantomData, + data: T::ExprData::from_node(node), + } + } +} +impl ToNode for Expr { + fn node(&self) -> NodeRef { + self.node } } -pub trait VarProxy: Copy + Aggregate + FromNode { - type Value: Value; - fn store>>(&self, value: U) { - let value = value.into(); - super::_store(self, &value); +impl Aggregate for Var { + fn to_nodes(&self, nodes: &mut Vec) { + nodes.push(self.node); } - fn load(&self) -> Expr { - __current_scope(|b| { - let nodes = self.to_vec_nodes(); - let mut ret = vec![]; - for node in nodes { - ret.push(b.call(Func::Load, &[node], node.type_().clone())); - } - Expr::::from_nodes(&mut ret.into_iter()) - }) + fn from_nodes>(iter: &mut I) -> Self { + let node = iter.next().unwrap(); + Self::from_node(node) } - fn get_mut(&self) -> VarDerefProxy { - VarDerefProxy { - var: *self, - dirty: Cell::new(false), - assigned: self.load(), - _phantom: PhantomData, +} +impl FromNode for Var { + fn from_node(node: NodeRef) -> Self { + Self { + node, + _marker: PhantomData, + data: T::VarData::from_node(node), } } - fn _deref<'a>(&'a self) -> &'a Expr { +} +impl ToNode for Var { + fn node(&self) -> NodeRef { + self.node + } +} + +impl Deref for Expr { + type Target = T::Expr; + fn deref(&self) -> &Self::Target { + unsafe { &*(self as *const Self as *const T::Expr) } + } +} +impl Deref for Var { + type Target = T::Var; + fn deref(&self) -> &Self::Target { + unsafe { &*(self as *const Self as *const T::Var) } + } +} + +impl Expr { + pub fn var(self) -> Var { + Var::::from_node(__current_scope(|b| b.local(self.node()))) + } + pub fn zeroed() -> Self { + FromNode::from_node(__current_scope(|b| b.zero_initializer(T::type_()))) + } + pub fn _ref<'a>(self) -> &'a Self { RECORDER.with(|r| { - let v: Expr = self.load(); let r = r.borrow(); - let v: &Expr = r.arena.alloc(v); + let v: &Expr = r.arena.alloc(self); unsafe { - let v: &'a Expr = std::mem::transmute(v); + let v: &'a Expr = std::mem::transmute(v); v } }) } - fn zeroed() -> Self { - local_zeroed::() + pub unsafe fn bitcast(self) -> Expr { + assert_eq!(std::mem::size_of::(), std::mem::size_of::()); + let ty = S::type_(); + let node = __current_scope(|s| s.bitcast(self.node(), ty)); + Expr::::from_node(node) } } -pub struct VarDerefProxy -where - P: VarProxy, -{ - pub(crate) var: P, - pub(crate) dirty: Cell, - pub(crate) assigned: Expr, - pub(crate) _phantom: PhantomData, +impl Var { + pub fn zeroed() -> Self { + Self::from_node(__current_scope(|b| { + b.local_zero_init(::type_()) + })) + } + pub fn load(&self) -> Expr { + __current_scope(|b| { + let nodes = self.to_vec_nodes(); + let mut ret = vec![]; + for node in nodes { + ret.push(b.call(Func::Load, &[node], node.type_().clone())); + } + Expr::::from_nodes(&mut ret.into_iter()) + }) + } } -impl Deref for VarDerefProxy -where - P: VarProxy, -{ - type Target = Expr; - fn deref(&self) -> &Self::Target { - &self.assigned - } +pub fn _deref_proxy(proxy: &P) -> &Expr { + unsafe { &*(proxy as *const P as *const Var) } + .load() + ._ref() } -impl DerefMut for VarDerefProxy -where - P: VarProxy, -{ - fn deref_mut(&mut self) -> &mut Self::Target { - self.dirty.set(true); - &mut self.assigned +#[macro_export] +macro_rules! impl_simple_expr_proxy { + ($([ $($bounds:tt)* ])? $name: ident $([ $($qualifiers:tt)* ])? for $t: ty $(where $($where_bounds:tt)+)?) => { + #[derive(Debug, Clone, Copy)] + #[repr(transparent)] + pub struct $name $(< $($bounds)* >)? ($crate::lang::types::Expr<$t>) $(where $($where_bounds)+)?; + unsafe impl $(< $($bounds)* >)? $crate::lang::types::HasExprLayout< <$t as $crate::lang::types::Value>::ExprData > + for $name $(< $($qualifiers)* >)? $(where $($where_bounds)+)? {} + impl $(< $($bounds)* >)? $crate::lang::types::ExprProxy for $name $(< $($qualifiers)* >)? $(where $($where_bounds)+)? { + type Value = $t; + } } } -impl Drop for VarDerefProxy -where - P: VarProxy, -{ - fn drop(&mut self) { - if self.dirty.get() { - self.var.store(self.assigned) +#[macro_export] +macro_rules! impl_simple_var_proxy { + ($([ $($bounds:tt)* ])? $name: ident $([ $($qualifiers:tt)* ])? for $t: ty $(where $($where_bounds:tt)+)?) => { + #[derive(Debug, Clone, Copy)] + #[repr(transparent)] + pub struct $name $(< $($bounds)* >)? ($crate::lang::types::Var<$t>) $(where $($where_bounds)+)?; + unsafe impl $(< $($bounds)* >)? $crate::lang::types::HasVarLayout< <$t as $crate::lang::types::Value>::VarData > + for $name $(< $($qualifiers)* >)? $(where $($where_bounds)+)? {} + impl $(< $($bounds)* >)? $crate::lang::types::VarProxy for $name $(< $($qualifiers)* >)? $(where $($where_bounds)+)? { + type Value = $t; + } + impl $(< $($bounds)* >)? std::ops::Deref for $name $(< $($qualifiers)* >)? $(where $($where_bounds)+)? { + type Target = $crate::lang::types::Expr<$t>; + fn deref(&self) -> &Self::Target { + $crate::lang::types::_deref_proxy(self) + } } } } -fn def, T: Value>(init: E) -> Var { - Var::::from_node(__current_scope(|b| b.local(init.node()))) +#[macro_export] +macro_rules! impl_marker_trait { + ($T:ident for $($t:ty),*) => { + $(impl $T for $t {})* + }; +} + +mod private { + use super::*; + + pub trait Sealed {} + impl Sealed for T {} + impl Sealed for Expr {} + impl Sealed for &Expr {} + impl Sealed for Var {} + impl Sealed for &Var {} + + impl Sealed for ValueType {} + impl Sealed for ExprType {} + impl Sealed for VarType {} } -fn local(init: Expr) -> Var { - Var::::from_node(__current_scope(|b| b.local(init.node()))) + +pub trait Tracked: private::Sealed { + type Type: TrackingType; + type Value: Value; } -fn local_zeroed() -> Var { - Var::::from_node(__current_scope(|b| { - b.local_zero_init(::type_()) - })) +pub trait TrackingType: private::Sealed {} +pub struct ValueType; +impl TrackingType for ValueType {} +pub struct ExprType; +impl TrackingType for ExprType {} +pub struct VarType; +impl TrackingType for VarType {} + +impl Tracked for T { + type Type = ValueType; + type Value = T; +} +impl Tracked for Expr { + type Type = ExprType; + type Value = T; +} +impl Tracked for &Expr { + type Type = ExprType; + type Value = T; +} +impl Tracked for Var { + type Type = VarType; + type Value = T; +} +impl Tracked for &Var { + type Type = VarType; + type Value = T; } -fn zeroed() -> T::Expr { - FromNode::from_node(__current_scope(|b| b.zero_initializer(T::type_()))) +pub trait AsExpr: Tracked { + fn as_expr(&self) -> Expr; } -fn const_(value: T) -> T::Expr { - let node = __current_scope(|s| -> NodeRef { - let any = &value as &dyn Any; - if let Some(value) = any.downcast_ref::() { - s.const_(Const::Bool(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Int32(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Uint32(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Int64(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Uint64(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Float32(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Float64(*value)) - } else { - let mut buf = vec![0u8; std::mem::size_of::()]; - unsafe { - std::ptr::copy_nonoverlapping( - &value as *const T as *const u8, - buf.as_mut_ptr(), - buf.len(), - ); - } - s.const_(Const::Generic(CBoxedSlice::new(buf), T::type_())) - } - }); - FromNode::from_node(node) +impl AsExpr for T { + fn as_expr(&self) -> Expr { + self.expr() + } +} +impl AsExpr for Expr { + fn as_expr(&self) -> Expr { + self.clone() + } +} +impl AsExpr for &Expr { + fn as_expr(&self) -> Expr { + (*self).clone() + } +} +impl AsExpr for Var { + fn as_expr(&self) -> Expr { + self.load() + } +} +impl AsExpr for &Var { + fn as_expr(&self) -> Expr { + self.load() + } } diff --git a/luisa_compute/src/lang/types/alignment.rs b/luisa_compute/src/lang/types/alignment.rs new file mode 100644 index 0000000..2b6977d --- /dev/null +++ b/luisa_compute/src/lang/types/alignment.rs @@ -0,0 +1,27 @@ +use super::*; +use std::hash::Hash; + +pub(crate) trait Alignment: Default + Copy + Hash + Eq + 'static { + const ALIGNMENT: usize; +} + +macro_rules! alignment { + ($T:ident, $align:literal) => { + #[derive(Copy, Clone, Debug, Hash, Default, PartialEq, Eq)] + #[repr(align($align))] + pub struct $T; + impl Alignment for $T { + const ALIGNMENT: usize = $align; + } + }; +} + +alignment!(Align1, 1); +alignment!(Align2, 2); +alignment!(Align4, 4); +alignment!(Align8, 8); +alignment!(Align16, 16); +alignment!(Align32, 32); +alignment!(Align64, 64); +alignment!(Align128, 128); +alignment!(Align256, 256); diff --git a/luisa_compute/src/lang/types/array.rs b/luisa_compute/src/lang/types/array.rs index e367966..04c6029 100644 --- a/luisa_compute/src/lang/types/array.rs +++ b/luisa_compute/src/lang/types/array.rs @@ -1,140 +1,37 @@ +use std::ops::Index; + use super::*; use crate::lang::index::IntoIndex; use ir::ArrayType; -#[derive(Clone, Copy, Debug)] -pub struct ArrayExpr { - marker: PhantomData, - node: NodeRef, -} - -#[derive(Clone, Copy, Debug)] -pub struct ArrayVar { - marker: PhantomData, - node: NodeRef, -} - -impl FromNode for ArrayExpr { - fn from_node(node: NodeRef) -> Self { - Self { - marker: PhantomData, - node, - } - } -} - -impl ToNode for ArrayExpr { - fn node(&self) -> NodeRef { - self.node - } -} - -impl Aggregate for ArrayExpr { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self::from_node(iter.next().unwrap()) - } -} - -impl FromNode for ArrayVar { - fn from_node(node: NodeRef) -> Self { - Self { - marker: PhantomData, - node, - } - } -} - -impl ToNode for ArrayVar { - fn node(&self) -> NodeRef { - self.node - } -} - -impl Aggregate for ArrayVar { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self::from_node(iter.next().unwrap()) - } -} - -impl ExprProxy for ArrayExpr { - type Value = [T; N]; +impl Value for [T; N] { + type Expr = ArrayExpr; + type Var = ArrayVar; + type ExprData = (); + type VarData = (); } -impl VarProxy for ArrayVar { - type Value = [T; N]; -} - -impl ArrayVar { - pub fn len(&self) -> Expr { - (N as u32).expr() - } -} +impl_simple_expr_proxy!([T: Value, const N: usize] ArrayExpr[T, N] for [T; N]); +impl_simple_var_proxy!([T: Value, const N: usize] ArrayVar[T, N] for [T; N]); impl ArrayExpr { - pub fn zero() -> Self { - let node = __current_scope(|b| b.call(Func::ZeroInitializer, &[], <[T; N]>::type_())); - Self::from_node(node) - } pub fn len(&self) -> Expr { (N as u32).expr() } } -impl IndexRead for ArrayExpr { - type Element = T; - fn read(&self, i: I) -> Expr { +impl Index for ArrayExpr { + type Output = Expr; + fn index(&self, i: X) -> &Self::Output { let i = i.to_u64(); - lc_assert!(i.cmplt((N as u64).expr())); - - Expr::::from_node(__current_scope(|b| { - b.call(Func::ExtractElement, &[self.node, i.node()], T::type_()) - })) - } -} - -impl IndexRead for ArrayVar { - type Element = T; - fn read(&self, i: I) -> Expr { - let i = i.to_u64(); - if need_runtime_check() { - lc_assert!(i.cmplt((N as u64).expr())); - } + // TODO: Add need_runtime_check()? + lc_assert!(i.lt((N as u64).expr())); Expr::::from_node(__current_scope(|b| { - let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); - b.call(Func::Load, &[gep], T::type_()) + b.call(Func::ExtractElement, &[self.0.node, i.node()], T::type_()) })) - } -} - -impl IndexWrite for ArrayVar { - fn write>>(&self, i: I, value: V) { - let i = i.to_u64(); - let value = value.into(); - - if need_runtime_check() { - lc_assert!(i.cmplt((N as u64).expr())); - } - - __current_scope(|b| { - let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); - b.update(gep, value.node()); - }); - } -} - -impl Value for [T; N] { - type Expr = ArrayExpr; - type Var = ArrayVar; - fn fields() -> Vec { - todo!("why this method exists?") + ._ref() } } @@ -202,7 +99,7 @@ impl VLArrayVar { pub fn read>>(&self, i: I) -> Expr { let i = i.into(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len()), "VLArrayVar::read out of bounds"); + lc_assert!(i.lt(self.len()), "VLArrayVar::read out of bounds"); } Expr::::from_node(__current_scope(|b| { @@ -227,7 +124,7 @@ impl VLArrayVar { let value = value.into(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len()), "VLArrayVar::read out of bounds"); + lc_assert!(i.lt(self.len()), "VLArrayVar::read out of bounds"); } __current_scope(|b| { @@ -278,7 +175,7 @@ impl VLArrayExpr { pub fn read(&self, i: I) -> Expr { let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::::from_node(__current_scope(|b| { diff --git a/luisa_compute/src/lang/types/core.rs b/luisa_compute/src/lang/types/core.rs index 9898c3e..6fd624a 100644 --- a/luisa_compute/src/lang/types/core.rs +++ b/luisa_compute/src/lang/types/core.rs @@ -1,163 +1,239 @@ use super::*; -use std::ops::Deref; -// This is a hack in order to get rust-analyzer to display type hints as Expr -// instead of Expr, which is rather redundant and generally clutters things up. -pub(crate) mod prim { +mod private { use super::*; + pub trait Sealed {} + impl Sealed for bool {} + impl Sealed for f16 {} + impl Sealed for f32 {} + impl Sealed for f64 {} + impl Sealed for i8 {} + impl Sealed for i16 {} + impl Sealed for i32 {} + impl Sealed for i64 {} + impl Sealed for u8 {} + impl Sealed for u16 {} + impl Sealed for u32 {} + impl Sealed for u64 {} +} + +pub trait Primitive: private::Sealed + Copy + TypeOf + 'static { + fn const_(&self) -> Const; + fn primitive() -> ir::Primitive; +} +impl Value for T { + type Expr = PrimitiveExpr; + type Var = PrimitiveVar; + type ExprData = (); + type VarData = (); + + fn expr(self) -> Expr { + let node = __current_scope(|s| -> NodeRef { s.const_(self.const_()) }); + Expr::::from_node(node) + } +} - #[derive(Clone, Copy, Debug)] - pub struct Expr { - pub(crate) node: NodeRef, - pub(crate) _phantom: PhantomData, - } - - #[derive(Clone, Copy, Debug)] - pub struct Var { - pub(crate) node: NodeRef, - pub(crate) _phantom: PhantomData, - } -} - -impl Aggregate for prim::Expr { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - _phantom: PhantomData, - } - } -} - -impl Aggregate for prim::Var { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - _phantom: PhantomData, - } - } -} - -impl FromNode for prim::Expr { - fn from_node(node: NodeRef) -> Self { - Self { - node, - _phantom: PhantomData, - } - } -} -impl ToNode for prim::Expr { - fn node(&self) -> NodeRef { - self.node - } -} - -impl Deref for prim::Var -where - prim::Var: VarProxy, -{ - type Target = T::Expr; - fn deref(&self) -> &Self::Target { - self._deref() - } -} - -macro_rules! impl_prim { - ($t:ty) => { - impl From<$t> for prim::Expr<$t> { - fn from(v: $t) -> Self { - (v).expr() - } - } - impl From> for prim::Expr<$t> { - fn from(v: Var<$t>) -> Self { - v.load() - } - } - impl FromNode for prim::Var<$t> { - fn from_node(node: NodeRef) -> Self { - Self { - node, - _phantom: PhantomData, - } - } - } - impl ToNode for prim::Var<$t> { - fn node(&self) -> NodeRef { - self.node - } - } - impl ExprProxy for prim::Expr<$t> { - type Value = $t; - } - impl VarProxy for prim::Var<$t> { - type Value = $t; - } - impl Value for $t { - type Expr = prim::Expr<$t>; - type Var = prim::Var<$t>; - fn fields() -> Vec { - vec![] - } - } - impl_callable_param!($t, prim::Expr<$t>, prim::Var<$t>); +impl_simple_expr_proxy!([T: Primitive] PrimitiveExpr[T] for T); +impl_simple_var_proxy!([T: Primitive] PrimitiveVar[T] for T); + +impl Primitive for bool { + fn const_(&self) -> Const { + Const::Bool(*self) + } + fn primitive() -> ir::Primitive { + ir::Primitive::Bool + } +} + +impl Primitive for f16 { + fn const_(&self) -> Const { + Const::Float16(*self) + } + fn primitive() -> ir::Primitive { + ir::Primitive::Float16 + } +} +impl Primitive for f32 { + fn const_(&self) -> Const { + Const::Float32(*self) + } + fn primitive() -> ir::Primitive { + ir::Primitive::Float32 + } +} +impl Primitive for f64 { + fn const_(&self) -> Const { + Const::Float64(*self) + } + fn primitive() -> ir::Primitive { + ir::Primitive::Float64 + } +} + +impl Primitive for i8 { + fn const_(&self) -> Const { + Const::Int8(*self) + } + fn primitive() -> ir::Primitive { + ir::Primitive::Int8 + } +} +impl Primitive for i16 { + fn const_(&self) -> Const { + Const::Int16(*self) + } + fn primitive() -> ir::Primitive { + ir::Primitive::Int16 + } +} +impl Primitive for i32 { + fn const_(&self) -> Const { + Const::Int32(*self) + } + fn primitive() -> ir::Primitive { + ir::Primitive::Int32 + } +} +impl Primitive for i64 { + fn const_(&self) -> Const { + Const::Int64(*self) + } + fn primitive() -> ir::Primitive { + ir::Primitive::Int64 + } +} + +impl Primitive for u8 { + fn const_(&self) -> Const { + Const::Uint8(*self) + } + fn primitive() -> ir::Primitive { + ir::Primitive::Uint8 + } +} +impl Primitive for u16 { + fn const_(&self) -> Const { + Const::Uint16(*self) + } + fn primitive() -> ir::Primitive { + ir::Primitive::Uint16 + } +} +impl Primitive for u32 { + fn const_(&self) -> Const { + Const::Uint32(*self) + } + fn primitive() -> ir::Primitive { + ir::Primitive::Uint32 + } +} +impl Primitive for u64 { + fn const_(&self) -> Const { + Const::Uint64(*self) + } + fn primitive() -> ir::Primitive { + ir::Primitive::Uint64 + } +} + +macro_rules! impls { + ($T:ident for $($t:ty),*) => { + $(impl $T for $t {})* }; } -impl_prim!(bool); -impl_prim!(u32); -impl_prim!(u64); -impl_prim!(i32); -impl_prim!(i64); -impl_prim!(i16); -impl_prim!(u16); -impl_prim!(f16); -impl_prim!(f32); -impl_prim!(f64); - -pub type Bool = prim::Expr; -pub type F16 = prim::Expr; -pub type F32 = prim::Expr; -pub type F64 = prim::Expr; -pub type I16 = prim::Expr; -pub type I32 = prim::Expr; -pub type I64 = prim::Expr; -pub type U16 = prim::Expr; -pub type U32 = prim::Expr; -pub type U64 = prim::Expr; - -pub type F16Var = prim::Var; -pub type F32Var = prim::Var; -pub type F64Var = prim::Var; -pub type I16Var = prim::Var; -pub type I32Var = prim::Var; -pub type I64Var = prim::Var; -pub type U16Var = prim::Var; -pub type U32Var = prim::Var; -pub type U64Var = prim::Var; - -pub type Half = prim::Expr; -pub type Float = prim::Expr; -pub type Double = prim::Expr; -pub type Int = prim::Expr; -pub type Long = prim::Expr; -pub type Uint = prim::Expr; -pub type Ulong = prim::Expr; -pub type Short = prim::Expr; -pub type Ushort = prim::Expr; - -pub type BoolVar = prim::Var; -pub type HalfVar = prim::Var; -pub type FloatVar = prim::Var; -pub type DoubleVar = prim::Var; -pub type IntVar = prim::Var; -pub type LongVar = prim::Var; -pub type UintVar = prim::Var; -pub type UlongVar = prim::Var; -pub type ShortVar = prim::Var; -pub type UshortVar = prim::Var; +pub trait Integral: Primitive {} +impls!(Integral for bool, i8, i16, i32, i64, u8, u16, u32, u64); + +pub trait Numeric: Primitive {} +impls!(Numeric for f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64); + +pub trait Floating: Numeric {} +impls!(Floating for f16, f32, f64); + +pub trait Signed: Numeric {} +impls!(Signed for f16, f32, f64, i8, i16, i32, i64); + +mod legacy { + use super::*; + + #[deprecated] + pub type Bool = Expr; + #[deprecated] + pub type F16 = Expr; + #[deprecated] + pub type F32 = Expr; + #[deprecated] + pub type F64 = Expr; + #[deprecated] + pub type I16 = Expr; + #[deprecated] + pub type I32 = Expr; + #[deprecated] + pub type I64 = Expr; + #[deprecated] + pub type U16 = Expr; + #[deprecated] + pub type U32 = Expr; + #[deprecated] + pub type U64 = Expr; + + #[deprecated] + pub type F16Var = Var; + #[deprecated] + pub type F32Var = Var; + #[deprecated] + pub type F64Var = Var; + #[deprecated] + pub type I16Var = Var; + #[deprecated] + pub type I32Var = Var; + #[deprecated] + pub type I64Var = Var; + #[deprecated] + pub type U16Var = Var; + #[deprecated] + pub type U32Var = Var; + #[deprecated] + pub type U64Var = Var; + + #[deprecated] + pub type Half = Expr; + #[deprecated] + pub type Float = Expr; + #[deprecated] + pub type Double = Expr; + #[deprecated] + pub type Int = Expr; + #[deprecated] + pub type Long = Expr; + #[deprecated] + pub type Uint = Expr; + #[deprecated] + pub type Ulong = Expr; + #[deprecated] + pub type Short = Expr; + #[deprecated] + pub type Ushort = Expr; + + #[deprecated] + pub type BoolVar = Var; + #[deprecated] + pub type HalfVar = Var; + #[deprecated] + pub type FloatVar = Var; + #[deprecated] + pub type DoubleVar = Var; + #[deprecated] + pub type IntVar = Var; + #[deprecated] + pub type LongVar = Var; + #[deprecated] + pub type UintVar = Var; + #[deprecated] + pub type UlongVar = Var; + #[deprecated] + pub type ShortVar = Var; + #[deprecated] + pub type UshortVar = Var; +} diff --git a/luisa_compute/src/lang/types/dynamic.rs b/luisa_compute/src/lang/types/dynamic.rs index 1744779..595520d 100644 --- a/luisa_compute/src/lang/types/dynamic.rs +++ b/luisa_compute/src/lang/types/dynamic.rs @@ -9,14 +9,14 @@ pub struct DynExpr { node: NodeRef, } -impl From for DynExpr { - fn from(value: T) -> Self { +impl From> for DynExpr { + fn from(value: Expr) -> Self { Self { node: value.node() } } } -impl From for DynVar { - fn from(value: T) -> Self { +impl From> for DynVar { + fn from(value: Var) -> Self { Self { node: value.node() } } } @@ -62,7 +62,7 @@ impl DynExpr { ) }) } - pub fn new(expr: E) -> Self { + pub fn new(expr: Expr) -> Self { Self { node: expr.node() } } } @@ -204,7 +204,7 @@ impl DynVar { __current_scope(|b| b.update(self.node, value.node)); } pub fn zero() -> Self { - let v = local_zeroed::(); + let v = Var::::zeroed(); Self { node: v.node() } } } diff --git a/luisa_compute/src/lang/types/shared.rs b/luisa_compute/src/lang/types/shared.rs index a374d69..0b14635 100644 --- a/luisa_compute/src/lang/types/shared.rs +++ b/luisa_compute/src/lang/types/shared.rs @@ -47,7 +47,7 @@ impl Shared { let value = value.into(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len()), "VLArrayVar::read out of bounds"); + lc_assert!(i.lt(self.len()), "VLArrayVar::read out of bounds"); } __current_scope(|b| { diff --git a/luisa_compute/src/lang/types/vector.rs b/luisa_compute/src/lang/types/vector.rs index a1505ea..1381f18 100644 --- a/luisa_compute/src/lang/types/vector.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -1,1825 +1,377 @@ +use super::alignment::*; use super::core::*; use super::*; -use ir::{MatrixType, Primitive, VectorElementType, VectorType}; -use serde::{Deserialize, Serialize}; -use std::ops::Mul; +use ir::{VectorElementType, VectorType}; +use std::fmt::Debug; -macro_rules! def_vec { - ($name:ident, $glam_type:ident, $scalar:ty, $align:literal, $($comp:ident), *) => { - #[repr(C, align($align))] - #[derive(Copy, Clone, Debug, Default, PartialEq, Serialize, Deserialize)] - pub struct $name { - $(pub $comp: $scalar), * - } - impl $name { - #[inline] - pub const fn new($($comp: $scalar), *) -> Self { - Self { $($comp), * } - } - #[inline] - pub const fn splat(scalar: $scalar) -> Self { - Self { $($comp: scalar), * } - } - } - impl From<$name> for glam::$glam_type { - #[inline] - fn from(v: $name) -> Self { - Self::new($(v.$comp), *) - } - } - impl From for $name { - #[inline] - fn from(v: glam::$glam_type) -> Self { - Self::new($(v.$comp), *) - } - } - }; -} -macro_rules! def_packed_vec { - ($name:ident, $vec_type:ident, $glam_type:ident, $scalar:ty, $($comp:ident), *) => { - #[repr(C)] - #[derive(Copy, Clone, Debug, Default, Value, PartialEq, Serialize, Deserialize)] - pub struct $name { - $(pub $comp: $scalar), * - } - impl $name { - #[inline] - pub const fn new($($comp: $scalar), *) -> Self { - Self { $($comp), * } - } - #[inline] - pub const fn splat(scalar: $scalar) -> Self { - Self { $($comp: scalar), * } - } - } - impl From<$name> for glam::$glam_type { - #[inline] - fn from(v: $name) -> Self { - Self::new($(v.$comp), *) - } - } - impl From for $name { - #[inline] - fn from(v: glam::$glam_type) -> Self { - Self::new($(v.$comp), *) - } - } - impl From<$name> for $vec_type { - #[inline] - fn from(v: $name) -> Self { - Self::new($(v.$comp), *) - } - } - impl From<$vec_type> for $name { - #[inline] - fn from(v: $vec_type) -> Self { - Self::new($(v.$comp), *) - } - } - }; -} -macro_rules! def_packed_vec_no_glam { - ($name:ident, $vec_type:ident, $scalar:ty, $($comp:ident), *) => { - #[repr(C)] - #[derive(Copy, Clone, Debug, Default, Value)] - pub struct $name { - $(pub $comp: $scalar), * - } - impl $name { - #[inline] - pub const fn new($($comp: $scalar), *) -> Self { - Self { $($comp), * } - } - #[inline] - pub const fn splat(scalar: $scalar) -> Self { - Self { $($comp: scalar), * } - } - } - impl From<$name> for $vec_type { - #[inline] - fn from(v: $name) -> Self { - Self::new($(v.$comp), *) - } - } - impl From<$vec_type> for $name { - #[inline] - fn from(v: $vec_type) -> Self { - Self::new($(v.$comp), *) - } - } - }; -} -macro_rules! def_vec_no_glam { - ($name:ident, $scalar:ty, $align:literal, $($comp:ident), *) => { - #[repr(C, align($align))] - #[derive(Copy, Clone, Debug, Default)] - pub struct $name { - $(pub $comp: $scalar), * - } - impl $name { - #[inline] - pub fn new($($comp: $scalar), *) -> Self { - Self { $($comp), * } - } - #[inline] - pub fn splat(scalar: $scalar) -> Self { - Self { $($comp: scalar), * } - } - } - }; -} -def_vec!(Float2, Vec2, f32, 8, x, y); -def_vec!(Float3, Vec3, f32, 16, x, y, z); -def_vec!(Float4, Vec4, f32, 16, x, y, z, w); - -def_packed_vec!(PackedFloat2, Float2, Vec2, f32, x, y); -def_packed_vec!(PackedFloat3, Float3, Vec3, f32, x, y, z); -def_packed_vec!(PackedFloat4, Float4, Vec4, f32, x, y, z, w); - -def_vec!(Uint2, UVec2, u32, 8, x, y); -def_vec!(Uint3, UVec3, u32, 16, x, y, z); -def_vec!(Uint4, UVec4, u32, 16, x, y, z, w); - -def_packed_vec!(PackedUint2, Uint2, UVec2, u32, x, y); -def_packed_vec!(PackedUint3, Uint3, UVec3, u32, x, y, z); -def_packed_vec!(PackedUint4, Uint4, UVec4, u32, x, y, z, w); - -def_vec!(Int2, IVec2, i32, 8, x, y); -def_vec!(Int3, IVec3, i32, 16, x, y, z); -def_vec!(Int4, IVec4, i32, 16, x, y, z, w); - -def_packed_vec!(PackedInt2, Int2, IVec2, i32, x, y); -def_packed_vec!(PackedInt3, Int3, IVec3, i32, x, y, z); -def_packed_vec!(PackedInt4, Int4, IVec4, i32, x, y, z, w); - -def_vec!(Double2, DVec2, f64, 16, x, y); -def_vec!(Double3, DVec3, f64, 32, x, y, z); -def_vec!(Double4, DVec4, f64, 32, x, y, z, w); - -def_vec!(Bool2, BVec2, bool, 2, x, y); -def_vec!(Bool3, BVec3, bool, 4, x, y, z); -def_vec!(Bool4, BVec4, bool, 4, x, y, z, w); - -def_packed_vec!(PackedBool2, Bool2, BVec2, bool, x, y); -def_packed_vec!(PackedBool3, Bool3, BVec3, bool, x, y, z); -def_packed_vec!(PackedBool4, Bool4, BVec4, bool, x, y, z, w); - -def_vec_no_glam!(Ulong2, u64, 16, x, y); -def_vec_no_glam!(Ulong3, u64, 32, x, y, z); -def_vec_no_glam!(Ulong4, u64, 32, x, y, z, w); - -def_packed_vec_no_glam!(PackedUlong2, Ulong2, u64, x, y); -def_packed_vec_no_glam!(PackedUlong3, Ulong3, u64, x, y, z); -def_packed_vec_no_glam!(PackedUlong4, Ulong4, u64, x, y, z, w); - -def_vec_no_glam!(Long2, i64, 16, x, y); -def_vec_no_glam!(Long3, i64, 32, x, y, z); -def_vec_no_glam!(Long4, i64, 32, x, y, z, w); - -def_packed_vec_no_glam!(PackedLong2, Long2, i64, x, y); -def_packed_vec_no_glam!(PackedLong3, Long3, i64, x, y, z); -def_packed_vec_no_glam!(PackedLong4, Long4, i64, x, y, z, w); - -def_vec_no_glam!(Ushort2, u16, 4, x, y); -def_vec_no_glam!(Ushort3, u16, 8, x, y, z); -def_vec_no_glam!(Ushort4, u16, 8, x, y, z, w); - -def_packed_vec_no_glam!(PackedUshort2, Ushort2, u16, x, y); -def_packed_vec_no_glam!(PackedUshort3, Ushort3, u16, x, y, z); -def_packed_vec_no_glam!(PackedUshort4, Ushort4, u16, x, y, z, w); - -def_vec_no_glam!(Short2, i16, 4, x, y); -def_vec_no_glam!(Short3, i16, 8, x, y, z); -def_vec_no_glam!(Short4, i16, 8, x, y, z, w); +#[cfg(feature = "glam")] +mod glam; +#[cfg(feature = "nalgebra")] +mod nalgebra; -def_packed_vec_no_glam!(PackedShort2, Short2, i16, x, y); -def_packed_vec_no_glam!(PackedShort3, Short3, i16, x, y, z); -def_packed_vec_no_glam!(PackedShort4, Short4, i16, x, y, z, w); +pub mod coords; +mod element; +mod impls; +pub mod legacy; +pub mod swizzle; -def_vec_no_glam!(Half2, f16, 4, x, y); -def_vec_no_glam!(Half3, f16, 8, x, y, z); -def_vec_no_glam!(Half4, f16, 8, x, y, z, w); +use swizzle::*; -// def_packed_vec_no_glam!(PackedHalf2, f16, x, y); -// def_packed_vec_no_glam!(PackedHalf3, f16, x, y, z); -// pub type PackHalf4 = Half4; +pub trait VectorElement: VectorAlign<2> + VectorAlign<3> + VectorAlign<4> {} +impl + VectorAlign<3> + VectorAlign<4>> VectorElement for T {} -def_vec_no_glam!(Ubyte2, u8, 2, x, y); -def_vec_no_glam!(Ubyte3, u8, 4, x, y, z); -def_vec_no_glam!(Ubyte4, u8, 4, x, y, z, w); - -// def_packed_vec_no_glam!(PackedUbyte2, u8, x, y); -// def_packed_vec_no_glam!(PackedUbyte3, u8, x, y, z); -// pub type PackUbyte4 = Ubyte4; - -def_vec_no_glam!(Byte2, u8, 2, x, y); -def_vec_no_glam!(Byte3, u8, 4, x, y, z); -def_vec_no_glam!(Byte4, u8, 4, x, y, z, w); - -// def_packed_vec_no_glam!(PackedByte2, u8, x, y); -// def_packed_vec_no_glam!(PackedByte3, u8, x, y, z); -// pub type PackByte4 = Byte4; - -#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] -#[repr(C, align(8))] -pub struct Mat2 { - pub cols: [Float2; 2], +pub trait VectorAlign: Primitive { + type A: Alignment; + type VectorExpr: ExprProxy>; + type VectorVar: VarProxy>; + type VectorExprData: Clone + FromNode + 'static; + type VectorVarData: Clone + FromNode + 'static; } -#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] -#[repr(C, align(16))] -pub struct Mat3 { - pub cols: [Float3; 3], -} -#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] -#[repr(C, align(16))] -pub struct Mat4 { - pub cols: [Float4; 4], -} -impl Mat2 { - pub const fn from_cols(c0: Float2, c1: Float2) -> Self { - Self { cols: [c0, c1] } - } - pub const fn identity() -> Self { - Self::from_cols(Float2::new(1.0, 0.0), Float2::new(0.0, 1.0)) - } -} -impl Mat3 { - pub const fn from_cols(c0: Float3, c1: Float3, c2: Float3) -> Self { - Self { cols: [c0, c1, c2] } - } - pub const fn identity() -> Self { - Self::from_cols( - Float3::new(1.0, 0.0, 0.0), - Float3::new(0.0, 1.0, 0.0), - Float3::new(0.0, 0.0, 1.0), - ) - } -} -impl Mat4 { - pub const fn from_cols(c0: Float4, c1: Float4, c2: Float4, c3: Float4) -> Self { - Self { - cols: [c0, c1, c2, c3], - } - } - pub const fn identity() -> Self { - Self::from_cols( - Float4::new(1.0, 0.0, 0.0, 0.0), - Float4::new(0.0, 1.0, 0.0, 0.0), - Float4::new(0.0, 0.0, 1.0, 0.0), - Float4::new(0.0, 0.0, 0.0, 1.0), - ) - } - pub fn into_affine3x4(self) -> [f32; 12] { - // [ - // self.cols[0].x, - // self.cols[0].y, - // self.cols[0].z, - // self.cols[1].x, - // self.cols[1].y, - // self.cols[1].z, - // self.cols[2].x, - // self.cols[2].y, - // self.cols[2].z, - // self.cols[3].x, - // self.cols[3].y, - // self.cols[3].z, - // ] - [ - self.cols[0].x, - self.cols[1].x, - self.cols[2].x, - self.cols[3].x, - self.cols[0].y, - self.cols[1].y, - self.cols[2].y, - self.cols[3].y, - self.cols[0].z, - self.cols[1].z, - self.cols[2].z, - self.cols[3].z, - ] + +impl, const N: usize> Debug for Vector { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.elements.fmt(f) } } -impl From for glam::Mat2 { - #[inline] - fn from(m: Mat2) -> Self { - Self::from_cols(m.cols[0].into(), m.cols[1].into()) - } + +#[repr(C)] +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct Vector, const N: usize> { + _align: T::A, + pub elements: [T; N], } -impl From for glam::Mat3 { - #[inline] - fn from(m: Mat3) -> Self { - Self::from_cols(m.cols[0].into(), m.cols[1].into(), m.cols[2].into()) + +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct VectorExprData, const N: usize>([Expr; N]); +impl, const N: usize> FromNode for VectorExprData { + fn from_node(node: NodeRef) -> Self { + Self(std::array::from_fn(|i| { + FromNode::from_node(__extract::(node, i)) + })) } } -impl From for glam::Mat4 { - #[inline] - fn from(m: Mat4) -> Self { - Self::from_cols( - m.cols[0].into(), - m.cols[1].into(), - m.cols[2].into(), - m.cols[3].into(), - ) +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct VectorVarData, const N: usize>([Var; N]); +impl, const N: usize> FromNode for VectorVarData { + fn from_node(node: NodeRef) -> Self { + Self(std::array::from_fn(|i| { + FromNode::from_node(__extract::(node, i)) + })) } } -impl From for Mat2 { - #[inline] - fn from(m: glam::Mat2) -> Self { - Self { - cols: [m.x_axis.into(), m.y_axis.into()], - } +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DoubledProxyData(X, X); +impl FromNode for DoubledProxyData { + fn from_node(node: NodeRef) -> Self { + Self(X::from_node(node), X::from_node(node)) + } +} + +pub trait VectorExprProxy { + const N: usize; + type T: Primitive; + fn node(&self) -> NodeRef; + fn _permute2(&self, x: u32, y: u32) -> Expr> + where + Self::T: VectorAlign<2>, + { + assert!(x < Self::N as u32); + assert!(y < Self::N as u32); + let x = x.expr(); + let y = y.expr(); + Expr::>::from_node(__current_scope(|s| { + s.call( + Func::Permute, + &[self.node(), x.node(), y.node()], + Vec2::::type_(), + ) + })) } -} -impl From for Mat3 { - #[inline] - fn from(m: glam::Mat3) -> Self { - Self { - cols: [m.x_axis.into(), m.y_axis.into(), m.z_axis.into()], - } + fn _permute3(&self, x: u32, y: u32, z: u32) -> Expr> + where + Self::T: VectorAlign<3>, + { + assert!(x < Self::N as u32); + assert!(y < Self::N as u32); + assert!(z < Self::N as u32); + let x = x.expr(); + let y = y.expr(); + let z = z.expr(); + Expr::>::from_node(__current_scope(|s| { + s.call( + Func::Permute, + &[self.node(), x.node(), y.node(), z.node()], + Vec3::::type_(), + ) + })) } -} -impl From for Mat4 { - #[inline] - fn from(m: glam::Mat4) -> Self { - Self { - cols: [ - m.x_axis.into(), - m.y_axis.into(), - m.z_axis.into(), - m.w_axis.into(), - ], - } + fn _permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Expr> + where + Self::T: VectorAlign<4>, + { + assert!(x < Self::N as u32); + assert!(y < Self::N as u32); + assert!(z < Self::N as u32); + assert!(w < Self::N as u32); + let x = x.expr(); + let y = y.expr(); + let z = z.expr(); + let w = w.expr(); + Expr::>::from_node(__current_scope(|s| { + s.call( + Func::Permute, + &[self.node(), x.node(), y.node(), z.node(), w.node()], + Vec4::::type_(), + ) + })) } } -macro_rules! impl_proxy_fields { - ($vec:ident, $proxy:ident, $scalar:ty, x) => { - impl $proxy { - #[inline] - pub fn x(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 0)) - } - #[inline] - pub fn set_x(&self, value: prim::Expr<$scalar>) -> Self { - Self::from_node(__insert::<$vec>(self.node, 0, ToNode::node(&value))) - } - } - }; - ($vec:ident,$proxy:ident, $scalar:ty, y) => { - impl $proxy { - #[inline] - pub fn y(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 1)) - } - #[inline] - pub fn set_y(&self, value: prim::Expr<$scalar>) -> Self { - Self::from_node(__insert::<$vec>(self.node, 1, ToNode::node(&value))) - } - } - }; - ($vec:ident,$proxy:ident, $scalar:ty, z) => { - impl $proxy { - #[inline] - pub fn z(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 2)) - } - #[inline] - pub fn set_z(&self, value: prim::Expr<$scalar>) -> Self { - Self::from_node(__insert::<$vec>(self.node, 2, ToNode::node(&value))) - } - } - }; - ($vec:ident,$proxy:ident, $scalar:ty, w) => { - impl $proxy { - #[inline] - pub fn w(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 3)) - } - #[inline] - pub fn set_w(&self, value: prim::Expr<$scalar>) -> Self { - Self::from_node(__insert::<$vec>(self.node, 3, ToNode::node(&value))) - } - } - }; -} -macro_rules! impl_var_proxy_fields { - ($proxy:ident, $scalar:ty, x) => { - impl $proxy { - #[inline] - pub fn x(&self) -> Var<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 0)) - } - } - }; - ($proxy:ident, $scalar:ty, y) => { - impl $proxy { - #[inline] - pub fn y(&self) -> Var<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 1)) - } - } - }; - ($proxy:ident, $scalar:ty, z) => { - impl $proxy { - #[inline] - pub fn z(&self) -> Var<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 2)) - } - } - }; - ($proxy:ident, $scalar:ty, w) => { - impl $proxy { - #[inline] - pub fn w(&self) -> Var<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 3)) - } - } - }; -} -macro_rules! impl_vec_proxy { - ($vec:ident, $expr_proxy:ident, $var_proxy:ident, $scalar:ty, $scalar_ty:ident, $length:literal, $($comp:ident), *) => { - #[derive(Clone, Copy)] - pub struct $expr_proxy { - node: NodeRef, - } - #[derive(Clone, Copy)] - pub struct $var_proxy { - node: NodeRef, - } - impl Value for $vec { - type Expr = $expr_proxy; - type Var = $var_proxy; - fn fields() -> Vec { - vec![$(stringify!($comp).to_string()),*] - } - } - impl TypeOf for $vec { - fn type_() -> luisa_compute_ir::CArc { - let type_ = Type::Vector(VectorType { - element: VectorElementType::Scalar(Primitive::$scalar_ty), - length: $length, - }); - register_type(type_) - } - } - impl Aggregate for $expr_proxy { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - } - } - } - impl VectorVarTrait for $expr_proxy { } - impl ScalarOrVector for $expr_proxy { - type Element = prim::Expr<$scalar>; - type ElementHost = $scalar; - } - impl BuiltinVarTrait for $expr_proxy { } - impl Aggregate for $var_proxy { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - } - } - } - impl FromNode for $expr_proxy { - fn from_node(node: NodeRef) -> Self { - Self { node } - } - } - impl ToNode for $expr_proxy { - fn node(&self) -> NodeRef { - self.node - } - } - impl FromNode for $var_proxy { - fn from_node(node: NodeRef) -> Self { - Self { node } - } - } - impl ToNode for $var_proxy { - fn node(&self) -> NodeRef { - self.node - } - } - impl ExprProxy for $expr_proxy { - type Value = $vec; - } - impl VarProxy for $var_proxy { - type Value = $vec; - } - impl std::ops::Deref for $var_proxy { - type Target = $expr_proxy; - fn deref(&self) -> &Self::Target { - self._deref() - } - } - impl From<$var_proxy> for $expr_proxy { - fn from(var: $var_proxy) -> Self { - var.load() - } - } - impl_callable_param!($vec, $expr_proxy, $var_proxy); - $(impl_proxy_fields!($vec, $expr_proxy, $scalar, $comp);)* - $(impl_var_proxy_fields!($var_proxy, $scalar, $comp);)* - impl $expr_proxy { - #[inline] - pub fn new($($comp: prim::Expr<$scalar>), *) -> Self { - Self { - node: __compose::<$vec>(&[$(ToNode::node(&$comp)), *]), - } - } - pub fn at(&self, index: usize) -> prim::Expr<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, index)) - } +macro_rules! vector_proxies { + ($N:literal [ $($c:ident),* ]: $ExprName:ident, $VarName:ident) => { + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct $ExprName> { + _node: NodeRef, + $(pub $c: Expr),* } - impl $vec { - #[inline] - pub fn expr($($comp: impl Into>), *) -> $expr_proxy { - $expr_proxy::new($($comp.into()), *) - } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct $VarName> { + _node: NodeRef, + $(pub $c: Var),* } - }; -} -macro_rules! impl_mat_proxy { - ($mat:ident, $expr_proxy:ident, $var_proxy:ident, $vec:ty, $scalar_ty:ident, $length:literal, $($comp:ident), *) => { - #[derive(Clone, Copy)] - pub struct $expr_proxy { - node: NodeRef, - } - #[derive(Clone, Copy)] - pub struct $var_proxy { - node: NodeRef, - } - impl Value for $mat { - type Expr = $expr_proxy; - type Var = $var_proxy; - fn fields() -> Vec { - vec![$(stringify!($comp).to_string()),*] - } - } - impl TypeOf for $mat { - fn type_() -> luisa_compute_ir::CArc { - let type_ = Type::Matrix(MatrixType { - element: VectorElementType::Scalar(Primitive::$scalar_ty), - dimension: $length, - }); - register_type(type_) - } - } - impl Aggregate for $expr_proxy { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - } - } - } - impl MatrixVarTrait for $expr_proxy { } - impl BuiltinVarTrait for $expr_proxy { } - impl Aggregate for $var_proxy { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - } - } - } - impl FromNode for $expr_proxy { - fn from_node(node: NodeRef) -> Self { - Self { node } - } - } - impl ToNode for $expr_proxy { - fn node(&self) -> NodeRef { - self.node - } - } - impl ExprProxy for $expr_proxy { - type Value = $mat; - } - impl FromNode for $var_proxy { - fn from_node(node: NodeRef) -> Self { - Self { node } - } + unsafe impl> HasExprLayout< as Value>::ExprData> for $ExprName {} + unsafe impl> HasVarLayout< as Value>::VarData> for $VarName {} + + impl>> ExprProxy for $ExprName { + type Value = Vector; } - impl ToNode for $var_proxy { + impl> VectorExprProxy for $ExprName { + const N: usize = $N; + type T = T; fn node(&self) -> NodeRef { - self.node + self._node } } - impl VarProxy for $var_proxy { - type Value = $mat; + impl>> VarProxy for $VarName { + type Value = Vector; } - impl std::ops::Deref for $var_proxy { - type Target = $expr_proxy; + impl>> Deref for $VarName { + type Target = Expr>; fn deref(&self) -> &Self::Target { - self._deref() - } - } - impl From<$var_proxy> for $expr_proxy { - fn from(var: $var_proxy) -> Self { - var.load() - } - } - impl_callable_param!($mat, $expr_proxy, $var_proxy); - impl $expr_proxy { - #[inline] - pub fn new($($comp: Expr<$vec>), *) -> Self { - Self { - node: __compose::<$mat>(&[$(ToNode::node(&$comp)), *]), - } - } - pub fn col(&self, index: usize) -> Expr<$vec> { - Expr::<$vec>::from_node(__extract::<$vec>(self.node, index)) - } - } - impl $mat { - #[inline] - pub fn expr($($comp: impl Into>), *) -> $expr_proxy { - $expr_proxy::new($($comp.into()), *) - } - } - }; -} - -impl_vec_proxy!(Bool2, Bool2Expr, Bool2Var, bool, Bool, 2, x, y); -impl_vec_proxy!(Bool3, Bool3Expr, Bool3Var, bool, Bool, 3, x, y, z); -impl_vec_proxy!(Bool4, Bool4Expr, Bool4Var, bool, Bool, 4, x, y, z, w); - -impl_vec_proxy!(Half2, Half2Expr, Half2Var, f16, Float16, 2, x, y); -impl_vec_proxy!(Half3, Half3Expr, Half3Var, f16, Float16, 3, x, y, z); -impl_vec_proxy!(Half4, Half4Expr, Half4Var, f16, Float16, 4, x, y, z, w); - -impl_vec_proxy!(Float2, Float2Expr, Float2Var, f32, Float32, 2, x, y); -impl_vec_proxy!(Float3, Float3Expr, Float3Var, f32, Float32, 3, x, y, z); -impl_vec_proxy!(Float4, Float4Expr, Float4Var, f32, Float32, 4, x, y, z, w); - -impl_vec_proxy!(Double2, Double2Expr, Double2Var, f64, Float64, 2, x, y); -impl_vec_proxy!(Double3, Double3Expr, Double3Var, f64, Float64, 3, x, y, z); -impl_vec_proxy!( - Double4, - Double4Expr, - Double4Var, - f64, - Float64, - 4, - x, - y, - z, - w -); - -impl_vec_proxy!(Ushort2, Ushort2Expr, Ushort2Var, u16, Uint16, 2, x, y); -impl_vec_proxy!(Ushort3, Ushort3Expr, Ushort3Var, u16, Uint16, 3, x, y, z); -impl_vec_proxy!(Ushort4, Ushort4Expr, Ushort4Var, u16, Uint16, 4, x, y, z, w); - -impl_vec_proxy!(Short2, Short2Expr, Short2Var, i16, Int16, 2, x, y); -impl_vec_proxy!(Short3, Short3Expr, Short3Var, i16, Int16, 3, x, y, z); -impl_vec_proxy!(Short4, Short4Expr, Short4Var, i16, Int16, 4, x, y, z, w); - -impl_vec_proxy!(Uint2, Uint2Expr, Uint2Var, u32, Uint32, 2, x, y); -impl_vec_proxy!(Uint3, Uint3Expr, Uint3Var, u32, Uint32, 3, x, y, z); -impl_vec_proxy!(Uint4, Uint4Expr, Uint4Var, u32, Uint32, 4, x, y, z, w); - -impl_vec_proxy!(Int2, Int2Expr, Int2Var, i32, Int32, 2, x, y); -impl_vec_proxy!(Int3, Int3Expr, Int3Var, i32, Int32, 3, x, y, z); -impl_vec_proxy!(Int4, Int4Expr, Int4Var, i32, Int32, 4, x, y, z, w); - -impl_vec_proxy!(Ulong2, Ulong2Expr, Ulong2Var, u64, Uint64, 2, x, y); -impl_vec_proxy!(Ulong3, Ulong3Expr, Ulong3Var, u64, Uint64, 3, x, y, z); -impl_vec_proxy!(Ulong4, Ulong4Expr, Ulong4Var, u64, Uint64, 4, x, y, z, w); - -impl_vec_proxy!(Long2, Long2Expr, Long2Var, i64, Int64, 2, x, y); -impl_vec_proxy!(Long3, Long3Expr, Long3Var, i64, Int64, 3, x, y, z); -impl_vec_proxy!(Long4, Long4Expr, Long4Var, i64, Int64, 4, x, y, z, w); - -impl_mat_proxy!(Mat2, Mat2Expr, Mat2Var, Float2, Float32, 2, x, y); -impl_mat_proxy!(Mat3, Mat3Expr, Mat3Var, Float3, Float32, 3, x, y, z); -impl_mat_proxy!(Mat4, Mat4Expr, Mat4Var, Float4, Float32, 4, x, y, z, w); - -macro_rules! impl_packed_cvt { - ($packed:ty, $vec:ty, $($comp:ident), *) => { - impl From<$vec> for $packed { - fn from(v: $vec) -> Self { - Self::new($(v.$comp()), *) - } - } - impl $packed { - pub fn unpack(&self) -> $vec { - (*self).into() - } - } - impl From<$packed> for $vec { - fn from(v: $packed) -> Self { - Self::new($(v.$comp()), *) - } - } - impl $vec { - pub fn pack(&self) -> $packed { - (*self).into() + _deref_proxy(self) } } } } -impl_packed_cvt!(PackedFloat2Expr, Float2Expr, x, y); -impl_packed_cvt!(PackedFloat3Expr, Float3Expr, x, y, z); -impl_packed_cvt!(PackedFloat4Expr, Float4Expr, x, y, z, w); - -impl_packed_cvt!(PackedShort2Expr, Short2Expr, x, y); -impl_packed_cvt!(PackedShort3Expr, Short3Expr, x, y, z); -impl_packed_cvt!(PackedShort4Expr, Short4Expr, x, y, z, w); - -// ushort -impl_packed_cvt!(PackedUshort2Expr, Ushort2Expr, x, y); -impl_packed_cvt!(PackedUshort3Expr, Ushort3Expr, x, y, z); -impl_packed_cvt!(PackedUshort4Expr, Ushort4Expr, x, y, z, w); -// int -impl_packed_cvt!(PackedInt2Expr, Int2Expr, x, y); -impl_packed_cvt!(PackedInt3Expr, Int3Expr, x, y, z); -impl_packed_cvt!(PackedInt4Expr, Int4Expr, x, y, z, w); +vector_proxies!(2 [x, y]: VectorExprProxy2, VectorVarProxy2); +vector_proxies!(3 [x, y, z, r, g, b]: VectorExprProxy3, VectorVarProxy3); +vector_proxies!(4 [x, y, z, w, r, g, b, a]: VectorExprProxy4, VectorVarProxy4); -// uint -impl_packed_cvt!(PackedUint2Expr, Uint2Expr, x, y); -impl_packed_cvt!(PackedUint3Expr, Uint3Expr, x, y, z); -impl_packed_cvt!(PackedUint4Expr, Uint4Expr, x, y, z, w); - -// long -impl_packed_cvt!(PackedLong2Expr, Long2Expr, x, y); -impl_packed_cvt!(PackedLong3Expr, Long3Expr, x, y, z); -impl_packed_cvt!(PackedLong4Expr, Long4Expr, x, y, z, w); - -// ulong -impl_packed_cvt!(PackedUlong2Expr, Ulong2Expr, x, y); -impl_packed_cvt!(PackedUlong3Expr, Ulong3Expr, x, y, z); -impl_packed_cvt!(PackedUlong4Expr, Ulong4Expr, x, y, z, w); - -macro_rules! impl_binop { - ($t:ty, $scalar:ty, $proxy:ty, $tr:ident, $m:ident, $tr_assign:ident, $m_assign:ident) => { - impl std::ops::$tr_assign<$proxy> for $proxy { - fn $m_assign(&mut self, rhs: $proxy) { - use std::ops::$tr; - *self = (*self).$m(rhs); - } - } - impl std::ops::$tr for $proxy { - type Output = $proxy; - fn $m(self, rhs: $proxy) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr<$scalar> for $proxy { - type Output = $proxy; - fn $m(self, rhs: $scalar) -> Self::Output { - let rhs = Self::splat(rhs); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr<$proxy> for $scalar { - type Output = $proxy; - fn $m(self, rhs: $proxy) -> Self::Output { - let lhs = <$proxy>::splat(self); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[lhs.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr> for $proxy { - type Output = $proxy; - fn $m(self, rhs: prim::Expr<$scalar>) -> Self::Output { - let rhs = Self::splat(rhs); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr<$proxy> for prim::Expr<$scalar> { - type Output = $proxy; - fn $m(self, rhs: $proxy) -> Self::Output { - let lhs = <$proxy>::splat(self); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[lhs.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - }; -} -macro_rules! impl_binop_for_mat { - ($t:ty, $scalar:ty, $proxy:ty, $tr:ident, $m:ident, $tr_assign:ident, $m_assign:ident) => { - impl std::ops::$tr_assign<$proxy> for $proxy { - fn $m_assign(&mut self, rhs: $proxy) { - use std::ops::$tr; - *self = (*self).$m(rhs); - } - } - impl std::ops::$tr for $proxy { - type Output = $proxy; - fn $m(self, rhs: $proxy) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr<$scalar> for $proxy { - type Output = $proxy; - fn $m(self, rhs: $scalar) -> Self::Output { - let rhs = Self::fill(rhs); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr<$proxy> for $scalar { - type Output = $proxy; - fn $m(self, rhs: $proxy) -> Self::Output { - let lhs = <$proxy>::fill(self); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[lhs.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr> for $proxy { - type Output = $proxy; - fn $m(self, rhs: prim::Expr<$scalar>) -> Self::Output { - let rhs = Self::fill(rhs); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr<$proxy> for prim::Expr<$scalar> { - type Output = $proxy; - fn $m(self, rhs: $proxy) -> Self::Output { - let lhs = <$proxy>::fill(self); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[lhs.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - }; -} -macro_rules! impl_arith_binop { - ($t:ty, $scalar:ty, $proxy:ty) => { - impl_common_op!($t, $scalar, $proxy); - impl_binop!($t, $scalar, $proxy, Add, add, AddAssign, add_assign); - impl_binop!($t, $scalar, $proxy, Sub, sub, SubAssign, sub_assign); - impl_binop!($t, $scalar, $proxy, Mul, mul, MulAssign, mul_assign); - impl_binop!($t, $scalar, $proxy, Div, div, DivAssign, div_assign); - impl_binop!($t, $scalar, $proxy, Rem, rem, RemAssign, rem_assign); - impl_reduce!($t, $scalar, $proxy); - impl std::ops::Neg for $proxy { - type Output = $proxy; - fn neg(self) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Neg, &[self.node], <$t as TypeOf>::type_()) - })) - } - } - }; -} -macro_rules! impl_arith_binop_for_mat { - ($t:ty, $scalar:ty, $proxy:ty) => { - impl_binop_for_mat!($t, $scalar, $proxy, Add, add, AddAssign, add_assign); - impl_binop_for_mat!($t, $scalar, $proxy, Sub, sub, SubAssign, sub_assign); - // Mat * Mat - impl std::ops::MulAssign<$proxy> for $proxy { - fn mul_assign(&mut self, rhs: $proxy) { - use std::ops::Mul; - *self = (*self).mul(rhs); - } - } - impl std::ops::Mul for $proxy { - type Output = $proxy; - fn mul(self, rhs: $proxy) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Mul, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - // Mat * Scalar - impl std::ops::MulAssign<$scalar> for $proxy { - fn mul_assign(&mut self, rhs: $scalar) { - use std::ops::Mul; - *self = (*self).mul(rhs); - } - } - impl std::ops::Mul<$scalar> for $proxy { - type Output = $proxy; - fn mul(self, rhs: $scalar) -> Self::Output { - let rhs = Self::fill(rhs); - <$proxy>::from_node(__current_scope(|s| { - s.call( - Func::MatCompMul, - &[self.node, rhs.node], - <$t as TypeOf>::type_(), - ) - })) - } - } - impl std::ops::Mul<$proxy> for $scalar { - type Output = $proxy; - fn mul(self, rhs: $proxy) -> Self::Output { - let lhs = <$proxy>::fill(self); - <$proxy>::from_node(__current_scope(|s| { - s.call( - Func::MatCompMul, - &[lhs.node, rhs.node], - <$t as TypeOf>::type_(), - ) - })) - } - } - impl std::ops::Mul> for $proxy { - type Output = $proxy; - fn mul(self, rhs: prim::Expr<$scalar>) -> Self::Output { - let rhs = Self::fill(rhs); - <$proxy>::from_node(__current_scope(|s| { - s.call( - Func::MatCompMul, - &[self.node, rhs.node], - <$t as TypeOf>::type_(), - ) - })) - } - } - impl std::ops::Mul<$proxy> for prim::Expr<$scalar> { - type Output = $proxy; - fn mul(self, rhs: $proxy) -> Self::Output { - let lhs = <$proxy>::fill(self); - <$proxy>::from_node(__current_scope(|s| { - s.call( - Func::MatCompMul, - &[lhs.node, rhs.node], - <$t as TypeOf>::type_(), - ) - })) - } - } - // Rem - impl std::ops::RemAssign<$scalar> for $proxy { - fn rem_assign(&mut self, rhs: $scalar) { - use std::ops::Rem; - *self = (*self).rem(rhs); - } - } - impl std::ops::Rem<$scalar> for $proxy { - type Output = $proxy; - fn rem(self, rhs: $scalar) -> Self::Output { - let rhs: prim::Expr<$scalar> = rhs.into(); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Rem, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::Rem> for $proxy { - type Output = $proxy; - fn rem(self, rhs: prim::Expr<$scalar>) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Rem, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - // Div - impl std::ops::DivAssign<$scalar> for $proxy { - fn div_assign(&mut self, rhs: $scalar) { - use std::ops::Div; - *self = (*self).div(rhs); - } - } - impl std::ops::Div<$scalar> for $proxy { - type Output = $proxy; - fn div(self, rhs: $scalar) -> Self::Output { - let rhs: prim::Expr<$scalar> = rhs.into(); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Div, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::Div> for $proxy { - type Output = $proxy; - fn div(self, rhs: prim::Expr<$scalar>) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Div, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - // Neg - impl std::ops::Neg for $proxy { - type Output = $proxy; - fn neg(self) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Neg, &[self.node], <$t as TypeOf>::type_()) - })) - } - } - impl $proxy { - pub fn comp_mul(&self, other: Self) -> Self { - <$proxy>::from_node(__current_scope(|s| { - s.call( - Func::MatCompMul, - &[self.node, other.node], - <$t as TypeOf>::type_(), - ) - })) - } - } - }; -} -macro_rules! impl_int_binop { - ($t:ty, $scalar:ty, $proxy:ty) => { - impl_binop!( - $t, - $scalar, - $proxy, - BitAnd, - bitand, - BitAndAssign, - bitand_assign - ); - impl_binop!($t, $scalar, $proxy, BitOr, bitor, BitOrAssign, bitor_assign); - impl_binop!( - $t, - $scalar, - $proxy, - BitXor, - bitxor, - BitXorAssign, - bitxor_assign - ); - impl_binop!($t, $scalar, $proxy, Shl, shl, ShlAssign, shl_assign); - impl_binop!($t, $scalar, $proxy, Shr, shr, ShrAssign, shr_assign); - impl std::ops::Not for $proxy { - type Output = Expr<$t>; - fn not(self) -> Self::Output { - __current_scope(|s| { - let ret = s.call(Func::BitNot, &[ToNode::node(&self)], Self::Output::type_()); - Expr::<$t>::from_node(ret) - }) - } - } - }; -} -macro_rules! impl_bool_binop { - ($t:ty, $proxy:ty) => { - impl_binop!( - $t, - bool, - $proxy, - BitAnd, - bitand, - BitAndAssign, - bitand_assign - ); - impl_binop!($t, bool, $proxy, BitOr, bitor, BitOrAssign, bitor_assign); - impl_binop!( - $t, - bool, - $proxy, - BitXor, - bitxor, - BitXorAssign, - bitxor_assign - ); - impl $proxy { - pub fn splat>>(value: V) -> Self { - let value = value.into(); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Vec, &[value.node], <$t as TypeOf>::type_()) - })) - } - pub fn zero() -> Self { - Self::splat(false) - } - pub fn one() -> Self { - Self::splat(true) - } - pub fn all(&self) -> prim::Expr { - Expr::::from_node(__current_scope(|s| { - s.call(Func::All, &[self.node], ::type_()) - })) - } - pub fn any(&self) -> prim::Expr { - Expr::::from_node(__current_scope(|s| { - s.call(Func::Any, &[self.node], ::type_()) - })) - } - } - impl std::ops::Not for $proxy { - type Output = Expr<$t>; - fn not(self) -> Self::Output { - self ^ Self::splat(true) - } - } - }; -} -macro_rules! impl_reduce { - ($t:ty, $scalar:ty, $proxy:ty) => { - impl $proxy { - #[inline] - pub fn reduce_sum(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__current_scope(|s| { - s.call(Func::ReduceSum, &[self.node], <$scalar as TypeOf>::type_()) - })) - } - #[inline] - pub fn reduce_prod(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__current_scope(|s| { - s.call(Func::ReduceProd, &[self.node], <$scalar as TypeOf>::type_()) - })) - } - #[inline] - pub fn reduce_min(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__current_scope(|s| { - s.call(Func::ReduceMin, &[self.node], <$scalar as TypeOf>::type_()) - })) - } - #[inline] - pub fn reduce_max(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__current_scope(|s| { - s.call(Func::ReduceMax, &[self.node], <$scalar as TypeOf>::type_()) - })) - } - #[inline] - pub fn dot(&self, rhs: $proxy) -> prim::Expr<$scalar> { - FromNode::from_node(__current_scope(|s| { - s.call( - Func::Dot, - &[self.node, rhs.node], - <$scalar as TypeOf>::type_(), - ) - })) - } - } - }; -} -macro_rules! impl_common_op { - ($t:ty, $scalar:ty, $proxy:ty) => { - impl $proxy { - pub fn splat>>(value: V) -> Self { - let value = value.into(); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Vec, &[value.node], <$t as TypeOf>::type_()) - })) - } - pub fn zero() -> Self { - Self::splat(0.0 as $scalar) - } - pub fn one() -> Self { - Self::splat(1.0 as $scalar) - } - } - }; -} -macro_rules! impl_vec_op { - ($t:ty, $scalar:ty, $proxy:ty, $mat:ty) => { - impl $proxy { - #[inline] - pub fn length(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__current_scope(|s| { - s.call(Func::Length, &[self.node], <$scalar as TypeOf>::type_()) - })) - } - #[inline] - pub fn normalize(&self) -> Self { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Normalize, &[self.node], <$t as TypeOf>::type_()) - })) - } - #[inline] - pub fn length_squared(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__current_scope(|s| { - s.call( - Func::LengthSquared, - &[self.node], - <$scalar as TypeOf>::type_(), - ) - })) - } - #[inline] - pub fn distance(&self, rhs: $proxy) -> prim::Expr<$scalar> { - (*self - rhs).length() - } - #[inline] - pub fn distance_squared(&self, rhs: $proxy) -> prim::Expr<$scalar> { - (*self - rhs).length_squared() - } - #[inline] - pub fn fma(&self, a: $proxy, b: $proxy) -> Self { - <$proxy>::from_node(__current_scope(|s| { - s.call( - Func::Fma, - &[self.node, a.node, b.node], - <$t as TypeOf>::type_(), - ) - })) - } - #[inline] - pub fn outer_product(&self, rhs: $proxy) -> Expr<$mat> { - Expr::<$mat>::from_node(__current_scope(|s| { - s.call( - Func::OuterProduct, - &[self.node, rhs.node], - <$mat as TypeOf>::type_(), - ) - })) - } - } - }; -} - -// a little shit -macro_rules! impl_arith_binop_f16 { - ($t:ty, $scalar:ty, $proxy:ty) => { - impl_common_op_f16!($t, $scalar, $proxy); - impl_binop!($t, $scalar, $proxy, Add, add, AddAssign, add_assign); - impl_binop!($t, $scalar, $proxy, Sub, sub, SubAssign, sub_assign); - impl_binop!($t, $scalar, $proxy, Mul, mul, MulAssign, mul_assign); - impl_binop!($t, $scalar, $proxy, Div, div, DivAssign, div_assign); - impl_binop!($t, $scalar, $proxy, Rem, rem, RemAssign, rem_assign); - impl_reduce!($t, $scalar, $proxy); - impl std::ops::Neg for $proxy { - type Output = $proxy; - fn neg(self) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Neg, &[self.node], <$t as TypeOf>::type_()) - })) - } - } - }; -} -macro_rules! impl_common_op_f16 { - ($t:ty, $scalar:ty, $proxy:ty) => { - impl $proxy { - pub fn splat>>(value: V) -> Self { - let value = value.into(); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Vec, &[value.node], <$t as TypeOf>::type_()) - })) - } - pub fn zero() -> Self { - Self::splat(f16::from_f32(0.0f32)) - } - pub fn one() -> Self { - Self::splat(f16::from_f32(1.0f32)) - } - } - }; -} - -impl_arith_binop_f16!(Half2, f16, Half2Expr); -impl_arith_binop_f16!(Half3, f16, Half3Expr); -impl_arith_binop_f16!(Half4, f16, Half4Expr); - -impl_arith_binop!(Float2, f32, Float2Expr); -impl_arith_binop!(Float3, f32, Float3Expr); -impl_arith_binop!(Float4, f32, Float4Expr); - -impl_arith_binop!(Short2, i16, Short2Expr); -impl_arith_binop!(Short3, i16, Short3Expr); -impl_arith_binop!(Short4, i16, Short4Expr); - -impl_arith_binop!(Ushort2, u16, Ushort2Expr); -impl_arith_binop!(Ushort3, u16, Ushort3Expr); -impl_arith_binop!(Ushort4, u16, Ushort4Expr); - -impl_arith_binop!(Int2, i32, Int2Expr); -impl_arith_binop!(Int3, i32, Int3Expr); -impl_arith_binop!(Int4, i32, Int4Expr); - -impl_arith_binop!(Uint2, u32, Uint2Expr); -impl_arith_binop!(Uint3, u32, Uint3Expr); -impl_arith_binop!(Uint4, u32, Uint4Expr); - -impl_arith_binop!(Long2, i64, Long2Expr); -impl_arith_binop!(Long3, i64, Long3Expr); -impl_arith_binop!(Long4, i64, Long4Expr); - -impl_arith_binop!(Ulong2, u64, Ulong2Expr); -impl_arith_binop!(Ulong3, u64, Ulong3Expr); -impl_arith_binop!(Ulong4, u64, Ulong4Expr); - -impl_int_binop!(Short2, i16, Short2Expr); -impl_int_binop!(Short3, i16, Short3Expr); -impl_int_binop!(Short4, i16, Short4Expr); - -impl_int_binop!(Ushort2, u16, Ushort2Expr); -impl_int_binop!(Ushort3, u16, Ushort3Expr); -impl_int_binop!(Ushort4, u16, Ushort4Expr); - -impl_int_binop!(Int2, i32, Int2Expr); -impl_int_binop!(Int3, i32, Int3Expr); -impl_int_binop!(Int4, i32, Int4Expr); - -impl_int_binop!(Uint2, u32, Uint2Expr); -impl_int_binop!(Uint3, u32, Uint3Expr); -impl_int_binop!(Uint4, u32, Uint4Expr); - -impl_int_binop!(Long2, i64, Long2Expr); -impl_int_binop!(Long3, i64, Long3Expr); -impl_int_binop!(Long4, i64, Long4Expr); - -impl_int_binop!(Ulong2, u64, Ulong2Expr); -impl_int_binop!(Ulong3, u64, Ulong3Expr); -impl_int_binop!(Ulong4, u64, Ulong4Expr); - -impl_bool_binop!(Bool2, Bool2Expr); -impl_bool_binop!(Bool3, Bool3Expr); -impl_bool_binop!(Bool4, Bool4Expr); - -macro_rules! impl_select { - ($bvec:ty, $vec:ty, $proxy:ty) => { - impl $proxy { - pub fn select(mask: Expr<$bvec>, a: Expr<$vec>, b: Expr<$vec>) -> Expr<$vec> { - Expr::<$vec>::from_node(__current_scope(|s| { - s.call( - Func::Select, - &[mask.node(), a.node(), b.node()], - <$vec as TypeOf>::type_(), - ) - })) - } - } - }; -} - -impl_select!(Bool2, Bool2, Bool2Expr); -impl_select!(Bool3, Bool3, Bool3Expr); -impl_select!(Bool4, Bool4, Bool4Expr); - -impl_select!(Bool2, Half2, Half2Expr); -impl_select!(Bool3, Half3, Half3Expr); -impl_select!(Bool4, Half4, Half4Expr); - -impl_select!(Bool2, Float2, Float2Expr); -impl_select!(Bool3, Float3, Float3Expr); -impl_select!(Bool4, Float4, Float4Expr); - -impl_select!(Bool2, Int2, Int2Expr); -impl_select!(Bool3, Int3, Int3Expr); -impl_select!(Bool4, Int4, Int4Expr); - -impl_select!(Bool2, Uint2, Uint2Expr); -impl_select!(Bool3, Uint3, Uint3Expr); -impl_select!(Bool4, Uint4, Uint4Expr); - -impl_select!(Bool2, Short2, Short2Expr); -impl_select!(Bool3, Short3, Short3Expr); -impl_select!(Bool4, Short4, Short4Expr); - -impl_select!(Bool2, Ushort2, Ushort2Expr); -impl_select!(Bool3, Ushort3, Ushort3Expr); -impl_select!(Bool4, Ushort4, Ushort4Expr); - -impl_select!(Bool2, Long2, Long2Expr); -impl_select!(Bool3, Long3, Long3Expr); -impl_select!(Bool4, Long4, Long4Expr); - -impl_select!(Bool2, Ulong2, Ulong2Expr); -impl_select!(Bool3, Ulong3, Ulong3Expr); -impl_select!(Bool4, Ulong4, Ulong4Expr); - -macro_rules! impl_permute { - ($tr:ident, $proxy:ty,$len:expr, $v2:ty, $v3:ty, $v4:ty) => { - impl $tr for $proxy { - type Vec2 = Expr<$v2>; - type Vec3 = Expr<$v3>; - type Vec4 = Expr<$v4>; - fn permute2(&self, x: i32, y: i32) -> Self::Vec2 { - assert!(x < $len); - assert!(y < $len); - let x: Expr = x.into(); - let y: Expr = y.into(); - Expr::<$v2>::from_node(__current_scope(|s| { - s.call( - Func::Permute, - &[self.node, ToNode::node(&x), ToNode::node(&y)], - <$v2 as TypeOf>::type_(), - ) - })) - } - fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3 { - assert!(x < $len); - assert!(y < $len); - assert!(z < $len); - let x: Expr = x.into(); - let y: Expr = y.into(); - let z: Expr = z.into(); - Expr::<$v3>::from_node(__current_scope(|s| { - s.call( - Func::Permute, - &[ - self.node, - ToNode::node(&x), - ToNode::node(&y), - ToNode::node(&z), - ], - <$v3 as TypeOf>::type_(), - ) - })) - } - fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4 { - assert!(x < $len); - assert!(y < $len); - assert!(z < $len); - assert!(w < $len); - let x: Expr = x.into(); - let y: Expr = y.into(); - let z: Expr = z.into(); - let w: Expr = w.into(); - Expr::<$v4>::from_node(__current_scope(|s| { - s.call( - Func::Permute, - &[ - self.node, - ToNode::node(&x), - ToNode::node(&y), - ToNode::node(&z), - ToNode::node(&w), - ], - <$v4 as TypeOf>::type_(), - ) - })) - } - } - }; -} - -impl_permute!(Vec2Swizzle, Half2Expr, 2, Half2, Half3, Half4); -impl_permute!(Vec3Swizzle, Half3Expr, 3, Half2, Half3, Half4); -impl_permute!(Vec4Swizzle, Half4Expr, 4, Half2, Half3, Half4); - -impl_permute!(Vec2Swizzle, Float2Expr, 2, Float2, Float3, Float4); -impl_permute!(Vec3Swizzle, Float3Expr, 3, Float2, Float3, Float4); -impl_permute!(Vec4Swizzle, Float4Expr, 4, Float2, Float3, Float4); - -impl_permute!(Vec2Swizzle, Short2Expr, 2, Short2, Short3, Short4); -impl_permute!(Vec3Swizzle, Short3Expr, 3, Short2, Short3, Short4); -impl_permute!(Vec4Swizzle, Short4Expr, 4, Short2, Short3, Short4); - -impl_permute!(Vec2Swizzle, Ushort2Expr, 2, Ushort2, Ushort3, Ushort4); -impl_permute!(Vec3Swizzle, Ushort3Expr, 3, Ushort2, Ushort3, Ushort4); -impl_permute!(Vec4Swizzle, Ushort4Expr, 4, Ushort2, Ushort3, Ushort4); - -impl_permute!(Vec2Swizzle, Int2Expr, 2, Int2, Int3, Int4); -impl_permute!(Vec3Swizzle, Int3Expr, 3, Int2, Int3, Int4); -impl_permute!(Vec4Swizzle, Int4Expr, 4, Int2, Int3, Int4); - -impl_permute!(Vec2Swizzle, Uint2Expr, 2, Uint2, Uint3, Uint4); -impl_permute!(Vec3Swizzle, Uint3Expr, 3, Uint2, Uint3, Uint4); -impl_permute!(Vec4Swizzle, Uint4Expr, 4, Uint2, Uint3, Uint4); - -impl_permute!(Vec2Swizzle, Long2Expr, 2, Long2, Long3, Long4); -impl_permute!(Vec3Swizzle, Long3Expr, 3, Long2, Long3, Long4); -impl_permute!(Vec4Swizzle, Long4Expr, 4, Long2, Long3, Long4); - -impl_permute!(Vec2Swizzle, Ulong2Expr, 2, Ulong2, Ulong3, Ulong4); -impl_permute!(Vec3Swizzle, Ulong3Expr, 3, Ulong2, Ulong3, Ulong4); -impl_permute!(Vec4Swizzle, Ulong4Expr, 4, Ulong2, Ulong3, Ulong4); - -impl Float3Expr { - #[inline] - pub fn cross(&self, rhs: Float3Expr) -> Self { - Float3Expr::from_node(__current_scope(|s| { - s.call( - Func::Cross, - &[self.node, rhs.node], - ::type_(), - ) - })) +impl, const N: usize> TypeOf for Vector { + fn type_() -> CArc { + let type_ = Type::Vector(VectorType { + element: VectorElementType::Scalar(T::primitive()), + length: N as u32, + }); + register_type(type_) } } -impl_vec_op!(Float2, f32, Float2Expr, Mat2); -impl_vec_op!(Float3, f32, Float3Expr, Mat3); -impl_vec_op!(Float4, f32, Float4Expr, Mat4); - -macro_rules! impl_var_trait2 { - ($t:ty, $v:ty) => { - impl VarTrait for $t { - type Value = $v; - type Short = Short2Expr; - type Ushort = Ushort2Expr; - type Int = Int2Expr; - type Uint = Uint2Expr; - type Float = Float2Expr; - type Half = Half2Expr; - type Bool = Bool2Expr; - type Double = Double2Expr; - type Long = Long2Expr; - type Ulong = Ulong2Expr; - } - impl CommonVarOp for $t {} - impl VarCmp for $t {} - impl VarCmpEq for $t {} - impl From<$v> for $t { - fn from(v: $v) -> Self { - Self::new((v.x).expr(), (v.y).expr()) - } - } - }; -} -macro_rules! impl_var_trait3 { - ($t:ty, $v:ty) => { - impl VarTrait for $t { - type Value = $v; - type Short = Short3Expr; - type Ushort = Ushort3Expr; - type Int = Int3Expr; - type Uint = Uint3Expr; - type Float = Float3Expr; - type Half = Half3Expr; - type Bool = Bool3Expr; - type Double = Double3Expr; - type Long = Long3Expr; - type Ulong = Ulong3Expr; - } - impl CommonVarOp for $t {} - impl VarCmp for $t {} - impl VarCmpEq for $t {} - impl From<$v> for $t { - fn from(v: $v) -> Self { - Self::new(v.x.expr(), v.y.expr(), v.z.expr()) - } - } - }; -} -macro_rules! impl_var_trait4 { - ($t:ty, $v:ty) => { - impl VarTrait for $t { - type Value = $v; - type Short = Short4Expr; - type Ushort = Ushort4Expr; - type Int = Int4Expr; - type Uint = Uint4Expr; - type Float = Float4Expr; - type Double = Double4Expr; - type Half = Half4Expr; - type Bool = Bool4Expr; - type Long = Long4Expr; - type Ulong = Ulong4Expr; - } - impl CommonVarOp for $t {} - impl VarCmp for $t {} - impl VarCmpEq for $t {} - impl From<$v> for $t { - fn from(v: $v) -> Self { - Self::new(v.x.expr(), v.y.expr(), v.z.expr(), v.w.expr()) - } - } - }; -} - -impl_var_trait2!(Half2Expr, Half2); -impl_var_trait2!(Float2Expr, Float2); -impl_var_trait2!(Double2Expr, Double2); -impl_var_trait2!(Short2Expr, Short2); -impl_var_trait2!(Ushort2Expr, Ushort2); -impl_var_trait2!(Int2Expr, Int2); -impl_var_trait2!(Uint2Expr, Uint2); -impl_var_trait2!(Bool2Expr, Bool2); -impl_var_trait2!(Long2Expr, Long2); -impl_var_trait2!(Ulong2Expr, Ulong2); - -impl_var_trait3!(Half3Expr, Half3); -impl_var_trait3!(Float3Expr, Float3); -impl_var_trait3!(Double3Expr, Double3); -impl_var_trait3!(Short3Expr, Short3); -impl_var_trait3!(Ushort3Expr, Ushort3); -impl_var_trait3!(Int3Expr, Int3); -impl_var_trait3!(Uint3Expr, Uint3); -impl_var_trait3!(Bool3Expr, Bool3); -impl_var_trait3!(Long3Expr, Long3); -impl_var_trait3!(Ulong3Expr, Ulong3); -impl_var_trait4!(Half4Expr, Half4); -impl_var_trait4!(Float4Expr, Float4); -impl_var_trait4!(Double4Expr, Double4); -impl_var_trait4!(Short4Expr, Short4); -impl_var_trait4!(Ushort4Expr, Ushort4); -impl_var_trait4!(Int4Expr, Int4); -impl_var_trait4!(Uint4Expr, Uint4); -impl_var_trait4!(Bool4Expr, Bool4); -impl_var_trait4!(Long4Expr, Long4); -impl_var_trait4!(Ulong4Expr, Ulong4); - -macro_rules! impl_float_trait { - ($t:ty) => { - impl From for $t { - fn from(v: f32) -> Self { - Self::splat(v) - } - } - impl FloatVarTrait for $t {} - }; +impl, const N: usize> Value for Vector { + type Expr = T::VectorExpr; + type Var = T::VectorVar; + type ExprData = T::VectorExprData; + type VarData = T::VectorVarData; } -impl_float_trait!(Half2Expr); -impl_float_trait!(Half3Expr); -impl_float_trait!(Half4Expr); -impl_float_trait!(Float2Expr); -impl_float_trait!(Float3Expr); -impl_float_trait!(Float4Expr); - -macro_rules! impl_int_trait { - ($t:ty) => { - impl From for $t { - fn from(v: i64) -> Self { - Self::splat(v) - } - } - impl IntVarTrait for $t {} - }; +impl, const N: usize> Vector { + fn _permute2(&self, x: u32, y: u32) -> Vec2 + where + T: VectorAlign<2>, + { + Vector::from_elements([self.elements[x as usize], self.elements[y as usize]]) + } + fn _permute3(&self, x: u32, y: u32, z: u32) -> Vec3 + where + T: VectorAlign<3>, + { + Vector::from_elements([ + self.elements[x as usize], + self.elements[y as usize], + self.elements[z as usize], + ]) + } + fn _permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Vec4 + where + T: VectorAlign<4>, + { + Vector::from_elements([ + self.elements[x as usize], + self.elements[y as usize], + self.elements[z as usize], + self.elements[w as usize], + ]) + } } -impl_int_trait!(Int2Expr); -impl_int_trait!(Int3Expr); -impl_int_trait!(Int4Expr); - -impl_int_trait!(Long2Expr); -impl_int_trait!(Long3Expr); -impl_int_trait!(Long4Expr); -impl_int_trait!(Uint2Expr); -impl_int_trait!(Uint3Expr); -impl_int_trait!(Uint4Expr); - -impl_int_trait!(Ulong2Expr); -impl_int_trait!(Ulong3Expr); -impl_int_trait!(Ulong4Expr); - -impl_int_trait!(Short2Expr); -impl_int_trait!(Short3Expr); -impl_int_trait!(Short4Expr); - -impl_int_trait!(Ushort2Expr); -impl_int_trait!(Ushort3Expr); -impl_int_trait!(Ushort4Expr); - -impl Mul for Mat2Expr { - type Output = Float2Expr; - #[inline] - fn mul(self, rhs: Float2Expr) -> Self::Output { - Float2Expr::from_node(__current_scope(|s| { - s.call( - Func::Mul, - &[self.node, rhs.node], - ::type_(), - ) - })) +impl Vec2Swizzle for Vec2 { + type Vec2 = Self; + type Vec3 = Vec3; + type Vec4 = Vec4; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { + self._permute2(x, y) } -} -impl Mat2Expr { - pub fn fill(e: impl Into> + Copy) -> Self { - Self::new(Float2::expr(e, e), Float2::expr(e, e)) + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3 { + self._permute3(x, y, z) } - pub fn eye(e: Expr) -> Self { - Self::new(Float2::expr(e.x(), 0.0), Float2::expr(0.0, e.y())) + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4 { + self._permute4(x, y, z, w) } - pub fn inverse(&self) -> Self { - Mat2Expr::from_node(__current_scope(|s| { - s.call(Func::Inverse, &[self.node], ::type_()) - })) +} +impl Vec3Swizzle for Vec3 { + type Vec2 = Vec2; + type Vec3 = Self; + type Vec4 = Vec4; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { + self._permute2(x, y) } - pub fn transpose(&self) -> Self { - Mat2Expr::from_node(__current_scope(|s| { - s.call(Func::Transpose, &[self.node], ::type_()) - })) + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3 { + self._permute3(x, y, z) } - pub fn determinant(&self) -> prim::Expr { - FromNode::from_node(__current_scope(|s| { - s.call(Func::Determinant, &[self.node], ::type_()) - })) + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4 { + self._permute4(x, y, z, w) } } -impl_arith_binop_for_mat!(Mat2, f32, Mat2Expr); -impl Mul for Mat3Expr { - type Output = Float3Expr; - #[inline] - fn mul(self, rhs: Float3Expr) -> Self::Output { - Float3Expr::from_node(__current_scope(|s| { - s.call( - Func::Mul, - &[self.node, rhs.node], - ::type_(), - ) - })) +impl Vec4Swizzle for Vec4 { + type Vec2 = Vec2; + type Vec3 = Vec3; + type Vec4 = Self; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { + self._permute2(x, y) } -} -impl Mat3Expr { - pub fn fill(e: impl Into> + Copy) -> Self { - Self::new( - Float3::expr(e, e, e), - Float3::expr(e, e, e), - Float3::expr(e, e, e), - ) + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3 { + self._permute3(x, y, z) } - pub fn eye(e: Expr) -> Self { - Self::new( - Float3::expr(e.x(), 0.0, 0.0), - Float3::expr(0.0, e.y(), 0.0), - Float3::expr(0.0, 0.0, e.z()), - ) + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4 { + self._permute4(x, y, z, w) } - pub fn inverse(&self) -> Self { - Self::from_node(__current_scope(|s| { - s.call(Func::Inverse, &[self.node], ::type_()) - })) +} + +impl Vec2Swizzle for VectorExprProxy2 { + type Vec2 = Expr>; + type Vec3 = Expr>; + type Vec4 = Expr>; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { + self._permute2(x, y) } - pub fn transpose(&self) -> Self { - Self::from_node(__current_scope(|s| { - s.call(Func::Transpose, &[self.node], ::type_()) - })) + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3 { + self._permute3(x, y, z) } - pub fn determinant(&self) -> prim::Expr { - FromNode::from_node(__current_scope(|s| { - s.call(Func::Determinant, &[self.node], ::type_()) - })) + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4 { + self._permute4(x, y, z, w) } } -impl Mul for Mat4Expr { - type Output = Float4Expr; - #[inline] - fn mul(self, rhs: Float4Expr) -> Self::Output { - Float4Expr::from_node(__current_scope(|s| { - s.call( - Func::Mul, - &[self.node, rhs.node], - ::type_(), - ) - })) +impl Vec3Swizzle for VectorExprProxy3 { + type Vec2 = Expr>; + type Vec3 = Expr>; + type Vec4 = Expr>; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { + self._permute2(x, y) } -} -impl_arith_binop_for_mat!(Mat3, f32, Mat3Expr); -impl Mat4Expr { - pub fn fill(e: impl Into> + Copy) -> Self { - Self::new( - Float4::expr(e, e, e, e), - Float4::expr(e, e, e, e), - Float4::expr(e, e, e, e), - Float4::expr(e, e, e, e), - ) + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3 { + self._permute3(x, y, z) } - pub fn eye(e: Expr) -> Self { - Self::new( - Float4::expr(e.x(), 0.0, 0.0, 0.0), - Float4::expr(0.0, e.y(), 0.0, 0.0), - Float4::expr(0.0, 0.0, e.z(), 0.0), - Float4::expr(0.0, 0.0, 0.0, e.w()), - ) + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4 { + self._permute4(x, y, z, w) } - pub fn inverse(&self) -> Self { - Self::from_node(__current_scope(|s| { - s.call(Func::Inverse, &[self.node], ::type_()) - })) +} +impl Vec4Swizzle for VectorExprProxy4 { + type Vec2 = Expr>; + type Vec3 = Expr>; + type Vec4 = Expr>; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { + self._permute2(x, y) } - pub fn transpose(&self) -> Self { - Self::from_node(__current_scope(|s| { - s.call(Func::Transpose, &[self.node], ::type_()) - })) + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3 { + self._permute3(x, y, z) } - pub fn determinant(&self) -> prim::Expr { - FromNode::from_node(__current_scope(|s| { - s.call(Func::Determinant, &[self.node], ::type_()) - })) + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4 { + self._permute4(x, y, z, w) } } -impl_arith_binop_for_mat!(Mat4, f32, Mat4Expr); -#[cfg(test)] -mod test { - #[test] - fn test_size() { - use crate::internal_prelude::*; - macro_rules! assert_size { - ($ty:ty) => { - {assert_eq!(std::mem::size_of::<$ty>(), <$ty as TypeOf>::type_().size());} - }; - ($ty:ty, $($rest:ty),*) => { - assert_size!($ty); - assert_size!($($rest),*); - }; - } - assert_size!(f32, f64, bool, u16, u32, u64, i16, i32, i64); - assert_size!(Float2, Float3, Float4, Int2, Int3, Int4, Uint2, Uint3, Uint4); - assert_size!(Short2, Short3, Short4, Ushort2, Ushort3, Ushort4); - assert_size!(Long2, Long3, Long4, Ulong2, Ulong3, Ulong4); - assert_size!(Mat2, Mat3, Mat4); - assert_size!(PackedFloat2, PackedFloat3, PackedFloat4); - assert_eq!(std::mem::size_of::(), 12); +pub type Vec2 = Vector; +pub type Vec3 = Vector; +pub type Vec4 = Vector; + +// Matrix + +impl Debug for SquareMatrix +where + f32: VectorAlign, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.elements.fmt(f) } } + +#[repr(C)] +#[derive(Copy, Clone, PartialEq)] +pub struct SquareMatrix +where + f32: VectorAlign, +{ + pub elements: [Vector; N], +} + +impl TypeOf for SquareMatrix +where + f32: VectorAlign, +{ + fn type_() -> CArc { + let type_ = Type::Matrix(ir::MatrixType { + element: VectorElementType::Scalar(f32::primitive()), + dimension: N as u32, + }); + register_type(type_) + } +} + +impl_simple_expr_proxy!(SquareMatrixExpr2 for SquareMatrix<2>); +impl_simple_var_proxy!(SquareMatrixVar2 for SquareMatrix<2>); + +impl_simple_expr_proxy!(SquareMatrixExpr3 for SquareMatrix<3>); +impl_simple_var_proxy!(SquareMatrixVar3 for SquareMatrix<3>); + +impl_simple_expr_proxy!(SquareMatrixExpr4 for SquareMatrix<4>); +impl_simple_var_proxy!(SquareMatrixVar4 for SquareMatrix<4>); + +impl Value for SquareMatrix<2> { + type Expr = SquareMatrixExpr2; + type Var = SquareMatrixVar2; + type ExprData = (); + type VarData = (); +} +impl Value for SquareMatrix<3> { + type Expr = SquareMatrixExpr3; + type Var = SquareMatrixVar3; + type ExprData = (); + type VarData = (); +} +impl Value for SquareMatrix<4> { + type Expr = SquareMatrixExpr4; + type Var = SquareMatrixVar4; + type ExprData = (); + type VarData = (); +} + +pub type Mat2 = SquareMatrix<2>; +pub type Mat3 = SquareMatrix<3>; +pub type Mat4 = SquareMatrix<4>; diff --git a/luisa_compute/src/lang/types/vector/coords.rs b/luisa_compute/src/lang/types/vector/coords.rs new file mode 100644 index 0000000..5811603 --- /dev/null +++ b/luisa_compute/src/lang/types/vector/coords.rs @@ -0,0 +1,70 @@ +use super::*; +use std::ops::{Deref, DerefMut}; + +macro_rules! impl_coords { + ($T:ident [ $($c:ident), * ]) => { + #[repr(C)] + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub struct $T { + $(pub $c: T),* + } + } +} +macro_rules! impl_deref { + ($T:ident; $N:literal) => { + impl> Deref for Vector { + type Target = $T; + + #[inline] + fn deref(&self) -> &$T { + unsafe { &*(self as *const Self as *const $T) } + } + } + + impl> DerefMut for Vector { + #[inline] + fn deref_mut(&mut self) -> &mut $T { + unsafe { &mut *(self as *mut Self as *mut $T) } + } + } + }; +} + +impl_coords!(XY[x, y]); +impl_coords!(XYZ[x, y, z]); +impl_coords!(XYZW[x, y, z, w]); +impl_coords!(RGB[r, g, b]); +impl_coords!(RGBA[r, g, b, a]); + +impl_deref![XY; 2]; +impl_deref![XYZ; 3]; +impl_deref![XYZW; 4]; + +impl Deref for XYZ { + type Target = RGB; + + #[inline] + fn deref(&self) -> &RGB { + unsafe { &*(self as *const Self as *const RGB) } + } +} +impl DerefMut for XYZ { + #[inline] + fn deref_mut(&mut self) -> &mut RGB { + unsafe { &mut *(self as *mut Self as *mut RGB) } + } +} +impl Deref for XYZW { + type Target = RGBA; + + #[inline] + fn deref(&self) -> &RGBA { + unsafe { &*(self as *const Self as *const RGBA) } + } +} +impl DerefMut for XYZW { + #[inline] + fn deref_mut(&mut self) -> &mut RGBA { + unsafe { &mut *(self as *mut Self as *mut RGBA) } + } +} diff --git a/luisa_compute/src/lang/types/vector/element.rs b/luisa_compute/src/lang/types/vector/element.rs new file mode 100644 index 0000000..8583536 --- /dev/null +++ b/luisa_compute/src/lang/types/vector/element.rs @@ -0,0 +1,71 @@ +use super::*; + +macro_rules! element { + ($T:ty [ 2 ]: $A: ident) => { + impl VectorAlign<2> for $T { + type A = $A; + type VectorExpr = VectorExprProxy2<$T>; + type VectorVar = VectorVarProxy2<$T>; + type VectorExprData = VectorExprData<$T, 2>; + type VectorVarData = VectorVarData<$T, 2>; + } + }; + ($T:ty [ 3 ]: $A: ident) => { + impl VectorAlign<3> for $T { + type A = $A; + type VectorExpr = VectorExprProxy3<$T>; + type VectorVar = VectorVarProxy3<$T>; + type VectorExprData = DoubledProxyData>; + type VectorVarData = DoubledProxyData>; + } + }; + ($T:ty [ 4 ]: $A: ident) => { + impl VectorAlign<4> for $T { + type A = $A; + type VectorExpr = VectorExprProxy4<$T>; + type VectorVar = VectorVarProxy4<$T>; + type VectorExprData = DoubledProxyData>; + type VectorVarData = DoubledProxyData>; + } + }; +} + +element!(bool[2]: Align2); +element!(bool[3]: Align4); +element!(bool[4]: Align4); +element!(u8[2]: Align2); +element!(u8[3]: Align4); +element!(u8[4]: Align4); +element!(i8[2]: Align2); +element!(i8[3]: Align4); +element!(i8[4]: Align4); + +element!(f16[2]: Align4); +element!(f16[3]: Align8); +element!(f16[4]: Align8); +element!(u16[2]: Align4); +element!(u16[3]: Align8); +element!(u16[4]: Align8); +element!(i16[2]: Align4); +element!(i16[3]: Align8); +element!(i16[4]: Align8); + +element!(f32[2]: Align8); +element!(f32[3]: Align16); +element!(f32[4]: Align16); +element!(u32[2]: Align8); +element!(u32[3]: Align16); +element!(u32[4]: Align16); +element!(i32[2]: Align8); +element!(i32[3]: Align16); +element!(i32[4]: Align16); + +element!(f64[2]: Align16); +element!(f64[3]: Align32); +element!(f64[4]: Align32); +element!(u64[2]: Align16); +element!(u64[3]: Align32); +element!(u64[4]: Align32); +element!(i64[2]: Align16); +element!(i64[3]: Align32); +element!(i64[4]: Align32); diff --git a/luisa_compute/src/lang/gen_swizzle.py b/luisa_compute/src/lang/types/vector/gen_swizzle.py similarity index 88% rename from luisa_compute/src/lang/gen_swizzle.py rename to luisa_compute/src/lang/types/vector/gen_swizzle.py index 891b1c3..f619f33 100644 --- a/luisa_compute/src/lang/gen_swizzle.py +++ b/luisa_compute/src/lang/types/vector/gen_swizzle.py @@ -22,9 +22,9 @@ def swizzle_name(perm: List[int]): s += ' type Vec2;\n' s += ' type Vec3;\n' s += ' type Vec4;\n' - s += ' fn permute2(&self, x: i32, y: i32) -> Self::Vec2;\n' - s += ' fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3;\n' - s += ' fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4;\n' + s += ' fn permute2(&self, x: u32, y: u32) -> Self::Vec2;\n' + s += ' fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3;\n' + s += ' fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4;\n' for sw in sw_m_to_n[(n, 2)]: s += ' fn {}(&self) -> Self::Vec2 {{\n'.format(swizzle_name(sw)) s += ' self.permute2({}, {})\n'.format(sw[0], sw[1]) diff --git a/luisa_compute/src/lang/types/vector/glam.rs b/luisa_compute/src/lang/types/vector/glam.rs new file mode 100644 index 0000000..43fe63d --- /dev/null +++ b/luisa_compute/src/lang/types/vector/glam.rs @@ -0,0 +1,2 @@ +use super::*; +use glam::*; diff --git a/luisa_compute/src/lang/types/vector/impls.rs b/luisa_compute/src/lang/types/vector/impls.rs new file mode 100644 index 0000000..80927d3 --- /dev/null +++ b/luisa_compute/src/lang/types/vector/impls.rs @@ -0,0 +1,47 @@ +use super::*; + +impl, const N: usize> Vector { + pub fn from_elements(elements: [T; N]) -> Self { + Self { + _align: T::A::default(), + elements, + } + } + pub fn splat(element: T) -> Self { + Self { + _align: T::A::default(), + elements: [element; N], + } + } + pub fn splat_expr(element: impl AsExpr) -> Expr { + Func::Vec.call(element.as_expr()) + } + pub fn map(&self, f: impl Fn(T) -> T) -> Self { + Self { + _align: T::A::default(), + elements: self.elements.map(f), + } + } + pub fn expr_from_elements(elements: [Expr; N]) -> Expr { + Expr::::from_node(__compose::(&elements.map(|x| x.node()))) + } +} + +macro_rules! impl_sized { + ($Vn:ident($N: literal): $($xs:ident),+) => { + impl> $Vn { + pub fn new($($xs: T),+) -> Self { + Self { + _align: T::A::default(), + elements: [$($xs),+], + } + } + pub fn expr($($xs: impl AsExpr),+) -> Expr { + Self::expr_from_elements([$($xs.as_expr()),+]) + } + } + } +} +impl_sized!(Vec2(2): x, y); +impl_sized!(Vec3(3): x, y, z); +impl_sized!(Vec4(4): x, y, z, w); diff --git a/luisa_compute/src/lang/types/vector/legacy.rs b/luisa_compute/src/lang/types/vector/legacy.rs new file mode 100644 index 0000000..bd7c9e3 --- /dev/null +++ b/luisa_compute/src/lang/types/vector/legacy.rs @@ -0,0 +1,208 @@ +use super::*; + +#[deprecated] +pub type Half2 = Vec2; +#[deprecated] +pub type Half3 = Vec3; +#[deprecated] +pub type Half4 = Vec4; +#[deprecated] +pub type Float2 = Vec2; +#[deprecated] +pub type Float3 = Vec3; +#[deprecated] +pub type Float4 = Vec4; +#[deprecated] +pub type Double2 = Vec2; +#[deprecated] +pub type Double3 = Vec3; +#[deprecated] +pub type Double4 = Vec4; +#[deprecated] +pub type Byte2 = Vec2; +#[deprecated] +pub type Byte3 = Vec3; +#[deprecated] +pub type Byte4 = Vec4; +#[deprecated] +pub type Short2 = Vec2; +#[deprecated] +pub type Short3 = Vec3; +#[deprecated] +pub type Short4 = Vec4; +#[deprecated] +pub type Int2 = Vec2; +#[deprecated] +pub type Int3 = Vec3; +#[deprecated] +pub type Int4 = Vec4; +#[deprecated] +pub type Long2 = Vec2; +#[deprecated] +pub type Long3 = Vec3; +#[deprecated] +pub type Long4 = Vec4; +#[deprecated] +pub type Ubyte2 = Vec2; +#[deprecated] +pub type Ubyte3 = Vec3; +#[deprecated] +pub type Ubyte4 = Vec4; +#[deprecated] +pub type Ushort2 = Vec2; +#[deprecated] +pub type Ushort3 = Vec3; +#[deprecated] +pub type Ushort4 = Vec4; +#[deprecated] +pub type Uint2 = Vec2; +#[deprecated] +pub type Uint3 = Vec3; +#[deprecated] +pub type Uint4 = Vec4; +#[deprecated] +pub type Ulong2 = Vec2; +#[deprecated] +pub type Bool2 = Vec2; +#[deprecated] +pub type Bool3 = Vec3; +#[deprecated] +pub type Bool4 = Vec4; + +#[deprecated] +pub type Half2Expr = Expr>; +#[deprecated] +pub type Half3Expr = Expr>; +#[deprecated] +pub type Half4Expr = Expr>; +#[deprecated] +pub type Float2Expr = Expr>; +#[deprecated] +pub type Float3Expr = Expr>; +#[deprecated] +pub type Float4Expr = Expr>; +#[deprecated] +pub type Double2Expr = Expr>; +#[deprecated] +pub type Double3Expr = Expr>; +#[deprecated] +pub type Double4Expr = Expr>; +#[deprecated] +pub type Byte2Expr = Expr>; +#[deprecated] +pub type Byte3Expr = Expr>; +#[deprecated] +pub type Byte4Expr = Expr>; +#[deprecated] +pub type Short2Expr = Expr>; +#[deprecated] +pub type Short3Expr = Expr>; +#[deprecated] +pub type Short4Expr = Expr>; +#[deprecated] +pub type Int2Expr = Expr>; +#[deprecated] +pub type Int3Expr = Expr>; +#[deprecated] +pub type Int4Expr = Expr>; +#[deprecated] +pub type Long2Expr = Expr>; +#[deprecated] +pub type Long3Expr = Expr>; +#[deprecated] +pub type Long4Expr = Expr>; +#[deprecated] +pub type Ubyte2Expr = Expr>; +#[deprecated] +pub type Ubyte3Expr = Expr>; +#[deprecated] +pub type Ubyte4Expr = Expr>; +#[deprecated] +pub type Ushort2Expr = Expr>; +#[deprecated] +pub type Ushort3Expr = Expr>; +#[deprecated] +pub type Ushort4Expr = Expr>; +#[deprecated] +pub type Uint2Expr = Expr>; +#[deprecated] +pub type Uint3Expr = Expr>; +#[deprecated] +pub type Uint4Expr = Expr>; +#[deprecated] +pub type Ulong2Expr = Expr>; +#[deprecated] +pub type Bool2Expr = Expr>; +#[deprecated] +pub type Bool3Expr = Expr>; +#[deprecated] +pub type Bool4Expr = Expr>; + +#[deprecated] +pub type Half2Var = Var>; +#[deprecated] +pub type Half3Var = Var>; +#[deprecated] +pub type Half4Var = Var>; +#[deprecated] +pub type Float2Var = Var>; +#[deprecated] +pub type Float3Var = Var>; +#[deprecated] +pub type Float4Var = Var>; +#[deprecated] +pub type Double2Var = Var>; +#[deprecated] +pub type Double3Var = Var>; +#[deprecated] +pub type Double4Var = Var>; +#[deprecated] +pub type Byte2Var = Var>; +#[deprecated] +pub type Byte3Var = Var>; +#[deprecated] +pub type Byte4Var = Var>; +#[deprecated] +pub type Short2Var = Var>; +#[deprecated] +pub type Short3Var = Var>; +#[deprecated] +pub type Short4Var = Var>; +#[deprecated] +pub type Int2Var = Var>; +#[deprecated] +pub type Int3Var = Var>; +#[deprecated] +pub type Int4Var = Var>; +#[deprecated] +pub type Long2Var = Var>; +#[deprecated] +pub type Long3Var = Var>; +#[deprecated] +pub type Long4Var = Var>; +#[deprecated] +pub type Ubyte2Var = Var>; +#[deprecated] +pub type Ubyte3Var = Var>; +#[deprecated] +pub type Ubyte4Var = Var>; +#[deprecated] +pub type Ushort2Var = Var>; +#[deprecated] +pub type Ushort3Var = Var>; +#[deprecated] +pub type Ushort4Var = Var>; +#[deprecated] +pub type Uint2Var = Var>; +#[deprecated] +pub type Uint3Var = Var>; +#[deprecated] +pub type Uint4Var = Var>; +#[deprecated] +pub type Ulong2Var = Var>; +#[deprecated] +pub type Bool2Var = Var>; +#[deprecated] +pub type Bool3Var = Var>; +#[deprecated] +pub type Bool4Var = Var>; diff --git a/luisa_compute/src/lang/types/vector/nalgebra.rs b/luisa_compute/src/lang/types/vector/nalgebra.rs new file mode 100644 index 0000000..bcb7ea0 --- /dev/null +++ b/luisa_compute/src/lang/types/vector/nalgebra.rs @@ -0,0 +1,2 @@ +use super::*; +use nalgebra::*; diff --git a/luisa_compute/src/lang/swizzle.rs b/luisa_compute/src/lang/types/vector/swizzle.rs similarity index 98% rename from luisa_compute/src/lang/swizzle.rs rename to luisa_compute/src/lang/types/vector/swizzle.rs index 2696954..188d599 100644 --- a/luisa_compute/src/lang/swizzle.rs +++ b/luisa_compute/src/lang/types/vector/swizzle.rs @@ -2,9 +2,9 @@ pub trait Vec2Swizzle { type Vec2; type Vec3; type Vec4; - fn permute2(&self, x: i32, y: i32) -> Self::Vec2; - fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3; - fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2; + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3; + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4; fn xx(&self) -> Self::Vec2 { self.permute2(0, 0) } @@ -94,9 +94,9 @@ pub trait Vec3Swizzle { type Vec2; type Vec3; type Vec4; - fn permute2(&self, x: i32, y: i32) -> Self::Vec2; - fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3; - fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2; + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3; + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4; fn xx(&self) -> Self::Vec2 { self.permute2(0, 0) } @@ -453,9 +453,9 @@ pub trait Vec4Swizzle { type Vec2; type Vec3; type Vec4; - fn permute2(&self, x: i32, y: i32) -> Self::Vec2; - fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3; - fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2; + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3; + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4; fn xx(&self) -> Self::Vec2 { self.permute2(0, 0) } diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index ff9bb44..238be8d 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -22,22 +22,16 @@ pub mod prelude { pub use crate::lang::functions::{block_size, dispatch_id, dispatch_size, set_block_size}; pub use crate::lang::index::{IndexRead, IndexWrite}; pub use crate::lang::ops::*; - pub use crate::lang::swizzle::*; + pub use crate::lang::types::vector::swizzle::*; pub use crate::lang::types::vector::{ - Bool2, Bool3, Bool4, Byte2, Byte3, Byte4, Double2, Double3, Double4, Float2, Float3, - Float4, Half2, Half3, Half4, Int2, Int3, Int4, Long2, Long3, Long4, Mat2, Mat3, Mat4, - PackedBool2, PackedBool3, PackedBool4, PackedFloat2, PackedFloat3, PackedFloat4, - PackedInt2, PackedInt3, PackedInt4, PackedLong2, PackedLong3, PackedLong4, PackedShort2, - PackedShort3, PackedShort4, PackedUint2, PackedUint3, PackedUint4, PackedUlong2, - PackedUlong3, PackedUlong4, PackedUshort2, PackedUshort3, PackedUshort4, Short2, Short3, - Short4, Ubyte2, Ubyte3, Ubyte4, Uint2, Uint3, Uint4, Ulong2, Ulong3, Ulong4, Ushort2, - Ushort3, Ushort4, + Mat2, Mat3, Mat4, SquareMatrix, Vec2, Vec3, Vec4, Vector, }; - pub use crate::lang::types::{Expr, ExprProxy, Value, Var, VarProxy}; + pub use crate::lang::types::{AsExpr, Expr, Value, Var}; pub use crate::lang::Aggregate; pub use crate::resource::{IoTexel, StorageTexel, *}; + pub use crate::runtime::api::StreamTag; pub use crate::runtime::{ - api::StreamTag, create_static_callable, Command, Device, KernelBuildOptions, Scope, Stream, + create_static_callable, Command, Device, KernelBuildOptions, Scope, Stream, }; pub use crate::{cpu_dbg, if_, lc_assert, lc_unreachable, loop_, struct_, while_, Context}; @@ -46,22 +40,25 @@ pub mod prelude { } mod internal_prelude { - pub(crate) use crate::lang::debug::{CpuFn, __env_need_backtrace, is_cpu_backend}; + pub(crate) use crate::lang::debug::{__env_need_backtrace, is_cpu_backend, CpuFn}; pub(crate) use crate::lang::ir::ffi::*; pub(crate) use crate::lang::ir::{ new_node, register_type, BasicBlock, Const, Func, Instruction, IrBuilder, Node, PhiIncoming, Pooled, Type, TypeOf, INVALID_REF, }; - pub(crate) use crate::lang::types::vector::*; + pub(crate) use crate::lang::types::vector::legacy::*; pub(crate) use crate::lang::{ - ir, Recorder, __compose, __extract, __insert, __module_pools, need_runtime_check, FromNode, - NodeRef, ToNode, __current_scope, __pop_scope, RECORDER, + ir, CallFuncTrait, Recorder, __compose, __extract, __insert, __module_pools, + need_runtime_check, FromNode, NodeLike, NodeRef, ToNode, __current_scope, __pop_scope, + RECORDER, }; pub(crate) use crate::prelude::*; pub(crate) use crate::runtime::{ CallableArgEncoder, CallableParameter, CallableRet, KernelBuilder, }; - pub(crate) use crate::{get_backtrace, impl_callable_param, ResourceTracker}; + pub(crate) use crate::{ + get_backtrace, impl_simple_expr_proxy, impl_simple_var_proxy, ResourceTracker, + }; pub(crate) use luisa_compute_backend::Backend; pub(crate) use std::marker::PhantomData; } @@ -99,7 +96,8 @@ lazy_static! { } impl Context { /// path to libluisa-* - /// if the current_exe() is in the same directory as libluisa-*, then passing current_exe() is enough + /// if the current_exe() is in the same directory as libluisa-*, then + /// passing current_exe() is enough pub fn new(lib_path: impl AsRef) -> Self { let mut lib_path = lib_path.as_ref().to_path_buf(); lib_path = lib_path.canonicalize().unwrap(); diff --git a/luisa_compute/src/printer.rs b/luisa_compute/src/printer.rs index 61cdc76..a41b988 100644 --- a/luisa_compute/src/printer.rs +++ b/luisa_compute/src/printer.rs @@ -33,13 +33,11 @@ pub struct PrinterArgs { count: usize, } impl PrinterArgs { - pub fn append(&mut self, v: E) - where - E::Value: Debug, - { - let n = packed_size::(); + pub fn append(&mut self, v: Expr) { + let n = packed_size::(); self.count_per_arg.push(n); self.pack_fn.push(Box::new(move |offset, data| { + let v = (&v).clone(); pack_to(v, data, offset); })); self.count += n; @@ -104,8 +102,8 @@ macro_rules! lc_error { $crate::lc_log!($printer, log::Level::Error, $fmt, $($arg)*); }; } -pub fn _unpack_from_expr(data: *const u32, _: E) -> E::Value { - unsafe { std::ptr::read_unaligned(data as *const E::Value) } +pub fn _unpack_from_expr(data: *const u32, _: Expr) -> V { + unsafe { std::ptr::read_unaligned(data as *const V) } } impl Printer { pub fn new(device: &Device, size: usize) -> Self { @@ -130,8 +128,8 @@ impl Printer { let item_id = items.len() as u32; if_!( - offset.cmplt(data.len().uint()) - & (offset + 1 + args.count as u32).cmple(data.len().uint()), + offset.lt(data.len().cast::()) + & (offset + 1 + args.count as u32).le(data.len().cast::()), { data.atomic_fetch_add(0, 1); data.write(offset, item_id); diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index 63848be..8138012 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -906,15 +906,15 @@ macro_rules! impl_io_texel { } }; } -impl_io_texel!(f32, f32, Float4, |x: Float4Expr| x.x(), |x| { - Float4Expr::splat(x) +impl_io_texel!(f32, f32, Float4, |x: Float4Expr| x.x, |x| { + Float4::splat_expr(x) }); impl_io_texel!( Float2, f32, Float4, |x: Float4Expr| x.xy(), - |x: Float2Expr| { Float4::expr(x.x(), x.y(), 0.0, 0.0) } + |x: Float2Expr| { Float4::expr(x.x, x.y, 0.0, 0.0) } ); impl_io_texel!(Float4, f32, Float4, |x: Float4Expr| x, |x: Float4Expr| x); @@ -924,15 +924,15 @@ impl_io_texel!(Float4, f32, Float4, |x: Float4Expr| x, |x: Float4Expr| x); // impl_io_texel!(Short2,); // impl_io_texel!(Ushort4,); // impl_io_texel!(Short4,); -impl_io_texel!(u32, u32, Uint4, |x: Uint4Expr| x.x(), |x| Uint4Expr::splat( +impl_io_texel!(u32, u32, Uint4, |x: Uint4Expr| x.x, |x| Uint4::splat_expr( x )); -impl_io_texel!(i32, i32, Int4, |x: Int4Expr| x.x(), |x| Int4Expr::splat(x)); +impl_io_texel!(i32, i32, Int4, |x: Int4Expr| x.x, |x| Int4::splat_expr(x)); impl_io_texel!(Uint2, u32, Uint4, |x: Uint4Expr| x.xy(), |x: Uint2Expr| { - Uint4::expr(x.x(), x.y(), 0u32, 0u32) + Uint4::expr(x.x, x.y, 0u32, 0u32) }); impl_io_texel!(Int2, i32, Int4, |x: Int4Expr| x.xy(), |x: Int2Expr| { - Int4::expr(x.x(), x.y(), 0i32, 0i32) + Int4::expr(x.x, x.y, 0i32, 0i32) }); impl_io_texel!(Uint4, u32, Uint4, |x: Uint4Expr| x, |x| x); impl_io_texel!(Int4, i32, Int4, |x: Int4Expr| x, |x| x); @@ -1003,8 +1003,9 @@ impl_storage_texel!(Half4, Half4, f32, Float2, Float4,); impl_storage_texel!([f16; 2], Byte2, f32, Float2, Float4, Int2, Int4, Uint2, Uint4,); impl_storage_texel!([f16; 4], Byte4, f32, Float2, Float4, Int2, Int4, Uint2, Uint4,); -// `T` is the read out type of the texture, which is not necessarily the same as the storage type -// In fact, the texture can be stored in any format as long as it can be converted to `T` +// `T` is the read out type of the texture, which is not necessarily the same as +// the storage type In fact, the texture can be stored in any format as long as +// it can be converted to `T` pub struct Tex2d { #[allow(dead_code)] pub(crate) width: u32, @@ -1014,8 +1015,9 @@ pub struct Tex2d { pub(crate) marker: PhantomData, } -// `T` is the read out type of the texture, which is not necessarily the same as the storage type -// In fact, the texture can be stored in any format as long as it can be converted to `T` +// `T` is the read out type of the texture, which is not necessarily the same as +// the storage type In fact, the texture can be stored in any format as long as +// it can be converted to `T` pub struct Tex3d { #[allow(dead_code)] pub(crate) width: u32, @@ -1214,8 +1216,8 @@ impl<'a, T: IoTexel> Tex2dView<'a, T> { } pub fn size(&self) -> [u32; 3] { [ - (self.tex.handle.width >> self.level).max(1), - (self.tex.handle.height >> self.level).max(1), + Ord::max(self.tex.handle.width >> self.level, 1), + Ord::max(self.tex.handle.height >> self.level, 1), 1, ] } @@ -1234,9 +1236,9 @@ impl<'a, T: IoTexel> Tex3dView<'a, T> { } pub fn size(&self) -> [u32; 3] { [ - (self.tex.handle.width >> self.level).max(1), - (self.tex.handle.height >> self.level).max(1), - (self.tex.handle.depth >> self.level).max(1), + Ord::max(self.tex.handle.width >> self.level, 1), + Ord::max(self.tex.handle.height >> self.level, 1), + Ord::max(self.tex.handle.depth >> self.level, 1), ] } pub fn var(&self) -> Tex3dVar { @@ -1325,7 +1327,7 @@ impl IndexRead for BindlessBufferVar { fn read(&self, i: I) -> Expr { let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::::from_node(__current_scope(|b| { @@ -1614,7 +1616,7 @@ impl BindlessArrayVar { } else if is_cpu_backend() { if need_runtime_check() { let expected = type_hash(&T::type_()); - lc_assert!(v.__type().cmpeq(expected)); + lc_assert!(v.__type().eq(expected)); } } v @@ -1661,7 +1663,7 @@ impl IndexRead for Buffer { } } impl IndexWrite for Buffer { - fn write>>(&self, i: I, v: V) { + fn write>(&self, i: I, v: V) { self.var().write(i, v) } } @@ -1670,7 +1672,7 @@ impl IndexRead for BufferVar { fn read(&self, i: I) -> Expr { let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } __current_scope(|b| { FromNode::from_node(b.call( @@ -1682,11 +1684,11 @@ impl IndexRead for BufferVar { } } impl IndexWrite for BufferVar { - fn write>>(&self, i: I, v: V) { + fn write>(&self, i: I, v: V) { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } __current_scope(|b| { b.call( @@ -1736,11 +1738,15 @@ impl BufferVar { macro_rules! impl_atomic { ($t:ty) => { impl BufferVar<$t> { - pub fn atomic_exchange>>(&self, i: I, v: V) -> Expr<$t> { + pub fn atomic_exchange>( + &self, + i: I, + v: V, + ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1750,17 +1756,21 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_compare_exchange>, V1: Into>>( + pub fn atomic_compare_exchange< + I: IntoIndex, + V0: AsExpr, + V1: AsExpr, + >( &self, i: I, expected: V0, desired: V1, ) -> Expr<$t> { let i = i.to_u64(); - let expected = expected.into(); - let desired = desired.into(); + let expected = expected.as_expr(); + let desired = desired.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1770,15 +1780,15 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_fetch_add>>( + pub fn atomic_fetch_add>( &self, i: I, v: V, ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1788,15 +1798,15 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_fetch_sub>>( + pub fn atomic_fetch_sub>( &self, i: I, v: V, ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1806,15 +1816,15 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_fetch_min>>( + pub fn atomic_fetch_min>( &self, i: I, v: V, ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1824,15 +1834,15 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_fetch_max>>( + pub fn atomic_fetch_max>( &self, i: I, v: V, ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1848,15 +1858,15 @@ macro_rules! impl_atomic { macro_rules! impl_atomic_bit { ($t:ty) => { impl BufferVar<$t> { - pub fn atomic_fetch_and>>( + pub fn atomic_fetch_and>( &self, i: I, v: V, ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1866,11 +1876,15 @@ macro_rules! impl_atomic_bit { ) })) } - pub fn atomic_fetch_or>>(&self, i: I, v: V) -> Expr<$t> { + pub fn atomic_fetch_or>( + &self, + i: I, + v: V, + ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1880,15 +1894,15 @@ macro_rules! impl_atomic_bit { ) })) } - pub fn atomic_fetch_xor>>( + pub fn atomic_fetch_xor>( &self, i: I, v: V, ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 90dc02b..4cb8a61 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -957,7 +957,7 @@ impl CallableArgEncoder { pub fn bindless_array(&mut self, array: &BindlessArrayVar) { self.args.push(array.node); } - pub fn var(&mut self, value: impl FromNode) { + pub fn var(&mut self, value: impl NodeLike) { self.args.push(value.node()); } pub fn accel(&mut self, accel: &rtx::AccelVar) { diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index 1bee826..d70a26b 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -1,31 +1,20 @@ use super::*; -#[macro_export] -macro_rules! impl_callable_param { - ($t:ty, $e:ty, $v:ty) => { - impl CallableParameter for $e { - fn def_param( - _: Option>, - builder: &mut KernelBuilder, - ) -> Self { - builder.value::<$t>() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.var(*self) - } - } - impl CallableParameter for $v { - fn def_param( - _: Option>, - builder: &mut KernelBuilder, - ) -> Self { - builder.var::<$t>() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.var(*self) - } - } - }; +impl CallableParameter for Expr { + fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { + builder.value::() + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.var(self.clone()) + } +} +impl CallableParameter for Var { + fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { + builder.var::() + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.var(self.clone()) + } } // Not recommended to use this directly @@ -116,24 +105,11 @@ impl CallableParameter for BindlessArrayVar { } } -impl CallableParameter for rtx::AccelVar { - fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { - builder.accel() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.accel(self) - } -} - pub trait KernelParameter { fn def_param(builder: &mut KernelBuilder) -> Self; } -impl KernelParameter for U -where - U: ExprProxy, - T: Value, -{ +impl KernelParameter for Expr { fn def_param(builder: &mut KernelBuilder) -> Self { builder.uniform::() } @@ -167,11 +143,6 @@ impl KernelParameter for BindlessArrayVar { } } -impl KernelParameter for rtx::AccelVar { - fn def_param(builder: &mut KernelBuilder) -> Self { - builder.accel() - } -} macro_rules! impl_kernel_param_for_tuple { ($first:ident $($rest:ident)*) => { impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelParameter for ($first, $($rest,)*) { @@ -290,14 +261,6 @@ impl KernelBuilder { self.args.push(node); BindlessArrayVar { node, handle: None } } - pub fn accel(&mut self) -> rtx::AccelVar { - let node = new_node( - __module_pools(), - Node::new(CArc::new(Instruction::Accel), Type::void()), - ); - self.args.push(node); - rtx::AccelVar { node, handle: None } - } fn collect_module_info(&self) -> (ResourceTracker, Vec>, Vec) { RECORDER.with(|r| { let mut resource_tracker = ResourceTracker::new(); @@ -518,12 +481,12 @@ unsafe impl CallableRet for () { fn _from_return(_: NodeRef) -> Self {} } -unsafe impl CallableRet for T { +unsafe impl CallableRet for Expr { fn _return(&self) -> CArc { __current_scope(|b| { b.return_(self.node()); }); - T::Value::type_() + V::type_() } fn _from_return(node: NodeRef) -> Self { Self::from_node(node) diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index b8b284f..517ae49 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit b8b284f5c8b8ee4470298f362a1cd7f9a7c79698 +Subproject commit 517ae49e84b6d255739d92d28590588c9de4ee56 diff --git a/luisa_compute_track/src/lib.rs b/luisa_compute_track/src/lib.rs index 3129809..90f943d 100644 --- a/luisa_compute_track/src/lib.rs +++ b/luisa_compute_track/src/lib.rs @@ -5,6 +5,10 @@ use syn::spanned::Spanned; use syn::visit_mut::*; use syn::*; +// TODO: Impl let mut -> let = .var() +// TODO: Impl x as f32 -> .cast() +// TOOD: Impl switch! macro. + #[cfg(test)] use pretty_assertions::assert_eq; @@ -40,6 +44,20 @@ impl VisitMut for TraceVisitor { let trait_path = &self.trait_path; let span = node.span(); match node { + Expr::Assign(expr) => { + let left = &expr.left; + let right = &expr.right; + if let Expr::Unary(ExprUnary { + op: UnOp::Deref(_), + expr, + .. + }) = &**left + { + *node = parse_quote_spanned! {span=> + <_ as #trait_path::StoreMaybeExpr>::deref_set(#expr, #right) + } + } + } Expr::If(expr) => { let cond = &expr.cond; let then_branch = &expr.then_branch; @@ -47,11 +65,11 @@ impl VisitMut for TraceVisitor { if let Expr::Let(_) = **cond { } else if let Some((_, else_branch)) = else_branch { *node = parse_quote_spanned! {span=> - <_ as #trait_path::BoolIfElseMaybeExpr<_>>::if_then_else(#cond, || #then_branch, || #else_branch) + <_ as #trait_path::SelectMaybeExpr<_>>::select(#cond, || #then_branch, || #else_branch) } } else { *node = parse_quote_spanned! {span=> - <_ as #trait_path::BoolIfMaybeExpr>::if_then(#cond, || #then_branch) + <_ as #trait_path::ActivateMaybeExpr>::activate(#cond, || #then_branch) } } } @@ -59,7 +77,7 @@ impl VisitMut for TraceVisitor { let cond = &expr.cond; let body = &expr.body; *node = parse_quote_spanned! {span=> - <_ as #trait_path::BoolWhileMaybeExpr>::while_loop(|| #cond, || #body) + <_ as #trait_path::LoopMaybeExpr>::while_loop(|| #cond, || #body) } } Expr::Loop(expr) => { @@ -99,15 +117,15 @@ impl VisitMut for TraceVisitor { let op_fn = Ident::new(op_fn_str, expr.op.span()); if op_fn_str == "eq" || op_fn_str == "ne" { *node = parse_quote_spanned! {span=> - <_ as #trait_path::EqMaybeExpr<_>>::#op_fn(#left, #right) + <_ as #trait_path::EqMaybeExpr<_, _>>::#op_fn(#left, #right) } } else if op_fn_str == "and" || op_fn_str == "or" { *node = parse_quote_spanned! {span=> - <_ as #trait_path::BoolLazyOpsMaybeExpr<_>>::#op_fn(#left, || #right) + <_ as #trait_path::LazyBoolMaybeExpr<_>>::#op_fn(#left, || #right) } } else { *node = parse_quote_spanned! {span=> - <_ as #trait_path::PartialOrdMaybeExpr<_>>::#op_fn(#left, #right) + <_ as #trait_path::CmpMaybeExpr<_, _>>::#op_fn(#left, #right) } } } @@ -171,6 +189,9 @@ impl VisitMut for TraceVisitor { #[proc_macro] pub fn track(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = TokenStream::from(input); + let input = quote!({ #input }); + let input = proc_macro::TokenStream::from(input); track_impl(parse_macro_input!(input as Expr)).into() } diff --git a/rustfmt.toml b/rustfmt.toml index d1bdb5e..174fd7a 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,2 +1,3 @@ ignore = ["luisa_compute_sys"] imports_granularity = "Module" +wrap_comments = true