Skip to content

Commit

Permalink
Update the current_witness_index after optimisation
Browse files Browse the repository at this point in the history
  • Loading branch information
aakoshh committed Dec 4, 2024
1 parent 73e11c6 commit b59cb34
Showing 1 changed file with 264 additions and 5 deletions.
269 changes: 264 additions & 5 deletions acvm-repo/acvm/src/compiler/transformers/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
use std::collections::BTreeSet;

use acir::{
circuit::{brillig::BrilligOutputs, Circuit, ExpressionWidth, Opcode},
circuit::{
self,
brillig::{BrilligInputs, BrilligOutputs},
opcodes::{BlackBoxFuncCall, FunctionInput, MemOp},
Circuit, ExpressionWidth, Opcode,
},
native_types::{Expression, Witness},
AcirField,
};
Expand Down Expand Up @@ -79,8 +86,6 @@ pub(super) fn transform_internal<F: AcirField>(
&mut next_witness_index,
);

// Update next_witness counter
next_witness_index += (intermediate_variables.len() - len) as u32;
let mut new_opcodes = Vec::new();
for (g, (norm, w)) in intermediate_variables.iter().skip(len) {
// de-normalize
Expand Down Expand Up @@ -160,13 +165,267 @@ pub(super) fn transform_internal<F: AcirField>(
let mut merge_optimizer = MergeExpressionsOptimizer::new();
let (opcodes, new_acir_opcode_positions) =
merge_optimizer.eliminate_intermediate_variable(&acir, new_acir_opcode_positions);
// n.b. we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less.
let acir = Circuit {

// n.b. if we do not update current_witness_index after the eliminate_intermediate_variable pass, the real index could be less.
let mut acir = Circuit {
current_witness_index,
expression_width,
opcodes,
// The optimizer does not add new public inputs
..acir
};

// After the elimination of intermediate variables the `current_witness_index` is potentially higher than it needs to be,
// which would cause gaps if we ran the optimization a second time, making it look like new variables were added.
// Here we figure out what is the final state of witnesses by visiting each opcode.
let witnesses = WitnessCollector::collect_from_circuit(&acir);
if let Some(max_witness) = witnesses.last() {
acir.current_witness_index = max_witness.0;
}

(acir, new_acir_opcode_positions)
}

/// Collect all witnesses in a circuit.
#[derive(Default, Clone, Debug)]
struct WitnessCollector {
witnesses: BTreeSet<Witness>,
}

impl WitnessCollector {
/// Collect all witnesses in a circuit.
fn collect_from_circuit<F: AcirField>(circuit: &Circuit<F>) -> BTreeSet<Witness> {
let mut collector = Self::default();
collector.extend_from_circuit(circuit);
collector.witnesses
}

fn add(&mut self, witness: Witness) {
self.witnesses.insert(witness);
}

fn add_many(&mut self, witnesses: &[Witness]) {
self.witnesses.extend(witnesses);
}

/// Add all witnesses from the circuit.
fn extend_from_circuit<F: AcirField>(&mut self, circuit: &Circuit<F>) {
self.witnesses.extend(&circuit.private_parameters);
self.witnesses.extend(&circuit.public_parameters.0);
self.witnesses.extend(&circuit.return_values.0);
for opcode in &circuit.opcodes {
self.extend_from_opcode(opcode);
}
}

/// Add witnesses from the opcode.
fn extend_from_opcode<F: AcirField>(&mut self, opcode: &Opcode<F>) {
match opcode {
Opcode::AssertZero(expr) => {
self.extend_from_expr(expr);
}
Opcode::BlackBoxFuncCall(call) => self.extend_from_blackbox(call),
Opcode::MemoryOp { block_id: _, op, predicate } => {
let MemOp { operation, index, value } = op;
self.extend_from_expr(operation);
self.extend_from_expr(index);
self.extend_from_expr(value);
if let Some(pred) = predicate {
self.extend_from_expr(pred);
}
}
Opcode::MemoryInit { block_id: _, init, block_type: _ } => {
for w in init {
self.add(*w);
}
}
// We keep the display for a BrilligCall and circuit Call separate as they
// are distinct in their functionality and we should maintain this separation for debugging.
Opcode::BrilligCall { id: _, inputs, outputs, predicate } => {
if let Some(pred) = predicate {
self.extend_from_expr(pred);
}
self.extend_from_brillig_inputs(inputs);
self.extend_from_brillig_outputs(outputs);
}
Opcode::Call { id: _, inputs, outputs, predicate } => {
if let Some(pred) = predicate {
self.extend_from_expr(pred);
}
self.add_many(inputs);
self.add_many(outputs);
}
}
}

fn extend_from_expr<F: AcirField>(&mut self, expr: &Expression<F>) {
for i in &expr.mul_terms {
self.add(i.1);
self.add(i.2);
}
for i in &expr.linear_combinations {
self.add(i.1);
}
}

fn extend_from_brillig_inputs<F: AcirField>(&mut self, inputs: &[BrilligInputs<F>]) {
for input in inputs {
match input {
BrilligInputs::Single(expr) => {
self.extend_from_expr(expr);
}
BrilligInputs::Array(exprs) => {
for expr in exprs {
self.extend_from_expr(expr);
}
}
BrilligInputs::MemoryArray(_) => {}
}
}
}

fn extend_from_brillig_outputs(&mut self, outputs: &[BrilligOutputs]) {
for output in outputs {
match output {
BrilligOutputs::Simple(w) => {
self.add(*w);
}
BrilligOutputs::Array(ws) => self.add_many(ws),
}
}
}

fn extend_from_blackbox<F: AcirField>(&mut self, call: &BlackBoxFuncCall<F>) {
match call {
BlackBoxFuncCall::AES128Encrypt { inputs, iv, key, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.extend_from_function_inputs(iv.as_slice());
self.extend_from_function_inputs(key.as_slice());
self.add_many(outputs);
}
BlackBoxFuncCall::AND { lhs, rhs, output } => {
self.extend_from_function_input(lhs);
self.extend_from_function_input(rhs);
self.add(*output);
}
BlackBoxFuncCall::XOR { lhs, rhs, output } => {
self.extend_from_function_input(lhs);
self.extend_from_function_input(rhs);
self.add(*output);
}
BlackBoxFuncCall::RANGE { input } => {
self.extend_from_function_input(input);
}
BlackBoxFuncCall::Blake2s { inputs, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.add_many(outputs.as_slice());
}
BlackBoxFuncCall::Blake3 { inputs, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.add_many(outputs.as_slice());
}
BlackBoxFuncCall::SchnorrVerify {
public_key_x,
public_key_y,
signature,
message,
output,
} => {
self.extend_from_function_input(public_key_x);
self.extend_from_function_input(public_key_y);
self.extend_from_function_inputs(signature.as_slice());
self.extend_from_function_inputs(message.as_slice());
self.add(*output);
}
BlackBoxFuncCall::EcdsaSecp256k1 {
public_key_x,
public_key_y,
signature,
hashed_message,
output,
} => {
self.extend_from_function_inputs(public_key_x.as_slice());
self.extend_from_function_inputs(public_key_y.as_slice());
self.extend_from_function_inputs(signature.as_slice());
self.extend_from_function_inputs(hashed_message.as_slice());
self.add(*output);
}
BlackBoxFuncCall::EcdsaSecp256r1 {
public_key_x,
public_key_y,
signature,
hashed_message,
output,
} => {
self.extend_from_function_inputs(public_key_x.as_slice());
self.extend_from_function_inputs(public_key_y.as_slice());
self.extend_from_function_inputs(signature.as_slice());
self.extend_from_function_inputs(hashed_message.as_slice());
self.add(*output);
}
BlackBoxFuncCall::MultiScalarMul { points, scalars, outputs } => {
self.extend_from_function_inputs(points.as_slice());
self.extend_from_function_inputs(scalars.as_slice());
let (x, y, i) = outputs;
self.add(*x);
self.add(*y);
self.add(*i);
}
BlackBoxFuncCall::EmbeddedCurveAdd { input1, input2, outputs } => {
self.extend_from_function_inputs(input1.as_slice());
self.extend_from_function_inputs(input2.as_slice());
let (x, y, i) = outputs;
self.add(*x);
self.add(*y);
self.add(*i);
}
BlackBoxFuncCall::Keccakf1600 { inputs, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.add_many(outputs.as_slice());
}
BlackBoxFuncCall::RecursiveAggregation {
verification_key,
proof,
public_inputs,
key_hash,
proof_type: _,
} => {
self.extend_from_function_inputs(verification_key.as_slice());
self.extend_from_function_inputs(proof.as_slice());
self.extend_from_function_inputs(public_inputs.as_slice());
self.extend_from_function_input(key_hash);
}
BlackBoxFuncCall::BigIntAdd { .. }
| BlackBoxFuncCall::BigIntSub { .. }
| BlackBoxFuncCall::BigIntMul { .. }
| BlackBoxFuncCall::BigIntDiv { .. } => {}
BlackBoxFuncCall::BigIntFromLeBytes { inputs, modulus: _, output: _ } => {
self.extend_from_function_inputs(inputs.as_slice());
}
BlackBoxFuncCall::BigIntToLeBytes { input: _, outputs } => {
self.add_many(outputs.as_slice());
}
BlackBoxFuncCall::Poseidon2Permutation { inputs, outputs, len: _ } => {
self.extend_from_function_inputs(inputs.as_slice());
self.add_many(outputs.as_slice());
}
BlackBoxFuncCall::Sha256Compression { inputs, hash_values, outputs } => {
self.extend_from_function_inputs(inputs.as_slice());
self.extend_from_function_inputs(hash_values.as_slice());
self.add_many(outputs.as_slice());
}
}
}

fn extend_from_function_input<F: AcirField>(&mut self, input: &FunctionInput<F>) {
if let circuit::opcodes::ConstantOrWitnessEnum::Witness(witness) = input.input() {
self.add(witness);
}
}

fn extend_from_function_inputs<F: AcirField>(&mut self, inputs: &[FunctionInput<F>]) {
for input in inputs {
self.extend_from_function_input(input);
}
}
}

0 comments on commit b59cb34

Please sign in to comment.