Skip to content

Commit

Permalink
feat: Signed integer division and modulus in brillig gen (#5279)
Browse files Browse the repository at this point in the history
Removes the signed div opcode since the AVM won't support it and
replaces it with the emulated version using unsigned division.
  • Loading branch information
sirasistant authored Mar 18, 2024
1 parent d0a5b19 commit 82f8cf5
Show file tree
Hide file tree
Showing 16 changed files with 315 additions and 262 deletions.
2 changes: 1 addition & 1 deletion avm-transpiler/src/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ pub fn brillig_to_avm(brillig: &Brillig) -> Vec<u8> {
BinaryIntOp::Add => AvmOpcode::ADD,
BinaryIntOp::Sub => AvmOpcode::SUB,
BinaryIntOp::Mul => AvmOpcode::MUL,
BinaryIntOp::UnsignedDiv => AvmOpcode::DIV,
BinaryIntOp::Div => AvmOpcode::DIV,
BinaryIntOp::Equals => AvmOpcode::EQ,
BinaryIntOp::LessThan => AvmOpcode::LT,
BinaryIntOp::LessThanEquals => AvmOpcode::LTE,
Expand Down
74 changes: 13 additions & 61 deletions barretenberg/cpp/src/barretenberg/dsl/acir_format/serde/acir.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,10 @@ struct BinaryIntOp {
static Mul bincodeDeserialize(std::vector<uint8_t>);
};

struct SignedDiv {
friend bool operator==(const SignedDiv&, const SignedDiv&);
std::vector<uint8_t> bincodeSerialize() const;
static SignedDiv bincodeDeserialize(std::vector<uint8_t>);
};

struct UnsignedDiv {
friend bool operator==(const UnsignedDiv&, const UnsignedDiv&);
struct Div {
friend bool operator==(const Div&, const Div&);
std::vector<uint8_t> bincodeSerialize() const;
static UnsignedDiv bincodeDeserialize(std::vector<uint8_t>);
static Div bincodeDeserialize(std::vector<uint8_t>);
};

struct Equals {
Expand Down Expand Up @@ -142,7 +136,7 @@ struct BinaryIntOp {
static Shr bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Add, Sub, Mul, SignedDiv, UnsignedDiv, Equals, LessThan, LessThanEquals, And, Or, Xor, Shl, Shr> value;
std::variant<Add, Sub, Mul, Div, Equals, LessThan, LessThanEquals, And, Or, Xor, Shl, Shr> value;

friend bool operator==(const BinaryIntOp&, const BinaryIntOp&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -1768,63 +1762,22 @@ Circuit::BinaryIntOp::Mul serde::Deserializable<Circuit::BinaryIntOp::Mul>::dese

namespace Circuit {

inline bool operator==(const BinaryIntOp::SignedDiv& lhs, const BinaryIntOp::SignedDiv& rhs)
{
return true;
}

inline std::vector<uint8_t> BinaryIntOp::SignedDiv::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BinaryIntOp::SignedDiv>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BinaryIntOp::SignedDiv BinaryIntOp::SignedDiv::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BinaryIntOp::SignedDiv>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BinaryIntOp::SignedDiv>::serialize(const Circuit::BinaryIntOp::SignedDiv& obj,
Serializer& serializer)
{}

template <>
template <typename Deserializer>
Circuit::BinaryIntOp::SignedDiv serde::Deserializable<Circuit::BinaryIntOp::SignedDiv>::deserialize(
Deserializer& deserializer)
{
Circuit::BinaryIntOp::SignedDiv obj;
return obj;
}

namespace Circuit {

inline bool operator==(const BinaryIntOp::UnsignedDiv& lhs, const BinaryIntOp::UnsignedDiv& rhs)
inline bool operator==(const BinaryIntOp::Div& lhs, const BinaryIntOp::Div& rhs)
{
return true;
}

inline std::vector<uint8_t> BinaryIntOp::UnsignedDiv::bincodeSerialize() const
inline std::vector<uint8_t> BinaryIntOp::Div::bincodeSerialize() const
{
auto serializer = serde::BincodeSerializer();
serde::Serializable<BinaryIntOp::UnsignedDiv>::serialize(*this, serializer);
serde::Serializable<BinaryIntOp::Div>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BinaryIntOp::UnsignedDiv BinaryIntOp::UnsignedDiv::bincodeDeserialize(std::vector<uint8_t> input)
inline BinaryIntOp::Div BinaryIntOp::Div::bincodeDeserialize(std::vector<uint8_t> input)
{
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BinaryIntOp::UnsignedDiv>::deserialize(deserializer);
auto value = serde::Deserializable<BinaryIntOp::Div>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw_or_abort("Some input bytes were not read");
}
Expand All @@ -1835,16 +1788,15 @@ inline BinaryIntOp::UnsignedDiv BinaryIntOp::UnsignedDiv::bincodeDeserialize(std

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BinaryIntOp::UnsignedDiv>::serialize(const Circuit::BinaryIntOp::UnsignedDiv& obj,
Serializer& serializer)
void serde::Serializable<Circuit::BinaryIntOp::Div>::serialize(const Circuit::BinaryIntOp::Div& obj,
Serializer& serializer)
{}

template <>
template <typename Deserializer>
Circuit::BinaryIntOp::UnsignedDiv serde::Deserializable<Circuit::BinaryIntOp::UnsignedDiv>::deserialize(
Deserializer& deserializer)
Circuit::BinaryIntOp::Div serde::Deserializable<Circuit::BinaryIntOp::Div>::deserialize(Deserializer& deserializer)
{
Circuit::BinaryIntOp::UnsignedDiv obj;
Circuit::BinaryIntOp::Div obj;
return obj;
}

Expand Down
65 changes: 12 additions & 53 deletions noir/noir-repo/acvm-repo/acir/codegen/acir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,10 @@ namespace Circuit {
static Mul bincodeDeserialize(std::vector<uint8_t>);
};

struct SignedDiv {
friend bool operator==(const SignedDiv&, const SignedDiv&);
std::vector<uint8_t> bincodeSerialize() const;
static SignedDiv bincodeDeserialize(std::vector<uint8_t>);
};

struct UnsignedDiv {
friend bool operator==(const UnsignedDiv&, const UnsignedDiv&);
struct Div {
friend bool operator==(const Div&, const Div&);
std::vector<uint8_t> bincodeSerialize() const;
static UnsignedDiv bincodeDeserialize(std::vector<uint8_t>);
static Div bincodeDeserialize(std::vector<uint8_t>);
};

struct Equals {
Expand Down Expand Up @@ -142,7 +136,7 @@ namespace Circuit {
static Shr bincodeDeserialize(std::vector<uint8_t>);
};

std::variant<Add, Sub, Mul, SignedDiv, UnsignedDiv, Equals, LessThan, LessThanEquals, And, Or, Xor, Shl, Shr> value;
std::variant<Add, Sub, Mul, Div, Equals, LessThan, LessThanEquals, And, Or, Xor, Shl, Shr> value;

friend bool operator==(const BinaryIntOp&, const BinaryIntOp&);
std::vector<uint8_t> bincodeSerialize() const;
Expand Down Expand Up @@ -1634,54 +1628,19 @@ Circuit::BinaryIntOp::Mul serde::Deserializable<Circuit::BinaryIntOp::Mul>::dese

namespace Circuit {

inline bool operator==(const BinaryIntOp::SignedDiv &lhs, const BinaryIntOp::SignedDiv &rhs) {
return true;
}

inline std::vector<uint8_t> BinaryIntOp::SignedDiv::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BinaryIntOp::SignedDiv>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BinaryIntOp::SignedDiv BinaryIntOp::SignedDiv::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BinaryIntOp::SignedDiv>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
return value;
}

} // end of namespace Circuit

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BinaryIntOp::SignedDiv>::serialize(const Circuit::BinaryIntOp::SignedDiv &obj, Serializer &serializer) {
}

template <>
template <typename Deserializer>
Circuit::BinaryIntOp::SignedDiv serde::Deserializable<Circuit::BinaryIntOp::SignedDiv>::deserialize(Deserializer &deserializer) {
Circuit::BinaryIntOp::SignedDiv obj;
return obj;
}

namespace Circuit {

inline bool operator==(const BinaryIntOp::UnsignedDiv &lhs, const BinaryIntOp::UnsignedDiv &rhs) {
inline bool operator==(const BinaryIntOp::Div &lhs, const BinaryIntOp::Div &rhs) {
return true;
}

inline std::vector<uint8_t> BinaryIntOp::UnsignedDiv::bincodeSerialize() const {
inline std::vector<uint8_t> BinaryIntOp::Div::bincodeSerialize() const {
auto serializer = serde::BincodeSerializer();
serde::Serializable<BinaryIntOp::UnsignedDiv>::serialize(*this, serializer);
serde::Serializable<BinaryIntOp::Div>::serialize(*this, serializer);
return std::move(serializer).bytes();
}

inline BinaryIntOp::UnsignedDiv BinaryIntOp::UnsignedDiv::bincodeDeserialize(std::vector<uint8_t> input) {
inline BinaryIntOp::Div BinaryIntOp::Div::bincodeDeserialize(std::vector<uint8_t> input) {
auto deserializer = serde::BincodeDeserializer(input);
auto value = serde::Deserializable<BinaryIntOp::UnsignedDiv>::deserialize(deserializer);
auto value = serde::Deserializable<BinaryIntOp::Div>::deserialize(deserializer);
if (deserializer.get_buffer_offset() < input.size()) {
throw serde::deserialization_error("Some input bytes were not read");
}
Expand All @@ -1692,13 +1651,13 @@ namespace Circuit {

template <>
template <typename Serializer>
void serde::Serializable<Circuit::BinaryIntOp::UnsignedDiv>::serialize(const Circuit::BinaryIntOp::UnsignedDiv &obj, Serializer &serializer) {
void serde::Serializable<Circuit::BinaryIntOp::Div>::serialize(const Circuit::BinaryIntOp::Div &obj, Serializer &serializer) {
}

template <>
template <typename Deserializer>
Circuit::BinaryIntOp::UnsignedDiv serde::Deserializable<Circuit::BinaryIntOp::UnsignedDiv>::deserialize(Deserializer &deserializer) {
Circuit::BinaryIntOp::UnsignedDiv obj;
Circuit::BinaryIntOp::Div serde::Deserializable<Circuit::BinaryIntOp::Div>::deserialize(Deserializer &deserializer) {
Circuit::BinaryIntOp::Div obj;
return obj;
}

Expand Down
3 changes: 1 addition & 2 deletions noir/noir-repo/acvm-repo/brillig/src/opcodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,7 @@ pub enum BinaryIntOp {
Add,
Sub,
Mul,
SignedDiv,
UnsignedDiv,
Div,
/// (==) equal
Equals,
/// (<) Field less than
Expand Down
71 changes: 3 additions & 68 deletions noir/noir-repo/acvm-repo/brillig_vm/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use acir::brillig::{BinaryFieldOp, BinaryIntOp};
use acir::FieldElement;
use num_bigint::{BigInt, BigUint};
use num_bigint::BigUint;
use num_traits::{One, ToPrimitive, Zero};

/// Evaluate a binary operation on two FieldElements and return the result as a FieldElement.
Expand Down Expand Up @@ -42,24 +42,14 @@ pub(crate) fn evaluate_binary_bigint_op(
BinaryIntOp::Sub => (bit_modulo + a - b) % bit_modulo,
BinaryIntOp::Mul => (a * b) % bit_modulo,
// Perform unsigned division using the modulo operation on a and b.
BinaryIntOp::UnsignedDiv => {
BinaryIntOp::Div => {
let b_mod = b % bit_modulo;
if b_mod.is_zero() {
BigUint::zero()
} else {
(a % bit_modulo) / b_mod
}
}
// Perform signed division by first converting a and b to signed integers and then back to unsigned after the operation.
BinaryIntOp::SignedDiv => {
let b_signed = to_big_signed(b, bit_size);
if b_signed.is_zero() {
BigUint::zero()
} else {
let signed_div = to_big_signed(a, bit_size) / b_signed;
to_big_unsigned(signed_div, bit_size)
}
}
// Perform a == operation, returning 0 or 1
BinaryIntOp::Equals => {
if (a % bit_modulo) == (b % bit_modulo) {
Expand Down Expand Up @@ -103,23 +93,6 @@ pub(crate) fn evaluate_binary_bigint_op(
Ok(result)
}

fn to_big_signed(a: BigUint, bit_size: u32) -> BigInt {
let pow_2 = BigUint::from(2_u32).pow(bit_size - 1);
if a < pow_2 {
BigInt::from(a)
} else {
BigInt::from(a) - 2 * BigInt::from(pow_2)
}
}

fn to_big_unsigned(a: BigInt, bit_size: u32) -> BigUint {
if a >= BigInt::zero() {
BigUint::from_bytes_le(&a.to_bytes_le().1)
} else {
BigUint::from(2_u32).pow(bit_size) - BigUint::from_bytes_le(&a.to_bytes_le().1)
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand All @@ -139,24 +112,6 @@ mod tests {
result_value.to_u128().unwrap()
}

fn to_signed(a: u128, bit_size: u32) -> i128 {
assert!(bit_size < 128);
let pow_2 = 2_u128.pow(bit_size - 1);
if a < pow_2 {
a as i128
} else {
(a.wrapping_sub(2 * pow_2)) as i128
}
}

fn to_unsigned(a: i128, bit_size: u32) -> u128 {
if a >= 0 {
a as u128
} else {
(a + 2_i128.pow(bit_size)) as u128
}
}

fn to_negative(a: u128, bit_size: u32) -> u128 {
assert!(a > 0);
let two_pow = 2_u128.pow(bit_size);
Expand Down Expand Up @@ -233,26 +188,6 @@ mod tests {
let test_ops =
vec![TestParams { a: 5, b: 3, result: 1 }, TestParams { a: 5, b: 10, result: 0 }];

evaluate_int_ops(test_ops, BinaryIntOp::UnsignedDiv, bit_size);
}

#[test]
fn to_signed_roundtrip() {
let bit_size = 32;
let minus_one = 2_u128.pow(bit_size) - 1;
assert_eq!(to_unsigned(to_signed(minus_one, bit_size), bit_size), minus_one);
}

#[test]
fn signed_div_test() {
let bit_size = 32;

let test_ops = vec![
TestParams { a: 5, b: to_negative(10, bit_size), result: 0 },
TestParams { a: 5, b: to_negative(1, bit_size), result: to_negative(5, bit_size) },
TestParams { a: to_negative(5, bit_size), b: to_negative(1, bit_size), result: 5 },
];

evaluate_int_ops(test_ops, BinaryIntOp::SignedDiv, bit_size);
evaluate_int_ops(test_ops, BinaryIntOp::Div, bit_size);
}
}
Loading

0 comments on commit 82f8cf5

Please sign in to comment.