Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added type hierarchy support for arrays and slices #1957

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,71 +2,93 @@ use dep::std::slice;
use dep::std;

unconstrained fn main(x: Field, y: Field) {
// Mark it as mut so the compiler doesn't simplify the following operations
// But don't reuse the mut slice variable until this is fixed https://github.com/noir-lang/noir/issues/1931
let slice: [Field] = [y, x];
//Get the slice from a function to correct the type until https://github.com/noir-lang/noir/issues/1931 is fixed
let mut slice: [Field] = get_me_a_slice(y, x);
assert(slice.len() == 2);

let mut pushed_back_slice = slice.push_back(7);
assert(pushed_back_slice.len() == 3);
assert(pushed_back_slice[0] == y);
assert(pushed_back_slice[1] == x);
assert(pushed_back_slice[2] == 7);
slice = slice.push_back(7);
assert(slice.len() == 3);
assert(slice[0] == y);
assert(slice[1] == x);
assert(slice[2] == 7);

// Array set on slice target
pushed_back_slice[0] = x;
pushed_back_slice[1] = y;
pushed_back_slice[2] = 1;

assert(pushed_back_slice[0] == x);
assert(pushed_back_slice[1] == y);
assert(pushed_back_slice[2] == 1);

assert(slice.len() == 2);

let pushed_front_slice = pushed_back_slice.push_front(2);
assert(pushed_front_slice.len() == 4);
assert(pushed_front_slice[0] == 2);
assert(pushed_front_slice[1] == x);
assert(pushed_front_slice[2] == y);
assert(pushed_front_slice[3] == 1);

let (item, popped_front_slice) = pushed_front_slice.pop_front();
slice[0] = x;
slice[1] = y;
slice[2] = 1;

assert(slice[0] == x);
assert(slice[1] == y);
assert(slice[2] == 1);

slice = slice.push_front(2);
assert(slice.len() == 4);
assert(slice[0] == 2);
assert(slice[1] == x);
assert(slice[2] == y);
assert(slice[3] == 1);

let (item, popped_front_slice) = slice.pop_front();
assert(item == 2);
slice = popped_front_slice;

assert(popped_front_slice.len() == 3);
assert(popped_front_slice[0] == x);
assert(popped_front_slice[1] == y);
assert(popped_front_slice[2] == 1);
assert(slice.len() == 3);
assert(slice[0] == x);
assert(slice[1] == y);
assert(slice[2] == 1);

let (popped_back_slice, another_item) = popped_front_slice.pop_back();
let (popped_back_slice, another_item) = slice.pop_back();
assert(another_item == 1);
slice = popped_back_slice;

assert(popped_back_slice.len() == 2);
assert(popped_back_slice[0] == x);
assert(popped_back_slice[1] == y);
assert(slice.len() == 2);
assert(slice[0] == x);
assert(slice[1] == y);

let inserted_slice = popped_back_slice.insert(1, 2);
assert(inserted_slice.len() == 3);
assert(inserted_slice[0] == x);
assert(inserted_slice[1] == 2);
assert(inserted_slice[2] == y);
slice = slice.insert(1, 2);
assert(slice.len() == 3);
assert(slice[0] == x);
assert(slice[1] == 2);
assert(slice[2] == y);

let (removed_slice, should_be_2) = inserted_slice.remove(1);
let (removed_slice, should_be_2) = slice.remove(1);
assert(should_be_2 == 2);
slice = removed_slice;

assert(removed_slice.len() == 2);
assert(removed_slice[0] == x);
assert(removed_slice[1] == y);
assert(slice.len() == 2);
assert(slice[0] == x);
assert(slice[1] == y);

let (slice_with_only_x, should_be_y) = removed_slice.remove(1);
let (slice_with_only_x, should_be_y) = slice.remove(1);
assert(should_be_y == y);
slice = slice_with_only_x;

assert(slice_with_only_x.len() == 1);
assert(slice.len() == 1);
assert(removed_slice[0] == x);

let (empty_slice, should_be_x) = slice_with_only_x.remove(0);
let (empty_slice, should_be_x) = slice.remove(0);
assert(should_be_x == x);
assert(empty_slice.len() == 0);


let sequence = create_sequence([], 5);

assert(sequence.len() == 5);
assert(sequence[0] == 5);
assert(sequence[1] == 4);
assert(sequence[2] == 3);
assert(sequence[3] == 2);
assert(sequence[4] == 1);
}

unconstrained fn get_me_a_slice(x: Field, y: Field) -> [Field] {
[x, y]
}

unconstrained fn create_sequence(mut sequence: [Field], n: Field) -> [Field] {
if n != 0 {
sequence = create_sequence(sequence.push_back(n), n - 1);
}

sequence
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[package]
authors = [""]
compiler_version = "0.6.0"

[dependencies]
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
x = "5"
y = "10"
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
use dep::std::slice;
use dep::std;

unconstrained fn main(x: Field, y: Field) {
let mut array = [x, y];

// Dynamic dispatch typing
let mut test_slice = wrapper(push_accepts_array_returns_slice, array);
assert(test_slice[2] == 1);

test_slice = wrapper(push_accepts_slice_returns_slice, array);
assert(test_slice[2] == 1);


//Get the slice from a function to correct the type until https://github.com/noir-lang/noir/issues/1931 is fixed
let mut slice = get_me_a_slice(x, y);


// No cast
let another_array = push_accepts_array_returns_array(array);
assert(another_array[2] == 1);

slice = push_accepts_array_returns_slice(array);
assert(slice[2] == 1);

// Cast return value
slice = push_accepts_array_returns_array(array);
assert(slice[2] == 1);

// Cast param
slice = push_accepts_slice_returns_slice(array);
assert(slice[2] == 1);
}

unconstrained fn get_me_a_slice(x: Field, y: Field) -> [Field] {
[x, y]
}

unconstrained fn push_accepts_array_returns_slice<N>(array: [Field; N]) -> [Field] {
let slice: [Field] = array;
slice.push_back(1)
}

unconstrained fn push_accepts_array_returns_array(array: [Field; 2]) -> [Field; 3] {
[array[0], array[1], 1]
}

unconstrained fn push_accepts_slice_returns_slice(slice: [Field]) -> [Field] {
slice.push_back(1)
}

unconstrained fn wrapper(function: fn ([Field]) -> [Field], param: [Field; 2]) -> [Field] {
function(param)
}
19 changes: 16 additions & 3 deletions crates/noirc_evaluator/src/brillig/brillig_gen.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ pub(crate) mod brillig_directive;
pub(crate) mod brillig_fn;
pub(crate) mod brillig_slice_ops;

use crate::ssa_refactor::ir::{function::Function, post_order::PostOrder};
use crate::ssa_refactor::ir::{
function::{Function, FunctionId, Signature},
post_order::PostOrder,
};

use std::collections::HashMap;

Expand All @@ -13,7 +16,11 @@ use self::{brillig_block::BrilligBlock, brillig_fn::FunctionContext};
use super::brillig_ir::{artifact::BrilligArtifact, BrilligContext};

/// Converting an SSA function into Brillig bytecode.
pub(crate) fn convert_ssa_function(func: &Function, enable_debug_trace: bool) -> BrilligArtifact {
pub(crate) fn convert_ssa_function(
func: &Function,
function_to_signature: &HashMap<FunctionId, Signature>,
enable_debug_trace: bool,
) -> BrilligArtifact {
let mut reverse_post_order = Vec::new();
reverse_post_order.extend_from_slice(PostOrder::with_function(func).as_slice());
reverse_post_order.reverse();
Expand All @@ -29,7 +36,13 @@ pub(crate) fn convert_ssa_function(func: &Function, enable_debug_trace: bool) ->

brillig_context.enter_context(FunctionContext::function_id_to_function_label(func.id()));
for block in reverse_post_order {
BrilligBlock::compile(&mut function_context, &mut brillig_context, block, &func.dfg);
BrilligBlock::compile(
&mut function_context,
&mut brillig_context,
block,
&func.dfg,
function_to_signature,
);
}

brillig_context.artifact()
Expand Down
87 changes: 79 additions & 8 deletions crates/noirc_evaluator/src/brillig/brillig_gen/brillig_block.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::collections::HashMap;

use crate::brillig::brillig_gen::brillig_slice_ops::{
convert_array_or_vector_to_vector, slice_push_back_operation,
};
use crate::brillig::brillig_ir::{
BrilligBinaryOp, BrilligContext, BRILLIG_INTEGER_ARITHMETIC_BIT_SIZE,
};
use crate::ssa_refactor::ir::function::FunctionId;
use crate::ssa_refactor::ir::function::{FunctionId, Signature};
use crate::ssa_refactor::ir::instruction::Intrinsic;
use crate::ssa_refactor::ir::{
basic_block::{BasicBlock, BasicBlockId},
Expand All @@ -19,7 +21,7 @@ use acvm::FieldElement;
use iter_extended::vecmap;

use super::brillig_black_box::convert_black_box_call;
use super::brillig_fn::{compute_size_of_composite_type, FunctionContext};
use super::brillig_fn::{compute_size_of_composite_type, compute_size_of_type, FunctionContext};
use super::brillig_slice_ops::{
slice_insert_operation, slice_pop_back_operation, slice_pop_front_operation,
slice_push_front_operation, slice_remove_operation,
Expand All @@ -32,6 +34,8 @@ pub(crate) struct BrilligBlock<'block> {
block_id: BasicBlockId,
/// Context for creating brillig opcodes
brillig_context: &'block mut BrilligContext,
/// A map of function ids to their signatures
function_to_signature: &'block HashMap<FunctionId, Signature>,
}

impl<'block> BrilligBlock<'block> {
Expand All @@ -41,8 +45,10 @@ impl<'block> BrilligBlock<'block> {
brillig_context: &'block mut BrilligContext,
block_id: BasicBlockId,
dfg: &DataFlowGraph,
function_to_signature: &'block HashMap<FunctionId, Signature>,
) {
let mut brillig_block = BrilligBlock { function_context, block_id, brillig_context };
let mut brillig_block =
BrilligBlock { function_context, block_id, brillig_context, function_to_signature };

brillig_block.convert_block(dfg);
}
Expand Down Expand Up @@ -401,12 +407,18 @@ impl<'block> BrilligBlock<'block> {
dfg: &DataFlowGraph,
instruction_id: InstructionId,
) {
let signature_of_called_function =
self.function_to_signature.get(&func_id).expect("ICE: cannot find function signature");

// Convert the arguments to registers casting those to the types of the receiving function
let argument_registers: Vec<RegisterIndex> = arguments
.iter()
.flat_map(|argument_id| {
.zip(&signature_of_called_function.params)
.flat_map(|(argument_id, receiver_typ)| {
let variable_to_pass = self.convert_ssa_value(*argument_id, dfg);
self.function_context.extract_registers(variable_to_pass)
let casted_to_param_type =
self.cast_variable_for_call(variable_to_pass, receiver_typ);
self.function_context.extract_registers(casted_to_param_type)
})
.collect();

Expand All @@ -429,11 +441,20 @@ impl<'block> BrilligBlock<'block> {
self.function_context.create_variable(self.brillig_context, *result_id, dfg)
});

// Transform the assigned to variables into the types of the called function returns
let returned_variables: Vec<RegisterOrMemory> = variables_assigned_to
.iter()
.zip(&signature_of_called_function.returns)
.map(|(variable_assigned_to, return_typ)| {
self.cast_back_variable_from_call(*variable_assigned_to, return_typ)
})
.collect();

// Collect the registers that should have been returned
let returned_registers: Vec<RegisterIndex> = variables_assigned_to
let returned_registers: Vec<RegisterIndex> = returned_variables
.iter()
.flat_map(|returned_variable| {
self.function_context.extract_registers(*returned_variable)
.flat_map(|casted_to_return_type| {
self.function_context.extract_registers(*casted_to_return_type)
})
.collect();

Expand All @@ -445,6 +466,56 @@ impl<'block> BrilligBlock<'block> {
// puts the returns into the returned_registers and restores saved_registers
self.brillig_context
.post_call_prep_returns_load_registers(&returned_registers, &saved_registers);

// Reconciliate the types of the variables that the returns are assigned to with the types of the returns
variables_assigned_to.iter().zip(returned_variables).for_each(
|(variable_assigned_to, return_variable)| {
self.reconciliate_from_call(*variable_assigned_to, return_variable);
},
);
}

fn cast_variable_for_call(
&mut self,
variable_to_pass: RegisterOrMemory,
param_type: &Type,
) -> RegisterOrMemory {
match (variable_to_pass, param_type) {
(RegisterOrMemory::HeapArray(array), Type::Slice(..)) => {
RegisterOrMemory::HeapVector(self.brillig_context.array_to_vector(&array))
}
(_, _) => variable_to_pass,
}
}

fn cast_back_variable_from_call(
&mut self,
variable_assigned_to: RegisterOrMemory,
return_type: &Type,
) -> RegisterOrMemory {
match (variable_assigned_to, return_type) {
(RegisterOrMemory::HeapVector(vector), Type::Array(..)) => {
RegisterOrMemory::HeapArray(HeapArray {
pointer: vector.pointer,
size: compute_size_of_type(return_type),
})
}
(_, _) => variable_assigned_to,
}
}

fn reconciliate_from_call(
&mut self,
variable_assigned_to: RegisterOrMemory,
return_variable: RegisterOrMemory,
) -> RegisterOrMemory {
match (variable_assigned_to, return_variable) {
(RegisterOrMemory::HeapVector(vector), RegisterOrMemory::HeapArray(array)) => {
self.brillig_context.const_instruction(vector.size, array.size.into());
RegisterOrMemory::HeapVector(vector)
}
(_, _) => variable_assigned_to,
}
}

/// Array set operation in SSA returns a new array or slice that is a copy of the parameter array or slice
Expand Down
Loading