Skip to content

Commit

Permalink
Consolidate permutation mapping computation into one method (AztecPro…
Browse files Browse the repository at this point in the history
  • Loading branch information
ledwards2225 authored Apr 7, 2023
1 parent 646cf49 commit ea7df26
Showing 1 changed file with 73 additions and 102 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ struct permutation_subgroup_element {
bool is_public_input = false;
bool is_tag = false;
};

template <size_t program_width> struct PermutationMapping {
using Mapping = std::array<std::vector<permutation_subgroup_element>, program_width>;
Mapping sigmas;
Mapping ids;
};

using CyclicPermutation = std::vector<cycle_node>;

namespace {
Expand Down Expand Up @@ -120,63 +127,98 @@ std::vector<CyclicPermutation> compute_wire_copy_cycles(const CircuitConstructor
}

/**
* @brief Compute the permutation mapping for the basic no-tags case
* @brief Compute the traditional or generalized permutation mapping
*
* @details This function computes the permutation information in a commonf format that can then be used to generate
* either Plonk-style FFT-ready sigma polynomials or Honk-style indexed vectors
* @details Computes the mappings from which the sigma polynomials (and conditionally, the id polynomials)
* can be computed. The output is proving system agnostic.
*
* @tparam program_width The number of wires
* @tparam generalized (bool) Triggers use of gen perm tags and computation of id mappings when true
* @tparam CircuitConstructor The class that holds basic circuitl ogic
* @param circuit_constructor Circuit-containing object
* @param key Pointer to the proving key
* @return PermutationMapping sigma mapping (and id mapping if generalized == true)
*/
template <size_t program_width, typename CircuitConstructor>
std::array<std::vector<permutation_subgroup_element>, program_width> compute_basic_proof_system_sigma_permutations(
const CircuitConstructor& circuit_constructor, plonk::proving_key* key)
template <size_t program_width, bool generalized, typename CircuitConstructor>
PermutationMapping<program_width> compute_permutation_mapping(const CircuitConstructor& circuit_constructor,
plonk::proving_key* key)
{
// Compute wire copy cycles (cycles of permutations)
auto wire_copy_cycles = compute_wire_copy_cycles<program_width>(circuit_constructor);

PermutationMapping<program_width> mapping;

// Initialize the table of permutations so that every element points to itself
std::array<std::vector<permutation_subgroup_element>, program_width> sigma_mappings;
for (size_t i = 0; i < program_width; ++i) {
sigma_mappings[i].reserve(key->circuit_size);
mapping.sigmas[i].reserve(key->circuit_size);
if (generalized) {
mapping.ids[i].reserve(key->circuit_size);
}

for (size_t j = 0; j < key->circuit_size; ++j) {
sigma_mappings[i].emplace_back(permutation_subgroup_element{
mapping.sigmas[i].emplace_back(permutation_subgroup_element{
.row_index = (uint32_t)j, .column_index = (uint8_t)i, .is_public_input = false, .is_tag = false });
if (generalized) {
mapping.ids[i].emplace_back(permutation_subgroup_element{
.row_index = (uint32_t)j, .column_index = (uint8_t)i, .is_public_input = false, .is_tag = false });
}
}
}

// Represents the index of a variable in circuit_constructor.variables (needed only for generalized)
std::span<const uint32_t> real_variable_tags = circuit_constructor.real_variable_tags;

// Go through each cycle
for (size_t i = 0; i < wire_copy_cycles.size(); ++i) {
for (size_t j = 0; j < wire_copy_cycles[i].size(); ++j) {
// Get the indices of the current node and next node int he cycle
cycle_node current_cycle_node = wire_copy_cycles[i][j];
size_t cycle_index = 0;
for (auto& copy_cycle : wire_copy_cycles) {
for (size_t node_idx = 0; node_idx < copy_cycle.size(); ++node_idx) {
// Get the indices of the current node and next node in the cycle
cycle_node current_cycle_node = copy_cycle[node_idx];
// If current node is the last one in the cycle, then the next one is the first one
size_t next_cycle_node_index = j == wire_copy_cycles[i].size() - 1 ? 0 : j + 1;
cycle_node next_cycle_node = wire_copy_cycles[i][next_cycle_node_index];
size_t next_cycle_node_index = (node_idx == copy_cycle.size() - 1 ? 0 : node_idx + 1);
cycle_node next_cycle_node = copy_cycle[next_cycle_node_index];
const auto current_row = current_cycle_node.gate_index;
const auto next_row = next_cycle_node.gate_index;

const uint32_t current_column = current_cycle_node.wire_index;
const uint32_t next_column = next_cycle_node.wire_index;
const auto current_column = current_cycle_node.wire_index;
const auto next_column = static_cast<uint8_t>(next_cycle_node.wire_index);
// Point current node to the next node
sigma_mappings[current_column][current_row] = {
.row_index = next_row, .column_index = (uint8_t)next_column, .is_public_input = false, .is_tag = false
mapping.sigmas[current_column][current_row] = {
.row_index = next_row, .column_index = next_column, .is_public_input = false, .is_tag = false
};

if (generalized) {
bool first_node = (node_idx == 0);
bool last_node = (next_cycle_node_index == 0);

if (first_node) {
mapping.ids[current_column][current_row].is_tag = true;
mapping.ids[current_column][current_row].row_index = (real_variable_tags[cycle_index]);
}
if (last_node) {
mapping.sigmas[current_column][current_row].is_tag = true;

// TODO(Zac): yikes, std::maps (tau) are expensive. Can we find a way to get rid of this?
mapping.sigmas[current_column][current_row].row_index =
circuit_constructor.tau.at(real_variable_tags[cycle_index]);
}
}
}
cycle_index++;
}

// Add informationa about public inputs to the computation
// Add information about public inputs to the computation
const uint32_t num_public_inputs = static_cast<uint32_t>(circuit_constructor.public_inputs.size());

for (size_t i = 0; i < num_public_inputs; ++i) {
sigma_mappings[0][i].row_index = static_cast<uint32_t>(i);
sigma_mappings[0][i].column_index = 0;
sigma_mappings[0][i].is_public_input = true;
mapping.sigmas[0][i].row_index = static_cast<uint32_t>(i);
mapping.sigmas[0][i].column_index = 0;
mapping.sigmas[0][i].is_public_input = true;
if (mapping.sigmas[0][i].is_tag) {
std::cerr << "MAPPING IS BOTH A TAG AND A PUBLIC INPUT" << std::endl;
}
}
return sigma_mappings;
return mapping;
}

/**
Expand Down Expand Up @@ -403,9 +445,9 @@ template <size_t program_width, typename CircuitConstructor>
void compute_standard_honk_sigma_permutations(CircuitConstructor& circuit_constructor, plonk::proving_key* key)
{
// Compute the permutation table specifying which element becomes which
auto sigma_mappings = compute_basic_proof_system_sigma_permutations<program_width>(circuit_constructor, key);
auto mapping = compute_permutation_mapping<program_width, false>(circuit_constructor, key);
// Compute Honk-style sigma polynomial fromt the permutation table
compute_honk_style_sigma_lagrange_polynomials_from_mapping(sigma_mappings, key);
compute_honk_style_sigma_lagrange_polynomials_from_mapping(mapping.sigmas, key);
}

/**
Expand All @@ -420,9 +462,9 @@ template <size_t program_width, typename CircuitConstructor>
void compute_standard_plonk_sigma_permutations(CircuitConstructor& circuit_constructor, plonk::proving_key* key)
{
// Compute the permutation table specifying which element becomes which
auto sigma_mappings = compute_basic_proof_system_sigma_permutations<program_width>(circuit_constructor, key);
auto mapping = compute_permutation_mapping<program_width, false>(circuit_constructor, key);
// Compute Plonk-style sigma polynomials from the mapping
compute_plonk_permutation_lagrange_polynomials_from_mapping("sigma", sigma_mappings, key);
compute_plonk_permutation_lagrange_polynomials_from_mapping("sigma", mapping.sigmas, key);
// Compute their monomial and coset versions
compute_monomial_and_coset_fft_polynomials_from_lagrange<program_width>("sigma", key);
}
Expand Down Expand Up @@ -456,82 +498,11 @@ template <size_t program_width, typename CircuitConstructor>
void compute_plonk_generalized_sigma_permutations(const CircuitConstructor& circuit_constructor,
plonk::proving_key* key)
{
// Compute wire copy cycles for public and private variables
auto wire_copy_cycles = compute_wire_copy_cycles<program_width>(circuit_constructor);
std::array<std::vector<permutation_subgroup_element>, program_width> sigma_mappings;
std::array<std::vector<permutation_subgroup_element>, program_width> id_mappings;

// Instantiate the sigma and id mappings by reserving enough space and pushing 'default' permutation subgroup
// elements that point to themselves.
for (size_t i = 0; i < program_width; ++i) {
sigma_mappings[i].reserve(key->circuit_size);
id_mappings[i].reserve(key->circuit_size);
}
for (size_t i = 0; i < program_width; ++i) {
for (size_t j = 0; j < key->circuit_size; ++j) {
sigma_mappings[i].emplace_back(permutation_subgroup_element{
.row_index = (uint32_t)j, .column_index = (uint8_t)i, .is_public_input = false, .is_tag = false });

id_mappings[i].emplace_back(permutation_subgroup_element{
.row_index = (uint32_t)j, .column_index = (uint8_t)i, .is_public_input = false, .is_tag = false });
}
}

// // Represents the index of a variable in circuit_constructor.variables
std::span<const uint32_t> real_variable_tags = circuit_constructor.real_variable_tags;
// const std::map<uint32_t, uint32_t>& tau = circuit_constructor.tau;

// Go through all wire cycles and update sigma and id mappings to point to the next element
// within each cycle as well as set the appropriate tags
for (size_t i = 0; i < wire_copy_cycles.size(); ++i) {
for (size_t j = 0; j < wire_copy_cycles[i].size(); ++j) {
cycle_node current_cycle_node = wire_copy_cycles[i][j];
size_t next_cycle_node_index = j == wire_copy_cycles[i].size() - 1 ? 0 : j + 1;
cycle_node next_cycle_node = wire_copy_cycles[i][next_cycle_node_index];
const auto current_row = current_cycle_node.gate_index;
const auto next_row = next_cycle_node.gate_index;

const uint32_t current_column = current_cycle_node.wire_index;
const uint32_t next_column = next_cycle_node.wire_index;

sigma_mappings[current_column][current_row] = {
.row_index = next_row, .column_index = (uint8_t)next_column, .is_public_input = false, .is_tag = false
};

bool first_node, last_node;

first_node = j == 0;
last_node = next_cycle_node_index == 0;
if (first_node) {
id_mappings[current_column][current_row].is_tag = true;
id_mappings[current_column][current_row].row_index = (real_variable_tags[i]);
}
if (last_node) {
sigma_mappings[current_column][current_row].is_tag = true;

// TODO: yikes, std::maps are expensive. Can we find a way to get rid of this?
sigma_mappings[current_column][current_row].row_index =
circuit_constructor.tau.at(real_variable_tags[i]);
}
}
}

const uint32_t num_public_inputs = static_cast<uint32_t>(circuit_constructor.public_inputs.size());

// This corresponds in the paper to modifying sigma to sigma' with the zeta_i values; this enforces public input
// consistency
for (size_t i = 0; i < num_public_inputs; ++i) {
sigma_mappings[0][i].row_index = static_cast<uint32_t>(i);
sigma_mappings[0][i].column_index = 0;
sigma_mappings[0][i].is_public_input = true;
if (sigma_mappings[0][i].is_tag) {
std::cerr << "MAPPING IS BOTH A TAG AND A PUBLIC INPUT" << std::endl;
}
}
auto mapping = compute_permutation_mapping<program_width, true>(circuit_constructor, key);

// Compute Plonk-style sigma and ID polynomials from the corresponding mappings
compute_plonk_permutation_lagrange_polynomials_from_mapping("sigma", sigma_mappings, key);
compute_plonk_permutation_lagrange_polynomials_from_mapping("id", id_mappings, key);
compute_plonk_permutation_lagrange_polynomials_from_mapping("sigma", mapping.sigmas, key);
compute_plonk_permutation_lagrange_polynomials_from_mapping("id", mapping.ids, key);
// Compute the monomial and coset-ffts for sigmas and IDs
compute_monomial_and_coset_fft_polynomials_from_lagrange<program_width>("sigma", key);
compute_monomial_and_coset_fft_polynomials_from_lagrange<program_width>("id", key);
Expand Down

0 comments on commit ea7df26

Please sign in to comment.