Skip to content

Commit

Permalink
Merge pull request #42 from entropylost/main
Browse files Browse the repository at this point in the history
Cleanup and add more IoTexel implementations.
  • Loading branch information
shiinamiyuki authored Oct 1, 2024
2 parents c3e42ea + fcf4f5a commit 30532e3
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 55 deletions.
24 changes: 0 additions & 24 deletions luisa_compute/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ pub(crate) trait CallFuncTrait {
y: Expr<S>,
z: Expr<U>,
) -> Expr<V>;
fn call_void<T: Value>(self, x: Expr<T>);
fn call2_void<T: Value, S: Value>(self, x: Expr<T>, y: Expr<S>);
fn call3_void<T: Value, S: Value, U: Value>(self, x: Expr<T>, y: Expr<S>, z: Expr<U>);
}
impl CallFuncTrait for Func {
fn call<T: Value, S: Value>(self, x: Expr<T>) -> Expr<S> {
Expand Down Expand Up @@ -82,27 +79,6 @@ impl CallFuncTrait for Func {
b.call(self, &[x, y, z], <V as TypeOf>::type_())
})))
}
fn call_void<T: Value>(self, x: Expr<T>) {
let x = x.node().get();
__current_scope(|b| {
b.call(self, &[x], Type::void());
});
}
fn call2_void<T: Value, S: Value>(self, x: Expr<T>, y: Expr<S>) {
let x = x.node().get();
let y = y.node().get();
__current_scope(|b| {
b.call(self, &[x, y], Type::void());
});
}
fn call3_void<T: Value, S: Value, U: Value>(self, x: Expr<T>, y: Expr<S>, z: Expr<U>) {
let x = x.node().get();
let y = y.node().get();
let z = z.node().get();
__current_scope(|b| {
b.call(self, &[x, y, z], Type::void());
});
}
}

/**
Expand Down
2 changes: 0 additions & 2 deletions luisa_compute/src/lang/control_flow.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::ffi::CString;

use crate::internal_prelude::*;
use ir::SwitchCase;

Expand Down
53 changes: 30 additions & 23 deletions luisa_compute/src/lang/types/vector/gen_swizzle.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import List
from itertools import permutations, product
s = ''

s = ""


def swizzle_name(perm: List[int]):
return ''.join('xyzw'[i] for i in perm)
return "".join("xyzw"[i] for i in perm)


swizzles2 = list(product(range(4), repeat=2))
swizzles3 = list(product(range(4), repeat=3))
Expand All @@ -15,29 +19,32 @@ def swizzle_name(perm: List[int]):
sw_m_to_n = {}
for m in range(2, 5):
for n in range(2, 5):
comps = 'xyzw'[:m]
sw_m_to_n[(m, n)] = [sw for sw in all_swizzles[n] if len(sw) == n and all([s < m for s in sw])]
for n in range(2,5):
s += 'pub trait Vec{}Swizzle {{\n'.format(n)
s += ' type Vec2;\n'
s += ' type Vec3;\n'
s += ' type Vec4;\n'
s += ' fn permute2(&self, x: u32, y: u32) -> Self::Vec2;\n'
s += ' fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3;\n'
s += ' fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4;\n'
sw_m_to_n[(m, n)] = [
sw for sw in all_swizzles[n] if len(sw) == n and all([s < m for s in sw])
]
for n in range(2, 5):
s += "pub trait Vec{}Swizzle {{\n".format(n)
s += " type Vec2;\n"
s += " type Vec3;\n"
s += " type Vec4;\n"
s += " fn permute2(&self, x: u32, y: u32) -> Self::Vec2;\n"
s += " fn permute3(&self, x: u32, y: u32, z: u32) -> Self::Vec3;\n"
s += " fn permute4(&self, x: u32, y: u32, z: u32, w: u32) -> Self::Vec4;\n"
for sw in sw_m_to_n[(n, 2)]:
s += ' fn {}(&self) -> Self::Vec2 {{\n'.format(swizzle_name(sw))
s += ' self.permute2({}, {})\n'.format(sw[0], sw[1])
s += ' }\n'
s += " fn {}(&self) -> Self::Vec2 {{\n".format(swizzle_name(sw))
s += " self.permute2({}, {})\n".format(sw[0], sw[1])
s += " }\n"
for sw in sw_m_to_n[(n, 3)]:
s += ' fn {}(&self) -> Self::Vec3 {{\n'.format(swizzle_name(sw))
s += ' self.permute3({}, {}, {})\n'.format(sw[0], sw[1], sw[2])
s += ' }\n'
s += " fn {}(&self) -> Self::Vec3 {{\n".format(swizzle_name(sw))
s += " self.permute3({}, {}, {})\n".format(sw[0], sw[1], sw[2])
s += " }\n"
for sw in sw_m_to_n[(n, 4)]:
s += ' fn {}(&self) -> Self::Vec4 {{\n'.format(swizzle_name(sw))
s += ' self.permute4({}, {}, {}, {})\n'.format(sw[0], sw[1], sw[2], sw[3])
s += ' }\n'
s += '}\n'
s += " fn {}(&self) -> Self::Vec4 {{\n".format(swizzle_name(sw))
s += " self.permute4({}, {}, {}, {})\n".format(
sw[0], sw[1], sw[2], sw[3]
)
s += " }\n"
s += "}\n"

with open('swizzle.rs', 'w') as f:
with open("swizzle.rs", "w") as f:
f.write(s)
51 changes: 45 additions & 6 deletions luisa_compute/src/resource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,12 +292,12 @@ impl<T: Value> BufferView<T> {
}
impl<T: Value + fmt::Debug> fmt::Debug for Buffer<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct DebugEllipsis;
impl fmt::Debug for DebugEllipsis {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("..")
}
}
// struct DebugEllipsis;
// impl fmt::Debug for DebugEllipsis {
// fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// f.write_str("..")
// }
// }

write!(f, "Buffer<{}>({})", std::any::type_name::<T>(), self.len())?;
// if self.len() <= 16 || f.precision().is_some() {
Expand Down Expand Up @@ -901,6 +901,37 @@ macro_rules! impl_io_texel {
}
};
}
impl_io_texel!(
f16,
f32,
Float4,
|a: Expr<Float4>| a.x.cast_f16(),
|x: Expr<f16>| { Float4::splat_expr(x.cast_f32()) }
);
impl_io_texel!(
Half2,
f32,
Float4,
|a: Expr<Float4>| a.xy().cast_f16(),
|x: Expr<Half2>| {
let x = x.cast_f32();
Float4::expr(x.x, x.y, 0.0, 0.0)
}
);
impl_io_texel!(
Half3,
f32,
Float4,
|a: Expr<Float4>| a.xyz().cast_f16(),
|x: Expr<Half3>| { x.cast_f32().extend(0.0) }
);
impl_io_texel!(
Half4,
f32,
Float4,
|a: Expr<Float4>| a.cast_f16(),
|x: Expr<Half4>| { x.cast_f32() }
);
impl_io_texel!(
bool,
f32,
Expand Down Expand Up @@ -946,6 +977,14 @@ impl_io_texel!(Uint2, u32, Uint4, |x: Expr<Uint4>| x.xy(), |x: Expr<
impl_io_texel!(Int2, i32, Int4, |x: Expr<Int4>| x.xy(), |x: Expr<Int2>| {
Int4::expr(x.x, x.y, 0i32, 0i32)
});
impl_io_texel!(Uint3, u32, Uint4, |x: Expr<Uint4>| x.xyz(), |x: Expr<
Uint3,
>| {
Uint4::expr(x.x, x.y, x.z, 0u32)
});
impl_io_texel!(Int3, i32, Int4, |x: Expr<Int4>| x.xyz(), |x: Expr<Int3>| {
Int4::expr(x.x, x.y, x.z, 0i32)
});
impl_io_texel!(Uint4, u32, Uint4, |x: Expr<Uint4>| x, |x| x);
impl_io_texel!(Int4, i32, Int4, |x: Expr<Int4>| x, |x| x);

Expand Down

0 comments on commit 30532e3

Please sign in to comment.