From 658050a0f9b77188aaf80f1f84ecb819662df128 Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Mon, 18 Sep 2023 10:30:11 +0100 Subject: [PATCH 01/15] Rename FromNode to NodeLike and split up. --- luisa_compute/src/lang.rs | 5 ++++- luisa_compute/src/lang/control_flow.rs | 6 ++++-- luisa_compute/src/lang/diff.rs | 2 +- luisa_compute/src/lang/ops.rs | 2 +- luisa_compute/src/lang/types.rs | 4 ++-- luisa_compute/src/lib.rs | 5 +++-- luisa_compute/src/runtime.rs | 2 +- 7 files changed, 16 insertions(+), 10 deletions(-) diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index dae2d7e..c176c2f 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -135,7 +135,10 @@ pub trait ToNode { fn node(&self) -> NodeRef; } -pub trait FromNode: ToNode { +pub trait NodeLike: FromNode + ToNode {} +impl NodeLike for T where T: FromNode + ToNode {} + +pub trait FromNode { fn from_node(node: NodeRef) -> Self; } diff --git a/luisa_compute/src/lang/control_flow.rs b/luisa_compute/src/lang/control_flow.rs index 9324695..c02751d 100644 --- a/luisa_compute/src/lang/control_flow.rs +++ b/luisa_compute/src/lang/control_flow.rs @@ -66,7 +66,7 @@ pub fn continue_() { }); } -pub fn return_v(v: T) { +pub fn return_v(v: T) { RECORDER.with(|r| { let mut r = r.borrow_mut(); if r.callable_ret_type.is_none() { @@ -296,7 +296,9 @@ impl_range!(u32); impl_range!(u64); pub fn loop_(body: impl Fn()) { - while_!(true.expr(), { body(); }); + while_!(true.expr(), { + body(); + }); } pub fn for_range(r: R, body: impl Fn(Expr)) { diff --git a/luisa_compute/src/lang/diff.rs b/luisa_compute/src/lang/diff.rs index cb37fbb..55278f1 100644 --- a/luisa_compute/src/lang/diff.rs +++ b/luisa_compute/src/lang/diff.rs @@ -118,7 +118,7 @@ pub fn grad(var: T) -> T { // .collect(); // R::from_vec_nodes(nodes) // } -pub fn detach(v: T) -> T { +pub fn detach(v: T) -> T { let v = v.node(); let node = __current_scope(|b| b.call(Func::Detach, &[v], v.type_().clone())); T::from_node(node) diff --git a/luisa_compute/src/lang/ops.rs b/luisa_compute/src/lang/ops.rs index 4920bf7..ff514bc 100644 --- a/luisa_compute/src/lang/ops.rs +++ b/luisa_compute/src/lang/ops.rs @@ -3,7 +3,7 @@ use std::ops::*; pub mod impls; -pub trait VarTrait: Copy + Clone + 'static + FromNode { +pub trait VarTrait: Copy + Clone + 'static + NodeLike { type Value: Value; type Short: VarTrait; type Ushort: VarTrait; diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs index dc8f3f1..f6bb7a9 100644 --- a/luisa_compute/src/lang/types.rs +++ b/luisa_compute/src/lang/types.rs @@ -25,7 +25,7 @@ pub trait Value: Copy + ir::TypeOf + 'static { } } -pub trait ExprProxy: Copy + Aggregate + FromNode { +pub trait ExprProxy: Copy + Aggregate + NodeLike { type Value: Value; fn var(self) -> Var { @@ -37,7 +37,7 @@ pub trait ExprProxy: Copy + Aggregate + FromNode { } } -pub trait VarProxy: Copy + Aggregate + FromNode { +pub trait VarProxy: Copy + Aggregate + NodeLike { type Value: Value; fn store>>(&self, value: U) { let value = value.into(); diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index ff9bb44..42cf920 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -36,8 +36,9 @@ pub mod prelude { pub use crate::lang::types::{Expr, ExprProxy, Value, Var, VarProxy}; pub use crate::lang::Aggregate; pub use crate::resource::{IoTexel, StorageTexel, *}; + pub use crate::runtime::api::StreamTag; pub use crate::runtime::{ - api::StreamTag, create_static_callable, Command, Device, KernelBuildOptions, Scope, Stream, + create_static_callable, Command, Device, KernelBuildOptions, Scope, Stream, }; pub use crate::{cpu_dbg, if_, lc_assert, lc_unreachable, loop_, struct_, while_, Context}; @@ -55,7 +56,7 @@ mod internal_prelude { pub(crate) use crate::lang::types::vector::*; pub(crate) use crate::lang::{ ir, Recorder, __compose, __extract, __insert, __module_pools, need_runtime_check, FromNode, - NodeRef, ToNode, __current_scope, __pop_scope, RECORDER, + NodeLike, NodeRef, ToNode, __current_scope, __pop_scope, RECORDER, }; pub(crate) use crate::prelude::*; pub(crate) use crate::runtime::{ diff --git a/luisa_compute/src/runtime.rs b/luisa_compute/src/runtime.rs index 90dc02b..4cb8a61 100644 --- a/luisa_compute/src/runtime.rs +++ b/luisa_compute/src/runtime.rs @@ -957,7 +957,7 @@ impl CallableArgEncoder { pub fn bindless_array(&mut self, array: &BindlessArrayVar) { self.args.push(array.node); } - pub fn var(&mut self, value: impl FromNode) { + pub fn var(&mut self, value: impl NodeLike) { self.args.push(value.node()); } pub fn accel(&mut self, accel: &rtx::AccelVar) { From 7d98587f1c87f1dbc736438d2678d656ae29bb70 Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Mon, 18 Sep 2023 12:49:24 +0100 Subject: [PATCH 02/15] Minor track fix. --- luisa_compute_track/src/lib.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/luisa_compute_track/src/lib.rs b/luisa_compute_track/src/lib.rs index 3129809..6a94f79 100644 --- a/luisa_compute_track/src/lib.rs +++ b/luisa_compute_track/src/lib.rs @@ -171,6 +171,9 @@ impl VisitMut for TraceVisitor { #[proc_macro] pub fn track(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = TokenStream::from(input); + let input = quote!({ #input }); + let input = proc_macro::TokenStream::from(input); track_impl(parse_macro_input!(input as Expr)).into() } From b2b97238e8db7d02942fc8f0b761028417557b97 Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Tue, 19 Sep 2023 12:41:22 +0100 Subject: [PATCH 03/15] Struct expr refactoring pt 1. --- luisa_compute/Cargo.toml | 25 +- luisa_compute/src/lang.rs | 6 + luisa_compute/src/lang/maybe_expr.rs | 78 ++++ luisa_compute/src/lang/ops/impls.rs | 2 +- luisa_compute/src/lang/types.rs | 374 +++++++++++----- luisa_compute/src/lang/types/alignment.rs | 26 ++ luisa_compute/src/lang/types/array.rs | 131 +----- luisa_compute/src/lang/types/core.rs | 273 ++++++------ luisa_compute/src/lang/types/vector.rs | 418 ++++-------------- luisa_compute/src/lang/types/vector/coords.rs | 1 + luisa_compute/src/lang/types/vector/glam.rs | 2 + .../src/lang/types/vector/nalgebra.rs | 2 + luisa_compute/src/lib.rs | 7 +- luisa_compute/src/runtime/kernel.rs | 41 +- luisa_compute_track/src/lib.rs | 14 + rustfmt.toml | 1 + 16 files changed, 653 insertions(+), 748 deletions(-) create mode 100644 luisa_compute/src/lang/types/alignment.rs create mode 100644 luisa_compute/src/lang/types/vector/coords.rs create mode 100644 luisa_compute/src/lang/types/vector/glam.rs create mode 100644 luisa_compute/src/lang/types/vector/nalgebra.rs diff --git a/luisa_compute/Cargo.toml b/luisa_compute/Cargo.toml index dbf69d2..1a0ca36 100644 --- a/luisa_compute/Cargo.toml +++ b/luisa_compute/Cargo.toml @@ -7,20 +7,12 @@ version = "0.1.1-alpha.1" base64ct = { version = "1.5.0", features = ["alloc"] } bumpalo = "3.12.0" env_logger = "0.10.0" -glam = "0.24.0" half = "2.2.1" lazy_static = "1.4.0" libc = "0.2" libloading = "0.8" log = "0.4" -luisa_compute_api_types = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_api_types", version = "0.1.1-alpha.1" } -luisa_compute_backend = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_backend", version = "0.1.1-alpha.1" } -luisa_compute_derive = { path = "../luisa_compute_derive", version = "0.1.1-alpha.1" } -luisa_compute_derive_impl = { path = "../luisa_compute_derive_impl", version = "0.1.1-alpha.1" } -luisa_compute_track = { path = "../luisa_compute_track", version = "0.1.1-alpha.1" } -luisa_compute_ir = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_ir", version = "0.1.1-alpha.1" } -luisa_compute_sys = { path = "../luisa_compute_sys", version = "0.1.1-alpha.1" } parking_lot = "0.12.1" rayon = "1.6.0" serde = { version = "1.0", features = ["derive"] } @@ -28,12 +20,23 @@ serde_json = "1.0" sha2 = "0.10" winit = "0.28.3" raw-window-handle = "0.5.1" -indexmap = "1.9.3" +indexmap = "2.0.0" + +luisa_compute_api_types = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_api_types", version = "0.1.1-alpha.1" } +luisa_compute_backend = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_backend", version = "0.1.1-alpha.1" } +luisa_compute_derive = { path = "../luisa_compute_derive", version = "0.1.1-alpha.1" } +luisa_compute_derive_impl = { path = "../luisa_compute_derive_impl", version = "0.1.1-alpha.1" } +luisa_compute_track = { path = "../luisa_compute_track", version = "0.1.1-alpha.1" } +luisa_compute_ir = { path = "../luisa_compute_sys/LuisaCompute/src/rust/luisa_compute_ir", version = "0.1.1-alpha.1" } +luisa_compute_sys = { path = "../luisa_compute_sys", version = "0.1.1-alpha.1" } + +glam = { version = "0.24.0", optional = true } +nalgebra = { version = "0.32.3", optional = true } [dev-dependencies] rand = "0.8.5" image = "0.24.5" -tobj = "3.2.5" +tobj = "4.0.0" [features] default = ["remote", "cuda", "cpu", "metal", "dx"] @@ -43,3 +46,5 @@ dx = ["luisa_compute_sys/dx"] strict = ["luisa_compute_sys/strict"] remote = ["luisa_compute_sys/remote"] cpu = ["luisa_compute_sys/cpu"] +glam = ["dep:glam"] +nalgebra = ["dep:nalgebra"] diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index c176c2f..159c354 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -142,6 +142,12 @@ pub trait FromNode { fn from_node(node: NodeRef) -> Self; } +impl FromNode for T { + fn from_node(_: NodeRef) -> Self { + Default::default() + } +} + fn _store(var: &T1, value: &T2) { let value_nodes = value.to_vec_nodes(); let self_nodes = var.to_vec_nodes(); diff --git a/luisa_compute/src/lang/maybe_expr.rs b/luisa_compute/src/lang/maybe_expr.rs index 2654889..7bb76b2 100644 --- a/luisa_compute/src/lang/maybe_expr.rs +++ b/luisa_compute/src/lang/maybe_expr.rs @@ -2,10 +2,88 @@ //! either be an expression or a normal value. This is necessary for making the //! trace macro work for both types of value. +use std::ops::DerefMut; + use super::control_flow::{generic_loop, if_then_else}; use super::types::core::*; +use super::types::AsExpr; use crate::internal_prelude::*; +/*== Version 1 +pub trait DerefSet { + type Target; + fn deref_set(&mut self, target: Self::Target); +} +impl DerefSet for T +where + T::Target: Sized, +{ + type Target = T::Target; + fn deref_set(&mut self, target: Self::Target) { + **self = target; + } +} +impl DerefSet for Var { + type Target = Expr; + fn deref_set(&mut self, target: Self::Target) { + self.store(target); + } +} +*/ +/*== Version 2 +pub trait DerefSet { + type Target; + fn deref_set(self, target: Self::Target); +} +impl DerefSet for &mut T +where + T::Target: Sized, +{ + type Target = T::Target; + fn deref_set(self, target: Self::Target) { + *self = target; + } +} +impl DerefSet for Var { + type Target = Expr; + fn deref_set(&mut self, target: Self::Target) { + self.store(target); + } +} +// TODO: Confirm that `&mut Var` errors. Otherwise, make a `&mut Var` impl that +// panics. +impl DerefSet for &Var { + type Target = Expr; + fn deref_set(&mut self, target: Self::Target) { + self.store(target); + } +} +*/ +/* == Version 3 == */ +pub trait DerefSet { + fn deref_set(self, target: X); +} +impl DerefSet for &mut T +where + T::Target: Sized, +{ + fn deref_set(self, target: T::Target) { + *self = target; + } +} +impl> DerefSet for Var { + fn deref_set(self, target: X) { + self.store(target.as_expr()); + } +} +// TODO: Confirm that `&mut Var` errors. Otherwise, make a `&mut Var` impl that +// panics. +impl> DerefSet for &Var { + fn deref_set(self, target: X) { + self.store(target.as_expr()); + } +} + pub trait BoolIfElseMaybeExpr { fn if_then_else(self, then: impl FnOnce() -> R, else_: impl FnOnce() -> R) -> R; } diff --git a/luisa_compute/src/lang/ops/impls.rs b/luisa_compute/src/lang/ops/impls.rs index db3af67..b581b40 100644 --- a/luisa_compute/src/lang/ops/impls.rs +++ b/luisa_compute/src/lang/ops/impls.rs @@ -1,6 +1,6 @@ use super::*; use crate::lang::types::core::*; -use crate::lang::types::VarDerefProxy; +use crate::lang::types::VarDeref; macro_rules! impl_var_trait { ($t:ty) => { diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs index f6bb7a9..e33bf26 100644 --- a/luisa_compute/src/lang/types.rs +++ b/luisa_compute/src/lang/types.rs @@ -1,168 +1,312 @@ -use std::any::Any; -use std::cell::Cell; -use std::ops::{Deref, DerefMut}; +use std::mem::transmute; +use std::ops::Deref; use crate::internal_prelude::*; +pub mod alignment; pub mod array; pub mod core; pub mod dynamic; pub mod shared; pub mod vector; -pub type Expr = ::Expr; -pub type Var = ::Var; +// TODO: Check up on comments. -pub trait Value: Copy + ir::TypeOf + 'static { +/// A value that can be used in a [`Kernel`] or [`Callable`]. Call [`expr`] or +/// [`var`] to convert into a kernel-trackable type. +pub trait Value: Copy + TypeOf + 'static { + /// A proxy for additional impls on [`Expr`]. type Expr: ExprProxy; + /// A proxy for additional impls on [`Var`]. type Var: VarProxy; - fn fields() -> Vec; - fn expr(self) -> Self::Expr { - const_(self) + /// The type of the custom data within an [`Expr`]. + type ExprData: Clone + FromNode + 'static; + /// The type of the custom data within an [`Var`]. + type VarData: Clone + FromNode + 'static; + + fn expr(self) -> Expr { + let node = __current_scope(|s| -> NodeRef { + let mut buf = vec![0u8; std::mem::size_of::()]; + unsafe { + std::ptr::copy_nonoverlapping( + &self as *const Self as *const u8, + buf.as_mut_ptr(), + buf.len(), + ); + } + s.const_(Const::Generic(CBoxedSlice::new(buf), Self::type_())) + }); + Expr::::from_node(node) } - fn var(self) -> Self::Var { - local::(self.expr()) + fn var(self) -> Var { + self.expr().var() } } -pub trait ExprProxy: Copy + Aggregate + NodeLike { +/// A trait for implementing remote impls on top of an [`Expr`] using [`Deref`]. +/// +/// For example, `Expr<[f32; 4]>` dereferences to `ArrayExpr`, which +/// exposes an [`Index`](std::ops::Index) impl. +pub trait ExprProxy: Clone + HasExprLayout<::ExprData> + 'static { type Value: Value; +} - fn var(self) -> Var { - def(self) - } +/// A trait for implementing remote impls on top of an [`Var`] using [`Deref`]. +/// +/// For example, `Var<[f32; 4]>` dereferences to `ArrayVar`, which +/// exposes [`Index`](std::ops::Index) and [`IndexMut`](std::ops::IndexMut) +/// impls. +pub trait VarProxy: + Clone + HasVarLayout<::VarData> + Deref> + 'static +{ + type Value: Value; +} +/// This marker trait states that `Self` has the same layout as an [`Expr`] +/// with `T::ExprData = X`. +pub unsafe trait HasExprLayout {} +/// This marker trait states that `Self` has the same layout as an [`Var`] +/// with `T::VarData = X`. +pub unsafe trait HasVarLayout {} - fn zeroed() -> Self { - zeroed::() +/// An expression within a [`Kernel`] or [`Callable`]. Created from a raw value +/// using [`Value::expr`]. +/// +/// Note that this does not store the value, and in order to get the result of a +/// function returning an `Expr`, you must call [`Kernel::dispatch`]. +#[derive(Debug, Clone)] +#[repr(C)] +pub struct Expr { + pub(crate) node: NodeRef, + _marker: PhantomData, + /// Custom data stored within the expression. + pub data: T::ExprData, +} + +/// A variable within a [`Kernel`] or [`Callable`]. Created using [`Expr::var`] +/// and [`Value::var`]. +/// +/// Note that setting a `Var` using direct assignment will not work. Instead, +/// either use the [`store`](Var::store) method or the `track!` macro and `*var +/// = expr` syntax. +#[derive(Debug, Clone)] +#[repr(C)] +pub struct Var { + pub(crate) node: NodeRef, + _marker: PhantomData, + /// Custom data stored within the variable. + pub data: T::VarData, +} + +impl Copy for Expr where T::ExprData: Copy {} +impl Copy for Var where T::VarData: Copy {} + +impl Aggregate for Expr { + fn to_nodes(&self, nodes: &mut Vec) { + nodes.push(self.node); + } + fn from_nodes>(iter: &mut I) -> Self { + let node = iter.next().unwrap(); + Self::from_node(node) + } +} +impl FromNode for Expr { + fn from_node(node: NodeRef) -> Self { + Self { + node, + _marker: PhantomData, + data: T::ExprData::from_node(node), + } + } +} +impl ToNode for Expr { + fn node(&self) -> NodeRef { + self.node } } -pub trait VarProxy: Copy + Aggregate + NodeLike { - type Value: Value; - fn store>>(&self, value: U) { - let value = value.into(); - super::_store(self, &value); +impl Aggregate for Var { + fn to_nodes(&self, nodes: &mut Vec) { + nodes.push(self.node); } - fn load(&self) -> Expr { - __current_scope(|b| { - let nodes = self.to_vec_nodes(); - let mut ret = vec![]; - for node in nodes { - ret.push(b.call(Func::Load, &[node], node.type_().clone())); - } - Expr::::from_nodes(&mut ret.into_iter()) - }) + fn from_nodes>(iter: &mut I) -> Self { + let node = iter.next().unwrap(); + Self::from_node(node) } - fn get_mut(&self) -> VarDerefProxy { - VarDerefProxy { - var: *self, - dirty: Cell::new(false), - assigned: self.load(), - _phantom: PhantomData, +} +impl FromNode for Var { + fn from_node(node: NodeRef) -> Self { + Self { + node, + _marker: PhantomData, + data: T::VarData::from_node(node), } } - fn _deref<'a>(&'a self) -> &'a Expr { +} +impl ToNode for Var { + fn node(&self) -> NodeRef { + self.node + } +} + +impl Deref for Expr { + type Target = T::Expr; + fn deref(&self) -> &Self::Target { + unsafe { transmute(self) } + } +} +impl Deref for Var { + type Target = T::Var; + fn deref(&self) -> &Self::Target { + unsafe { transmute(self) } + } +} + +impl Expr { + pub fn var(self) -> Var { + Var::::from_node(__current_scope(|b| b.local(self.node()))) + } + pub fn zeroed() -> Self { + FromNode::from_node(__current_scope(|b| b.zero_initializer(T::type_()))) + } + pub fn _ref<'a>(self) -> &'a Self { RECORDER.with(|r| { - let v: Expr = self.load(); let r = r.borrow(); - let v: &Expr = r.arena.alloc(v); + let v: &Expr = r.arena.alloc(self); unsafe { - let v: &'a Expr = std::mem::transmute(v); + let v: &'a Expr = transmute(v); v } }) } - fn zeroed() -> Self { - local_zeroed::() - } -} - -pub struct VarDerefProxy -where - P: VarProxy, -{ - pub(crate) var: P, - pub(crate) dirty: Cell, - pub(crate) assigned: Expr, - pub(crate) _phantom: PhantomData, } -impl Deref for VarDerefProxy -where - P: VarProxy, -{ - type Target = Expr; - fn deref(&self) -> &Self::Target { - &self.assigned +impl Var { + pub fn zeroed() -> Self { + Self::from_node(__current_scope(|b| { + b.local_zero_init(::type_()) + })) + } + pub fn load(&self) -> Expr { + __current_scope(|b| { + let nodes = self.to_vec_nodes(); + let mut ret = vec![]; + for node in nodes { + ret.push(b.call(Func::Load, &[node], node.type_().clone())); + } + Expr::::from_nodes(&mut ret.into_iter()) + }) + } + pub fn store(&self, value: impl AsExpr) { + let value = value.as_expr(); + super::_store(self, &value); + } + pub fn _deref(&self) -> &Expr { + self.load()._ref() } } -impl DerefMut for VarDerefProxy -where - P: VarProxy, -{ - fn deref_mut(&mut self) -> &mut Self::Target { - self.dirty.set(true); - &mut self.assigned +#[macro_export] +macro_rules! impl_simple_expr_proxy { + ([ $($bounds:tt)* ] $name: ident [ $($qualifiers:tt)* ] for $t: ty) => { + #[derive(Debug, Clone, Copy)] + #[repr(transparent)] + pub struct $name < $($bounds)* > ($crate::lang::types::Expr<$t>); + unsafe impl < $($bounds)* > $crate::lang::types::HasExprLayout< <$t as $crate::lang::types::Value>::ExprData > for $name < $($qualifiers)* > {} + impl < $($bounds)* > $crate::lang::types::ExprProxy for $name < $($qualifiers)* > { + type Value = $t; + } } } -impl Drop for VarDerefProxy -where - P: VarProxy, -{ - fn drop(&mut self) { - if self.dirty.get() { - self.var.store(self.assigned) +#[macro_export] +macro_rules! impl_simple_var_proxy { + ([ $($bounds:tt)* ] $name: ident [ $($qualifiers:tt)* ] for $t: ty) => { + #[derive(Debug, Clone, Copy)] + #[repr(transparent)] + pub struct $name < $($bounds)* > ($crate::lang::types::Var<$t>); + unsafe impl < $($bounds)* > $crate::lang::types::HasVarLayout< <$t as $crate::lang::types::Value>::VarData > for $name < $($qualifiers)* > {} + impl < $($bounds)* > $crate::lang::types::VarProxy for $name < $($qualifiers)* > { + type Value = $t; + } + impl < $($bounds)* > std::ops::Deref for $name < $($qualifiers)* > { + type Target = $crate::lang::types::Expr<$t>; + fn deref(&self) -> &Self::Target { + self.0._deref() + } } } } -fn def, T: Value>(init: E) -> Var { - Var::::from_node(__current_scope(|b| b.local(init.node()))) +mod private { + use super::*; + pub trait Sealed {} + impl Sealed for T {} + impl Sealed for Expr {} + impl Sealed for &Expr {} + impl Sealed for Var {} + impl Sealed for &Var {} } -fn local(init: Expr) -> Var { - Var::::from_node(__current_scope(|b| b.local(init.node()))) + +pub trait Tracked: private::Sealed { + type Type: TrackingType; + type Value: Value; } -fn local_zeroed() -> Var { - Var::::from_node(__current_scope(|b| { - b.local_zero_init(::type_()) - })) +trait TrackingType {} +struct ValueType; +impl TrackingType for ValueType {} +struct ExprType; +impl TrackingType for ExprType {} +struct VarType; +impl TrackingType for VarType {} + +impl Tracked for T { + type Type = ValueType; + type Value = T; +} +impl Tracked for Expr { + type Type = ExprType; + type Value = T; +} +impl Tracked for &Expr { + type Type = ExprType; + type Value = T; +} +impl Tracked for Var { + type Type = VarType; + type Value = T; +} +impl Tracked for &Var { + type Type = VarType; + type Value = T; } -fn zeroed() -> T::Expr { - FromNode::from_node(__current_scope(|b| b.zero_initializer(T::type_()))) +pub trait AsExpr: Tracked { + fn as_expr(&self) -> Expr; } -fn const_(value: T) -> T::Expr { - let node = __current_scope(|s| -> NodeRef { - let any = &value as &dyn Any; - if let Some(value) = any.downcast_ref::() { - s.const_(Const::Bool(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Int32(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Uint32(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Int64(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Uint64(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Float32(*value)) - } else if let Some(value) = any.downcast_ref::() { - s.const_(Const::Float64(*value)) - } else { - let mut buf = vec![0u8; std::mem::size_of::()]; - unsafe { - std::ptr::copy_nonoverlapping( - &value as *const T as *const u8, - buf.as_mut_ptr(), - buf.len(), - ); - } - s.const_(Const::Generic(CBoxedSlice::new(buf), T::type_())) - } - }); - FromNode::from_node(node) +impl AsExpr for T { + fn as_expr(&self) -> Expr { + self.expr() + } +} +impl AsExpr for Expr { + fn as_expr(&self) -> Expr { + *self + } +} +impl AsExpr for &Expr { + fn as_expr(&self) -> Expr { + **self + } +} +impl AsExpr for Var { + fn as_expr(&self) -> Expr { + self.load() + } +} +impl AsExpr for &Var { + fn as_expr(&self) -> Expr { + self.load() + } } diff --git a/luisa_compute/src/lang/types/alignment.rs b/luisa_compute/src/lang/types/alignment.rs new file mode 100644 index 0000000..c2294f0 --- /dev/null +++ b/luisa_compute/src/lang/types/alignment.rs @@ -0,0 +1,26 @@ +use super::*; + +pub(crate) trait Alignment: Default { + const ALIGNMENT: usize; +} + +macro_rules! alignment { + ($t:ident, $align:literal) => { + #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] + #[repr(align($align))] + pub struct $t; + impl Alignment for $t { + const ALIGNMENT: usize = $align; + } + }; +} + +alignment!(Align1, 1); +alignment!(Align2, 2); +alignment!(Align4, 4); +alignment!(Align8, 8); +alignment!(Align16, 16); +alignment!(Align32, 32); +alignment!(Align64, 64); +alignment!(Align128, 128); +alignment!(Align256, 256); diff --git a/luisa_compute/src/lang/types/array.rs b/luisa_compute/src/lang/types/array.rs index e367966..7ec31b1 100644 --- a/luisa_compute/src/lang/types/array.rs +++ b/luisa_compute/src/lang/types/array.rs @@ -1,140 +1,37 @@ +use std::ops::Index; + use super::*; use crate::lang::index::IntoIndex; use ir::ArrayType; -#[derive(Clone, Copy, Debug)] -pub struct ArrayExpr { - marker: PhantomData, - node: NodeRef, -} - -#[derive(Clone, Copy, Debug)] -pub struct ArrayVar { - marker: PhantomData, - node: NodeRef, -} - -impl FromNode for ArrayExpr { - fn from_node(node: NodeRef) -> Self { - Self { - marker: PhantomData, - node, - } - } -} - -impl ToNode for ArrayExpr { - fn node(&self) -> NodeRef { - self.node - } -} - -impl Aggregate for ArrayExpr { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self::from_node(iter.next().unwrap()) - } -} - -impl FromNode for ArrayVar { - fn from_node(node: NodeRef) -> Self { - Self { - marker: PhantomData, - node, - } - } -} - -impl ToNode for ArrayVar { - fn node(&self) -> NodeRef { - self.node - } -} - -impl Aggregate for ArrayVar { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self::from_node(iter.next().unwrap()) - } -} - -impl ExprProxy for ArrayExpr { - type Value = [T; N]; +impl Value for [T; N] { + type Expr = ArrayExpr; + type Var = ArrayVar; + type ExprData = (); + type VarData = (); } -impl VarProxy for ArrayVar { - type Value = [T; N]; -} - -impl ArrayVar { - pub fn len(&self) -> Expr { - (N as u32).expr() - } -} +impl_simple_expr_proxy!([T: Value, const N: usize] ArrayExpr[T, N] for [T; N]); +impl_simple_var_proxy!([T: Value, const N: usize] ArrayVar[T, N] for [T; N]); impl ArrayExpr { - pub fn zero() -> Self { - let node = __current_scope(|b| b.call(Func::ZeroInitializer, &[], <[T; N]>::type_())); - Self::from_node(node) - } pub fn len(&self) -> Expr { (N as u32).expr() } } -impl IndexRead for ArrayExpr { - type Element = T; - fn read(&self, i: I) -> Expr { +impl Index for ArrayExpr { + type Output = Expr; + fn index(&self, i: X) -> &Self::Output { let i = i.to_u64(); + // TODO: Add need_runtime_check()? lc_assert!(i.cmplt((N as u64).expr())); Expr::::from_node(__current_scope(|b| { b.call(Func::ExtractElement, &[self.node, i.node()], T::type_()) })) - } -} - -impl IndexRead for ArrayVar { - type Element = T; - fn read(&self, i: I) -> Expr { - let i = i.to_u64(); - if need_runtime_check() { - lc_assert!(i.cmplt((N as u64).expr())); - } - - Expr::::from_node(__current_scope(|b| { - let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); - b.call(Func::Load, &[gep], T::type_()) - })) - } -} - -impl IndexWrite for ArrayVar { - fn write>>(&self, i: I, value: V) { - let i = i.to_u64(); - let value = value.into(); - - if need_runtime_check() { - lc_assert!(i.cmplt((N as u64).expr())); - } - - __current_scope(|b| { - let gep = b.call(Func::GetElementPtr, &[self.node, i.node()], T::type_()); - b.update(gep, value.node()); - }); - } -} - -impl Value for [T; N] { - type Expr = ArrayExpr; - type Var = ArrayVar; - fn fields() -> Vec { - todo!("why this method exists?") + ._ref() } } diff --git a/luisa_compute/src/lang/types/core.rs b/luisa_compute/src/lang/types/core.rs index 9898c3e..3d5bd23 100644 --- a/luisa_compute/src/lang/types/core.rs +++ b/luisa_compute/src/lang/types/core.rs @@ -1,163 +1,164 @@ use super::*; use std::ops::Deref; -// This is a hack in order to get rust-analyzer to display type hints as Expr -// instead of Expr, which is rather redundant and generally clutters things up. -pub(crate) mod prim { - use super::*; +pub(crate) trait Primitive: Copy + TypeOf + 'static { + fn const_(&self) -> Const; +} +impl Value for T { + type Expr = PrimitiveExpr; + type Var = PrimitiveVar; + type ExprData = (); + type VarData = (); - #[derive(Clone, Copy, Debug)] - pub struct Expr { - pub(crate) node: NodeRef, - pub(crate) _phantom: PhantomData, + fn expr(&self) -> Expr { + let node = __current_scope(|s| -> NodeRef { s.const_(self.const_()) }); + Expr::::from_node(node) } +} - #[derive(Clone, Copy, Debug)] - pub struct Var { - pub(crate) node: NodeRef, - pub(crate) _phantom: PhantomData, +impl_simple_expr_proxy!([T: Primitive] PrimitiveExpr[T] for T); +impl_simple_var_proxy!([T: Primitive] PrimitiveVar[T] for T); + +impl Primitive for bool { + fn const_(&self) -> Const { + Const::Bool(*self) } } -impl Aggregate for prim::Expr { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - _phantom: PhantomData, - } +impl Primitive for f16 { + fn const_(&self) -> Const { + Const::F16(*self) } } - -impl Aggregate for prim::Var { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); +impl Primitive for f32 { + fn const_(&self) -> Const { + Const::F32(*self) } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - _phantom: PhantomData, - } +} +impl Primitive for f64 { + fn const_(&self) -> Const { + Const::F64(*self) } } -impl FromNode for prim::Expr { - fn from_node(node: NodeRef) -> Self { - Self { - node, - _phantom: PhantomData, - } +// impl Primitive for i8 { +// fn const_(&self) -> Const { +// Const::I8(*self) +// } +// } +impl Primitive for i16 { + fn const_(&self) -> Const { + Const::I16(*self) } } -impl ToNode for prim::Expr { - fn node(&self) -> NodeRef { - self.node +impl Primitive for i32 { + fn const_(&self) -> Const { + Const::I32(*self) } } - -impl Deref for prim::Var -where - prim::Var: VarProxy, -{ - type Target = T::Expr; - fn deref(&self) -> &Self::Target { - self._deref() +impl Primitive for i64 { + fn const_(&self) -> Const { + Const::I64(*self) } } -macro_rules! impl_prim { - ($t:ty) => { - impl From<$t> for prim::Expr<$t> { - fn from(v: $t) -> Self { - (v).expr() - } - } - impl From> for prim::Expr<$t> { - fn from(v: Var<$t>) -> Self { - v.load() - } - } - impl FromNode for prim::Var<$t> { - fn from_node(node: NodeRef) -> Self { - Self { - node, - _phantom: PhantomData, - } - } - } - impl ToNode for prim::Var<$t> { - fn node(&self) -> NodeRef { - self.node - } - } - impl ExprProxy for prim::Expr<$t> { - type Value = $t; - } - impl VarProxy for prim::Var<$t> { - type Value = $t; - } - impl Value for $t { - type Expr = prim::Expr<$t>; - type Var = prim::Var<$t>; - fn fields() -> Vec { - vec![] - } - } - impl_callable_param!($t, prim::Expr<$t>, prim::Var<$t>); - }; +// impl Primitive for u8 { +// fn const_(&self) -> Const { +// Const::U8(*self) +// } +// } +impl Primitive for u16 { + fn const_(&self) -> Const { + Const::U16(*self) + } +} +impl Primitive for u32 { + fn const_(&self) -> Const { + Const::U32(*self) + } +} +impl Primitive for u64 { + fn const_(&self) -> Const { + Const::U64(*self) + } } -impl_prim!(bool); -impl_prim!(u32); -impl_prim!(u64); -impl_prim!(i32); -impl_prim!(i64); -impl_prim!(i16); -impl_prim!(u16); -impl_prim!(f16); -impl_prim!(f32); -impl_prim!(f64); - -pub type Bool = prim::Expr; -pub type F16 = prim::Expr; -pub type F32 = prim::Expr; -pub type F64 = prim::Expr; -pub type I16 = prim::Expr; -pub type I32 = prim::Expr; -pub type I64 = prim::Expr; -pub type U16 = prim::Expr; -pub type U32 = prim::Expr; -pub type U64 = prim::Expr; +#[deprecated] +pub type Bool = Expr; +#[deprecated] +pub type F16 = Expr; +#[deprecated] +pub type F32 = Expr; +#[deprecated] +pub type F64 = Expr; +#[deprecated] +pub type I16 = Expr; +#[deprecated] +pub type I32 = Expr; +#[deprecated] +pub type I64 = Expr; +#[deprecated] +pub type U16 = Expr; +#[deprecated] +pub type U32 = Expr; +#[deprecated] +pub type U64 = Expr; -pub type F16Var = prim::Var; -pub type F32Var = prim::Var; -pub type F64Var = prim::Var; -pub type I16Var = prim::Var; -pub type I32Var = prim::Var; -pub type I64Var = prim::Var; -pub type U16Var = prim::Var; -pub type U32Var = prim::Var; -pub type U64Var = prim::Var; +#[deprecated] +pub type F16Var = Var; +#[deprecated] +pub type F32Var = Var; +#[deprecated] +pub type F64Var = Var; +#[deprecated] +pub type I16Var = Var; +#[deprecated] +pub type I32Var = Var; +#[deprecated] +pub type I64Var = Var; +#[deprecated] +pub type U16Var = Var; +#[deprecated] +pub type U32Var = Var; +#[deprecated] +pub type U64Var = Var; -pub type Half = prim::Expr; -pub type Float = prim::Expr; -pub type Double = prim::Expr; -pub type Int = prim::Expr; -pub type Long = prim::Expr; -pub type Uint = prim::Expr; -pub type Ulong = prim::Expr; -pub type Short = prim::Expr; -pub type Ushort = prim::Expr; +#[deprecated] +pub type Half = Expr; +#[deprecated] +pub type Float = Expr; +#[deprecated] +pub type Double = Expr; +#[deprecated] +pub type Int = Expr; +#[deprecated] +pub type Long = Expr; +#[deprecated] +pub type Uint = Expr; +#[deprecated] +pub type Ulong = Expr; +#[deprecated] +pub type Short = Expr; +#[deprecated] +pub type Ushort = Expr; -pub type BoolVar = prim::Var; -pub type HalfVar = prim::Var; -pub type FloatVar = prim::Var; -pub type DoubleVar = prim::Var; -pub type IntVar = prim::Var; -pub type LongVar = prim::Var; -pub type UintVar = prim::Var; -pub type UlongVar = prim::Var; -pub type ShortVar = prim::Var; -pub type UshortVar = prim::Var; +#[deprecated] +pub type BoolVar = Var; +#[deprecated] +pub type HalfVar = Var; +#[deprecated] +pub type FloatVar = Var; +#[deprecated] +pub type DoubleVar = Var; +#[deprecated] +pub type IntVar = Var; +#[deprecated] +pub type LongVar = Var; +#[deprecated] +pub type UintVar = Var; +#[deprecated] +pub type UlongVar = Var; +#[deprecated] +pub type ShortVar = Var; +#[deprecated] +pub type UshortVar = Var; diff --git a/luisa_compute/src/lang/types/vector.rs b/luisa_compute/src/lang/types/vector.rs index a1505ea..956acee 100644 --- a/luisa_compute/src/lang/types/vector.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -1,358 +1,94 @@ +use super::alignment::*; use super::core::*; use super::*; -use ir::{MatrixType, Primitive, VectorElementType, VectorType}; +use ir::{MatrixType, VectorElementType, VectorType}; use serde::{Deserialize, Serialize}; +use std::fmt::Debug; use std::ops::Mul; -macro_rules! def_vec { - ($name:ident, $glam_type:ident, $scalar:ty, $align:literal, $($comp:ident), *) => { - #[repr(C, align($align))] - #[derive(Copy, Clone, Debug, Default, PartialEq, Serialize, Deserialize)] - pub struct $name { - $(pub $comp: $scalar), * - } - impl $name { - #[inline] - pub const fn new($($comp: $scalar), *) -> Self { - Self { $($comp), * } - } - #[inline] - pub const fn splat(scalar: $scalar) -> Self { - Self { $($comp: scalar), * } - } - } - impl From<$name> for glam::$glam_type { - #[inline] - fn from(v: $name) -> Self { - Self::new($(v.$comp), *) - } - } - impl From for $name { - #[inline] - fn from(v: glam::$glam_type) -> Self { - Self::new($(v.$comp), *) - } - } - }; -} -macro_rules! def_packed_vec { - ($name:ident, $vec_type:ident, $glam_type:ident, $scalar:ty, $($comp:ident), *) => { - #[repr(C)] - #[derive(Copy, Clone, Debug, Default, Value, PartialEq, Serialize, Deserialize)] - pub struct $name { - $(pub $comp: $scalar), * - } - impl $name { - #[inline] - pub const fn new($($comp: $scalar), *) -> Self { - Self { $($comp), * } - } - #[inline] - pub const fn splat(scalar: $scalar) -> Self { - Self { $($comp: scalar), * } - } - } - impl From<$name> for glam::$glam_type { - #[inline] - fn from(v: $name) -> Self { - Self::new($(v.$comp), *) - } - } - impl From for $name { - #[inline] - fn from(v: glam::$glam_type) -> Self { - Self::new($(v.$comp), *) - } - } - impl From<$name> for $vec_type { - #[inline] - fn from(v: $name) -> Self { - Self::new($(v.$comp), *) - } - } - impl From<$vec_type> for $name { - #[inline] - fn from(v: $vec_type) -> Self { - Self::new($(v.$comp), *) - } - } - }; -} -macro_rules! def_packed_vec_no_glam { - ($name:ident, $vec_type:ident, $scalar:ty, $($comp:ident), *) => { - #[repr(C)] - #[derive(Copy, Clone, Debug, Default, Value)] - pub struct $name { - $(pub $comp: $scalar), * - } - impl $name { - #[inline] - pub const fn new($($comp: $scalar), *) -> Self { - Self { $($comp), * } - } - #[inline] - pub const fn splat(scalar: $scalar) -> Self { - Self { $($comp: scalar), * } - } - } - impl From<$name> for $vec_type { - #[inline] - fn from(v: $name) -> Self { - Self::new($(v.$comp), *) - } - } - impl From<$vec_type> for $name { - #[inline] - fn from(v: $vec_type) -> Self { - Self::new($(v.$comp), *) - } - } - }; -} -macro_rules! def_vec_no_glam { - ($name:ident, $scalar:ty, $align:literal, $($comp:ident), *) => { - #[repr(C, align($align))] - #[derive(Copy, Clone, Debug, Default)] - pub struct $name { - $(pub $comp: $scalar), * - } - impl $name { - #[inline] - pub fn new($($comp: $scalar), *) -> Self { - Self { $($comp), * } - } - #[inline] - pub fn splat(scalar: $scalar) -> Self { - Self { $($comp: scalar), * } - } - } - }; -} -def_vec!(Float2, Vec2, f32, 8, x, y); -def_vec!(Float3, Vec3, f32, 16, x, y, z); -def_vec!(Float4, Vec4, f32, 16, x, y, z, w); - -def_packed_vec!(PackedFloat2, Float2, Vec2, f32, x, y); -def_packed_vec!(PackedFloat3, Float3, Vec3, f32, x, y, z); -def_packed_vec!(PackedFloat4, Float4, Vec4, f32, x, y, z, w); - -def_vec!(Uint2, UVec2, u32, 8, x, y); -def_vec!(Uint3, UVec3, u32, 16, x, y, z); -def_vec!(Uint4, UVec4, u32, 16, x, y, z, w); - -def_packed_vec!(PackedUint2, Uint2, UVec2, u32, x, y); -def_packed_vec!(PackedUint3, Uint3, UVec3, u32, x, y, z); -def_packed_vec!(PackedUint4, Uint4, UVec4, u32, x, y, z, w); - -def_vec!(Int2, IVec2, i32, 8, x, y); -def_vec!(Int3, IVec3, i32, 16, x, y, z); -def_vec!(Int4, IVec4, i32, 16, x, y, z, w); - -def_packed_vec!(PackedInt2, Int2, IVec2, i32, x, y); -def_packed_vec!(PackedInt3, Int3, IVec3, i32, x, y, z); -def_packed_vec!(PackedInt4, Int4, IVec4, i32, x, y, z, w); - -def_vec!(Double2, DVec2, f64, 16, x, y); -def_vec!(Double3, DVec3, f64, 32, x, y, z); -def_vec!(Double4, DVec4, f64, 32, x, y, z, w); - -def_vec!(Bool2, BVec2, bool, 2, x, y); -def_vec!(Bool3, BVec3, bool, 4, x, y, z); -def_vec!(Bool4, BVec4, bool, 4, x, y, z, w); - -def_packed_vec!(PackedBool2, Bool2, BVec2, bool, x, y); -def_packed_vec!(PackedBool3, Bool3, BVec3, bool, x, y, z); -def_packed_vec!(PackedBool4, Bool4, BVec4, bool, x, y, z, w); - -def_vec_no_glam!(Ulong2, u64, 16, x, y); -def_vec_no_glam!(Ulong3, u64, 32, x, y, z); -def_vec_no_glam!(Ulong4, u64, 32, x, y, z, w); - -def_packed_vec_no_glam!(PackedUlong2, Ulong2, u64, x, y); -def_packed_vec_no_glam!(PackedUlong3, Ulong3, u64, x, y, z); -def_packed_vec_no_glam!(PackedUlong4, Ulong4, u64, x, y, z, w); - -def_vec_no_glam!(Long2, i64, 16, x, y); -def_vec_no_glam!(Long3, i64, 32, x, y, z); -def_vec_no_glam!(Long4, i64, 32, x, y, z, w); - -def_packed_vec_no_glam!(PackedLong2, Long2, i64, x, y); -def_packed_vec_no_glam!(PackedLong3, Long3, i64, x, y, z); -def_packed_vec_no_glam!(PackedLong4, Long4, i64, x, y, z, w); - -def_vec_no_glam!(Ushort2, u16, 4, x, y); -def_vec_no_glam!(Ushort3, u16, 8, x, y, z); -def_vec_no_glam!(Ushort4, u16, 8, x, y, z, w); - -def_packed_vec_no_glam!(PackedUshort2, Ushort2, u16, x, y); -def_packed_vec_no_glam!(PackedUshort3, Ushort3, u16, x, y, z); -def_packed_vec_no_glam!(PackedUshort4, Ushort4, u16, x, y, z, w); - -def_vec_no_glam!(Short2, i16, 4, x, y); -def_vec_no_glam!(Short3, i16, 8, x, y, z); -def_vec_no_glam!(Short4, i16, 8, x, y, z, w); - -def_packed_vec_no_glam!(PackedShort2, Short2, i16, x, y); -def_packed_vec_no_glam!(PackedShort3, Short3, i16, x, y, z); -def_packed_vec_no_glam!(PackedShort4, Short4, i16, x, y, z, w); - -def_vec_no_glam!(Half2, f16, 4, x, y); -def_vec_no_glam!(Half3, f16, 8, x, y, z); -def_vec_no_glam!(Half4, f16, 8, x, y, z, w); - -// def_packed_vec_no_glam!(PackedHalf2, f16, x, y); -// def_packed_vec_no_glam!(PackedHalf3, f16, x, y, z); -// pub type PackHalf4 = Half4; - -def_vec_no_glam!(Ubyte2, u8, 2, x, y); -def_vec_no_glam!(Ubyte3, u8, 4, x, y, z); -def_vec_no_glam!(Ubyte4, u8, 4, x, y, z, w); - -// def_packed_vec_no_glam!(PackedUbyte2, u8, x, y); -// def_packed_vec_no_glam!(PackedUbyte3, u8, x, y, z); -// pub type PackUbyte4 = Ubyte4; - -def_vec_no_glam!(Byte2, u8, 2, x, y); -def_vec_no_glam!(Byte3, u8, 4, x, y, z); -def_vec_no_glam!(Byte4, u8, 4, x, y, z, w); - -// def_packed_vec_no_glam!(PackedByte2, u8, x, y); -// def_packed_vec_no_glam!(PackedByte3, u8, x, y, z); -// pub type PackByte4 = Byte4; - -#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] -#[repr(C, align(8))] -pub struct Mat2 { - pub cols: [Float2; 2], -} -#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] -#[repr(C, align(16))] -pub struct Mat3 { - pub cols: [Float3; 3], -} -#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] -#[repr(C, align(16))] -pub struct Mat4 { - pub cols: [Float4; 4], -} -impl Mat2 { - pub const fn from_cols(c0: Float2, c1: Float2) -> Self { - Self { cols: [c0, c1] } - } - pub const fn identity() -> Self { - Self::from_cols(Float2::new(1.0, 0.0), Float2::new(0.0, 1.0)) - } -} -impl Mat3 { - pub const fn from_cols(c0: Float3, c1: Float3, c2: Float3) -> Self { - Self { cols: [c0, c1, c2] } - } - pub const fn identity() -> Self { - Self::from_cols( - Float3::new(1.0, 0.0, 0.0), - Float3::new(0.0, 1.0, 0.0), - Float3::new(0.0, 0.0, 1.0), - ) - } -} -impl Mat4 { - pub const fn from_cols(c0: Float4, c1: Float4, c2: Float4, c3: Float4) -> Self { - Self { - cols: [c0, c1, c2, c3], - } - } - pub const fn identity() -> Self { - Self::from_cols( - Float4::new(1.0, 0.0, 0.0, 0.0), - Float4::new(0.0, 1.0, 0.0, 0.0), - Float4::new(0.0, 0.0, 1.0, 0.0), - Float4::new(0.0, 0.0, 0.0, 1.0), - ) - } - pub fn into_affine3x4(self) -> [f32; 12] { - // [ - // self.cols[0].x, - // self.cols[0].y, - // self.cols[0].z, - // self.cols[1].x, - // self.cols[1].y, - // self.cols[1].z, - // self.cols[2].x, - // self.cols[2].y, - // self.cols[2].z, - // self.cols[3].x, - // self.cols[3].y, - // self.cols[3].z, - // ] - [ - self.cols[0].x, - self.cols[1].x, - self.cols[2].x, - self.cols[3].x, - self.cols[0].y, - self.cols[1].y, - self.cols[2].y, - self.cols[3].y, - self.cols[0].z, - self.cols[1].z, - self.cols[2].z, - self.cols[3].z, - ] - } -} -impl From for glam::Mat2 { - #[inline] - fn from(m: Mat2) -> Self { - Self::from_cols(m.cols[0].into(), m.cols[1].into()) - } -} -impl From for glam::Mat3 { - #[inline] - fn from(m: Mat3) -> Self { - Self::from_cols(m.cols[0].into(), m.cols[1].into(), m.cols[2].into()) - } +#[cfg(feature = "glam")] +mod glam; +#[cfg(feature = "nalgebra")] +mod nalgebra; + +pub mod coords; + +trait VectorElement: Primitive { + type A: Alignment; } -impl From for glam::Mat4 { - #[inline] - fn from(m: Mat4) -> Self { - Self::from_cols( - m.cols[0].into(), - m.cols[1].into(), - m.cols[2].into(), - m.cols[3].into(), - ) - } + +#[repr(C)] +#[derive(Copy, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct Vector, const PACKED: bool = false> { + _align: T::A, + elements: [T; N], } -impl From for Mat2 { - #[inline] - fn from(m: glam::Mat2) -> Self { - Self { - cols: [m.x_axis.into(), m.y_axis.into()], - } +impl, const P: bool> Debug for Vector { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.elements.fmt(f) } } -impl From for Mat3 { - #[inline] - fn from(m: glam::Mat3) -> Self { - Self { - cols: [m.x_axis.into(), m.y_axis.into(), m.z_axis.into()], + +impl, const PACKED: bool = false> + +macro_rules! element { + ($t:ty [ $l:literal ]: $a: ident, $p: ident) => { + impl VectorElement<$l, false> for $t { + type A = $a; } - } -} -impl From for Mat4 { - #[inline] - fn from(m: glam::Mat4) -> Self { - Self { - cols: [ - m.x_axis.into(), - m.y_axis.into(), - m.z_axis.into(), - m.w_axis.into(), - ], + impl VectorElement<$l, true> for $t { + type A = $p; } + }; + ($t:ty [ $l:literal ]: $a: ident) => { + element!($t [ $l ] : $a, Align1); } } +element!(bool[2]: Align2); +element!(bool[3]: Align4); +element!(bool[4]: Align4); +// TODO: Make u8 support ir::TypeOf. +// element!(u8[2]: Align2); +// element!(u8[3]: Align4); +// element!(u8[4]: Align4); +// element!(i8[2]: Align2); +// element!(i8[3]: Align4); +// element!(i8[4]: Align4); + +element!(f16[2]: Align4); +element!(f16[3]: Align8); +element!(f16[4]: Align8); +element!(u16[2]: Align4); +element!(u16[3]: Align8); +element!(u16[4]: Align8); +element!(i16[2]: Align4); +element!(i16[3]: Align8); +element!(i16[4]: Align8); + +element!(f32[2]: Align8); +element!(f32[3]: Align16); +element!(f32[4]: Align16); +element!(u32[2]: Align8); +element!(u32[3]: Align16); +element!(u32[4]: Align16); +element!(i32[2]: Align8); +element!(i32[3]: Align16); +element!(i32[4]: Align16); + +// TODO: Check whether size 8 alignment on packed f32 is necessary. +// This is an x86 feature though. +element!(f64[2]: Align16, Align8); +element!(f64[3]: Align32, Align8); +element!(f64[4]: Align32, Align8); +element!(u64[2]: Align16, Align8); +element!(u64[3]: Align32, Align8); +element!(u64[4]: Align32, Align8); +element!(i64[2]: Align16, Align8); +element!(i64[3]: Align32, Align8); +element!(i64[4]: Align32, Align8); + + macro_rules! impl_proxy_fields { ($vec:ident, $proxy:ident, $scalar:ty, x) => { impl $proxy { diff --git a/luisa_compute/src/lang/types/vector/coords.rs b/luisa_compute/src/lang/types/vector/coords.rs new file mode 100644 index 0000000..4563e55 --- /dev/null +++ b/luisa_compute/src/lang/types/vector/coords.rs @@ -0,0 +1 @@ +use super::*; diff --git a/luisa_compute/src/lang/types/vector/glam.rs b/luisa_compute/src/lang/types/vector/glam.rs new file mode 100644 index 0000000..43fe63d --- /dev/null +++ b/luisa_compute/src/lang/types/vector/glam.rs @@ -0,0 +1,2 @@ +use super::*; +use glam::*; diff --git a/luisa_compute/src/lang/types/vector/nalgebra.rs b/luisa_compute/src/lang/types/vector/nalgebra.rs new file mode 100644 index 0000000..bcb7ea0 --- /dev/null +++ b/luisa_compute/src/lang/types/vector/nalgebra.rs @@ -0,0 +1,2 @@ +use super::*; +use nalgebra::*; diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 42cf920..79d2b37 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -62,7 +62,9 @@ mod internal_prelude { pub(crate) use crate::runtime::{ CallableArgEncoder, CallableParameter, CallableRet, KernelBuilder, }; - pub(crate) use crate::{get_backtrace, impl_callable_param, ResourceTracker}; + pub(crate) use crate::{ + get_backtrace, impl_simple_expr_proxy, impl_simple_var_proxy, ResourceTracker, + }; pub(crate) use luisa_compute_backend::Backend; pub(crate) use std::marker::PhantomData; } @@ -100,7 +102,8 @@ lazy_static! { } impl Context { /// path to libluisa-* - /// if the current_exe() is in the same directory as libluisa-*, then passing current_exe() is enough + /// if the current_exe() is in the same directory as libluisa-*, then + /// passing current_exe() is enough pub fn new(lib_path: impl AsRef) -> Self { let mut lib_path = lib_path.as_ref().to_path_buf(); lib_path = lib_path.canonicalize().unwrap(); diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index 1bee826..806e870 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -1,31 +1,20 @@ use super::*; -#[macro_export] -macro_rules! impl_callable_param { - ($t:ty, $e:ty, $v:ty) => { - impl CallableParameter for $e { - fn def_param( - _: Option>, - builder: &mut KernelBuilder, - ) -> Self { - builder.value::<$t>() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.var(*self) - } - } - impl CallableParameter for $v { - fn def_param( - _: Option>, - builder: &mut KernelBuilder, - ) -> Self { - builder.var::<$t>() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.var(*self) - } - } - }; +impl CallableParameter for Expr { + fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { + builder.value::() + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.var(*self) + } +} +impl CallableParameter for Var { + fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { + builder.var::() + } + fn encode(&self, encoder: &mut CallableArgEncoder) { + encoder.var(*self) + } } // Not recommended to use this directly diff --git a/luisa_compute_track/src/lib.rs b/luisa_compute_track/src/lib.rs index 6a94f79..8d5e848 100644 --- a/luisa_compute_track/src/lib.rs +++ b/luisa_compute_track/src/lib.rs @@ -40,6 +40,20 @@ impl VisitMut for TraceVisitor { let trait_path = &self.trait_path; let span = node.span(); match node { + Expr::Assign(expr) => { + let left = &expr.left; + let right = &expr.right; + if let Expr::Unary(ExprUnary { + op: UnOp::Deref(_), + expr, + .. + }) = &**left + { + *node = parse_quote_spanned! {span=> + <_ as #trait_path::DerefSet>::deref_set(#expr, #right) + } + } + } Expr::If(expr) => { let cond = &expr.cond; let then_branch = &expr.then_branch; diff --git a/rustfmt.toml b/rustfmt.toml index d1bdb5e..174fd7a 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,2 +1,3 @@ ignore = ["luisa_compute_sys"] imports_granularity = "Module" +wrap_comments = true From 133de6007c73a3f6c0e47e8b8646c68b5845cadb Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Tue, 19 Sep 2023 13:46:08 +0100 Subject: [PATCH 04/15] Initial vector version. --- luisa_compute/src/lang/types.rs | 18 +- luisa_compute/src/lang/types/alignment.rs | 8 +- luisa_compute/src/lang/types/core.rs | 31 + luisa_compute/src/lang/types/vector.rs | 1601 +---------------- luisa_compute/src/lang/types/vector/coords.rs | 69 + .../src/lang/types/vector/element.rs | 50 + 6 files changed, 256 insertions(+), 1521 deletions(-) create mode 100644 luisa_compute/src/lang/types/vector/element.rs diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs index e33bf26..0386e64 100644 --- a/luisa_compute/src/lang/types.rs +++ b/luisa_compute/src/lang/types.rs @@ -1,4 +1,3 @@ -use std::mem::transmute; use std::ops::Deref; use crate::internal_prelude::*; @@ -151,13 +150,13 @@ impl ToNode for Var { impl Deref for Expr { type Target = T::Expr; fn deref(&self) -> &Self::Target { - unsafe { transmute(self) } + unsafe { &*(self as *const Self as *const T::Expr) } } } impl Deref for Var { type Target = T::Var; fn deref(&self) -> &Self::Target { - unsafe { transmute(self) } + unsafe { &*(self as *const Self as *const T::Var) } } } @@ -173,7 +172,7 @@ impl Expr { let r = r.borrow(); let v: &Expr = r.arena.alloc(self); unsafe { - let v: &'a Expr = transmute(v); + let v: &'a Expr = std::mem::transmute(v); v } }) @@ -200,9 +199,12 @@ impl Var { let value = value.as_expr(); super::_store(self, &value); } - pub fn _deref(&self) -> &Expr { - self.load()._ref() - } +} + +pub fn _deref_proxy(proxy: &P) -> &Expr { + unsafe { &*(proxy as *const P as *const Var) } + .load() + ._ref() } #[macro_export] @@ -231,7 +233,7 @@ macro_rules! impl_simple_var_proxy { impl < $($bounds)* > std::ops::Deref for $name < $($qualifiers)* > { type Target = $crate::lang::types::Expr<$t>; fn deref(&self) -> &Self::Target { - self.0._deref() + $crate::lang::types::_deref_proxy(self) } } } diff --git a/luisa_compute/src/lang/types/alignment.rs b/luisa_compute/src/lang/types/alignment.rs index c2294f0..f217a29 100644 --- a/luisa_compute/src/lang/types/alignment.rs +++ b/luisa_compute/src/lang/types/alignment.rs @@ -5,11 +5,11 @@ pub(crate) trait Alignment: Default { } macro_rules! alignment { - ($t:ident, $align:literal) => { - #[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] + ($T:ident, $align:literal) => { + #[derive(Copy, Clone, Debug, Hash, Default, PartialEq, Eq)] #[repr(align($align))] - pub struct $t; - impl Alignment for $t { + pub struct $T; + impl Alignment for $T { const ALIGNMENT: usize = $align; } }; diff --git a/luisa_compute/src/lang/types/core.rs b/luisa_compute/src/lang/types/core.rs index 3d5bd23..a48c6de 100644 --- a/luisa_compute/src/lang/types/core.rs +++ b/luisa_compute/src/lang/types/core.rs @@ -3,6 +3,7 @@ use std::ops::Deref; pub(crate) trait Primitive: Copy + TypeOf + 'static { fn const_(&self) -> Const; + fn primitive(&self) -> ir::Primitive; } impl Value for T { type Expr = PrimitiveExpr; @@ -23,22 +24,34 @@ impl Primitive for bool { fn const_(&self) -> Const { Const::Bool(*self) } + fn primitive(&self) -> ir::Primitive { + ir::Primitive::Bool + } } impl Primitive for f16 { fn const_(&self) -> Const { Const::F16(*self) } + fn primitive(&self) -> ir::Primitive { + ir::Primitive::F16 + } } impl Primitive for f32 { fn const_(&self) -> Const { Const::F32(*self) } + fn primitive(&self) -> ir::Primitive { + ir::Primitive::F32 + } } impl Primitive for f64 { fn const_(&self) -> Const { Const::F64(*self) } + fn primitive(&self) -> ir::Primitive { + ir::Primitive::F64 + } } // impl Primitive for i8 { @@ -50,16 +63,25 @@ impl Primitive for i16 { fn const_(&self) -> Const { Const::I16(*self) } + fn primitive(&self) -> ir::Primitive { + ir::Primitive::Int16 + } } impl Primitive for i32 { fn const_(&self) -> Const { Const::I32(*self) } + fn primitive(&self) -> ir::Primitive { + ir::Primitive::Int32 + } } impl Primitive for i64 { fn const_(&self) -> Const { Const::I64(*self) } + fn primitive(&self) -> ir::Primitive { + ir::Primitive::Int64 + } } // impl Primitive for u8 { @@ -71,16 +93,25 @@ impl Primitive for u16 { fn const_(&self) -> Const { Const::U16(*self) } + fn primitive(&self) -> ir::Primitive { + ir::Primitive::UInt16 + } } impl Primitive for u32 { fn const_(&self) -> Const { Const::U32(*self) } + fn primitive(&self) -> ir::Primitive { + ir::Primitive::UInt32 + } } impl Primitive for u64 { fn const_(&self) -> Const { Const::U64(*self) } + fn primitive(&self) -> ir::Primitive { + ir::Primitive::UInt64 + } } #[deprecated] diff --git a/luisa_compute/src/lang/types/vector.rs b/luisa_compute/src/lang/types/vector.rs index 956acee..fd5ba2d 100644 --- a/luisa_compute/src/lang/types/vector.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -12,1550 +12,133 @@ mod glam; mod nalgebra; pub mod coords; +mod element; -trait VectorElement: Primitive { +trait VectorElement: Primitive { type A: Alignment; } -#[repr(C)] -#[derive(Copy, Clone, Default, PartialEq, Eq, Serialize, Deserialize)] -pub struct Vector, const PACKED: bool = false> { - _align: T::A, - elements: [T; N], -} -impl, const P: bool> Debug for Vector { +impl, const N: usize> Debug for Vector { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.elements.fmt(f) } } -impl, const PACKED: bool = false> - -macro_rules! element { - ($t:ty [ $l:literal ]: $a: ident, $p: ident) => { - impl VectorElement<$l, false> for $t { - type A = $a; - } - impl VectorElement<$l, true> for $t { - type A = $p; - } - }; - ($t:ty [ $l:literal ]: $a: ident) => { - element!($t [ $l ] : $a, Align1); - } -} - -element!(bool[2]: Align2); -element!(bool[3]: Align4); -element!(bool[4]: Align4); -// TODO: Make u8 support ir::TypeOf. -// element!(u8[2]: Align2); -// element!(u8[3]: Align4); -// element!(u8[4]: Align4); -// element!(i8[2]: Align2); -// element!(i8[3]: Align4); -// element!(i8[4]: Align4); - -element!(f16[2]: Align4); -element!(f16[3]: Align8); -element!(f16[4]: Align8); -element!(u16[2]: Align4); -element!(u16[3]: Align8); -element!(u16[4]: Align8); -element!(i16[2]: Align4); -element!(i16[3]: Align8); -element!(i16[4]: Align8); - -element!(f32[2]: Align8); -element!(f32[3]: Align16); -element!(f32[4]: Align16); -element!(u32[2]: Align8); -element!(u32[3]: Align16); -element!(u32[4]: Align16); -element!(i32[2]: Align8); -element!(i32[3]: Align16); -element!(i32[4]: Align16); - -// TODO: Check whether size 8 alignment on packed f32 is necessary. -// This is an x86 feature though. -element!(f64[2]: Align16, Align8); -element!(f64[3]: Align32, Align8); -element!(f64[4]: Align32, Align8); -element!(u64[2]: Align16, Align8); -element!(u64[3]: Align32, Align8); -element!(u64[4]: Align32, Align8); -element!(i64[2]: Align16, Align8); -element!(i64[3]: Align32, Align8); -element!(i64[4]: Align32, Align8); - - -macro_rules! impl_proxy_fields { - ($vec:ident, $proxy:ident, $scalar:ty, x) => { - impl $proxy { - #[inline] - pub fn x(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 0)) - } - #[inline] - pub fn set_x(&self, value: prim::Expr<$scalar>) -> Self { - Self::from_node(__insert::<$vec>(self.node, 0, ToNode::node(&value))) - } - } - }; - ($vec:ident,$proxy:ident, $scalar:ty, y) => { - impl $proxy { - #[inline] - pub fn y(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 1)) - } - #[inline] - pub fn set_y(&self, value: prim::Expr<$scalar>) -> Self { - Self::from_node(__insert::<$vec>(self.node, 1, ToNode::node(&value))) - } - } - }; - ($vec:ident,$proxy:ident, $scalar:ty, z) => { - impl $proxy { - #[inline] - pub fn z(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 2)) - } - #[inline] - pub fn set_z(&self, value: prim::Expr<$scalar>) -> Self { - Self::from_node(__insert::<$vec>(self.node, 2, ToNode::node(&value))) - } - } - }; - ($vec:ident,$proxy:ident, $scalar:ty, w) => { - impl $proxy { - #[inline] - pub fn w(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 3)) - } - #[inline] - pub fn set_w(&self, value: prim::Expr<$scalar>) -> Self { - Self::from_node(__insert::<$vec>(self.node, 3, ToNode::node(&value))) - } - } - }; -} -macro_rules! impl_var_proxy_fields { - ($proxy:ident, $scalar:ty, x) => { - impl $proxy { - #[inline] - pub fn x(&self) -> Var<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 0)) - } - } - }; - ($proxy:ident, $scalar:ty, y) => { - impl $proxy { - #[inline] - pub fn y(&self) -> Var<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 1)) - } - } - }; - ($proxy:ident, $scalar:ty, z) => { - impl $proxy { - #[inline] - pub fn z(&self) -> Var<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 2)) - } - } - }; - ($proxy:ident, $scalar:ty, w) => { - impl $proxy { - #[inline] - pub fn w(&self) -> Var<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, 3)) - } - } - }; -} -macro_rules! impl_vec_proxy { - ($vec:ident, $expr_proxy:ident, $var_proxy:ident, $scalar:ty, $scalar_ty:ident, $length:literal, $($comp:ident), *) => { - #[derive(Clone, Copy)] - pub struct $expr_proxy { - node: NodeRef, - } - #[derive(Clone, Copy)] - pub struct $var_proxy { - node: NodeRef, - } - impl Value for $vec { - type Expr = $expr_proxy; - type Var = $var_proxy; - fn fields() -> Vec { - vec![$(stringify!($comp).to_string()),*] - } - } - impl TypeOf for $vec { - fn type_() -> luisa_compute_ir::CArc { - let type_ = Type::Vector(VectorType { - element: VectorElementType::Scalar(Primitive::$scalar_ty), - length: $length, - }); - register_type(type_) - } - } - impl Aggregate for $expr_proxy { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - } - } - } - impl VectorVarTrait for $expr_proxy { } - impl ScalarOrVector for $expr_proxy { - type Element = prim::Expr<$scalar>; - type ElementHost = $scalar; - } - impl BuiltinVarTrait for $expr_proxy { } - impl Aggregate for $var_proxy { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - } - } - } - impl FromNode for $expr_proxy { - fn from_node(node: NodeRef) -> Self { - Self { node } - } - } - impl ToNode for $expr_proxy { - fn node(&self) -> NodeRef { - self.node - } - } - impl FromNode for $var_proxy { - fn from_node(node: NodeRef) -> Self { - Self { node } - } - } - impl ToNode for $var_proxy { - fn node(&self) -> NodeRef { - self.node - } - } - impl ExprProxy for $expr_proxy { - type Value = $vec; - } - impl VarProxy for $var_proxy { - type Value = $vec; - } - impl std::ops::Deref for $var_proxy { - type Target = $expr_proxy; - fn deref(&self) -> &Self::Target { - self._deref() - } - } - impl From<$var_proxy> for $expr_proxy { - fn from(var: $var_proxy) -> Self { - var.load() - } - } - impl_callable_param!($vec, $expr_proxy, $var_proxy); - $(impl_proxy_fields!($vec, $expr_proxy, $scalar, $comp);)* - $(impl_var_proxy_fields!($var_proxy, $scalar, $comp);)* - impl $expr_proxy { - #[inline] - pub fn new($($comp: prim::Expr<$scalar>), *) -> Self { - Self { - node: __compose::<$vec>(&[$(ToNode::node(&$comp)), *]), - } - } - pub fn at(&self, index: usize) -> prim::Expr<$scalar> { - FromNode::from_node(__extract::<$scalar>(self.node, index)) - } - } - impl $vec { - #[inline] - pub fn expr($($comp: impl Into>), *) -> $expr_proxy { - $expr_proxy::new($($comp.into()), *) - } - } - }; -} - -macro_rules! impl_mat_proxy { - ($mat:ident, $expr_proxy:ident, $var_proxy:ident, $vec:ty, $scalar_ty:ident, $length:literal, $($comp:ident), *) => { - #[derive(Clone, Copy)] - pub struct $expr_proxy { - node: NodeRef, - } - #[derive(Clone, Copy)] - pub struct $var_proxy { - node: NodeRef, - } - impl Value for $mat { - type Expr = $expr_proxy; - type Var = $var_proxy; - fn fields() -> Vec { - vec![$(stringify!($comp).to_string()),*] - } - } - impl TypeOf for $mat { - fn type_() -> luisa_compute_ir::CArc { - let type_ = Type::Matrix(MatrixType { - element: VectorElementType::Scalar(Primitive::$scalar_ty), - dimension: $length, - }); - register_type(type_) - } - } - impl Aggregate for $expr_proxy { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - } - } - } - impl MatrixVarTrait for $expr_proxy { } - impl BuiltinVarTrait for $expr_proxy { } - impl Aggregate for $var_proxy { - fn to_nodes(&self, nodes: &mut Vec) { - nodes.push(self.node); - } - fn from_nodes>(iter: &mut I) -> Self { - Self { - node: iter.next().unwrap(), - } - } - } - impl FromNode for $expr_proxy { - fn from_node(node: NodeRef) -> Self { - Self { node } - } - } - impl ToNode for $expr_proxy { - fn node(&self) -> NodeRef { - self.node - } - } - impl ExprProxy for $expr_proxy { - type Value = $mat; - } - impl FromNode for $var_proxy { - fn from_node(node: NodeRef) -> Self { - Self { node } - } - } - impl ToNode for $var_proxy { - fn node(&self) -> NodeRef { - self.node - } - } - impl VarProxy for $var_proxy { - type Value = $mat; - } - impl std::ops::Deref for $var_proxy { - type Target = $expr_proxy; - fn deref(&self) -> &Self::Target { - self._deref() - } - } - impl From<$var_proxy> for $expr_proxy { - fn from(var: $var_proxy) -> Self { - var.load() - } - } - impl_callable_param!($mat, $expr_proxy, $var_proxy); - impl $expr_proxy { - #[inline] - pub fn new($($comp: Expr<$vec>), *) -> Self { - Self { - node: __compose::<$mat>(&[$(ToNode::node(&$comp)), *]), - } - } - pub fn col(&self, index: usize) -> Expr<$vec> { - Expr::<$vec>::from_node(__extract::<$vec>(self.node, index)) - } - } - impl $mat { - #[inline] - pub fn expr($($comp: impl Into>), *) -> $expr_proxy { - $expr_proxy::new($($comp.into()), *) - } - } - }; +#[repr(C)] +#[derive(Copy, Clone, Hash, Default, PartialEq, Eq, Serialize, Deserialize)] +pub struct Vector, const N: usize> { + _align: T::A, + elements: [T; N], } -impl_vec_proxy!(Bool2, Bool2Expr, Bool2Var, bool, Bool, 2, x, y); -impl_vec_proxy!(Bool3, Bool3Expr, Bool3Var, bool, Bool, 3, x, y, z); -impl_vec_proxy!(Bool4, Bool4Expr, Bool4Var, bool, Bool, 4, x, y, z, w); - -impl_vec_proxy!(Half2, Half2Expr, Half2Var, f16, Float16, 2, x, y); -impl_vec_proxy!(Half3, Half3Expr, Half3Var, f16, Float16, 3, x, y, z); -impl_vec_proxy!(Half4, Half4Expr, Half4Var, f16, Float16, 4, x, y, z, w); - -impl_vec_proxy!(Float2, Float2Expr, Float2Var, f32, Float32, 2, x, y); -impl_vec_proxy!(Float3, Float3Expr, Float3Var, f32, Float32, 3, x, y, z); -impl_vec_proxy!(Float4, Float4Expr, Float4Var, f32, Float32, 4, x, y, z, w); - -impl_vec_proxy!(Double2, Double2Expr, Double2Var, f64, Float64, 2, x, y); -impl_vec_proxy!(Double3, Double3Expr, Double3Var, f64, Float64, 3, x, y, z); -impl_vec_proxy!( - Double4, - Double4Expr, - Double4Var, - f64, - Float64, - 4, - x, - y, - z, - w -); - -impl_vec_proxy!(Ushort2, Ushort2Expr, Ushort2Var, u16, Uint16, 2, x, y); -impl_vec_proxy!(Ushort3, Ushort3Expr, Ushort3Var, u16, Uint16, 3, x, y, z); -impl_vec_proxy!(Ushort4, Ushort4Expr, Ushort4Var, u16, Uint16, 4, x, y, z, w); - -impl_vec_proxy!(Short2, Short2Expr, Short2Var, i16, Int16, 2, x, y); -impl_vec_proxy!(Short3, Short3Expr, Short3Var, i16, Int16, 3, x, y, z); -impl_vec_proxy!(Short4, Short4Expr, Short4Var, i16, Int16, 4, x, y, z, w); - -impl_vec_proxy!(Uint2, Uint2Expr, Uint2Var, u32, Uint32, 2, x, y); -impl_vec_proxy!(Uint3, Uint3Expr, Uint3Var, u32, Uint32, 3, x, y, z); -impl_vec_proxy!(Uint4, Uint4Expr, Uint4Var, u32, Uint32, 4, x, y, z, w); - -impl_vec_proxy!(Int2, Int2Expr, Int2Var, i32, Int32, 2, x, y); -impl_vec_proxy!(Int3, Int3Expr, Int3Var, i32, Int32, 3, x, y, z); -impl_vec_proxy!(Int4, Int4Expr, Int4Var, i32, Int32, 4, x, y, z, w); - -impl_vec_proxy!(Ulong2, Ulong2Expr, Ulong2Var, u64, Uint64, 2, x, y); -impl_vec_proxy!(Ulong3, Ulong3Expr, Ulong3Var, u64, Uint64, 3, x, y, z); -impl_vec_proxy!(Ulong4, Ulong4Expr, Ulong4Var, u64, Uint64, 4, x, y, z, w); - -impl_vec_proxy!(Long2, Long2Expr, Long2Var, i64, Int64, 2, x, y); -impl_vec_proxy!(Long3, Long3Expr, Long3Var, i64, Int64, 3, x, y, z); -impl_vec_proxy!(Long4, Long4Expr, Long4Var, i64, Int64, 4, x, y, z, w); - -impl_mat_proxy!(Mat2, Mat2Expr, Mat2Var, Float2, Float32, 2, x, y); -impl_mat_proxy!(Mat3, Mat3Expr, Mat3Var, Float3, Float32, 3, x, y, z); -impl_mat_proxy!(Mat4, Mat4Expr, Mat4Var, Float4, Float32, 4, x, y, z, w); - -macro_rules! impl_packed_cvt { - ($packed:ty, $vec:ty, $($comp:ident), *) => { - impl From<$vec> for $packed { - fn from(v: $vec) -> Self { - Self::new($(v.$comp()), *) - } - } - impl $packed { - pub fn unpack(&self) -> $vec { - (*self).into() - } - } - impl From<$packed> for $vec { - fn from(v: $packed) -> Self { - Self::new($(v.$comp()), *) - } - } - impl $vec { - pub fn pack(&self) -> $packed { - (*self).into() - } +impl, const N: usize> Vector { + pub fn new(elements: [T; N]) -> Self { + Self { + _align: T::A::default(), + elements, } } -} -impl_packed_cvt!(PackedFloat2Expr, Float2Expr, x, y); -impl_packed_cvt!(PackedFloat3Expr, Float3Expr, x, y, z); -impl_packed_cvt!(PackedFloat4Expr, Float4Expr, x, y, z, w); - -impl_packed_cvt!(PackedShort2Expr, Short2Expr, x, y); -impl_packed_cvt!(PackedShort3Expr, Short3Expr, x, y, z); -impl_packed_cvt!(PackedShort4Expr, Short4Expr, x, y, z, w); - -// ushort -impl_packed_cvt!(PackedUshort2Expr, Ushort2Expr, x, y); -impl_packed_cvt!(PackedUshort3Expr, Ushort3Expr, x, y, z); -impl_packed_cvt!(PackedUshort4Expr, Ushort4Expr, x, y, z, w); - -// int -impl_packed_cvt!(PackedInt2Expr, Int2Expr, x, y); -impl_packed_cvt!(PackedInt3Expr, Int3Expr, x, y, z); -impl_packed_cvt!(PackedInt4Expr, Int4Expr, x, y, z, w); - -// uint -impl_packed_cvt!(PackedUint2Expr, Uint2Expr, x, y); -impl_packed_cvt!(PackedUint3Expr, Uint3Expr, x, y, z); -impl_packed_cvt!(PackedUint4Expr, Uint4Expr, x, y, z, w); - -// long -impl_packed_cvt!(PackedLong2Expr, Long2Expr, x, y); -impl_packed_cvt!(PackedLong3Expr, Long3Expr, x, y, z); -impl_packed_cvt!(PackedLong4Expr, Long4Expr, x, y, z, w); - -// ulong -impl_packed_cvt!(PackedUlong2Expr, Ulong2Expr, x, y); -impl_packed_cvt!(PackedUlong3Expr, Ulong3Expr, x, y, z); -impl_packed_cvt!(PackedUlong4Expr, Ulong4Expr, x, y, z, w); - -macro_rules! impl_binop { - ($t:ty, $scalar:ty, $proxy:ty, $tr:ident, $m:ident, $tr_assign:ident, $m_assign:ident) => { - impl std::ops::$tr_assign<$proxy> for $proxy { - fn $m_assign(&mut self, rhs: $proxy) { - use std::ops::$tr; - *self = (*self).$m(rhs); - } - } - impl std::ops::$tr for $proxy { - type Output = $proxy; - fn $m(self, rhs: $proxy) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr<$scalar> for $proxy { - type Output = $proxy; - fn $m(self, rhs: $scalar) -> Self::Output { - let rhs = Self::splat(rhs); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr<$proxy> for $scalar { - type Output = $proxy; - fn $m(self, rhs: $proxy) -> Self::Output { - let lhs = <$proxy>::splat(self); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[lhs.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr> for $proxy { - type Output = $proxy; - fn $m(self, rhs: prim::Expr<$scalar>) -> Self::Output { - let rhs = Self::splat(rhs); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr<$proxy> for prim::Expr<$scalar> { - type Output = $proxy; - fn $m(self, rhs: $proxy) -> Self::Output { - let lhs = <$proxy>::splat(self); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[lhs.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - }; -} -macro_rules! impl_binop_for_mat { - ($t:ty, $scalar:ty, $proxy:ty, $tr:ident, $m:ident, $tr_assign:ident, $m_assign:ident) => { - impl std::ops::$tr_assign<$proxy> for $proxy { - fn $m_assign(&mut self, rhs: $proxy) { - use std::ops::$tr; - *self = (*self).$m(rhs); - } - } - impl std::ops::$tr for $proxy { - type Output = $proxy; - fn $m(self, rhs: $proxy) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } + pub fn splat(element: T) -> Self { + Self { + _align: T::A::default(), + elements: [element; N], } - impl std::ops::$tr<$scalar> for $proxy { - type Output = $proxy; - fn $m(self, rhs: $scalar) -> Self::Output { - let rhs = Self::fill(rhs); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr<$proxy> for $scalar { - type Output = $proxy; - fn $m(self, rhs: $proxy) -> Self::Output { - let lhs = <$proxy>::fill(self); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[lhs.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr> for $proxy { - type Output = $proxy; - fn $m(self, rhs: prim::Expr<$scalar>) -> Self::Output { - let rhs = Self::fill(rhs); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::$tr<$proxy> for prim::Expr<$scalar> { - type Output = $proxy; - fn $m(self, rhs: $proxy) -> Self::Output { - let lhs = <$proxy>::fill(self); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::$tr, &[lhs.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - }; -} -macro_rules! impl_arith_binop { - ($t:ty, $scalar:ty, $proxy:ty) => { - impl_common_op!($t, $scalar, $proxy); - impl_binop!($t, $scalar, $proxy, Add, add, AddAssign, add_assign); - impl_binop!($t, $scalar, $proxy, Sub, sub, SubAssign, sub_assign); - impl_binop!($t, $scalar, $proxy, Mul, mul, MulAssign, mul_assign); - impl_binop!($t, $scalar, $proxy, Div, div, DivAssign, div_assign); - impl_binop!($t, $scalar, $proxy, Rem, rem, RemAssign, rem_assign); - impl_reduce!($t, $scalar, $proxy); - impl std::ops::Neg for $proxy { - type Output = $proxy; - fn neg(self) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Neg, &[self.node], <$t as TypeOf>::type_()) - })) - } - } - }; -} -macro_rules! impl_arith_binop_for_mat { - ($t:ty, $scalar:ty, $proxy:ty) => { - impl_binop_for_mat!($t, $scalar, $proxy, Add, add, AddAssign, add_assign); - impl_binop_for_mat!($t, $scalar, $proxy, Sub, sub, SubAssign, sub_assign); - // Mat * Mat - impl std::ops::MulAssign<$proxy> for $proxy { - fn mul_assign(&mut self, rhs: $proxy) { - use std::ops::Mul; - *self = (*self).mul(rhs); - } - } - impl std::ops::Mul for $proxy { - type Output = $proxy; - fn mul(self, rhs: $proxy) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Mul, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - // Mat * Scalar - impl std::ops::MulAssign<$scalar> for $proxy { - fn mul_assign(&mut self, rhs: $scalar) { - use std::ops::Mul; - *self = (*self).mul(rhs); - } - } - impl std::ops::Mul<$scalar> for $proxy { - type Output = $proxy; - fn mul(self, rhs: $scalar) -> Self::Output { - let rhs = Self::fill(rhs); - <$proxy>::from_node(__current_scope(|s| { - s.call( - Func::MatCompMul, - &[self.node, rhs.node], - <$t as TypeOf>::type_(), - ) - })) - } - } - impl std::ops::Mul<$proxy> for $scalar { - type Output = $proxy; - fn mul(self, rhs: $proxy) -> Self::Output { - let lhs = <$proxy>::fill(self); - <$proxy>::from_node(__current_scope(|s| { - s.call( - Func::MatCompMul, - &[lhs.node, rhs.node], - <$t as TypeOf>::type_(), - ) - })) - } - } - impl std::ops::Mul> for $proxy { - type Output = $proxy; - fn mul(self, rhs: prim::Expr<$scalar>) -> Self::Output { - let rhs = Self::fill(rhs); - <$proxy>::from_node(__current_scope(|s| { - s.call( - Func::MatCompMul, - &[self.node, rhs.node], - <$t as TypeOf>::type_(), - ) - })) - } - } - impl std::ops::Mul<$proxy> for prim::Expr<$scalar> { - type Output = $proxy; - fn mul(self, rhs: $proxy) -> Self::Output { - let lhs = <$proxy>::fill(self); - <$proxy>::from_node(__current_scope(|s| { - s.call( - Func::MatCompMul, - &[lhs.node, rhs.node], - <$t as TypeOf>::type_(), - ) - })) - } - } - // Rem - impl std::ops::RemAssign<$scalar> for $proxy { - fn rem_assign(&mut self, rhs: $scalar) { - use std::ops::Rem; - *self = (*self).rem(rhs); - } - } - impl std::ops::Rem<$scalar> for $proxy { - type Output = $proxy; - fn rem(self, rhs: $scalar) -> Self::Output { - let rhs: prim::Expr<$scalar> = rhs.into(); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Rem, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::Rem> for $proxy { - type Output = $proxy; - fn rem(self, rhs: prim::Expr<$scalar>) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Rem, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - // Div - impl std::ops::DivAssign<$scalar> for $proxy { - fn div_assign(&mut self, rhs: $scalar) { - use std::ops::Div; - *self = (*self).div(rhs); - } - } - impl std::ops::Div<$scalar> for $proxy { - type Output = $proxy; - fn div(self, rhs: $scalar) -> Self::Output { - let rhs: prim::Expr<$scalar> = rhs.into(); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Div, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - impl std::ops::Div> for $proxy { - type Output = $proxy; - fn div(self, rhs: prim::Expr<$scalar>) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Div, &[self.node, rhs.node], <$t as TypeOf>::type_()) - })) - } - } - // Neg - impl std::ops::Neg for $proxy { - type Output = $proxy; - fn neg(self) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Neg, &[self.node], <$t as TypeOf>::type_()) - })) - } - } - impl $proxy { - pub fn comp_mul(&self, other: Self) -> Self { - <$proxy>::from_node(__current_scope(|s| { - s.call( - Func::MatCompMul, - &[self.node, other.node], - <$t as TypeOf>::type_(), - ) - })) - } - } - }; -} -macro_rules! impl_int_binop { - ($t:ty, $scalar:ty, $proxy:ty) => { - impl_binop!( - $t, - $scalar, - $proxy, - BitAnd, - bitand, - BitAndAssign, - bitand_assign - ); - impl_binop!($t, $scalar, $proxy, BitOr, bitor, BitOrAssign, bitor_assign); - impl_binop!( - $t, - $scalar, - $proxy, - BitXor, - bitxor, - BitXorAssign, - bitxor_assign - ); - impl_binop!($t, $scalar, $proxy, Shl, shl, ShlAssign, shl_assign); - impl_binop!($t, $scalar, $proxy, Shr, shr, ShrAssign, shr_assign); - impl std::ops::Not for $proxy { - type Output = Expr<$t>; - fn not(self) -> Self::Output { - __current_scope(|s| { - let ret = s.call(Func::BitNot, &[ToNode::node(&self)], Self::Output::type_()); - Expr::<$t>::from_node(ret) - }) - } - } - }; -} -macro_rules! impl_bool_binop { - ($t:ty, $proxy:ty) => { - impl_binop!( - $t, - bool, - $proxy, - BitAnd, - bitand, - BitAndAssign, - bitand_assign - ); - impl_binop!($t, bool, $proxy, BitOr, bitor, BitOrAssign, bitor_assign); - impl_binop!( - $t, - bool, - $proxy, - BitXor, - bitxor, - BitXorAssign, - bitxor_assign - ); - impl $proxy { - pub fn splat>>(value: V) -> Self { - let value = value.into(); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Vec, &[value.node], <$t as TypeOf>::type_()) - })) - } - pub fn zero() -> Self { - Self::splat(false) - } - pub fn one() -> Self { - Self::splat(true) - } - pub fn all(&self) -> prim::Expr { - Expr::::from_node(__current_scope(|s| { - s.call(Func::All, &[self.node], ::type_()) - })) - } - pub fn any(&self) -> prim::Expr { - Expr::::from_node(__current_scope(|s| { - s.call(Func::Any, &[self.node], ::type_()) - })) - } - } - impl std::ops::Not for $proxy { - type Output = Expr<$t>; - fn not(self) -> Self::Output { - self ^ Self::splat(true) - } - } - }; -} -macro_rules! impl_reduce { - ($t:ty, $scalar:ty, $proxy:ty) => { - impl $proxy { - #[inline] - pub fn reduce_sum(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__current_scope(|s| { - s.call(Func::ReduceSum, &[self.node], <$scalar as TypeOf>::type_()) - })) - } - #[inline] - pub fn reduce_prod(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__current_scope(|s| { - s.call(Func::ReduceProd, &[self.node], <$scalar as TypeOf>::type_()) - })) - } - #[inline] - pub fn reduce_min(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__current_scope(|s| { - s.call(Func::ReduceMin, &[self.node], <$scalar as TypeOf>::type_()) - })) - } - #[inline] - pub fn reduce_max(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__current_scope(|s| { - s.call(Func::ReduceMax, &[self.node], <$scalar as TypeOf>::type_()) - })) - } - #[inline] - pub fn dot(&self, rhs: $proxy) -> prim::Expr<$scalar> { - FromNode::from_node(__current_scope(|s| { - s.call( - Func::Dot, - &[self.node, rhs.node], - <$scalar as TypeOf>::type_(), - ) - })) - } - } - }; -} -macro_rules! impl_common_op { - ($t:ty, $scalar:ty, $proxy:ty) => { - impl $proxy { - pub fn splat>>(value: V) -> Self { - let value = value.into(); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Vec, &[value.node], <$t as TypeOf>::type_()) - })) - } - pub fn zero() -> Self { - Self::splat(0.0 as $scalar) - } - pub fn one() -> Self { - Self::splat(1.0 as $scalar) - } - } - }; -} -macro_rules! impl_vec_op { - ($t:ty, $scalar:ty, $proxy:ty, $mat:ty) => { - impl $proxy { - #[inline] - pub fn length(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__current_scope(|s| { - s.call(Func::Length, &[self.node], <$scalar as TypeOf>::type_()) - })) - } - #[inline] - pub fn normalize(&self) -> Self { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Normalize, &[self.node], <$t as TypeOf>::type_()) - })) - } - #[inline] - pub fn length_squared(&self) -> prim::Expr<$scalar> { - FromNode::from_node(__current_scope(|s| { - s.call( - Func::LengthSquared, - &[self.node], - <$scalar as TypeOf>::type_(), - ) - })) - } - #[inline] - pub fn distance(&self, rhs: $proxy) -> prim::Expr<$scalar> { - (*self - rhs).length() - } - #[inline] - pub fn distance_squared(&self, rhs: $proxy) -> prim::Expr<$scalar> { - (*self - rhs).length_squared() - } - #[inline] - pub fn fma(&self, a: $proxy, b: $proxy) -> Self { - <$proxy>::from_node(__current_scope(|s| { - s.call( - Func::Fma, - &[self.node, a.node, b.node], - <$t as TypeOf>::type_(), - ) - })) - } - #[inline] - pub fn outer_product(&self, rhs: $proxy) -> Expr<$mat> { - Expr::<$mat>::from_node(__current_scope(|s| { - s.call( - Func::OuterProduct, - &[self.node, rhs.node], - <$mat as TypeOf>::type_(), - ) - })) - } - } - }; -} - -// a little shit -macro_rules! impl_arith_binop_f16 { - ($t:ty, $scalar:ty, $proxy:ty) => { - impl_common_op_f16!($t, $scalar, $proxy); - impl_binop!($t, $scalar, $proxy, Add, add, AddAssign, add_assign); - impl_binop!($t, $scalar, $proxy, Sub, sub, SubAssign, sub_assign); - impl_binop!($t, $scalar, $proxy, Mul, mul, MulAssign, mul_assign); - impl_binop!($t, $scalar, $proxy, Div, div, DivAssign, div_assign); - impl_binop!($t, $scalar, $proxy, Rem, rem, RemAssign, rem_assign); - impl_reduce!($t, $scalar, $proxy); - impl std::ops::Neg for $proxy { - type Output = $proxy; - fn neg(self) -> Self::Output { - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Neg, &[self.node], <$t as TypeOf>::type_()) - })) - } - } - }; -} -macro_rules! impl_common_op_f16 { - ($t:ty, $scalar:ty, $proxy:ty) => { - impl $proxy { - pub fn splat>>(value: V) -> Self { - let value = value.into(); - <$proxy>::from_node(__current_scope(|s| { - s.call(Func::Vec, &[value.node], <$t as TypeOf>::type_()) - })) - } - pub fn zero() -> Self { - Self::splat(f16::from_f32(0.0f32)) - } - pub fn one() -> Self { - Self::splat(f16::from_f32(1.0f32)) - } - } - }; -} - -impl_arith_binop_f16!(Half2, f16, Half2Expr); -impl_arith_binop_f16!(Half3, f16, Half3Expr); -impl_arith_binop_f16!(Half4, f16, Half4Expr); - -impl_arith_binop!(Float2, f32, Float2Expr); -impl_arith_binop!(Float3, f32, Float3Expr); -impl_arith_binop!(Float4, f32, Float4Expr); - -impl_arith_binop!(Short2, i16, Short2Expr); -impl_arith_binop!(Short3, i16, Short3Expr); -impl_arith_binop!(Short4, i16, Short4Expr); - -impl_arith_binop!(Ushort2, u16, Ushort2Expr); -impl_arith_binop!(Ushort3, u16, Ushort3Expr); -impl_arith_binop!(Ushort4, u16, Ushort4Expr); - -impl_arith_binop!(Int2, i32, Int2Expr); -impl_arith_binop!(Int3, i32, Int3Expr); -impl_arith_binop!(Int4, i32, Int4Expr); - -impl_arith_binop!(Uint2, u32, Uint2Expr); -impl_arith_binop!(Uint3, u32, Uint3Expr); -impl_arith_binop!(Uint4, u32, Uint4Expr); - -impl_arith_binop!(Long2, i64, Long2Expr); -impl_arith_binop!(Long3, i64, Long3Expr); -impl_arith_binop!(Long4, i64, Long4Expr); - -impl_arith_binop!(Ulong2, u64, Ulong2Expr); -impl_arith_binop!(Ulong3, u64, Ulong3Expr); -impl_arith_binop!(Ulong4, u64, Ulong4Expr); - -impl_int_binop!(Short2, i16, Short2Expr); -impl_int_binop!(Short3, i16, Short3Expr); -impl_int_binop!(Short4, i16, Short4Expr); - -impl_int_binop!(Ushort2, u16, Ushort2Expr); -impl_int_binop!(Ushort3, u16, Ushort3Expr); -impl_int_binop!(Ushort4, u16, Ushort4Expr); - -impl_int_binop!(Int2, i32, Int2Expr); -impl_int_binop!(Int3, i32, Int3Expr); -impl_int_binop!(Int4, i32, Int4Expr); - -impl_int_binop!(Uint2, u32, Uint2Expr); -impl_int_binop!(Uint3, u32, Uint3Expr); -impl_int_binop!(Uint4, u32, Uint4Expr); - -impl_int_binop!(Long2, i64, Long2Expr); -impl_int_binop!(Long3, i64, Long3Expr); -impl_int_binop!(Long4, i64, Long4Expr); - -impl_int_binop!(Ulong2, u64, Ulong2Expr); -impl_int_binop!(Ulong3, u64, Ulong3Expr); -impl_int_binop!(Ulong4, u64, Ulong4Expr); - -impl_bool_binop!(Bool2, Bool2Expr); -impl_bool_binop!(Bool3, Bool3Expr); -impl_bool_binop!(Bool4, Bool4Expr); - -macro_rules! impl_select { - ($bvec:ty, $vec:ty, $proxy:ty) => { - impl $proxy { - pub fn select(mask: Expr<$bvec>, a: Expr<$vec>, b: Expr<$vec>) -> Expr<$vec> { - Expr::<$vec>::from_node(__current_scope(|s| { - s.call( - Func::Select, - &[mask.node(), a.node(), b.node()], - <$vec as TypeOf>::type_(), - ) - })) - } - } - }; + } } -impl_select!(Bool2, Bool2, Bool2Expr); -impl_select!(Bool3, Bool3, Bool3Expr); -impl_select!(Bool4, Bool4, Bool4Expr); - -impl_select!(Bool2, Half2, Half2Expr); -impl_select!(Bool3, Half3, Half3Expr); -impl_select!(Bool4, Half4, Half4Expr); - -impl_select!(Bool2, Float2, Float2Expr); -impl_select!(Bool3, Float3, Float3Expr); -impl_select!(Bool4, Float4, Float4Expr); - -impl_select!(Bool2, Int2, Int2Expr); -impl_select!(Bool3, Int3, Int3Expr); -impl_select!(Bool4, Int4, Int4Expr); - -impl_select!(Bool2, Uint2, Uint2Expr); -impl_select!(Bool3, Uint3, Uint3Expr); -impl_select!(Bool4, Uint4, Uint4Expr); - -impl_select!(Bool2, Short2, Short2Expr); -impl_select!(Bool3, Short3, Short3Expr); -impl_select!(Bool4, Short4, Short4Expr); - -impl_select!(Bool2, Ushort2, Ushort2Expr); -impl_select!(Bool3, Ushort3, Ushort3Expr); -impl_select!(Bool4, Ushort4, Ushort4Expr); - -impl_select!(Bool2, Long2, Long2Expr); -impl_select!(Bool3, Long3, Long3Expr); -impl_select!(Bool4, Long4, Long4Expr); - -impl_select!(Bool2, Ulong2, Ulong2Expr); -impl_select!(Bool3, Ulong3, Ulong3Expr); -impl_select!(Bool4, Ulong4, Ulong4Expr); - -macro_rules! impl_permute { - ($tr:ident, $proxy:ty,$len:expr, $v2:ty, $v3:ty, $v4:ty) => { - impl $tr for $proxy { - type Vec2 = Expr<$v2>; - type Vec3 = Expr<$v3>; - type Vec4 = Expr<$v4>; - fn permute2(&self, x: i32, y: i32) -> Self::Vec2 { - assert!(x < $len); - assert!(y < $len); - let x: Expr = x.into(); - let y: Expr = y.into(); - Expr::<$v2>::from_node(__current_scope(|s| { - s.call( - Func::Permute, - &[self.node, ToNode::node(&x), ToNode::node(&y)], - <$v2 as TypeOf>::type_(), - ) - })) - } - fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3 { - assert!(x < $len); - assert!(y < $len); - assert!(z < $len); - let x: Expr = x.into(); - let y: Expr = y.into(); - let z: Expr = z.into(); - Expr::<$v3>::from_node(__current_scope(|s| { - s.call( - Func::Permute, - &[ - self.node, - ToNode::node(&x), - ToNode::node(&y), - ToNode::node(&z), - ], - <$v3 as TypeOf>::type_(), - ) - })) - } - fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4 { - assert!(x < $len); - assert!(y < $len); - assert!(z < $len); - assert!(w < $len); - let x: Expr = x.into(); - let y: Expr = y.into(); - let z: Expr = z.into(); - let w: Expr = w.into(); - Expr::<$v4>::from_node(__current_scope(|s| { - s.call( - Func::Permute, - &[ - self.node, - ToNode::node(&x), - ToNode::node(&y), - ToNode::node(&z), - ToNode::node(&w), - ], - <$v4 as TypeOf>::type_(), - ) - })) - } - } - }; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct VectorExprData, const N: usize>([Expr; N]); +impl> FromNode for VectorExprData { + fn from_node(node: NodeRef) -> Self { + Self(std::array::from_fn(|i| { + FromNode::from_node(__extract::(node, i)) + })) + } } -impl_permute!(Vec2Swizzle, Half2Expr, 2, Half2, Half3, Half4); -impl_permute!(Vec3Swizzle, Half3Expr, 3, Half2, Half3, Half4); -impl_permute!(Vec4Swizzle, Half4Expr, 4, Half2, Half3, Half4); - -impl_permute!(Vec2Swizzle, Float2Expr, 2, Float2, Float3, Float4); -impl_permute!(Vec3Swizzle, Float3Expr, 3, Float2, Float3, Float4); -impl_permute!(Vec4Swizzle, Float4Expr, 4, Float2, Float3, Float4); - -impl_permute!(Vec2Swizzle, Short2Expr, 2, Short2, Short3, Short4); -impl_permute!(Vec3Swizzle, Short3Expr, 3, Short2, Short3, Short4); -impl_permute!(Vec4Swizzle, Short4Expr, 4, Short2, Short3, Short4); - -impl_permute!(Vec2Swizzle, Ushort2Expr, 2, Ushort2, Ushort3, Ushort4); -impl_permute!(Vec3Swizzle, Ushort3Expr, 3, Ushort2, Ushort3, Ushort4); -impl_permute!(Vec4Swizzle, Ushort4Expr, 4, Ushort2, Ushort3, Ushort4); - -impl_permute!(Vec2Swizzle, Int2Expr, 2, Int2, Int3, Int4); -impl_permute!(Vec3Swizzle, Int3Expr, 3, Int2, Int3, Int4); -impl_permute!(Vec4Swizzle, Int4Expr, 4, Int2, Int3, Int4); - -impl_permute!(Vec2Swizzle, Uint2Expr, 2, Uint2, Uint3, Uint4); -impl_permute!(Vec3Swizzle, Uint3Expr, 3, Uint2, Uint3, Uint4); -impl_permute!(Vec4Swizzle, Uint4Expr, 4, Uint2, Uint3, Uint4); - -impl_permute!(Vec2Swizzle, Long2Expr, 2, Long2, Long3, Long4); -impl_permute!(Vec3Swizzle, Long3Expr, 3, Long2, Long3, Long4); -impl_permute!(Vec4Swizzle, Long4Expr, 4, Long2, Long3, Long4); - -impl_permute!(Vec2Swizzle, Ulong2Expr, 2, Ulong2, Ulong3, Ulong4); -impl_permute!(Vec3Swizzle, Ulong3Expr, 3, Ulong2, Ulong3, Ulong4); -impl_permute!(Vec4Swizzle, Ulong4Expr, 4, Ulong2, Ulong3, Ulong4); - -impl Float3Expr { - #[inline] - pub fn cross(&self, rhs: Float3Expr) -> Self { - Float3Expr::from_node(__current_scope(|s| { - s.call( - Func::Cross, - &[self.node, rhs.node], - ::type_(), - ) +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct VectorVarData, const N: usize>([Var; N]); +impl> FromNode for VectorVarData { + fn from_node(node: NodeRef) -> Self { + Self(std::array::from_fn(|i| { + FromNode::from_node(__extract::(node, i)) })) } } -impl_vec_op!(Float2, f32, Float2Expr, Mat2); -impl_vec_op!(Float3, f32, Float3Expr, Mat3); -impl_vec_op!(Float4, f32, Float4Expr, Mat4); -macro_rules! impl_var_trait2 { - ($t:ty, $v:ty) => { - impl VarTrait for $t { - type Value = $v; - type Short = Short2Expr; - type Ushort = Ushort2Expr; - type Int = Int2Expr; - type Uint = Uint2Expr; - type Float = Float2Expr; - type Half = Half2Expr; - type Bool = Bool2Expr; - type Double = Double2Expr; - type Long = Long2Expr; - type Ulong = Ulong2Expr; - } - impl CommonVarOp for $t {} - impl VarCmp for $t {} - impl VarCmpEq for $t {} - impl From<$v> for $t { - fn from(v: $v) -> Self { - Self::new((v.x).expr(), (v.y).expr()) - } - } - }; -} -macro_rules! impl_var_trait3 { - ($t:ty, $v:ty) => { - impl VarTrait for $t { - type Value = $v; - type Short = Short3Expr; - type Ushort = Ushort3Expr; - type Int = Int3Expr; - type Uint = Uint3Expr; - type Float = Float3Expr; - type Half = Half3Expr; - type Bool = Bool3Expr; - type Double = Double3Expr; - type Long = Long3Expr; - type Ulong = Ulong3Expr; - } - impl CommonVarOp for $t {} - impl VarCmp for $t {} - impl VarCmpEq for $t {} - impl From<$v> for $t { - fn from(v: $v) -> Self { - Self::new(v.x.expr(), v.y.expr(), v.z.expr()) - } - } - }; +#[repr(C)] +#[derive(Debug, Copy, Clone)] +pub struct DoubledProxyData(X, X); +impl FromNode for DoubledProxyData { + fn from_node(node: NodeRef) -> Self { + Self(X::from_node(node), X::from_node(node)) + } } -macro_rules! impl_var_trait4 { - ($t:ty, $v:ty) => { - impl VarTrait for $t { - type Value = $v; - type Short = Short4Expr; - type Ushort = Ushort4Expr; - type Int = Int4Expr; - type Uint = Uint4Expr; - type Float = Float4Expr; - type Double = Double4Expr; - type Half = Half4Expr; - type Bool = Bool4Expr; - type Long = Long4Expr; - type Ulong = Ulong4Expr; + +macro_rules! vector_proxies { + ($N:literal [ $($c:ident),* ]: $ExprName:ident, $VarName:ident) => { + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct $ExprName> { + _node: NodeRef, + $(pub $c: Expr),* } - impl CommonVarOp for $t {} - impl VarCmp for $t {} - impl VarCmpEq for $t {} - impl From<$v> for $t { - fn from(v: $v) -> Self { - Self::new(v.x.expr(), v.y.expr(), v.z.expr(), v.w.expr()) - } + #[repr(C)] + #[derive(Debug, Copy, Clone)] + pub struct $VarName> { + _node: NodeRef, + $(pub $c: Var),* } - }; -} - -impl_var_trait2!(Half2Expr, Half2); -impl_var_trait2!(Float2Expr, Float2); -impl_var_trait2!(Double2Expr, Double2); -impl_var_trait2!(Short2Expr, Short2); -impl_var_trait2!(Ushort2Expr, Ushort2); -impl_var_trait2!(Int2Expr, Int2); -impl_var_trait2!(Uint2Expr, Uint2); -impl_var_trait2!(Bool2Expr, Bool2); -impl_var_trait2!(Long2Expr, Long2); -impl_var_trait2!(Ulong2Expr, Ulong2); - -impl_var_trait3!(Half3Expr, Half3); -impl_var_trait3!(Float3Expr, Float3); -impl_var_trait3!(Double3Expr, Double3); -impl_var_trait3!(Short3Expr, Short3); -impl_var_trait3!(Ushort3Expr, Ushort3); -impl_var_trait3!(Int3Expr, Int3); -impl_var_trait3!(Uint3Expr, Uint3); -impl_var_trait3!(Bool3Expr, Bool3); -impl_var_trait3!(Long3Expr, Long3); -impl_var_trait3!(Ulong3Expr, Ulong3); -impl_var_trait4!(Half4Expr, Half4); -impl_var_trait4!(Float4Expr, Float4); -impl_var_trait4!(Double4Expr, Double4); -impl_var_trait4!(Short4Expr, Short4); -impl_var_trait4!(Ushort4Expr, Ushort4); -impl_var_trait4!(Int4Expr, Int4); -impl_var_trait4!(Uint4Expr, Uint4); -impl_var_trait4!(Bool4Expr, Bool4); -impl_var_trait4!(Long4Expr, Long4); -impl_var_trait4!(Ulong4Expr, Ulong4); + unsafe impl> HasExprLayout::ExprData> for $ExprName {} + unsafe impl> HasVarLayout::VarData> for $VarName {} -macro_rules! impl_float_trait { - ($t:ty) => { - impl From for $t { - fn from(v: f32) -> Self { - Self::splat(v) - } + impl> ExprProxy for $ExprName { + type Value = Vector; } - impl FloatVarTrait for $t {} - }; -} - -impl_float_trait!(Half2Expr); -impl_float_trait!(Half3Expr); -impl_float_trait!(Half4Expr); -impl_float_trait!(Float2Expr); -impl_float_trait!(Float3Expr); -impl_float_trait!(Float4Expr); - -macro_rules! impl_int_trait { - ($t:ty) => { - impl From for $t { - fn from(v: i64) -> Self { - Self::splat(v) + impl> VarProxy for $VarName { + type Value = Vector; + } + impl> Deref for $VarName { + type Target = $ExprName; + fn deref(&self) -> &Self::Target { + _deref_proxy(self) } } - impl IntVarTrait for $t {} - }; + } } -impl_int_trait!(Int2Expr); -impl_int_trait!(Int3Expr); -impl_int_trait!(Int4Expr); - -impl_int_trait!(Long2Expr); -impl_int_trait!(Long3Expr); -impl_int_trait!(Long4Expr); - -impl_int_trait!(Uint2Expr); -impl_int_trait!(Uint3Expr); -impl_int_trait!(Uint4Expr); - -impl_int_trait!(Ulong2Expr); -impl_int_trait!(Ulong3Expr); -impl_int_trait!(Ulong4Expr); - -impl_int_trait!(Short2Expr); -impl_int_trait!(Short3Expr); -impl_int_trait!(Short4Expr); -impl_int_trait!(Ushort2Expr); -impl_int_trait!(Ushort3Expr); -impl_int_trait!(Ushort4Expr); +vector_proxies!(2 [x, y]: VectorExprProxy2, VectorVarProxy2); +vector_proxies!(3 [x, y, z, r, g, b]: VectorExprProxy3, VectorVarProxy3); +vector_proxies!(4 [x, y, z, w, r, g, b, a]: VectorExprProxy4, VectorVarProxy4); -impl Mul for Mat2Expr { - type Output = Float2Expr; - #[inline] - fn mul(self, rhs: Float2Expr) -> Self::Output { - Float2Expr::from_node(__current_scope(|s| { - s.call( - Func::Mul, - &[self.node, rhs.node], - ::type_(), - ) - })) - } -} -impl Mat2Expr { - pub fn fill(e: impl Into> + Copy) -> Self { - Self::new(Float2::expr(e, e), Float2::expr(e, e)) - } - pub fn eye(e: Expr) -> Self { - Self::new(Float2::expr(e.x(), 0.0), Float2::expr(0.0, e.y())) - } - pub fn inverse(&self) -> Self { - Mat2Expr::from_node(__current_scope(|s| { - s.call(Func::Inverse, &[self.node], ::type_()) - })) - } - pub fn transpose(&self) -> Self { - Mat2Expr::from_node(__current_scope(|s| { - s.call(Func::Transpose, &[self.node], ::type_()) - })) - } - pub fn determinant(&self) -> prim::Expr { - FromNode::from_node(__current_scope(|s| { - s.call(Func::Determinant, &[self.node], ::type_()) - })) - } -} -impl_arith_binop_for_mat!(Mat2, f32, Mat2Expr); -impl Mul for Mat3Expr { - type Output = Float3Expr; - #[inline] - fn mul(self, rhs: Float3Expr) -> Self::Output { - Float3Expr::from_node(__current_scope(|s| { - s.call( - Func::Mul, - &[self.node, rhs.node], - ::type_(), - ) - })) - } -} -impl Mat3Expr { - pub fn fill(e: impl Into> + Copy) -> Self { - Self::new( - Float3::expr(e, e, e), - Float3::expr(e, e, e), - Float3::expr(e, e, e), - ) - } - pub fn eye(e: Expr) -> Self { - Self::new( - Float3::expr(e.x(), 0.0, 0.0), - Float3::expr(0.0, e.y(), 0.0), - Float3::expr(0.0, 0.0, e.z()), - ) - } - pub fn inverse(&self) -> Self { - Self::from_node(__current_scope(|s| { - s.call(Func::Inverse, &[self.node], ::type_()) - })) - } - pub fn transpose(&self) -> Self { - Self::from_node(__current_scope(|s| { - s.call(Func::Transpose, &[self.node], ::type_()) - })) - } - pub fn determinant(&self) -> prim::Expr { - FromNode::from_node(__current_scope(|s| { - s.call(Func::Determinant, &[self.node], ::type_()) - })) +impl, const N: usize> TypeOf for Vector { + fn type_() -> CArc { + let type_ = Type::Vector(VectorType { + element: VectorElementType::Scalar(T::type_()), + length: N as u32, + }); + register_type(type_) } } -impl Mul for Mat4Expr { - type Output = Float4Expr; - #[inline] - fn mul(self, rhs: Float4Expr) -> Self::Output { - Float4Expr::from_node(__current_scope(|s| { - s.call( - Func::Mul, - &[self.node, rhs.node], - ::type_(), - ) - })) - } + +impl> Value for Vector { + type Expr = VectorExprProxy2; + type Var = VectorVarProxy2; + type ExprData = VectorExprData; + type VarData = VectorVarData; } -impl_arith_binop_for_mat!(Mat3, f32, Mat3Expr); -impl Mat4Expr { - pub fn fill(e: impl Into> + Copy) -> Self { - Self::new( - Float4::expr(e, e, e, e), - Float4::expr(e, e, e, e), - Float4::expr(e, e, e, e), - Float4::expr(e, e, e, e), - ) - } - pub fn eye(e: Expr) -> Self { - Self::new( - Float4::expr(e.x(), 0.0, 0.0, 0.0), - Float4::expr(0.0, e.y(), 0.0, 0.0), - Float4::expr(0.0, 0.0, e.z(), 0.0), - Float4::expr(0.0, 0.0, 0.0, e.w()), - ) - } - pub fn inverse(&self) -> Self { - Self::from_node(__current_scope(|s| { - s.call(Func::Inverse, &[self.node], ::type_()) - })) - } - pub fn transpose(&self) -> Self { - Self::from_node(__current_scope(|s| { - s.call(Func::Transpose, &[self.node], ::type_()) - })) - } - pub fn determinant(&self) -> prim::Expr { - FromNode::from_node(__current_scope(|s| { - s.call(Func::Determinant, &[self.node], ::type_()) - })) - } +impl> Value for Vector { + type Expr = VectorExprProxy3; + type Var = VectorVarProxy3; + type ExprData = DoubledProxyData>; + type VarData = DoubledProxyData>; } -impl_arith_binop_for_mat!(Mat4, f32, Mat4Expr); - -#[cfg(test)] -mod test { - #[test] - fn test_size() { - use crate::internal_prelude::*; - macro_rules! assert_size { - ($ty:ty) => { - {assert_eq!(std::mem::size_of::<$ty>(), <$ty as TypeOf>::type_().size());} - }; - ($ty:ty, $($rest:ty),*) => { - assert_size!($ty); - assert_size!($($rest),*); - }; - } - assert_size!(f32, f64, bool, u16, u32, u64, i16, i32, i64); - assert_size!(Float2, Float3, Float4, Int2, Int3, Int4, Uint2, Uint3, Uint4); - assert_size!(Short2, Short3, Short4, Ushort2, Ushort3, Ushort4); - assert_size!(Long2, Long3, Long4, Ulong2, Ulong3, Ulong4); - assert_size!(Mat2, Mat3, Mat4); - assert_size!(PackedFloat2, PackedFloat3, PackedFloat4); - assert_eq!(std::mem::size_of::(), 12); - } +impl> Value for Vector { + type Expr = VectorExprProxy4; + type Var = VectorVarProxy4; + type ExprData = DoubledProxyData>; + type VarData = DoubledProxyData>; } diff --git a/luisa_compute/src/lang/types/vector/coords.rs b/luisa_compute/src/lang/types/vector/coords.rs index 4563e55..5b11642 100644 --- a/luisa_compute/src/lang/types/vector/coords.rs +++ b/luisa_compute/src/lang/types/vector/coords.rs @@ -1 +1,70 @@ use super::*; +use std::ops::{Deref, DerefMut}; + +macro_rules! impl_coords { + ($T:ident [ $($c:ident), * ]) => { + #[repr(C)] + #[derive(Copy, Clone, PartialEq, Eq, Hash)] + pub struct $T { + $(pub $c: T),* + } + } +} +macro_rules! impl_deref { + ($T:ident; $N:literal) => { + impl> Deref for Vector { + type Target = $T; + + #[inline] + fn deref(&self) -> &$T { + unsafe { &*(self as *const Self as *const $T) } + } + } + + impl> DerefMut for Vector { + #[inline] + fn deref_mut(&self) -> &$T { + unsafe { &*(self as *const Self as *const $T) } + } + } + }; +} + +impl_coords!(XY[x, y]); +impl_coords!(XYZ[x, y, z]); +impl_coords!(XYZW[x, y, z, w]); +impl_coords!(RGB[r, g, b]); +impl_coords!(RGBA[r, g, b, a]); + +impl_deref![XY; 2]; +impl_deref![XYZ; 3]; +impl_deref![XYZW; 4]; + +impl Deref for XYZ { + type Target = RGB; + + #[inline] + fn deref(&self) -> &RGB { + unsafe { &*(self as *const Self as *const RGB) } + } +} +impl DerefMut for XYZ { + #[inline] + fn deref_mut(&self) -> &RGB { + unsafe { &*(self as *const Self as *const RGB) } + } +} +impl Deref for XYZW { + type Target = RGBA; + + #[inline] + fn deref(&self) -> &RGBA { + unsafe { &*(self as *const Self as *const RGBA) } + } +} +impl DerefMut for XYZW { + #[inline] + fn deref_mut(&self) -> &RGBA { + unsafe { &*(self as *const Self as *const RGBA) } + } +} diff --git a/luisa_compute/src/lang/types/vector/element.rs b/luisa_compute/src/lang/types/vector/element.rs new file mode 100644 index 0000000..72425e6 --- /dev/null +++ b/luisa_compute/src/lang/types/vector/element.rs @@ -0,0 +1,50 @@ +use super::*; + +macro_rules! element { + ($t:ty [ $l:literal ]: $a: ident) => { + impl VectorElement<$l> for $t { + type A = $a; + } + }; +} + +element!(bool[2]: Align2); +element!(bool[3]: Align4); +element!(bool[4]: Align4); +// TODO: Make u8 support ir::TypeOf. +// element!(u8[2]: Align2); +// element!(u8[3]: Align4); +// element!(u8[4]: Align4); +// element!(i8[2]: Align2); +// element!(i8[3]: Align4); +// element!(i8[4]: Align4); + +element!(f16[2]: Align4); +element!(f16[3]: Align8); +element!(f16[4]: Align8); +element!(u16[2]: Align4); +element!(u16[3]: Align8); +element!(u16[4]: Align8); +element!(i16[2]: Align4); +element!(i16[3]: Align8); +element!(i16[4]: Align8); + +element!(f32[2]: Align8); +element!(f32[3]: Align16); +element!(f32[4]: Align16); +element!(u32[2]: Align8); +element!(u32[3]: Align16); +element!(u32[4]: Align16); +element!(i32[2]: Align8); +element!(i32[3]: Align16); +element!(i32[4]: Align16); + +element!(f64[2]: Align16); +element!(f64[3]: Align32); +element!(f64[4]: Align32); +element!(u64[2]: Align16); +element!(u64[3]: Align32); +element!(u64[4]: Align32); +element!(i64[2]: Align16); +element!(i64[3]: Align32); +element!(i64[4]: Align32); From bc02deb658da16b6a2aad230beff66bafde72814 Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Tue, 19 Sep 2023 13:54:06 +0100 Subject: [PATCH 05/15] Bugfixes for vector. --- luisa_compute/src/lang/types/alignment.rs | 3 ++- luisa_compute/src/lang/types/vector.rs | 7 ++++--- luisa_compute/src/lib.rs | 12 +----------- 3 files changed, 7 insertions(+), 15 deletions(-) diff --git a/luisa_compute/src/lang/types/alignment.rs b/luisa_compute/src/lang/types/alignment.rs index f217a29..2b6977d 100644 --- a/luisa_compute/src/lang/types/alignment.rs +++ b/luisa_compute/src/lang/types/alignment.rs @@ -1,6 +1,7 @@ use super::*; +use std::hash::Hash; -pub(crate) trait Alignment: Default { +pub(crate) trait Alignment: Default + Copy + Hash + Eq + 'static { const ALIGNMENT: usize; } diff --git a/luisa_compute/src/lang/types/vector.rs b/luisa_compute/src/lang/types/vector.rs index fd5ba2d..4a6318f 100644 --- a/luisa_compute/src/lang/types/vector.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -27,6 +27,7 @@ impl, const N: usize> Debug for Vector { #[repr(C)] #[derive(Copy, Clone, Hash, Default, PartialEq, Eq, Serialize, Deserialize)] pub struct Vector, const N: usize> { + #[serde(skip)] _align: T::A, elements: [T; N], } @@ -92,8 +93,8 @@ macro_rules! vector_proxies { $(pub $c: Var),* } - unsafe impl> HasExprLayout::ExprData> for $ExprName {} - unsafe impl> HasVarLayout::VarData> for $VarName {} + unsafe impl> HasExprLayout< as Value>::ExprData> for $ExprName {} + unsafe impl> HasVarLayout< as Value>::VarData> for $VarName {} impl> ExprProxy for $ExprName { type Value = Vector; @@ -102,7 +103,7 @@ macro_rules! vector_proxies { type Value = Vector; } impl> Deref for $VarName { - type Target = $ExprName; + type Target = Expr>; fn deref(&self) -> &Self::Target { _deref_proxy(self) } diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 79d2b37..98d0138 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -23,16 +23,7 @@ pub mod prelude { pub use crate::lang::index::{IndexRead, IndexWrite}; pub use crate::lang::ops::*; pub use crate::lang::swizzle::*; - pub use crate::lang::types::vector::{ - Bool2, Bool3, Bool4, Byte2, Byte3, Byte4, Double2, Double3, Double4, Float2, Float3, - Float4, Half2, Half3, Half4, Int2, Int3, Int4, Long2, Long3, Long4, Mat2, Mat3, Mat4, - PackedBool2, PackedBool3, PackedBool4, PackedFloat2, PackedFloat3, PackedFloat4, - PackedInt2, PackedInt3, PackedInt4, PackedLong2, PackedLong3, PackedLong4, PackedShort2, - PackedShort3, PackedShort4, PackedUint2, PackedUint3, PackedUint4, PackedUlong2, - PackedUlong3, PackedUlong4, PackedUshort2, PackedUshort3, PackedUshort4, Short2, Short3, - Short4, Ubyte2, Ubyte3, Ubyte4, Uint2, Uint3, Uint4, Ulong2, Ulong3, Ulong4, Ushort2, - Ushort3, Ushort4, - }; + pub use crate::lang::types::vector::Vector; pub use crate::lang::types::{Expr, ExprProxy, Value, Var, VarProxy}; pub use crate::lang::Aggregate; pub use crate::resource::{IoTexel, StorageTexel, *}; @@ -53,7 +44,6 @@ mod internal_prelude { new_node, register_type, BasicBlock, Const, Func, Instruction, IrBuilder, Node, PhiIncoming, Pooled, Type, TypeOf, INVALID_REF, }; - pub(crate) use crate::lang::types::vector::*; pub(crate) use crate::lang::{ ir, Recorder, __compose, __extract, __insert, __module_pools, need_runtime_check, FromNode, NodeLike, NodeRef, ToNode, __current_scope, __pop_scope, RECORDER, From 9e6eacf6b0c2b59a2646f45a727e4d82e84a46cd Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Wed, 20 Sep 2023 12:19:40 +0100 Subject: [PATCH 06/15] Working on linear operations. --- luisa_compute/src/lang.rs | 40 +- luisa_compute/src/lang/ops.rs | 529 ++------------- luisa_compute/src/lang/ops/impls.rs | 611 +++++++++--------- luisa_compute/src/lang/ops/spread.rs | 264 ++++++++ luisa_compute/src/lang/ops/traits.rs | 141 ++++ luisa_compute/src/lang/types.rs | 6 + luisa_compute/src/lang/types/core.rs | 78 ++- luisa_compute/src/lang/types/dynamic.rs | 2 +- luisa_compute/src/lang/types/vector.rs | 211 +++++- .../src/lang/types/vector/element.rs | 18 +- .../lang/{ => types/vector}/gen_swizzle.py | 6 +- .../src/lang/{ => types/vector}/swizzle.rs | 18 +- luisa_compute/src/lib.rs | 7 +- luisa_compute_sys/LuisaCompute | 2 +- luisa_compute_track/src/lib.rs | 4 + 15 files changed, 1106 insertions(+), 831 deletions(-) create mode 100644 luisa_compute/src/lang/ops/spread.rs create mode 100644 luisa_compute/src/lang/ops/traits.rs rename luisa_compute/src/lang/{ => types/vector}/gen_swizzle.py (88%) rename luisa_compute/src/lang/{ => types/vector}/swizzle.rs (98%) diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index 159c354..6799cb2 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -29,12 +29,48 @@ pub mod debug; pub mod diff; pub mod functions; pub mod index; -pub mod maybe_expr; +// pub mod maybe_expr; pub mod ops; pub mod poly; -pub mod swizzle; pub mod types; +pub(crate) trait CallFuncTrait { + fn call(self, x: Expr) -> Expr; + fn call2(self, x: Expr, y: Expr) -> Expr; + fn call3( + self, + x: Expr, + y: Expr, + z: Expr, + ) -> Expr; +} +impl CallFuncTrait for Func { + fn call(self, x: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(self, &[x.node()], ::type_()) + })) + } + fn call2(self, x: Expr, y: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call(self, &[x.node(), y.node()], ::type_()) + })) + } + fn call3( + self, + x: Expr, + y: Expr, + z: Expr, + ) -> Expr { + Expr::::from_node(__current_scope(|b| { + b.call( + self, + &[x.node(), y.node(), z.node()], + ::type_(), + ) + })) + } +} + #[allow(dead_code)] pub(crate) static KERNEL_ID: AtomicUsize = AtomicUsize::new(0); // prevent node being shared across kernels diff --git a/luisa_compute/src/lang/ops.rs b/luisa_compute/src/lang/ops.rs index ff514bc..9e8b551 100644 --- a/luisa_compute/src/lang/ops.rs +++ b/luisa_compute/src/lang/ops.rs @@ -1,490 +1,47 @@ use crate::internal_prelude::*; use std::ops::*; -pub mod impls; - -pub trait VarTrait: Copy + Clone + 'static + NodeLike { - type Value: Value; - type Short: VarTrait; - type Ushort: VarTrait; - type Int: VarTrait; - type Uint: VarTrait; - type Long: VarTrait; - type Ulong: VarTrait; - type Half: VarTrait; - type Float: VarTrait; - type Double: VarTrait; - type Bool: VarTrait + Not + BitAnd; - fn type_() -> CArc { - ::type_() - } -} +use super::types::core::{Floating, Integral, Numeric, Primitive, Signed}; +use super::types::vector::VectorElement; -fn _cast(expr: T) -> U { - let node = expr.node(); - __current_scope(|s| { - let ret = s.call(Func::Cast, &[node], U::type_()); - U::from_node(ret) - }) -} -pub trait CommonVarOp: VarTrait { - fn max>(&self, other: A) -> Self { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Max, &[lhs, rhs], Self::type_()); - Self::from_node(ret) - }) - } - fn min>(&self, other: A) -> Self { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Min, &[lhs, rhs], Self::type_()); - Self::from_node(ret) - }) - } - fn clamp, B: Into>(&self, min: A, max: B) -> Self { - let min = min.into().node(); - let max = max.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Clamp, &[self.node(), min, max], Self::type_()); - Self::from_node(ret) - }) - } - fn abs(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Abs, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn bitcast(&self) -> Expr { - assert_eq!(std::mem::size_of::(), std::mem::size_of::()); - let ty = ::type_(); - let node = __current_scope(|s| s.bitcast(self.node(), ty)); - Expr::::from_node(node) - } - fn uint(&self) -> Self::Uint { - _cast(*self) - } - fn int(&self) -> Self::Int { - _cast(*self) - } - fn ulong(&self) -> Self::Ulong { - _cast(*self) - } - fn long(&self) -> Self::Long { - _cast(*self) - } - fn float(&self) -> Self::Float { - _cast(*self) - } - fn short(&self) -> Self::Short { - _cast(*self) - } - fn ushort(&self) -> Self::Ushort { - _cast(*self) - } - fn half(&self) -> Self::Half { - _cast(*self) - } - fn double(&self) -> Self::Double { - _cast(*self) - } - fn bool_(&self) -> Self::Bool { - _cast(*self) - } -} -pub trait VarCmpEq: VarTrait { - fn cmpeq>(&self, other: A) -> Self::Bool { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Eq, &[lhs, rhs], Self::Bool::type_()); - FromNode::from_node(ret) - }) - } - fn cmpne>(&self, other: A) -> Self::Bool { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Ne, &[lhs, rhs], Self::Bool::type_()); - FromNode::from_node(ret) - }) - } -} -pub trait VarCmp: VarTrait + VarCmpEq { - fn cmplt>(&self, other: A) -> Self::Bool { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Lt, &[lhs, rhs], Self::Bool::type_()); - FromNode::from_node(ret) - }) - } - fn cmple>(&self, other: A) -> Self::Bool { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Le, &[lhs, rhs], Self::Bool::type_()); - FromNode::from_node(ret) - }) - } - fn cmpgt>(&self, other: A) -> Self::Bool { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Gt, &[lhs, rhs], Self::Bool::type_()); - FromNode::from_node(ret) - }) - } - fn cmpge>(&self, other: A) -> Self::Bool { - let lhs = self.node(); - let rhs = other.into().node(); - __current_scope(|s| { - let ret = s.call(Func::Ge, &[lhs, rhs], Self::Bool::type_()); - FromNode::from_node(ret) - }) - } -} -pub trait IntVarTrait: - VarTrait - + CommonVarOp - + VarCmp - + Add - + Sub - + Mul - + Div - + Rem - + Shl - + Shr - + BitAnd - + BitOr - + BitXor - + AddAssign - + SubAssign - + MulAssign - + DivAssign - + RemAssign - + ShlAssign - + ShrAssign - + BitAndAssign - + BitOrAssign - + BitXorAssign - + Neg - + Clone - + Not - + From - + From -{ - fn one() -> Self { - Self::from(1i64) - } - fn zero() -> Self { - Self::from(0i64) - } - fn rotate_right(&self, n: Expr) -> Self { - let lhs = self.node(); - let rhs = Expr::::node(&n); - __current_scope(|s| { - let ret = s.call(Func::RotRight, &[lhs, rhs], Self::type_()); - Self::from_node(ret) - }) - } - fn rotate_left(&self, n: Expr) -> Self { - let lhs = self.node(); - let rhs = Expr::::node(&n); - __current_scope(|s| { - let ret = s.call(Func::RotLeft, &[lhs, rhs], Self::type_()); - Self::from_node(ret) - }) - } -} -pub trait FloatVarTrait: - VarTrait - + CommonVarOp - + VarCmp - + Add - + Sub - + Mul - + Div - + Rem - + Neg - + AddAssign - + SubAssign - + MulAssign - + DivAssign - + RemAssign - + Clone - + From - + From -{ - fn one() -> Self { - Self::from(1.0f32) - } - fn zero() -> Self { - Self::from(0.0f32) - } - fn mul_add, B: Into>(&self, a: A, b: B) -> Self { - let a: Self = a.into(); - let b: Self = b.into(); - let node = __current_scope(|s| { - s.call(Func::Fma, &[self.node(), a.node(), b.node()], Self::type_()) - }); - Self::from_node(node) - } - fn ceil(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Ceil, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn floor(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Floor, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn round(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Round, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn trunc(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Trunc, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn copysign>(&self, other: A) -> Self { - __current_scope(|s| { - let ret = s.call( - Func::Copysign, - &[self.node(), other.into().node()], - Self::type_(), - ); - Self::from_node(ret) - }) - } - fn sqrt(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Sqrt, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn rsqrt(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Rsqrt, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn fract(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Fract, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - - // x.step(edge) - fn step(&self, edge: Self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Step, &[edge.node(), self.node()], Self::type_()); - Self::from_node(ret) - }) - } - - fn smooth_step(&self, edge0: Self, edge1: Self) -> Self { - __current_scope(|s| { - let ret = s.call( - Func::SmoothStep, - &[edge0.node(), edge1.node(), self.node()], - Self::type_(), - ); - Self::from_node(ret) - }) - } - - fn saturate(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Saturate, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - - fn sin(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Sin, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - // crate::math::approx_sin_cos(self.clone(), true, false).0 - } - fn cos(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Cos, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - // crate::math::approx_sin_cos(self.clone(), false, true).1 - } - fn tan(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Tan, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn asin(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Asin, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn acos(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Acos, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn atan(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Atan, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn atan2(&self, other: Self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Atan2, &[self.node(), other.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn sinh(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Sinh, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn cosh(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Cosh, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn tanh(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Tanh, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn asinh(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Asinh, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn acosh(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Acosh, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn atanh(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Atanh, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn exp(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Exp, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn exp2(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Exp2, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn is_finite(&self) -> Self::Bool { - !self.is_infinite() & !self.is_nan() - } - fn is_infinite(&self) -> Self::Bool { - __current_scope(|s| { - let ret = s.call(Func::IsInf, &[self.node()], ::type_()); - FromNode::from_node(ret) - }) - } - fn is_nan(&self) -> Self::Bool { - __current_scope(|s| { - let ret = s.call(Func::IsNan, &[self.node()], ::type_()); - FromNode::from_node(ret) - }) - } - fn ln(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Log, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn log(&self, base: impl Into) -> Self { - self.ln() / base.into().ln() - } - fn log2(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Log2, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn log10(&self) -> Self { - __current_scope(|s| { - let ret = s.call(Func::Log10, &[self.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn powf(&self, exp: impl Into) -> Self { - let exp = exp.into(); - __current_scope(|s| { - let ret = s.call(Func::Powf, &[self.node(), exp.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn sqr(&self) -> Self { - *self * *self - } - fn cube(&self) -> Self { - *self * *self * *self - } - fn powi(&self, exp: impl Into) -> Self { - let exp = exp.into(); - __current_scope(|s| { - let ret = s.call(Func::Powi, &[self.node(), exp.node()], Self::type_()); - Self::from_node(ret) - }) - } - fn lerp(&self, other: impl Into, frac: impl Into) -> Self { - let other = other.into(); - let frac = frac.into(); - __current_scope(|s| { - let ret = s.call( - Func::Lerp, - &[self.node(), other.node(), frac.node()], - Self::type_(), - ); - Self::from_node(ret) - }) - } - fn recip(&self) -> Self { - Self::one() / self.clone() - } - fn sin_cos(&self) -> (Self, Self) { - (self.sin(), self.cos()) - } -} - -pub trait ScalarVarTrait: ToNode + FromNode {} -pub trait VectorVarTrait: ToNode + FromNode {} -pub trait MatrixVarTrait: ToNode + FromNode {} -pub trait ScalarOrVector: ToNode + FromNode { - type Element: ScalarVarTrait; - type ElementHost: Value; -} -pub trait BuiltinVarTrait: ToNode + FromNode {} +pub mod impls; +pub mod spread; +pub mod traits; + +trait CastFrom: Primitive {} +impl CastFrom for T {} +impl CastFrom for T {} + +// Hack because using an associated constant is not allowed within a trait bound +// without #![feature(generic_const_exprs)]. +pub trait Linear: Value { + type Scalar: VectorElement; + type WithScalar>: Linear; + // We don't actually know that the vector has equivalent vectors of every + // primitive type. + type WithBool: Linear; +} +impl Linear<1> for T { + type Scalar = T; + type WithScalar = S; + type WithBool = bool; +} +macro_rules! impl_linear_vectors { + ($t:ty) => { + impl_linear_vectors!($t: 2, 3, 4); + }; + ($t:ty, $($ts:ty),+) => { + impl_linear_vectors!($t); + impl_linear_vectors!($($ts),+); + }; + ($t:ty : $($n:literal),+) => { + $( + impl Linear<$n> for Vector<$t, $n> { + type Scalar = $t; + type WithScalar = Vector; + type WithBool = Vector; + } + )+ + } +} +impl_linear_vectors!(bool, f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64); diff --git a/luisa_compute/src/lang/ops/impls.rs b/luisa_compute/src/lang/ops/impls.rs index b581b40..916e13a 100644 --- a/luisa_compute/src/lang/ops/impls.rs +++ b/luisa_compute/src/lang/ops/impls.rs @@ -1,338 +1,347 @@ use super::*; -use crate::lang::types::core::*; -use crate::lang::types::VarDeref; +use traits::*; -macro_rules! impl_var_trait { - ($t:ty) => { - impl VarTrait for prim::Expr<$t> { - type Value = $t; - type Short = prim::Expr; - type Ushort = prim::Expr; - type Int = prim::Expr; - type Uint = prim::Expr; - type Long = prim::Expr; - type Ulong = prim::Expr; - type Half = prim::Expr; - type Float = prim::Expr; - type Double = prim::Expr; - type Bool = prim::Expr; - } - impl ScalarVarTrait for prim::Expr<$t> {} - impl ScalarOrVector for prim::Expr<$t> { - type Element = prim::Expr<$t>; - type ElementHost = $t; - } - impl BuiltinVarTrait for prim::Expr<$t> {} - }; +impl> Expr { + fn as_>(self) -> Y + where + Y::Scalar: CastFrom, + { + Func::Cast.call(self) + } + fn cast>(self) -> Expr> + where + S: CastFrom, + { + self.as_::>() + } } -impl_var_trait!(f16); -impl_var_trait!(f32); -impl_var_trait!(f64); -impl_var_trait!(i16); -impl_var_trait!(u16); -impl_var_trait!(i32); -impl_var_trait!(u32); -impl_var_trait!(i64); -impl_var_trait!(u64); -impl_var_trait!(bool); - -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} -impl VarCmpEq for prim::Expr {} - -impl VarCmpEq for prim::Expr {} - -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} -impl VarCmp for prim::Expr {} - -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} -impl CommonVarOp for prim::Expr {} - -impl CommonVarOp for prim::Expr {} - -impl FloatVarTrait for prim::Expr {} -impl FloatVarTrait for prim::Expr {} -impl FloatVarTrait for prim::Expr {} -impl IntVarTrait for prim::Expr {} -impl IntVarTrait for prim::Expr {} -impl IntVarTrait for prim::Expr {} -impl IntVarTrait for prim::Expr {} -impl IntVarTrait for prim::Expr {} -impl IntVarTrait for prim::Expr {} +impl> MinMaxExpr for Expr +where + X::Scalar: Numeric, +{ + type Output = Self; -macro_rules! impl_from { - ($from:ty, $to:ty) => { - impl From<$from> for prim::Expr<$to> { - fn from(x: $from) -> Self { - let y: $to = (x.try_into().unwrap()); - y.expr() - } - } - }; + fn max(self, other: Self) -> Self { + Func::Max.call2(self, other) + } + fn min(self, other: Self) -> Self { + Func::Min.call2(self, other) + } } -impl_from!(i16, u16); -impl_from!(i16, i32); -impl_from!(i16, u32); -impl_from!(i16, i64); -impl_from!(i16, u64); - -impl_from!(u16, i16); -impl_from!(u16, i32); -impl_from!(u16, u32); -impl_from!(u16, i64); -impl_from!(u16, u64); - -impl_from!(i32, u16); -impl_from!(i32, i16); -impl_from!(i32, u32); -impl_from!(i32, i64); -impl_from!(i32, u64); +impl> ClampExpr for Expr +where + X::Scalar: Numeric, +{ + type Output = Self; -impl_from!(i64, u16); -impl_from!(i64, i16); -impl_from!(i64, u64); -impl_from!(i64, i32); -impl_from!(i64, u32); + fn clamp(self, min: Self, max: Self) -> Self { + Func::Clamp.call3(self, min, max) + } +} -impl_from!(u32, u16); -impl_from!(u32, i16); -impl_from!(u32, i32); -impl_from!(u32, i64); -impl_from!(u32, u64); +impl> AbsExpr for Expr +where + X::Scalar: Signed, +{ + fn abs(&self) -> Self { + Func::Abs.call(self) + } +} -impl_from!(u64, u16); -impl_from!(u64, i16); -impl_from!(u64, i64); -impl_from!(u64, i32); -impl_from!(u64, u32); +impl> EqExpr for Expr { + type Output = Expr; + fn eq(self, other: Self) -> Self::Output { + Func::Eq.call2(self, other) + } + fn ne(self, other: Self) -> Self::Output { + Func::Ne.call2(self, other) + } +} +impl> CmpExpr for Expr { + fn lt(self, other: Self) -> Self::Output { + Func::Lt.call2(self, other) + } + fn le(self, other: Self) -> Self::Output { + Func::Le.call2(self, other) + } + fn gt(self, other: Self) -> Self::Output { + Func::Gt.call2(self, other) + } + fn ge(self, other: Self) -> Self::Output { + Func::Ge.call2(self, other) + } +} -impl From for prim::Expr { - fn from(x: f64) -> Self { - (x as f32).into() +impl> Add for Expr +where + X::Scalar: Numeric, +{ + type Output = Self; + fn add(self, other: Self) -> Self { + Func::Add.call2(self, other) } } -impl From for prim::Expr { - fn from(x: f32) -> Self { - (x as f64).into() +impl> Sub for Expr +where + X::Scalar: Numeric, +{ + type Output = Self; + fn sub(self, other: Self) -> Self { + Func::Sub.call2(self, other) } } -impl From for prim::Expr { - fn from(x: f64) -> Self { - f16::from_f64(x).into() +impl> Mul for Expr +where + X::Scalar: Numeric, +{ + type Output = Self; + fn mul(self, other: Self) -> Self { + Func::Mul.call2(self, other) } } -impl From for prim::Expr { - fn from(x: f32) -> Self { - f16::from_f32(x).into() +impl> Div for Expr +where + X::Scalar: Numeric, +{ + type Output = Self; + fn div(self, other: Self) -> Self { + Func::Div.call2(self, other) + } +} +impl> Rem for Expr +where + X::Scalar: Numeric, +{ + type Output = Self; + fn rem(self, other: Self) -> Self { + Func::Rem.call2(self, other) } } -macro_rules! impl_binop { - ($t:ty, $proxy:ty, $tr_assign:ident, $method_assign:ident, $tr:ident, $method:ident) => { - impl $tr_assign> for $proxy { - fn $method_assign(&mut self, rhs: prim::Expr<$t>) { - *self = self.clone().$method(rhs); - } - } - impl $tr_assign<$t> for $proxy { - fn $method_assign(&mut self, rhs: $t) { - *self = self.clone().$method(rhs); - } - } - impl $tr> for $proxy { - type Output = prim::Expr<$t>; - fn $method(self, rhs: prim::Expr<$t>) -> Self::Output { - __current_scope(|s| { - let lhs = ToNode::node(&self); - let rhs = ToNode::node(&rhs); - let ret = s.call(Func::$tr, &[lhs, rhs], Self::Output::type_()); - Expr::<$t>::from_node(ret) - }) - } - } +impl> BitAnd for Expr +where + X::Scalar: Integral, +{ + type Output = Self; + fn bitand(self, other: Self) -> Self { + Func::BitAnd.call2(self, other) + } +} +impl> BitOr for Expr +where + X::Scalar: Integral, +{ + type Output = Self; + fn bitor(self, other: Self) -> Self { + Func::BitOr.call2(self, other) + } +} +impl> BitXor for Expr +where + X::Scalar: Integral, +{ + type Output = Self; + fn bitxor(self, other: Self) -> Self { + Func::BitXor.call2(self, other) + } +} +impl> Shl for Expr +where + X::Scalar: Integral, +{ + type Output = Self; + fn shl(self, other: Self) -> Self { + Func::Shl.call2(self, other) + } +} +impl> Shr for Expr +where + X::Scalar: Integral, +{ + type Output = Self; + fn shr(self, other: Self) -> Self { + Func::Shr.call2(self, other) + } +} - impl $tr<$t> for $proxy { - type Output = prim::Expr<$t>; - fn $method(self, rhs: $t) -> Self::Output { - $tr::$method(self, rhs.expr()) - } - } - impl $tr<$proxy> for $t { - type Output = prim::Expr<$t>; - fn $method(self, rhs: $proxy) -> Self::Output { - $tr::$method(self.expr(), rhs) - } - } - }; +impl> Neg for Expr +where + X::Scalar: Signed, +{ + type Output = Self; + fn neg(self) -> Self { + Func::Neg.call(self) + } } -macro_rules! impl_common_binop { - ($t:ty,$proxy:ty) => { - impl_binop!($t, $proxy, AddAssign, add_assign, Add, add); - impl_binop!($t, $proxy, SubAssign, sub_assign, Sub, sub); - impl_binop!($t, $proxy, MulAssign, mul_assign, Mul, mul); - impl_binop!($t, $proxy, DivAssign, div_assign, Div, div); - impl_binop!($t, $proxy, RemAssign, rem_assign, Rem, rem); - }; +impl> Not for Expr +where + X::Scalar: Integral, +{ + type Output = Self; + fn not(self) -> Self { + Func::BitNot.call(self) + } } -macro_rules! impl_int_binop { - ($t:ty,$proxy:ty) => { - impl_binop!($t, $proxy, ShlAssign, shl_assign, Shl, shl); - impl_binop!($t, $proxy, ShrAssign, shr_assign, Shr, shr); - impl_binop!($t, $proxy, BitAndAssign, bitand_assign, BitAnd, bitand); - impl_binop!($t, $proxy, BitOrAssign, bitor_assign, BitOr, bitor); - impl_binop!($t, $proxy, BitXorAssign, bitxor_assign, BitXor, bitxor); - }; + +impl> IntExpr for Expr +where + X::Scalar: Integral + Numeric, +{ + fn rotate_left(&self, n: Expr) -> Self { + Func::RotRight.call2(self, n) + } + fn rotate_right(&self, n: Expr) -> Self { + Func::RotLeft.call2(self, n) + } } -macro_rules! impl_not { - ($t:ty,$proxy:ty) => { - impl Not for $proxy { - type Output = prim::Expr<$t>; - fn not(self) -> Self::Output { - __current_scope(|s| { - let ret = s.call(Func::BitNot, &[ToNode::node(&self)], Self::Output::type_()); - Expr::<$t>::from_node(ret) - }) - } +macro_rules! impl_simple_fns { + ($($fname:ident => $func:ident),+) => {$( + fn $fname(&self) -> Self { + Func::$func.call(self) } - }; + )+}; } -macro_rules! impl_neg { - ($t:ty,$proxy:ty) => { - impl Neg for $proxy { - type Output = prim::Expr<$t>; - fn neg(self) -> Self::Output { - __current_scope(|s| { - let ret = s.call(Func::Neg, &[ToNode::node(&self)], Self::Output::type_()); - Expr::<$t>::from_node(ret) - }) - } - } - }; + +impl> FloatExpr for Expr +where + X::Scalar: Floating, +{ + type Bool = Self::WithBool; + impl_simple_fns! { + ceil => Ceil, + floor => Floor, + round => Round, + trunc => Trunc, + sqrt => Sqrt, + rsqrt => Rsqrt, + fract => Fract, + saturate => Saturate, + sin => Sin, + cos => Cos, + tan => Tan, + asin => Asin, + acos => Acos, + atan => Atan, + sinh => Sinh, + cosh => Cosh, + tanh => Tanh, + asinh => Asinh, + acosh => Acosh, + atanh => Atanh, + exp => Exp, + exp2 => Exp2, + is_infinite => IsInf, + is_nan => IsNan, + ln => Log, + log2 => Log2, + log10 => Log10 + } + fn is_finite(&self) -> Self::Bool { + !self.is_infinite() & !self.is_nan() + } + fn sqr(&self) -> Self { + *self * *self + } + fn cube(&self) -> Self { + *self * *self * *self + } + fn recip(&self) -> Self { + 1.0 / *self + } + fn sin_cos(&self) -> (Self, Self) { + (self.sin(), self.cos()) + } } -macro_rules! impl_fneg { - ($t:ty, $proxy:ty) => { - impl Neg for $proxy { - type Output = prim::Expr<$t>; - fn neg(self) -> Self::Output { - __current_scope(|s| { - let ret = s.call(Func::Neg, &[ToNode::node(&self)], Self::Output::type_()); - Expr::<$t>::from_node(ret) - }) - } - } - }; +impl> FloatMulAddExpr for Expr +where + X::Scalar: Floating, +{ + type Output = Self; + + fn mul_add(self, a: Self, b: Self) -> Self::Output { + Func::Fma.call3(self, a, b) + } } -impl Not for prim::Expr { - type Output = prim::Expr; - fn not(self) -> Self::Output { - __current_scope(|s| { - let ret = s.call(Func::BitNot, &[ToNode::node(&self)], Self::Output::type_()); - FromNode::from_node(ret) - }) +impl> FloatCopySignExpr for Expr +where + X::Scalar: Floating, +{ + type Output = Self; + + fn copy_sign(self, sign: Self) -> Self::Output { + Func::Copysign.call2(self, sign) } } -impl_common_binop!(f16, prim::Expr); -impl_common_binop!(f32, prim::Expr); -impl_common_binop!(f64, prim::Expr); -impl_common_binop!(i16, prim::Expr); -impl_common_binop!(i32, prim::Expr); -impl_common_binop!(i64, prim::Expr); -impl_common_binop!(u16, prim::Expr); -impl_common_binop!(u32, prim::Expr); -impl_common_binop!(u64, prim::Expr); +impl> FloatStepExpr for Expr +where + X::Scalar: Floating, +{ + type Output = Self; -impl_binop!( - bool, - prim::Expr, - BitAndAssign, - bitand_assign, - BitAnd, - bitand -); -impl_binop!( - bool, - prim::Expr, - BitOrAssign, - bitor_assign, - BitOr, - bitor -); -impl_binop!( - bool, - prim::Expr, - BitXorAssign, - bitxor_assign, - BitXor, - bitxor -); -impl_int_binop!(i16, prim::Expr); -impl_int_binop!(i32, prim::Expr); -impl_int_binop!(i64, prim::Expr); -impl_int_binop!(u16, prim::Expr); -impl_int_binop!(u32, prim::Expr); -impl_int_binop!(u64, prim::Expr); + fn step(self, edge: Self) -> Self::Output { + Func::Step.call2(edge, self) + } +} +impl> FloatSmoothStepExpr for Expr +where + X::Scalar: Floating, +{ + type Output = Self; -impl_not!(i16, prim::Expr); -impl_not!(i32, prim::Expr); -impl_not!(i64, prim::Expr); -impl_not!(u16, prim::Expr); -impl_not!(u32, prim::Expr); -impl_not!(u64, prim::Expr); + fn smooth_step(self, edge0: Self, edge1: Self) -> Self::Output { + Func::SmoothStep.call3(edge0, edge1, self) + } +} +impl> FloatArcTan2Expr for Expr +where + X::Scalar: Floating, +{ + type Output = Self; -impl_neg!(i16, prim::Expr); -impl_neg!(i32, prim::Expr); -impl_neg!(i64, prim::Expr); -impl_neg!(u16, prim::Expr); -impl_neg!(u32, prim::Expr); -impl_neg!(u64, prim::Expr); + fn atan2(self, other: Self) -> Self::Output { + Func::Atan2.call2(self, other) + } +} +impl> FloatLogExpr for Expr +where + X::Scalar: Floating, +{ + type Output = Self; -impl_fneg!(f16, prim::Expr); -impl_fneg!(f32, prim::Expr); -impl_fneg!(f64, prim::Expr); + fn log(self, base: Self) -> Self::Output { + self.ln() / base.ln() + } +} +impl> FloatPowfExpr for Expr +where + X::Scalar: Floating, +{ + type Output = Self; -macro_rules! impl_assign_ops { - ($ass:ident, $ass_m:ident, $o:ident, $o_m:ident) => { - impl std::ops::$ass for VarDerefProxy - where - P: VarProxy, - Expr: std::ops::$o>, - { - fn $ass_m(&mut self, rhs: Rhs) { - *self.deref_mut() = std::ops::$o::$o_m(**self, rhs); - } - } - }; + fn powf(self, exponent: Self) -> Self::Output { + Func::Powf.call2(self, exponent) + } +} +impl, Y: Linear> FloatPowiExpr> for Expr +where + X::Scalar: Floating, +{ + type Output = Self; + + fn powi(self, exponent: Expr) -> Self::Output { + Func::Powi.call2(self, exponent) + } +} +impl> FloatLerpExpr for Expr +where + X::Scalar: Floating, +{ + type Output = Self; + + fn lerp(self, other: Self, frac: Self) -> Self::Output { + Func::Lerp.call3(self, other, frac) + } } -impl_assign_ops!(AddAssign, add_assign, Add, add); -impl_assign_ops!(SubAssign, sub_assign, Sub, sub); -impl_assign_ops!(MulAssign, mul_assign, Mul, mul); -impl_assign_ops!(DivAssign, div_assign, Div, div); -impl_assign_ops!(RemAssign, rem_assign, Rem, rem); -impl_assign_ops!(BitAndAssign, bitand_assign, BitAnd, bitand); -impl_assign_ops!(BitOrAssign, bitor_assign, BitOr, bitor); -impl_assign_ops!(BitXorAssign, bitxor_assign, BitXor, bitxor); -impl_assign_ops!(ShlAssign, shl_assign, Shl, shl); -impl_assign_ops!(ShrAssign, shr_assign, Shr, shr); diff --git a/luisa_compute/src/lang/ops/spread.rs b/luisa_compute/src/lang/ops/spread.rs new file mode 100644 index 0000000..98a0db3 --- /dev/null +++ b/luisa_compute/src/lang/ops/spread.rs @@ -0,0 +1,264 @@ +use super::*; +use traits::*; + +trait SpreadOps { + type Join; + fn lift_self(x: Self) -> Expr; + fn lift_other(x: Other) -> Expr; +} + +macro_rules! impl_spread_single { + ([$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { + impl<$($bounds)*> SpreadOps<$S> for $T { + type Join = $J; + fn lift_self($x: $T) -> Expr { + $f + } + fn lift_other($x: $S) -> Expr { + $g + } + } + }; +} +macro_rules! impl_spread { + ([$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { + impl<$($bounds)*> SpreadOps<$S> for $T { + type Join = $J; + fn lift_self($x: $T) -> Expr { + $f + } + fn lift_other($x: $S) -> Expr { + $g + } + } + impl<$($bounds)*> SpreadOps<$T> for $S { + type Join = $J; + fn lift_self($y: $S) -> Expr { + $g + } + fn lift_other($y: $T) -> Expr { + $f + } + } + }; +} +impl_spread!([T: Value] T: |x| x.expr(), Expr: |x| x => Expr); +impl_spread!([T: Value] &T: |x| x.expr(), Expr: |x| x => Expr); +impl_spread!([T: Value] T: |x| x.expr(), &Expr: |x| x.clone() => Expr); +impl_spread!([T: Value] &T: |x| x.expr(), &Expr: |x| x.clone() => Expr); + +impl_spread!([T: Value] Expr: |x| x, &Expr: |x| x.clone() => Expr); +impl_spread!([T: Value] &Expr: |x| x.clone(), &Expr: |x| x.clone() => Expr); + +impl_spread!([T: Value] T: |x| x.expr(), Var: |x| x.load() => Expr); +impl_spread!([T: Value] &T: |x| x.expr(), Var: |x| x.load() => Expr); +impl_spread!([T: Value] T: |x| x.expr(), &Var: |x| x.load() => Expr); +impl_spread!([T: Value] &T: |x| x.expr(), &Var: |x| x.load() => Expr); + +// Other way is unneded because of the deref impl. +impl_spread_single!([T: Value] &Expr: |x| x.clone(), Var: |x| x.load() => Expr); +impl_spread_single!([T: Value] &Expr: |x| x.clone(), &Var: |x| x.load() => Expr); + +impl_spread!([const N: usize, T: VectorElement] T: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); +impl_spread!([const N: usize, T: VectorElement] &T: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); +impl_spread!([const N: usize, T: VectorElement] Expr: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); +impl_spread!([const N: usize, T: VectorElement] &Expr: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); +impl_spread!([const N: usize, T: VectorElement] T: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); +impl_spread!([const N: usize, T: VectorElement] &T: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); +impl_spread!([const N: usize, T: VectorElement] Expr: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); +impl_spread!([const N: usize, T: VectorElement] &Expr: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); + +impl_spread!([const N: usize, T: VectorElement] Expr: |x| Vector::::splat_expr(x), Vector: |x| x.expr() => Expr>); +impl_spread!([const N: usize, T: VectorElement] &Expr: |x| Vector::::splat_expr(x), Vector: |x| x.expr() => Expr>); +impl_spread!([const N: usize, T: VectorElement] Expr: |x| Vector::::splat_expr(x), &Vector: |x| x.expr() => Expr>); +impl_spread!([const N: usize, T: VectorElement] &Expr: |x| Vector::::splat_expr(x), &Vector: |x| x.expr() => Expr>); + +mod impls { + use super::*; + impl MinMaxExpr for T + where + T: SpreadOps, + Expr: MinMaxExpr, + { + type Output = Expr::Output; + fn max(self, other: S) -> Self::Output { + Expr::::max(Self::lift_self(self), Self::lift_other(other)) + } + fn min(self, other: S) -> Self::Output { + Expr::::min(Self::lift_self(self), Self::lift_other(other)) + } + } + impl ClampExpr for T + where + S: SpreadOps, + T: SpreadOps, + Expr: ClampExpr, + { + /// T::Join + /// / \ + /// / \ + /// / \ + /// / \ + /// / S::Join + /// / / \ + /// / / \ + /// / / \ + /// / / \ + /// T S U + + type Output = Expr::Output; + fn clamp(&self, min: S, max: U) -> Self::Output { + Expr::::clamp( + Self::lift_self(self), + Self::lift_other(S::lift_self(min)), + Self::lift_other(S::lift_other(max)), + ) + } + } + impl EqExpr for T + where + T: SpreadOps, + Expr: EqExpr, + { + type Output = Expr::Output; + fn eq(self, other: S) -> Self::Output { + Expr::::eq(Self::lift_self(self), Self::lift_other(other)) + } + fn ne(self, other: S) -> Self::Output { + Expr::::ne(Self::lift_self(self), Self::lift_other(other)) + } + } + impl CmpExpr for T + where + T: SpreadOps, + Expr: CmpExpr, + { + fn lt(self, other: S) -> Self::Output { + Expr::::lt(Self::lift_self(self), Self::lift_other(other)) + } + fn le(self, other: S) -> Self::Output { + Expr::::le(Self::lift_self(self), Self::lift_other(other)) + } + fn gt(self, other: S) -> Self::Output { + Expr::::gt(Self::lift_self(self), Self::lift_other(other)) + } + fn ge(self, other: S) -> Self::Output { + Expr::::ge(Self::lift_self(self), Self::lift_other(other)) + } + } + impl FloatMulAddExpr for T + where + S: SpreadOps, + T: SpreadOps, + Expr: FloatMulAddExpr, + { + type Output = Expr::Output; + fn mul_add(self, mul: S, add: U) -> Self::Output { + Expr::::mul_add( + Self::lift_self(self), + Self::lift_other(S::lift_self(mul)), + Self::lift_other(S::lift_other(add)), + ) + } + } + impl FloatCopySignExpr for T + where + T: SpreadOps, + Expr: FloatCopySignExpr, + { + type Output = Expr::Output; + fn copy_sign(self, sign: S) -> Self::Output { + Expr::::copy_sign(Self::lift_self(self), Self::lift_other(sign)) + } + } + impl FloatStepExpr for T + where + T: SpreadOps, + Expr: FloatStepExpr, + { + type Output = Expr::Output; + fn step(self, edge: S) -> Self::Output { + Expr::::step(Self::lift_self(self), Self::lift_other(edge)) + } + } + impl FloatSmoothStepExpr for T + where + S: SpreadOps, + T: SpreadOps, + Expr: FloatSmoothStepExpr, + { + type Output = Expr::Output; + fn smooth_step(self, edge0: S, edge1: U) -> Self::Output { + Expr::::smooth_step( + Self::lift_self(self), + Self::lift_other(S::lift_self(edge0)), + Self::lift_other(S::lift_other(edge1)), + ) + } + } + impl FloatArcTan2Expr for T + where + T: SpreadOps, + Expr: FloatArcTan2Expr, + { + type Output = Expr::Output; + fn atan2(self, other: S) -> Self::Output { + Expr::::atan2(Self::lift_self(self), Self::lift_other(other)) + } + } + impl FloatLogExpr for T + where + T: SpreadOps, + Expr: FloatLogExpr, + { + type Output = Expr::Output; + fn log(self, base: S) -> Self::Output { + Expr::::log(Self::lift_self(self), Self::lift_other(base)) + } + } + impl FloatPowfExpr for T + where + T: SpreadOps, + Expr: FloatPowfExpr, + { + type Output = Expr::Output; + fn powf(self, exponent: S) -> Self::Output { + Expr::::powf(Self::lift_self(self), Self::lift_other(exponent)) + } + } + impl FloatLerpExpr for T + where + S: SpreadOps, + T: SpreadOps, + Expr: FloatLerpExpr, + { + type Output = Expr::Output; + fn lerp(self, other: S, frac: U) -> Self::Output { + Expr::::lerp( + Self::lift_self(self), + Self::lift_other(S::lift_self(other)), + Self::lift_other(S::lift_other(frac)), + ) + } + } +} +macro_rules! impl_spread_op { + ([ $($bounds:tt)* ]: $Op:ident::$op_fn:ident for $T:ty, $S:ty) => { + impl<$($bounds)*> $Op <$S> for $T where $T: SpreadOps<$T>, Expr<$T::Join>: $Op { + type Output = as $Op>::Output; + fn $op_fn (self, other: $S) -> Self::Output { + as $Op>::$op_fn (Self::lift_self(self), Self::lift_other(other)) + } + } + } +} + +macro_rules! impl_num_spread { + ([ $($bounds:tt)* ]: $T:ty, $S:ty) => { + impl_spread_op!( [ $($bounds)* ]: Add::add for $T, $S); + impl_spread_op!( [ $($bounds)* ]: Sub::sub for $T, $S); + impl_spread_op!( [ $($bounds)* ]: Mul::mul for $T, $S); + impl_spread_op!( [ $($bounds)* ]: Div::div for $T, $S); + impl_spread_op!( [ $($bounds)* ]: Rem::rem for $T, $S); + }; +} diff --git a/luisa_compute/src/lang/ops/traits.rs b/luisa_compute/src/lang/ops/traits.rs new file mode 100644 index 0000000..ce49a28 --- /dev/null +++ b/luisa_compute/src/lang/ops/traits.rs @@ -0,0 +1,141 @@ +use super::*; + +pub trait MinMaxExpr { + type Output; + + fn max(self, other: T) -> Self::Output; + fn min(self, other: T) -> Self::Output; +} + +pub trait ClampExpr { + type Output; + + fn clamp(self, min: A, max: B) -> Self::Output; +} + +pub trait AbsExpr { + fn abs(&self) -> Self; +} + +pub trait EqExpr { + type Output; + + fn eq(self, other: T) -> Self::Output; + fn ne(self, other: T) -> Self::Output; +} + +pub trait CmpExpr: EqExpr { + fn lt(self, other: T) -> Self::Output; + fn le(self, other: T) -> Self::Output; + fn gt(self, other: T) -> Self::Output; + fn ge(self, other: T) -> Self::Output; +} + +pub trait IntExpr { + fn rotate_right(&self, n: Expr) -> Self; + fn rotate_left(&self, n: Expr) -> Self; +} + +pub trait FloatExpr { + type Bool; + + fn ceil(&self) -> Self; + fn floor(&self) -> Self; + fn round(&self) -> Self; + fn trunc(&self) -> Self; + fn sqrt(&self) -> Self; + fn rsqrt(&self) -> Self; + fn fract(&self) -> Self; + fn saturate(&self) -> Self; + fn sin(&self) -> Self; + fn cos(&self) -> Self; + fn tan(&self) -> Self; + fn asin(&self) -> Self; + fn acos(&self) -> Self; + fn atan(&self) -> Self; + fn sinh(&self) -> Self; + fn cosh(&self) -> Self; + fn tanh(&self) -> Self; + fn asinh(&self) -> Self; + fn acosh(&self) -> Self; + fn atanh(&self) -> Self; + fn exp(&self) -> Self; + fn exp2(&self) -> Self; + fn is_finite(&self) -> Self::Bool; + fn is_infinite(&self) -> Self::Bool; + fn is_nan(&self) -> Self::Bool; + fn ln(&self) -> Self; + fn log2(&self) -> Self; + fn log10(&self) -> Self; + fn sqr(&self) -> Self; + fn cube(&self) -> Self; + fn recip(&self) -> Self; + fn sin_cos(&self) -> (Self, Self); +} +pub trait FloatMulAddExpr { + type Output; + + fn mul_add(self, a: A, b: B) -> Self::Output; +} +pub trait FloatCopySignExpr { + type Output; + + fn copy_sign(self, sign: T) -> Self::Output; +} +pub trait FloatStepExpr { + type Output; + + fn step(self, edge: T) -> Self::Output; +} +pub trait FloatSmoothStepExpr { + type Output; + + fn smooth_step(self, edge0: T, edge1: S) -> Self::Output; +} +pub trait FloatArcTan2Expr { + type Output; + + fn atan2(self, other: T) -> Self::Output; +} +pub trait FloatLogExpr { + type Output; + + fn log(self, base: T) -> Self::Output; +} +pub trait FloatPowfExpr { + type Output; + + fn powf(self, exponent: T) -> Self::Output; +} +pub trait FloatPowiExpr { + type Output; + + fn powi(self, exponent: T) -> Self::Output; +} +pub trait FloatLerpExpr { + type Output; + + fn lerp(self, other: T, frac: S) -> Self::Output; +} + +pub trait StoreExpr { + fn store(self, value: V); +} + +pub trait SwitchExpr { + fn switch(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R; +} + +pub trait ActivateExpr { + fn activate(self, then: impl FnOnce()); +} + +pub trait LoopExpr { + fn while_loop(cond: impl FnMut() -> Self, body: impl FnMut()); +} + +pub trait LazyBoolExpr { + type Bool; + fn and(self, other: impl FnOnce() -> T) -> Self::Bool; + fn or(self, other: impl FnOnce() -> T) -> Self::Bool; +} diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs index 0386e64..c046030 100644 --- a/luisa_compute/src/lang/types.rs +++ b/luisa_compute/src/lang/types.rs @@ -177,6 +177,12 @@ impl Expr { } }) } + pub unsafe fn bitcast(self) -> Expr { + assert_eq!(std::mem::size_of::(), std::mem::size_of::()); + let ty = S::type_(); + let node = __current_scope(|s| s.bitcast(self.node(), ty)); + Expr::::from_node(node) + } } impl Var { diff --git a/luisa_compute/src/lang/types/core.rs b/luisa_compute/src/lang/types/core.rs index a48c6de..5bf0660 100644 --- a/luisa_compute/src/lang/types/core.rs +++ b/luisa_compute/src/lang/types/core.rs @@ -1,7 +1,23 @@ use super::*; -use std::ops::Deref; -pub(crate) trait Primitive: Copy + TypeOf + 'static { +mod private { + use super::*; + pub trait Sealed {} + impl Sealed for bool {} + impl Sealed for f16 {} + impl Sealed for f32 {} + impl Sealed for f64 {} + impl Sealed for i8 {} + impl Sealed for i16 {} + impl Sealed for i32 {} + impl Sealed for i64 {} + impl Sealed for u8 {} + impl Sealed for u16 {} + impl Sealed for u32 {} + impl Sealed for u64 {} +} + +pub trait Primitive: private::Sealed + Copy + TypeOf + 'static { fn const_(&self) -> Const; fn primitive(&self) -> ir::Primitive; } @@ -54,11 +70,11 @@ impl Primitive for f64 { } } -// impl Primitive for i8 { -// fn const_(&self) -> Const { -// Const::I8(*self) -// } -// } +impl Primitive for i8 { + fn const_(&self) -> Const { + todo!() // Const::I8(*self) + } +} impl Primitive for i16 { fn const_(&self) -> Const { Const::I16(*self) @@ -84,11 +100,11 @@ impl Primitive for i64 { } } -// impl Primitive for u8 { -// fn const_(&self) -> Const { -// Const::U8(*self) -// } -// } +impl Primitive for u8 { + fn const_(&self) -> Const { + todo!() // Const::U8(*self) + } +} impl Primitive for u16 { fn const_(&self) -> Const { Const::U16(*self) @@ -114,6 +130,44 @@ impl Primitive for u64 { } } +pub trait Integral: Primitive {} +impl Integral for bool {} +impl Integral for i8 {} +impl Integral for i16 {} +impl Integral for i32 {} +impl Integral for i64 {} +impl Integral for u8 {} +impl Integral for u16 {} +impl Integral for u32 {} +impl Integral for u64 {} + +pub trait Numeric: Primitive {} +impl Numeric for f16 {} +impl Numeric for f32 {} +impl Numeric for f64 {} +impl Numeric for i8 {} +impl Numeric for i16 {} +impl Numeric for i32 {} +impl Numeric for i64 {} +impl Numeric for u8 {} +impl Numeric for u16 {} +impl Numeric for u32 {} +impl Numeric for u64 {} + +pub trait Floating: Numeric {} +impl Floating for f16 {} +impl Floating for f32 {} +impl Floating for f64 {} + +pub trait Signed: Numeric {} +impl Signed for f16 {} +impl Signed for f32 {} +impl Signed for f64 {} +impl Signed for i8 {} +impl Signed for i16 {} +impl Signed for i32 {} +impl Signed for i64 {} + #[deprecated] pub type Bool = Expr; #[deprecated] diff --git a/luisa_compute/src/lang/types/dynamic.rs b/luisa_compute/src/lang/types/dynamic.rs index 1744779..73c49b6 100644 --- a/luisa_compute/src/lang/types/dynamic.rs +++ b/luisa_compute/src/lang/types/dynamic.rs @@ -204,7 +204,7 @@ impl DynVar { __current_scope(|b| b.update(self.node, value.node)); } pub fn zero() -> Self { - let v = local_zeroed::(); + let v = Var::::zeroed(); Self { node: v.node() } } } diff --git a/luisa_compute/src/lang/types/vector.rs b/luisa_compute/src/lang/types/vector.rs index 4a6318f..88f8563 100644 --- a/luisa_compute/src/lang/types/vector.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -13,8 +13,11 @@ mod nalgebra; pub mod coords; mod element; +pub mod swizzle; -trait VectorElement: Primitive { +use swizzle::*; + +pub trait VectorElement: Primitive { type A: Alignment; } @@ -33,7 +36,7 @@ pub struct Vector, const N: usize> { } impl, const N: usize> Vector { - pub fn new(elements: [T; N]) -> Self { + pub fn from_elements(elements: [T; N]) -> Self { Self { _align: T::A::default(), elements, @@ -45,30 +48,58 @@ impl, const N: usize> Vector { elements: [element; N], } } + pub fn splat_expr(element: impl AsExpr) -> Expr { + Func::Vec.call(element.as_expr()) + } + fn _permute2(&self, x: u32, y: u32) -> Vector + where + T: VectorElement<2>, + { + Vector::from_elements([self.elements[x as usize], self.elements[y as usize]]) + } + fn _permute3(&self, x: u32, y: u32, z: u32) -> Vector + where + T: VectorElement<3>, + { + Vector::from_elements([ + self.elements[x as usize], + self.elements[y as usize], + self.elements[z as usize], + ]) + } + fn _permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Vector + where + T: VectorElement<4>, + { + Vector::from_elements([ + self.elements[x as usize], + self.elements[y as usize], + self.elements[z as usize], + self.elements[w as usize], + ]) + } } #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct VectorExprData, const N: usize>([Expr; N]); -impl> FromNode for VectorExprData { +impl, const N: usize> FromNode for VectorExprData { fn from_node(node: NodeRef) -> Self { Self(std::array::from_fn(|i| { FromNode::from_node(__extract::(node, i)) })) } } - #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct VectorVarData, const N: usize>([Var; N]); -impl> FromNode for VectorVarData { +impl, const N: usize> FromNode for VectorVarData { fn from_node(node: NodeRef) -> Self { Self(std::array::from_fn(|i| { FromNode::from_node(__extract::(node, i)) })) } } - #[repr(C)] #[derive(Debug, Copy, Clone)] pub struct DoubledProxyData(X, X); @@ -78,6 +109,10 @@ impl FromNode for DoubledProxyData { } } +pub trait VectorExprProxy { + type T: Primitive; +} + macro_rules! vector_proxies { ($N:literal [ $($c:ident),* ]: $ExprName:ident, $VarName:ident) => { #[repr(C)] @@ -99,6 +134,9 @@ macro_rules! vector_proxies { impl> ExprProxy for $ExprName { type Value = Vector; } + impl> VectorExprProxy for $ExprName { + type T = T; + } impl> VarProxy for $VarName { type Value = Vector; } @@ -143,3 +181,164 @@ impl> Value for Vector { type ExprData = DoubledProxyData>; type VarData = DoubledProxyData>; } + +impl + VectorElement<3> + VectorElement<4>> Vec2Swizzle for Vector { + type Vec2 = Self; + type Vec3 = Vector; + type Vec4 = Vector; + fn permute2(&self, x: i32, y: i32) -> Self::Vec2 { + self._permute2(x, y) + } + fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3 { + self._permute3(x, y, z) + } + fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4 { + self._permute4(x, y, z, w) + } +} +impl + VectorElement<3> + VectorElement<4>> Vec3Swizzle for Vector { + type Vec2 = Vector; + type Vec3 = Self; + type Vec4 = Vector; + fn permute2(&self, x: i32, y: i32) -> Self::Vec2 { + self._permute2(x, y) + } + fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3 { + self._permute3(x, y, z) + } + fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4 { + self._permute4(x, y, z, w) + } +} +impl + VectorElement<3> + VectorElement<4>> Vec4Swizzle for Vector { + type Vec2 = Vector; + type Vec3 = Vector; + type Vec4 = Self; + fn permute2(&self, x: i32, y: i32) -> Self::Vec2 { + self._permute2(x, y) + } + fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3 { + self._permute3(x, y, z) + } + fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4 { + self._permute4(x, y, z, w) + } +} + +impl, const N: usize> VectorExprData { + fn _permute2(&self, x: u32, y: u32) -> Expr> + where + T: VectorElement<2>, + { + assert!(x < N as u32); + assert!(y < N as u32); + let x = x.expr(); + let y = y.expr(); + Expr::>::from_node(__current_scope(|s| { + s.call( + Func::Permute, + &[self.node, ToNode::node(&x), ToNode::node(&y)], + Vector::::type_(), + ) + })) + } + fn _permute3(&self, x: u32, y: u32, z: u32) -> Expr> + where + T: VectorElement<3>, + { + assert!(x < N as u32); + assert!(y < N as u32); + assert!(z < N as u32); + let x = x.expr(); + let y = y.expr(); + let z = z.expr(); + Expr::>::from_node(__current_scope(|s| { + s.call( + Func::Permute, + &[ + self.node, + ToNode::node(&x), + ToNode::node(&y), + ToNode::node(&z), + ], + Vector::::type_(), + ) + })) + } + fn _permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Expr> + where + T: VectorElement<4>, + { + assert!(x < N as u32); + assert!(y < N as u32); + assert!(z < N as u32); + assert!(w < N as u32); + let x = x.expr(); + let y = y.expr(); + let z = z.expr(); + let w = w.expr(); + Expr::>::from_node(__current_scope(|s| { + s.call( + Func::Permute, + &[ + self.node, + ToNode::node(&x), + ToNode::node(&y), + ToNode::node(&z), + ToNode::node(&w), + ], + Vector::::type_(), + ) + })) + } +} + +impl + VectorElement<3> + VectorElement<4>> Vec2Swizzle + for VectorExprProxy2 +{ + type Vec2 = Self; + type Vec3 = Expr>; + type Vec4 = Expr>; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { + self._permute2(x, y) + } + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3 { + self._permute3(x, y, z) + } + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4 { + self._permute4(x, y, z, w) + } +} + +impl + VectorElement<3> + VectorElement<4>> Vec3Swizzle + for VectorExprProxy3 +{ + type Vec2 = Expr>; + type Vec3 = Self; + type Vec4 = Expr>; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { + self._permute2(x, y) + } + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3 { + self._permute3(x, y, z) + } + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4 { + self._permute4(x, y, z, w) + } +} +impl + VectorElement<3> + VectorElement<4>> Vec4Swizzle + for VectorExprProxy4 +{ + type Vec2 = Expr>; + type Vec3 = Expr>; + type Vec4 = Self; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { + self._permute2(x, y) + } + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3 { + self._permute3(x, y, z) + } + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4 { + self._permute4(x, y, z, w) + } +} diff --git a/luisa_compute/src/lang/types/vector/element.rs b/luisa_compute/src/lang/types/vector/element.rs index 72425e6..ff2f661 100644 --- a/luisa_compute/src/lang/types/vector/element.rs +++ b/luisa_compute/src/lang/types/vector/element.rs @@ -1,5 +1,10 @@ use super::*; +// Stupid hack to make ops work. +impl VectorElement<1> for T { + type A = Align1; +} + macro_rules! element { ($t:ty [ $l:literal ]: $a: ident) => { impl VectorElement<$l> for $t { @@ -11,13 +16,12 @@ macro_rules! element { element!(bool[2]: Align2); element!(bool[3]: Align4); element!(bool[4]: Align4); -// TODO: Make u8 support ir::TypeOf. -// element!(u8[2]: Align2); -// element!(u8[3]: Align4); -// element!(u8[4]: Align4); -// element!(i8[2]: Align2); -// element!(i8[3]: Align4); -// element!(i8[4]: Align4); +element!(u8[2]: Align2); +element!(u8[3]: Align4); +element!(u8[4]: Align4); +element!(i8[2]: Align2); +element!(i8[3]: Align4); +element!(i8[4]: Align4); element!(f16[2]: Align4); element!(f16[3]: Align8); diff --git a/luisa_compute/src/lang/gen_swizzle.py b/luisa_compute/src/lang/types/vector/gen_swizzle.py similarity index 88% rename from luisa_compute/src/lang/gen_swizzle.py rename to luisa_compute/src/lang/types/vector/gen_swizzle.py index 891b1c3..f619f33 100644 --- a/luisa_compute/src/lang/gen_swizzle.py +++ b/luisa_compute/src/lang/types/vector/gen_swizzle.py @@ -22,9 +22,9 @@ def swizzle_name(perm: List[int]): s += ' type Vec2;\n' s += ' type Vec3;\n' s += ' type Vec4;\n' - s += ' fn permute2(&self, x: i32, y: i32) -> Self::Vec2;\n' - s += ' fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3;\n' - s += ' fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4;\n' + s += ' fn permute2(&self, x: u32, y: u32) -> Self::Vec2;\n' + s += ' fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3;\n' + s += ' fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4;\n' for sw in sw_m_to_n[(n, 2)]: s += ' fn {}(&self) -> Self::Vec2 {{\n'.format(swizzle_name(sw)) s += ' self.permute2({}, {})\n'.format(sw[0], sw[1]) diff --git a/luisa_compute/src/lang/swizzle.rs b/luisa_compute/src/lang/types/vector/swizzle.rs similarity index 98% rename from luisa_compute/src/lang/swizzle.rs rename to luisa_compute/src/lang/types/vector/swizzle.rs index 2696954..188d599 100644 --- a/luisa_compute/src/lang/swizzle.rs +++ b/luisa_compute/src/lang/types/vector/swizzle.rs @@ -2,9 +2,9 @@ pub trait Vec2Swizzle { type Vec2; type Vec3; type Vec4; - fn permute2(&self, x: i32, y: i32) -> Self::Vec2; - fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3; - fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2; + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3; + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4; fn xx(&self) -> Self::Vec2 { self.permute2(0, 0) } @@ -94,9 +94,9 @@ pub trait Vec3Swizzle { type Vec2; type Vec3; type Vec4; - fn permute2(&self, x: i32, y: i32) -> Self::Vec2; - fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3; - fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2; + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3; + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4; fn xx(&self) -> Self::Vec2 { self.permute2(0, 0) } @@ -453,9 +453,9 @@ pub trait Vec4Swizzle { type Vec2; type Vec3; type Vec4; - fn permute2(&self, x: i32, y: i32) -> Self::Vec2; - fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3; - fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4; + fn permute2(&self, x: u32, y: u32) -> Self::Vec2; + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3; + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4; fn xx(&self) -> Self::Vec2 { self.permute2(0, 0) } diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 98d0138..0dd0b6b 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -22,7 +22,7 @@ pub mod prelude { pub use crate::lang::functions::{block_size, dispatch_id, dispatch_size, set_block_size}; pub use crate::lang::index::{IndexRead, IndexWrite}; pub use crate::lang::ops::*; - pub use crate::lang::swizzle::*; + pub use crate::lang::types::vector::swizzle::*; pub use crate::lang::types::vector::Vector; pub use crate::lang::types::{Expr, ExprProxy, Value, Var, VarProxy}; pub use crate::lang::Aggregate; @@ -45,8 +45,9 @@ mod internal_prelude { PhiIncoming, Pooled, Type, TypeOf, INVALID_REF, }; pub(crate) use crate::lang::{ - ir, Recorder, __compose, __extract, __insert, __module_pools, need_runtime_check, FromNode, - NodeLike, NodeRef, ToNode, __current_scope, __pop_scope, RECORDER, + ir, CallFuncTrait, Recorder, __compose, __extract, __insert, __module_pools, + need_runtime_check, FromNode, NodeLike, NodeRef, ToNode, __current_scope, __pop_scope, + RECORDER, }; pub(crate) use crate::prelude::*; pub(crate) use crate::runtime::{ diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index b8b284f..f4e8dab 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit b8b284f5c8b8ee4470298f362a1cd7f9a7c79698 +Subproject commit f4e8dabfd1c5d8ce0923b1f81f2dc8ef3ea8f68e diff --git a/luisa_compute_track/src/lib.rs b/luisa_compute_track/src/lib.rs index 8d5e848..7b08f58 100644 --- a/luisa_compute_track/src/lib.rs +++ b/luisa_compute_track/src/lib.rs @@ -5,6 +5,10 @@ use syn::spanned::Spanned; use syn::visit_mut::*; use syn::*; +// TODO: Impl let mut -> let = .var() +// TODO: Impl x as f32 -> .cast() +// TOOD: Impl switch! macro. + #[cfg(test)] use pretty_assertions::assert_eq; From 86f7053e31e7e7d1cefd8867b042c1a6196d2244 Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Wed, 20 Sep 2023 12:29:48 +0100 Subject: [PATCH 07/15] De-constanted the Linear type. --- luisa_compute/src/lang/ops.rs | 24 +++--- luisa_compute/src/lang/ops/impls.rs | 62 +++++++-------- luisa_compute/src/lang/ops/spread.rs | 48 ++++++------ luisa_compute/src/lang/types/vector.rs | 75 +++++++++---------- .../src/lang/types/vector/element.rs | 4 +- 5 files changed, 106 insertions(+), 107 deletions(-) diff --git a/luisa_compute/src/lang/ops.rs b/luisa_compute/src/lang/ops.rs index 9e8b551..f1d6e73 100644 --- a/luisa_compute/src/lang/ops.rs +++ b/luisa_compute/src/lang/ops.rs @@ -2,7 +2,7 @@ use crate::internal_prelude::*; use std::ops::*; use super::types::core::{Floating, Integral, Numeric, Primitive, Signed}; -use super::types::vector::VectorElement; +use super::types::vector::{VectorAlign, VectorElement}; pub mod impls; pub mod spread; @@ -12,19 +12,21 @@ trait CastFrom: Primitive {} impl CastFrom for T {} impl CastFrom for T {} -// Hack because using an associated constant is not allowed within a trait bound -// without #![feature(generic_const_exprs)]. -pub trait Linear: Value { - type Scalar: VectorElement; - type WithScalar>: Linear; +pub trait Linear: Value { + // Note that without #![feature(generic_const_exprs)], I can't use this within + // the WithScalar restriction. As such, we can't support higher dimensional + // vector operations. If that ever becomes necessary, check commit + // 9e6eacf6b0c2b59a2646f45a727e4d82e84a46cd. + const N: usize; + type Scalar: VectorElement; + type WithScalar: Linear; // We don't actually know that the vector has equivalent vectors of every // primitive type. - type WithBool: Linear; } -impl Linear<1> for T { +impl Linear for T { + const N: usize = 1; type Scalar = T; type WithScalar = S; - type WithBool = bool; } macro_rules! impl_linear_vectors { ($t:ty) => { @@ -36,10 +38,10 @@ macro_rules! impl_linear_vectors { }; ($t:ty : $($n:literal),+) => { $( - impl Linear<$n> for Vector<$t, $n> { + impl Linear for Vector<$t, $n> { + const N: usize = $n; type Scalar = $t; type WithScalar = Vector; - type WithBool = Vector; } )+ } diff --git a/luisa_compute/src/lang/ops/impls.rs b/luisa_compute/src/lang/ops/impls.rs index 916e13a..40c138d 100644 --- a/luisa_compute/src/lang/ops/impls.rs +++ b/luisa_compute/src/lang/ops/impls.rs @@ -1,14 +1,14 @@ use super::*; use traits::*; -impl> Expr { - fn as_>(self) -> Y +impl Expr { + fn as_>(self) -> Y where Y::Scalar: CastFrom, { Func::Cast.call(self) } - fn cast>(self) -> Expr> + fn cast(self) -> Expr> where S: CastFrom, { @@ -16,7 +16,7 @@ impl> Expr { } } -impl> MinMaxExpr for Expr +impl MinMaxExpr for Expr where X::Scalar: Numeric, { @@ -30,7 +30,7 @@ where } } -impl> ClampExpr for Expr +impl ClampExpr for Expr where X::Scalar: Numeric, { @@ -41,7 +41,7 @@ where } } -impl> AbsExpr for Expr +impl AbsExpr for Expr where X::Scalar: Signed, { @@ -50,7 +50,7 @@ where } } -impl> EqExpr for Expr { +impl EqExpr for Expr { type Output = Expr; fn eq(self, other: Self) -> Self::Output { Func::Eq.call2(self, other) @@ -59,7 +59,7 @@ impl> EqExpr for Expr { Func::Ne.call2(self, other) } } -impl> CmpExpr for Expr { +impl CmpExpr for Expr { fn lt(self, other: Self) -> Self::Output { Func::Lt.call2(self, other) } @@ -74,7 +74,7 @@ impl> CmpExpr for Expr { } } -impl> Add for Expr +impl Add for Expr where X::Scalar: Numeric, { @@ -83,7 +83,7 @@ where Func::Add.call2(self, other) } } -impl> Sub for Expr +impl Sub for Expr where X::Scalar: Numeric, { @@ -92,7 +92,7 @@ where Func::Sub.call2(self, other) } } -impl> Mul for Expr +impl Mul for Expr where X::Scalar: Numeric, { @@ -101,7 +101,7 @@ where Func::Mul.call2(self, other) } } -impl> Div for Expr +impl Div for Expr where X::Scalar: Numeric, { @@ -110,7 +110,7 @@ where Func::Div.call2(self, other) } } -impl> Rem for Expr +impl Rem for Expr where X::Scalar: Numeric, { @@ -120,7 +120,7 @@ where } } -impl> BitAnd for Expr +impl BitAnd for Expr where X::Scalar: Integral, { @@ -129,7 +129,7 @@ where Func::BitAnd.call2(self, other) } } -impl> BitOr for Expr +impl BitOr for Expr where X::Scalar: Integral, { @@ -138,7 +138,7 @@ where Func::BitOr.call2(self, other) } } -impl> BitXor for Expr +impl BitXor for Expr where X::Scalar: Integral, { @@ -147,7 +147,7 @@ where Func::BitXor.call2(self, other) } } -impl> Shl for Expr +impl Shl for Expr where X::Scalar: Integral, { @@ -156,7 +156,7 @@ where Func::Shl.call2(self, other) } } -impl> Shr for Expr +impl Shr for Expr where X::Scalar: Integral, { @@ -166,7 +166,7 @@ where } } -impl> Neg for Expr +impl Neg for Expr where X::Scalar: Signed, { @@ -175,7 +175,7 @@ where Func::Neg.call(self) } } -impl> Not for Expr +impl Not for Expr where X::Scalar: Integral, { @@ -185,7 +185,7 @@ where } } -impl> IntExpr for Expr +impl IntExpr for Expr where X::Scalar: Integral + Numeric, { @@ -205,7 +205,7 @@ macro_rules! impl_simple_fns { )+}; } -impl> FloatExpr for Expr +impl FloatExpr for Expr where X::Scalar: Floating, { @@ -255,7 +255,7 @@ where (self.sin(), self.cos()) } } -impl> FloatMulAddExpr for Expr +impl FloatMulAddExpr for Expr where X::Scalar: Floating, { @@ -265,7 +265,7 @@ where Func::Fma.call3(self, a, b) } } -impl> FloatCopySignExpr for Expr +impl FloatCopySignExpr for Expr where X::Scalar: Floating, { @@ -275,7 +275,7 @@ where Func::Copysign.call2(self, sign) } } -impl> FloatStepExpr for Expr +impl FloatStepExpr for Expr where X::Scalar: Floating, { @@ -285,7 +285,7 @@ where Func::Step.call2(edge, self) } } -impl> FloatSmoothStepExpr for Expr +impl FloatSmoothStepExpr for Expr where X::Scalar: Floating, { @@ -295,7 +295,7 @@ where Func::SmoothStep.call3(edge0, edge1, self) } } -impl> FloatArcTan2Expr for Expr +impl FloatArcTan2Expr for Expr where X::Scalar: Floating, { @@ -305,7 +305,7 @@ where Func::Atan2.call2(self, other) } } -impl> FloatLogExpr for Expr +impl FloatLogExpr for Expr where X::Scalar: Floating, { @@ -315,7 +315,7 @@ where self.ln() / base.ln() } } -impl> FloatPowfExpr for Expr +impl FloatPowfExpr for Expr where X::Scalar: Floating, { @@ -325,7 +325,7 @@ where Func::Powf.call2(self, exponent) } } -impl, Y: Linear> FloatPowiExpr> for Expr +impl> FloatPowiExpr> for Expr where X::Scalar: Floating, { @@ -335,7 +335,7 @@ where Func::Powi.call2(self, exponent) } } -impl> FloatLerpExpr for Expr +impl FloatLerpExpr for Expr where X::Scalar: Floating, { diff --git a/luisa_compute/src/lang/ops/spread.rs b/luisa_compute/src/lang/ops/spread.rs index 98a0db3..3fc3bda 100644 --- a/luisa_compute/src/lang/ops/spread.rs +++ b/luisa_compute/src/lang/ops/spread.rs @@ -42,36 +42,36 @@ macro_rules! impl_spread { } }; } -impl_spread!([T: Value] T: |x| x.expr(), Expr: |x| x => Expr); -impl_spread!([T: Value] &T: |x| x.expr(), Expr: |x| x => Expr); -impl_spread!([T: Value] T: |x| x.expr(), &Expr: |x| x.clone() => Expr); -impl_spread!([T: Value] &T: |x| x.expr(), &Expr: |x| x.clone() => Expr); +impl_spread!([T: Linear] T: |x| x.expr(), Expr: |x| x => Expr); +impl_spread!([T: Linear] &T: |x| x.expr(), Expr: |x| x => Expr); +impl_spread!([T: Linear] T: |x| x.expr(), &Expr: |x| x.clone() => Expr); +impl_spread!([T: Linear] &T: |x| x.expr(), &Expr: |x| x.clone() => Expr); -impl_spread!([T: Value] Expr: |x| x, &Expr: |x| x.clone() => Expr); -impl_spread!([T: Value] &Expr: |x| x.clone(), &Expr: |x| x.clone() => Expr); +impl_spread!([T: Linear] Expr: |x| x, &Expr: |x| x.clone() => Expr); +impl_spread!([T: Linear] &Expr: |x| x.clone(), &Expr: |x| x.clone() => Expr); -impl_spread!([T: Value] T: |x| x.expr(), Var: |x| x.load() => Expr); -impl_spread!([T: Value] &T: |x| x.expr(), Var: |x| x.load() => Expr); -impl_spread!([T: Value] T: |x| x.expr(), &Var: |x| x.load() => Expr); -impl_spread!([T: Value] &T: |x| x.expr(), &Var: |x| x.load() => Expr); +impl_spread!([T: Linear] T: |x| x.expr(), Var: |x| x.load() => Expr); +impl_spread!([T: Linear] &T: |x| x.expr(), Var: |x| x.load() => Expr); +impl_spread!([T: Linear] T: |x| x.expr(), &Var: |x| x.load() => Expr); +impl_spread!([T: Linear] &T: |x| x.expr(), &Var: |x| x.load() => Expr); // Other way is unneded because of the deref impl. -impl_spread_single!([T: Value] &Expr: |x| x.clone(), Var: |x| x.load() => Expr); -impl_spread_single!([T: Value] &Expr: |x| x.clone(), &Var: |x| x.load() => Expr); +impl_spread_single!([T: Linear] &Expr: |x| x.clone(), Var: |x| x.load() => Expr); +impl_spread_single!([T: Linear] &Expr: |x| x.clone(), &Var: |x| x.load() => Expr); -impl_spread!([const N: usize, T: VectorElement] T: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); -impl_spread!([const N: usize, T: VectorElement] &T: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); -impl_spread!([const N: usize, T: VectorElement] Expr: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); -impl_spread!([const N: usize, T: VectorElement] &Expr: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); -impl_spread!([const N: usize, T: VectorElement] T: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); -impl_spread!([const N: usize, T: VectorElement] &T: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); -impl_spread!([const N: usize, T: VectorElement] Expr: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); -impl_spread!([const N: usize, T: VectorElement] &Expr: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); +impl_spread!([const N: usize, T: VectorAlign] T: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); +impl_spread!([const N: usize, T: VectorAlign] &T: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); +impl_spread!([const N: usize, T: VectorAlign] Expr: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); +impl_spread!([const N: usize, T: VectorAlign] &Expr: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); +impl_spread!([const N: usize, T: VectorAlign] T: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); +impl_spread!([const N: usize, T: VectorAlign] &T: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); +impl_spread!([const N: usize, T: VectorAlign] Expr: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); +impl_spread!([const N: usize, T: VectorAlign] &Expr: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); -impl_spread!([const N: usize, T: VectorElement] Expr: |x| Vector::::splat_expr(x), Vector: |x| x.expr() => Expr>); -impl_spread!([const N: usize, T: VectorElement] &Expr: |x| Vector::::splat_expr(x), Vector: |x| x.expr() => Expr>); -impl_spread!([const N: usize, T: VectorElement] Expr: |x| Vector::::splat_expr(x), &Vector: |x| x.expr() => Expr>); -impl_spread!([const N: usize, T: VectorElement] &Expr: |x| Vector::::splat_expr(x), &Vector: |x| x.expr() => Expr>); +impl_spread!([const N: usize, T: VectorAlign] Expr: |x| Vector::::splat_expr(x), Vector: |x| x.expr() => Expr>); +impl_spread!([const N: usize, T: VectorAlign] &Expr: |x| Vector::::splat_expr(x), Vector: |x| x.expr() => Expr>); +impl_spread!([const N: usize, T: VectorAlign] Expr: |x| Vector::::splat_expr(x), &Vector: |x| x.expr() => Expr>); +impl_spread!([const N: usize, T: VectorAlign] &Expr: |x| Vector::::splat_expr(x), &Vector: |x| x.expr() => Expr>); mod impls { use super::*; diff --git a/luisa_compute/src/lang/types/vector.rs b/luisa_compute/src/lang/types/vector.rs index 88f8563..a2167bc 100644 --- a/luisa_compute/src/lang/types/vector.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -17,11 +17,14 @@ pub mod swizzle; use swizzle::*; -pub trait VectorElement: Primitive { +pub trait VectorElement: VectorAlign<2> + VectorAlign<3> + VectorAlign<4> {} +impl + VectorAlign<3> + VectorAlign<4>> VectorElement for T {} + +pub trait VectorAlign: Primitive { type A: Alignment; } -impl, const N: usize> Debug for Vector { +impl, const N: usize> Debug for Vector { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.elements.fmt(f) } @@ -29,13 +32,13 @@ impl, const N: usize> Debug for Vector { #[repr(C)] #[derive(Copy, Clone, Hash, Default, PartialEq, Eq, Serialize, Deserialize)] -pub struct Vector, const N: usize> { +pub struct Vector, const N: usize> { #[serde(skip)] _align: T::A, elements: [T; N], } -impl, const N: usize> Vector { +impl, const N: usize> Vector { pub fn from_elements(elements: [T; N]) -> Self { Self { _align: T::A::default(), @@ -53,13 +56,13 @@ impl, const N: usize> Vector { } fn _permute2(&self, x: u32, y: u32) -> Vector where - T: VectorElement<2>, + T: VectorAlign<2>, { Vector::from_elements([self.elements[x as usize], self.elements[y as usize]]) } fn _permute3(&self, x: u32, y: u32, z: u32) -> Vector where - T: VectorElement<3>, + T: VectorAlign<3>, { Vector::from_elements([ self.elements[x as usize], @@ -69,7 +72,7 @@ impl, const N: usize> Vector { } fn _permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Vector where - T: VectorElement<4>, + T: VectorAlign<4>, { Vector::from_elements([ self.elements[x as usize], @@ -82,8 +85,8 @@ impl, const N: usize> Vector { #[repr(C)] #[derive(Debug, Copy, Clone)] -pub struct VectorExprData, const N: usize>([Expr; N]); -impl, const N: usize> FromNode for VectorExprData { +pub struct VectorExprData, const N: usize>([Expr; N]); +impl, const N: usize> FromNode for VectorExprData { fn from_node(node: NodeRef) -> Self { Self(std::array::from_fn(|i| { FromNode::from_node(__extract::(node, i)) @@ -92,8 +95,8 @@ impl, const N: usize> FromNode for VectorExprData { } #[repr(C)] #[derive(Debug, Copy, Clone)] -pub struct VectorVarData, const N: usize>([Var; N]); -impl, const N: usize> FromNode for VectorVarData { +pub struct VectorVarData, const N: usize>([Var; N]); +impl, const N: usize> FromNode for VectorVarData { fn from_node(node: NodeRef) -> Self { Self(std::array::from_fn(|i| { FromNode::from_node(__extract::(node, i)) @@ -117,30 +120,30 @@ macro_rules! vector_proxies { ($N:literal [ $($c:ident),* ]: $ExprName:ident, $VarName:ident) => { #[repr(C)] #[derive(Debug, Copy, Clone)] - pub struct $ExprName> { + pub struct $ExprName> { _node: NodeRef, $(pub $c: Expr),* } #[repr(C)] #[derive(Debug, Copy, Clone)] - pub struct $VarName> { + pub struct $VarName> { _node: NodeRef, $(pub $c: Var),* } - unsafe impl> HasExprLayout< as Value>::ExprData> for $ExprName {} - unsafe impl> HasVarLayout< as Value>::VarData> for $VarName {} + unsafe impl> HasExprLayout< as Value>::ExprData> for $ExprName {} + unsafe impl> HasVarLayout< as Value>::VarData> for $VarName {} - impl> ExprProxy for $ExprName { + impl> ExprProxy for $ExprName { type Value = Vector; } - impl> VectorExprProxy for $ExprName { + impl> VectorExprProxy for $ExprName { type T = T; } - impl> VarProxy for $VarName { + impl> VarProxy for $VarName { type Value = Vector; } - impl> Deref for $VarName { + impl> Deref for $VarName { type Target = Expr>; fn deref(&self) -> &Self::Target { _deref_proxy(self) @@ -153,7 +156,7 @@ vector_proxies!(2 [x, y]: VectorExprProxy2, VectorVarProxy2); vector_proxies!(3 [x, y, z, r, g, b]: VectorExprProxy3, VectorVarProxy3); vector_proxies!(4 [x, y, z, w, r, g, b, a]: VectorExprProxy4, VectorVarProxy4); -impl, const N: usize> TypeOf for Vector { +impl, const N: usize> TypeOf for Vector { fn type_() -> CArc { let type_ = Type::Vector(VectorType { element: VectorElementType::Scalar(T::type_()), @@ -163,26 +166,26 @@ impl, const N: usize> TypeOf for Vector { } } -impl> Value for Vector { +impl> Value for Vector { type Expr = VectorExprProxy2; type Var = VectorVarProxy2; type ExprData = VectorExprData; type VarData = VectorVarData; } -impl> Value for Vector { +impl> Value for Vector { type Expr = VectorExprProxy3; type Var = VectorVarProxy3; type ExprData = DoubledProxyData>; type VarData = DoubledProxyData>; } -impl> Value for Vector { +impl> Value for Vector { type Expr = VectorExprProxy4; type Var = VectorVarProxy4; type ExprData = DoubledProxyData>; type VarData = DoubledProxyData>; } -impl + VectorElement<3> + VectorElement<4>> Vec2Swizzle for Vector { +impl Vec2Swizzle for Vector { type Vec2 = Self; type Vec3 = Vector; type Vec4 = Vector; @@ -196,7 +199,7 @@ impl + VectorElement<3> + VectorElement<4>> Vec2Swizzle for self._permute4(x, y, z, w) } } -impl + VectorElement<3> + VectorElement<4>> Vec3Swizzle for Vector { +impl Vec3Swizzle for Vector { type Vec2 = Vector; type Vec3 = Self; type Vec4 = Vector; @@ -210,7 +213,7 @@ impl + VectorElement<3> + VectorElement<4>> Vec3Swizzle for self._permute4(x, y, z, w) } } -impl + VectorElement<3> + VectorElement<4>> Vec4Swizzle for Vector { +impl Vec4Swizzle for Vector { type Vec2 = Vector; type Vec3 = Vector; type Vec4 = Self; @@ -225,10 +228,10 @@ impl + VectorElement<3> + VectorElement<4>> Vec4Swizzle for } } -impl, const N: usize> VectorExprData { +impl, const N: usize> VectorExprData { fn _permute2(&self, x: u32, y: u32) -> Expr> where - T: VectorElement<2>, + T: VectorAlign<2>, { assert!(x < N as u32); assert!(y < N as u32); @@ -244,7 +247,7 @@ impl, const N: usize> VectorExprData { } fn _permute3(&self, x: u32, y: u32, z: u32) -> Expr> where - T: VectorElement<3>, + T: VectorAlign<3>, { assert!(x < N as u32); assert!(y < N as u32); @@ -267,7 +270,7 @@ impl, const N: usize> VectorExprData { } fn _permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Expr> where - T: VectorElement<4>, + T: VectorAlign<4>, { assert!(x < N as u32); assert!(y < N as u32); @@ -293,9 +296,7 @@ impl, const N: usize> VectorExprData { } } -impl + VectorElement<3> + VectorElement<4>> Vec2Swizzle - for VectorExprProxy2 -{ +impl Vec2Swizzle for VectorExprProxy2 { type Vec2 = Self; type Vec3 = Expr>; type Vec4 = Expr>; @@ -310,9 +311,7 @@ impl + VectorElement<3> + VectorElement<4>> Vec2Swizzle } } -impl + VectorElement<3> + VectorElement<4>> Vec3Swizzle - for VectorExprProxy3 -{ +impl Vec3Swizzle for VectorExprProxy3 { type Vec2 = Expr>; type Vec3 = Self; type Vec4 = Expr>; @@ -326,9 +325,7 @@ impl + VectorElement<3> + VectorElement<4>> Vec3Swizzle self._permute4(x, y, z, w) } } -impl + VectorElement<3> + VectorElement<4>> Vec4Swizzle - for VectorExprProxy4 -{ +impl Vec4Swizzle for VectorExprProxy4 { type Vec2 = Expr>; type Vec3 = Expr>; type Vec4 = Self; diff --git a/luisa_compute/src/lang/types/vector/element.rs b/luisa_compute/src/lang/types/vector/element.rs index ff2f661..cf53277 100644 --- a/luisa_compute/src/lang/types/vector/element.rs +++ b/luisa_compute/src/lang/types/vector/element.rs @@ -1,13 +1,13 @@ use super::*; // Stupid hack to make ops work. -impl VectorElement<1> for T { +impl VectorAlign<1> for T { type A = Align1; } macro_rules! element { ($t:ty [ $l:literal ]: $a: ident) => { - impl VectorElement<$l> for $t { + impl VectorAlign<$l> for $t { type A = $a; } }; From 8fa839575d125d49783ecac019d4f243c16590b8 Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Wed, 20 Sep 2023 13:01:22 +0100 Subject: [PATCH 08/15] Commented out irrelevant things for debugging. --- luisa_compute/src/lang.rs | 16 ++++++------- luisa_compute/src/lang/types.rs | 2 +- luisa_compute/src/lib.rs | 40 +++++++++++++++++---------------- 3 files changed, 30 insertions(+), 28 deletions(-) diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index 6799cb2..3932d41 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -10,7 +10,7 @@ use crate::internal_prelude::*; use bumpalo::Bump; use indexmap::IndexMap; -use crate::runtime::WeakDevice; +// use crate::runtime::WeakDevice; pub mod ir { pub use luisa_compute_ir::context::register_type; @@ -24,14 +24,14 @@ use ir::{ Instruction, IrBuilder, ModulePools, Pooled, Type, TypeOf, UserNodeData, }; -pub mod control_flow; +// pub mod control_flow; pub mod debug; -pub mod diff; -pub mod functions; +// pub mod diff; +// pub mod functions; pub mod index; // pub mod maybe_expr; pub mod ops; -pub mod poly; +// pub mod poly; pub mod types; pub(crate) trait CallFuncTrait { @@ -235,7 +235,7 @@ pub(crate) struct Recorder { pub(crate) cpu_custom_ops: IndexMap)>, pub(crate) callables: IndexMap, pub(crate) shared: Vec, - pub(crate) device: Option, + // pub(crate) device: Option, pub(crate) block_size: Option<[u32; 3]>, pub(crate) building_kernel: bool, pub(crate) pools: Option>, @@ -250,7 +250,7 @@ impl Recorder { self.cpu_custom_ops.clear(); self.callables.clear(); self.lock = false; - self.device = None; + // self.device = None; self.block_size = None; self.arena.reset(); self.shared.clear(); @@ -265,7 +265,7 @@ impl Recorder { cpu_custom_ops: IndexMap::new(), callables: IndexMap::new(), shared: vec![], - device: None, + // device: None, block_size: None, pools: None, arena: Bump::new(), diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs index c046030..38195fb 100644 --- a/luisa_compute/src/lang/types.rs +++ b/luisa_compute/src/lang/types.rs @@ -5,7 +5,7 @@ use crate::internal_prelude::*; pub mod alignment; pub mod array; pub mod core; -pub mod dynamic; +// pub mod dynamic; pub mod shared; pub mod vector; diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 0dd0b6b..549d51c 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -8,30 +8,31 @@ use std::path::Path; use std::sync::Arc; pub mod lang; -pub mod printer; -pub mod resource; -pub mod rtx; -pub mod runtime; +// pub mod printer; +// pub mod resource; +// pub mod rtx; +// pub mod runtime; pub mod prelude { pub use half::f16; - pub use crate::lang::control_flow::{ - break_, continue_, for_range, return_, return_v, select, switch, - }; - pub use crate::lang::functions::{block_size, dispatch_id, dispatch_size, set_block_size}; + // pub use crate::lang::control_flow::{ + // break_, continue_, for_range, return_, return_v, select, switch, + // }; + // pub use crate::lang::functions::{block_size, dispatch_id, dispatch_size, + // set_block_size}; pub use crate::lang::index::{IndexRead, IndexWrite}; pub use crate::lang::ops::*; pub use crate::lang::types::vector::swizzle::*; pub use crate::lang::types::vector::Vector; pub use crate::lang::types::{Expr, ExprProxy, Value, Var, VarProxy}; pub use crate::lang::Aggregate; - pub use crate::resource::{IoTexel, StorageTexel, *}; - pub use crate::runtime::api::StreamTag; - pub use crate::runtime::{ - create_static_callable, Command, Device, KernelBuildOptions, Scope, Stream, - }; - pub use crate::{cpu_dbg, if_, lc_assert, lc_unreachable, loop_, struct_, while_, Context}; + // pub use crate::resource::{IoTexel, StorageTexel, *}; + // pub use crate::runtime::api::StreamTag; + // pub use crate::runtime::{ + // create_static_callable, Command, Device, KernelBuildOptions, Scope, + // Stream, }; + pub use crate::{cpu_dbg, lc_assert, lc_unreachable, struct_}; pub use luisa_compute_derive::*; pub use luisa_compute_track::track; @@ -50,11 +51,11 @@ mod internal_prelude { RECORDER, }; pub(crate) use crate::prelude::*; - pub(crate) use crate::runtime::{ - CallableArgEncoder, CallableParameter, CallableRet, KernelBuilder, - }; + // pub(crate) use crate::runtime::{ + // CallableArgEncoder, CallableParameter, CallableRet, KernelBuilder, + // }; pub(crate) use crate::{ - get_backtrace, impl_simple_expr_proxy, impl_simple_var_proxy, ResourceTracker, + get_backtrace, impl_simple_expr_proxy, impl_simple_var_proxy, /* ResourceTracker, */ }; pub(crate) use luisa_compute_backend::Backend; pub(crate) use std::marker::PhantomData; @@ -69,10 +70,11 @@ use lazy_static::lazy_static; use luisa_compute_backend::Backend; use parking_lot::lock_api::RawMutex as RawMutexTrait; use parking_lot::{Mutex, RawMutex}; -use runtime::{Device, DeviceHandle, StreamHandle}; +// use runtime::{Device, DeviceHandle, StreamHandle}; use std::collections::HashMap; use std::sync::Weak; +/* pub struct Context { inner: Arc, } From 5a5501ae80f3d9f9f634473467ce968c2d83868f Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Wed, 20 Sep 2023 13:01:30 +0100 Subject: [PATCH 09/15] Bugfixes. --- luisa_compute/src/lang/ops.rs | 37 +++++++++---------- luisa_compute/src/lang/ops/impls.rs | 4 +- luisa_compute/src/lang/ops/spread.rs | 37 ++++++++++++------- luisa_compute/src/lang/ops/traits.rs | 2 +- luisa_compute/src/lang/types/vector.rs | 31 ++++++---------- luisa_compute/src/lang/types/vector/coords.rs | 4 +- .../src/lang/types/vector/element.rs | 33 +++++++++++++---- luisa_compute/src/lib.rs | 1 + luisa_compute_sys/LuisaCompute | 2 +- 9 files changed, 84 insertions(+), 67 deletions(-) diff --git a/luisa_compute/src/lang/ops.rs b/luisa_compute/src/lang/ops.rs index f1d6e73..b99dae9 100644 --- a/luisa_compute/src/lang/ops.rs +++ b/luisa_compute/src/lang/ops.rs @@ -23,27 +23,24 @@ pub trait Linear: Value { // We don't actually know that the vector has equivalent vectors of every // primitive type. } -impl Linear for T { +impl Linear for T { const N: usize = 1; type Scalar = T; - type WithScalar = S; + type WithScalar = S; } -macro_rules! impl_linear_vectors { - ($t:ty) => { - impl_linear_vectors!($t: 2, 3, 4); - }; - ($t:ty, $($ts:ty),+) => { - impl_linear_vectors!($t); - impl_linear_vectors!($($ts),+); - }; - ($t:ty : $($n:literal),+) => { - $( - impl Linear for Vector<$t, $n> { - const N: usize = $n; - type Scalar = $t; - type WithScalar = Vector; - } - )+ - } + +impl Linear for Vector { + const N: usize = 2; + type Scalar = T; + type WithScalar = Vector; +} +impl Linear for Vector { + const N: usize = 3; + type Scalar = T; + type WithScalar = Vector; +} +impl Linear for Vector { + const N: usize = 4; + type Scalar = T; + type WithScalar = Vector; } -impl_linear_vectors!(bool, f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64); diff --git a/luisa_compute/src/lang/ops/impls.rs b/luisa_compute/src/lang/ops/impls.rs index 40c138d..4342680 100644 --- a/luisa_compute/src/lang/ops/impls.rs +++ b/luisa_compute/src/lang/ops/impls.rs @@ -51,7 +51,7 @@ where } impl EqExpr for Expr { - type Output = Expr; + type Output = Expr>; fn eq(self, other: Self) -> Self::Output { Func::Eq.call2(self, other) } @@ -209,7 +209,7 @@ impl FloatExpr for Expr where X::Scalar: Floating, { - type Bool = Self::WithBool; + type Bool = Expr>; impl_simple_fns! { ceil => Ceil, floor => Floor, diff --git a/luisa_compute/src/lang/ops/spread.rs b/luisa_compute/src/lang/ops/spread.rs index 3fc3bda..b52daba 100644 --- a/luisa_compute/src/lang/ops/spread.rs +++ b/luisa_compute/src/lang/ops/spread.rs @@ -2,7 +2,7 @@ use super::*; use traits::*; trait SpreadOps { - type Join; + type Join: Value; fn lift_self(x: Self) -> Expr; fn lift_other(x: Other) -> Expr; } @@ -48,7 +48,7 @@ impl_spread!([T: Linear] T: |x| x.expr(), &Expr: |x| x.clone() => Expr); impl_spread!([T: Linear] &T: |x| x.expr(), &Expr: |x| x.clone() => Expr); impl_spread!([T: Linear] Expr: |x| x, &Expr: |x| x.clone() => Expr); -impl_spread!([T: Linear] &Expr: |x| x.clone(), &Expr: |x| x.clone() => Expr); +impl_spread_single!([T: Linear] &Expr: |x| x.clone(), &Expr: |x| x.clone() => Expr); impl_spread!([T: Linear] T: |x| x.expr(), Var: |x| x.load() => Expr); impl_spread!([T: Linear] &T: |x| x.expr(), Var: |x| x.load() => Expr); @@ -80,7 +80,7 @@ mod impls { T: SpreadOps, Expr: MinMaxExpr, { - type Output = Expr::Output; + type Output = as MinMaxExpr>::Output; fn max(self, other: S) -> Self::Output { Expr::::max(Self::lift_self(self), Self::lift_other(other)) } @@ -106,7 +106,7 @@ mod impls { /// / / \ /// T S U - type Output = Expr::Output; + type Output = as ClampExpr>::Output; fn clamp(&self, min: S, max: U) -> Self::Output { Expr::::clamp( Self::lift_self(self), @@ -120,7 +120,7 @@ mod impls { T: SpreadOps, Expr: EqExpr, { - type Output = Expr::Output; + type Output = as EqExpr>::Output; fn eq(self, other: S) -> Self::Output { Expr::::eq(Self::lift_self(self), Self::lift_other(other)) } @@ -152,7 +152,7 @@ mod impls { T: SpreadOps, Expr: FloatMulAddExpr, { - type Output = Expr::Output; + type Output = as FloatMulAddExpr>::Output; fn mul_add(self, mul: S, add: U) -> Self::Output { Expr::::mul_add( Self::lift_self(self), @@ -166,7 +166,7 @@ mod impls { T: SpreadOps, Expr: FloatCopySignExpr, { - type Output = Expr::Output; + type Output = as FloatCopySignExpr>::Output; fn copy_sign(self, sign: S) -> Self::Output { Expr::::copy_sign(Self::lift_self(self), Self::lift_other(sign)) } @@ -176,7 +176,7 @@ mod impls { T: SpreadOps, Expr: FloatStepExpr, { - type Output = Expr::Output; + type Output = as FloatStepExpr>::Output; fn step(self, edge: S) -> Self::Output { Expr::::step(Self::lift_self(self), Self::lift_other(edge)) } @@ -187,7 +187,7 @@ mod impls { T: SpreadOps, Expr: FloatSmoothStepExpr, { - type Output = Expr::Output; + type Output = as FloatSmoothStepExpr>::Output; fn smooth_step(self, edge0: S, edge1: U) -> Self::Output { Expr::::smooth_step( Self::lift_self(self), @@ -201,7 +201,7 @@ mod impls { T: SpreadOps, Expr: FloatArcTan2Expr, { - type Output = Expr::Output; + type Output = as FloatArcTan2Expr>::Output; fn atan2(self, other: S) -> Self::Output { Expr::::atan2(Self::lift_self(self), Self::lift_other(other)) } @@ -211,7 +211,7 @@ mod impls { T: SpreadOps, Expr: FloatLogExpr, { - type Output = Expr::Output; + type Output = as FloatLogExpr>::Output; fn log(self, base: S) -> Self::Output { Expr::::log(Self::lift_self(self), Self::lift_other(base)) } @@ -221,7 +221,7 @@ mod impls { T: SpreadOps, Expr: FloatPowfExpr, { - type Output = Expr::Output; + type Output = as FloatPowfExpr>::Output; fn powf(self, exponent: S) -> Self::Output { Expr::::powf(Self::lift_self(self), Self::lift_other(exponent)) } @@ -232,7 +232,7 @@ mod impls { T: SpreadOps, Expr: FloatLerpExpr, { - type Output = Expr::Output; + type Output = as FloatLerpExpr>::Output; fn lerp(self, other: S, frac: U) -> Self::Output { Expr::::lerp( Self::lift_self(self), @@ -262,3 +262,14 @@ macro_rules! impl_num_spread { impl_spread_op!( [ $($bounds)* ]: Rem::rem for $T, $S); }; } + +mod tests { + fn test() { + let x = 10.0f32; + let y = 20.0f32; + let z = x.min(y); + + let w = x.expr().min(y); + println!("{:?}", w); + } +} diff --git a/luisa_compute/src/lang/ops/traits.rs b/luisa_compute/src/lang/ops/traits.rs index ce49a28..68ec7e7 100644 --- a/luisa_compute/src/lang/ops/traits.rs +++ b/luisa_compute/src/lang/ops/traits.rs @@ -36,7 +36,7 @@ pub trait IntExpr { fn rotate_left(&self, n: Expr) -> Self; } -pub trait FloatExpr { +pub trait FloatExpr: Sized { type Bool; fn ceil(&self) -> Self; diff --git a/luisa_compute/src/lang/types/vector.rs b/luisa_compute/src/lang/types/vector.rs index a2167bc..ae86b85 100644 --- a/luisa_compute/src/lang/types/vector.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -1,10 +1,9 @@ use super::alignment::*; use super::core::*; use super::*; -use ir::{MatrixType, VectorElementType, VectorType}; +use ir::{VectorElementType, VectorType}; use serde::{Deserialize, Serialize}; use std::fmt::Debug; -use std::ops::Mul; #[cfg(feature = "glam")] mod glam; @@ -22,6 +21,10 @@ impl + VectorAlign<3> + VectorAlign<4>> VectorElement for T {} pub trait VectorAlign: Primitive { type A: Alignment; + type VectorExpr: ExprProxy>; + type VectorVar: VarProxy>; + type VectorExprData: Clone + FromNode + 'static; + type VectorVarData: Clone + FromNode + 'static; } impl, const N: usize> Debug for Vector { @@ -51,7 +54,7 @@ impl, const N: usize> Vector { elements: [element; N], } } - pub fn splat_expr(element: impl AsExpr) -> Expr { + pub fn splat_expr(element: impl AsExpr) -> Expr { Func::Vec.call(element.as_expr()) } fn _permute2(&self, x: u32, y: u32) -> Vector @@ -166,23 +169,11 @@ impl, const N: usize> TypeOf for Vector { } } -impl> Value for Vector { - type Expr = VectorExprProxy2; - type Var = VectorVarProxy2; - type ExprData = VectorExprData; - type VarData = VectorVarData; -} -impl> Value for Vector { - type Expr = VectorExprProxy3; - type Var = VectorVarProxy3; - type ExprData = DoubledProxyData>; - type VarData = DoubledProxyData>; -} -impl> Value for Vector { - type Expr = VectorExprProxy4; - type Var = VectorVarProxy4; - type ExprData = DoubledProxyData>; - type VarData = DoubledProxyData>; +impl, const N: usize> Value for Vector { + type Expr = T::VectorExpr; + type Var = T::VectorVar; + type ExprData = T::VectorExprData; + type VarData = T::VectorVarData; } impl Vec2Swizzle for Vector { diff --git a/luisa_compute/src/lang/types/vector/coords.rs b/luisa_compute/src/lang/types/vector/coords.rs index 5b11642..583122c 100644 --- a/luisa_compute/src/lang/types/vector/coords.rs +++ b/luisa_compute/src/lang/types/vector/coords.rs @@ -12,7 +12,7 @@ macro_rules! impl_coords { } macro_rules! impl_deref { ($T:ident; $N:literal) => { - impl> Deref for Vector { + impl> Deref for Vector { type Target = $T; #[inline] @@ -21,7 +21,7 @@ macro_rules! impl_deref { } } - impl> DerefMut for Vector { + impl> DerefMut for Vector { #[inline] fn deref_mut(&self) -> &$T { unsafe { &*(self as *const Self as *const $T) } diff --git a/luisa_compute/src/lang/types/vector/element.rs b/luisa_compute/src/lang/types/vector/element.rs index cf53277..8583536 100644 --- a/luisa_compute/src/lang/types/vector/element.rs +++ b/luisa_compute/src/lang/types/vector/element.rs @@ -1,14 +1,31 @@ use super::*; -// Stupid hack to make ops work. -impl VectorAlign<1> for T { - type A = Align1; -} - macro_rules! element { - ($t:ty [ $l:literal ]: $a: ident) => { - impl VectorAlign<$l> for $t { - type A = $a; + ($T:ty [ 2 ]: $A: ident) => { + impl VectorAlign<2> for $T { + type A = $A; + type VectorExpr = VectorExprProxy2<$T>; + type VectorVar = VectorVarProxy2<$T>; + type VectorExprData = VectorExprData<$T, 2>; + type VectorVarData = VectorVarData<$T, 2>; + } + }; + ($T:ty [ 3 ]: $A: ident) => { + impl VectorAlign<3> for $T { + type A = $A; + type VectorExpr = VectorExprProxy3<$T>; + type VectorVar = VectorVarProxy3<$T>; + type VectorExprData = DoubledProxyData>; + type VectorVarData = DoubledProxyData>; + } + }; + ($T:ty [ 4 ]: $A: ident) => { + impl VectorAlign<4> for $T { + type A = $A; + type VectorExpr = VectorExprProxy4<$T>; + type VectorVar = VectorVarProxy4<$T>; + type VectorExprData = DoubledProxyData>; + type VectorVarData = DoubledProxyData>; } }; } diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 549d51c..0308cd7 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -163,6 +163,7 @@ unsafe impl Send for ResourceTracker {} unsafe impl Sync for ResourceTracker {} + */ pub(crate) fn get_backtrace() -> Backtrace { Backtrace::force_capture() } diff --git a/luisa_compute_sys/LuisaCompute b/luisa_compute_sys/LuisaCompute index f4e8dab..517ae49 160000 --- a/luisa_compute_sys/LuisaCompute +++ b/luisa_compute_sys/LuisaCompute @@ -1 +1 @@ -Subproject commit f4e8dabfd1c5d8ce0923b1f81f2dc8ef3ea8f68e +Subproject commit 517ae49e84b6d255739d92d28590588c9de4ee56 From 9ed99cbc8ef77bbc05664b7e74e7747a1db614d3 Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Wed, 20 Sep 2023 20:01:03 +0100 Subject: [PATCH 10/15] Initial working ops version. --- luisa_compute/src/lang.rs | 26 +- luisa_compute/src/lang/control_flow.rs | 31 +- luisa_compute/src/lang/debug.rs | 7 +- luisa_compute/src/lang/index.rs | 2 +- luisa_compute/src/lang/maybe_expr.rs | 334 --------------- luisa_compute/src/lang/ops.rs | 2 + luisa_compute/src/lang/ops/impls.rs | 390 +++++++++++++----- luisa_compute/src/lang/ops/spread.rs | 249 +++++++---- luisa_compute/src/lang/ops/traits.rs | 171 +++++--- luisa_compute/src/lang/types.rs | 18 +- luisa_compute/src/lang/types/array.rs | 10 +- luisa_compute/src/lang/types/core.rs | 104 ++--- luisa_compute/src/lang/types/shared.rs | 2 +- luisa_compute/src/lang/types/vector.rs | 33 +- luisa_compute/src/lang/types/vector/coords.rs | 12 +- luisa_compute/src/lib.rs | 4 +- luisa_compute_track/src/lib.rs | 14 +- 17 files changed, 683 insertions(+), 726 deletions(-) delete mode 100644 luisa_compute/src/lang/maybe_expr.rs diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index 3932d41..89f58c4 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -24,12 +24,11 @@ use ir::{ Instruction, IrBuilder, ModulePools, Pooled, Type, TypeOf, UserNodeData, }; -// pub mod control_flow; +pub mod control_flow; pub mod debug; // pub mod diff; // pub mod functions; pub mod index; -// pub mod maybe_expr; pub mod ops; // pub mod poly; pub mod types; @@ -43,6 +42,9 @@ pub(crate) trait CallFuncTrait { y: Expr, z: Expr, ) -> Expr; + fn call_void(self, x: Expr); + fn call2_void(self, x: Expr, y: Expr); + fn call3_void(self, x: Expr, y: Expr, z: Expr); } impl CallFuncTrait for Func { fn call(self, x: Expr) -> Expr { @@ -69,6 +71,21 @@ impl CallFuncTrait for Func { ) })) } + fn call_void(self, x: Expr) { + __current_scope(|b| { + b.call(self, &[x.node()], Type::void()); + }); + } + fn call2_void(self, x: Expr, y: Expr) { + __current_scope(|b| { + b.call(self, &[x.node(), y.node()], Type::void()); + }); + } + fn call3_void(self, x: Expr, y: Expr, z: Expr) { + __current_scope(|b| { + b.call(self, &[x.node(), y.node(), z.node()], Type::void()); + }); + } } #[allow(dead_code)] @@ -452,12 +469,11 @@ pub const fn packed_size() -> usize { (std::mem::size_of::() + 3) / 4 } -pub fn pack_to(expr: E, buffer: &B, index: impl Into>) +pub fn pack_to(expr: Expr, buffer: &B, index: impl AsExpr) where - E: ExprProxy, B: IndexWrite, { - let index = index.into(); + let index = index.as_expr(); __current_scope(|b| { b.call( Func::Pack, diff --git a/luisa_compute/src/lang/control_flow.rs b/luisa_compute/src/lang/control_flow.rs index c02751d..06ca73f 100644 --- a/luisa_compute/src/lang/control_flow.rs +++ b/luisa_compute/src/lang/control_flow.rs @@ -4,45 +4,30 @@ use crate::internal_prelude::*; use ir::SwitchCase; /** - * If you want rustfmt to format your code, use if_!(cond, { .. }, { .. }) or if_!(cond, { .. }, else, {...}) - * instead of if_!(cond, { .. }, else {...}). + * If you want rustfmt to format your code, use if_!(cond, { .. }, { .. }) + * or if_!(cond, { .. }, else, {...}) instead of if_!(cond, { .. }, else + * {...}). * */ #[macro_export] macro_rules! if_ { ($cond:expr, $then:block, else $else_:block) => { - <_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else( - $cond, - || $then, - || $else_, - ) + <_ as $crate::lang::ops::SelectMaybeExpr<_>>::select($cond, || $then, || $else_) }; ($cond:expr, $then:block, else, $else_:block) => { - <_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else( - $cond, - || $then, - || $else_, - ) + <_ as $crate::lang::ops::SelectMaybeExpr<_>>::select($cond, || $then, || $else_) }; ($cond:expr, $then:block, $else_:block) => { - <_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else( - $cond, - || $then, - || $else_, - ) + <_ as $crate::lang::ops::SelectMaybeExpr<_>>::select($cond, || $then, || $else_) }; ($cond:expr, $then:block) => { - <_ as $crate::lang::maybe_expr::BoolIfElseMaybeExpr<_>>::if_then_else( - $cond, - || $then, - || {}, - ) + <_ as $crate::lang::ops::ActivateMaybeExpr>::activate($cond, || $then) }; } #[macro_export] macro_rules! while_ { ($cond:expr,$body:block) => { - <_ as $crate::lang::maybe_expr::BoolWhileMaybeExpr>::while_loop(|| $cond, || $body) + <_ as $crate::lang::ops::LoopMaybeExpr>::while_loop(|| $cond, || $body) }; } #[macro_export] diff --git a/luisa_compute/src/lang/debug.rs b/luisa_compute/src/lang/debug.rs index 1f2c362..7a31e96 100644 --- a/luisa_compute/src/lang/debug.rs +++ b/luisa_compute/src/lang/debug.rs @@ -31,7 +31,7 @@ impl CpuFn { _marker: PhantomData, } } - pub fn call(&self, arg: impl ExprProxy) -> Expr { + pub fn call(&self, arg: impl AsExpr) -> Expr { RECORDER.with(|r| { let mut r = r.borrow_mut(); assert!(r.lock); @@ -107,10 +107,7 @@ macro_rules! lc_assert { $crate::lang::debug::__assert($arg, $msg, file!(), line!(), column!()) }; } -pub fn __cpu_dbg(arg: T, file: &'static str, line: u32) -where - T::Value: Debug, -{ +pub fn __cpu_dbg(arg: Expr, file: &'static str, line: u32) { if !is_cpu_backend() { return; } diff --git a/luisa_compute/src/lang/index.rs b/luisa_compute/src/lang/index.rs index 31d84d3..9c2fc5f 100644 --- a/luisa_compute/src/lang/index.rs +++ b/luisa_compute/src/lang/index.rs @@ -25,7 +25,7 @@ impl IntoIndex for u64 { } impl IntoIndex for Expr { fn to_u64(&self) -> Expr { - self.ulong() + self.cast::() } } impl IntoIndex for Expr { diff --git a/luisa_compute/src/lang/maybe_expr.rs b/luisa_compute/src/lang/maybe_expr.rs deleted file mode 100644 index 7bb76b2..0000000 --- a/luisa_compute/src/lang/maybe_expr.rs +++ /dev/null @@ -1,334 +0,0 @@ -//! The purpose of this module is to provide traits to represent things that may -//! either be an expression or a normal value. This is necessary for making the -//! trace macro work for both types of value. - -use std::ops::DerefMut; - -use super::control_flow::{generic_loop, if_then_else}; -use super::types::core::*; -use super::types::AsExpr; -use crate::internal_prelude::*; - -/*== Version 1 -pub trait DerefSet { - type Target; - fn deref_set(&mut self, target: Self::Target); -} -impl DerefSet for T -where - T::Target: Sized, -{ - type Target = T::Target; - fn deref_set(&mut self, target: Self::Target) { - **self = target; - } -} -impl DerefSet for Var { - type Target = Expr; - fn deref_set(&mut self, target: Self::Target) { - self.store(target); - } -} -*/ -/*== Version 2 -pub trait DerefSet { - type Target; - fn deref_set(self, target: Self::Target); -} -impl DerefSet for &mut T -where - T::Target: Sized, -{ - type Target = T::Target; - fn deref_set(self, target: Self::Target) { - *self = target; - } -} -impl DerefSet for Var { - type Target = Expr; - fn deref_set(&mut self, target: Self::Target) { - self.store(target); - } -} -// TODO: Confirm that `&mut Var` errors. Otherwise, make a `&mut Var` impl that -// panics. -impl DerefSet for &Var { - type Target = Expr; - fn deref_set(&mut self, target: Self::Target) { - self.store(target); - } -} -*/ -/* == Version 3 == */ -pub trait DerefSet { - fn deref_set(self, target: X); -} -impl DerefSet for &mut T -where - T::Target: Sized, -{ - fn deref_set(self, target: T::Target) { - *self = target; - } -} -impl> DerefSet for Var { - fn deref_set(self, target: X) { - self.store(target.as_expr()); - } -} -// TODO: Confirm that `&mut Var` errors. Otherwise, make a `&mut Var` impl that -// panics. -impl> DerefSet for &Var { - fn deref_set(self, target: X) { - self.store(target.as_expr()); - } -} - -pub trait BoolIfElseMaybeExpr { - fn if_then_else(self, then: impl FnOnce() -> R, else_: impl FnOnce() -> R) -> R; -} -impl BoolIfElseMaybeExpr for bool { - fn if_then_else(self, then: impl FnOnce() -> R, else_: impl FnOnce() -> R) -> R { - if self { - then() - } else { - else_() - } - } -} -impl BoolIfElseMaybeExpr for Bool { - fn if_then_else(self, then: impl FnOnce() -> R, else_: impl FnOnce() -> R) -> R { - if_then_else(self, then, else_) - } -} - -pub trait BoolIfMaybeExpr { - fn if_then(self, then: impl FnOnce()); -} -impl BoolIfMaybeExpr for bool { - fn if_then(self, then: impl FnOnce()) { - if self { - then() - } - } -} -impl BoolIfMaybeExpr for Bool { - fn if_then(self, then: impl FnOnce()) { - if_then_else(self, then, || {}) - } -} - -pub trait BoolWhileMaybeExpr { - fn while_loop(this: impl FnMut() -> Self, body: impl FnMut()); -} -impl BoolWhileMaybeExpr for bool { - fn while_loop(mut this: impl FnMut() -> Self, mut body: impl FnMut()) { - while this() { - body() - } - } -} -impl BoolWhileMaybeExpr for Bool { - fn while_loop(this: impl FnMut() -> Self, body: impl FnMut()) { - generic_loop(this, body, || {}); - } -} - -// TODO: Support lazy expressions if that isn't done already? -pub trait BoolLazyOpsMaybeExpr { - type Ret; - fn and(self, other: impl FnOnce() -> R) -> Self::Ret; - fn or(self, other: impl FnOnce() -> R) -> Self::Ret; -} -impl BoolLazyOpsMaybeExpr for bool { - type Ret = bool; - fn and(self, other: impl FnOnce() -> bool) -> Self::Ret { - self && other() - } - fn or(self, other: impl FnOnce() -> bool) -> Self::Ret { - self || other() - } -} -impl BoolLazyOpsMaybeExpr for bool { - type Ret = Bool; - fn and(self, other: impl FnOnce() -> Bool) -> Self::Ret { - self & other() - } - fn or(self, other: impl FnOnce() -> Bool) -> Self::Ret { - self | other() - } -} -impl BoolLazyOpsMaybeExpr for Bool { - type Ret = Bool; - fn and(self, other: impl FnOnce() -> bool) -> Self::Ret { - self & other() - } - fn or(self, other: impl FnOnce() -> bool) -> Self::Ret { - self | other() - } -} -impl BoolLazyOpsMaybeExpr for Bool { - type Ret = Bool; - fn and(self, other: impl FnOnce() -> Bool) -> Self::Ret { - self & other() - } - fn or(self, other: impl FnOnce() -> Bool) -> Self::Ret { - self | other() - } -} - -pub trait EqMaybeExpr { - type Bool; - fn eq(self, other: X) -> Self::Bool; - fn ne(self, other: X) -> Self::Bool; -} -impl> EqMaybeExpr for A { - type Bool = bool; - fn eq(self, other: R) -> Self::Bool { - self == other - } - fn ne(self, other: R) -> Self::Bool { - self != other - } -} -macro_rules! impl_eme { - ($t: ty, $s: ty) => { - impl EqMaybeExpr<$s> for $t { - type Bool = <$t as VarTrait>::Bool; - fn eq(self, other: $s) -> Self::Bool { - self.cmpeq(other) - } - fn ne(self, other: $s) -> Self::Bool { - self.cmpne(other) - } - } - }; -} -macro_rules! impl_mem { - ($t: ty, $s: ty) => { - impl EqMaybeExpr<$s> for $t { - type Bool = <$s as VarTrait>::Bool; - fn eq(self, other: $s) -> Self::Bool { - other.cmpeq(self) - } - fn ne(self, other: $s) -> Self::Bool { - other.cmpne(self) - } - } - }; -} -macro_rules! emes { - ($x: ty $(, $y: ty)*) => { - impl_eme!(Expr<$x>, Expr<$x>); - impl_eme!(Expr<$x>, $x); - impl_mem!($x, Expr<$x>); - $(impl_eme!(Expr<$x>, $y); - impl_mem!($y, Expr<$x>);)* - }; -} -emes!(bool); -emes!(Bool2); -emes!(Bool3); -emes!(Bool4); - -pub trait PartialOrdMaybeExpr { - type Bool; - fn lt(self, other: R) -> Self::Bool; - fn le(self, other: R) -> Self::Bool; - fn ge(self, other: R) -> Self::Bool; - fn gt(self, other: R) -> Self::Bool; -} -impl> PartialOrdMaybeExpr for A { - type Bool = bool; - fn lt(self, other: R) -> Self::Bool { - self < other - } - fn le(self, other: R) -> Self::Bool { - self <= other - } - fn ge(self, other: R) -> Self::Bool { - self >= other - } - fn gt(self, other: R) -> Self::Bool { - self > other - } -} -macro_rules! impl_pome { - ($t: ty, $s: ty) => { - impl_eme!($t, $s); - impl PartialOrdMaybeExpr<$s> for $t { - type Bool = <$t as VarTrait>::Bool; - fn lt(self, other: $s) -> Self::Bool { - self.cmplt(other) - } - fn le(self, other: $s) -> Self::Bool { - self.cmple(other) - } - fn ge(self, other: $s) -> Self::Bool { - self.cmpge(other) - } - fn gt(self, other: $s) -> Self::Bool { - self.cmpgt(other) - } - } - }; -} -macro_rules! impl_emop { - ($t: ty, $s: ty) => { - impl_mem!($t, $s); - impl PartialOrdMaybeExpr<$s> for $t { - type Bool = <$s as VarTrait>::Bool; - fn lt(self, other: $s) -> Self::Bool { - other.cmpgt(self) - } - fn le(self, other: $s) -> Self::Bool { - other.cmpge(self) - } - fn ge(self, other: $s) -> Self::Bool { - other.cmplt(self) - } - fn gt(self, other: $s) -> Self::Bool { - other.cmplt(self) - } - } - }; -} -macro_rules! pomes { - ($x: ty $(, $y:ty)*) => { - impl_pome!(Expr<$x>, Expr<$x>); - impl_pome!(Expr<$x>, $x); - impl_emop!($x, Expr<$x>); - impl_pome!(Expr<$x>, Var<$x>); - impl_emop!(Var<$x>, Expr<$x>); - $(impl_pome!(Expr<$x>, $y); - impl_emop!($y, Expr<$x>);)* - }; -} -pomes!(f16); -pomes!(f32); -pomes!(f64); -pomes!(i16); -pomes!(i32); -pomes!(i64); -pomes!(u16); -pomes!(u32); -pomes!(u64); - -pomes!(Float2, Expr, f32); -pomes!(Float3, Expr, f32); -pomes!(Float4, Expr, f32); -pomes!(Double2); -pomes!(Double3); -pomes!(Double4); -pomes!(Int2, Expr); -pomes!(Int3, Expr); -pomes!(Int4, Expr); -pomes!(Uint2, Expr); -pomes!(Uint3, Expr); -pomes!(Uint4, Expr); - -#[allow(dead_code)] -fn tests() { - <_ as BoolWhileMaybeExpr>::while_loop(|| true, || {}); - <_ as BoolWhileMaybeExpr>::while_loop(|| Bool::from(true), || {}); -} diff --git a/luisa_compute/src/lang/ops.rs b/luisa_compute/src/lang/ops.rs index b99dae9..3758647 100644 --- a/luisa_compute/src/lang/ops.rs +++ b/luisa_compute/src/lang/ops.rs @@ -8,6 +8,8 @@ pub mod impls; pub mod spread; pub mod traits; +pub use traits::*; + trait CastFrom: Primitive {} impl CastFrom for T {} impl CastFrom for T {} diff --git a/luisa_compute/src/lang/ops/impls.rs b/luisa_compute/src/lang/ops/impls.rs index 4342680..0cf2fe4 100644 --- a/luisa_compute/src/lang/ops/impls.rs +++ b/luisa_compute/src/lang/ops/impls.rs @@ -1,78 +1,106 @@ use super::*; -use traits::*; impl Expr { - fn as_>(self) -> Y + pub fn as_(self) -> Expr where Y::Scalar: CastFrom, { + assert_eq!(X::N, Y::N); Func::Cast.call(self) } - fn cast(self) -> Expr> + pub fn cast(self) -> Expr> where S: CastFrom, { - self.as_::>() + self.as_::<::WithScalar>() } } -impl MinMaxExpr for Expr -where - X::Scalar: Numeric, -{ - type Output = Self; - - fn max(self, other: Self) -> Self { - Func::Max.call2(self, other) - } - fn min(self, other: Self) -> Self { - Func::Min.call2(self, other) - } -} +macro_rules! impl_ops_trait { + ( + [$($bounds:tt)*] $TraitExpr:ident [$TraitThis:ident] for $T:ty where [$($where:tt)*] { + $( + fn $fn:ident [$fn_this:ident] ($sl:ident, $($arg:ident),*) { $body:expr } + )* + } + ) => { + impl<$($bounds)*> $TraitThis for $T where $($where)* { + $( + fn $fn_this($sl, $($arg: Self),*) -> Self { + $body + } + )* + } + impl<$($bounds)*> $TraitExpr for $T where $($where)* { + type Output = Self; -impl ClampExpr for Expr -where - X::Scalar: Numeric, -{ - type Output = Self; + $( + fn $fn($sl, $($arg: Self),*) -> Self { + <$T as $TraitThis>::$fn_this($sl, $($arg),*) + } + )* + } + }; + ( + [$($bounds:tt)*] $TraitExpr:ident [$TraitThis:ident] for $T:ty where [$($where:tt)*] { + type Output = $Output:ty; + $( + fn $fn:ident [$fn_this:ident] ($sl:ident, $($arg:ident),*) { $body:expr } + )* + } + ) => { + impl<$($bounds)*> $TraitThis for $T where $($where)* { + type Output = $Output; + $( + fn $fn_this($sl, $($arg: Self),*) -> Self::Output { + $body + } + )* + } + impl<$($bounds)*> $TraitExpr for $T where $($where)* { + type Output = $Output; - fn clamp(self, min: Self, max: Self) -> Self { - Func::Clamp.call3(self, min, max) + $( + fn $fn($sl, $($arg: Self),*) -> Self::Output { + <$T as $TraitThis>::$fn_this($sl, $($arg),*) + } + )* + } } } +impl_ops_trait!([X: Linear] MinMaxExpr[MinMaxThis] for Expr where [X::Scalar: Numeric] { + fn max[_max](self, other) { Func::Max.call2(self, other) } + fn min[_min](self, other) { Func::Min.call2(self, other) } +}); + +impl_ops_trait!([X: Linear] ClampExpr[ClampThis] for Expr where [X::Scalar: Numeric] { + fn clamp[_clamp](self, min, max) { Func::Clamp.call3(self, min, max) } +}); impl AbsExpr for Expr where X::Scalar: Signed, { fn abs(&self) -> Self { - Func::Abs.call(self) + Func::Abs.call(self.clone()) } } -impl EqExpr for Expr { +impl_ops_trait!([X: Linear] EqExpr[EqThis] for Expr where [X::Scalar: VectorElement] { + type Output = Expr>; + + fn eq[_eq](self, other) { Func::Eq.call2(self, other) } + fn ne[_ne](self, other) { Func::Ne.call2(self, other) } +}); + +impl_ops_trait!([X: Linear] CmpExpr[CmpThis] for Expr where [X::Scalar: Numeric] { type Output = Expr>; - fn eq(self, other: Self) -> Self::Output { - Func::Eq.call2(self, other) - } - fn ne(self, other: Self) -> Self::Output { - Func::Ne.call2(self, other) - } -} -impl CmpExpr for Expr { - fn lt(self, other: Self) -> Self::Output { - Func::Lt.call2(self, other) - } - fn le(self, other: Self) -> Self::Output { - Func::Le.call2(self, other) - } - fn gt(self, other: Self) -> Self::Output { - Func::Gt.call2(self, other) - } - fn ge(self, other: Self) -> Self::Output { - Func::Ge.call2(self, other) - } -} + + fn lt[_lt](self, other) { Func::Lt.call2(self, other) } + fn le[_le](self, other) { Func::Le.call2(self, other) } + fn gt[_gt](self, other) { Func::Gt.call2(self, other) } + fn ge[_ge](self, other) { Func::Ge.call2(self, other) } +}); impl Add for Expr where @@ -190,17 +218,17 @@ where X::Scalar: Integral + Numeric, { fn rotate_left(&self, n: Expr) -> Self { - Func::RotRight.call2(self, n) + Func::RotRight.call2(self.clone(), n) } fn rotate_right(&self, n: Expr) -> Self { - Func::RotLeft.call2(self, n) + Func::RotLeft.call2(self.clone(), n) } } macro_rules! impl_simple_fns { ($($fname:ident => $func:ident),+) => {$( fn $fname(&self) -> Self { - Func::$func.call(self) + Func::$func.call(self.clone()) } )+}; } @@ -233,8 +261,6 @@ where atanh => Atanh, exp => Exp, exp2 => Exp2, - is_infinite => IsInf, - is_nan => IsNan, ln => Log, log2 => Log2, log10 => Log10 @@ -242,106 +268,250 @@ where fn is_finite(&self) -> Self::Bool { !self.is_infinite() & !self.is_nan() } + fn is_infinite(&self) -> Self::Bool { + Func::IsInf.call(self.clone()) + } + fn is_nan(&self) -> Self::Bool { + Func::IsNan.call(self.clone()) + } fn sqr(&self) -> Self { - *self * *self + self.clone() * self.clone() } fn cube(&self) -> Self { - *self * *self * *self + self.clone() * self.clone() * self.clone() } fn recip(&self) -> Self { - 1.0 / *self + todo!() + // 1.0 / self.clone() } fn sin_cos(&self) -> (Self, Self) { (self.sin(), self.cos()) } } -impl FloatMulAddExpr for Expr +impl_ops_trait!([X: Linear] FloatMulAddExpr[FloatMulAddThis] for Expr where [X::Scalar: Floating] { + fn mul_add[_mul_add](self, a, b) { Func::Fma.call3(self, a, b) } +}); + +impl_ops_trait!([X: Linear] FloatCopySignExpr[FloatCopySignThis] for Expr where [X::Scalar: Floating] { + fn copy_sign[_copy_sign](self, sign) { Func::Copysign.call2(self, sign) } +}); + +impl_ops_trait!([X: Linear] FloatStepExpr[FloatStepThis] for Expr where [X::Scalar: Floating] { + fn step[_step](self, edge) { Func::Step.call2(edge, self) } +}); + +impl_ops_trait!([X: Linear] FloatSmoothStepExpr[FloatSmoothStepThis] for Expr where [X::Scalar: Floating] { + fn smooth_step[_smooth_step](self, edge0, edge1) { Func::SmoothStep.call3(edge0, edge1, self) } +}); + +impl_ops_trait!([X: Linear] FloatArcTan2Expr[FloatArcTan2This] for Expr where [X::Scalar: Floating] { + fn atan2[_atan2](self, other) { Func::Atan2.call2(self, other) } +}); + +impl_ops_trait!([X: Linear] FloatLogExpr[FloatLogThis] for Expr where [X::Scalar: Floating] { + fn log[_log](self, base) { self.ln() / base.ln()} +}); + +impl_ops_trait!([X: Linear] FloatPowfExpr[FloatPowfThis] for Expr where [X::Scalar: Floating] { + fn powf[_powf](self, exponent) { Func::Powf.call2(self, exponent) } +}); + +impl> FloatPowiExpr> for Expr where X::Scalar: Floating, { type Output = Self; - fn mul_add(self, a: Self, b: Self) -> Self::Output { - Func::Fma.call3(self, a, b) + fn powi(self, exponent: Expr) -> Self::Output { + Func::Powi.call2(self, exponent) } } -impl FloatCopySignExpr for Expr -where - X::Scalar: Floating, -{ - type Output = Self; - fn copy_sign(self, sign: Self) -> Self::Output { - Func::Copysign.call2(self, sign) +impl_ops_trait!([X: Linear] FloatLerpExpr[FloatLerpThis] for Expr where [X::Scalar: Floating] { + fn lerp[_lerp](self, other, frac) { Func::Lerp.call3(self, other, frac) } +}); + +// Traits for `track!`. + +impl StoreMaybeExpr for &mut T { + fn store(self, value: T) { + *self = value; + } +} +impl> StoreMaybeExpr for &Var { + fn store(self, value: E) { + crate::lang::_store(self, &value.as_expr()); + } +} +impl> StoreMaybeExpr for Var { + fn store(self, value: E) { + crate::lang::_store(&self, &value.as_expr()); + } +} + +impl SelectMaybeExpr for bool { + fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R { + if self { + on() + } else { + off() + } + } + fn select(self, on: R, off: R) -> R { + if self { + on + } else { + off + } + } +} +impl SelectMaybeExpr for Expr { + fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R { + crate::lang::control_flow::if_then_else(self, on, off) + } + fn select(self, on: R, off: R) -> R { + crate::lang::control_flow::select(self, on, off) } } -impl FloatStepExpr for Expr -where - X::Scalar: Floating, -{ - type Output = Self; - fn step(self, edge: Self) -> Self::Output { - Func::Step.call2(edge, self) +impl ActivateMaybeExpr for bool { + fn activate(self, then: impl FnOnce()) { + if self { + then() + } + } +} +impl ActivateMaybeExpr for Expr { + fn activate(self, then: impl FnOnce()) { + crate::lang::control_flow::if_then_else(self, then, || {}) } } -impl FloatSmoothStepExpr for Expr -where - X::Scalar: Floating, -{ - type Output = Self; - fn smooth_step(self, edge0: Self, edge1: Self) -> Self::Output { - Func::SmoothStep.call3(edge0, edge1, self) +impl LoopMaybeExpr for bool { + fn while_loop(mut cond: impl FnMut() -> Self, mut body: impl FnMut()) { + while cond() { + body() + } } } -impl FloatArcTan2Expr for Expr -where - X::Scalar: Floating, -{ - type Output = Self; - fn atan2(self, other: Self) -> Self::Output { - Func::Atan2.call2(self, other) +impl LoopMaybeExpr for Expr { + fn while_loop(cond: impl FnMut() -> Self, body: impl FnMut()) { + crate::lang::control_flow::generic_loop(cond, body, || {}) } } -impl FloatLogExpr for Expr -where - X::Scalar: Floating, -{ - type Output = Self; - fn log(self, base: Self) -> Self::Output { - self.ln() / base.ln() +impl LazyBoolMaybeExpr for bool { + type Bool = bool; + fn and(self, other: impl FnOnce() -> bool) -> bool { + self && other() + } + fn or(self, other: impl FnOnce() -> bool) -> bool { + self || other() + } +} +impl LazyBoolMaybeExpr> for bool { + type Bool = Expr; + fn and(self, other: impl FnOnce() -> Expr) -> Self::Bool { + if self { + other() + } else { + false.expr() + } + } + fn or(self, other: impl FnOnce() -> Expr) -> Self::Bool { + if self { + true.expr() + } else { + other() + } + } +} +impl LazyBoolMaybeExpr for Expr { + type Bool = Expr; + fn and(self, other: impl FnOnce() -> bool) -> Self::Bool { + if other() { + self + } else { + false.expr() + } + } + fn or(self, other: impl FnOnce() -> bool) -> Self::Bool { + if other() { + true.expr() + } else { + self + } + } +} +impl LazyBoolMaybeExpr for Expr { + type Bool = Expr; + fn and(self, other: impl FnOnce() -> Expr) -> Self::Bool { + crate::lang::control_flow::if_then_else(self, other, || false.expr()) + } + fn or(self, other: impl FnOnce() -> Expr) -> Self::Bool { + crate::lang::control_flow::if_then_else(self, || true.expr(), other) } } -impl FloatPowfExpr for Expr + +impl EqMaybeExpr for T where - X::Scalar: Floating, + T: EqExpr, { - type Output = Self; - - fn powf(self, exponent: Self) -> Self::Output { - Func::Powf.call2(self, exponent) + type Bool = >::Output; + fn __eq(self, other: S) -> Self::Bool { + self.eq(other) + } + fn __ne(self, other: S) -> Self::Bool { + self.ne(other) } } -impl> FloatPowiExpr> for Expr +impl EqMaybeExpr for T where - X::Scalar: Floating, + T: PartialEq, { - type Output = Self; + type Bool = bool; + fn __eq(self, other: S) -> Self::Bool { + self == other + } + fn __ne(self, other: S) -> Self::Bool { + self != other + } +} - fn powi(self, exponent: Expr) -> Self::Output { - Func::Powi.call2(self, exponent) +impl CmpMaybeExpr for T +where + T: CmpExpr, +{ + type Bool = >::Output; + fn __lt(self, other: S) -> Self::Bool { + self.lt(other) + } + fn __le(self, other: S) -> Self::Bool { + self.le(other) + } + fn __gt(self, other: S) -> Self::Bool { + self.gt(other) + } + fn __ge(self, other: S) -> Self::Bool { + self.ge(other) } } -impl FloatLerpExpr for Expr +impl CmpMaybeExpr for T where - X::Scalar: Floating, + T: PartialOrd, { - type Output = Self; - - fn lerp(self, other: Self, frac: Self) -> Self::Output { - Func::Lerp.call3(self, other, frac) + type Bool = bool; + fn __lt(self, other: S) -> Self::Bool { + self < other + } + fn __le(self, other: S) -> Self::Bool { + self <= other + } + fn __gt(self, other: S) -> Self::Bool { + self > other + } + fn __ge(self, other: S) -> Self::Bool { + self >= other } } diff --git a/luisa_compute/src/lang/ops/spread.rs b/luisa_compute/src/lang/ops/spread.rs index b52daba..ae08e57 100644 --- a/luisa_compute/src/lang/ops/spread.rs +++ b/luisa_compute/src/lang/ops/spread.rs @@ -7,20 +7,18 @@ trait SpreadOps { fn lift_other(x: Other) -> Expr; } -macro_rules! impl_spread_single { - ([$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { +macro_rules! impl_spread { + (@sym [$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { impl<$($bounds)*> SpreadOps<$S> for $T { type Join = $J; fn lift_self($x: $T) -> Expr { $f } - fn lift_other($x: $S) -> Expr { + fn lift_other($y: $S) -> Expr { $g } } }; -} -macro_rules! impl_spread { ([$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { impl<$($bounds)*> SpreadOps<$S> for $T { type Join = $J; @@ -42,57 +40,97 @@ macro_rules! impl_spread { } }; } -impl_spread!([T: Linear] T: |x| x.expr(), Expr: |x| x => Expr); -impl_spread!([T: Linear] &T: |x| x.expr(), Expr: |x| x => Expr); -impl_spread!([T: Linear] T: |x| x.expr(), &Expr: |x| x.clone() => Expr); -impl_spread!([T: Linear] &T: |x| x.expr(), &Expr: |x| x.clone() => Expr); -impl_spread!([T: Linear] Expr: |x| x, &Expr: |x| x.clone() => Expr); -impl_spread_single!([T: Linear] &Expr: |x| x.clone(), &Expr: |x| x.clone() => Expr); +macro_rules! call_linear_fn_spread { + ($f:ident [$($bounds:tt)*]($T:ty)) => { + $f!([$($bounds)*] $T: |x| x.expr(), Expr<$T>: |x| x => Expr<$T>); + $f!(['a, $($bounds)*] &'a $T: |x| x.expr(), Expr<$T>: |x| x => Expr<$T>); + $f!(['b, $($bounds)*] $T: |x| x.expr(), &'b Expr<$T>: |x| x.clone() => Expr<$T>); + $f!(['a, 'b, $($bounds)*] &'a $T: |x| x.expr(), &'b Expr<$T>: |x| x.clone() => Expr<$T>); -impl_spread!([T: Linear] T: |x| x.expr(), Var: |x| x.load() => Expr); -impl_spread!([T: Linear] &T: |x| x.expr(), Var: |x| x.load() => Expr); -impl_spread!([T: Linear] T: |x| x.expr(), &Var: |x| x.load() => Expr); -impl_spread!([T: Linear] &T: |x| x.expr(), &Var: |x| x.load() => Expr); + $f!(['b, $($bounds)*] Expr<$T>: |x| x, &'b Expr<$T>: |x| x.clone() => Expr<$T>); + $f!(['b, $($bounds)*] Var<$T>: |x| x.load(), &'b Var<$T>: |x| x.load() => Expr<$T>); + $f!(@sym ['a, 'b, $($bounds)*] &'a Expr<$T>: |x| x.clone(), &'b Expr<$T>: |x| x.clone() => Expr<$T>); + $f!(@sym [$($bounds)*] Var<$T>: |x| x.load(), Var<$T>: |x| x.load() => Expr<$T>); + $f!(@sym ['a, 'b, $($bounds)*] &'a Var<$T>: |x| x.load(), &'b Var<$T>: |x| x.load() => Expr<$T>); -// Other way is unneded because of the deref impl. -impl_spread_single!([T: Linear] &Expr: |x| x.clone(), Var: |x| x.load() => Expr); -impl_spread_single!([T: Linear] &Expr: |x| x.clone(), &Var: |x| x.load() => Expr); + $f!([$($bounds)*] $T: |x| x.expr(), Var<$T>: |x| x.load() => Expr<$T>); + $f!(['a, $($bounds)*] &'a $T: |x| x.expr(), Var<$T>: |x| x.load() => Expr<$T>); + $f!(['b, $($bounds)*] $T: |x| x.expr(), &'b Var<$T>: |x| x.load() => Expr<$T>); + $f!(['a, 'b, $($bounds)*] &'a $T: |x| x.expr(), &'b Var<$T>: |x| x.load() => Expr<$T>); + + $f!(['a, $($bounds)*] &'a Expr<$T>: |x| x.clone(), Var<$T>: |x| x.load() => Expr<$T>); + $f!(['a, 'b, $($bounds)*] &'a Expr<$T>: |x| x.clone(), &'b Var<$T>: |x| x.load() => Expr<$T>); + $f!([$($bounds)*] Expr<$T>: |x| x, Var<$T>: |x| x.load() => Expr<$T>); + $f!(['b, $($bounds)*] Expr<$T>: |x| x, &'b Var<$T>: |x| x.load() => Expr<$T>); + }; + ($f:ident [$T:ident]) => { + call_linear_fn_spread!($f [$T: Linear]($T)); + } +} -impl_spread!([const N: usize, T: VectorAlign] T: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); -impl_spread!([const N: usize, T: VectorAlign] &T: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); -impl_spread!([const N: usize, T: VectorAlign] Expr: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); -impl_spread!([const N: usize, T: VectorAlign] &Expr: |x| Vector::::splat_expr(x), Expr>: |x| x => Expr>); -impl_spread!([const N: usize, T: VectorAlign] T: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); -impl_spread!([const N: usize, T: VectorAlign] &T: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); -impl_spread!([const N: usize, T: VectorAlign] Expr: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); -impl_spread!([const N: usize, T: VectorAlign] &Expr: |x| Vector::::splat_expr(x), &Expr>: |x| x.clone() => Expr>); +call_linear_fn_spread!(impl_spread[T]); + +macro_rules! call_vector_fn_spread { + ($f:ident [$($bounds:tt)*]($N:tt, $T:ty) $Vt:ty, $Vsplat:path) => { + $f!([$($bounds)*] $T: |x| $Vsplat(x), Expr<$Vt>: |x| x => Expr<$Vt>); + $f!(['a, $($bounds)*] &'a $T: |x| $Vsplat(*x), Expr<$Vt>: |x| x => Expr<$Vt>); + $f!([$($bounds)*] Expr<$T>: |x| $Vsplat(x), Expr<$Vt>: |x| x => Expr<$Vt>); + $f!(['a, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), Expr<$Vt>: |x| x => Expr<$Vt>); + $f!(['b, $($bounds)*] $T: |x| $Vsplat(x), &'b Expr<$Vt>: |x| x.clone() => Expr<$Vt>); + $f!(['a, 'b, $($bounds)*] &'a $T: |x| $Vsplat(*x), &'b Expr<$Vt>: |x| x.clone() => Expr<$Vt>); + $f!(['b, $($bounds)*] Expr<$T>: |x| $Vsplat(x), &'b Expr<$Vt>: |x| x.clone() => Expr<$Vt>); + $f!(['a, 'b, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), &'b Expr<$Vt>: |x| x.clone() => Expr<$Vt>); + + $f!([$($bounds)*] Expr<$T>: |x| $Vsplat(x), $Vt: |x| x.expr() => Expr<$Vt>); + $f!(['a, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), $Vt: |x| x.expr() => Expr<$Vt>); + $f!(['b, $($bounds)*] Expr<$T>: |x| $Vsplat(x), &'b $Vt: |x| x.expr() => Expr<$Vt>); + $f!(['a, 'b, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), &'b $Vt: |x| x.expr() => Expr<$Vt>); + + $f!([$($bounds)*] $T: |x| $Vsplat(x), Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['a, $($bounds)*] &'a $T: |x| $Vsplat(*x), Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!([$($bounds)*] Expr<$T>: |x| $Vsplat(x), Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['a, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['b, $($bounds)*] $T: |x| $Vsplat(x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['a, 'b, $($bounds)*] &'a $T: |x| $Vsplat(*x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['b, $($bounds)*] Expr<$T>: |x| $Vsplat(x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['a, 'b, $($bounds)*] &'a Expr<$T>: |x| $Vsplat(x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); + + $f!([$($bounds)*] Var<$T>: |x| $Vsplat(x), Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['a, $($bounds)*] &'a Var<$T>: |x| $Vsplat(x), Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['b, $($bounds)*] Var<$T>: |x| $Vsplat(x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); + $f!(['a, 'b, $($bounds)*] &'a Var<$T>: |x| $Vsplat(x), &'b Var<$Vt>: |x| x.load() => Expr<$Vt>); + }; + ($f:ident [$($bounds:tt)*]($N:tt, $T:ty)) => { + call_vector_fn_spread!($f[$($bounds)*]($N, $T) Vector<$T, $N>, Vector::<$T, $N>::splat_expr); + }; + ($f:ident[$N:ident, $T:ident]) => { + call_vector_fn_spread!($f[const $N: usize, $T: VectorAlign<$N>]($N, $T)); + } +} -impl_spread!([const N: usize, T: VectorAlign] Expr: |x| Vector::::splat_expr(x), Vector: |x| x.expr() => Expr>); -impl_spread!([const N: usize, T: VectorAlign] &Expr: |x| Vector::::splat_expr(x), Vector: |x| x.expr() => Expr>); -impl_spread!([const N: usize, T: VectorAlign] Expr: |x| Vector::::splat_expr(x), &Vector: |x| x.expr() => Expr>); -impl_spread!([const N: usize, T: VectorAlign] &Expr: |x| Vector::::splat_expr(x), &Vector: |x| x.expr() => Expr>); +call_vector_fn_spread!(impl_spread[N, T]); -mod impls { +mod trait_impls { use super::*; impl MinMaxExpr for T where T: SpreadOps, - Expr: MinMaxExpr, + Expr: MinMaxThis, { - type Output = as MinMaxExpr>::Output; + type Output = Expr; fn max(self, other: S) -> Self::Output { - Expr::::max(Self::lift_self(self), Self::lift_other(other)) + Expr::::_max(Self::lift_self(self), Self::lift_other(other)) } fn min(self, other: S) -> Self::Output { - Expr::::min(Self::lift_self(self), Self::lift_other(other)) + Expr::::_min(Self::lift_self(self), Self::lift_other(other)) } } impl ClampExpr for T where S: SpreadOps, - T: SpreadOps, - Expr: ClampExpr, + T: SpreadOps>, + Expr: ClampThis, { /// T::Join /// / \ @@ -106,9 +144,9 @@ mod impls { /// / / \ /// T S U - type Output = as ClampExpr>::Output; - fn clamp(&self, min: S, max: U) -> Self::Output { - Expr::::clamp( + type Output = Expr; + fn clamp(self, min: S, max: U) -> Self::Output { + Expr::::_clamp( Self::lift_self(self), Self::lift_other(S::lift_self(min)), Self::lift_other(S::lift_other(max)), @@ -118,43 +156,44 @@ mod impls { impl EqExpr for T where T: SpreadOps, - Expr: EqExpr, + Expr: EqThis, { - type Output = as EqExpr>::Output; + type Output = as EqThis>::Output; fn eq(self, other: S) -> Self::Output { - Expr::::eq(Self::lift_self(self), Self::lift_other(other)) + Expr::::_eq(Self::lift_self(self), Self::lift_other(other)) } fn ne(self, other: S) -> Self::Output { - Expr::::ne(Self::lift_self(self), Self::lift_other(other)) + Expr::::_ne(Self::lift_self(self), Self::lift_other(other)) } } impl CmpExpr for T where T: SpreadOps, - Expr: CmpExpr, + Expr: CmpThis, { + type Output = as CmpThis>::Output; fn lt(self, other: S) -> Self::Output { - Expr::::lt(Self::lift_self(self), Self::lift_other(other)) + Expr::::_lt(Self::lift_self(self), Self::lift_other(other)) } fn le(self, other: S) -> Self::Output { - Expr::::le(Self::lift_self(self), Self::lift_other(other)) + Expr::::_le(Self::lift_self(self), Self::lift_other(other)) } fn gt(self, other: S) -> Self::Output { - Expr::::gt(Self::lift_self(self), Self::lift_other(other)) + Expr::::_gt(Self::lift_self(self), Self::lift_other(other)) } fn ge(self, other: S) -> Self::Output { - Expr::::ge(Self::lift_self(self), Self::lift_other(other)) + Expr::::_ge(Self::lift_self(self), Self::lift_other(other)) } } impl FloatMulAddExpr for T where S: SpreadOps, - T: SpreadOps, - Expr: FloatMulAddExpr, + T: SpreadOps>, + Expr: FloatMulAddThis, { - type Output = as FloatMulAddExpr>::Output; + type Output = Expr; fn mul_add(self, mul: S, add: U) -> Self::Output { - Expr::::mul_add( + Expr::::_mul_add( Self::lift_self(self), Self::lift_other(S::lift_self(mul)), Self::lift_other(S::lift_other(add)), @@ -164,32 +203,32 @@ mod impls { impl FloatCopySignExpr for T where T: SpreadOps, - Expr: FloatCopySignExpr, + Expr: FloatCopySignThis, { - type Output = as FloatCopySignExpr>::Output; + type Output = Expr; fn copy_sign(self, sign: S) -> Self::Output { - Expr::::copy_sign(Self::lift_self(self), Self::lift_other(sign)) + Expr::::_copy_sign(Self::lift_self(self), Self::lift_other(sign)) } } impl FloatStepExpr for T where T: SpreadOps, - Expr: FloatStepExpr, + Expr: FloatStepThis, { - type Output = as FloatStepExpr>::Output; + type Output = Expr; fn step(self, edge: S) -> Self::Output { - Expr::::step(Self::lift_self(self), Self::lift_other(edge)) + Expr::::_step(Self::lift_self(self), Self::lift_other(edge)) } } impl FloatSmoothStepExpr for T where S: SpreadOps, - T: SpreadOps, - Expr: FloatSmoothStepExpr, + T: SpreadOps>, + Expr: FloatSmoothStepThis, { - type Output = as FloatSmoothStepExpr>::Output; + type Output = Expr; fn smooth_step(self, edge0: S, edge1: U) -> Self::Output { - Expr::::smooth_step( + Expr::::_smooth_step( Self::lift_self(self), Self::lift_other(S::lift_self(edge0)), Self::lift_other(S::lift_other(edge1)), @@ -199,42 +238,42 @@ mod impls { impl FloatArcTan2Expr for T where T: SpreadOps, - Expr: FloatArcTan2Expr, + Expr: FloatArcTan2This, { - type Output = as FloatArcTan2Expr>::Output; + type Output = Expr; fn atan2(self, other: S) -> Self::Output { - Expr::::atan2(Self::lift_self(self), Self::lift_other(other)) + Expr::::_atan2(Self::lift_self(self), Self::lift_other(other)) } } impl FloatLogExpr for T where T: SpreadOps, - Expr: FloatLogExpr, + Expr: FloatLogThis, { - type Output = as FloatLogExpr>::Output; + type Output = Expr; fn log(self, base: S) -> Self::Output { - Expr::::log(Self::lift_self(self), Self::lift_other(base)) + Expr::::_log(Self::lift_self(self), Self::lift_other(base)) } } impl FloatPowfExpr for T where T: SpreadOps, - Expr: FloatPowfExpr, + Expr: FloatPowfThis, { - type Output = as FloatPowfExpr>::Output; + type Output = Expr; fn powf(self, exponent: S) -> Self::Output { - Expr::::powf(Self::lift_self(self), Self::lift_other(exponent)) + Expr::::_powf(Self::lift_self(self), Self::lift_other(exponent)) } } impl FloatLerpExpr for T where S: SpreadOps, - T: SpreadOps, - Expr: FloatLerpExpr, + T: SpreadOps>, + Expr: FloatLerpThis, { - type Output = as FloatLerpExpr>::Output; + type Output = Expr; fn lerp(self, other: S, frac: U) -> Self::Output { - Expr::::lerp( + Expr::::_lerp( Self::lift_self(self), Self::lift_other(S::lift_self(other)), Self::lift_other(S::lift_other(frac)), @@ -244,32 +283,74 @@ mod impls { } macro_rules! impl_spread_op { ([ $($bounds:tt)* ]: $Op:ident::$op_fn:ident for $T:ty, $S:ty) => { - impl<$($bounds)*> $Op <$S> for $T where $T: SpreadOps<$T>, Expr<$T::Join>: $Op { - type Output = as $Op>::Output; + impl<$($bounds)*> $Op <$S> for $T where $T: SpreadOps<$S>, Expr<<$T as SpreadOps<$S>>::Join>: $Op { + type Output = >::Join> as $Op>::Output; fn $op_fn (self, other: $S) -> Self::Output { - as $Op>::$op_fn (Self::lift_self(self), Self::lift_other(other)) + >::Join> as $Op>::$op_fn (<$T as SpreadOps<$S>>::lift_self(self), <$T as SpreadOps<$S>>::lift_other(other)) } } } } -macro_rules! impl_num_spread { - ([ $($bounds:tt)* ]: $T:ty, $S:ty) => { +macro_rules! impl_num_spread_single { + ([ $($bounds:tt)* ] $T:ty, $S:ty) => { impl_spread_op!( [ $($bounds)* ]: Add::add for $T, $S); impl_spread_op!( [ $($bounds)* ]: Sub::sub for $T, $S); impl_spread_op!( [ $($bounds)* ]: Mul::mul for $T, $S); impl_spread_op!( [ $($bounds)* ]: Div::div for $T, $S); impl_spread_op!( [ $($bounds)* ]: Rem::rem for $T, $S); + } +} +macro_rules! impl_int_spread_single { + ([ $($bounds:tt)* ] $T:ty, $S:ty) => { + impl_spread_op!([ $($bounds)* ]: BitAnd::bitand for $T, $S); + impl_spread_op!([ $($bounds)* ]: BitOr::bitor for $T, $S); + impl_spread_op!([ $($bounds)* ]: BitXor::bitxor for $T, $S); + impl_spread_op!([ $($bounds)* ]: Shl::shl for $T, $S); + impl_spread_op!([ $($bounds)* ]: Shr::shr for $T, $S); + } +} + +macro_rules! impl_num_spread { + (@sym [$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { + impl_num_spread_single!([$($bounds)*] $T, $S); + }; + ([$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { + impl_num_spread_single!([$($bounds)*] $T, $S); + impl_num_spread_single!([$($bounds)*] $S, $T); + } +} +macro_rules! impl_int_spread { + (@sym [$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { + impl_int_spread_single!([$($bounds)*] $T, $S); + }; + ([$($bounds:tt)*] $T:ty : |$x:ident| $f:expr, $S:ty : |$y:ident| $g:expr => Expr<$J:ty>) => { + impl_int_spread_single!([$($bounds)*] $T, $S); + impl_int_spread_single!([$($bounds)*] $S, $T); + } +} +macro_rules! call_spreads { + ($f:ident: $($T:ty),+) => { + $( + call_linear_fn_spread!($f []($T)); + call_vector_fn_spread!($f [](2, $T)); + call_vector_fn_spread!($f [](3, $T)); + call_vector_fn_spread!($f [](4, $T)); + )+ }; } +call_spreads!(impl_num_spread: f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64); +call_spreads!(impl_int_spread: bool, i8, i16, i32, i64, u8, u16, u32, u64); +#[allow(dead_code)] mod tests { + use super::*; fn test() { let x = 10.0f32; - let y = 20.0f32; - let z = x.min(y); + let y = Vector::<_, 2>::splat(20.0f32); + let x = x.expr(); - let w = x.expr().min(y); + let w = (&x.var()).min(&0.0_f32.expr()); println!("{:?}", w); } } diff --git a/luisa_compute/src/lang/ops/traits.rs b/luisa_compute/src/lang/ops/traits.rs index 68ec7e7..08732a2 100644 --- a/luisa_compute/src/lang/ops/traits.rs +++ b/luisa_compute/src/lang/ops/traits.rs @@ -1,35 +1,78 @@ use super::*; -pub trait MinMaxExpr { - type Output; - - fn max(self, other: T) -> Self::Output; - fn min(self, other: T) -> Self::Output; -} - -pub trait ClampExpr { - type Output; - - fn clamp(self, min: A, max: B) -> Self::Output; -} +macro_rules! ops_trait { + ( + $TraitExpr:ident<$($T:ident),*> [ $TraitThis:ident] { + $( + fn $fn:ident [$fn_this:ident] (self, $($arg:ident: $S:ident),*); + )+ + } + ) => { + pub(crate) trait $TraitThis { + $( + fn $fn_this(self, $($arg: Self),*) -> Self; + )* + } + pub trait $TraitExpr<$($T = Self),*> { + type Output; + + $( + fn $fn(self, $($arg: $S),*) -> Self::Output; + )* + } + }; + ( + $TraitExpr:ident<$($T:ident),*> [ $TraitThis:ident] { + type Output; + $( + fn $fn:ident [$fn_this:ident] (self, $($arg:ident: $S:ident),*); + )+ + } + ) => { + pub(crate) trait $TraitThis { + type Output; + $( + fn $fn_this(self, $($arg: Self),*) -> Self::Output; + )* + } + pub trait $TraitExpr<$($T = Self),*> { + type Output; + + $( + fn $fn(self, $($arg: $S),*) -> Self::Output; + )* + } + } +} + +ops_trait!(MinMaxExpr[MinMaxThis] { + fn max[_max](self, other: T); + fn min[_min](self, other: T); +}); + +ops_trait!(ClampExpr[ClampThis] { + fn clamp[_clamp](self, min: A, max: B); +}); pub trait AbsExpr { fn abs(&self) -> Self; } -pub trait EqExpr { +ops_trait!(EqExpr[EqThis] { type Output; - fn eq(self, other: T) -> Self::Output; - fn ne(self, other: T) -> Self::Output; -} + fn eq[_eq](self, other: T); + fn ne[_ne](self, other: T); +}); -pub trait CmpExpr: EqExpr { - fn lt(self, other: T) -> Self::Output; - fn le(self, other: T) -> Self::Output; - fn gt(self, other: T) -> Self::Output; - fn ge(self, other: T) -> Self::Output; -} +ops_trait!(CmpExpr[CmpThis] { + type Output; + + fn lt[_lt](self, other: T); + fn le[_le](self, other: T); + fn gt[_gt](self, other: T); + fn ge[_ge](self, other: T); +}); pub trait IntExpr { fn rotate_right(&self, n: Expr) -> Self; @@ -72,70 +115,78 @@ pub trait FloatExpr: Sized { fn recip(&self) -> Self; fn sin_cos(&self) -> (Self, Self); } -pub trait FloatMulAddExpr { - type Output; - fn mul_add(self, a: A, b: B) -> Self::Output; -} -pub trait FloatCopySignExpr { - type Output; +ops_trait!(FloatMulAddExpr[FloatMulAddThis] { + fn mul_add[_mul_add](self, a: A, b: B); +}); - fn copy_sign(self, sign: T) -> Self::Output; -} -pub trait FloatStepExpr { - type Output; +ops_trait!(FloatCopySignExpr[FloatCopySignThis] { + fn copy_sign[_copy_sign](self, sign: T); +}); - fn step(self, edge: T) -> Self::Output; -} -pub trait FloatSmoothStepExpr { - type Output; +ops_trait!(FloatStepExpr[FloatStepThis] { + fn step[_step](self, edge: T); +}); - fn smooth_step(self, edge0: T, edge1: S) -> Self::Output; -} -pub trait FloatArcTan2Expr { - type Output; +ops_trait!(FloatSmoothStepExpr[FloatSmoothStepThis] { + fn smooth_step[_smooth_step](self, edge0: T, edge1: S); +}); - fn atan2(self, other: T) -> Self::Output; -} -pub trait FloatLogExpr { - type Output; +ops_trait!(FloatArcTan2Expr[FloatArcTan2This] { + fn atan2[_atan2](self, other: T); +}); - fn log(self, base: T) -> Self::Output; -} -pub trait FloatPowfExpr { - type Output; +ops_trait!(FloatLogExpr[FloatLogThis] { + fn log[_log](self, base: T); +}); + +ops_trait!(FloatPowfExpr[FloatPowfThis] { + fn powf[_powf](self, exponent: T); +}); - fn powf(self, exponent: T) -> Self::Output; -} pub trait FloatPowiExpr { type Output; fn powi(self, exponent: T) -> Self::Output; } -pub trait FloatLerpExpr { - type Output; - fn lerp(self, other: T, frac: S) -> Self::Output; -} +ops_trait!(FloatLerpExpr[FloatLerpThis] { + fn lerp[_lerp](self, other: A, frac: B); +}); -pub trait StoreExpr { +pub trait StoreMaybeExpr { fn store(self, value: V); } -pub trait SwitchExpr { - fn switch(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R; +pub trait SelectMaybeExpr { + fn if_then_else(self, on: impl FnOnce() -> R, off: impl FnOnce() -> R) -> R; + fn select(self, on: R, off: R) -> R; } -pub trait ActivateExpr { +pub trait ActivateMaybeExpr { fn activate(self, then: impl FnOnce()); } -pub trait LoopExpr { +pub trait LoopMaybeExpr { fn while_loop(cond: impl FnMut() -> Self, body: impl FnMut()); } -pub trait LazyBoolExpr { +pub trait LazyBoolMaybeExpr { type Bool; fn and(self, other: impl FnOnce() -> T) -> Self::Bool; fn or(self, other: impl FnOnce() -> T) -> Self::Bool; } + +pub trait EqMaybeExpr { + type Bool; + fn __eq(self, other: T) -> Self::Bool; + fn __ne(self, other: T) -> Self::Bool; +} + +pub trait CmpMaybeExpr { + type Bool; + fn __lt(self, other: T) -> Self::Bool; + fn __le(self, other: T) -> Self::Bool; + fn __gt(self, other: T) -> Self::Bool; + fn __ge(self, other: T) -> Self::Bool; +} diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs index 38195fb..f845920 100644 --- a/luisa_compute/src/lang/types.rs +++ b/luisa_compute/src/lang/types.rs @@ -198,13 +198,9 @@ impl Var { for node in nodes { ret.push(b.call(Func::Load, &[node], node.type_().clone())); } - Expr::::from_nodes(&mut ret.into_iter()) + Expr::::from_nodes(&mut ret.into_iter()) }) } - pub fn store(&self, value: impl AsExpr) { - let value = value.as_expr(); - super::_store(self, &value); - } } pub fn _deref_proxy(proxy: &P) -> &Expr { @@ -245,8 +241,16 @@ macro_rules! impl_simple_var_proxy { } } +#[macro_export] +macro_rules! impl_marker_trait { + ($T:ident for $($t:ty),*) => { + $(impl $T for $t {})* + }; +} + mod private { use super::*; + pub trait Sealed {} impl Sealed for T {} impl Sealed for Expr {} @@ -300,12 +304,12 @@ impl AsExpr for T { } impl AsExpr for Expr { fn as_expr(&self) -> Expr { - *self + self.clone() } } impl AsExpr for &Expr { fn as_expr(&self) -> Expr { - **self + (*self).clone() } } impl AsExpr for Var { diff --git a/luisa_compute/src/lang/types/array.rs b/luisa_compute/src/lang/types/array.rs index 7ec31b1..04c6029 100644 --- a/luisa_compute/src/lang/types/array.rs +++ b/luisa_compute/src/lang/types/array.rs @@ -26,10 +26,10 @@ impl Index for ArrayExpr { let i = i.to_u64(); // TODO: Add need_runtime_check()? - lc_assert!(i.cmplt((N as u64).expr())); + lc_assert!(i.lt((N as u64).expr())); Expr::::from_node(__current_scope(|b| { - b.call(Func::ExtractElement, &[self.node, i.node()], T::type_()) + b.call(Func::ExtractElement, &[self.0.node, i.node()], T::type_()) })) ._ref() } @@ -99,7 +99,7 @@ impl VLArrayVar { pub fn read>>(&self, i: I) -> Expr { let i = i.into(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len()), "VLArrayVar::read out of bounds"); + lc_assert!(i.lt(self.len()), "VLArrayVar::read out of bounds"); } Expr::::from_node(__current_scope(|b| { @@ -124,7 +124,7 @@ impl VLArrayVar { let value = value.into(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len()), "VLArrayVar::read out of bounds"); + lc_assert!(i.lt(self.len()), "VLArrayVar::read out of bounds"); } __current_scope(|b| { @@ -175,7 +175,7 @@ impl VLArrayExpr { pub fn read(&self, i: I) -> Expr { let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::::from_node(__current_scope(|b| { diff --git a/luisa_compute/src/lang/types/core.rs b/luisa_compute/src/lang/types/core.rs index 5bf0660..4c968a0 100644 --- a/luisa_compute/src/lang/types/core.rs +++ b/luisa_compute/src/lang/types/core.rs @@ -19,7 +19,7 @@ mod private { pub trait Primitive: private::Sealed + Copy + TypeOf + 'static { fn const_(&self) -> Const; - fn primitive(&self) -> ir::Primitive; + fn primitive() -> ir::Primitive; } impl Value for T { type Expr = PrimitiveExpr; @@ -27,7 +27,7 @@ impl Value for T { type ExprData = (); type VarData = (); - fn expr(&self) -> Expr { + fn expr(self) -> Expr { let node = __current_scope(|s| -> NodeRef { s.const_(self.const_()) }); Expr::::from_node(node) } @@ -40,133 +40,119 @@ impl Primitive for bool { fn const_(&self) -> Const { Const::Bool(*self) } - fn primitive(&self) -> ir::Primitive { + fn primitive() -> ir::Primitive { ir::Primitive::Bool } } impl Primitive for f16 { fn const_(&self) -> Const { - Const::F16(*self) + Const::Float16(*self) } - fn primitive(&self) -> ir::Primitive { - ir::Primitive::F16 + fn primitive() -> ir::Primitive { + ir::Primitive::Float16 } } impl Primitive for f32 { fn const_(&self) -> Const { - Const::F32(*self) + Const::Float32(*self) } - fn primitive(&self) -> ir::Primitive { - ir::Primitive::F32 + fn primitive() -> ir::Primitive { + ir::Primitive::Float32 } } impl Primitive for f64 { fn const_(&self) -> Const { - Const::F64(*self) + Const::Float64(*self) } - fn primitive(&self) -> ir::Primitive { - ir::Primitive::F64 + fn primitive() -> ir::Primitive { + ir::Primitive::Float64 } } impl Primitive for i8 { fn const_(&self) -> Const { - todo!() // Const::I8(*self) + Const::Int8(*self) + } + fn primitive() -> ir::Primitive { + ir::Primitive::Int8 } } impl Primitive for i16 { fn const_(&self) -> Const { - Const::I16(*self) + Const::Int16(*self) } - fn primitive(&self) -> ir::Primitive { + fn primitive() -> ir::Primitive { ir::Primitive::Int16 } } impl Primitive for i32 { fn const_(&self) -> Const { - Const::I32(*self) + Const::Int32(*self) } - fn primitive(&self) -> ir::Primitive { + fn primitive() -> ir::Primitive { ir::Primitive::Int32 } } impl Primitive for i64 { fn const_(&self) -> Const { - Const::I64(*self) + Const::Int64(*self) } - fn primitive(&self) -> ir::Primitive { + fn primitive() -> ir::Primitive { ir::Primitive::Int64 } } impl Primitive for u8 { fn const_(&self) -> Const { - todo!() // Const::U8(*self) + Const::Uint8(*self) + } + fn primitive() -> ir::Primitive { + ir::Primitive::Uint8 } } impl Primitive for u16 { fn const_(&self) -> Const { - Const::U16(*self) + Const::Uint16(*self) } - fn primitive(&self) -> ir::Primitive { - ir::Primitive::UInt16 + fn primitive() -> ir::Primitive { + ir::Primitive::Uint16 } } impl Primitive for u32 { fn const_(&self) -> Const { - Const::U32(*self) + Const::Uint32(*self) } - fn primitive(&self) -> ir::Primitive { - ir::Primitive::UInt32 + fn primitive() -> ir::Primitive { + ir::Primitive::Uint32 } } impl Primitive for u64 { fn const_(&self) -> Const { - Const::U64(*self) + Const::Uint64(*self) } - fn primitive(&self) -> ir::Primitive { - ir::Primitive::UInt64 + fn primitive() -> ir::Primitive { + ir::Primitive::Uint64 } } +macro_rules! impls { + ($T:ident for $($t:ty),*) => { + $(impl $T for $t {})* + }; +} + pub trait Integral: Primitive {} -impl Integral for bool {} -impl Integral for i8 {} -impl Integral for i16 {} -impl Integral for i32 {} -impl Integral for i64 {} -impl Integral for u8 {} -impl Integral for u16 {} -impl Integral for u32 {} -impl Integral for u64 {} +impls!(Integral for bool, i8, i16, i32, i64, u8, u16, u32, u64); pub trait Numeric: Primitive {} -impl Numeric for f16 {} -impl Numeric for f32 {} -impl Numeric for f64 {} -impl Numeric for i8 {} -impl Numeric for i16 {} -impl Numeric for i32 {} -impl Numeric for i64 {} -impl Numeric for u8 {} -impl Numeric for u16 {} -impl Numeric for u32 {} -impl Numeric for u64 {} +impls!(Numeric for f16, f32, f64, i8, i16, i32, i64, u8, u16, u32, u64); pub trait Floating: Numeric {} -impl Floating for f16 {} -impl Floating for f32 {} -impl Floating for f64 {} +impls!(Floating for f16, f32, f64); pub trait Signed: Numeric {} -impl Signed for f16 {} -impl Signed for f32 {} -impl Signed for f64 {} -impl Signed for i8 {} -impl Signed for i16 {} -impl Signed for i32 {} -impl Signed for i64 {} +impls!(Signed for f16, f32, f64, i8, i16, i32, i64); #[deprecated] pub type Bool = Expr; diff --git a/luisa_compute/src/lang/types/shared.rs b/luisa_compute/src/lang/types/shared.rs index a374d69..0b14635 100644 --- a/luisa_compute/src/lang/types/shared.rs +++ b/luisa_compute/src/lang/types/shared.rs @@ -47,7 +47,7 @@ impl Shared { let value = value.into(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len()), "VLArrayVar::read out of bounds"); + lc_assert!(i.lt(self.len()), "VLArrayVar::read out of bounds"); } __current_scope(|b| { diff --git a/luisa_compute/src/lang/types/vector.rs b/luisa_compute/src/lang/types/vector.rs index ae86b85..e8124ca 100644 --- a/luisa_compute/src/lang/types/vector.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -2,7 +2,6 @@ use super::alignment::*; use super::core::*; use super::*; use ir::{VectorElementType, VectorType}; -use serde::{Deserialize, Serialize}; use std::fmt::Debug; #[cfg(feature = "glam")] @@ -34,9 +33,8 @@ impl, const N: usize> Debug for Vector { } #[repr(C)] -#[derive(Copy, Clone, Hash, Default, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Copy, Clone, PartialEq, Eq)] pub struct Vector, const N: usize> { - #[serde(skip)] _align: T::A, elements: [T; N], } @@ -137,16 +135,16 @@ macro_rules! vector_proxies { unsafe impl> HasExprLayout< as Value>::ExprData> for $ExprName {} unsafe impl> HasVarLayout< as Value>::VarData> for $VarName {} - impl> ExprProxy for $ExprName { + impl>> ExprProxy for $ExprName { type Value = Vector; } impl> VectorExprProxy for $ExprName { type T = T; } - impl> VarProxy for $VarName { + impl>> VarProxy for $VarName { type Value = Vector; } - impl> Deref for $VarName { + impl>> Deref for $VarName { type Target = Expr>; fn deref(&self) -> &Self::Target { _deref_proxy(self) @@ -162,7 +160,7 @@ vector_proxies!(4 [x, y, z, w, r, g, b, a]: VectorExprProxy4, VectorVarProxy4); impl, const N: usize> TypeOf for Vector { fn type_() -> CArc { let type_ = Type::Vector(VectorType { - element: VectorElementType::Scalar(T::type_()), + element: VectorElementType::Scalar(T::primitive()), length: N as u32, }); register_type(type_) @@ -180,13 +178,13 @@ impl Vec2Swizzle for Vector { type Vec2 = Self; type Vec3 = Vector; type Vec4 = Vector; - fn permute2(&self, x: i32, y: i32) -> Self::Vec2 { + fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { self._permute2(x, y) } - fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3 { + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3 { self._permute3(x, y, z) } - fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4 { + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4 { self._permute4(x, y, z, w) } } @@ -194,13 +192,13 @@ impl Vec3Swizzle for Vector { type Vec2 = Vector; type Vec3 = Self; type Vec4 = Vector; - fn permute2(&self, x: i32, y: i32) -> Self::Vec2 { + fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { self._permute2(x, y) } - fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3 { + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3 { self._permute3(x, y, z) } - fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4 { + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4 { self._permute4(x, y, z, w) } } @@ -208,17 +206,17 @@ impl Vec4Swizzle for Vector { type Vec2 = Vector; type Vec3 = Vector; type Vec4 = Self; - fn permute2(&self, x: i32, y: i32) -> Self::Vec2 { + fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { self._permute2(x, y) } - fn permute3(&self, x: i32, y: i32, z: i32) -> Self::Vec3 { + fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3 { self._permute3(x, y, z) } - fn permute4(&self, x: i32, y: i32, z: i32, w: i32) -> Self::Vec4 { + fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4 { self._permute4(x, y, z, w) } } - +/* impl, const N: usize> VectorExprData { fn _permute2(&self, x: u32, y: u32) -> Expr> where @@ -330,3 +328,4 @@ impl Vec4Swizzle for VectorExprProxy4 { self._permute4(x, y, z, w) } } + */ diff --git a/luisa_compute/src/lang/types/vector/coords.rs b/luisa_compute/src/lang/types/vector/coords.rs index 583122c..5811603 100644 --- a/luisa_compute/src/lang/types/vector/coords.rs +++ b/luisa_compute/src/lang/types/vector/coords.rs @@ -23,8 +23,8 @@ macro_rules! impl_deref { impl> DerefMut for Vector { #[inline] - fn deref_mut(&self) -> &$T { - unsafe { &*(self as *const Self as *const $T) } + fn deref_mut(&mut self) -> &mut $T { + unsafe { &mut *(self as *mut Self as *mut $T) } } } }; @@ -50,8 +50,8 @@ impl Deref for XYZ { } impl DerefMut for XYZ { #[inline] - fn deref_mut(&self) -> &RGB { - unsafe { &*(self as *const Self as *const RGB) } + fn deref_mut(&mut self) -> &mut RGB { + unsafe { &mut *(self as *mut Self as *mut RGB) } } } impl Deref for XYZW { @@ -64,7 +64,7 @@ impl Deref for XYZW { } impl DerefMut for XYZW { #[inline] - fn deref_mut(&self) -> &RGBA { - unsafe { &*(self as *const Self as *const RGBA) } + fn deref_mut(&mut self) -> &mut RGBA { + unsafe { &mut *(self as *mut Self as *mut RGBA) } } } diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 0308cd7..bbcac49 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -25,7 +25,7 @@ pub mod prelude { pub use crate::lang::ops::*; pub use crate::lang::types::vector::swizzle::*; pub use crate::lang::types::vector::Vector; - pub use crate::lang::types::{Expr, ExprProxy, Value, Var, VarProxy}; + pub use crate::lang::types::{AsExpr, Expr, Value, Var}; pub use crate::lang::Aggregate; // pub use crate::resource::{IoTexel, StorageTexel, *}; // pub use crate::runtime::api::StreamTag; @@ -39,7 +39,7 @@ pub mod prelude { } mod internal_prelude { - pub(crate) use crate::lang::debug::{CpuFn, __env_need_backtrace, is_cpu_backend}; + pub(crate) use crate::lang::debug::{__env_need_backtrace, is_cpu_backend}; pub(crate) use crate::lang::ir::ffi::*; pub(crate) use crate::lang::ir::{ new_node, register_type, BasicBlock, Const, Func, Instruction, IrBuilder, Node, diff --git a/luisa_compute_track/src/lib.rs b/luisa_compute_track/src/lib.rs index 7b08f58..90f943d 100644 --- a/luisa_compute_track/src/lib.rs +++ b/luisa_compute_track/src/lib.rs @@ -54,7 +54,7 @@ impl VisitMut for TraceVisitor { }) = &**left { *node = parse_quote_spanned! {span=> - <_ as #trait_path::DerefSet>::deref_set(#expr, #right) + <_ as #trait_path::StoreMaybeExpr>::deref_set(#expr, #right) } } } @@ -65,11 +65,11 @@ impl VisitMut for TraceVisitor { if let Expr::Let(_) = **cond { } else if let Some((_, else_branch)) = else_branch { *node = parse_quote_spanned! {span=> - <_ as #trait_path::BoolIfElseMaybeExpr<_>>::if_then_else(#cond, || #then_branch, || #else_branch) + <_ as #trait_path::SelectMaybeExpr<_>>::select(#cond, || #then_branch, || #else_branch) } } else { *node = parse_quote_spanned! {span=> - <_ as #trait_path::BoolIfMaybeExpr>::if_then(#cond, || #then_branch) + <_ as #trait_path::ActivateMaybeExpr>::activate(#cond, || #then_branch) } } } @@ -77,7 +77,7 @@ impl VisitMut for TraceVisitor { let cond = &expr.cond; let body = &expr.body; *node = parse_quote_spanned! {span=> - <_ as #trait_path::BoolWhileMaybeExpr>::while_loop(|| #cond, || #body) + <_ as #trait_path::LoopMaybeExpr>::while_loop(|| #cond, || #body) } } Expr::Loop(expr) => { @@ -117,15 +117,15 @@ impl VisitMut for TraceVisitor { let op_fn = Ident::new(op_fn_str, expr.op.span()); if op_fn_str == "eq" || op_fn_str == "ne" { *node = parse_quote_spanned! {span=> - <_ as #trait_path::EqMaybeExpr<_>>::#op_fn(#left, #right) + <_ as #trait_path::EqMaybeExpr<_, _>>::#op_fn(#left, #right) } } else if op_fn_str == "and" || op_fn_str == "or" { *node = parse_quote_spanned! {span=> - <_ as #trait_path::BoolLazyOpsMaybeExpr<_>>::#op_fn(#left, || #right) + <_ as #trait_path::LazyBoolMaybeExpr<_>>::#op_fn(#left, || #right) } } else { *node = parse_quote_spanned! {span=> - <_ as #trait_path::PartialOrdMaybeExpr<_>>::#op_fn(#left, #right) + <_ as #trait_path::CmpMaybeExpr<_, _>>::#op_fn(#left, #right) } } } From f532ca8aafbc6a1d53cb25a2b0af724e39388b8e Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Wed, 20 Sep 2023 20:06:24 +0100 Subject: [PATCH 11/15] Revert "Commented out irrelevant things for debugging." This reverts commit 8fa839575d125d49783ecac019d4f243c16590b8. --- luisa_compute/src/lang.rs | 14 ++++++------ luisa_compute/src/lang/types.rs | 2 +- luisa_compute/src/lib.rs | 40 ++++++++++++++++----------------- 3 files changed, 27 insertions(+), 29 deletions(-) diff --git a/luisa_compute/src/lang.rs b/luisa_compute/src/lang.rs index 89f58c4..4fdf9cb 100644 --- a/luisa_compute/src/lang.rs +++ b/luisa_compute/src/lang.rs @@ -10,7 +10,7 @@ use crate::internal_prelude::*; use bumpalo::Bump; use indexmap::IndexMap; -// use crate::runtime::WeakDevice; +use crate::runtime::WeakDevice; pub mod ir { pub use luisa_compute_ir::context::register_type; @@ -26,11 +26,11 @@ use ir::{ pub mod control_flow; pub mod debug; -// pub mod diff; -// pub mod functions; +pub mod diff; +pub mod functions; pub mod index; pub mod ops; -// pub mod poly; +pub mod poly; pub mod types; pub(crate) trait CallFuncTrait { @@ -252,7 +252,7 @@ pub(crate) struct Recorder { pub(crate) cpu_custom_ops: IndexMap)>, pub(crate) callables: IndexMap, pub(crate) shared: Vec, - // pub(crate) device: Option, + pub(crate) device: Option, pub(crate) block_size: Option<[u32; 3]>, pub(crate) building_kernel: bool, pub(crate) pools: Option>, @@ -267,7 +267,7 @@ impl Recorder { self.cpu_custom_ops.clear(); self.callables.clear(); self.lock = false; - // self.device = None; + self.device = None; self.block_size = None; self.arena.reset(); self.shared.clear(); @@ -282,7 +282,7 @@ impl Recorder { cpu_custom_ops: IndexMap::new(), callables: IndexMap::new(), shared: vec![], - // device: None, + device: None, block_size: None, pools: None, arena: Bump::new(), diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs index f845920..1f7447d 100644 --- a/luisa_compute/src/lang/types.rs +++ b/luisa_compute/src/lang/types.rs @@ -5,7 +5,7 @@ use crate::internal_prelude::*; pub mod alignment; pub mod array; pub mod core; -// pub mod dynamic; +pub mod dynamic; pub mod shared; pub mod vector; diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index bbcac49..ece3d92 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -8,31 +8,30 @@ use std::path::Path; use std::sync::Arc; pub mod lang; -// pub mod printer; -// pub mod resource; -// pub mod rtx; -// pub mod runtime; +pub mod printer; +pub mod resource; +pub mod rtx; +pub mod runtime; pub mod prelude { pub use half::f16; - // pub use crate::lang::control_flow::{ - // break_, continue_, for_range, return_, return_v, select, switch, - // }; - // pub use crate::lang::functions::{block_size, dispatch_id, dispatch_size, - // set_block_size}; + pub use crate::lang::control_flow::{ + break_, continue_, for_range, return_, return_v, select, switch, + }; + pub use crate::lang::functions::{block_size, dispatch_id, dispatch_size, set_block_size}; pub use crate::lang::index::{IndexRead, IndexWrite}; pub use crate::lang::ops::*; pub use crate::lang::types::vector::swizzle::*; pub use crate::lang::types::vector::Vector; pub use crate::lang::types::{AsExpr, Expr, Value, Var}; pub use crate::lang::Aggregate; - // pub use crate::resource::{IoTexel, StorageTexel, *}; - // pub use crate::runtime::api::StreamTag; - // pub use crate::runtime::{ - // create_static_callable, Command, Device, KernelBuildOptions, Scope, - // Stream, }; - pub use crate::{cpu_dbg, lc_assert, lc_unreachable, struct_}; + pub use crate::resource::{IoTexel, StorageTexel, *}; + pub use crate::runtime::api::StreamTag; + pub use crate::runtime::{ + create_static_callable, Command, Device, KernelBuildOptions, Scope, Stream, + }; + pub use crate::{cpu_dbg, if_, lc_assert, lc_unreachable, loop_, struct_, while_, Context}; pub use luisa_compute_derive::*; pub use luisa_compute_track::track; @@ -51,11 +50,11 @@ mod internal_prelude { RECORDER, }; pub(crate) use crate::prelude::*; - // pub(crate) use crate::runtime::{ - // CallableArgEncoder, CallableParameter, CallableRet, KernelBuilder, - // }; + pub(crate) use crate::runtime::{ + CallableArgEncoder, CallableParameter, CallableRet, KernelBuilder, + }; pub(crate) use crate::{ - get_backtrace, impl_simple_expr_proxy, impl_simple_var_proxy, /* ResourceTracker, */ + get_backtrace, impl_simple_expr_proxy, impl_simple_var_proxy, ResourceTracker, }; pub(crate) use luisa_compute_backend::Backend; pub(crate) use std::marker::PhantomData; @@ -70,11 +69,10 @@ use lazy_static::lazy_static; use luisa_compute_backend::Backend; use parking_lot::lock_api::RawMutex as RawMutexTrait; use parking_lot::{Mutex, RawMutex}; -// use runtime::{Device, DeviceHandle, StreamHandle}; +use runtime::{Device, DeviceHandle, StreamHandle}; use std::collections::HashMap; use std::sync::Weak; -/* pub struct Context { inner: Arc, } From 5fb6137422b123091139a3afdd5ed039cf93ec44 Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Wed, 20 Sep 2023 20:06:42 +0100 Subject: [PATCH 12/15] Minor fix. --- luisa_compute/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index ece3d92..fd752ec 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -161,7 +161,6 @@ unsafe impl Send for ResourceTracker {} unsafe impl Sync for ResourceTracker {} - */ pub(crate) fn get_backtrace() -> Backtrace { Backtrace::force_capture() } From d5febd889bf05488b879f03aff746b7011de7eb5 Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Wed, 20 Sep 2023 21:20:49 +0100 Subject: [PATCH 13/15] Updated most things to fix except rtx. --- luisa_compute/src/lang/debug.rs | 2 +- luisa_compute/src/lang/diff.rs | 29 +-- luisa_compute/src/lang/functions.rs | 46 ++-- luisa_compute/src/lang/types/core.rs | 162 +++++++------- luisa_compute/src/lang/types/vector.rs | 129 +++++------ luisa_compute/src/lang/types/vector/impls.rs | 47 ++++ luisa_compute/src/lang/types/vector/legacy.rs | 208 ++++++++++++++++++ luisa_compute/src/lib.rs | 5 +- luisa_compute/src/printer.rs | 11 +- luisa_compute/src/resource.rs | 10 +- luisa_compute/src/runtime/kernel.rs | 10 +- 11 files changed, 458 insertions(+), 201 deletions(-) create mode 100644 luisa_compute/src/lang/types/vector/impls.rs create mode 100644 luisa_compute/src/lang/types/vector/legacy.rs diff --git a/luisa_compute/src/lang/debug.rs b/luisa_compute/src/lang/debug.rs index 7a31e96..36dbd5e 100644 --- a/luisa_compute/src/lang/debug.rs +++ b/luisa_compute/src/lang/debug.rs @@ -111,7 +111,7 @@ pub fn __cpu_dbg(arg: Expr, file: &'static str, line: u32) if !is_cpu_backend() { return; } - let f = CpuFn::new(move |x: &mut T::Value| { + let f = CpuFn::new(move |x: &mut V| { println!("[{}:{}] {:?}", file, line, x); }); let _ = f.call(arg); diff --git a/luisa_compute/src/lang/diff.rs b/luisa_compute/src/lang/diff.rs index 55278f1..ca454d9 100644 --- a/luisa_compute/src/lang/diff.rs +++ b/luisa_compute/src/lang/diff.rs @@ -34,7 +34,7 @@ impl AdContext { thread_local! { static AD_CONTEXT:RefCell = RefCell::new(AdContext::new_rev()); } -pub fn requires_grad(var: impl ExprProxy) { +pub fn requires_grad(var: Expr) { AD_CONTEXT.with(|c| { let c = c.borrow(); assert!(c.started, "autodiff section is not started"); @@ -49,15 +49,15 @@ pub fn requires_grad(var: impl ExprProxy) { }); } -pub fn backward(out: T) { +pub fn backward(out: Expr) { backward_with_grad( out, FromNode::from_node(__current_scope(|b| { let one = new_node( b.pools(), Node::new( - CArc::new(Instruction::Const(Const::One(::type_()))), - ::type_(), + CArc::new(Instruction::Const(Const::One(V::type_()))), + V::type_(), ), ); b.append(one); @@ -66,7 +66,7 @@ pub fn backward(out: T) { ); } -pub fn backward_with_grad(out: T, grad: T) { +pub fn backward_with_grad(out: Expr, grad: Expr) { AD_CONTEXT.with(|c| { let mut c = c.borrow_mut(); assert!(c.started, "autodiff section is not started"); @@ -83,19 +83,19 @@ pub fn backward_with_grad(out: T, grad: T) { } /// Gradient of a value in *Reverse mode* AD -pub fn gradient(var: T) -> T { +pub fn gradient(var: Expr) -> Expr { AD_CONTEXT.with(|c| { let c = c.borrow(); assert!(c.started, "autodiff section is not started"); assert!(!c.is_forward_mode, "gradient() is called in forward mode"); assert!(c.backward_called, "backward is not called"); }); - T::from_node(__current_scope(|b| { + Expr::::from_node(__current_scope(|b| { b.call(Func::Gradient, &[var.node()], var.node().type_().clone()) })) } /// Gradient of a value in *Reverse mode* AD -pub fn grad(var: T) -> T { +pub fn grad(var: Expr) -> Expr { gradient(var) } @@ -108,8 +108,8 @@ pub fn grad(var: T) -> T { // let ret = body(); // let fwd = pop_scope(); // __current_scope(|b| { -// let node = new_node(Node::new(CArc::new(Instruction::AdDetach(fwd)), Type::void())); -// b.append(node); +// let node = new_node(Node::new(CArc::new(Instruction::AdDetach(fwd)), +// Type::void())); b.append(node); // }); // let nodes = ret.to_vec_nodes(); // let nodes: Vec<_> = nodes @@ -124,7 +124,8 @@ pub fn detach(v: T) -> T { T::from_node(node) } -/// Start a *Forward mode* AD section that propagates N gradients w.r.t to input variable +/// Start a *Forward mode* AD section that propagates N gradients w.r.t to input +/// variable pub fn forward_autodiff(n_grads: usize, body: impl Fn()) { AD_CONTEXT.with(|c| { let mut c = c.borrow_mut(); @@ -152,7 +153,7 @@ pub fn forward_autodiff(n_grads: usize, body: impl Fn()) { } /// Propagate N gradients w.r.t to input variable using *Forward mode* AD -pub fn propagate_gradient(v: T, grads: &[T]) { +pub fn propagate_gradient(v: Expr, grads: &[Expr]) { AD_CONTEXT.with(|c| { let c = c.borrow(); assert_eq!(grads.len(), c.n_forward_grads); @@ -169,7 +170,7 @@ pub fn propagate_gradient(v: T, grads: &[T]) { }); } -pub fn output_gradients(v: T) -> Vec { +pub fn output_gradients(v: Expr) -> Vec> { let n = AD_CONTEXT.with(|c| { let c = c.borrow(); assert!(c.started, "autodiff section is not started"); @@ -183,7 +184,7 @@ pub fn output_gradients(v: T) -> Vec { let mut grads = vec![]; for i in 0..n { let idx = b.const_(Const::Int32(i as i32)); - grads.push(T::from_node(b.call( + grads.push(Expr::::from_node(b.call( Func::OutputGrad, &[v.node(), idx], v.node().type_().clone(), diff --git a/luisa_compute/src/lang/functions.rs b/luisa_compute/src/lang/functions.rs index 23be6fa..4094d01 100644 --- a/luisa_compute/src/lang/functions.rs +++ b/luisa_compute/src/lang/functions.rs @@ -1,5 +1,7 @@ use crate::internal_prelude::*; +use super::types::core::{Integral, Numeric}; + pub fn thread_id() -> Expr { Expr::::from_node(__current_scope(|b| { b.call(Func::ThreadId, &[], Uint3::type_()) @@ -54,7 +56,7 @@ pub fn warp_is_first_active_lane() -> Expr { b.call(Func::WarpIsFirstActiveLane, &[], Expr::::type_()) })) } -pub fn warp_active_all_equal(v: impl ScalarOrVector) -> Expr { +pub fn warp_active_all_equal(v: Expr) -> Expr { Expr::::from_node(__current_scope(|b| { b.call( Func::WarpActiveAllEqual, @@ -63,7 +65,10 @@ pub fn warp_active_all_equal(v: impl ScalarOrVector) -> Expr { ) })) } -pub fn warp_active_bit_and, E: IntVarTrait>(v: T) -> T { +pub fn warp_active_bit_and(v: Expr) -> Expr +where + T::Scalar: Integral + Numeric, +{ T::from_node(__current_scope(|b| { b.call( Func::WarpActiveBitAnd, @@ -73,7 +78,10 @@ pub fn warp_active_bit_and, E: IntVarTrait>(v: T) })) } -pub fn warp_active_bit_or, E: IntVarTrait>(v: T) -> T { +pub fn warp_active_bit_or(v: Expr) -> Expr +where + T::Scalar: Integral + Numeric, +{ T::from_node(__current_scope(|b| { b.call( Func::WarpActiveBitOr, @@ -83,7 +91,10 @@ pub fn warp_active_bit_or, E: IntVarTrait>(v: T) })) } -pub fn warp_active_bit_xor, E: IntVarTrait>(v: T) -> T { +pub fn warp_active_bit_xor(v: Expr) -> Expr +where + T::Scalar: Integral + Numeric, +{ T::from_node(__current_scope(|b| { b.call( Func::WarpActiveBitXor, @@ -93,26 +104,26 @@ pub fn warp_active_bit_xor, E: IntVarTrait>(v: T) })) } -pub fn warp_active_count_bits(v: impl Into>) -> Expr { +pub fn warp_active_count_bits(v: impl AsExpr) -> Expr { Expr::::from_node(__current_scope(|b| { b.call( Func::WarpActiveCountBits, - &[v.into().node()], + &[v.as_expr().node()], ::type_(), ) })) } -pub fn warp_active_max(v: T) -> T::Element { +pub fn warp_active_max(v: Expr) -> Expr { ::from_node(__current_scope(|b| { - b.call(Func::WarpActiveMax, &[v.node()], ::type_()) + b.call(Func::WarpActiveMax, &[v.node()], ::type_()) })) } -pub fn warp_active_min(v: T) -> T::Element { +pub fn warp_active_min(v: Expr) -> Expr { ::from_node(__current_scope(|b| { - b.call(Func::WarpActiveMin, &[v.node()], ::type_()) + b.call(Func::WarpActiveMin, &[v.node()], ::type_()) })) } -pub fn warp_active_product(v: T) -> T::Element { +pub fn warp_active_product(v: Expr) -> Expr { ::from_node(__current_scope(|b| { b.call( Func::WarpActiveProduct, @@ -121,9 +132,9 @@ pub fn warp_active_product(v: T) -> T::Element { ) })) } -pub fn warp_active_sum(v: T) -> T::Element { +pub fn warp_active_sum(v: Expr) -> Expr { ::from_node(__current_scope(|b| { - b.call(Func::WarpActiveSum, &[v.node()], ::type_()) + b.call(Func::WarpActiveSum, &[v.node()], ::type_()) })) } pub fn warp_active_all(v: Expr) -> Expr { @@ -150,12 +161,12 @@ pub fn warp_prefix_count_bits(v: Expr) -> Expr { ) })) } -pub fn warp_prefix_sum_exclusive(v: T) -> T { +pub fn warp_prefix_sum_exclusive(v: Expr) -> Expr { T::from_node(__current_scope(|b| { b.call(Func::WarpPrefixSum, &[v.node()], v.node().type_().clone()) })) } -pub fn warp_prefix_product_exclusive(v: T) -> T { +pub fn warp_prefix_product_exclusive(v: Expr) -> Expr { T::from_node(__current_scope(|b| { b.call( Func::WarpPrefixProduct, @@ -164,7 +175,8 @@ pub fn warp_prefix_product_exclusive(v: T) -> T { ) })) } -pub fn warp_read_lane_at(v: T, index: impl Into>) -> T { +// TODO: Difference between `Linear` and BuiltinVarTrait? +pub fn warp_read_lane_at(v: T, index: impl AsExpr) -> T { let index = index.into(); T::from_node(__current_scope(|b| { b.call( @@ -174,7 +186,7 @@ pub fn warp_read_lane_at(v: T, index: impl Into>) ) })) } -pub fn warp_read_first_active_lane(v: T) -> T { +pub fn warp_read_first_active_lane(v: T) -> T { T::from_node(__current_scope(|b| { b.call( Func::WarpReadFirstLane, diff --git a/luisa_compute/src/lang/types/core.rs b/luisa_compute/src/lang/types/core.rs index 4c968a0..6fd624a 100644 --- a/luisa_compute/src/lang/types/core.rs +++ b/luisa_compute/src/lang/types/core.rs @@ -154,82 +154,86 @@ impls!(Floating for f16, f32, f64); pub trait Signed: Numeric {} impls!(Signed for f16, f32, f64, i8, i16, i32, i64); -#[deprecated] -pub type Bool = Expr; -#[deprecated] -pub type F16 = Expr; -#[deprecated] -pub type F32 = Expr; -#[deprecated] -pub type F64 = Expr; -#[deprecated] -pub type I16 = Expr; -#[deprecated] -pub type I32 = Expr; -#[deprecated] -pub type I64 = Expr; -#[deprecated] -pub type U16 = Expr; -#[deprecated] -pub type U32 = Expr; -#[deprecated] -pub type U64 = Expr; - -#[deprecated] -pub type F16Var = Var; -#[deprecated] -pub type F32Var = Var; -#[deprecated] -pub type F64Var = Var; -#[deprecated] -pub type I16Var = Var; -#[deprecated] -pub type I32Var = Var; -#[deprecated] -pub type I64Var = Var; -#[deprecated] -pub type U16Var = Var; -#[deprecated] -pub type U32Var = Var; -#[deprecated] -pub type U64Var = Var; - -#[deprecated] -pub type Half = Expr; -#[deprecated] -pub type Float = Expr; -#[deprecated] -pub type Double = Expr; -#[deprecated] -pub type Int = Expr; -#[deprecated] -pub type Long = Expr; -#[deprecated] -pub type Uint = Expr; -#[deprecated] -pub type Ulong = Expr; -#[deprecated] -pub type Short = Expr; -#[deprecated] -pub type Ushort = Expr; - -#[deprecated] -pub type BoolVar = Var; -#[deprecated] -pub type HalfVar = Var; -#[deprecated] -pub type FloatVar = Var; -#[deprecated] -pub type DoubleVar = Var; -#[deprecated] -pub type IntVar = Var; -#[deprecated] -pub type LongVar = Var; -#[deprecated] -pub type UintVar = Var; -#[deprecated] -pub type UlongVar = Var; -#[deprecated] -pub type ShortVar = Var; -#[deprecated] -pub type UshortVar = Var; +mod legacy { + use super::*; + + #[deprecated] + pub type Bool = Expr; + #[deprecated] + pub type F16 = Expr; + #[deprecated] + pub type F32 = Expr; + #[deprecated] + pub type F64 = Expr; + #[deprecated] + pub type I16 = Expr; + #[deprecated] + pub type I32 = Expr; + #[deprecated] + pub type I64 = Expr; + #[deprecated] + pub type U16 = Expr; + #[deprecated] + pub type U32 = Expr; + #[deprecated] + pub type U64 = Expr; + + #[deprecated] + pub type F16Var = Var; + #[deprecated] + pub type F32Var = Var; + #[deprecated] + pub type F64Var = Var; + #[deprecated] + pub type I16Var = Var; + #[deprecated] + pub type I32Var = Var; + #[deprecated] + pub type I64Var = Var; + #[deprecated] + pub type U16Var = Var; + #[deprecated] + pub type U32Var = Var; + #[deprecated] + pub type U64Var = Var; + + #[deprecated] + pub type Half = Expr; + #[deprecated] + pub type Float = Expr; + #[deprecated] + pub type Double = Expr; + #[deprecated] + pub type Int = Expr; + #[deprecated] + pub type Long = Expr; + #[deprecated] + pub type Uint = Expr; + #[deprecated] + pub type Ulong = Expr; + #[deprecated] + pub type Short = Expr; + #[deprecated] + pub type Ushort = Expr; + + #[deprecated] + pub type BoolVar = Var; + #[deprecated] + pub type HalfVar = Var; + #[deprecated] + pub type FloatVar = Var; + #[deprecated] + pub type DoubleVar = Var; + #[deprecated] + pub type IntVar = Var; + #[deprecated] + pub type LongVar = Var; + #[deprecated] + pub type UintVar = Var; + #[deprecated] + pub type UlongVar = Var; + #[deprecated] + pub type ShortVar = Var; + #[deprecated] + pub type UshortVar = Var; +} diff --git a/luisa_compute/src/lang/types/vector.rs b/luisa_compute/src/lang/types/vector.rs index e8124ca..a0c1611 100644 --- a/luisa_compute/src/lang/types/vector.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -11,6 +11,8 @@ mod nalgebra; pub mod coords; mod element; +mod impls; +pub mod legacy; pub mod swizzle; use swizzle::*; @@ -36,52 +38,7 @@ impl, const N: usize> Debug for Vector { #[derive(Copy, Clone, PartialEq, Eq)] pub struct Vector, const N: usize> { _align: T::A, - elements: [T; N], -} - -impl, const N: usize> Vector { - pub fn from_elements(elements: [T; N]) -> Self { - Self { - _align: T::A::default(), - elements, - } - } - pub fn splat(element: T) -> Self { - Self { - _align: T::A::default(), - elements: [element; N], - } - } - pub fn splat_expr(element: impl AsExpr) -> Expr { - Func::Vec.call(element.as_expr()) - } - fn _permute2(&self, x: u32, y: u32) -> Vector - where - T: VectorAlign<2>, - { - Vector::from_elements([self.elements[x as usize], self.elements[y as usize]]) - } - fn _permute3(&self, x: u32, y: u32, z: u32) -> Vector - where - T: VectorAlign<3>, - { - Vector::from_elements([ - self.elements[x as usize], - self.elements[y as usize], - self.elements[z as usize], - ]) - } - fn _permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Vector - where - T: VectorAlign<4>, - { - Vector::from_elements([ - self.elements[x as usize], - self.elements[y as usize], - self.elements[z as usize], - self.elements[w as usize], - ]) - } + pub elements: [T; N], } #[repr(C)] @@ -174,10 +131,40 @@ impl, const N: usize> Value for Vector { type VarData = T::VectorVarData; } -impl Vec2Swizzle for Vector { +impl, const N: usize> Vector { + fn _permute2(&self, x: u32, y: u32) -> Vec2 + where + T: VectorAlign<2>, + { + Vector::from_elements([self.elements[x as usize], self.elements[y as usize]]) + } + fn _permute3(&self, x: u32, y: u32, z: u32) -> Vec3 + where + T: VectorAlign<3>, + { + Vector::from_elements([ + self.elements[x as usize], + self.elements[y as usize], + self.elements[z as usize], + ]) + } + fn _permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Vec4 + where + T: VectorAlign<4>, + { + Vector::from_elements([ + self.elements[x as usize], + self.elements[y as usize], + self.elements[z as usize], + self.elements[w as usize], + ]) + } +} + +impl Vec2Swizzle for Vec2 { type Vec2 = Self; - type Vec3 = Vector; - type Vec4 = Vector; + type Vec3 = Vec3; + type Vec4 = Vec4; fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { self._permute2(x, y) } @@ -188,10 +175,10 @@ impl Vec2Swizzle for Vector { self._permute4(x, y, z, w) } } -impl Vec3Swizzle for Vector { - type Vec2 = Vector; +impl Vec3Swizzle for Vec3 { + type Vec2 = Vec2; type Vec3 = Self; - type Vec4 = Vector; + type Vec4 = Vec4; fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { self._permute2(x, y) } @@ -202,9 +189,9 @@ impl Vec3Swizzle for Vector { self._permute4(x, y, z, w) } } -impl Vec4Swizzle for Vector { - type Vec2 = Vector; - type Vec3 = Vector; +impl Vec4Swizzle for Vec4 { + type Vec2 = Vec2; + type Vec3 = Vec3; type Vec4 = Self; fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { self._permute2(x, y) @@ -216,9 +203,9 @@ impl Vec4Swizzle for Vector { self._permute4(x, y, z, w) } } -/* + impl, const N: usize> VectorExprData { - fn _permute2(&self, x: u32, y: u32) -> Expr> + fn _permute2(&self, x: u32, y: u32) -> Expr> where T: VectorAlign<2>, { @@ -226,7 +213,7 @@ impl, const N: usize> VectorExprData { assert!(y < N as u32); let x = x.expr(); let y = y.expr(); - Expr::>::from_node(__current_scope(|s| { + Expr::>::from_node(__current_scope(|s| { s.call( Func::Permute, &[self.node, ToNode::node(&x), ToNode::node(&y)], @@ -234,7 +221,7 @@ impl, const N: usize> VectorExprData { ) })) } - fn _permute3(&self, x: u32, y: u32, z: u32) -> Expr> + fn _permute3(&self, x: u32, y: u32, z: u32) -> Expr> where T: VectorAlign<3>, { @@ -244,7 +231,7 @@ impl, const N: usize> VectorExprData { let x = x.expr(); let y = y.expr(); let z = z.expr(); - Expr::>::from_node(__current_scope(|s| { + Expr::>::from_node(__current_scope(|s| { s.call( Func::Permute, &[ @@ -257,7 +244,7 @@ impl, const N: usize> VectorExprData { ) })) } - fn _permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Expr> + fn _permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Expr> where T: VectorAlign<4>, { @@ -269,7 +256,7 @@ impl, const N: usize> VectorExprData { let y = y.expr(); let z = z.expr(); let w = w.expr(); - Expr::>::from_node(__current_scope(|s| { + Expr::>::from_node(__current_scope(|s| { s.call( Func::Permute, &[ @@ -287,8 +274,8 @@ impl, const N: usize> VectorExprData { impl Vec2Swizzle for VectorExprProxy2 { type Vec2 = Self; - type Vec3 = Expr>; - type Vec4 = Expr>; + type Vec3 = Expr>; + type Vec4 = Expr>; fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { self._permute2(x, y) } @@ -299,11 +286,10 @@ impl Vec2Swizzle for VectorExprProxy2 { self._permute4(x, y, z, w) } } - impl Vec3Swizzle for VectorExprProxy3 { - type Vec2 = Expr>; + type Vec2 = Expr>; type Vec3 = Self; - type Vec4 = Expr>; + type Vec4 = Expr>; fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { self._permute2(x, y) } @@ -315,8 +301,8 @@ impl Vec3Swizzle for VectorExprProxy3 { } } impl Vec4Swizzle for VectorExprProxy4 { - type Vec2 = Expr>; - type Vec3 = Expr>; + type Vec2 = Expr>; + type Vec3 = Expr>; type Vec4 = Self; fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { self._permute2(x, y) @@ -328,4 +314,7 @@ impl Vec4Swizzle for VectorExprProxy4 { self._permute4(x, y, z, w) } } - */ + +pub type Vec2> = Vector; +pub type Vec3> = Vector; +pub type Vec4> = Vector; diff --git a/luisa_compute/src/lang/types/vector/impls.rs b/luisa_compute/src/lang/types/vector/impls.rs new file mode 100644 index 0000000..d679e74 --- /dev/null +++ b/luisa_compute/src/lang/types/vector/impls.rs @@ -0,0 +1,47 @@ +use super::*; + +impl, const N: usize> Vector { + pub fn from_elements(elements: [T; N]) -> Self { + Self { + _align: T::A::default(), + elements, + } + } + pub fn splat(element: T) -> Self { + Self { + _align: T::A::default(), + elements: [element; N], + } + } + pub fn splat_expr(element: impl AsExpr) -> Expr { + Func::Vec.call(element.as_expr()) + } + pub fn map(&self, f: impl Fn(T) -> T) -> Self { + Self { + _align: T::A::default(), + elements: self.elements.map(f), + } + } + pub fn expr_from_elements(elements: [Expr; N]) -> Expr { + Expr::::from_node(__compose(elements.map(ToNode::node))) + } +} + +macro_rules! impl_sized { + ($Vn:ident($N: literal): $($xs:ident),+) => { + impl> $Vn { + pub fn new($($xs: T),+) -> Self { + Self { + _align: T::A::default(), + elements: [$($xs),+], + } + } + pub fn expr($($xs: impl AsExpr),+) -> Expr { + Self::expr_from_elements([$($xs.as_expr()),+]) + } + } + } +} +impl_sized!(Vec2(2): x, y); +impl_sized!(Vec3(3): x, y, z); +impl_sized!(Vec4(4): x, y, z, w); diff --git a/luisa_compute/src/lang/types/vector/legacy.rs b/luisa_compute/src/lang/types/vector/legacy.rs new file mode 100644 index 0000000..bd7c9e3 --- /dev/null +++ b/luisa_compute/src/lang/types/vector/legacy.rs @@ -0,0 +1,208 @@ +use super::*; + +#[deprecated] +pub type Half2 = Vec2; +#[deprecated] +pub type Half3 = Vec3; +#[deprecated] +pub type Half4 = Vec4; +#[deprecated] +pub type Float2 = Vec2; +#[deprecated] +pub type Float3 = Vec3; +#[deprecated] +pub type Float4 = Vec4; +#[deprecated] +pub type Double2 = Vec2; +#[deprecated] +pub type Double3 = Vec3; +#[deprecated] +pub type Double4 = Vec4; +#[deprecated] +pub type Byte2 = Vec2; +#[deprecated] +pub type Byte3 = Vec3; +#[deprecated] +pub type Byte4 = Vec4; +#[deprecated] +pub type Short2 = Vec2; +#[deprecated] +pub type Short3 = Vec3; +#[deprecated] +pub type Short4 = Vec4; +#[deprecated] +pub type Int2 = Vec2; +#[deprecated] +pub type Int3 = Vec3; +#[deprecated] +pub type Int4 = Vec4; +#[deprecated] +pub type Long2 = Vec2; +#[deprecated] +pub type Long3 = Vec3; +#[deprecated] +pub type Long4 = Vec4; +#[deprecated] +pub type Ubyte2 = Vec2; +#[deprecated] +pub type Ubyte3 = Vec3; +#[deprecated] +pub type Ubyte4 = Vec4; +#[deprecated] +pub type Ushort2 = Vec2; +#[deprecated] +pub type Ushort3 = Vec3; +#[deprecated] +pub type Ushort4 = Vec4; +#[deprecated] +pub type Uint2 = Vec2; +#[deprecated] +pub type Uint3 = Vec3; +#[deprecated] +pub type Uint4 = Vec4; +#[deprecated] +pub type Ulong2 = Vec2; +#[deprecated] +pub type Bool2 = Vec2; +#[deprecated] +pub type Bool3 = Vec3; +#[deprecated] +pub type Bool4 = Vec4; + +#[deprecated] +pub type Half2Expr = Expr>; +#[deprecated] +pub type Half3Expr = Expr>; +#[deprecated] +pub type Half4Expr = Expr>; +#[deprecated] +pub type Float2Expr = Expr>; +#[deprecated] +pub type Float3Expr = Expr>; +#[deprecated] +pub type Float4Expr = Expr>; +#[deprecated] +pub type Double2Expr = Expr>; +#[deprecated] +pub type Double3Expr = Expr>; +#[deprecated] +pub type Double4Expr = Expr>; +#[deprecated] +pub type Byte2Expr = Expr>; +#[deprecated] +pub type Byte3Expr = Expr>; +#[deprecated] +pub type Byte4Expr = Expr>; +#[deprecated] +pub type Short2Expr = Expr>; +#[deprecated] +pub type Short3Expr = Expr>; +#[deprecated] +pub type Short4Expr = Expr>; +#[deprecated] +pub type Int2Expr = Expr>; +#[deprecated] +pub type Int3Expr = Expr>; +#[deprecated] +pub type Int4Expr = Expr>; +#[deprecated] +pub type Long2Expr = Expr>; +#[deprecated] +pub type Long3Expr = Expr>; +#[deprecated] +pub type Long4Expr = Expr>; +#[deprecated] +pub type Ubyte2Expr = Expr>; +#[deprecated] +pub type Ubyte3Expr = Expr>; +#[deprecated] +pub type Ubyte4Expr = Expr>; +#[deprecated] +pub type Ushort2Expr = Expr>; +#[deprecated] +pub type Ushort3Expr = Expr>; +#[deprecated] +pub type Ushort4Expr = Expr>; +#[deprecated] +pub type Uint2Expr = Expr>; +#[deprecated] +pub type Uint3Expr = Expr>; +#[deprecated] +pub type Uint4Expr = Expr>; +#[deprecated] +pub type Ulong2Expr = Expr>; +#[deprecated] +pub type Bool2Expr = Expr>; +#[deprecated] +pub type Bool3Expr = Expr>; +#[deprecated] +pub type Bool4Expr = Expr>; + +#[deprecated] +pub type Half2Var = Var>; +#[deprecated] +pub type Half3Var = Var>; +#[deprecated] +pub type Half4Var = Var>; +#[deprecated] +pub type Float2Var = Var>; +#[deprecated] +pub type Float3Var = Var>; +#[deprecated] +pub type Float4Var = Var>; +#[deprecated] +pub type Double2Var = Var>; +#[deprecated] +pub type Double3Var = Var>; +#[deprecated] +pub type Double4Var = Var>; +#[deprecated] +pub type Byte2Var = Var>; +#[deprecated] +pub type Byte3Var = Var>; +#[deprecated] +pub type Byte4Var = Var>; +#[deprecated] +pub type Short2Var = Var>; +#[deprecated] +pub type Short3Var = Var>; +#[deprecated] +pub type Short4Var = Var>; +#[deprecated] +pub type Int2Var = Var>; +#[deprecated] +pub type Int3Var = Var>; +#[deprecated] +pub type Int4Var = Var>; +#[deprecated] +pub type Long2Var = Var>; +#[deprecated] +pub type Long3Var = Var>; +#[deprecated] +pub type Long4Var = Var>; +#[deprecated] +pub type Ubyte2Var = Var>; +#[deprecated] +pub type Ubyte3Var = Var>; +#[deprecated] +pub type Ubyte4Var = Var>; +#[deprecated] +pub type Ushort2Var = Var>; +#[deprecated] +pub type Ushort3Var = Var>; +#[deprecated] +pub type Ushort4Var = Var>; +#[deprecated] +pub type Uint2Var = Var>; +#[deprecated] +pub type Uint3Var = Var>; +#[deprecated] +pub type Uint4Var = Var>; +#[deprecated] +pub type Ulong2Var = Var>; +#[deprecated] +pub type Bool2Var = Var>; +#[deprecated] +pub type Bool3Var = Var>; +#[deprecated] +pub type Bool4Var = Var>; diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index fd752ec..28e360b 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -23,7 +23,7 @@ pub mod prelude { pub use crate::lang::index::{IndexRead, IndexWrite}; pub use crate::lang::ops::*; pub use crate::lang::types::vector::swizzle::*; - pub use crate::lang::types::vector::Vector; + pub use crate::lang::types::vector::{Vec2, Vec3, Vec4, Vector}; pub use crate::lang::types::{AsExpr, Expr, Value, Var}; pub use crate::lang::Aggregate; pub use crate::resource::{IoTexel, StorageTexel, *}; @@ -38,12 +38,13 @@ pub mod prelude { } mod internal_prelude { - pub(crate) use crate::lang::debug::{__env_need_backtrace, is_cpu_backend}; + pub(crate) use crate::lang::debug::{__env_need_backtrace, is_cpu_backend, CpuFn}; pub(crate) use crate::lang::ir::ffi::*; pub(crate) use crate::lang::ir::{ new_node, register_type, BasicBlock, Const, Func, Instruction, IrBuilder, Node, PhiIncoming, Pooled, Type, TypeOf, INVALID_REF, }; + pub(crate) use crate::lang::types::vector::legacy::*; pub(crate) use crate::lang::{ ir, CallFuncTrait, Recorder, __compose, __extract, __insert, __module_pools, need_runtime_check, FromNode, NodeLike, NodeRef, ToNode, __current_scope, __pop_scope, diff --git a/luisa_compute/src/printer.rs b/luisa_compute/src/printer.rs index 61cdc76..d696ca4 100644 --- a/luisa_compute/src/printer.rs +++ b/luisa_compute/src/printer.rs @@ -33,11 +33,8 @@ pub struct PrinterArgs { count: usize, } impl PrinterArgs { - pub fn append(&mut self, v: E) - where - E::Value: Debug, - { - let n = packed_size::(); + pub fn append(&mut self, v: Expr) { + let n = packed_size::(); self.count_per_arg.push(n); self.pack_fn.push(Box::new(move |offset, data| { pack_to(v, data, offset); @@ -104,8 +101,8 @@ macro_rules! lc_error { $crate::lc_log!($printer, log::Level::Error, $fmt, $($arg)*); }; } -pub fn _unpack_from_expr(data: *const u32, _: E) -> E::Value { - unsafe { std::ptr::read_unaligned(data as *const E::Value) } +pub fn _unpack_from_expr(data: *const u32, _: Expr) -> V { + unsafe { std::ptr::read_unaligned(data as *const V) } } impl Printer { pub fn new(device: &Device, size: usize) -> Self { diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index 63848be..2451746 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -1003,8 +1003,9 @@ impl_storage_texel!(Half4, Half4, f32, Float2, Float4,); impl_storage_texel!([f16; 2], Byte2, f32, Float2, Float4, Int2, Int4, Uint2, Uint4,); impl_storage_texel!([f16; 4], Byte4, f32, Float2, Float4, Int2, Int4, Uint2, Uint4,); -// `T` is the read out type of the texture, which is not necessarily the same as the storage type -// In fact, the texture can be stored in any format as long as it can be converted to `T` +// `T` is the read out type of the texture, which is not necessarily the same as +// the storage type In fact, the texture can be stored in any format as long as +// it can be converted to `T` pub struct Tex2d { #[allow(dead_code)] pub(crate) width: u32, @@ -1014,8 +1015,9 @@ pub struct Tex2d { pub(crate) marker: PhantomData, } -// `T` is the read out type of the texture, which is not necessarily the same as the storage type -// In fact, the texture can be stored in any format as long as it can be converted to `T` +// `T` is the read out type of the texture, which is not necessarily the same as +// the storage type In fact, the texture can be stored in any format as long as +// it can be converted to `T` pub struct Tex3d { #[allow(dead_code)] pub(crate) width: u32, diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index 806e870..2963cee 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -118,11 +118,7 @@ pub trait KernelParameter { fn def_param(builder: &mut KernelBuilder) -> Self; } -impl KernelParameter for U -where - U: ExprProxy, - T: Value, -{ +impl KernelParameter for Expr { fn def_param(builder: &mut KernelBuilder) -> Self { builder.uniform::() } @@ -507,12 +503,12 @@ unsafe impl CallableRet for () { fn _from_return(_: NodeRef) -> Self {} } -unsafe impl CallableRet for T { +unsafe impl CallableRet for Expr { fn _return(&self) -> CArc { __current_scope(|b| { b.return_(self.node()); }); - T::Value::type_() + V::type_() } fn _from_return(node: NodeRef) -> Self { Self::from_node(node) From 3cb95a687d65b78192bfecd67679b7cf5360aba6 Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Wed, 20 Sep 2023 21:48:23 +0100 Subject: [PATCH 14/15] Added matrix. --- luisa_compute/src/lang/ops/traits.rs | 2 + luisa_compute/src/lang/types.rs | 20 ++++---- luisa_compute/src/lang/types/vector.rs | 65 ++++++++++++++++++++++++++ luisa_compute/src/lib.rs | 4 +- 4 files changed, 81 insertions(+), 10 deletions(-) diff --git a/luisa_compute/src/lang/ops/traits.rs b/luisa_compute/src/lang/ops/traits.rs index 08732a2..bf0dcf9 100644 --- a/luisa_compute/src/lang/ops/traits.rs +++ b/luisa_compute/src/lang/ops/traits.rs @@ -1,5 +1,7 @@ use super::*; +// The double trait implementation is necessary as the compiler infinite loops +// when trying to resolve the Expr: SpreadOps>> bound. macro_rules! ops_trait { ( $TraitExpr:ident<$($T:ident),*> [ $TraitThis:ident] { diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs index 1f7447d..d110eda 100644 --- a/luisa_compute/src/lang/types.rs +++ b/luisa_compute/src/lang/types.rs @@ -211,12 +211,13 @@ pub fn _deref_proxy(proxy: &P) -> &Expr { #[macro_export] macro_rules! impl_simple_expr_proxy { - ([ $($bounds:tt)* ] $name: ident [ $($qualifiers:tt)* ] for $t: ty) => { + ($([ $($bounds:tt)* ])? $name: ident $([ $($qualifiers:tt)* ])? for $t: ty $(where $($where_bounds:tt)+)?) => { #[derive(Debug, Clone, Copy)] #[repr(transparent)] - pub struct $name < $($bounds)* > ($crate::lang::types::Expr<$t>); - unsafe impl < $($bounds)* > $crate::lang::types::HasExprLayout< <$t as $crate::lang::types::Value>::ExprData > for $name < $($qualifiers)* > {} - impl < $($bounds)* > $crate::lang::types::ExprProxy for $name < $($qualifiers)* > { + pub struct $name $(< $($bounds)* >)? ($crate::lang::types::Expr<$t>) $(where $($where_bounds)+)?; + unsafe impl $(< $($bounds)* >)? $crate::lang::types::HasExprLayout< <$t as $crate::lang::types::Value>::ExprData > + for $name $(< $($qualifiers)* >)? $(where $($where_bounds)+)? {} + impl $(< $($bounds)* >)? $crate::lang::types::ExprProxy for $name $(< $($qualifiers)* >)? $(where $($where_bounds)+)? { type Value = $t; } } @@ -224,15 +225,16 @@ macro_rules! impl_simple_expr_proxy { #[macro_export] macro_rules! impl_simple_var_proxy { - ([ $($bounds:tt)* ] $name: ident [ $($qualifiers:tt)* ] for $t: ty) => { + ($([ $($bounds:tt)* ])? $name: ident $([ $($qualifiers:tt)* ])? for $t: ty $(where $($where_bounds:tt)+)?) => { #[derive(Debug, Clone, Copy)] #[repr(transparent)] - pub struct $name < $($bounds)* > ($crate::lang::types::Var<$t>); - unsafe impl < $($bounds)* > $crate::lang::types::HasVarLayout< <$t as $crate::lang::types::Value>::VarData > for $name < $($qualifiers)* > {} - impl < $($bounds)* > $crate::lang::types::VarProxy for $name < $($qualifiers)* > { + pub struct $name $(< $($bounds)* >)? ($crate::lang::types::Var<$t>) $(where $($where_bounds)+)?; + unsafe impl $(< $($bounds)* >)? $crate::lang::types::HasVarLayout< <$t as $crate::lang::types::Value>::VarData > + for $name $(< $($qualifiers)* >)? $(where $($where_bounds)+)? {} + impl $(< $($bounds)* >)? $crate::lang::types::VarProxy for $name $(< $($qualifiers)* >)? $(where $($where_bounds)+)? { type Value = $t; } - impl < $($bounds)* > std::ops::Deref for $name < $($qualifiers)* > { + impl $(< $($bounds)* >)? std::ops::Deref for $name $(< $($qualifiers)* >)? $(where $($where_bounds)+)? { type Target = $crate::lang::types::Expr<$t>; fn deref(&self) -> &Self::Target { $crate::lang::types::_deref_proxy(self) diff --git a/luisa_compute/src/lang/types/vector.rs b/luisa_compute/src/lang/types/vector.rs index a0c1611..21c7dd3 100644 --- a/luisa_compute/src/lang/types/vector.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -318,3 +318,68 @@ impl Vec4Swizzle for VectorExprProxy4 { pub type Vec2> = Vector; pub type Vec3> = Vector; pub type Vec4> = Vector; + +// Matrix + +impl Debug for SquareMatrix +where + f32: VectorAlign, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.elements.fmt(f) + } +} + +#[repr(C)] +#[derive(Copy, Clone, PartialEq, Eq)] +pub struct SquareMatrix +where + f32: VectorAlign, +{ + pub elements: [Vector; N], +} + +impl TypeOf for SquareMatrix +where + f32: VectorAlign, +{ + fn type_() -> CArc { + let type_ = Type::Matrix(ir::MatrixType { + element: VectorElementType::Scalar(Primitive::Float32), + dimension: N, + }); + register_type(type_) + } +} + +impl_simple_expr_proxy!(SquareMatrixExpr2 for SquareMatrix<2>); +impl_simple_var_proxy!(SquareMatrixVar2 for SquareMatrix<2>); + +impl_simple_expr_proxy!(SquareMatrixExpr3 for SquareMatrix<3>); +impl_simple_var_proxy!(SquareMatrixVar3 for SquareMatrix<3>); + +impl_simple_expr_proxy!(SquareMatrixExpr4 for SquareMatrix<4>); +impl_simple_var_proxy!(SquareMatrixVar4 for SquareMatrix<4>); + +impl Value for SquareMatrix<2> { + type Expr = SquareMatrixExpr2; + type Var = SquareMatrixVar2; + type ExprData = (); + type VarData = (); +} +impl Value for SquareMatrix<3> { + type Expr = SquareMatrixExpr3; + type Var = SquareMatrixVar3; + type ExprData = (); + type VarData = (); +} +impl Value for SquareMatrix<4> { + type Expr = SquareMatrixExpr4; + type Var = SquareMatrixVar4; + type ExprData = (); + type VarData = (); +} + +pub type Mat2 = SquareMatrix<2>; +pub type Mat3 = SquareMatrix<3>; +pub type Mat4 = SquareMatrix<4>; diff --git a/luisa_compute/src/lib.rs b/luisa_compute/src/lib.rs index 28e360b..238be8d 100644 --- a/luisa_compute/src/lib.rs +++ b/luisa_compute/src/lib.rs @@ -23,7 +23,9 @@ pub mod prelude { pub use crate::lang::index::{IndexRead, IndexWrite}; pub use crate::lang::ops::*; pub use crate::lang::types::vector::swizzle::*; - pub use crate::lang::types::vector::{Vec2, Vec3, Vec4, Vector}; + pub use crate::lang::types::vector::{ + Mat2, Mat3, Mat4, SquareMatrix, Vec2, Vec3, Vec4, Vector, + }; pub use crate::lang::types::{AsExpr, Expr, Value, Var}; pub use crate::lang::Aggregate; pub use crate::resource::{IoTexel, StorageTexel, *}; From dc5818d06ba235a198231853fd98bdae0d776b80 Mon Sep 17 00:00:00 2001 From: ReversedGravity Date: Wed, 20 Sep 2023 22:31:53 +0100 Subject: [PATCH 15/15] Compiled. --- luisa_compute/src/lang/debug.rs | 2 +- luisa_compute/src/lang/functions.rs | 36 ++--- luisa_compute/src/lang/index.rs | 2 +- luisa_compute/src/lang/ops/spread.rs | 2 +- luisa_compute/src/lang/ops/traits.rs | 4 +- luisa_compute/src/lang/types.rs | 12 +- luisa_compute/src/lang/types/dynamic.rs | 10 +- luisa_compute/src/lang/types/vector.rs | 146 +++++++++---------- luisa_compute/src/lang/types/vector/impls.rs | 2 +- luisa_compute/src/printer.rs | 5 +- luisa_compute/src/resource.rs | 106 ++++++++------ luisa_compute/src/runtime/kernel.rs | 26 +--- 12 files changed, 168 insertions(+), 185 deletions(-) diff --git a/luisa_compute/src/lang/debug.rs b/luisa_compute/src/lang/debug.rs index 36dbd5e..d2213e0 100644 --- a/luisa_compute/src/lang/debug.rs +++ b/luisa_compute/src/lang/debug.rs @@ -58,7 +58,7 @@ impl CpuFn { Expr::::from_node(__current_scope(|b| { b.call( Func::CpuCustomOp(self.op.clone()), - &[arg.node()], + &[arg.as_expr().node()], T::type_(), ) })) diff --git a/luisa_compute/src/lang/functions.rs b/luisa_compute/src/lang/functions.rs index 4094d01..a5d057d 100644 --- a/luisa_compute/src/lang/functions.rs +++ b/luisa_compute/src/lang/functions.rs @@ -53,7 +53,7 @@ pub fn sync_block() { pub fn warp_is_first_active_lane() -> Expr { Expr::::from_node(__current_scope(|b| { - b.call(Func::WarpIsFirstActiveLane, &[], Expr::::type_()) + b.call(Func::WarpIsFirstActiveLane, &[], bool::type_()) })) } pub fn warp_active_all_equal(v: Expr) -> Expr { @@ -69,7 +69,7 @@ pub fn warp_active_bit_and(v: Expr) -> Expr where T::Scalar: Integral + Numeric, { - T::from_node(__current_scope(|b| { + Expr::::from_node(__current_scope(|b| { b.call( Func::WarpActiveBitAnd, &[v.node()], @@ -82,7 +82,7 @@ pub fn warp_active_bit_or(v: Expr) -> Expr where T::Scalar: Integral + Numeric, { - T::from_node(__current_scope(|b| { + Expr::::from_node(__current_scope(|b| { b.call( Func::WarpActiveBitOr, &[v.node()], @@ -95,7 +95,7 @@ pub fn warp_active_bit_xor(v: Expr) -> Expr where T::Scalar: Integral + Numeric, { - T::from_node(__current_scope(|b| { + Expr::::from_node(__current_scope(|b| { b.call( Func::WarpActiveBitXor, &[v.node()], @@ -114,26 +114,22 @@ pub fn warp_active_count_bits(v: impl AsExpr) -> Expr { })) } pub fn warp_active_max(v: Expr) -> Expr { - ::from_node(__current_scope(|b| { + Expr::::from_node(__current_scope(|b| { b.call(Func::WarpActiveMax, &[v.node()], ::type_()) })) } pub fn warp_active_min(v: Expr) -> Expr { - ::from_node(__current_scope(|b| { + Expr::::from_node(__current_scope(|b| { b.call(Func::WarpActiveMin, &[v.node()], ::type_()) })) } pub fn warp_active_product(v: Expr) -> Expr { - ::from_node(__current_scope(|b| { - b.call( - Func::WarpActiveProduct, - &[v.node()], - ::type_(), - ) + Expr::::from_node(__current_scope(|b| { + b.call(Func::WarpActiveProduct, &[v.node()], ::type_()) })) } pub fn warp_active_sum(v: Expr) -> Expr { - ::from_node(__current_scope(|b| { + Expr::::from_node(__current_scope(|b| { b.call(Func::WarpActiveSum, &[v.node()], ::type_()) })) } @@ -162,12 +158,12 @@ pub fn warp_prefix_count_bits(v: Expr) -> Expr { })) } pub fn warp_prefix_sum_exclusive(v: Expr) -> Expr { - T::from_node(__current_scope(|b| { + Expr::::from_node(__current_scope(|b| { b.call(Func::WarpPrefixSum, &[v.node()], v.node().type_().clone()) })) } pub fn warp_prefix_product_exclusive(v: Expr) -> Expr { - T::from_node(__current_scope(|b| { + Expr::::from_node(__current_scope(|b| { b.call( Func::WarpPrefixProduct, &[v.node()], @@ -176,9 +172,9 @@ pub fn warp_prefix_product_exclusive(v: Expr) -> Expr { })) } // TODO: Difference between `Linear` and BuiltinVarTrait? -pub fn warp_read_lane_at(v: T, index: impl AsExpr) -> T { - let index = index.into(); - T::from_node(__current_scope(|b| { +pub fn warp_read_lane_at(v: Expr, index: impl AsExpr) -> Expr { + let index = index.as_expr(); + Expr::::from_node(__current_scope(|b| { b.call( Func::WarpReadLaneAt, &[v.node(), index.node()], @@ -186,8 +182,8 @@ pub fn warp_read_lane_at(v: T, index: impl AsExpr) -> T ) })) } -pub fn warp_read_first_active_lane(v: T) -> T { - T::from_node(__current_scope(|b| { +pub fn warp_read_first_active_lane(v: Expr) -> Expr { + Expr::::from_node(__current_scope(|b| { b.call( Func::WarpReadFirstLane, &[v.node()], diff --git a/luisa_compute/src/lang/index.rs b/luisa_compute/src/lang/index.rs index 9c2fc5f..d83a0c2 100644 --- a/luisa_compute/src/lang/index.rs +++ b/luisa_compute/src/lang/index.rs @@ -40,5 +40,5 @@ pub trait IndexRead: ToNode { } pub trait IndexWrite: IndexRead { - fn write>>(&self, i: I, value: V); + fn write>(&self, i: I, value: V); } diff --git a/luisa_compute/src/lang/ops/spread.rs b/luisa_compute/src/lang/ops/spread.rs index ae08e57..2abab65 100644 --- a/luisa_compute/src/lang/ops/spread.rs +++ b/luisa_compute/src/lang/ops/spread.rs @@ -1,7 +1,7 @@ use super::*; use traits::*; -trait SpreadOps { +pub trait SpreadOps { type Join: Value; fn lift_self(x: Self) -> Expr; fn lift_other(x: Other) -> Expr; diff --git a/luisa_compute/src/lang/ops/traits.rs b/luisa_compute/src/lang/ops/traits.rs index bf0dcf9..a4561e1 100644 --- a/luisa_compute/src/lang/ops/traits.rs +++ b/luisa_compute/src/lang/ops/traits.rs @@ -10,7 +10,7 @@ macro_rules! ops_trait { )+ } ) => { - pub(crate) trait $TraitThis { + pub trait $TraitThis { $( fn $fn_this(self, $($arg: Self),*) -> Self; )* @@ -31,7 +31,7 @@ macro_rules! ops_trait { )+ } ) => { - pub(crate) trait $TraitThis { + pub trait $TraitThis { type Output; $( fn $fn_this(self, $($arg: Self),*) -> Self::Output; diff --git a/luisa_compute/src/lang/types.rs b/luisa_compute/src/lang/types.rs index d110eda..27b11af 100644 --- a/luisa_compute/src/lang/types.rs +++ b/luisa_compute/src/lang/types.rs @@ -259,6 +259,10 @@ mod private { impl Sealed for &Expr {} impl Sealed for Var {} impl Sealed for &Var {} + + impl Sealed for ValueType {} + impl Sealed for ExprType {} + impl Sealed for VarType {} } pub trait Tracked: private::Sealed { @@ -266,12 +270,12 @@ pub trait Tracked: private::Sealed { type Value: Value; } -trait TrackingType {} -struct ValueType; +pub trait TrackingType: private::Sealed {} +pub struct ValueType; impl TrackingType for ValueType {} -struct ExprType; +pub struct ExprType; impl TrackingType for ExprType {} -struct VarType; +pub struct VarType; impl TrackingType for VarType {} impl Tracked for T { diff --git a/luisa_compute/src/lang/types/dynamic.rs b/luisa_compute/src/lang/types/dynamic.rs index 73c49b6..595520d 100644 --- a/luisa_compute/src/lang/types/dynamic.rs +++ b/luisa_compute/src/lang/types/dynamic.rs @@ -9,14 +9,14 @@ pub struct DynExpr { node: NodeRef, } -impl From for DynExpr { - fn from(value: T) -> Self { +impl From> for DynExpr { + fn from(value: Expr) -> Self { Self { node: value.node() } } } -impl From for DynVar { - fn from(value: T) -> Self { +impl From> for DynVar { + fn from(value: Var) -> Self { Self { node: value.node() } } } @@ -62,7 +62,7 @@ impl DynExpr { ) }) } - pub fn new(expr: E) -> Self { + pub fn new(expr: Expr) -> Self { Self { node: expr.node() } } } diff --git a/luisa_compute/src/lang/types/vector.rs b/luisa_compute/src/lang/types/vector.rs index 21c7dd3..1381f18 100644 --- a/luisa_compute/src/lang/types/vector.rs +++ b/luisa_compute/src/lang/types/vector.rs @@ -71,7 +71,63 @@ impl FromNode for DoubledProxyData { } pub trait VectorExprProxy { + const N: usize; type T: Primitive; + fn node(&self) -> NodeRef; + fn _permute2(&self, x: u32, y: u32) -> Expr> + where + Self::T: VectorAlign<2>, + { + assert!(x < Self::N as u32); + assert!(y < Self::N as u32); + let x = x.expr(); + let y = y.expr(); + Expr::>::from_node(__current_scope(|s| { + s.call( + Func::Permute, + &[self.node(), x.node(), y.node()], + Vec2::::type_(), + ) + })) + } + fn _permute3(&self, x: u32, y: u32, z: u32) -> Expr> + where + Self::T: VectorAlign<3>, + { + assert!(x < Self::N as u32); + assert!(y < Self::N as u32); + assert!(z < Self::N as u32); + let x = x.expr(); + let y = y.expr(); + let z = z.expr(); + Expr::>::from_node(__current_scope(|s| { + s.call( + Func::Permute, + &[self.node(), x.node(), y.node(), z.node()], + Vec3::::type_(), + ) + })) + } + fn _permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Expr> + where + Self::T: VectorAlign<4>, + { + assert!(x < Self::N as u32); + assert!(y < Self::N as u32); + assert!(z < Self::N as u32); + assert!(w < Self::N as u32); + let x = x.expr(); + let y = y.expr(); + let z = z.expr(); + let w = w.expr(); + Expr::>::from_node(__current_scope(|s| { + s.call( + Func::Permute, + &[self.node(), x.node(), y.node(), z.node(), w.node()], + Vec4::::type_(), + ) + })) + } } macro_rules! vector_proxies { @@ -96,7 +152,11 @@ macro_rules! vector_proxies { type Value = Vector; } impl> VectorExprProxy for $ExprName { + const N: usize = $N; type T = T; + fn node(&self) -> NodeRef { + self._node + } } impl>> VarProxy for $VarName { type Value = Vector; @@ -204,76 +264,8 @@ impl Vec4Swizzle for Vec4 { } } -impl, const N: usize> VectorExprData { - fn _permute2(&self, x: u32, y: u32) -> Expr> - where - T: VectorAlign<2>, - { - assert!(x < N as u32); - assert!(y < N as u32); - let x = x.expr(); - let y = y.expr(); - Expr::>::from_node(__current_scope(|s| { - s.call( - Func::Permute, - &[self.node, ToNode::node(&x), ToNode::node(&y)], - Vector::::type_(), - ) - })) - } - fn _permute3(&self, x: u32, y: u32, z: u32) -> Expr> - where - T: VectorAlign<3>, - { - assert!(x < N as u32); - assert!(y < N as u32); - assert!(z < N as u32); - let x = x.expr(); - let y = y.expr(); - let z = z.expr(); - Expr::>::from_node(__current_scope(|s| { - s.call( - Func::Permute, - &[ - self.node, - ToNode::node(&x), - ToNode::node(&y), - ToNode::node(&z), - ], - Vector::::type_(), - ) - })) - } - fn _permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Expr> - where - T: VectorAlign<4>, - { - assert!(x < N as u32); - assert!(y < N as u32); - assert!(z < N as u32); - assert!(w < N as u32); - let x = x.expr(); - let y = y.expr(); - let z = z.expr(); - let w = w.expr(); - Expr::>::from_node(__current_scope(|s| { - s.call( - Func::Permute, - &[ - self.node, - ToNode::node(&x), - ToNode::node(&y), - ToNode::node(&z), - ToNode::node(&w), - ], - Vector::::type_(), - ) - })) - } -} - impl Vec2Swizzle for VectorExprProxy2 { - type Vec2 = Self; + type Vec2 = Expr>; type Vec3 = Expr>; type Vec4 = Expr>; fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { @@ -288,7 +280,7 @@ impl Vec2Swizzle for VectorExprProxy2 { } impl Vec3Swizzle for VectorExprProxy3 { type Vec2 = Expr>; - type Vec3 = Self; + type Vec3 = Expr>; type Vec4 = Expr>; fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { self._permute2(x, y) @@ -303,7 +295,7 @@ impl Vec3Swizzle for VectorExprProxy3 { impl Vec4Swizzle for VectorExprProxy4 { type Vec2 = Expr>; type Vec3 = Expr>; - type Vec4 = Self; + type Vec4 = Expr>; fn permute2(&self, x: u32, y: u32) -> Self::Vec2 { self._permute2(x, y) } @@ -315,9 +307,9 @@ impl Vec4Swizzle for VectorExprProxy4 { } } -pub type Vec2> = Vector; -pub type Vec3> = Vector; -pub type Vec4> = Vector; +pub type Vec2 = Vector; +pub type Vec3 = Vector; +pub type Vec4 = Vector; // Matrix @@ -331,7 +323,7 @@ where } #[repr(C)] -#[derive(Copy, Clone, PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq)] pub struct SquareMatrix where f32: VectorAlign, @@ -345,8 +337,8 @@ where { fn type_() -> CArc { let type_ = Type::Matrix(ir::MatrixType { - element: VectorElementType::Scalar(Primitive::Float32), - dimension: N, + element: VectorElementType::Scalar(f32::primitive()), + dimension: N as u32, }); register_type(type_) } diff --git a/luisa_compute/src/lang/types/vector/impls.rs b/luisa_compute/src/lang/types/vector/impls.rs index d679e74..80927d3 100644 --- a/luisa_compute/src/lang/types/vector/impls.rs +++ b/luisa_compute/src/lang/types/vector/impls.rs @@ -23,7 +23,7 @@ impl, const N: usize> Vector { } } pub fn expr_from_elements(elements: [Expr; N]) -> Expr { - Expr::::from_node(__compose(elements.map(ToNode::node))) + Expr::::from_node(__compose::(&elements.map(|x| x.node()))) } } diff --git a/luisa_compute/src/printer.rs b/luisa_compute/src/printer.rs index d696ca4..a41b988 100644 --- a/luisa_compute/src/printer.rs +++ b/luisa_compute/src/printer.rs @@ -37,6 +37,7 @@ impl PrinterArgs { let n = packed_size::(); self.count_per_arg.push(n); self.pack_fn.push(Box::new(move |offset, data| { + let v = (&v).clone(); pack_to(v, data, offset); })); self.count += n; @@ -127,8 +128,8 @@ impl Printer { let item_id = items.len() as u32; if_!( - offset.cmplt(data.len().uint()) - & (offset + 1 + args.count as u32).cmple(data.len().uint()), + offset.lt(data.len().cast::()) + & (offset + 1 + args.count as u32).le(data.len().cast::()), { data.atomic_fetch_add(0, 1); data.write(offset, item_id); diff --git a/luisa_compute/src/resource.rs b/luisa_compute/src/resource.rs index 2451746..8138012 100644 --- a/luisa_compute/src/resource.rs +++ b/luisa_compute/src/resource.rs @@ -906,15 +906,15 @@ macro_rules! impl_io_texel { } }; } -impl_io_texel!(f32, f32, Float4, |x: Float4Expr| x.x(), |x| { - Float4Expr::splat(x) +impl_io_texel!(f32, f32, Float4, |x: Float4Expr| x.x, |x| { + Float4::splat_expr(x) }); impl_io_texel!( Float2, f32, Float4, |x: Float4Expr| x.xy(), - |x: Float2Expr| { Float4::expr(x.x(), x.y(), 0.0, 0.0) } + |x: Float2Expr| { Float4::expr(x.x, x.y, 0.0, 0.0) } ); impl_io_texel!(Float4, f32, Float4, |x: Float4Expr| x, |x: Float4Expr| x); @@ -924,15 +924,15 @@ impl_io_texel!(Float4, f32, Float4, |x: Float4Expr| x, |x: Float4Expr| x); // impl_io_texel!(Short2,); // impl_io_texel!(Ushort4,); // impl_io_texel!(Short4,); -impl_io_texel!(u32, u32, Uint4, |x: Uint4Expr| x.x(), |x| Uint4Expr::splat( +impl_io_texel!(u32, u32, Uint4, |x: Uint4Expr| x.x, |x| Uint4::splat_expr( x )); -impl_io_texel!(i32, i32, Int4, |x: Int4Expr| x.x(), |x| Int4Expr::splat(x)); +impl_io_texel!(i32, i32, Int4, |x: Int4Expr| x.x, |x| Int4::splat_expr(x)); impl_io_texel!(Uint2, u32, Uint4, |x: Uint4Expr| x.xy(), |x: Uint2Expr| { - Uint4::expr(x.x(), x.y(), 0u32, 0u32) + Uint4::expr(x.x, x.y, 0u32, 0u32) }); impl_io_texel!(Int2, i32, Int4, |x: Int4Expr| x.xy(), |x: Int2Expr| { - Int4::expr(x.x(), x.y(), 0i32, 0i32) + Int4::expr(x.x, x.y, 0i32, 0i32) }); impl_io_texel!(Uint4, u32, Uint4, |x: Uint4Expr| x, |x| x); impl_io_texel!(Int4, i32, Int4, |x: Int4Expr| x, |x| x); @@ -1216,8 +1216,8 @@ impl<'a, T: IoTexel> Tex2dView<'a, T> { } pub fn size(&self) -> [u32; 3] { [ - (self.tex.handle.width >> self.level).max(1), - (self.tex.handle.height >> self.level).max(1), + Ord::max(self.tex.handle.width >> self.level, 1), + Ord::max(self.tex.handle.height >> self.level, 1), 1, ] } @@ -1236,9 +1236,9 @@ impl<'a, T: IoTexel> Tex3dView<'a, T> { } pub fn size(&self) -> [u32; 3] { [ - (self.tex.handle.width >> self.level).max(1), - (self.tex.handle.height >> self.level).max(1), - (self.tex.handle.depth >> self.level).max(1), + Ord::max(self.tex.handle.width >> self.level, 1), + Ord::max(self.tex.handle.height >> self.level, 1), + Ord::max(self.tex.handle.depth >> self.level, 1), ] } pub fn var(&self) -> Tex3dVar { @@ -1327,7 +1327,7 @@ impl IndexRead for BindlessBufferVar { fn read(&self, i: I) -> Expr { let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::::from_node(__current_scope(|b| { @@ -1616,7 +1616,7 @@ impl BindlessArrayVar { } else if is_cpu_backend() { if need_runtime_check() { let expected = type_hash(&T::type_()); - lc_assert!(v.__type().cmpeq(expected)); + lc_assert!(v.__type().eq(expected)); } } v @@ -1663,7 +1663,7 @@ impl IndexRead for Buffer { } } impl IndexWrite for Buffer { - fn write>>(&self, i: I, v: V) { + fn write>(&self, i: I, v: V) { self.var().write(i, v) } } @@ -1672,7 +1672,7 @@ impl IndexRead for BufferVar { fn read(&self, i: I) -> Expr { let i = i.to_u64(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } __current_scope(|b| { FromNode::from_node(b.call( @@ -1684,11 +1684,11 @@ impl IndexRead for BufferVar { } } impl IndexWrite for BufferVar { - fn write>>(&self, i: I, v: V) { + fn write>(&self, i: I, v: V) { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } __current_scope(|b| { b.call( @@ -1738,11 +1738,15 @@ impl BufferVar { macro_rules! impl_atomic { ($t:ty) => { impl BufferVar<$t> { - pub fn atomic_exchange>>(&self, i: I, v: V) -> Expr<$t> { + pub fn atomic_exchange>( + &self, + i: I, + v: V, + ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1752,17 +1756,21 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_compare_exchange>, V1: Into>>( + pub fn atomic_compare_exchange< + I: IntoIndex, + V0: AsExpr, + V1: AsExpr, + >( &self, i: I, expected: V0, desired: V1, ) -> Expr<$t> { let i = i.to_u64(); - let expected = expected.into(); - let desired = desired.into(); + let expected = expected.as_expr(); + let desired = desired.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1772,15 +1780,15 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_fetch_add>>( + pub fn atomic_fetch_add>( &self, i: I, v: V, ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1790,15 +1798,15 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_fetch_sub>>( + pub fn atomic_fetch_sub>( &self, i: I, v: V, ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1808,15 +1816,15 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_fetch_min>>( + pub fn atomic_fetch_min>( &self, i: I, v: V, ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1826,15 +1834,15 @@ macro_rules! impl_atomic { ) })) } - pub fn atomic_fetch_max>>( + pub fn atomic_fetch_max>( &self, i: I, v: V, ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1850,15 +1858,15 @@ macro_rules! impl_atomic { macro_rules! impl_atomic_bit { ($t:ty) => { impl BufferVar<$t> { - pub fn atomic_fetch_and>>( + pub fn atomic_fetch_and>( &self, i: I, v: V, ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1868,11 +1876,15 @@ macro_rules! impl_atomic_bit { ) })) } - pub fn atomic_fetch_or>>(&self, i: I, v: V) -> Expr<$t> { + pub fn atomic_fetch_or>( + &self, + i: I, + v: V, + ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( @@ -1882,15 +1894,15 @@ macro_rules! impl_atomic_bit { ) })) } - pub fn atomic_fetch_xor>>( + pub fn atomic_fetch_xor>( &self, i: I, v: V, ) -> Expr<$t> { let i = i.to_u64(); - let v = v.into(); + let v = v.as_expr(); if need_runtime_check() { - lc_assert!(i.cmplt(self.len())); + lc_assert!(i.lt(self.len())); } Expr::<$t>::from_node(__current_scope(|b| { b.call( diff --git a/luisa_compute/src/runtime/kernel.rs b/luisa_compute/src/runtime/kernel.rs index 2963cee..d70a26b 100644 --- a/luisa_compute/src/runtime/kernel.rs +++ b/luisa_compute/src/runtime/kernel.rs @@ -5,7 +5,7 @@ impl CallableParameter for Expr { builder.value::() } fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.var(*self) + encoder.var(self.clone()) } } impl CallableParameter for Var { @@ -13,7 +13,7 @@ impl CallableParameter for Var { builder.var::() } fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.var(*self) + encoder.var(self.clone()) } } @@ -105,15 +105,6 @@ impl CallableParameter for BindlessArrayVar { } } -impl CallableParameter for rtx::AccelVar { - fn def_param(_: Option>, builder: &mut KernelBuilder) -> Self { - builder.accel() - } - fn encode(&self, encoder: &mut CallableArgEncoder) { - encoder.accel(self) - } -} - pub trait KernelParameter { fn def_param(builder: &mut KernelBuilder) -> Self; } @@ -152,11 +143,6 @@ impl KernelParameter for BindlessArrayVar { } } -impl KernelParameter for rtx::AccelVar { - fn def_param(builder: &mut KernelBuilder) -> Self { - builder.accel() - } -} macro_rules! impl_kernel_param_for_tuple { ($first:ident $($rest:ident)*) => { impl<$first:KernelParameter, $($rest: KernelParameter),*> KernelParameter for ($first, $($rest,)*) { @@ -275,14 +261,6 @@ impl KernelBuilder { self.args.push(node); BindlessArrayVar { node, handle: None } } - pub fn accel(&mut self) -> rtx::AccelVar { - let node = new_node( - __module_pools(), - Node::new(CArc::new(Instruction::Accel), Type::void()), - ); - self.args.push(node); - rtx::AccelVar { node, handle: None } - } fn collect_module_info(&self) -> (ResourceTracker, Vec>, Vec) { RECORDER.with(|r| { let mut resource_tracker = ResourceTracker::new();