Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Consolidate permutation mapping computation into one method #330

Merged
merged 4 commits into from
Apr 7, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 73 additions & 102 deletions cpp/src/barretenberg/proof_system/composer/permutation_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ struct permutation_subgroup_element {
bool is_public_input = false;
bool is_tag = false;
};

template <size_t program_width> struct PermutationMapping {
using Mapping = std::array<std::vector<permutation_subgroup_element>, program_width>;
Mapping sigmas;
Mapping ids;
};

using CyclicPermutation = std::vector<cycle_node>;

namespace {
Expand Down Expand Up @@ -120,63 +127,98 @@ std::vector<CyclicPermutation> compute_wire_copy_cycles(const CircuitConstructor
}

/**
* @brief Compute the permutation mapping for the basic no-tags case
* @brief Compute the traditional or generalized permutation mapping
*
* @details This function computes the permutation information in a commonf format that can then be used to generate
* either Plonk-style FFT-ready sigma polynomials or Honk-style indexed vectors
* @details Computes the mappings from which the sigma polynomials (and conditionally, the id polynomials)
* can be computed. The output is proving system agnostic.
*
* @tparam program_width The number of wires
* @tparam generalized (bool) Triggers use of gen perm tags and computation of id mappings when true
* @tparam CircuitConstructor The class that holds basic circuitl ogic
* @param circuit_constructor Circuit-containing object
* @param key Pointer to the proving key
* @return PermutationMapping sigma mapping (and id mapping if generalized == true)
*/
template <size_t program_width, typename CircuitConstructor>
std::array<std::vector<permutation_subgroup_element>, program_width> compute_basic_proof_system_sigma_permutations(
const CircuitConstructor& circuit_constructor, plonk::proving_key* key)
template <size_t program_width, bool generalized, typename CircuitConstructor>
PermutationMapping<program_width> compute_permutation_mapping(const CircuitConstructor& circuit_constructor,
plonk::proving_key* key)
{
// Compute wire copy cycles (cycles of permutations)
auto wire_copy_cycles = compute_wire_copy_cycles<program_width>(circuit_constructor);

PermutationMapping<program_width> mapping;

// Initialize the table of permutations so that every element points to itself
std::array<std::vector<permutation_subgroup_element>, program_width> sigma_mappings;
for (size_t i = 0; i < program_width; ++i) {
sigma_mappings[i].reserve(key->circuit_size);
mapping.sigmas[i].reserve(key->circuit_size);
if (generalized) {
mapping.ids[i].reserve(key->circuit_size);
}

for (size_t j = 0; j < key->circuit_size; ++j) {
sigma_mappings[i].emplace_back(permutation_subgroup_element{
mapping.sigmas[i].emplace_back(permutation_subgroup_element{
.row_index = (uint32_t)j, .column_index = (uint8_t)i, .is_public_input = false, .is_tag = false });
if (generalized) {
mapping.ids[i].emplace_back(permutation_subgroup_element{
.row_index = (uint32_t)j, .column_index = (uint8_t)i, .is_public_input = false, .is_tag = false });
}
}
}

// Represents the index of a variable in circuit_constructor.variables (needed only for generalized)
std::span<const uint32_t> real_variable_tags = circuit_constructor.real_variable_tags;

// Go through each cycle
for (size_t i = 0; i < wire_copy_cycles.size(); ++i) {
for (size_t j = 0; j < wire_copy_cycles[i].size(); ++j) {
// Get the indices of the current node and next node int he cycle
cycle_node current_cycle_node = wire_copy_cycles[i][j];
size_t cycle_index = 0;
for (auto& copy_cycle : wire_copy_cycles) {
ledwards2225 marked this conversation as resolved.
Show resolved Hide resolved
for (size_t node_idx = 0; node_idx < copy_cycle.size(); ++node_idx) {
// Get the indices of the current node and next node in the cycle
cycle_node current_cycle_node = copy_cycle[node_idx];
// If current node is the last one in the cycle, then the next one is the first one
size_t next_cycle_node_index = j == wire_copy_cycles[i].size() - 1 ? 0 : j + 1;
cycle_node next_cycle_node = wire_copy_cycles[i][next_cycle_node_index];
size_t next_cycle_node_index = (node_idx == copy_cycle.size() - 1 ? 0 : node_idx + 1);
cycle_node next_cycle_node = copy_cycle[next_cycle_node_index];
const auto current_row = current_cycle_node.gate_index;
const auto next_row = next_cycle_node.gate_index;

const uint32_t current_column = current_cycle_node.wire_index;
const uint32_t next_column = next_cycle_node.wire_index;
const auto current_column = current_cycle_node.wire_index;
const auto next_column = static_cast<uint8_t>(next_cycle_node.wire_index);
// Point current node to the next node
sigma_mappings[current_column][current_row] = {
.row_index = next_row, .column_index = (uint8_t)next_column, .is_public_input = false, .is_tag = false
mapping.sigmas[current_column][current_row] = {
.row_index = next_row, .column_index = next_column, .is_public_input = false, .is_tag = false
};

if (generalized) {
bool first_node = (node_idx == 0);
bool last_node = (next_cycle_node_index == 0);

if (first_node) {
mapping.ids[current_column][current_row].is_tag = true;
mapping.ids[current_column][current_row].row_index = (real_variable_tags[cycle_index]);
}
if (last_node) {
mapping.sigmas[current_column][current_row].is_tag = true;

// TODO(Zac): yikes, std::maps (tau) are expensive. Can we find a way to get rid of this?
mapping.sigmas[current_column][current_row].row_index =
circuit_constructor.tau.at(real_variable_tags[cycle_index]);
}
}
}
cycle_index++;
}

// Add informationa about public inputs to the computation
// Add information about public inputs to the computation
const uint32_t num_public_inputs = static_cast<uint32_t>(circuit_constructor.public_inputs.size());

for (size_t i = 0; i < num_public_inputs; ++i) {
sigma_mappings[0][i].row_index = static_cast<uint32_t>(i);
sigma_mappings[0][i].column_index = 0;
sigma_mappings[0][i].is_public_input = true;
mapping.sigmas[0][i].row_index = static_cast<uint32_t>(i);
mapping.sigmas[0][i].column_index = 0;
mapping.sigmas[0][i].is_public_input = true;
if (mapping.sigmas[0][i].is_tag) {
std::cerr << "MAPPING IS BOTH A TAG AND A PUBLIC INPUT" << std::endl;
}
}
return sigma_mappings;
return mapping;
}

/**
Expand Down Expand Up @@ -403,9 +445,9 @@ template <size_t program_width, typename CircuitConstructor>
void compute_standard_honk_sigma_permutations(CircuitConstructor& circuit_constructor, plonk::proving_key* key)
{
// Compute the permutation table specifying which element becomes which
auto sigma_mappings = compute_basic_proof_system_sigma_permutations<program_width>(circuit_constructor, key);
auto mapping = compute_permutation_mapping<program_width, false>(circuit_constructor, key);
// Compute Honk-style sigma polynomial fromt the permutation table
compute_honk_style_sigma_lagrange_polynomials_from_mapping(sigma_mappings, key);
compute_honk_style_sigma_lagrange_polynomials_from_mapping(mapping.sigmas, key);
}

/**
Expand All @@ -420,9 +462,9 @@ template <size_t program_width, typename CircuitConstructor>
void compute_standard_plonk_sigma_permutations(CircuitConstructor& circuit_constructor, plonk::proving_key* key)
{
// Compute the permutation table specifying which element becomes which
auto sigma_mappings = compute_basic_proof_system_sigma_permutations<program_width>(circuit_constructor, key);
auto mapping = compute_permutation_mapping<program_width, false>(circuit_constructor, key);
// Compute Plonk-style sigma polynomials from the mapping
compute_plonk_permutation_lagrange_polynomials_from_mapping("sigma", sigma_mappings, key);
compute_plonk_permutation_lagrange_polynomials_from_mapping("sigma", mapping.sigmas, key);
// Compute their monomial and coset versions
compute_monomial_and_coset_fft_polynomials_from_lagrange<program_width>("sigma", key);
}
Expand Down Expand Up @@ -456,82 +498,11 @@ template <size_t program_width, typename CircuitConstructor>
void compute_plonk_generalized_sigma_permutations(const CircuitConstructor& circuit_constructor,
plonk::proving_key* key)
{
// Compute wire copy cycles for public and private variables
auto wire_copy_cycles = compute_wire_copy_cycles<program_width>(circuit_constructor);
std::array<std::vector<permutation_subgroup_element>, program_width> sigma_mappings;
std::array<std::vector<permutation_subgroup_element>, program_width> id_mappings;

// Instantiate the sigma and id mappings by reserving enough space and pushing 'default' permutation subgroup
// elements that point to themselves.
for (size_t i = 0; i < program_width; ++i) {
sigma_mappings[i].reserve(key->circuit_size);
id_mappings[i].reserve(key->circuit_size);
}
for (size_t i = 0; i < program_width; ++i) {
for (size_t j = 0; j < key->circuit_size; ++j) {
sigma_mappings[i].emplace_back(permutation_subgroup_element{
.row_index = (uint32_t)j, .column_index = (uint8_t)i, .is_public_input = false, .is_tag = false });

id_mappings[i].emplace_back(permutation_subgroup_element{
.row_index = (uint32_t)j, .column_index = (uint8_t)i, .is_public_input = false, .is_tag = false });
}
}

// // Represents the index of a variable in circuit_constructor.variables
std::span<const uint32_t> real_variable_tags = circuit_constructor.real_variable_tags;
// const std::map<uint32_t, uint32_t>& tau = circuit_constructor.tau;

// Go through all wire cycles and update sigma and id mappings to point to the next element
// within each cycle as well as set the appropriate tags
for (size_t i = 0; i < wire_copy_cycles.size(); ++i) {
for (size_t j = 0; j < wire_copy_cycles[i].size(); ++j) {
cycle_node current_cycle_node = wire_copy_cycles[i][j];
size_t next_cycle_node_index = j == wire_copy_cycles[i].size() - 1 ? 0 : j + 1;
cycle_node next_cycle_node = wire_copy_cycles[i][next_cycle_node_index];
const auto current_row = current_cycle_node.gate_index;
const auto next_row = next_cycle_node.gate_index;

const uint32_t current_column = current_cycle_node.wire_index;
const uint32_t next_column = next_cycle_node.wire_index;

sigma_mappings[current_column][current_row] = {
.row_index = next_row, .column_index = (uint8_t)next_column, .is_public_input = false, .is_tag = false
};

bool first_node, last_node;

first_node = j == 0;
last_node = next_cycle_node_index == 0;
if (first_node) {
id_mappings[current_column][current_row].is_tag = true;
id_mappings[current_column][current_row].row_index = (real_variable_tags[i]);
}
if (last_node) {
sigma_mappings[current_column][current_row].is_tag = true;

// TODO: yikes, std::maps are expensive. Can we find a way to get rid of this?
sigma_mappings[current_column][current_row].row_index =
circuit_constructor.tau.at(real_variable_tags[i]);
}
}
}

const uint32_t num_public_inputs = static_cast<uint32_t>(circuit_constructor.public_inputs.size());

// This corresponds in the paper to modifying sigma to sigma' with the zeta_i values; this enforces public input
// consistency
for (size_t i = 0; i < num_public_inputs; ++i) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where did this logic go?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This takes place in the final lines of compute_permutation_mapping

sigma_mappings[0][i].row_index = static_cast<uint32_t>(i);
sigma_mappings[0][i].column_index = 0;
sigma_mappings[0][i].is_public_input = true;
if (sigma_mappings[0][i].is_tag) {
std::cerr << "MAPPING IS BOTH A TAG AND A PUBLIC INPUT" << std::endl;
}
}
auto mapping = compute_permutation_mapping<program_width, true>(circuit_constructor, key);

// Compute Plonk-style sigma and ID polynomials from the corresponding mappings
compute_plonk_permutation_lagrange_polynomials_from_mapping("sigma", sigma_mappings, key);
compute_plonk_permutation_lagrange_polynomials_from_mapping("id", id_mappings, key);
compute_plonk_permutation_lagrange_polynomials_from_mapping("sigma", mapping.sigmas, key);
compute_plonk_permutation_lagrange_polynomials_from_mapping("id", mapping.ids, key);
// Compute the monomial and coset-ffts for sigmas and IDs
compute_monomial_and_coset_fft_polynomials_from_lagrange<program_width>("sigma", key);
compute_monomial_and_coset_fft_polynomials_from_lagrange<program_width>("id", key);
Expand Down