Skip to content

Commit

Permalink
Merge pull request #21 from iMplode-nZ/main
Browse files Browse the repository at this point in the history
Added Deref bound on VarProxy and fixed Kernel Command lifetime.
  • Loading branch information
shiinamiyuki authored Sep 23, 2023
2 parents b06f49e + 898009a commit 19caaa9
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 16 deletions.
4 changes: 1 addition & 3 deletions luisa_compute/src/lang/ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use std::ops::*;
use super::types::core::{Floating, Integral, Numeric, Primitive, Signed};
use super::types::vector::{VectorAlign, VectorElement};

mod cast_impls;
mod impls;
mod spread;
mod traits;
mod cast_impls;

pub use spread::*;
pub use traits::*;
Expand All @@ -24,8 +24,6 @@ pub trait Linear: Value {
const N: usize;
type Scalar: VectorElement;
type WithScalar<S: VectorElement>: Linear<Scalar = S>;
// We don't actually know that the vector has equivalent vectors of every
// primitive type.
}
impl<T: VectorElement> Linear for T {
const N: usize = 1;
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/src/lang/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ pub trait ExprProxy: Copy + 'static {
/// For example, `Var<[f32; 4]>` dereferences to `ArrayVar<f32, 4>`, which
/// exposes [`Index`](std::ops::Index) and [`IndexMut`](std::ops::IndexMut)
/// impls.
pub trait VarProxy: Copy + 'static {
pub trait VarProxy: Copy + 'static + Deref<Target = Expr<Self::Value>> {
type Value: Value<Var = Self>;
fn as_var_from_proxy(&self) -> &Var<Self::Value>;

Expand Down
30 changes: 18 additions & 12 deletions luisa_compute/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,12 +401,14 @@ impl Device {
}
}

/// Compile a [`KernelDef`] into a [`Kernel`]. See [`Kernel`] for more details on kernel creation
/// Compile a [`KernelDef`] into a [`Kernel`]. See [`Kernel`] for more
/// details on kernel creation
pub fn compile_kernel<S: KernelSignature>(&self, k: &KernelDef<S>) -> Kernel<S> {
self.compile_kernel_with_options(k, KernelBuildOptions::default())
}

/// Compile a [`KernelDef`] into a [`Kernel`] asynchronously. See [`Kernel`] for more details on kernel creation
/// Compile a [`KernelDef`] into a [`Kernel`] asynchronously. See [`Kernel`]
/// for more details on kernel creation
pub fn compile_kernel_async<S: KernelSignature>(&self, k: &KernelDef<S>) -> Kernel<S> {
self.compile_kernel_with_options(
k,
Expand All @@ -417,7 +419,8 @@ impl Device {
)
}

/// Compile a [`KernelDef`] into a [`Kernel`] with options. See [`Kernel`] for more details on kernel creation
/// Compile a [`KernelDef`] into a [`Kernel`] with options. See [`Kernel`]
/// for more details on kernel creation
pub fn compile_kernel_with_options<S: KernelSignature>(
&self,
k: &KernelDef<S>,
Expand Down Expand Up @@ -1084,7 +1087,11 @@ impl RawKernel {
}
}

pub fn dispatch_async(&self, args: KernelArgEncoder, dispatch_size: [u32; 3]) -> Command {
pub fn dispatch_async(
&self,
args: KernelArgEncoder,
dispatch_size: [u32; 3],
) -> Command<'static> {
let mut rt = ResourceTracker::new();
rt.add(Arc::new(args.uniform_data));
let args = args.args;
Expand Down Expand Up @@ -1218,20 +1225,19 @@ pub struct KernelDef<T: KernelSignature> {
/// use luisa_compute::prelude::*;
/// let ctx = Context::new(std::env::current_exe().unwrap());
/// let device = ctx.create_device("cpu");
/// let kernel = KernelDef::<fn(Buffer<f32>, Buffer<f32>, Buffer<f32>)>::new(&device, track!(|a,b,c|{ }));
/// // Compilation:
/// let kernel = KernelDef::<fn(Buffer<f32>, Buffer<f32>,
/// Buffer<f32>)>::new(&device, track!(|a,b,c|{ })); // Compilation:
/// let kernel = device.compile_kernel(&kernel);
/// ```
/// - Recording and compilation in one step:
/// ```no_run
/// use luisa_compute::prelude::*;
/// let ctx = Context::new(std::env::current_exe().unwrap());
/// let device = ctx.create_device("cpu");
/// let kernel = Kernel::<fn(Buffer<f32>, Buffer<f32>, Buffer<f32>)>::new(&device, track!(|a,b,c|{ }));
/// ```
/// let kernel = Kernel::<fn(Buffer<f32>, Buffer<f32>,
/// Buffer<f32>)>::new(&device, track!(|a,b,c|{ })); ```
/// - Asynchronous compilation use [`Kernel::<T>::new_async`]
/// - Custom build options using [`Kernel::<T>::new_with_options`]
///
pub struct Kernel<T: KernelSignature> {
pub(crate) inner: RawKernel,
pub(crate) _marker: PhantomData<T>,
Expand Down Expand Up @@ -1335,10 +1341,10 @@ macro_rules! impl_dispatch_for_kernel {
}
#[allow(non_snake_case)]
#[allow(unused_mut)]
pub fn dispatch_async<'a>(
&'a self,
pub fn dispatch_async(
&self,
dispatch_size: [u32; 3], $($Ts:&impl AsKernelArg<$Ts>),*
) -> Command<'a> {
) -> Command<'static> {
let mut encoder = KernelArgEncoder::new();
$($Ts.encode(&mut encoder);)*
self.inner.dispatch_async(encoder, dispatch_size)
Expand Down

0 comments on commit 19caaa9

Please sign in to comment.