Skip to content

Commit

Permalink
[VAN-1370] handle non-scalar function parameters in BucketInterpreter (
Browse files Browse the repository at this point in the history
  • Loading branch information
tim-hoffman authored Jun 26, 2024
1 parent 940ccdd commit 0bc3ca6
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 38 deletions.
21 changes: 11 additions & 10 deletions circom/tests/constraints/with_call_vec_arg.circom
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,24 @@ template ComputeFee() {
component main = ComputeFee();

//CHECK-LABEL: define{{.*}} void @..generated..loop.body.
//CHECK-SAME: [[$F_ID_1:[0-9]+]]([0 x i256]* %lvars, [0 x i256]* %signals, i256* %sig_0, i256* %var_1){{.*}} {
//CHECK-SAME: [[$F_ID_1:[0-9]+]]([0 x i256]* %lvars, [0 x i256]* %signals, i256* %sig_0){{.*}} {
//CHECK-NEXT: ..generated..loop.body.[[$F_ID_1]]:
//CHECK-NEXT: br label %call1
//CHECK-EMPTY:
//CHECK-NEXT: call1:
//CHECK-NEXT: %[[FUN_NAME:[0-9a-zA-Z_.]+]]_arena = alloca [3 x i256], align 8
//CHECK-NEXT: %[[T00:[0-9a-zA-Z_.]+]] = getelementptr [3 x i256], [3 x i256]* %[[FUN_NAME]]_arena, i32 0, i32 0
//CHECK-NEXT: %[[T01:[0-9a-zA-Z_.]+]] = getelementptr i256, i256* %var_1, i32 0
//CHECK-NEXT: %[[COPY_SRC_0:[0-9a-zA-Z_.]+]] = getelementptr i256, i256* %[[T01]], i32 0
//CHECK-NEXT: %[[T01:[0-9a-zA-Z_.]+]] = getelementptr [0 x i256], [0 x i256]* %lvars, i32 0, i32 8
//CHECK-NEXT: %[[T12:[0-9a-zA-Z_.]+]] = load i256, i256* %[[T01]], align 4
//CHECK-NEXT: %[[T13:[0-9a-zA-Z_.]+]] = call i32 @fr_cast_to_addr(i256 %[[T12]])
//CHECK-NEXT: %[[T14:[0-9a-zA-Z_.]+]] = mul i32 2, %[[T13]]
//CHECK-NEXT: %[[T15:[0-9a-zA-Z_.]+]] = add i32 %[[T14]], 0
//CHECK-NEXT: %[[T16:[0-9a-zA-Z_.]+]] = getelementptr [0 x i256], [0 x i256]* %lvars, i32 0, i32 %[[T15]]
//CHECK-NEXT: %[[COPY_SRC_0:[0-9a-zA-Z_.]+]] = getelementptr i256, i256* %[[T16]], i32 0
//CHECK-NEXT: %[[COPY_DST_0:[0-9a-zA-Z_.]+]] = getelementptr i256, i256* %[[T00]], i32 0
//CHECK-NEXT: %[[COPY_VAL_0:[0-9a-zA-Z_.]+]] = load i256, i256* %[[COPY_SRC_0]], align 4
//CHECK-NEXT: store i256 %[[COPY_VAL_0]], i256* %[[COPY_DST_0]], align 4
//CHECK-NEXT: %[[COPY_SRC_1:[0-9a-zA-Z_.]+]] = getelementptr i256, i256* %[[T01]], i32 1
//CHECK-NEXT: %[[COPY_SRC_1:[0-9a-zA-Z_.]+]] = getelementptr i256, i256* %[[T16]], i32 1
//CHECK-NEXT: %[[COPY_DST_1:[0-9a-zA-Z_.]+]] = getelementptr i256, i256* %[[T00]], i32 1
//CHECK-NEXT: %[[COPY_VAL_1:[0-9a-zA-Z_.]+]] = load i256, i256* %[[COPY_SRC_1]], align 4
//CHECK-NEXT: store i256 %[[COPY_VAL_1]], i256* %[[COPY_DST_1]], align 4
Expand Down Expand Up @@ -141,14 +146,10 @@ component main = ComputeFee();
//CHECK-NEXT: unrolled_loop8:
//CHECK-NEXT: %[[T10:[0-9a-zA-Z_.]+]] = bitcast [9 x i256]* %lvars to [0 x i256]*
//CHECK-NEXT: %[[T11:[0-9a-zA-Z_.]+]] = getelementptr [0 x i256], [0 x i256]* %0, i32 0, i256 0
//CHECK-NEXT: %[[T12:[0-9a-zA-Z_.]+]] = bitcast [9 x i256]* %lvars to [0 x i256]*
//CHECK-NEXT: %[[T13:[0-9a-zA-Z_.]+]] = getelementptr [0 x i256], [0 x i256]* %[[T12]], i32 0, i256 0
//CHECK-NEXT: call void @..generated..loop.body.[[$F_ID_1]]([0 x i256]* %[[T10]], [0 x i256]* %0, i256* %[[T11]], i256* %[[T13]])
//CHECK-NEXT: call void @..generated..loop.body.[[$F_ID_1]]([0 x i256]* %[[T10]], [0 x i256]* %0, i256* %[[T11]])
//CHECK-NEXT: %[[T14:[0-9a-zA-Z_.]+]] = bitcast [9 x i256]* %lvars to [0 x i256]*
//CHECK-NEXT: %[[T15:[0-9a-zA-Z_.]+]] = getelementptr [0 x i256], [0 x i256]* %0, i32 0, i256 1
//CHECK-NEXT: %[[T16:[0-9a-zA-Z_.]+]] = bitcast [9 x i256]* %lvars to [0 x i256]*
//CHECK-NEXT: %[[T17:[0-9a-zA-Z_.]+]] = getelementptr [0 x i256], [0 x i256]* %[[T16]], i32 0, i256 2
//CHECK-NEXT: call void @..generated..loop.body.[[$F_ID_1]]([0 x i256]* %[[T14]], [0 x i256]* %0, i256* %[[T15]], i256* %[[T17]])
//CHECK-NEXT: call void @..generated..loop.body.[[$F_ID_1]]([0 x i256]* %[[T14]], [0 x i256]* %0, i256* %[[T15]])
//CHECK-NEXT: br label %prologue
//CHECK-EMPTY:
//CHECK-NEXT: prologue:
Expand Down
59 changes: 50 additions & 9 deletions circuit_passes/src/bucket_interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ pub mod write_collector;
use std::cell::RefCell;
use std::ops::Range;
use paste::paste;
use code_producers::llvm_elements::{fr, array_switch};
use code_producers::llvm_elements::{array_switch, fr};
use code_producers::llvm_elements::stdlib::{GENERATED_FN_PREFIX, LLVM_DONOTHING_FN_NAME};
use compiler::intermediate_representation::{
BucketId, Instruction, InstructionList, InstructionPointer,
new_id, BucketId, Instruction, InstructionList, InstructionPointer,
};
use compiler::intermediate_representation::ir_interface::*;
use compiler::num_bigint::BigInt;
use observer::Observer;
use program_structure::error_code::ReportCode;
use crate::passes::builders::{build_compute, build_u32_value};
use crate::passes::loop_unroll::LOOP_BODY_FN_PREFIX;
use crate::passes::GlobalPassData;
use self::env::{CallStack, CallStackFrame, Env, LibraryAccess};
Expand Down Expand Up @@ -738,7 +739,6 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> {
env: Env<'env>,
observe: bool,
) -> REI<'env> {
let mut env = env;
let res = if bucket.symbol.starts_with(GENERATED_FN_PREFIX) {
// ASSUME: The arguments to a generated function will always be LoadBucket with 'bounded_fn'
// that are intended to generate pointers or a call to some built-in function that returns
Expand All @@ -760,12 +760,53 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> {
self._execute_function_extracted(&bucket, env, observe)
))
} else {
let mut args = Vec::with_capacity(bucket.arguments.len());
for i in &bucket.arguments {
let (val, new_env) = check_res!(self._execute_instruction(i, env, observe));
let val = check_std_res!(opt_as_result(val, "function argument"));
args.push(val);
env = new_env;
let mut args = vec![];
for a in &bucket.arguments {
// Case: vector load
if let Instruction::Load(load) = &**a {
let load_size = load.context.size;
if load_size > 1 {
assert!(load.bounded_fn.is_none());
match &load.src {
LocationRule::Mapped { .. } => todo!("Can this happen?"),
LocationRule::Indexed { location, template_header } => {
for i in 0..load_size {
let scalar_load = LoadBucket {
id: new_id(),
source_file_id: load.source_file_id,
line: load.line,
message_id: load.message_id,
address_type: load.address_type.clone(),
src: LocationRule::Indexed {
location: build_compute(
load,
OperatorType::Add,
0,
vec![location.clone(), build_u32_value(load, i)],
),
template_header: template_header.clone(),
},
context: InstrContext { size: 1 },
bounded_fn: None,
}
.allocate();
let val = check_res!(
self._compute_instruction(&scalar_load, &env, observe),
|v| (v, env)
);
args.push(check_std_res!(opt_as_result(
val,
"function argument"
)));
}
}
}
continue;
}
}
// Case: anything else
let val = check_res!(self._compute_instruction(a, &env, observe), |v| (v, env));
args.push(check_std_res!(opt_as_result(val, "function argument")));
}
check_res!(InterpRes::try_continue(self._execute_function_basic(
&bucket.symbol,
Expand Down
34 changes: 15 additions & 19 deletions circuit_passes/src/passes/unreachable_code_removal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::cell::RefCell;
use std::collections::HashSet;
use code_producers::llvm_elements::stdlib::GENERATED_FN_PREFIX;
use paste::paste;
use compiler::circuit_design::template::TemplateCode;
use compiler::intermediate_representation::{new_id, BucketId};
Expand Down Expand Up @@ -131,25 +130,22 @@ impl CircuitTransformationPass for UnreachableRemovalPass<'_> {

fn transform_call_bucket(&self, bucket: &CallBucket) -> Result<InstructionPointer, BadInterp> {
if self.visited.borrow_mut().remove(&bucket.id) {
if bucket.symbol.starts_with(GENERATED_FN_PREFIX) {
// BucketInterpreter::_execute_call_bucket() will never visit arguments
// within these generated functions so we need a special case so they
// are not removed unless the CallBucket is removed entirely.
Ok(CallBucket {
id: new_id(),
source_file_id: bucket.source_file_id,
line: bucket.line,
message_id: bucket.message_id,
symbol: bucket.symbol.to_string(),
argument_types: bucket.argument_types.clone(),
arguments: bucket.arguments.clone(),
arena_size: bucket.arena_size,
return_info: self.transform_return_type(&bucket.id, &bucket.return_info)?,
}
.allocate())
} else {
self.transform_call_bucket_default(bucket)
// BucketInterpreter::_execute_call_bucket() will never visit arguments within
// generated functions and will only visit the scalar arguments of all other
// functions so we need a special case to prevent the arguments from being
// removed unless the CallBucket is removed entirely.
Ok(CallBucket {
id: new_id(),
source_file_id: bucket.source_file_id,
line: bucket.line,
message_id: bucket.message_id,
symbol: bucket.symbol.to_string(),
argument_types: bucket.argument_types.clone(),
arguments: bucket.arguments.clone(),
arena_size: bucket.arena_size,
return_info: self.transform_return_type(&bucket.id, &bucket.return_info)?,
}
.allocate())
} else {
Ok(NopBucket { id: new_id() }.allocate())
}
Expand Down

0 comments on commit 0bc3ca6

Please sign in to comment.