From a18238180cbd6c71f75fcfcb1a093ac29c839aeb Mon Sep 17 00:00:00 2001 From: Tom French <15848336+TomAFrench@users.noreply.github.com> Date: Thu, 11 Jan 2024 22:57:19 +0000 Subject: [PATCH] feat!: implement keccakf1600 in brillig (#3914) --- .../dsl/acir_format/serde/acir.hpp | 62 +++++++++++++++++++ noir/acvm-repo/acir/codegen/acir.cpp | 52 +++++++++++++++- noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs | 4 +- noir/acvm-repo/brillig/src/black_box.rs | 2 + noir/acvm-repo/brillig_vm/src/black_box.rs | 19 +++++- .../brillig/brillig_gen/brillig_black_box.rs | 17 ++++- .../src/brillig/brillig_ir/debug_show.rs | 3 + 7 files changed, 151 insertions(+), 8 deletions(-) diff --git a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp index 93ced6a19b3..52e4d5a0b55 100644 --- a/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp +++ b/barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp @@ -466,6 +466,15 @@ struct BlackBoxOp { static Keccak256 bincodeDeserialize(std::vector); }; + struct Keccakf1600 { + Circuit::HeapVector message; + Circuit::HeapArray output; + + friend bool operator==(const Keccakf1600&, const Keccakf1600&); + std::vector bincodeSerialize() const; + static Keccakf1600 bincodeDeserialize(std::vector); + }; + struct EcdsaSecp256k1 { Circuit::HeapVector hashed_msg; Circuit::HeapArray public_key_x; @@ -558,6 +567,7 @@ struct BlackBoxOp { Blake2s, Blake3, Keccak256, + Keccakf1600, EcdsaSecp256k1, EcdsaSecp256r1, SchnorrVerify, @@ -3148,6 +3158,58 @@ Circuit::BlackBoxOp::Keccak256 serde::Deserializable BlackBoxOp::Keccakf1600::bincodeSerialize() const +{ + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); +} + +inline BlackBoxOp::Keccakf1600 BlackBoxOp::Keccakf1600::bincodeDeserialize(std::vector input) +{ + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw_or_abort("Some input bytes were not read"); + } + return value; +} + +} // end of namespace Circuit + +template <> +template +void serde::Serializable::serialize(const Circuit::BlackBoxOp::Keccakf1600& obj, + Serializer& serializer) +{ + serde::Serializable::serialize(obj.message, serializer); + serde::Serializable::serialize(obj.output, serializer); +} + +template <> +template +Circuit::BlackBoxOp::Keccakf1600 serde::Deserializable::deserialize( + Deserializer& deserializer) +{ + Circuit::BlackBoxOp::Keccakf1600 obj; + obj.message = serde::Deserializable::deserialize(deserializer); + obj.output = serde::Deserializable::deserialize(deserializer); + return obj; +} + +namespace Circuit { + inline bool operator==(const BlackBoxOp::EcdsaSecp256k1& lhs, const BlackBoxOp::EcdsaSecp256k1& rhs) { if (!(lhs.hashed_msg == rhs.hashed_msg)) { diff --git a/noir/acvm-repo/acir/codegen/acir.cpp b/noir/acvm-repo/acir/codegen/acir.cpp index 7d9d293a776..30f6e756337 100644 --- a/noir/acvm-repo/acir/codegen/acir.cpp +++ b/noir/acvm-repo/acir/codegen/acir.cpp @@ -446,6 +446,15 @@ namespace Circuit { static Keccak256 bincodeDeserialize(std::vector); }; + struct Keccakf1600 { + Circuit::HeapVector message; + Circuit::HeapArray output; + + friend bool operator==(const Keccakf1600&, const Keccakf1600&); + std::vector bincodeSerialize() const; + static Keccakf1600 bincodeDeserialize(std::vector); + }; + struct EcdsaSecp256k1 { Circuit::HeapVector hashed_msg; Circuit::HeapArray public_key_x; @@ -534,7 +543,7 @@ namespace Circuit { static EmbeddedCurveDouble bincodeDeserialize(std::vector); }; - std::variant value; + std::variant value; friend bool operator==(const BlackBoxOp&, const BlackBoxOp&); std::vector bincodeSerialize() const; @@ -2686,6 +2695,47 @@ Circuit::BlackBoxOp::Keccak256 serde::Deserializable BlackBoxOp::Keccakf1600::bincodeSerialize() const { + auto serializer = serde::BincodeSerializer(); + serde::Serializable::serialize(*this, serializer); + return std::move(serializer).bytes(); + } + + inline BlackBoxOp::Keccakf1600 BlackBoxOp::Keccakf1600::bincodeDeserialize(std::vector input) { + auto deserializer = serde::BincodeDeserializer(input); + auto value = serde::Deserializable::deserialize(deserializer); + if (deserializer.get_buffer_offset() < input.size()) { + throw serde::deserialization_error("Some input bytes were not read"); + } + return value; + } + +} // end of namespace Circuit + +template <> +template +void serde::Serializable::serialize(const Circuit::BlackBoxOp::Keccakf1600 &obj, Serializer &serializer) { + serde::Serializable::serialize(obj.message, serializer); + serde::Serializable::serialize(obj.output, serializer); +} + +template <> +template +Circuit::BlackBoxOp::Keccakf1600 serde::Deserializable::deserialize(Deserializer &deserializer) { + Circuit::BlackBoxOp::Keccakf1600 obj; + obj.message = serde::Deserializable::deserialize(deserializer); + obj.output = serde::Deserializable::deserialize(deserializer); + return obj; +} + namespace Circuit { inline bool operator==(const BlackBoxOp::EcdsaSecp256k1 &lhs, const BlackBoxOp::EcdsaSecp256k1 &rhs) { diff --git a/noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs b/noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs index ca355b6045d..5eea234885c 100644 --- a/noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs +++ b/noir/acvm-repo/acvm/src/pwg/blackbox/mod.rs @@ -119,8 +119,8 @@ pub(crate) fn solve( let lane = witness_assignment.try_to_u64(); state[i] = lane.unwrap(); } - let state = keccakf1600(state)?; - for (output_witness, value) in outputs.iter().zip(state.into_iter()) { + let output_state = keccakf1600(state)?; + for (output_witness, value) in outputs.iter().zip(output_state.into_iter()) { insert_value(output_witness, FieldElement::from(value as u128), initial_witness)?; } Ok(()) diff --git a/noir/acvm-repo/brillig/src/black_box.rs b/noir/acvm-repo/brillig/src/black_box.rs index 2286539e4c1..e63da276a7f 100644 --- a/noir/acvm-repo/brillig/src/black_box.rs +++ b/noir/acvm-repo/brillig/src/black_box.rs @@ -13,6 +13,8 @@ pub enum BlackBoxOp { Blake3 { message: HeapVector, output: HeapArray }, /// Calculates the Keccak256 hash of the inputs. Keccak256 { message: HeapVector, output: HeapArray }, + /// Keccak Permutation function of 1600 width + Keccakf1600 { message: HeapVector, output: HeapArray }, /// Verifies a ECDSA signature over the secp256k1 curve. EcdsaSecp256k1 { hashed_msg: HeapVector, diff --git a/noir/acvm-repo/brillig_vm/src/black_box.rs b/noir/acvm-repo/brillig_vm/src/black_box.rs index a6e904c2902..463038509e1 100644 --- a/noir/acvm-repo/brillig_vm/src/black_box.rs +++ b/noir/acvm-repo/brillig_vm/src/black_box.rs @@ -1,8 +1,8 @@ use acir::brillig::{BlackBoxOp, HeapArray, HeapVector, Value}; use acir::{BlackBoxFunc, FieldElement}; use acvm_blackbox_solver::{ - blake2s, blake3, ecdsa_secp256k1_verify, ecdsa_secp256r1_verify, keccak256, sha256, - BlackBoxFunctionSolver, BlackBoxResolutionError, + blake2s, blake3, ecdsa_secp256k1_verify, ecdsa_secp256r1_verify, keccak256, keccakf1600, + sha256, BlackBoxFunctionSolver, BlackBoxResolutionError, }; use crate::{Memory, Registers}; @@ -70,6 +70,20 @@ pub(crate) fn evaluate_black_box( memory.write_slice(registers.get(output.pointer).to_usize(), &to_value_vec(&bytes)); Ok(()) } + BlackBoxOp::Keccakf1600 { message, output } => { + let state_vec: Vec = read_heap_vector(memory, registers, message) + .iter() + .map(|value| value.to_field().try_to_u64().unwrap()) + .collect(); + let state: [u64; 25] = state_vec.try_into().unwrap(); + + let new_state = keccakf1600(state)?; + + let new_state: Vec = + new_state.into_iter().map(|x| Value::from(x as usize)).collect(); + memory.write_slice(registers.get(output.pointer).to_usize(), &new_state); + Ok(()) + } BlackBoxOp::EcdsaSecp256k1 { hashed_msg, public_key_x, @@ -195,6 +209,7 @@ fn black_box_function_from_op(op: &BlackBoxOp) -> BlackBoxFunc { BlackBoxOp::Blake2s { .. } => BlackBoxFunc::Blake2s, BlackBoxOp::Blake3 { .. } => BlackBoxFunc::Blake3, BlackBoxOp::Keccak256 { .. } => BlackBoxFunc::Keccak256, + BlackBoxOp::Keccakf1600 { .. } => BlackBoxFunc::Keccakf1600, BlackBoxOp::EcdsaSecp256k1 { .. } => BlackBoxFunc::EcdsaSecp256k1, BlackBoxOp::EcdsaSecp256r1 { .. } => BlackBoxFunc::EcdsaSecp256r1, BlackBoxOp::SchnorrVerify { .. } => BlackBoxFunc::SchnorrVerify, diff --git a/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs b/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs index 5a5f9694534..c081806f4a7 100644 --- a/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs +++ b/noir/compiler/noirc_evaluator/src/brillig/brillig_gen/brillig_black_box.rs @@ -71,6 +71,20 @@ pub(crate) fn convert_black_box_call( unreachable!("ICE: Keccak256 expects message, message size and result array") } } + BlackBoxFunc::Keccakf1600 => { + if let ([message], [BrilligVariable::BrilligArray(result_array)]) = + (function_arguments, function_results) + { + let state_vector = convert_array_or_vector(brillig_context, message, bb_func); + + brillig_context.black_box_op_instruction(BlackBoxOp::Keccakf1600 { + message: state_vector.to_heap_vector(), + output: result_array.to_heap_array(), + }); + } else { + unreachable!("ICE: Keccakf1600 expects one array argument and one array result") + } + } BlackBoxFunc::EcdsaSecp256k1 => { if let ( [BrilligVariable::BrilligArray(public_key_x), BrilligVariable::BrilligArray(public_key_y), BrilligVariable::BrilligArray(signature), message], @@ -230,9 +244,6 @@ pub(crate) fn convert_black_box_call( BlackBoxFunc::RecursiveAggregation => unimplemented!( "ICE: `BlackBoxFunc::RecursiveAggregation` is not implemented by the Brillig VM" ), - BlackBoxFunc::Keccakf1600 => { - unimplemented!("ICE: `BlackBoxFunc::Keccakf1600` is not implemented by the Brillig VM") - } } } diff --git a/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs b/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs index 66c6b3b0249..dc8c6b6694c 100644 --- a/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs +++ b/noir/compiler/noirc_evaluator/src/brillig/brillig_ir/debug_show.rs @@ -350,6 +350,9 @@ impl DebugShow { BlackBoxOp::Keccak256 { message, output } => { debug_println!(self.enable_debug_trace, " KECCAK256 {} -> {}", message, output); } + BlackBoxOp::Keccakf1600 { message, output } => { + debug_println!(self.enable_debug_trace, " KECCAKF1600 {} -> {}", message, output); + } BlackBoxOp::Blake2s { message, output } => { debug_println!(self.enable_debug_trace, " BLAKE2S {} -> {}", message, output); }