Skip to content

Commit

Permalink
fix: Temporary register leaks in brillig gen (#8350)
Browse files Browse the repository at this point in the history
  • Loading branch information
sirasistant authored Sep 4, 2024
1 parent abc83a2 commit 5f6d2e2
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
message: message_vector.to_heap_vector(),
output: result_array.to_heap_array(),
});
deallocate_converted_vector(brillig_context, message, message_vector, bb_func);
} else {
unreachable!("ICE: SHA256 expects one array argument and one array result")
}
Expand All @@ -42,6 +43,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
message: message_vector.to_heap_vector(),
output: result_array.to_heap_array(),
});
deallocate_converted_vector(brillig_context, message, message_vector, bb_func);
} else {
unreachable!("ICE: Blake2s expects one array argument and one array result")
}
Expand All @@ -55,6 +57,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
message: message_vector.to_heap_vector(),
output: result_array.to_heap_array(),
});
deallocate_converted_vector(brillig_context, message, message_vector, bb_func);
} else {
unreachable!("ICE: Blake3 expects one array argument and one array result")
}
Expand All @@ -78,6 +81,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
output: result_array.to_heap_array(),
});
brillig_context.deallocate_single_addr(message_size_as_usize);
deallocate_converted_vector(brillig_context, message, message_vector, bb_func);
} else {
unreachable!("ICE: Keccak256 expects message, message size and result array")
}
Expand All @@ -92,6 +96,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
message: state_vector.to_heap_vector(),
output: result_array.to_heap_array(),
});
deallocate_converted_vector(brillig_context, message, state_vector, bb_func);
} else {
unreachable!("ICE: Keccakf1600 expects one array argument and one array result")
}
Expand All @@ -111,6 +116,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
signature: signature.to_heap_array(),
result: result_register.address,
});
deallocate_converted_vector(brillig_context, message, message_hash_vector, bb_func);
} else {
unreachable!(
"ICE: EcdsaSecp256k1 expects four array arguments and one register result"
Expand All @@ -132,6 +138,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
signature: signature.to_heap_array(),
result: result_register.address,
});
deallocate_converted_vector(brillig_context, message, message_hash_vector, bb_func);
} else {
unreachable!(
"ICE: EcdsaSecp256r1 expects four array arguments and one register result"
Expand All @@ -151,6 +158,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
domain_separator: domain_separator.address,
output: result_array.to_heap_array(),
});
deallocate_converted_vector(brillig_context, message, message_vector, bb_func);
} else {
unreachable!("ICE: Pedersen expects one array argument, a register for the domain separator, and one array result")
}
Expand All @@ -167,6 +175,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
domain_separator: domain_separator.address,
output: result.address,
});
deallocate_converted_vector(brillig_context, message, message_vector, bb_func);
} else {
unreachable!("ICE: Pedersen hash expects one array argument, a register for the domain separator, and one register result")
}
Expand All @@ -178,14 +187,16 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
) = (function_arguments, function_results)
{
let message_hash = convert_array_or_vector(brillig_context, message, bb_func);
let signature = brillig_context.array_to_vector_instruction(signature);
let signature_vector = brillig_context.array_to_vector_instruction(signature);
brillig_context.black_box_op_instruction(BlackBoxOp::SchnorrVerify {
public_key_x: public_key_x.address,
public_key_y: public_key_y.address,
message: message_hash.to_heap_vector(),
signature: signature.to_heap_vector(),
signature: signature_vector.to_heap_vector(),
result: result_register.address,
});
deallocate_converted_vector(brillig_context, message, message_hash, bb_func);
brillig_context.deallocate_register(signature_vector.size);
} else {
unreachable!("ICE: Schnorr verify expects two registers for the public key, an array for signature, an array for the message hash and one result register")
}
Expand All @@ -194,13 +205,15 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
if let ([points, scalars], [BrilligVariable::BrilligArray(outputs)]) =
(function_arguments, function_results)
{
let points = convert_array_or_vector(brillig_context, points, bb_func);
let scalars = convert_array_or_vector(brillig_context, scalars, bb_func);
let points_vector = convert_array_or_vector(brillig_context, points, bb_func);
let scalars_vector = convert_array_or_vector(brillig_context, scalars, bb_func);
brillig_context.black_box_op_instruction(BlackBoxOp::MultiScalarMul {
points: points.to_heap_vector(),
scalars: scalars.to_heap_vector(),
points: points_vector.to_heap_vector(),
scalars: scalars_vector.to_heap_vector(),
outputs: outputs.to_heap_array(),
});
deallocate_converted_vector(brillig_context, points, points_vector, bb_func);
deallocate_converted_vector(brillig_context, scalars, scalars_vector, bb_func);
} else {
unreachable!(
"ICE: MultiScalarMul expects two register arguments and one array result"
Expand Down Expand Up @@ -319,6 +332,8 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
modulus: modulus_vector.to_heap_vector(),
output: output.address,
});
deallocate_converted_vector(brillig_context, inputs, inputs_vector, bb_func);
deallocate_converted_vector(brillig_context, modulus, modulus_vector, bb_func);
} else {
unreachable!(
"ICE: BigIntFromLeBytes expects a register and an array as arguments and two result registers"
Expand All @@ -336,6 +351,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
input: input.address,
output: output.to_heap_vector(),
});
deallocate_converted_vector(brillig_context, result_array, output, bb_func);
} else {
unreachable!(
"ICE: BigIntToLeBytes expects two register arguments and one array result"
Expand All @@ -354,6 +370,7 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
output: result_array.to_heap_array(),
len: state_len.address,
});
deallocate_converted_vector(brillig_context, message, message_vector, bb_func);
} else {
unreachable!("ICE: Poseidon2Permutation expects one array argument, a length and one array result")
}
Expand All @@ -369,6 +386,8 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
hash_values: hash_vector.to_heap_vector(),
output: result_array.to_heap_array(),
});
deallocate_converted_vector(brillig_context, message, message_vector, bb_func);
deallocate_converted_vector(brillig_context, hash_values, hash_vector, bb_func);
} else {
unreachable!("ICE: Sha256Compression expects two array argument, one array result")
}
Expand All @@ -379,18 +398,19 @@ pub(crate) fn convert_black_box_call<F: AcirField + DebugToString, Registers: Re
[BrilligVariable::SingleAddr(out_len), outputs],
) = (function_arguments, function_results)
{
let inputs = convert_array_or_vector(brillig_context, inputs, bb_func);
let outputs = convert_array_or_vector(brillig_context, outputs, bb_func);
let output_vec = outputs.to_heap_vector();
let inputs_vector = convert_array_or_vector(brillig_context, inputs, bb_func);
let outputs_vector = convert_array_or_vector(brillig_context, outputs, bb_func);
brillig_context.black_box_op_instruction(BlackBoxOp::AES128Encrypt {
inputs: inputs.to_heap_vector(),
inputs: inputs_vector.to_heap_vector(),
iv: iv.to_heap_array(),
key: key.to_heap_array(),
outputs: output_vec,
outputs: outputs_vector.to_heap_vector(),
});
brillig_context.mov_instruction(out_len.address, output_vec.size);
brillig_context.mov_instruction(out_len.address, outputs_vector.size);
// Returns slice, so we need to allocate memory for it after the fact
brillig_context.increase_free_memory_pointer_instruction(output_vec.size);
brillig_context.increase_free_memory_pointer_instruction(outputs_vector.size);
deallocate_converted_vector(brillig_context, inputs, inputs_vector, bb_func);
deallocate_converted_vector(brillig_context, outputs, outputs_vector, bb_func);
} else {
unreachable!("ICE: AES128Encrypt expects three array arguments, one array result")
}
Expand All @@ -413,3 +433,24 @@ fn convert_array_or_vector<F: AcirField + DebugToString, Registers: RegisterAllo
),
}
}

/// Deallocates any new register allocated by the function above.
/// Concretely, the only allocated register between array and vector is the size register if the array was converted to a vector.
fn deallocate_converted_vector<F: AcirField + DebugToString, Registers: RegisterAllocator>(
brillig_context: &mut BrilligContext<F, Registers>,
original_array_or_vector: &BrilligVariable,
converted_vector: BrilligVector,
bb_func: &BlackBoxFunc,
) {
match original_array_or_vector {
BrilligVariable::BrilligArray(_) => {
brillig_context.deallocate_register(converted_vector.size);
}
BrilligVariable::BrilligVector(_) => {}
_ => unreachable!(
"ICE: {} expected an array or a vector, but got {:?}",
bb_func.name(),
original_array_or_vector
),
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,12 @@ impl<'block> BrilligBlock<'block> {
// puts the returns into the returned_registers and restores saved_registers
self.brillig_context
.codegen_post_call_prep_returns_load_registers(&returned_registers, &saved_registers);

// Reset the register state to the one needed to hold the current available variables
let variables = self.variables.get_available_variables(self.function_context);
let registers =
variables.into_iter().flat_map(|variable| variable.extract_registers()).collect();
self.brillig_context.set_allocated_registers(registers);
}

fn validate_array_index(
Expand Down Expand Up @@ -1751,7 +1757,7 @@ impl<'block> BrilligBlock<'block> {
dfg,
);
let array = variable.extract_array();
self.allocate_nested_array(typ, Some(array));
self.allocate_foreign_call_result_array(typ, array);

variable
}
Expand All @@ -1778,40 +1784,39 @@ impl<'block> BrilligBlock<'block> {
}
}

fn allocate_nested_array(
&mut self,
typ: &Type,
array: Option<BrilligArray>,
) -> BrilligVariable {
match typ {
Type::Array(types, size) => {
let array = array.unwrap_or(BrilligArray {
pointer: self.brillig_context.allocate_register(),
size: *size,
rc: self.brillig_context.allocate_register(),
});
self.brillig_context.codegen_allocate_fixed_length_array(array.pointer, array.size);
self.brillig_context.usize_const_instruction(array.rc, 1_usize.into());

let mut index = 0_usize;
for _ in 0..*size {
for element_type in types.iter() {
match element_type {
Type::Array(_, _) => {
let inner_array = self.allocate_nested_array(element_type, None);
let idx =
self.brillig_context.make_usize_constant_instruction(index.into());
self.brillig_context.codegen_store_variable_in_array(array.pointer, idx, inner_array);
}
Type::Slice(_) => unreachable!("ICE: unsupported slice type in allocate_nested_array(), expects an array or a numeric type"),
_ => (),
}
index += 1;
fn allocate_foreign_call_result_array(&mut self, typ: &Type, array: BrilligArray) {
let Type::Array(types, size) = typ else {
unreachable!("ICE: allocate_foreign_call_array() expects an array, got {typ:?}")
};

self.brillig_context.codegen_allocate_fixed_length_array(array.pointer, array.size);
self.brillig_context.usize_const_instruction(array.rc, 1_usize.into());

let mut index = 0_usize;
for _ in 0..*size {
for element_type in types.iter() {
match element_type {
Type::Array(_, nested_size) => {
let inner_array = BrilligArray {
pointer: self.brillig_context.allocate_register(),
rc: self.brillig_context.allocate_register(),
size: *nested_size,
};
self.allocate_foreign_call_result_array(element_type, inner_array);

let idx =
self.brillig_context.make_usize_constant_instruction(index.into());
self.brillig_context.codegen_store_variable_in_array(array.pointer, idx, BrilligVariable::BrilligArray(inner_array));

self.brillig_context.deallocate_single_addr(idx);
self.brillig_context.deallocate_register(inner_array.pointer);
self.brillig_context.deallocate_register(inner_array.rc);
}
Type::Slice(_) => unreachable!("ICE: unsupported slice type in allocate_nested_array(), expects an array or a numeric type"),
_ => (),
}
BrilligVariable::BrilligArray(array)
index += 1;
}
_ => unreachable!("ICE: allocate_nested_array() expects an array, got {typ:?}"),
}
}

Expand Down

0 comments on commit 5f6d2e2

Please sign in to comment.