Skip to content

Commit

Permalink
insert line number to every tracked control flow stmt/fn
Browse files Browse the repository at this point in the history
  • Loading branch information
shiinamiyuki committed Oct 15, 2023
1 parent d11eba3 commit 484772b
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 23 deletions.
2 changes: 1 addition & 1 deletion luisa_compute/examples/path_tracer_cutout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ fn main() {
s.submit(cmds);
s.synchronize();
});
let spp_per_dispatch = 1;;
let spp_per_dispatch = 1;
// use create_kernel_async to compile multiple kernels in parallel
let path_tracer = Kernel::<fn(Tex2d<Float4>, Tex2d<u32>, Accel, Uint2)>::new_async(
&device,
Expand Down
2 changes: 1 addition & 1 deletion luisa_compute/examples/raytracing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use image::Rgb;
use luisa::lang::types::vector::alias::*;
use luisa::lang::types::vector::*;
use luisa::lang::types::*;
use luisa::lang::*;

use luisa::prelude::*;
use luisa::rtx::{AccelBuildRequest, AccelOption, Ray};
use luisa_compute as luisa;
Expand Down
25 changes: 23 additions & 2 deletions luisa_compute/src/lang.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,9 +309,21 @@ pub(crate) struct FnRecorder {
pub(crate) pools: CArc<ModulePools>,
pub(crate) arena: Bump,
pub(crate) callable_ret_type: Option<CArc<Type>>,
pub(crate) const_builder: IrBuilder,
pub(crate) index_const_pool: IndexMap<i32, NodeRef>,
}
pub(crate) type FnRecorderPtr = Rc<RefCell<FnRecorder>>;
impl FnRecorder {
pub(crate) fn make_index_const(&mut self, idx: i32) -> NodeRef {
if let Some(node) = self.index_const_pool.get(&idx) {
return *node;
}
let b = &mut self.const_builder;
let node = b.const_(Const::Int32(idx));
self.defined.insert(node, true);
self.index_const_pool.insert(idx, node);
node
}
pub(crate) fn add_block_to_inaccessible(&self, block: &BasicBlock) {
let mut inaccessible = self.inaccessible.borrow_mut();
for n in block.iter() {
Expand Down Expand Up @@ -390,6 +402,7 @@ impl FnRecorder {
}
}
pub(crate) fn new(kernel_id: usize, parent: Option<FnRecorderPtr>) -> Self {
let pools = CArc::new(ModulePools::new());
FnRecorder {
inaccessible: parent
.as_ref()
Expand All @@ -404,12 +417,14 @@ impl FnRecorder {
shared: vec![],
device: None,
block_size: None,
pools: CArc::new(ModulePools::new()),
pools: pools.clone(),
arena: Bump::new(),
building_kernel: false,
callable_ret_type: None,
kernel_id,
parent,
index_const_pool: IndexMap::new(),
const_builder: IrBuilder::new(pools.clone()),
}
}
pub(crate) fn map_captured_vars(&mut self, node0: SafeNodeRef) -> SafeNodeRef {
Expand Down Expand Up @@ -566,6 +581,12 @@ pub(crate) fn pop_recorder() -> FnRecorderPtr {
cur.unwrap()
})
}
pub(crate) fn recording_started() -> bool {
RECORDER.with(|r| {
let r = r.borrow();
r.is_some()
})
}

pub(crate) fn with_recorder<R>(f: impl FnOnce(&mut FnRecorder) -> R) -> R {
RECORDER.with(|r| {
Expand Down Expand Up @@ -741,7 +762,7 @@ fn __extract_impl(safe_node: SafeNodeRef, index: usize, ty: CArc<Type>) -> SafeN
b.set_insert_point(first_scope_bb.first());
}

let i = b.const_(Const::Int32(index as i32));
let i = r.make_index_const(index as i32);
// Since we have inserted something, the insertion point in cur_builder might
// not be up to date So we need to set it to the end of the current
// basic block
Expand Down
13 changes: 11 additions & 2 deletions luisa_compute/src/lang/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt::Debug;

use crate::internal_prelude::*;

use super::with_recorder;
use super::{with_recorder, recording_started};

#[macro_export]
macro_rules! cpu_dbg {
Expand Down Expand Up @@ -139,6 +139,9 @@ pub fn __assert(cond: impl Into<Expr<bool>>, msg: &str, file: &str, line: u32, c
}

pub fn comment(msg: &str) {
if !recording_started() {
return;
}
__current_scope(|b| {
b.comment(CBoxedSlice::new(
CString::new(msg).unwrap().into_bytes_with_nul(),
Expand All @@ -152,6 +155,12 @@ macro_rules! lc_comment_lineno {
$crate::lang::debug::comment(&format!("{}:{}:{}", file!(), line!(), column!()))
};
($msg:literal) => {
$crate::lang::debug::comment(&format!("`{}` at {}:{}:{}", $msg, file!(), line!(), column!()))
$crate::lang::debug::comment(&format!(
"`{}` at {}:{}:{}",
$msg,
file!(),
line!(),
column!()
))
};
}
31 changes: 26 additions & 5 deletions luisa_compute/src/runtime/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,14 @@ macro_rules! impl_kernel_param_for_tuple {
}
}
impl_kernel_param_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8 T9 T10 T11 T12 T13 T14 T15);

fn transform_module(module: Module) -> Module {
// use luisa_compute_ir::transform::Transform;
// let module = luisa_compute_ir::transform::dce::Dce.transform(module);
let module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(module);
module
}

impl KernelBuilder {
pub fn new(device: Option<crate::runtime::Device>, is_kernel: bool) -> Self {
let kernel_id = RECORDER.with(|r| {
Expand All @@ -204,8 +212,7 @@ impl KernelBuilder {
push_recorder(kernel_id);
with_recorder(|r| {
r.device = device.as_ref().map(|d| WeakDevice::new(d));
r.pools = CArc::new(ModulePools::new());
r.scopes.clear();
assert!(r.scopes.is_empty());
r.building_kernel = is_kernel;
let pools = r.pools.clone();
r.scopes.push(IrBuilder::new(pools));
Expand Down Expand Up @@ -412,6 +419,13 @@ impl KernelBuilder {
assert_eq!(r.scopes.len(), 1);
let scope = r.scopes.pop().unwrap();
let entry = scope.finish();
let const_block = std::mem::replace(
&mut r.const_builder,
IrBuilder::new_without_bb(r.pools.clone()),
)
.finish();
const_block.merge(entry);
let entry = const_block;
r.add_block_to_inaccessible(&entry);
let ir_module = Module {
entry,
Expand All @@ -420,12 +434,12 @@ impl KernelBuilder {
flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM
| ModuleFlags::REQUIRES_FWD_AD_TRANSFORM,
};
let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module);
let ir_module = transform_module(ir_module);

let mut args = self.args.clone();

args.extend(r.captured_vars.values().map(|x| unsafe { x.1.get_raw() }));

for a in &args {
r.inaccessible.borrow_mut().insert(*a);
}
Expand Down Expand Up @@ -463,7 +477,14 @@ impl KernelBuilder {
let ret = with_recorder(|r| {
assert_eq!(r.scopes.len(), 1);
let scope = r.scopes.pop().unwrap();
let const_block = std::mem::replace(
&mut r.const_builder,
IrBuilder::new_without_bb(r.pools.clone()),
)
.finish();
let entry = scope.finish();
const_block.merge(entry);
let entry = const_block;
assert!(r.captured_vars.is_empty());
let ir_module = Module {
entry,
Expand All @@ -472,7 +493,7 @@ impl KernelBuilder {
flags: ModuleFlags::REQUIRES_REV_AD_TRANSFORM
| ModuleFlags::REQUIRES_FWD_AD_TRANSFORM,
};
let ir_module = luisa_compute_ir::transform::luisa_compute_ir_transform_auto(ir_module);
let ir_module = transform_module(ir_module);
let module = KernelModule {
module: ir_module,
cpu_custom_ops: CBoxedSlice::new(cpu_custom_ops),
Expand Down
73 changes: 61 additions & 12 deletions luisa_compute_track/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use syn::*;
// TODO: Impl x as f32 -> .cast() <- Don't
// TOOD: Impl switch! macro. <- Don't


struct TraceVisitor {
trait_path: TokenStream,
flow_path: TokenStream,
Expand Down Expand Up @@ -99,25 +98,45 @@ impl VisitMut for TraceVisitor {
if let Expr::Let(_) = **cond {
} else if let Some((_, else_branch)) = else_branch {
*node = parse_quote_spanned! {span=>
<_ as #trait_path::SelectMaybeExpr<_>>::if_then_else(#cond, || #then_branch, || #else_branch)
{
::luisa_compute::lc_comment_lineno!("if stmt begin");
let __ret = <_ as #trait_path::SelectMaybeExpr<_>>::if_then_else(#cond, || #then_branch, || #else_branch);
::luisa_compute::lc_comment_lineno!("if stmt end");
__ret
}
}
} else {
*node = parse_quote_spanned! {span=>
<_ as #trait_path::ActivateMaybeExpr>::activate(#cond, || #then_branch)
{
::luisa_compute::lc_comment_lineno!("if stmt begin");
let __ret = <_ as #trait_path::ActivateMaybeExpr>::activate(#cond, || #then_branch);
::luisa_compute::lc_comment_lineno!("if stmt end");
__ret
}
}
}
}
Expr::While(expr) => {
let cond = &expr.cond;
let body = &expr.body;
*node = parse_quote_spanned! {span=>
<_ as #trait_path::LoopMaybeExpr>::while_loop(|| #cond, || #body)
{
::luisa_compute::lc_comment_lineno!("while stmt begin");
let __ret = <_ as #trait_path::LoopMaybeExpr>::while_loop(|| #cond, || #body);
::luisa_compute::lc_comment_lineno!("while stmt end");
__ret
}
}
}
Expr::Loop(expr) => {
let body = &expr.body;
*node = parse_quote_spanned! {span=>
#flow_path::loop_(|| #body)
{
::luisa_compute::lc_comment_lineno!("loop stmt begin");
let __ret = #flow_path::loop_(|| #body);
::luisa_compute::lc_comment_lineno!("loop stmt end");
__ret
}
}
}
Expr::ForLoop(expr) => {
Expand All @@ -127,19 +146,25 @@ impl VisitMut for TraceVisitor {
if let Expr::Range(range) = &**expr {
let attrs = &range.attrs;
// check if #[unroll] is present
let unroll = attrs.iter().any(|attr| {
attr.path().is_ident("unroll")
});
let unroll = attrs.iter().any(|attr| attr.path().is_ident("unroll"));
if unroll {
*node = parse_quote_spanned! {span=>
#range.for_each(|#pat| #body)
{
::luisa_compute::lc_comment_lineno!("for loop stmt begin");
let __ret = #range.for_each(|#pat| #body);
::luisa_compute::lc_comment_lineno!("for loop stmt end");
__ret
}
}
} else {
*node = parse_quote_spanned! {span=>
#flow_path::for_range(#range, |#pat| #body)
*node = parse_quote_spanned! {span=> {
::luisa_compute::lc_comment_lineno!("for loop stmt begin");
let __ret = #flow_path::for_range(#range, |#pat| #body);
::luisa_compute::lc_comment_lineno!("for loop stmt end");
__ret
}
}
}

}
}
// Expr::Unary(op) => {
Expand Down Expand Up @@ -345,8 +370,32 @@ pub fn tracked(
) -> proc_macro::TokenStream {
let item = syn::parse_macro_input!(item as ItemFn);
let body = &item.block;
let body_span = body.span();
let ret_type = match &item.sig.output {
ReturnType::Default => quote_spanned! {body_span=> () },
ReturnType::Type(_, ty) => quote_spanned! {body_span=> #ty },
};
let body = proc_macro::TokenStream::from(quote!({ #body }));
let body = track_impl(parse_macro_input!(body as Expr));
let body = quote_spanned! {body_span=>
{
let __fn_name = {
fn f() {}
fn type_name_of<T>(_: T) -> &'static str {
std::any::type_name::<T>()
}
let name = type_name_of(f);
name.strip_suffix("::f").unwrap()
};
::luisa_compute::lang::debug::comment(&format!("begin fn {} at {}:{}:{}", __fn_name, file!(), line!(), column!()));
let __ret: #ret_type = #body;
#[allow(unreachable_code)]
{
::luisa_compute::lang::debug::comment(&format!("end fn {} at {}:{}:{}", __fn_name, file!(), line!(), column!()));
__ret
}
}
};
let attrs = &item.attrs;
let sig = &item.sig;
let vis = &item.vis;
Expand Down

0 comments on commit 484772b

Please sign in to comment.