Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Several Updates in SMT verification module (part 1) #10437

Merged
merged 8 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ class CircuitBase {
std::unordered_map<SubcircuitType, std::unordered_map<size_t, CircuitProps>>
cached_subcircuits; // caches subcircuits during optimization
// No need to recompute them each time
std::unordered_map<uint32_t, std::vector<bb::fr>>
post_process; // Values idxs that should be post processed after the solver returns a witness.
// Basically it affects only optimized out variables.
// Because in BitVector case we can't collect negative values since they will not be
// the same in the field. That's why we store the expression and calculate it after the witness is
// obtained.

Solver* solver; // pointer to the solver
TermType type; // Type of the underlying Symbolic Terms
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ struct CircuitSchema {
std::vector<std::vector<std::vector<bb::fr>>> lookup_tables;
std::vector<uint32_t> real_variable_tags;
std::unordered_map<uint32_t, uint64_t> range_tags;
bool circuit_finalized;
MSGPACK_FIELDS(modulus,
public_inps,
vars_of_interest,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,32 +103,32 @@ size_t StandardCircuit::prepare_gates(size_t cursor)
// TODO(alex): Test the effect of this relaxation after the tests are merged.
if (univariate_flag) {
if ((q_m == 1) && (q_1 == 0) && (q_2 == 0) && (q_3 == -1) && (q_c == 0)) {
(Bool(symbolic_vars[w_l]) ==
(Bool(this->symbolic_vars[w_l]) ==
Bool(STerm(0, this->solver, this->type)) | // STerm(0, this->solver, this->type)) |
Bool(symbolic_vars[w_l]) ==
Bool(this->symbolic_vars[w_l]) ==
Bool(STerm(1, this->solver, this->type))) // STerm(1, this->solver, this->type)))
.assert_term();
} else {
this->handle_univariate_constraint(q_m, q_1, q_2, q_3, q_c, w_l);
}
} else {
STerm eq = symbolic_vars[0];
STerm eq = this->symbolic_vars[this->variable_names_inverse["zero"]];

// mul selector
if (q_m != 0) {
eq += symbolic_vars[w_l] * symbolic_vars[w_r] * q_m;
eq += this->symbolic_vars[w_l] * this->symbolic_vars[w_r] * q_m;
}
// left selector
if (q_1 != 0) {
eq += symbolic_vars[w_l] * q_1;
eq += this->symbolic_vars[w_l] * q_1;
}
// right selector
if (q_2 != 0) {
eq += symbolic_vars[w_r] * q_2;
eq += this->symbolic_vars[w_r] * q_2;
}
// out selector
if (q_3 != 0) {
eq += symbolic_vars[w_o] * q_3;
eq += this->symbolic_vars[w_o] * q_3;
}
// constant selector
if (q_c != 0) {
Expand Down Expand Up @@ -157,7 +157,7 @@ void StandardCircuit::handle_univariate_constraint(
bb::fr b = q_1 + q_2 + q_3;

if (q_m == 0) {
symbolic_vars[w] == -q_c / b;
this->symbolic_vars[w] == -q_c / b;
return;
}

Expand All @@ -169,10 +169,10 @@ void StandardCircuit::handle_univariate_constraint(
bb::fr x2 = (-b - d.second) / (bb::fr(2) * q_m);

if (d.second == 0) {
symbolic_vars[w] == STerm(x1, this->solver, type);
this->symbolic_vars[w] == STerm(x1, this->solver, type);
} else {
((Bool(symbolic_vars[w]) == Bool(STerm(x1, this->solver, this->type))) |
(Bool(symbolic_vars[w]) == Bool(STerm(x2, this->solver, this->type))))
((Bool(this->symbolic_vars[w]) == Bool(STerm(x1, this->solver, this->type))) |
(Bool(this->symbolic_vars[w]) == Bool(STerm(x2, this->solver, this->type))))
.assert_term();
}
}
Expand Down Expand Up @@ -285,8 +285,6 @@ size_t StandardCircuit::handle_logic_constraint(size_t cursor)
}
}

// TODO(alex): Figure out if I need to create range constraint here too or it'll be
// created anyway in any circuit
if (res != static_cast<size_t>(-1)) {
CircuitProps xor_props = get_standard_logic_circuit(res, true);
CircuitProps and_props = get_standard_logic_circuit(res, false);
Expand All @@ -307,6 +305,45 @@ size_t StandardCircuit::handle_logic_constraint(size_t cursor)
STerm right = this->symbolic_vars[right_idx];
STerm out = this->symbolic_vars[out_idx];

// Simulate the logic constraint circuit using the bitwise operations
size_t num_bits = res;
size_t processed_gates = 0;
for (size_t i = num_bits - 1; i < num_bits; i -= 2) {
// 8 here is the number of gates we have to skip to get proper indices
processed_gates += 8;
uint32_t left_quad_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
uint32_t left_lo_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][1]];
uint32_t left_hi_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]];
processed_gates += 1;
uint32_t right_quad_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
uint32_t right_lo_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][1]];
uint32_t right_hi_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]];
processed_gates += 1;
uint32_t out_quad_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
uint32_t out_lo_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][1]];
uint32_t out_hi_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]];
processed_gates += 1;
uint32_t old_left_acc_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
processed_gates += 1;
uint32_t old_right_acc_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
processed_gates += 1;
uint32_t old_out_acc_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
processed_gates += 1;

this->symbolic_vars[old_left_acc_idx] == (left >> static_cast<uint32_t>(i - 1));
this->symbolic_vars[left_quad_idx] == (this->symbolic_vars[old_left_acc_idx] & 3);
this->symbolic_vars[left_lo_idx] == (this->symbolic_vars[left_quad_idx] & 1);
this->symbolic_vars[left_hi_idx] == (this->symbolic_vars[left_quad_idx] >> 1);
this->symbolic_vars[old_right_acc_idx] == (right >> static_cast<uint32_t>(i - 1));
this->symbolic_vars[right_quad_idx] == (this->symbolic_vars[old_right_acc_idx] & 3);
this->symbolic_vars[right_lo_idx] == (this->symbolic_vars[right_quad_idx] & 1);
this->symbolic_vars[right_hi_idx] == (this->symbolic_vars[right_quad_idx] >> 1);
this->symbolic_vars[old_out_acc_idx] == (out >> static_cast<uint32_t>(i - 1));
this->symbolic_vars[out_quad_idx] == (this->symbolic_vars[old_out_acc_idx] & 3);
this->symbolic_vars[out_lo_idx] == (this->symbolic_vars[out_quad_idx] & 1);
this->symbolic_vars[out_hi_idx] == (this->symbolic_vars[out_quad_idx] >> 1);
}

if (logic_flag) {
(left ^ right) == out;
} else {
Expand Down Expand Up @@ -422,19 +459,41 @@ size_t StandardCircuit::handle_range_constraint(size_t cursor)
// we need this because even right shifts do not create
// any additional gates and therefore are undetectible

// TODO(alex): I think I should simulate the whole subcircuit at that point
// Otherwise optimized out variables are not correct in the final witness
// And I can't fix them by hand each time
size_t num_accs = range_props.gate_idxs.size() - 1;
for (size_t j = 1; j < num_accs + 1 && (this->type == TermType::BVTerm); j++) {
size_t acc_gate = range_props.gate_idxs[j];
uint32_t acc_gate_idx = range_props.idxs[j];
// Simulate the range constraint circuit using the bitwise operations
size_t num_bits = res;
size_t num_quads = num_bits >> 1;
num_quads += num_bits & 1;
uint32_t processed_gates = 0;

for (size_t i = num_quads - 1; i < num_quads; i--) {
uint32_t lo_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a few more comments describing what you are doing?

processed_gates += 1;
uint32_t quad_idx = 0;
uint32_t old_accumulator_idx = 0;
uint32_t hi_idx = 0;

if (i == num_quads - 1 && ((num_bits & 1) == 1)) {
quad_idx = lo_idx;
} else {
hi_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]];
processed_gates += 1;
quad_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
processed_gates += 1;
}

uint32_t acc_idx = this->real_variable_index[this->wires_idxs[cursor + acc_gate][acc_gate_idx]];
if (i == num_quads - 1) {
old_accumulator_idx = quad_idx;
} else {
old_accumulator_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
processed_gates += 1;
}

this->symbolic_vars[acc_idx] == (left >> static_cast<uint32_t>(2 * j));
// I think the following is worse. The name of the variable is lost after that
// this->symbolic_vars[acc_idx] = (left >> static_cast<uint32_t>(2 * j));
this->symbolic_vars[old_accumulator_idx] == (left >> static_cast<uint32_t>(2 * i));
this->symbolic_vars[quad_idx] == (this->symbolic_vars[old_accumulator_idx] & 3);
this->symbolic_vars[lo_idx] == (this->symbolic_vars[quad_idx] & 1);
if (i != (num_quads - 1) || ((num_bits)&1) != 1) {
this->symbolic_vars[hi_idx] == (this->symbolic_vars[quad_idx] >> 1);
}
}

left <= (bb::fr(2).pow(res) - 1);
Expand Down Expand Up @@ -545,8 +604,35 @@ size_t StandardCircuit::handle_shr_constraint(size_t cursor)
STerm left = this->symbolic_vars[left_idx];
STerm out = this->symbolic_vars[out_idx];

STerm shled = left >> nr.second;
out == shled;
// Simulate the shr circuit using bitwise ops
uint32_t shift = nr.second;
if ((shift & 1) == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where are these formulas from? An explanation wouldn't hurt. It is very hard to understand what's happening here without context

size_t processed_gates = 0;
uint32_t c_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]];
uint32_t delta_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
this->symbolic_vars[delta_idx] == (this->symbolic_vars[c_idx] & 3);
STerm delta = this->symbolic_vars[delta_idx];
processed_gates += 1;
uint32_t r0_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];

// this->symbolic_vars[r0_idx] == (-2 * delta * delta + 9 * delta - 7);
this->post_process.insert({ r0_idx, { delta_idx, delta_idx, -2, 9, 0, -7 } });

processed_gates += 1;
uint32_t r1_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
this->symbolic_vars[r1_idx] == (delta >> 1) * 6;
processed_gates += 1;
uint32_t r2_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
this->symbolic_vars[r2_idx] == (left >> shift) * 6;
processed_gates += 1;
uint32_t temp_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];

// this->symbolic_vars[temp_idx] == -6 * out;
this->post_process.insert({ temp_idx, { out_idx, out_idx, 0, -6, 0, 0 } });
}

STerm shred = left >> nr.second;
out == shred;

// You have to mark these arguments so they won't be optimized out
optimized[left_idx] = false;
Expand Down Expand Up @@ -652,7 +738,35 @@ size_t StandardCircuit::handle_shl_constraint(size_t cursor)
STerm left = this->symbolic_vars[left_idx];
STerm out = this->symbolic_vars[out_idx];

STerm shled = (left << nr.second) & (bb::fr(2).pow(nr.first) - 1);
// Simulate the shr circuit using bitwise ops
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shift left

uint32_t num_bits = nr.first;
uint32_t shift = nr.second;
if ((shift & 1) == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you provide an explanation of what you are doing here?

size_t processed_gates = 0;
uint32_t c_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]];
uint32_t delta_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
this->symbolic_vars[delta_idx] == (this->symbolic_vars[c_idx] & 3);
STerm delta = this->symbolic_vars[delta_idx];
processed_gates += 1;
uint32_t r0_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];

// this->symbolic_vars[r0_idx] == (-2 * delta * delta + 9 * delta - 7);
this->post_process.insert({ r0_idx, { delta_idx, delta_idx, -2, 9, 0, -7 } });

processed_gates += 1;
uint32_t r1_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
this->symbolic_vars[r1_idx] == (delta >> 1) * 6;
processed_gates += 1;
uint32_t r2_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
this->symbolic_vars[r2_idx] == (left >> (num_bits - shift)) * 6;
processed_gates += 1;
uint32_t temp_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];

// this->symbolic_vraiables[temp_idx] == -6 * r2
this->post_process.insert({ temp_idx, { r2_idx, r2_idx, 0, -1, 0, 0 } });
}

STerm shled = (left << shift) & (bb::fr(2).pow(num_bits) - 1);
out == shled;

// You have to mark these arguments so they won't be optimized out
Expand Down Expand Up @@ -760,7 +874,35 @@ size_t StandardCircuit::handle_ror_constraint(size_t cursor)
STerm left = this->symbolic_vars[left_idx];
STerm out = this->symbolic_vars[out_idx];

STerm rored = ((left >> nr.second) | (left << (nr.first - nr.second))) & (bb::fr(2).pow(nr.first) - 1);
// Simulate the ror circuit using bitwise ops
uint32_t num_bits = nr.first;
uint32_t rotation = nr.second;
if ((rotation & 1) == 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explain how this works

size_t processed_gates = 0;
uint32_t c_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][0]];
uint32_t delta_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
this->symbolic_vars[delta_idx] == (this->symbolic_vars[c_idx] & 3);
STerm delta = this->symbolic_vars[delta_idx];
processed_gates += 1;
uint32_t r0_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];

// this->symbolic_vars[r0_idx] == (-2 * delta * delta + 9 * delta - 7);
this->post_process.insert({ r0_idx, { delta_idx, delta_idx, -2, 9, 0, -7 } });

processed_gates += 1;
uint32_t r1_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
this->symbolic_vars[r1_idx] == (delta >> 1) * 6;
processed_gates += 1;
uint32_t r2_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];
this->symbolic_vars[r2_idx] == (left >> rotation) * 6;
processed_gates += 1;
uint32_t temp_idx = this->real_variable_index[this->wires_idxs[cursor + processed_gates][2]];

// this->symbolic_vraiables[temp_idx] == -6 * r2
this->post_process.insert({ temp_idx, { r2_idx, r2_idx, 0, -1, 0, 0 } });
}

STerm rored = ((left >> rotation) | (left << (num_bits - rotation))) & (bb::fr(2).pow(num_bits) - 1);
out == rored;

// You have to mark these arguments so they won't be optimized out
Expand Down Expand Up @@ -909,4 +1051,4 @@ std::pair<StandardCircuit, StandardCircuit> StandardCircuit::unique_witness(Circ
}
return { c1, c2 };
}
}; // namespace smt_circuit
}; // namespace smt_circuit
Loading
Loading