Skip to content

Commit

Permalink
fix: address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasRidhuan committed May 9, 2024
1 parent cecf2d5 commit ca89aab
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 44 deletions.
34 changes: 17 additions & 17 deletions barretenberg/cpp/pil/avm/avm_alu.pil
Original file line number Diff line number Diff line change
Expand Up @@ -652,13 +652,13 @@ namespace avm_alu(256);
pol commit op_div_a_lt_b;
op_div_a_lt_b * (1 - op_div_a_lt_b) = 0;
// To show this, we constrain ib - ia - 1 to be within 128 bits.
// Since we need a range check we use the existing a_lo column that is range checked over 128-bits
// Since we need a range check we use the existing a_lo column that is range checked over 128 bits.
op_div_a_lt_b * (a_lo - (ib - ia - 1)) = 0;
op_div_a_lt_b * ic = 0; // ic = 0
op_div_a_lt_b * (ia - remainder) = 0; // remainder = a, might not be needed.


// ====== Handling ia > ib =====
// ====== Handling ia >= ib =====
pol commit op_div_std;
op_div_std * (1 - op_div_std) = 0;
pol commit divisor_lo; // b
Expand All @@ -673,7 +673,7 @@ namespace avm_alu(256);
// (2) divisor_lo * quotient_hi + quotient_lo * divisor_hi --> Represents the middle 128 bits of the result, i.e. values between [2**64, 2**196)
// (3) divisor_hi * quotient_hi --> Represents the topmost 128 bits of the result, i.e. values between [2**128, 2**256).

// We simplify (2) by further decomposing it two limbs of 64 bits and adding the upper 64 bit to (3)
// We simplify (2) by further decomposing it into two limbs of 64 bits and adding the upper 64 bit to (3)
// divisor_lo * quotient_hi + quotient_lo * divisor_hi = partial_prod_lo + 2**64 * partial_prod_hi
// Need to range check that these are 64 bits
pol commit partial_prod_lo;
Expand All @@ -689,27 +689,27 @@ namespace avm_alu(256);
// Range checks already performed via a_lo and a_hi
// Primality checks already performed above via p_sub_a_lo and p_sub_a_hi

//Range check remainder < ib and put the value in b_hi, it has to fit into a 128 bit range check
// Range check remainder < ib and put the value in b_hi, it has to fit into a 128 bit range check
#[REMAINDER_RANGE_CHK]
op_div_std * (b_hi - (ib - remainder - 1)) = 0;
op_div_std * (b_hi - (ib - remainder - 1)) = 0;

// We need to perform 3x 256-bit range checks: (a_lo, a_hi), (b_lo, b_hi), and (p_sub_a_lo, p_sub_a_hi)
// We need to perform 3 x 256-bit range checks: (a_lo, a_hi), (b_lo, b_hi), and (p_sub_a_lo, p_sub_a_hi)
// One range check happens in-line with the division
#[CMP_CTR_REL_3]
(cmp_rng_ctr' - 2) * op_div_std = 0;

// If we have more range checks left we cannot do more divisions operations that might truncate the steps
rng_chk_sel * op_div_std = 0;

// Check PRODUCT = ia - remainder
// Check PRODUCT = ia - remainder
#[DIVISION_RELATION]
op_div_std * (PRODUCT - (ia - remainder)) = 0;

// === DIVISION 64 BIT RANGE CHECKS
// 64-bit decompositions and implicit 64 bit range checks for each limb,
// TODO: We need extra slice registers because we are performing an additional 64-bit bit range check in the same row, look into re-using old columns or refactoring
// === DIVISION 64-BIT RANGE CHECKS
// 64-bit decompositions and implicit 64-bit range checks for each limb,
// TODO: We need extra slice registers because we are performing an additional 64-bit range check in the same row, look into re-using old columns or refactoring
// range checks to be more modular.
// boolean to account for the division-specific 64 bit range checks.
// boolean to account for the division-specific 64-bit range checks.
pol commit div_rng_chk_selector;
div_rng_chk_selector * (1 - div_rng_chk_selector) = 0;
// div_rng_chk_selector && div_rng_chk_selector' = 1 if op_div_std = 1
Expand All @@ -724,14 +724,14 @@ namespace avm_alu(256);
pol commit div_u16_r6;
pol commit div_u16_r7;

divisor_lo = op_div_std * (div_u16_r0 + div_u16_r1 * 2**16 + div_u16_r2 * 2**32 + div_u16_r3 * 2 **48);
divisor_lo = op_div_std * (div_u16_r0 + div_u16_r1 * 2**16 + div_u16_r2 * 2**32 + div_u16_r3 * 2**48);
divisor_hi = op_div_std * (div_u16_r4 + div_u16_r5 * 2**16 + div_u16_r6 * 2**32 + div_u16_r7 * 2**48);
quotient_lo = op_div_std * (div_u16_r0' + div_u16_r1' * 2**16 + div_u16_r2' * 2**32 + div_u16_r3' * 2**48);
quotient_hi = op_div_std * (div_u16_r4' + div_u16_r5' * 2**16 + div_u16_r6' * 2**32 + div_u16_r7' * 2**48);

// We need an extra 128 bits to do 2 more 64 bit range checks. We use b_lo (128 bits) to store partial_prod_lo(64 bits) and partial_prod_hi(64 bits.
// We need an extra 128 bits to do 2 more 64-bit range checks. We use b_lo (128 bits) to store partial_prod_lo(64 bits) and partial_prod_hi(64 bits.
// Use a shift to access the slices (b_lo is moved into the alu slice registers on the next row anyways as part of the SHIFT_RELS_0 relations)
pol NEXT_SUM_64 = u8_r0' + u8_r1' * 2**8 + u16_r0' * 2**16 + u16_r1' * 2**32 + u16_r2' * 2**48;
pol NEXT_SUM_128 = u16_r3' + u16_r4' * 2**16 + u16_r5' * 2**32 + u16_r6' * 2**48;
partial_prod_lo = op_div_std * NEXT_SUM_64;
partial_prod_hi = op_div_std * NEXT_SUM_128;
pol NEXT_SUM_64_LO = u8_r0' + u8_r1' * 2**8 + u16_r0' * 2**16 + u16_r1' * 2**32 + u16_r2' * 2**48;
pol NEXT_SUM_128_HI = u16_r3' + u16_r4' * 2**16 + u16_r5' * 2**32 + u16_r6' * 2**48;
partial_prod_lo = op_div_std * NEXT_SUM_64_LO;
partial_prod_hi = op_div_std * NEXT_SUM_128_HI;
13 changes: 5 additions & 8 deletions barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_alu_trace.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -530,12 +530,12 @@ std::tuple<uint256_t, uint256_t> decompose(uint256_t const& a, uint8_t const b)
// This is useful when we want to enforce in certain checks that a must be greater than b
std::tuple<uint256_t, uint256_t, bool> gt_witness(uint256_t const& a, uint256_t const& b)
{
uint256_t two_pow_126 = uint256_t(1) << uint256_t(128);
uint256_t two_pow_128 = uint256_t(1) << uint256_t(128);
auto [a_lo, a_hi] = decompose(a, 128);
auto [b_lo, b_hi] = decompose(b, 128);
bool borrow = a_lo <= b_lo;
auto borrow_u256 = uint256_t(static_cast<uint64_t>(borrow));
uint256_t r_lo = a_lo - b_lo - 1 + borrow_u256 * two_pow_126;
uint256_t r_lo = a_lo - b_lo - 1 + borrow_u256 * two_pow_128;
uint256_t r_hi = a_hi - b_hi - borrow_u256;
return std::make_tuple(r_lo, r_hi, borrow);
}
Expand Down Expand Up @@ -983,7 +983,6 @@ FF AvmAluTraceBuilder::op_div(FF const& a, FF const& b, AvmMemoryTag in_tag, uin
alu_trace.push_back(AvmAluTraceBuilder::AluTraceEntry({
.alu_clk = clk,
.alu_op_div = true,
.alu_ff_tag = in_tag == AvmMemoryTag::FF,
.alu_u8_tag = in_tag == AvmMemoryTag::U8,
.alu_u16_tag = in_tag == AvmMemoryTag::U16,
.alu_u32_tag = in_tag == AvmMemoryTag::U32,
Expand Down Expand Up @@ -1036,7 +1035,6 @@ FF AvmAluTraceBuilder::op_div(FF const& a, FF const& b, AvmMemoryTag in_tag, uin
AvmAluTraceBuilder::AluTraceEntry row{
.alu_clk = clk,
.alu_op_div = true,
.alu_ff_tag = in_tag == AvmMemoryTag::FF,
.alu_u8_tag = in_tag == AvmMemoryTag::U8,
.alu_u16_tag = in_tag == AvmMemoryTag::U16,
.alu_u32_tag = in_tag == AvmMemoryTag::U32,
Expand All @@ -1058,11 +1056,10 @@ FF AvmAluTraceBuilder::op_div(FF const& a, FF const& b, AvmMemoryTag in_tag, uin
};
// We perform the range checks here
std::vector<AvmAluTraceBuilder::AluTraceEntry> rows = cmp_range_check_helper(row, hi_lo_limbs);
// Add the range checks for the quotient limbs in the row after the division operation
rows.at(1).div_u64_range_chk = div_u64_rng_chk_shifted;
rows.at(1).div_u64_range_chk_sel = true;
alu_trace.insert(alu_trace.end(), rows.begin(), rows.end());
// Add the range checks for the quotient limbs in the next row
alu_trace.at(1).div_u64_range_chk = div_u64_rng_chk_shifted;
alu_trace.at(1).div_u64_range_chk_sel = true;

return c_u256;
}
} // namespace bb::avm_trace
21 changes: 2 additions & 19 deletions barretenberg/cpp/src/barretenberg/vm/tests/avm_arithmetic.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ size_t common_validate_div(std::vector<Row> const& trace,
common_validate_arithmetic_op(*row, *alu_row, a, b, c, addr_a, addr_b, addr_c, tag);
EXPECT_EQ(row->avm_main_w_in_tag, FF(static_cast<uint32_t>(tag)));

// Check that multiplication selector is set.
// Check that division selector is set.
EXPECT_EQ(alu_row->avm_alu_op_div, FF(1));

return static_cast<size_t>(alu_row - trace.begin());
Expand Down Expand Up @@ -612,7 +612,7 @@ INSTANTIATE_TEST_SUITE_P(AvmArithmeticTestsDiv,
AvmArithmeticTestsDiv,
testing::ValuesIn(gen_three_op_params(positive_op_div_test_values, uint_mem_tags)));

// Test on division by zero over finite field type.
// Test on division by zero over U128.
// We check that the operator error flag is raised.
TEST_F(AvmArithmeticTests, DivisionByZeroError)
{
Expand Down Expand Up @@ -808,23 +808,6 @@ TEST_F(AvmArithmeticTestsU8, nonEquality)
validate_trace(std::move(trace));
}

// Test correct division of U8 elements using faster method
TEST_F(AvmArithmeticTestsU8, fastDivision)
{
auto trace_builder = avm_trace::AvmTraceBuilder();
trace_builder.op_set(0, 153, 0, AvmMemoryTag::U8);
trace_builder.op_set(0, 2, 1, AvmMemoryTag::U8);
trace_builder.op_div(0, 0, 1, 2, AvmMemoryTag::U8);
trace_builder.return_op(0, 0, 0);
auto trace = trace_builder.finalize();

auto alu_row_index = common_validate_div(trace, 153, 2, 76, 0, 1, 2, AvmMemoryTag::U8);
auto alu_row = trace.at(alu_row_index);

EXPECT_EQ(alu_row.avm_alu_u8_tag, FF(1));
validate_trace_check_circuit(std::move(trace));
}

/******************************************************************************
* Positive Tests - U16
******************************************************************************/
Expand Down

0 comments on commit ca89aab

Please sign in to comment.