Skip to content

Commit

Permalink
feat(avm): shift relations (#5716)
Browse files Browse the repository at this point in the history
Please read [contributing guidelines](CONTRIBUTING.md) and remove this
line.
  • Loading branch information
IlyasRidhuan authored Apr 23, 2024
1 parent 100744f commit a516637
Show file tree
Hide file tree
Showing 22 changed files with 1,934 additions and 364 deletions.
1 change: 1 addition & 0 deletions barretenberg/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ set(CMAKE_CXX_STANDARD_REQUIRED TRUE)
set(CMAKE_CXX_EXTENSIONS ON)

if(CMAKE_CXX_COMPILER_ID MATCHES "Clang")
add_compile_options(-fbracket-depth=512)
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "14")
message(WARNING "Clang <14 is not supported")
endif()
Expand Down
134 changes: 123 additions & 11 deletions barretenberg/cpp/pil/avm/avm_alu.pil
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ namespace avm_alu(256);
pol commit op_lt;
pol commit op_lte;
pol commit cmp_sel; // Predicate if LT or LTE is set
pol commit rng_chk_sel; // Predicate representing a range check row used in LT/LTE.
pol commit rng_chk_sel; // Predicate representing a range check row.
pol commit op_shl;
pol commit op_shr;
pol commit shift_sel; // Predicate if SHR or SHR is set

// Instruction tag (1: u8, 2: u16, 3: u32, 4: u64, 5: u128, 6: field) copied from Main table
pol commit in_tag;
Expand Down Expand Up @@ -61,8 +64,9 @@ namespace avm_alu(256);
pol commit cf;

// Compute predicate telling whether there is a row entry in the ALU table.
alu_sel = op_add + op_sub + op_mul + op_not + op_eq + op_cast + op_lt + op_lte;
alu_sel = op_add + op_sub + op_mul + op_not + op_eq + op_cast + op_lt + op_lte + op_shr + op_shl;
cmp_sel = op_lt + op_lte;
shift_sel = op_shl + op_shr;

// ========= Type Constraints =============================================
// TODO: Range constraints
Expand Down Expand Up @@ -355,7 +359,7 @@ namespace avm_alu(256);
// (b) IS_GT = 1 - ic = 0
// (c) res_lo = B_SUB_A_LO and res_hi = B_SUB_A_HI
// (d) res_lo = y_lo - x_lo + borrow * 2**128 and res_hi = y_hi - x_hi - borrow.
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, y_lo, we
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, x_hi, we
// have the guarantee that res_lo >= 0 && res_hi >= 0. Furthermore, borrow is
// boolean and so we have two cases to consider:
// (i) borrow == 0 ==> y_lo >= x_lo && y_hi >= x_hi
Expand All @@ -368,7 +372,7 @@ namespace avm_alu(256);
// (b) IS_GT = 1 - ic = 1
// (c) res_lo = A_SUB_B_LO and res_hi = A_SUB_B_HI
// (d) res_lo = x_lo - y_lo - 1 + borrow * 2**128 and res_hi = x_hi - y_hi - borrow.
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, y_lo, we
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, x_hi, we
// have the guarantee that res_lo >= 0 && res_hi >= 0. Furthermore, borrow is
// boolean and so we have two cases to consider:
// (i) borrow == 0 ==> x_lo > y_lo && x_hi >= y_hi
Expand All @@ -383,7 +387,7 @@ namespace avm_alu(256);
// (b) IS_GT = ic = 1
// (c) res_lo = A_SUB_B_LO and res_hi = A_SUB_B_HI, **remember we have swapped inputs**
// (d) res_lo = y_lo - x_lo - 1 + borrow * 2**128 and res_hi = y_hi - x_hi - borrow.
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, y_lo, we
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, x_hi, we
// have the guarantee that res_lo >= 0 && res_hi >= 0. Furthermore, borrow is
// boolean and so we have two cases to consider:
// (i) borrow == 0 ==> y_lo > x_lo && y_hi >= x_hi
Expand All @@ -395,8 +399,8 @@ namespace avm_alu(256);
// (a) We DO swap the operands, so a = y and b = x,
// (b) IS_GT = ic = 0
// (c) res_lo = B_SUB_A_LO and res_hi = B_SUB_A_HI, **remember we have swapped inputs**
// (d) res_lo = x_lo - y_lo + borrow * 2**128 and res_hi = x_hi - y_hi - borrow.
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, y_lo, we
// (d) res_lo = a_lo - y_lo + borrow * 2**128 and res_hi = a_hi - y_hi - borrow.
// (e) Due to 128-bit range checks on res_lo, res_hi, y_lo, x_lo, y_hi, x_hi, we
// have the guarantee that res_lo >= 0 && res_hi >= 0. Furthermore, borrow is
// boolean and so we have two cases to consider:
// (i) borrow == 0 ==> x_lo >= y_lo && x_hi >= y_hi
Expand Down Expand Up @@ -434,11 +438,13 @@ namespace avm_alu(256);
cmp_rng_ctr * ((1 - rng_chk_sel) * (1 - op_eq_diff_inv) + op_eq_diff_inv) - rng_chk_sel = 0;

// We perform a range check if we have some range checks remaining or we are performing a comparison op
pol RNG_CHK_OP = rng_chk_sel + cmp_sel + op_cast + op_cast_prev;
pol RNG_CHK_OP = rng_chk_sel + cmp_sel + op_cast + op_cast_prev + shift_lt_bit_len;

pol commit rng_chk_lookup_selector;
// TODO: Possible optimisation here if we swap the op_shl and op_shr with shift_lt_bit_len.
// Shift_lt_bit_len is a more restrictive form therefore we can avoid performing redundant range checks when we know the result == 0.
#[RNG_CHK_LOOKUP_SELECTOR]
rng_chk_lookup_selector' = cmp_sel' + rng_chk_sel' + op_add' + op_sub' + op_mul' + op_mul * u128_tag + op_cast' + op_cast_prev';
rng_chk_lookup_selector' = cmp_sel' + rng_chk_sel' + op_add' + op_sub' + op_mul' + op_mul * u128_tag + op_cast' + op_cast_prev' + op_shl' + op_shr';

// Perform 128-bit range check on lo part
#[LOWER_CMP_RNG_CHK]
Expand Down Expand Up @@ -469,7 +475,6 @@ namespace avm_alu(256);
(p_sub_b_lo' - res_lo) * rng_chk_sel'= 0;
(p_sub_b_hi' - res_hi) * rng_chk_sel'= 0;


// ========= CAST Operation Constraints ===============================
// We handle the input ia independently of its tag, i.e., we suppose it can take
// any value between 0 and p-1.
Expand Down Expand Up @@ -509,4 +514,111 @@ namespace avm_alu(256);
// 128-bit multiplication and CAST need two rows in ALU trace. We need to ensure
// that another ALU operation does not start in the second row.
#[TWO_LINE_OP_NO_OVERLAP]
(op_mul * ff_tag + op_cast) * alu_sel' = 0;
(op_mul * ff_tag + op_cast) * alu_sel' = 0;

// ========= SHIFT LEFT/RIGHT OPERATIONS ===============================
// Given (1) an input b, within the range [0, 2**128-1],
// (2) a value s, the amount of bits to shift b by,
// (3) and a memory tag, mem_tag that supports a maximum of t bits.
// Split input b into Big Endian hi and lo limbs, (we re-use the b_hi and b_lo columns we used for the comparison operators)
// b_hi and b_lo, and the number of bits represented by the memory tag, t.
// If we are shifting by more than the bit length represented by the memory tag, the result is trivially zero
//
// === Steps when performing SHR
// (1) Prove the correct decomposition: b_hi * 2**s + b_lo = b
// (2) Range check b_hi < 2**(t-s) && b_lo < 2**s, ensure we have not overflowed the limbs during decomp
// (3) Return b_hi
//
// <--(t-s) bits --> | <-- s bits -->
// -------------------|-------------------
// | b_hi | b_lo | --> b
// ---------------------------------------
//
// === Steps when performing SHL
// (1) Prove the correct decomposition: b_hi * 2**(t-s) + b_lo = b
// (2) Range check b_hi < 2**s && b_lo < 2**(t-s)
// (3) Return b_lo * 2**s
//
// <-- s bits --> | <-- (t-s) bits -->
// ------------------|-------------------
// | b_hi | b_lo | --> b
// --------------------------------------

// Check that b_lo and b_hi are range checked such that:
// SHR: b_hi < 2**(t - s) && b_lo < 2**s
// SHL: b_hi < 2**s && b_lo < 2**(t - s)

// In lieu of a variable length check, we can utilise 2 fixed range checks instead.
// Given the dynamic range check of 0 <= b_hi < 2**(t-s), where s < t
// (1) 0 <= b_hi < 2**128
// (2) 0 <= 2**(t - s) - b_hi < 2**128
// Note that (1) is guaranteed elsewhere through the tagged memory model, so we focus on (2) here.

// === General Notes:
// There are probably ways to merge various relations for the SHL/SHR, but they are separate
// now while we are still figuring out.


// We re-use the a_lo and a_hi cols from the comparison operators for the range checks
// SHR: (1) a_lo = 2**s - b_lo, (2) a_hi = 2**(t-s) - b_hi
// === Range checks: (1) a_lo < 2**128, (2) a_hi < 2**128.
#[SHR_RANGE_0]
shift_lt_bit_len * op_shr * (a_lo - (two_pow_s - b_lo - 1)) = 0;
#[SHR_RANGE_1]
shift_lt_bit_len * op_shr * (a_hi - (two_pow_t_sub_s - b_hi - 1)) = 0;

// SHL: (1) a_lo = 2**(t-s) - b_lo, (2) a_hi = 2**s - b_hi
// === Range checks: (1) a_lo < 2**128, (2) a_hi < 2**128.
#[SHL_RANGE_0]
shift_lt_bit_len * op_shl * (a_lo - (two_pow_t_sub_s - b_lo - 1)) = 0;
#[SHL_RANGE_1]
shift_lt_bit_len * op_shl * (a_hi - (two_pow_s - b_hi - 1)) = 0;

// Indicate if the shift amount < MAX_BITS
pol commit shift_lt_bit_len;
shift_lt_bit_len * (1 - shift_lt_bit_len) = 0;

// The number of bits represented by the memory tag, any shifts greater than this will result in zero.
pol MAX_BITS = u8_tag * 8 +
u16_tag * 16 +
u32_tag * 32 +
u64_tag * 64 +
u128_tag * 128;

// The result of MAX_BITS - ib, this used as part of the range check with the main trace
pol commit t_sub_s_bits;

// Lookups for powers of 2.
// 2**(MAX_BITS - ib), constrained as part of the range check to the main trace
pol commit two_pow_t_sub_s;
// 2 ** ib, constrained as part of the range check to the main trace
pol commit two_pow_s;

// For our assumptions to hold, we must check that s < MAX_BITS. This can be achieved by the following relation.
// We check if s <= MAX_BITS || s >= MAX_BITS using boolean shift_lt_bit_len.
// Regardless of which side is evaluated, the value of t_sub_s_bits < 2**8
// There is no chance of an underflow involving ib to result in a t_sub_b_bits < 2**8 ib is range checked to be < 2**8
// The range checking of t_sub_b_bits in the range [0, 2**8) is done by the lookup for 2**t_sub_s_bits
#[SHIFT_LT_BIT_LEN]
t_sub_s_bits = shift_sel * (shift_lt_bit_len * (MAX_BITS - ib) + (1 - shift_lt_bit_len) * (ib - MAX_BITS));

// ========= SHIFT RIGHT OPERATIONS ===============================
// a_hi * 2**s + a_lo = a
// If ib >= MAX_BITS, we trivially skip this check since the result will be forced to 0.
#[SHR_INPUT_DECOMPOSITION]
shift_lt_bit_len * op_shr * ((b_hi * two_pow_s + b_lo) - ia) = 0;

// Return hi limb, if ib >= MAX_BITS, the output is forced to be 0
#[SHR_OUTPUT]
op_shr * (ic - (b_hi * shift_lt_bit_len)) = 0;

// ========= SHIFT LEFT OPERATIONS ===============================
// a_hi * 2**(t-s) + a_lo = a
// If ib >= MAX_BITS, we trivially skip this check since the result will be forced to 0.
#[SHL_INPUT_DECOMPOSITION]
shift_lt_bit_len * op_shl * ((b_hi * two_pow_t_sub_s + b_lo) - ia) = 0;

// Return lo limb a_lo * 2**s, if ib >= MAX_BITS, the output is forced to be 0
#[SHL_OUTPUT]
op_shl * (ic - (b_lo * two_pow_s * shift_lt_bit_len)) = 0;

26 changes: 23 additions & 3 deletions barretenberg/cpp/pil/avm/avm_main.pil
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ namespace avm_main(256);
pol commit sel_rng_8; // Boolean selector for the 8-bit range check lookup
pol commit sel_rng_16; // Boolean selector for the 16-bit range check lookup

//===== Lookup table powers of 2 =============================================
pol commit table_pow_2; // Table of powers of 2 for 8-bit numbers.

//===== CONTROL FLOW ==========================================================
// Program counter
pol commit pc;
Expand Down Expand Up @@ -60,6 +63,10 @@ namespace avm_main(256);
pol commit sel_op_lt;
// LTE
pol commit sel_op_lte;
// SHL
pol commit sel_op_shl;
// SHR
pol commit sel_op_shr;

// Helper selector to characterize an ALU chiplet selector
pol commit alu_sel;
Expand Down Expand Up @@ -138,6 +145,8 @@ namespace avm_main(256);
sel_op_cast * (1 - sel_op_cast) = 0;
sel_op_lt * (1 - sel_op_lt) = 0;
sel_op_lte * (1 - sel_op_lte) = 0;
sel_op_shl * (1 - sel_op_shl) = 0;
sel_op_shr * (1 - sel_op_shr) = 0;

sel_internal_call * (1 - sel_internal_call) = 0;
sel_internal_return * (1 - sel_internal_return) = 0;
Expand Down Expand Up @@ -295,7 +304,7 @@ namespace avm_main(256);

//===== ALU CONSTRAINTS =====================================================
// TODO: when division is moved to the alu, we will need to add the selector in the list below.
pol ALU_R_TAG_SEL = sel_op_add + sel_op_sub + sel_op_mul + sel_op_not + sel_op_eq + sel_op_lt + sel_op_lte;
pol ALU_R_TAG_SEL = sel_op_add + sel_op_sub + sel_op_mul + sel_op_not + sel_op_eq + sel_op_lt + sel_op_lte + sel_op_shr + sel_op_shl;
pol ALU_W_TAG_SEL = sel_op_cast;
pol ALU_ALL_SEL = ALU_R_TAG_SEL + ALU_W_TAG_SEL;

Expand All @@ -317,11 +326,11 @@ namespace avm_main(256);
#[PERM_MAIN_ALU]
alu_sel {clk, ia, ib, ic, sel_op_add, sel_op_sub,
sel_op_mul, sel_op_eq, sel_op_not, sel_op_cast,
sel_op_lt, sel_op_lte, alu_in_tag}
sel_op_lt, sel_op_lte, sel_op_shr, sel_op_shl, alu_in_tag}
is
avm_alu.alu_sel {avm_alu.clk, avm_alu.ia, avm_alu.ib, avm_alu.ic, avm_alu.op_add, avm_alu.op_sub,
avm_alu.op_mul, avm_alu.op_eq, avm_alu.op_not, avm_alu.op_cast,
avm_alu.op_lt, avm_alu.op_lte, avm_alu.in_tag};
avm_alu.op_lt, avm_alu.op_lte, avm_alu.op_shr, avm_alu.op_shl, avm_alu.in_tag};

// Based on the boolean selectors, we derive the binary op id to lookup in the table;
// TODO: Check if having 4 columns (op_id + 3 boolean selectors) is more optimal that just using the op_id
Expand Down Expand Up @@ -379,6 +388,17 @@ namespace avm_main(256);
#[PERM_MAIN_MEM_IND_D]
ind_op_d {clk, ind_d, mem_idx_d} is avm_mem.ind_op_d {avm_mem.clk, avm_mem.addr, avm_mem.val};

//====== Inter-table Shift Constraints (Lookups) ============================================
// Currently only used for shift operations but can be generalised for other uses.

// Lookup for 2**(ib)
#[LOOKUP_POW_2_0]
avm_alu.shift_sel {avm_alu.ib, avm_alu.two_pow_s} in sel_rng_8 {clk, table_pow_2};

// Lookup for 2**(t-ib)
#[LOOKUP_POW_2_1]
avm_alu.shift_sel {avm_alu.t_sub_s_bits , avm_alu.two_pow_t_sub_s} in sel_rng_8 {clk, table_pow_2};

//====== Inter-table Constraints (Range Checks) ============================================
// TODO: Investigate optimising these range checks. Handling non-FF elements should require less range checks.
// One can increase the granularity based on the operation and tag. In the most extreme case,
Expand Down
Loading

0 comments on commit a516637

Please sign in to comment.