Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[VAN-1370] handle non-scalar function parameters in BucketInterpreter #121

Merged
merged 1 commit into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions circom/tests/constraints/with_call_vec_arg.circom
Original file line number Diff line number Diff line change
Expand Up @@ -25,19 +25,24 @@ template ComputeFee() {
component main = ComputeFee();

//CHECK-LABEL: define{{.*}} void @..generated..loop.body.
//CHECK-SAME: [[$F_ID_1:[0-9]+]]([0 x i256]* %lvars, [0 x i256]* %signals, i256* %sig_0, i256* %var_1){{.*}} {
//CHECK-SAME: [[$F_ID_1:[0-9]+]]([0 x i256]* %lvars, [0 x i256]* %signals, i256* %sig_0){{.*}} {
//CHECK-NEXT: ..generated..loop.body.[[$F_ID_1]]:
//CHECK-NEXT: br label %call1
//CHECK-EMPTY:
//CHECK-NEXT: call1:
//CHECK-NEXT: %[[FUN_NAME:[0-9a-zA-Z_.]+]]_arena = alloca [3 x i256], align 8
//CHECK-NEXT: %[[T00:[0-9a-zA-Z_.]+]] = getelementptr [3 x i256], [3 x i256]* %[[FUN_NAME]]_arena, i32 0, i32 0
//CHECK-NEXT: %[[T01:[0-9a-zA-Z_.]+]] = getelementptr i256, i256* %var_1, i32 0
//CHECK-NEXT: %[[COPY_SRC_0:[0-9a-zA-Z_.]+]] = getelementptr i256, i256* %[[T01]], i32 0
//CHECK-NEXT: %[[T01:[0-9a-zA-Z_.]+]] = getelementptr [0 x i256], [0 x i256]* %lvars, i32 0, i32 8
//CHECK-NEXT: %[[T12:[0-9a-zA-Z_.]+]] = load i256, i256* %[[T01]], align 4
//CHECK-NEXT: %[[T13:[0-9a-zA-Z_.]+]] = call i32 @fr_cast_to_addr(i256 %[[T12]])
//CHECK-NEXT: %[[T14:[0-9a-zA-Z_.]+]] = mul i32 2, %[[T13]]
//CHECK-NEXT: %[[T15:[0-9a-zA-Z_.]+]] = add i32 %[[T14]], 0
//CHECK-NEXT: %[[T16:[0-9a-zA-Z_.]+]] = getelementptr [0 x i256], [0 x i256]* %lvars, i32 0, i32 %[[T15]]
//CHECK-NEXT: %[[COPY_SRC_0:[0-9a-zA-Z_.]+]] = getelementptr i256, i256* %[[T16]], i32 0
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize this looks like a regression; what was fixed indexing via a parameter has become dynamic indexing. However, this is actually caused by a different long-standing issue https://veridise.atlassian.net/browse/VAN-671 and the fix to the interpreter has brought it to the surface.

//CHECK-NEXT: %[[COPY_DST_0:[0-9a-zA-Z_.]+]] = getelementptr i256, i256* %[[T00]], i32 0
//CHECK-NEXT: %[[COPY_VAL_0:[0-9a-zA-Z_.]+]] = load i256, i256* %[[COPY_SRC_0]], align 4
//CHECK-NEXT: store i256 %[[COPY_VAL_0]], i256* %[[COPY_DST_0]], align 4
//CHECK-NEXT: %[[COPY_SRC_1:[0-9a-zA-Z_.]+]] = getelementptr i256, i256* %[[T01]], i32 1
//CHECK-NEXT: %[[COPY_SRC_1:[0-9a-zA-Z_.]+]] = getelementptr i256, i256* %[[T16]], i32 1
//CHECK-NEXT: %[[COPY_DST_1:[0-9a-zA-Z_.]+]] = getelementptr i256, i256* %[[T00]], i32 1
//CHECK-NEXT: %[[COPY_VAL_1:[0-9a-zA-Z_.]+]] = load i256, i256* %[[COPY_SRC_1]], align 4
//CHECK-NEXT: store i256 %[[COPY_VAL_1]], i256* %[[COPY_DST_1]], align 4
Expand Down Expand Up @@ -141,14 +146,10 @@ component main = ComputeFee();
//CHECK-NEXT: unrolled_loop8:
//CHECK-NEXT: %[[T10:[0-9a-zA-Z_.]+]] = bitcast [9 x i256]* %lvars to [0 x i256]*
//CHECK-NEXT: %[[T11:[0-9a-zA-Z_.]+]] = getelementptr [0 x i256], [0 x i256]* %0, i32 0, i256 0
//CHECK-NEXT: %[[T12:[0-9a-zA-Z_.]+]] = bitcast [9 x i256]* %lvars to [0 x i256]*
//CHECK-NEXT: %[[T13:[0-9a-zA-Z_.]+]] = getelementptr [0 x i256], [0 x i256]* %[[T12]], i32 0, i256 0
//CHECK-NEXT: call void @..generated..loop.body.[[$F_ID_1]]([0 x i256]* %[[T10]], [0 x i256]* %0, i256* %[[T11]], i256* %[[T13]])
//CHECK-NEXT: call void @..generated..loop.body.[[$F_ID_1]]([0 x i256]* %[[T10]], [0 x i256]* %0, i256* %[[T11]])
//CHECK-NEXT: %[[T14:[0-9a-zA-Z_.]+]] = bitcast [9 x i256]* %lvars to [0 x i256]*
//CHECK-NEXT: %[[T15:[0-9a-zA-Z_.]+]] = getelementptr [0 x i256], [0 x i256]* %0, i32 0, i256 1
//CHECK-NEXT: %[[T16:[0-9a-zA-Z_.]+]] = bitcast [9 x i256]* %lvars to [0 x i256]*
//CHECK-NEXT: %[[T17:[0-9a-zA-Z_.]+]] = getelementptr [0 x i256], [0 x i256]* %[[T16]], i32 0, i256 2
//CHECK-NEXT: call void @..generated..loop.body.[[$F_ID_1]]([0 x i256]* %[[T14]], [0 x i256]* %0, i256* %[[T15]], i256* %[[T17]])
//CHECK-NEXT: call void @..generated..loop.body.[[$F_ID_1]]([0 x i256]* %[[T14]], [0 x i256]* %0, i256* %[[T15]])
//CHECK-NEXT: br label %prologue
//CHECK-EMPTY:
//CHECK-NEXT: prologue:
Expand Down
59 changes: 50 additions & 9 deletions circuit_passes/src/bucket_interpreter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@ pub mod write_collector;
use std::cell::RefCell;
use std::ops::Range;
use paste::paste;
use code_producers::llvm_elements::{fr, array_switch};
use code_producers::llvm_elements::{array_switch, fr};
use code_producers::llvm_elements::stdlib::{GENERATED_FN_PREFIX, LLVM_DONOTHING_FN_NAME};
use compiler::intermediate_representation::{
BucketId, Instruction, InstructionList, InstructionPointer,
new_id, BucketId, Instruction, InstructionList, InstructionPointer,
};
use compiler::intermediate_representation::ir_interface::*;
use compiler::num_bigint::BigInt;
use observer::Observer;
use program_structure::error_code::ReportCode;
use crate::passes::builders::{build_compute, build_u32_value};
use crate::passes::loop_unroll::LOOP_BODY_FN_PREFIX;
use crate::passes::GlobalPassData;
use self::env::{CallStack, CallStackFrame, Env, LibraryAccess};
Expand Down Expand Up @@ -738,7 +739,6 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> {
env: Env<'env>,
observe: bool,
) -> REI<'env> {
let mut env = env;
let res = if bucket.symbol.starts_with(GENERATED_FN_PREFIX) {
// ASSUME: The arguments to a generated function will always be LoadBucket with 'bounded_fn'
// that are intended to generate pointers or a call to some built-in function that returns
Expand All @@ -760,12 +760,53 @@ impl<'a: 'd, 'd> BucketInterpreter<'a, 'd> {
self._execute_function_extracted(&bucket, env, observe)
))
} else {
let mut args = Vec::with_capacity(bucket.arguments.len());
for i in &bucket.arguments {
let (val, new_env) = check_res!(self._execute_instruction(i, env, observe));
let val = check_std_res!(opt_as_result(val, "function argument"));
args.push(val);
env = new_env;
let mut args = vec![];
for a in &bucket.arguments {
// Case: vector load
if let Instruction::Load(load) = &**a {
let load_size = load.context.size;
if load_size > 1 {
assert!(load.bounded_fn.is_none());
match &load.src {
LocationRule::Mapped { .. } => todo!("Can this happen?"),
LocationRule::Indexed { location, template_header } => {
for i in 0..load_size {
let scalar_load = LoadBucket {
id: new_id(),
source_file_id: load.source_file_id,
line: load.line,
message_id: load.message_id,
address_type: load.address_type.clone(),
src: LocationRule::Indexed {
location: build_compute(
load,
OperatorType::Add,
0,
vec![location.clone(), build_u32_value(load, i)],
),
template_header: template_header.clone(),
},
context: InstrContext { size: 1 },
bounded_fn: None,
}
.allocate();
let val = check_res!(
self._compute_instruction(&scalar_load, &env, observe),
|v| (v, env)
);
args.push(check_std_res!(opt_as_result(
val,
"function argument"
)));
}
}
}
continue;
}
}
// Case: anything else
let val = check_res!(self._compute_instruction(a, &env, observe), |v| (v, env));
args.push(check_std_res!(opt_as_result(val, "function argument")));
}
check_res!(InterpRes::try_continue(self._execute_function_basic(
&bucket.symbol,
Expand Down
34 changes: 15 additions & 19 deletions circuit_passes/src/passes/unreachable_code_removal.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::cell::RefCell;
use std::collections::HashSet;
use code_producers::llvm_elements::stdlib::GENERATED_FN_PREFIX;
use paste::paste;
use compiler::circuit_design::template::TemplateCode;
use compiler::intermediate_representation::{new_id, BucketId};
Expand Down Expand Up @@ -131,25 +130,22 @@ impl CircuitTransformationPass for UnreachableRemovalPass<'_> {

fn transform_call_bucket(&self, bucket: &CallBucket) -> Result<InstructionPointer, BadInterp> {
if self.visited.borrow_mut().remove(&bucket.id) {
if bucket.symbol.starts_with(GENERATED_FN_PREFIX) {
// BucketInterpreter::_execute_call_bucket() will never visit arguments
// within these generated functions so we need a special case so they
// are not removed unless the CallBucket is removed entirely.
Ok(CallBucket {
id: new_id(),
source_file_id: bucket.source_file_id,
line: bucket.line,
message_id: bucket.message_id,
symbol: bucket.symbol.to_string(),
argument_types: bucket.argument_types.clone(),
arguments: bucket.arguments.clone(),
arena_size: bucket.arena_size,
return_info: self.transform_return_type(&bucket.id, &bucket.return_info)?,
}
.allocate())
} else {
self.transform_call_bucket_default(bucket)
// BucketInterpreter::_execute_call_bucket() will never visit arguments within
// generated functions and will only visit the scalar arguments of all other
// functions so we need a special case to prevent the arguments from being
// removed unless the CallBucket is removed entirely.
Ok(CallBucket {
id: new_id(),
source_file_id: bucket.source_file_id,
line: bucket.line,
message_id: bucket.message_id,
symbol: bucket.symbol.to_string(),
argument_types: bucket.argument_types.clone(),
arguments: bucket.arguments.clone(),
arena_size: bucket.arena_size,
return_info: self.transform_return_type(&bucket.id, &bucket.return_info)?,
}
.allocate())
} else {
Ok(NopBucket { id: new_id() }.allocate())
}
Expand Down
Loading