Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Initialise databus using return values #6074

Merged
merged 5 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1944,6 +1944,15 @@ impl<F: AcirField> AcirContext<F> {
Ok(())
}

/// Insert the MemoryInit for the Return Data array, using the provided witnesses
pub(crate) fn initialize_return_data(&mut self, block_id: BlockId, init: Vec<Witness>) {
self.acir_ir.push_opcode(Opcode::MemoryInit {
block_id,
init,
block_type: BlockType::ReturnData,
});
}

/// Initializes an array in memory with the given values `optional_values`.
/// If `optional_values` is empty, then the array is initialized with zeros.
pub(crate) fn initialize_array(
Expand Down
68 changes: 33 additions & 35 deletions compiler/noirc_evaluator/src/ssa/acir_gen/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -434,16 +434,9 @@ impl<'a> Context<'a> {
for instruction_id in entry_block.instructions() {
warnings.extend(self.convert_ssa_instruction(*instruction_id, dfg, ssa, brillig)?);
}

let (return_vars, return_warnings) =
self.convert_ssa_return(entry_block.unwrap_terminator(), dfg)?;

let call_data_arrays: Vec<ValueId> =
self.data_bus.call_data.iter().map(|cd| cd.array_id).collect();
for call_data_array in call_data_arrays {
self.ensure_array_is_initialized(call_data_array, dfg)?;
}

// TODO: This is a naive method of assigning the return values to their witnesses as
// we're likely to get a number of constraints which are asserting one witness to be equal to another.
//
Expand All @@ -452,13 +445,42 @@ impl<'a> Context<'a> {
self.acir_context.assert_eq_var(*witness_var, return_var, None)?;
}

self.initialize_databus(&return_witnesses, dfg)?;
warnings.extend(return_warnings);
warnings.extend(self.acir_context.warnings.clone());

// Add the warnings from the alter Ssa passes
Ok(self.acir_context.finish(input_witness, return_witnesses, warnings))
}

fn initialize_databus(
&mut self,
witnesses: &Vec<Witness>,
dfg: &DataFlowGraph,
) -> Result<(), RuntimeError> {
// Initialize return_data using provided witnesses
if let Some(return_data) = self.data_bus.return_data {
let block_id = self.block_id(&return_data);
let already_initialized = self.initialized_arrays.contains(&block_id);
if !already_initialized {
// We hijack ensure_array_is_initialized() because we want the return data to use the return value witnesses,
// but the databus contains the computed values instead, that have just been asserted to be equal to the return values.
// We do not use initialize_array either for the case where a constant value is returned.
// In that case, the constant value has already been assigned a witness and the returned acir vars will be
// converted to it, instead of the corresponding return value witness.
self.acir_context.initialize_return_data(block_id, witnesses.to_owned());
}
}

// Initialize call_data
let call_data_arrays: Vec<ValueId> =
self.data_bus.call_data.iter().map(|cd| cd.array_id).collect();
for call_data_array in call_data_arrays {
self.ensure_array_is_initialized(call_data_array, dfg)?;
}
Ok(())
}

fn convert_brillig_main(
mut self,
main_func: &Function,
Expand Down Expand Up @@ -1792,19 +1814,9 @@ impl<'a> Context<'a> {
_ => unreachable!("ICE: Program must have a singular return"),
};

return_values.iter().fold(0, |acc, value_id| {
let is_databus = self
.data_bus
.return_data
.map_or(false, |return_databus| dfg[*value_id] == dfg[return_databus]);

if is_databus {
// We do not return value for the data bus.
acc
} else {
acc + dfg.type_of_value(*value_id).flattened_size()
}
})
return_values
.iter()
.fold(0, |acc, value_id| acc + dfg.type_of_value(*value_id).flattened_size())
}

/// Converts an SSA terminator's return values into their ACIR representations
Expand All @@ -1824,27 +1836,13 @@ impl<'a> Context<'a> {
let mut has_constant_return = false;
let mut return_vars: Vec<AcirVar> = Vec::new();
for value_id in return_values {
let is_databus = self
.data_bus
.return_data
.map_or(false, |return_databus| dfg[*value_id] == dfg[return_databus]);
let value = self.convert_value(*value_id, dfg);

// `value` may or may not be an array reference. Calling `flatten` will expand the array if there is one.
let acir_vars = self.acir_context.flatten(value)?;
for (acir_var, _) in acir_vars {
has_constant_return |= self.acir_context.is_constant(&acir_var);
if is_databus {
// We do not return value for the data bus.
self.ensure_array_is_initialized(
self.data_bus.return_data.expect(
"`is_databus == true` implies `data_bus.return_data` is `Some`",
),
dfg,
)?;
} else {
return_vars.push(acir_var);
}
return_vars.push(acir_var);
}
}

Expand Down
Loading