Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Derive Ord and Hash in the stdlib; add std::meta::make_impl helper #5683

Merged
merged 15 commits into from
Aug 6, 2024
7 changes: 5 additions & 2 deletions compiler/noirc_frontend/src/hir/comptime/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,15 +219,14 @@
}
} else {
let name = self.elaborator.interner.function_name(&function);
unreachable!("Non-builtin, lowlevel or oracle builtin fn '{name}'")

Check warning on line 222 in compiler/noirc_frontend/src/hir/comptime/interpreter.rs

View workflow job for this annotation

GitHub Actions / Code

Unknown word (lowlevel)
}
}

fn call_closure(
&mut self,
closure: HirLambda,
// TODO: How to define environment here?
_environment: Vec<Value>,
environment: Vec<Value>,
arguments: Vec<(Value, Location)>,
call_location: Location,
) -> IResult<Value> {
Expand All @@ -246,6 +245,10 @@
self.define_pattern(parameter, typ, argument, arg_location)?;
}

for (param, arg) in closure.captures.into_iter().zip(environment) {
self.define(param.ident.id, arg);
}

let result = self.evaluate(closure.body)?;

self.exit_function(previous_state);
Expand Down
42 changes: 20 additions & 22 deletions noir_stdlib/src/cmp.nr
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,10 @@ trait Eq {
// docs:end:eq-trait

comptime fn derive_eq(s: StructDefinition) -> Quoted {
let typ = s.as_type();

let impl_generics = s.generics().map(|g| quote { $g }).join(quote {,});

let where_clause = s.generics().map(|name| quote { $name: Eq }).join(quote {,});

// `(self.a == other.a) & (self.b == other.b) & ...`
let equalities = s.fields().map(
|f: (Quoted, Type)| {
let name = f.0;
quote { (self.$name == other.$name) }
}
);
let body = equalities.join(quote { & });

quote {
impl<$impl_generics> Eq for $typ where $where_clause {
fn eq(self, other: Self) -> bool {
$body
}
}
}
let signature = quote { fn eq(_self: Self, _other: Self) -> bool };
let for_each_field = |name| quote { (_self.$name == _other.$name) };
let body = |fields| fields;
crate::meta::make_trait_impl(s, quote { Eq }, signature, for_each_field, quote { & }, body)
}

impl Eq for Field { fn eq(self, other: Field) -> bool { self == other } }
Expand Down Expand Up @@ -127,12 +109,28 @@ impl Ordering {
}
}

#[derive_via(derive_ord)]
// docs:start:ord-trait
trait Ord {
fn cmp(self, other: Self) -> Ordering;
}
// docs:end:ord-trait

comptime fn derive_ord(s: StructDefinition) -> Quoted {
let signature = quote { fn cmp(_self: Self, _other: Self) -> std::cmp::Ordering };
let for_each_field = |name| quote {
if result == std::cmp::Ordering::equal() {
result = _self.$name.cmp(_other.$name);
}
};
let body = |fields| quote {
let mut result = std::cmp::Ordering::equal();
$fields
result
};
crate::meta::make_trait_impl(s, quote { Ord }, signature, for_each_field, quote {}, body)
}

// Note: Field deliberately does not implement Ord

impl Ord for u64 {
Expand Down
27 changes: 5 additions & 22 deletions noir_stdlib/src/default.nr
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,11 @@ trait Default {
// docs:end:default-trait

comptime fn derive_default(s: StructDefinition) -> Quoted {
let typ = s.as_type();

let impl_generics = s.generics().map(|g| quote { $g }).join(quote {,});

let where_clause = s.generics().map(|name| quote { $name: Default }).join(quote {,});

// `foo: Default::default(), bar: Default::default(), ...`
let fields = s.fields().map(
|f: (Quoted, Type)| {
let name = f.0;
quote { $name: Default::default() }
}
);
let fields = fields.join(quote {,});

quote {
impl<$impl_generics> Default for $typ where $where_clause {
fn default() -> Self {
Self { $fields }
}
}
}
let name = quote { Default };
let signature = quote { fn default() -> Self };
let for_each_field = |name| quote { $name: Default::default() };
let body = |fields| quote { Self { $fields } };
crate::meta::make_trait_impl(s, name, signature, for_each_field, quote { , }, body)
}

impl Default for Field { fn default() -> Field { 0 } }
Expand Down
11 changes: 10 additions & 1 deletion noir_stdlib/src/hash/mod.nr
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use crate::uint128::U128;
use crate::sha256::{digest, sha256_var};
use crate::collections::vec::Vec;
use crate::embedded_curve_ops::{EmbeddedCurvePoint, EmbeddedCurveScalar, multi_scalar_mul, multi_scalar_mul_slice};
use crate::meta::derive_via;

#[foreign(sha256)]
// docs:start:sha256
Expand Down Expand Up @@ -141,10 +142,18 @@ pub fn sha256_compression(_input: [u32; 16], _state: [u32; 8]) -> [u32; 8] {}
// Partially ported and impacted by rust.

// Hash trait shall be implemented per type.
trait Hash{
#[derive_via(derive_hash)]
trait Hash {
fn hash<H>(self, state: &mut H) where H: Hasher;
}

comptime fn derive_hash(s: StructDefinition) -> Quoted {
let name = quote { Hash };
let signature = quote { fn hash<H>(_self: Self, _state: &mut H) where H: std::hash::Hasher };
let for_each_field = |name| quote { _self.$name.hash(_state); };
crate::meta::make_trait_impl(s, name, signature, for_each_field, quote {}, |fields| fields)
}

// Hasher trait shall be implemented by algorithms to provide hash-agnostic means.
// TODO: consider making the types generic here ([u8], [Field], etc.)
trait Hasher{
Expand Down
42 changes: 42 additions & 0 deletions noir_stdlib/src/meta/mod.nr
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,45 @@ pub comptime fn derive(s: StructDefinition, traits: [TraitDefinition]) -> Quoted
unconstrained pub comptime fn derive_via(t: TraitDefinition, f: DeriveFunction) {
HANDLERS.insert(t, f);
}

/// `make_impl` is a helper function to make a simple impl, usually while deriving a trait.
/// This impl has a couple assumptions:
/// 1. The impl only has one function, with the signature `function_signature`
/// 2. The trait itself does not have any generics.
///
/// While these assumptions are met, `make_impl` will create an impl from a StructDefinition,
/// automatically filling in the required generics from the struct, along with the where clause.
/// The function body is created by mapping each field with `for_each_field` and joining the
/// results with `join_fields_with`. The result of this is passed to the `body` function for
/// any final processing - e.g. wrapping each field in a `StructConstructor { .. }` expression.
///
/// See `derive_eq` and `derive_default` for example usage.
pub comptime fn make_trait_impl<Env1, Env2>(
s: StructDefinition,
trait_name: Quoted,
function_signature: Quoted,
for_each_field: fn[Env1](Quoted) -> Quoted,
join_fields_with: Quoted,
body: fn[Env2](Quoted) -> Quoted
) -> Quoted {
let typ = s.as_type();
let impl_generics = s.generics().map(|g| quote { $g }).join(quote {,});
let where_clause = s.generics().map(|name| quote { $name: $trait_name }).join(quote {,});

// `for_each_field(field1) $join_fields_with for_each_field(field2) $join_fields_with ...`
let fields = s.fields().map(
|f: (Quoted, Type)| {
let name = f.0;
for_each_field(name)
}
);
let body = body(fields.join(join_fields_with));

quote {
impl<$impl_generics> $trait_name for $typ where $where_clause {
$function_signature {
$body
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[package]
name = "comptime_closures"
type = "bin"
authors = [""]
compiler_version = ">=0.32.0"

[dependencies]
39 changes: 39 additions & 0 deletions test_programs/compile_success_empty/comptime_closures/src/main.nr
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
fn main() {
comptime
{
closure_test(0);
}
}

fn closure_test(mut x: Field) {
let one = 1;
let add1 = |z| {
(|| {
*z += one;
})()
};

let two = 2;
let add2 = |z| {
*z = *z + two;
};

add1(&mut x);
assert(x == 1);

add2(&mut x);
assert(x == 3);

issue_2120();
}

fn issue_2120() {
let x1 = &mut 42;
let set_x1 = |y| { *x1 = y; };

assert(*x1 == 42);
set_x1(44);
assert(*x1 == 44);
set_x1(*x1);
assert(*x1 == 44);
}
30 changes: 28 additions & 2 deletions test_programs/execution_success/derive/src/main.nr
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::hash::Hash;

#[derive_via(derive_do_nothing)]
trait DoNothing {
fn do_nothing(self);
Expand All @@ -20,14 +22,15 @@ comptime fn derive_do_nothing(s: StructDefinition) -> Quoted {
}

// Test stdlib derive fns & multiple traits
#[derive(Eq, Default)]
// - We can derive Ord and Hash even though std::cmp::Ordering and std::hash::Hasher aren't imported
#[derive(Eq, Default, Hash, Ord)]
struct MyOtherStruct<A, B> {
field1: A,
field2: B,
field3: MyOtherOtherStruct<B>,
}

#[derive(Eq, Default)]
#[derive(Eq, Default, Hash, Ord)]
struct MyOtherOtherStruct<T> {
x: T,
}
Expand All @@ -41,4 +44,27 @@ fn main() {

let o: MyOtherStruct<u8, [str<2>]> = MyOtherStruct::default();
assert_eq(o, o);

// Field & str<2> above don't implement Ord
let o1 = MyOtherStruct { field1: 12 as u32, field2: 24 as i8, field3: MyOtherOtherStruct { x: 54 as i8 } };
let o2 = MyOtherStruct { field1: 12 as u32, field2: 24 as i8, field3: MyOtherOtherStruct { x: 55 as i8 } };
assert(o1 < o2);

let mut hasher = TestHasher { result: 0 };
o1.hash(&mut hasher);
assert_eq(hasher.finish(), 12 + 24 + 54);
}

struct TestHasher {
result: Field,
}

impl std::hash::Hasher for TestHasher {
fn finish(self) -> Field {
self.result
}

fn write(&mut self, input: Field) {
self.result += input;
}
}
Loading