Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: move to_radix to a blackbox #6294

Merged
merged 4 commits into from
May 9, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
feat: move to_radix to a blackbox
sirasistant committed May 9, 2024
commit e907bd8b03856a7cc581aff48bba95cb8302ad58
74 changes: 69 additions & 5 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
@@ -686,7 +686,6 @@ struct BlackBoxOp {
Program::HeapVector inputs;
Program::HeapArray iv;
Program::HeapArray key;
Program::MemoryAddress length;
Program::HeapVector outputs;

friend bool operator==(const AES128Encrypt&, const AES128Encrypt&);
@@ -896,6 +895,16 @@ struct BlackBoxOp {
static Sha256Compression bincodeDeserialize(std::vector<uint8_t>);
};

struct ToRadix {
Program::MemoryAddress input;
uint32_t radix;
Program::HeapArray output;

friend bool operator==(const ToRadix&, const ToRadix&);
std::vector<uint8_t> bincodeSerialize() const;
static ToRadix bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AES128Encrypt,
Sha256,
Blake2s,
@@ -916,7 +925,8 @@ struct BlackBoxOp {
BigIntFromLeBytes,
BigIntToLeBytes,
Poseidon2Permutation,
Sha256Compression>
Sha256Compression,
ToRadix>
value;

friend bool operator==(const BlackBoxOp&, const BlackBoxOp&);
@@ -3939,9 +3949,6 @@ inline bool operator==(const BlackBoxOp::AES128Encrypt& lhs, const BlackBoxOp::A
if (!(lhs.key == rhs.key)) {
return false;
}
if (!(lhs.length == rhs.length)) {
return false;
}
if (!(lhs.outputs == rhs.outputs)) {
return false;
}
@@ -5141,6 +5148,63 @@ Program::BlackBoxOp::Sha256Compression serde::Deserializable<Program::BlackBoxOp

namespace Program {

inline bool operator==(const BlackBoxOp::ToRadix& lhs, const BlackBoxOp::ToRadix& rhs)
{
if (!(lhs.input == rhs.input)) {
return false;
}
if (!(lhs.radix == rhs.radix)) {
return false;
}
if (!(lhs.output == rhs.output)) {
return false;
}
return true;
}

inline std::vector<uint8_t> BlackBoxOp::ToRadix::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BlackBoxOp::ToRadix>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BlackBoxOp::ToRadix BlackBoxOp::ToRadix::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BlackBoxOp::ToRadix>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::BlackBoxOp::ToRadix>::serialize(const Program::BlackBoxOp::ToRadix& obj,
Serializer& serializer)
{
serde::Serializable<decltype(obj.input)>::serialize(obj.input, serializer);
serde::Serializable<decltype(obj.radix)>::serialize(obj.radix, serializer);
serde::Serializable<decltype(obj.output)>::serialize(obj.output, serializer);
}

template <>
template <typename Deserializer>
Program::BlackBoxOp::ToRadix serde::Deserializable<Program::BlackBoxOp::ToRadix>::deserialize(
Deserializer& deserializer)
{
Program::BlackBoxOp::ToRadix obj;
obj.input = serde::Deserializable<decltype(obj.input)>::deserialize(deserializer);
obj.radix = serde::Deserializable<decltype(obj.radix)>::deserialize(deserializer);
obj.output = serde::Deserializable<decltype(obj.output)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BlockId& lhs, const BlockId& rhs)
{
if (!(lhs.value == rhs.value)) {
56 changes: 55 additions & 1 deletion noir/noir-repo/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
@@ -870,7 +870,17 @@ namespace Program {
static Sha256Compression bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AES128Encrypt, Sha256, Blake2s, Blake3, Keccak256, Keccakf1600, EcdsaSecp256k1, EcdsaSecp256r1, SchnorrVerify, PedersenCommitment, PedersenHash, MultiScalarMul, EmbeddedCurveAdd, BigIntAdd, BigIntSub, BigIntMul, BigIntDiv, BigIntFromLeBytes, BigIntToLeBytes, Poseidon2Permutation, Sha256Compression> value;
struct ToRadix {
Program::MemoryAddress input;
uint32_t radix;
Program::HeapArray output;

friend bool operator==(const ToRadix&, const ToRadix&);
std::vector<uint8_t> bincodeSerialize() const;
static ToRadix bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<AES128Encrypt, Sha256, Blake2s, Blake3, Keccak256, Keccakf1600, EcdsaSecp256k1, EcdsaSecp256r1, SchnorrVerify, PedersenCommitment, PedersenHash, MultiScalarMul, EmbeddedCurveAdd, BigIntAdd, BigIntSub, BigIntMul, BigIntDiv, BigIntFromLeBytes, BigIntToLeBytes, Poseidon2Permutation, Sha256Compression, ToRadix> value;

friend bool operator==(const BlackBoxOp&, const BlackBoxOp&);
std::vector<uint8_t> bincodeSerialize() const;
@@ -4293,6 +4303,50 @@ Program::BlackBoxOp::Sha256Compression serde::Deserializable<Program::BlackBoxOp
return obj;
}

namespace Program {

inline bool operator==(const BlackBoxOp::ToRadix &lhs, const BlackBoxOp::ToRadix &rhs) {
if (!(lhs.input == rhs.input)) { return false; }
if (!(lhs.radix == rhs.radix)) { return false; }
if (!(lhs.output == rhs.output)) { return false; }
return true;
}

inline std::vector<uint8_t> BlackBoxOp::ToRadix::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BlackBoxOp::ToRadix>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BlackBoxOp::ToRadix BlackBoxOp::ToRadix::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BlackBoxOp::ToRadix>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Program

template <>
template <typename Serializer>
void serde::Serializable<Program::BlackBoxOp::ToRadix>::serialize(const Program::BlackBoxOp::ToRadix &obj, Serializer &serializer) {
serde::Serializable<decltype(obj.input)>::serialize(obj.input, serializer);
serde::Serializable<decltype(obj.radix)>::serialize(obj.radix, serializer);
serde::Serializable<decltype(obj.output)>::serialize(obj.output, serializer);
}

template <>
template <typename Deserializer>
Program::BlackBoxOp::ToRadix serde::Deserializable<Program::BlackBoxOp::ToRadix>::deserialize(Deserializer &deserializer) {
Program::BlackBoxOp::ToRadix obj;
obj.input = serde::Deserializable<decltype(obj.input)>::deserialize(deserializer);
obj.radix = serde::Deserializable<decltype(obj.radix)>::deserialize(deserializer);
obj.output = serde::Deserializable<decltype(obj.output)>::deserialize(deserializer);
return obj;
}

namespace Program {

inline bool operator==(const BlockId &lhs, const BlockId &rhs) {
5 changes: 5 additions & 0 deletions noir/noir-repo/acvm-repo/brillig/src/black_box.rs
Original file line number Diff line number Diff line change
@@ -126,4 +126,9 @@ pub enum BlackBoxOp {
hash_values: HeapVector,
output: HeapArray,
},
ToRadix {
input: MemoryAddress,
radix: u32,
output: HeapArray,
},
}
21 changes: 21 additions & 0 deletions noir/noir-repo/acvm-repo/brillig_vm/src/black_box.rs
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@ use acvm_blackbox_solver::{
aes128_encrypt, blake2s, blake3, ecdsa_secp256k1_verify, ecdsa_secp256r1_verify, keccak256,
keccakf1600, sha256, sha256compression, BlackBoxFunctionSolver, BlackBoxResolutionError,
};
use num_bigint::BigUint;

use crate::memory::MemoryValue;
use crate::Memory;
@@ -295,6 +296,25 @@ pub(crate) fn evaluate_black_box<Solver: BlackBoxFunctionSolver>(
memory.write_slice(memory.read_ref(output.pointer), &state);
Ok(())
}
BlackBoxOp::ToRadix { input, radix, output } => {
let input: FieldElement =
memory.read(*input).try_into().expect("ToRadix input not a field");

let mut input = BigUint::from_bytes_be(&input.to_be_bytes());
let radix = BigUint::from(*radix);

let mut limbs: Vec<MemoryValue> = Vec::with_capacity(output.size);

for _ in 0..output.size {
let limb = &input % &radix;
limbs.push(FieldElement::from_be_bytes_reduce(&limb.to_bytes_be()).into());
input /= &radix;
}

memory.write_slice(memory.read_ref(output.pointer), &limbs);

Ok(())
}
}
}

@@ -321,6 +341,7 @@ fn black_box_function_from_op(op: &BlackBoxOp) -> BlackBoxFunc {
BlackBoxOp::BigIntToLeBytes { .. } => BlackBoxFunc::BigIntToLeBytes,
BlackBoxOp::Poseidon2Permutation { .. } => BlackBoxFunc::Poseidon2Permutation,
BlackBoxOp::Sha256Compression { .. } => BlackBoxFunc::Sha256Compression,
BlackBoxOp::ToRadix { .. } => unreachable!("ToRadix is not an ACIR BlackBoxFunc"),
}
}

Original file line number Diff line number Diff line change
@@ -488,8 +488,22 @@ impl<'block> BrilligBlock<'block> {
}
Value::Intrinsic(Intrinsic::ToRadix(endianness)) => {
let source = self.convert_ssa_single_addr_value(arguments[0], dfg);
let radix = self.convert_ssa_single_addr_value(arguments[1], dfg);
let limb_count = self.convert_ssa_single_addr_value(arguments[2], dfg);

let radix: u32 = dfg
.get_numeric_constant(arguments[1])
.expect("Radix should be known")
.try_to_u64()
.expect("Radix should fit in u64")
.try_into()
.expect("Radix should be u32");

let limb_count: usize = dfg
.get_numeric_constant(arguments[2])
.expect("Limb count should be known")
.try_to_u64()
.expect("Limb count should fit in u64")
.try_into()
.expect("Limb count should fit in usize");

let results = dfg.instruction_results(instruction_id);

@@ -511,7 +525,8 @@ impl<'block> BrilligBlock<'block> {
.extract_vector();

// Update the user-facing slice length
self.brillig_context.cast_instruction(target_len, limb_count);
self.brillig_context
.usize_const_instruction(target_len.address, limb_count.into());

self.brillig_context.codegen_to_radix(
source,
@@ -524,7 +539,13 @@ impl<'block> BrilligBlock<'block> {
}
Value::Intrinsic(Intrinsic::ToBits(endianness)) => {
let source = self.convert_ssa_single_addr_value(arguments[0], dfg);
let limb_count = self.convert_ssa_single_addr_value(arguments[1], dfg);
let limb_count: usize = dfg
.get_numeric_constant(arguments[1])
.expect("Limb count should be known")
.try_to_u64()
.expect("Limb count should fit in u64")
.try_into()
.expect("Limb count should fit in usize");

let results = dfg.instruction_results(instruction_id);

@@ -549,21 +570,18 @@ impl<'block> BrilligBlock<'block> {
BrilligVariable::SingleAddr(..) => unreachable!("ICE: ToBits on non-array"),
};

let radix = self.brillig_context.make_constant_instruction(2_usize.into(), 32);

// Update the user-facing slice length
self.brillig_context.cast_instruction(target_len, limb_count);
self.brillig_context
.usize_const_instruction(target_len.address, limb_count.into());

self.brillig_context.codegen_to_radix(
source,
target_vector,
radix,
2,
limb_count,
matches!(endianness, Endian::Big),
1,
);

self.brillig_context.deallocate_single_addr(radix);
}
_ => {
unreachable!("unsupported function call type {:?}", dfg[*func])
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use acvm::FieldElement;

use crate::brillig::brillig_ir::BrilligBinaryOp;
use acvm::{
acir::brillig::{BlackBoxOp, HeapArray},
FieldElement,
};

use super::{
brillig_variable::{BrilligVector, SingleAddrVariable},
@@ -36,57 +37,46 @@ impl BrilligContext {
&mut self,
source_field: SingleAddrVariable,
target_vector: BrilligVector,
radix: SingleAddrVariable,
limb_count: SingleAddrVariable,
radix: u32,
limb_count: usize,
big_endian: bool,
limb_bit_size: u32,
) {
assert!(source_field.bit_size == FieldElement::max_num_bits());
assert!(radix.bit_size == 32);
assert!(limb_count.bit_size == 32);
let radix_as_field =
SingleAddrVariable::new(self.allocate_register(), FieldElement::max_num_bits());
self.cast_instruction(radix_as_field, radix);

self.cast_instruction(SingleAddrVariable::new_usize(target_vector.size), limb_count);
self.usize_const_instruction(target_vector.size, limb_count.into());
self.usize_const_instruction(target_vector.rc, 1_usize.into());
self.codegen_allocate_array(target_vector.pointer, target_vector.size);

let shifted_field =
SingleAddrVariable::new(self.allocate_register(), FieldElement::max_num_bits());
self.mov_instruction(shifted_field.address, source_field.address);
self.black_box_op_instruction(BlackBoxOp::ToRadix {
input: source_field.address,
radix,
output: HeapArray { pointer: target_vector.pointer, size: limb_count },
});

let limb_field =
SingleAddrVariable::new(self.allocate_register(), FieldElement::max_num_bits());

let limb_casted = SingleAddrVariable::new(self.allocate_register(), limb_bit_size);

self.codegen_loop(target_vector.size, |ctx, iterator_register| {
// Compute the modulus
ctx.binary_instruction(
shifted_field,
radix_as_field,
limb_field,
BrilligBinaryOp::Modulo,
);
// Cast it
ctx.cast_instruction(limb_casted, limb_field);
// Write it
ctx.codegen_array_set(target_vector.pointer, iterator_register, limb_casted.address);
// Integer div the field
ctx.binary_instruction(
shifted_field,
radix_as_field,
shifted_field,
BrilligBinaryOp::UnsignedDiv,
);
});
if limb_bit_size != FieldElement::max_num_bits() {
self.codegen_loop(target_vector.size, |ctx, iterator_register| {
// Read the limb
ctx.codegen_array_get(target_vector.pointer, iterator_register, limb_field.address);
// Cast it
ctx.cast_instruction(limb_casted, limb_field);
// Write it
ctx.codegen_array_set(
target_vector.pointer,
iterator_register,
limb_casted.address,
);
});
}

// Deallocate our temporary registers
self.deallocate_single_addr(shifted_field);
self.deallocate_single_addr(limb_field);
self.deallocate_single_addr(limb_casted);
self.deallocate_single_addr(radix_as_field);

if big_endian {
self.codegen_reverse_vector_in_place(target_vector);
Original file line number Diff line number Diff line change
@@ -451,6 +451,15 @@ impl DebugShow {
output
);
}
BlackBoxOp::ToRadix { input, radix, output } => {
debug_println!(
self.enable_debug_trace,
" TO_RADIX {} {} -> {}",
input,
radix,
output
);
}
}
}

57 changes: 36 additions & 21 deletions noir/noir-repo/noir_stdlib/src/field/bn254.nr
Original file line number Diff line number Diff line change
@@ -25,7 +25,7 @@ unconstrained fn decompose_unsafe(x: Field) -> (Field, Field) {
fn assert_gt_limbs(a: (Field, Field), b: (Field, Field)) {
let (alo, ahi) = a;
let (blo, bhi) = b;
let borrow = lte_unsafe(alo, blo, 16);
let borrow = lte_unsafe_16(alo, blo);

let rlo = alo - blo - 1 + (borrow as Field) * TWO_POW_128;
let rhi = ahi - bhi - (borrow as Field);
@@ -51,9 +51,9 @@ pub fn decompose(x: Field) -> (Field, Field) {
(xlo, xhi)
}

unconstrained fn lt_unsafe(x: Field, y: Field, num_bytes: u32) -> bool {
let x_bytes = x.__to_le_radix(256, num_bytes);
let y_bytes = y.__to_le_radix(256, num_bytes);
fn lt_unsafe_internal(x: Field, y: Field, num_bytes: u32) -> bool {
let x_bytes = x.to_le_radix(256, num_bytes);
let y_bytes = y.to_le_radix(256, num_bytes);
let mut x_is_lt = false;
let mut done = false;
for i in 0..num_bytes {
@@ -70,8 +70,20 @@ unconstrained fn lt_unsafe(x: Field, y: Field, num_bytes: u32) -> bool {
x_is_lt
}

unconstrained fn lte_unsafe(x: Field, y: Field, num_bytes: u32) -> bool {
lt_unsafe(x, y, num_bytes) | (x == y)
fn lte_unsafe_internal(x: Field, y: Field, num_bytes: u32) -> bool {
if x == y {
true
} else {
lt_unsafe_internal(x, y, num_bytes)
}
}

unconstrained fn lt_unsafe_32(x: Field, y: Field) -> bool {
lt_unsafe_internal(x, y, 32)
}

unconstrained fn lte_unsafe_16(x: Field, y: Field) -> bool {
lte_unsafe_internal(x, y, 16)
}

pub fn assert_gt(a: Field, b: Field) {
@@ -90,7 +102,7 @@ pub fn assert_lt(a: Field, b: Field) {
pub fn gt(a: Field, b: Field) -> bool {
if a == b {
false
} else if lt_unsafe(a, b, 32) {
} else if lt_unsafe_32(a, b) {
assert_gt(b, a);
false
} else {
@@ -105,7 +117,10 @@ pub fn lt(a: Field, b: Field) -> bool {

mod tests {
// TODO: Allow imports from "super"
use crate::field::bn254::{decompose_unsafe, decompose, lt_unsafe, assert_gt, gt, lt, TWO_POW_128, lte_unsafe, PLO, PHI};
use crate::field::bn254::{
decompose_unsafe, decompose, lt_unsafe_internal, assert_gt, gt, lt, TWO_POW_128,
lte_unsafe_internal, PLO, PHI
};

#[test]
fn check_decompose_unsafe() {
@@ -123,23 +138,23 @@ mod tests {

#[test]
fn check_lt_unsafe() {
assert(lt_unsafe(0, 1, 16));
assert(lt_unsafe(0, 0x100, 16));
assert(lt_unsafe(0x100, TWO_POW_128 - 1, 16));
assert(!lt_unsafe(0, TWO_POW_128, 16));
assert(lt_unsafe_internal(0, 1, 16));
assert(lt_unsafe_internal(0, 0x100, 16));
assert(lt_unsafe_internal(0x100, TWO_POW_128 - 1, 16));
assert(!lt_unsafe_internal(0, TWO_POW_128, 16));
}

#[test]
fn check_lte_unsafe() {
assert(lte_unsafe(0, 1, 16));
assert(lte_unsafe(0, 0x100, 16));
assert(lte_unsafe(0x100, TWO_POW_128 - 1, 16));
assert(!lte_unsafe(0, TWO_POW_128, 16));

assert(lte_unsafe(0, 0, 16));
assert(lte_unsafe(0x100, 0x100, 16));
assert(lte_unsafe(TWO_POW_128 - 1, TWO_POW_128 - 1, 16));
assert(lte_unsafe(TWO_POW_128, TWO_POW_128, 16));
assert(lte_unsafe_internal(0, 1, 16));
assert(lte_unsafe_internal(0, 0x100, 16));
assert(lte_unsafe_internal(0x100, TWO_POW_128 - 1, 16));
assert(!lte_unsafe_internal(0, TWO_POW_128, 16));

assert(lte_unsafe_internal(0, 0, 16));
assert(lte_unsafe_internal(0x100, 0x100, 16));
assert(lte_unsafe_internal(TWO_POW_128 - 1, TWO_POW_128 - 1, 16));
assert(lte_unsafe_internal(TWO_POW_128, TWO_POW_128, 16));
}

#[test]