diff --git a/barretenberg/cpp/pil/avm/avm_alu.pil b/barretenberg/cpp/pil/avm/avm_alu.pil index 8005535d05f..5d6db0544ba 100644 --- a/barretenberg/cpp/pil/avm/avm_alu.pil +++ b/barretenberg/cpp/pil/avm/avm_alu.pil @@ -652,13 +652,13 @@ namespace avm_alu(256); pol commit op_div_a_lt_b; op_div_a_lt_b * (1 - op_div_a_lt_b) = 0; // To show this, we constrain ib - ia - 1 to be within 128 bits. - // Since we need a range check we use the existing a_lo column that is range checked over 128-bits + // Since we need a range check we use the existing a_lo column that is range checked over 128 bits. op_div_a_lt_b * (a_lo - (ib - ia - 1)) = 0; op_div_a_lt_b * ic = 0; // ic = 0 op_div_a_lt_b * (ia - remainder) = 0; // remainder = a, might not be needed. - // ====== Handling ia > ib ===== + // ====== Handling ia >= ib ===== pol commit op_div_std; op_div_std * (1 - op_div_std) = 0; pol commit divisor_lo; // b @@ -673,7 +673,7 @@ namespace avm_alu(256); // (2) divisor_lo * quotient_hi + quotient_lo * divisor_hi --> Represents the middle 128 bits of the result, i.e. values between [2**64, 2**196) // (3) divisor_hi * quotient_hi --> Represents the topmost 128 bits of the result, i.e. values between [2**128, 2**256). - // We simplify (2) by further decomposing it two limbs of 64 bits and adding the upper 64 bit to (3) + // We simplify (2) by further decomposing it into two limbs of 64 bits and adding the upper 64 bit to (3) // divisor_lo * quotient_hi + quotient_lo * divisor_hi = partial_prod_lo + 2**64 * partial_prod_hi // Need to range check that these are 64 bits pol commit partial_prod_lo; @@ -689,11 +689,11 @@ namespace avm_alu(256); // Range checks already performed via a_lo and a_hi // Primality checks already performed above via p_sub_a_lo and p_sub_a_hi - //Range check remainder < ib and put the value in b_hi, it has to fit into a 128 bit range check + // Range check remainder < ib and put the value in b_hi, it has to fit into a 128 bit range check #[REMAINDER_RANGE_CHK] - op_div_std * (b_hi - (ib - remainder - 1)) = 0; + op_div_std * (b_hi - (ib - remainder - 1)) = 0; - // We need to perform 3x 256-bit range checks: (a_lo, a_hi), (b_lo, b_hi), and (p_sub_a_lo, p_sub_a_hi) + // We need to perform 3 x 256-bit range checks: (a_lo, a_hi), (b_lo, b_hi), and (p_sub_a_lo, p_sub_a_hi) // One range check happens in-line with the division #[CMP_CTR_REL_3] (cmp_rng_ctr' - 2) * op_div_std = 0; @@ -701,15 +701,15 @@ namespace avm_alu(256); // If we have more range checks left we cannot do more divisions operations that might truncate the steps rng_chk_sel * op_div_std = 0; - // Check PRODUCT = ia - remainder + // Check PRODUCT = ia - remainder #[DIVISION_RELATION] op_div_std * (PRODUCT - (ia - remainder)) = 0; - // === DIVISION 64 BIT RANGE CHECKS - // 64-bit decompositions and implicit 64 bit range checks for each limb, - // TODO: We need extra slice registers because we are performing an additional 64-bit bit range check in the same row, look into re-using old columns or refactoring + // === DIVISION 64-BIT RANGE CHECKS + // 64-bit decompositions and implicit 64-bit range checks for each limb, + // TODO: We need extra slice registers because we are performing an additional 64-bit range check in the same row, look into re-using old columns or refactoring // range checks to be more modular. - // boolean to account for the division-specific 64 bit range checks. + // boolean to account for the division-specific 64-bit range checks. pol commit div_rng_chk_selector; div_rng_chk_selector * (1 - div_rng_chk_selector) = 0; // div_rng_chk_selector && div_rng_chk_selector' = 1 if op_div_std = 1 @@ -724,14 +724,14 @@ namespace avm_alu(256); pol commit div_u16_r6; pol commit div_u16_r7; - divisor_lo = op_div_std * (div_u16_r0 + div_u16_r1 * 2**16 + div_u16_r2 * 2**32 + div_u16_r3 * 2 **48); + divisor_lo = op_div_std * (div_u16_r0 + div_u16_r1 * 2**16 + div_u16_r2 * 2**32 + div_u16_r3 * 2**48); divisor_hi = op_div_std * (div_u16_r4 + div_u16_r5 * 2**16 + div_u16_r6 * 2**32 + div_u16_r7 * 2**48); quotient_lo = op_div_std * (div_u16_r0' + div_u16_r1' * 2**16 + div_u16_r2' * 2**32 + div_u16_r3' * 2**48); quotient_hi = op_div_std * (div_u16_r4' + div_u16_r5' * 2**16 + div_u16_r6' * 2**32 + div_u16_r7' * 2**48); - // We need an extra 128 bits to do 2 more 64 bit range checks. We use b_lo (128 bits) to store partial_prod_lo(64 bits) and partial_prod_hi(64 bits. + // We need an extra 128 bits to do 2 more 64-bit range checks. We use b_lo (128 bits) to store partial_prod_lo(64 bits) and partial_prod_hi(64 bits. // Use a shift to access the slices (b_lo is moved into the alu slice registers on the next row anyways as part of the SHIFT_RELS_0 relations) - pol NEXT_SUM_64 = u8_r0' + u8_r1' * 2**8 + u16_r0' * 2**16 + u16_r1' * 2**32 + u16_r2' * 2**48; - pol NEXT_SUM_128 = u16_r3' + u16_r4' * 2**16 + u16_r5' * 2**32 + u16_r6' * 2**48; - partial_prod_lo = op_div_std * NEXT_SUM_64; - partial_prod_hi = op_div_std * NEXT_SUM_128; + pol NEXT_SUM_64_LO = u8_r0' + u8_r1' * 2**8 + u16_r0' * 2**16 + u16_r1' * 2**32 + u16_r2' * 2**48; + pol NEXT_SUM_128_HI = u16_r3' + u16_r4' * 2**16 + u16_r5' * 2**32 + u16_r6' * 2**48; + partial_prod_lo = op_div_std * NEXT_SUM_64_LO; + partial_prod_hi = op_div_std * NEXT_SUM_128_HI; diff --git a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_alu_trace.cpp b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_alu_trace.cpp index 4a3a7073134..9a055c79aba 100644 --- a/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_alu_trace.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/avm_trace/avm_alu_trace.cpp @@ -530,12 +530,12 @@ std::tuple decompose(uint256_t const& a, uint8_t const b) // This is useful when we want to enforce in certain checks that a must be greater than b std::tuple gt_witness(uint256_t const& a, uint256_t const& b) { - uint256_t two_pow_126 = uint256_t(1) << uint256_t(128); + uint256_t two_pow_128 = uint256_t(1) << uint256_t(128); auto [a_lo, a_hi] = decompose(a, 128); auto [b_lo, b_hi] = decompose(b, 128); bool borrow = a_lo <= b_lo; auto borrow_u256 = uint256_t(static_cast(borrow)); - uint256_t r_lo = a_lo - b_lo - 1 + borrow_u256 * two_pow_126; + uint256_t r_lo = a_lo - b_lo - 1 + borrow_u256 * two_pow_128; uint256_t r_hi = a_hi - b_hi - borrow_u256; return std::make_tuple(r_lo, r_hi, borrow); } @@ -983,7 +983,6 @@ FF AvmAluTraceBuilder::op_div(FF const& a, FF const& b, AvmMemoryTag in_tag, uin alu_trace.push_back(AvmAluTraceBuilder::AluTraceEntry({ .alu_clk = clk, .alu_op_div = true, - .alu_ff_tag = in_tag == AvmMemoryTag::FF, .alu_u8_tag = in_tag == AvmMemoryTag::U8, .alu_u16_tag = in_tag == AvmMemoryTag::U16, .alu_u32_tag = in_tag == AvmMemoryTag::U32, @@ -1036,7 +1035,6 @@ FF AvmAluTraceBuilder::op_div(FF const& a, FF const& b, AvmMemoryTag in_tag, uin AvmAluTraceBuilder::AluTraceEntry row{ .alu_clk = clk, .alu_op_div = true, - .alu_ff_tag = in_tag == AvmMemoryTag::FF, .alu_u8_tag = in_tag == AvmMemoryTag::U8, .alu_u16_tag = in_tag == AvmMemoryTag::U16, .alu_u32_tag = in_tag == AvmMemoryTag::U32, @@ -1058,11 +1056,10 @@ FF AvmAluTraceBuilder::op_div(FF const& a, FF const& b, AvmMemoryTag in_tag, uin }; // We perform the range checks here std::vector rows = cmp_range_check_helper(row, hi_lo_limbs); + // Add the range checks for the quotient limbs in the row after the division operation + rows.at(1).div_u64_range_chk = div_u64_rng_chk_shifted; + rows.at(1).div_u64_range_chk_sel = true; alu_trace.insert(alu_trace.end(), rows.begin(), rows.end()); - // Add the range checks for the quotient limbs in the next row - alu_trace.at(1).div_u64_range_chk = div_u64_rng_chk_shifted; - alu_trace.at(1).div_u64_range_chk_sel = true; - return c_u256; } } // namespace bb::avm_trace diff --git a/barretenberg/cpp/src/barretenberg/vm/tests/avm_arithmetic.test.cpp b/barretenberg/cpp/src/barretenberg/vm/tests/avm_arithmetic.test.cpp index 1f5c1f80eb4..c0754b31d4c 100644 --- a/barretenberg/cpp/src/barretenberg/vm/tests/avm_arithmetic.test.cpp +++ b/barretenberg/cpp/src/barretenberg/vm/tests/avm_arithmetic.test.cpp @@ -191,7 +191,7 @@ size_t common_validate_div(std::vector const& trace, common_validate_arithmetic_op(*row, *alu_row, a, b, c, addr_a, addr_b, addr_c, tag); EXPECT_EQ(row->avm_main_w_in_tag, FF(static_cast(tag))); - // Check that multiplication selector is set. + // Check that division selector is set. EXPECT_EQ(alu_row->avm_alu_op_div, FF(1)); return static_cast(alu_row - trace.begin()); @@ -612,7 +612,7 @@ INSTANTIATE_TEST_SUITE_P(AvmArithmeticTestsDiv, AvmArithmeticTestsDiv, testing::ValuesIn(gen_three_op_params(positive_op_div_test_values, uint_mem_tags))); -// Test on division by zero over finite field type. +// Test on division by zero over U128. // We check that the operator error flag is raised. TEST_F(AvmArithmeticTests, DivisionByZeroError) { @@ -808,23 +808,6 @@ TEST_F(AvmArithmeticTestsU8, nonEquality) validate_trace(std::move(trace)); } -// Test correct division of U8 elements using faster method -TEST_F(AvmArithmeticTestsU8, fastDivision) -{ - auto trace_builder = avm_trace::AvmTraceBuilder(); - trace_builder.op_set(0, 153, 0, AvmMemoryTag::U8); - trace_builder.op_set(0, 2, 1, AvmMemoryTag::U8); - trace_builder.op_div(0, 0, 1, 2, AvmMemoryTag::U8); - trace_builder.return_op(0, 0, 0); - auto trace = trace_builder.finalize(); - - auto alu_row_index = common_validate_div(trace, 153, 2, 76, 0, 1, 2, AvmMemoryTag::U8); - auto alu_row = trace.at(alu_row_index); - - EXPECT_EQ(alu_row.avm_alu_u8_tag, FF(1)); - validate_trace_check_circuit(std::move(trace)); -} - /****************************************************************************** * Positive Tests - U16 ******************************************************************************/