Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: move to_radix to a blackbox #6294

Merged
merged 4 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 69 additions & 5 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,6 @@ struct BlackBoxOp {
Program::HeapVector inputs;
Program::HeapArray iv;
Program::HeapArray key;
Program::MemoryAddress length;
Program::HeapVector outputs;

friend bool operator==(const AES128Encrypt&, const AES128Encrypt&);
Expand Down Expand Up @@ -896,6 +895,16 @@ struct BlackBoxOp {
static Sha256Compression bincodeDeserialize(std::vector<uint8_t>);
};

struct ToRadix {
Program::MemoryAddress input;
uint32_t radix;
Program::HeapArray output;

friend bool operator==(const ToRadix&, const ToRadix&);
std::vector<uint8_t> bincodeSerialize() const;
static ToRadix bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AES128Encrypt,
Sha256,
Blake2s,
Expand All @@ -916,7 +925,8 @@ struct BlackBoxOp {
BigIntFromLeBytes,
BigIntToLeBytes,
Poseidon2Permutation,
Sha256Compression>
Sha256Compression,
ToRadix>
value;

friend bool operator==(const BlackBoxOp&, const BlackBoxOp&);
Expand Down Expand Up @@ -3939,9 +3949,6 @@ inline bool operator==(const BlackBoxOp::AES128Encrypt& lhs, const BlackBoxOp::A
if (!(lhs.key == rhs.key)) {
return false;
}
if (!(lhs.length == rhs.length)) {
return false;
}
if (!(lhs.outputs == rhs.outputs)) {
return false;
}
Expand Down Expand Up @@ -5141,6 +5148,63 @@ Program::BlackBoxOp::Sha256Compression serde::Deserializable<Program::BlackBoxOp

namespace Program {

inline bool operator==(const BlackBoxOp::ToRadix& lhs, const BlackBoxOp::ToRadix& rhs)
{
if (!(lhs.input == rhs.input)) {
return false;
}
if (!(lhs.radix == rhs.radix)) {
return false;
}
if (!(lhs.output == rhs.output)) {
return false;
}
return true;
}

inline std::vector<uint8_t> BlackBoxOp::ToRadix::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BlackBoxOp::ToRadix>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BlackBoxOp::ToRadix BlackBoxOp::ToRadix::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BlackBoxOp::ToRadix>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::BlackBoxOp::ToRadix>::serialize(const Program::BlackBoxOp::ToRadix& obj,
Serializer& serializer)
{
serde::Serializable<decltype(obj.input)>::serialize(obj.input, serializer);
serde::Serializable<decltype(obj.radix)>::serialize(obj.radix, serializer);
serde::Serializable<decltype(obj.output)>::serialize(obj.output, serializer);
}

template <>
template <typename Deserializer>
Program::BlackBoxOp::ToRadix serde::Deserializable<Program::BlackBoxOp::ToRadix>::deserialize(
Deserializer& deserializer)
{
Program::BlackBoxOp::ToRadix obj;
obj.input = serde::Deserializable<decltype(obj.input)>::deserialize(deserializer);
obj.radix = serde::Deserializable<decltype(obj.radix)>::deserialize(deserializer);
obj.output = serde::Deserializable<decltype(obj.output)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BlockId& lhs, const BlockId& rhs)
{
if (!(lhs.value == rhs.value)) {
Expand Down
56 changes: 55 additions & 1 deletion noir/noir-repo/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -870,7 +870,17 @@ namespace Program {
static Sha256Compression bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AES128Encrypt, Sha256, Blake2s, Blake3, Keccak256, Keccakf1600, EcdsaSecp256k1, EcdsaSecp256r1, SchnorrVerify, PedersenCommitment, PedersenHash, MultiScalarMul, EmbeddedCurveAdd, BigIntAdd, BigIntSub, BigIntMul, BigIntDiv, BigIntFromLeBytes, BigIntToLeBytes, Poseidon2Permutation, Sha256Compression> value;
struct ToRadix {
Program::MemoryAddress input;
uint32_t radix;
Program::HeapArray output;

friend bool operator==(const ToRadix&, const ToRadix&);
std::vector<uint8_t> bincodeSerialize() const;
static ToRadix bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AES128Encrypt, Sha256, Blake2s, Blake3, Keccak256, Keccakf1600, EcdsaSecp256k1, EcdsaSecp256r1, SchnorrVerify, PedersenCommitment, PedersenHash, MultiScalarMul, EmbeddedCurveAdd, BigIntAdd, BigIntSub, BigIntMul, BigIntDiv, BigIntFromLeBytes, BigIntToLeBytes, Poseidon2Permutation, Sha256Compression, ToRadix> value;

friend bool operator==(const BlackBoxOp&, const BlackBoxOp&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -4293,6 +4303,50 @@ Program::BlackBoxOp::Sha256Compression serde::Deserializable<Program::BlackBoxOp
return obj;
}

namespace Program {

inline bool operator==(const BlackBoxOp::ToRadix &lhs, const BlackBoxOp::ToRadix &rhs) {
if (!(lhs.input == rhs.input)) { return false; }
if (!(lhs.radix == rhs.radix)) { return false; }
if (!(lhs.output == rhs.output)) { return false; }
return true;
}

inline std::vector<uint8_t> BlackBoxOp::ToRadix::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BlackBoxOp::ToRadix>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BlackBoxOp::ToRadix BlackBoxOp::ToRadix::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BlackBoxOp::ToRadix>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::BlackBoxOp::ToRadix>::serialize(const Program::BlackBoxOp::ToRadix &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.input)>::serialize(obj.input, serializer);
serde::Serializable<decltype(obj.radix)>::serialize(obj.radix, serializer);
serde::Serializable<decltype(obj.output)>::serialize(obj.output, serializer);
}

template <>
template <typename Deserializer>
Program::BlackBoxOp::ToRadix serde::Deserializable<Program::BlackBoxOp::ToRadix>::deserialize(Deserializer &deserializer) {
Program::BlackBoxOp::ToRadix obj;
obj.input = serde::Deserializable<decltype(obj.input)>::deserialize(deserializer);
obj.radix = serde::Deserializable<decltype(obj.radix)>::deserialize(deserializer);
obj.output = serde::Deserializable<decltype(obj.output)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BlockId &lhs, const BlockId &rhs) {
Expand Down
5 changes: 5 additions & 0 deletions noir/noir-repo/acvm-repo/brillig/src/black_box.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,9 @@ pub enum BlackBoxOp {
hash_values: HeapVector,
output: HeapArray,
},
ToRadix {
input: MemoryAddress,
radix: u32,
output: HeapArray,
},
}
21 changes: 21 additions & 0 deletions noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use acvm_blackbox_solver::{
aes128_encrypt, blake2s, blake3, ecdsa_secp256k1_verify, ecdsa_secp256r1_verify, keccak256,
keccakf1600, sha256, sha256compression, BlackBoxFunctionSolver, BlackBoxResolutionError,
};
use num_bigint::BigUint;

use crate::memory::MemoryValue;
use crate::Memory;
Expand Down Expand Up @@ -295,6 +296,25 @@ pub(crate) fn evaluate_black_box<Solver: BlackBoxFunctionSolver>(
memory.write_slice(memory.read_ref(output.pointer), &state);
Ok(())
}
BlackBoxOp::ToRadix { input, radix, output } => {
let input: FieldElement =
memory.read(*input).try_into().expect("ToRadix input not a field");

let mut input = BigUint::from_bytes_be(&input.to_be_bytes());
let radix = BigUint::from(*radix);

let mut limbs: Vec<MemoryValue> = Vec::with_capacity(output.size);

for _ in 0..output.size {
let limb = &input % &radix;
limbs.push(FieldElement::from_be_bytes_reduce(&limb.to_bytes_be()).into());
input /= &radix;
}

memory.write_slice(memory.read_ref(output.pointer), &limbs);

Ok(())
}
}
}

Expand All @@ -321,6 +341,7 @@ fn black_box_function_from_op(op: &BlackBoxOp) -> BlackBoxFunc {
BlackBoxOp::BigIntToLeBytes { .. } => BlackBoxFunc::BigIntToLeBytes,
BlackBoxOp::Poseidon2Permutation { .. } => BlackBoxFunc::Poseidon2Permutation,
BlackBoxOp::Sha256Compression { .. } => BlackBoxFunc::Sha256Compression,
BlackBoxOp::ToRadix { .. } => unreachable!("ToRadix is not an ACIR BlackBoxFunc"),
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,22 @@ impl<'block> BrilligBlock<'block> {
}
Value::Intrinsic(Intrinsic::ToRadix(endianness)) => {
let source = self.convert_ssa_single_addr_value(arguments[0], dfg);
let radix = self.convert_ssa_single_addr_value(arguments[1], dfg);
let limb_count = self.convert_ssa_single_addr_value(arguments[2], dfg);

let radix: u32 = dfg
.get_numeric_constant(arguments[1])
.expect("Radix should be known")
.try_to_u64()
.expect("Radix should fit in u64")
.try_into()
.expect("Radix should be u32");

let limb_count: usize = dfg
.get_numeric_constant(arguments[2])
.expect("Limb count should be known")
.try_to_u64()
.expect("Limb count should fit in u64")
.try_into()
.expect("Limb count should fit in usize");

let results = dfg.instruction_results(instruction_id);

Expand All @@ -511,7 +525,8 @@ impl<'block> BrilligBlock<'block> {
.extract_vector();

// Update the user-facing slice length
self.brillig_context.cast_instruction(target_len, limb_count);
self.brillig_context
.usize_const_instruction(target_len.address, limb_count.into());

self.brillig_context.codegen_to_radix(
source,
Expand All @@ -524,7 +539,13 @@ impl<'block> BrilligBlock<'block> {
}
Value::Intrinsic(Intrinsic::ToBits(endianness)) => {
let source = self.convert_ssa_single_addr_value(arguments[0], dfg);
let limb_count = self.convert_ssa_single_addr_value(arguments[1], dfg);
let limb_count: usize = dfg
.get_numeric_constant(arguments[1])
.expect("Limb count should be known")
.try_to_u64()
.expect("Limb count should fit in u64")
.try_into()
.expect("Limb count should fit in usize");

let results = dfg.instruction_results(instruction_id);

Expand All @@ -549,21 +570,18 @@ impl<'block> BrilligBlock<'block> {
BrilligVariable::SingleAddr(..) => unreachable!("ICE: ToBits on non-array"),
};

let radix = self.brillig_context.make_constant_instruction(2_usize.into(), 32);

// Update the user-facing slice length
self.brillig_context.cast_instruction(target_len, limb_count);
self.brillig_context
.usize_const_instruction(target_len.address, limb_count.into());

self.brillig_context.codegen_to_radix(
source,
target_vector,
radix,
2,
limb_count,
matches!(endianness, Endian::Big),
1,
);

self.brillig_context.deallocate_single_addr(radix);
}
_ => {
unreachable!("unsupported function call type {:?}", dfg[*func])
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use acvm::FieldElement;

use crate::brillig::brillig_ir::BrilligBinaryOp;
use acvm::{
acir::brillig::{BlackBoxOp, HeapArray},
FieldElement,
};

use super::{
brillig_variable::{BrilligVector, SingleAddrVariable},
Expand Down Expand Up @@ -36,57 +37,46 @@ impl BrilligContext {
&mut self,
source_field: SingleAddrVariable,
target_vector: BrilligVector,
radix: SingleAddrVariable,
limb_count: SingleAddrVariable,
radix: u32,
limb_count: usize,
big_endian: bool,
limb_bit_size: u32,
) {
assert!(source_field.bit_size == FieldElement::max_num_bits());
assert!(radix.bit_size == 32);
assert!(limb_count.bit_size == 32);
let radix_as_field =
SingleAddrVariable::new(self.allocate_register(), FieldElement::max_num_bits());
self.cast_instruction(radix_as_field, radix);

self.cast_instruction(SingleAddrVariable::new_usize(target_vector.size), limb_count);
self.usize_const_instruction(target_vector.size, limb_count.into());
self.usize_const_instruction(target_vector.rc, 1_usize.into());
self.codegen_allocate_array(target_vector.pointer, target_vector.size);

let shifted_field =
SingleAddrVariable::new(self.allocate_register(), FieldElement::max_num_bits());
self.mov_instruction(shifted_field.address, source_field.address);
self.black_box_op_instruction(BlackBoxOp::ToRadix {
input: source_field.address,
radix,
output: HeapArray { pointer: target_vector.pointer, size: limb_count },
});

let limb_field =
SingleAddrVariable::new(self.allocate_register(), FieldElement::max_num_bits());

let limb_casted = SingleAddrVariable::new(self.allocate_register(), limb_bit_size);

self.codegen_loop(target_vector.size, |ctx, iterator_register| {
// Compute the modulus
ctx.binary_instruction(
shifted_field,
radix_as_field,
limb_field,
BrilligBinaryOp::Modulo,
);
// Cast it
ctx.cast_instruction(limb_casted, limb_field);
// Write it
ctx.codegen_array_set(target_vector.pointer, iterator_register, limb_casted.address);
// Integer div the field
ctx.binary_instruction(
shifted_field,
radix_as_field,
shifted_field,
BrilligBinaryOp::UnsignedDiv,
);
});
if limb_bit_size != FieldElement::max_num_bits() {
self.codegen_loop(target_vector.size, |ctx, iterator_register| {
// Read the limb
ctx.codegen_array_get(target_vector.pointer, iterator_register, limb_field.address);
// Cast it
ctx.cast_instruction(limb_casted, limb_field);
// Write it
ctx.codegen_array_set(
target_vector.pointer,
iterator_register,
limb_casted.address,
);
});
}

// Deallocate our temporary registers
self.deallocate_single_addr(shifted_field);
self.deallocate_single_addr(limb_field);
self.deallocate_single_addr(limb_casted);
self.deallocate_single_addr(radix_as_field);

if big_endian {
self.codegen_reverse_vector_in_place(target_vector);
Expand Down
Loading
Loading