Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: faster square roots #2694

Merged
merged 5 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,17 @@ TEST(secp256k1, TestSqr)
}
}

TEST(secp256k1, SqrtRandom)
{
size_t n = 1;
for (size_t i = 0; i < n; ++i) {
secp256k1::fq input = secp256k1::fq::random_element().sqr();
auto [is_sqr, root] = input.sqrt();
secp256k1::fq root_test = root.sqr();
EXPECT_EQ(root_test, input);
}
}

TEST(secp256k1, TestArithmetic)
{
secp256k1::fq a = secp256k1::fq::random_element();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,8 +330,10 @@ template <class Params_> struct alignas(32) field {
*
* @return <true, root> if the element is a quadratic remainder, <false, 0> if it's not
*/
constexpr std::pair<bool, field> sqrt() const noexcept;

constexpr std::pair<bool, field> sqrt() const noexcept
requires((Params_::modulus_0 & 0x3UL) == 0x3UL);
constexpr std::pair<bool, field> sqrt() const noexcept
requires((Params_::modulus_0 & 0x3UL) != 0x3UL);
BB_INLINE constexpr void self_neg() & noexcept;

BB_INLINE constexpr void self_to_montgomery_form() & noexcept;
Expand Down
236 changes: 161 additions & 75 deletions barretenberg/cpp/src/barretenberg/ecc/fields/field_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -452,102 +452,187 @@ template <class T> void field<T>::batch_invert(std::span<field> coeffs) noexcept
}
}

/**
* @brief Implements an optimised variant of Tonelli-Shanks via lookup tables.
* Algorithm taken from https://cr.yp.to/papers/sqroot-20011123-retypeset20220327.pdf
* "FASTER SQUARE ROOTS IN ANNOYING FINITE FIELDS" by D. Bernstein
* Page 5 "Accelerated Discrete Logarithm"
* @tparam T
* @return constexpr field<T>
*/
template <class T> constexpr field<T> field<T>::tonelli_shanks_sqrt() const noexcept
{
BB_OP_COUNT_TRACK_NAME("fr::tonelli_shanks_sqrt");
// Tonelli-shanks algorithm begins by finding a field element Q and integer S,
// such that (p - 1) = Q.2^{s}

// We can compute the square root of a, by considering a^{(Q + 1) / 2} = R
// Once we have found such an R, we have
// R^{2} = a^{Q + 1} = a^{Q}a
// If a^{Q} = 1, we have found our square root.
// Otherwise, we have a^{Q} = t, where t is a 2^{s-1}'th root of unity.
// This is because t^{2^{s-1}} = a^{Q.2^{s-1}}.
// We know that (p - 1) = Q.w^{s}, therefore t^{2^{s-1}} = a^{(p - 1) / 2}
// From Euler's criterion, if a is a quadratic residue, a^{(p - 1) / 2} = 1
// i.e. t^{2^{s-1}} = 1

// To proceed with computing our square root, we want to transform t into a smaller subgroup,
// specifically, the (s-2)'th roots of unity.
// We do this by finding some value b,such that
// (t.b^2)^{2^{s-2}} = 1 and R' = R.b
// Finding such a b is trivial, because from Euler's criterion, we know that,
// for any quadratic non-residue z, z^{(p - 1) / 2} = -1
// i.e. z^{Q.2^{s-1}} = -1
// => z^Q is a 2^{s-1}'th root of -1
// => z^{Q^2} is a 2^{s-2}'th root of -1
// Since t^{2^{s-1}} = 1, we know that t^{2^{s - 2}} = -1
// => t.z^{Q^2} is a 2^{s - 2}'th root of unity.

// We can iteratively transform t into ever smaller subgroups, until t = 1.
// At each iteration, we need to find a new value for b, which we can obtain
// by repeatedly squaring z^{Q}
constexpr uint256_t Q = (modulus - 1) >> static_cast<uint64_t>(primitive_root_log_size() - 1);
constexpr uint256_t Q_minus_one_over_two = (Q - 1) >> 2;

// __to_montgomery_form(Q_minus_one_over_two, Q_minus_one_over_two);
field z = coset_generator(0); // the generator is a non-residue
field b = pow(Q_minus_one_over_two);
field r = operator*(b); // r = a^{(Q + 1) / 2}
field t = r * b; // t = a^{(Q - 1) / 2 + (Q + 1) / 2} = a^{Q}
// We can determine s by counting the least significant set bit of `p - 1`
// We pick elements `r, g` such that g = r^Q and r is not a square.
// (the coset generators are all nonresidues and satisfy this condition)
//
// To find the square root of `u`, consider `v = u^(Q - 1 / 2)`
// There exists an integer `e` where uv^2 = g^e (see Theorem 3.1 in paper).
// If `u` is a square, `e` is even and (uvg^{−e/2})^2 = u^2v^2g^e = u^{Q+1}g^{-e} = u
//
// The goal of the algorithm is two fold:
// 1. find `e` given `u`
// 2. compute `sqrt(u) = uvg^{−e/2}`
constexpr uint256_t Q = (modulus - 1) >> static_cast<uint64_t>(primitive_root_log_size());
constexpr uint256_t Q_minus_one_over_two = (Q - 1) >> 1;
field v = pow(Q_minus_one_over_two);
field uv = operator*(v); // uv = u^{(Q + 1) / 2}
// uvv = g^e for some unknown e. Goal is to find e.
field uvv = uv * v; // uvv = u^{(Q - 1) / 2 + (Q + 1) / 2} = u^{Q}

// check if t is a square with euler's criterion
// if not, we don't have a quadratic residue and a has no square root!
field check = t;
field check = uvv;
for (size_t i = 0; i < primitive_root_log_size() - 1; ++i) {
check.self_sqr();
}
if (check != one()) {
return zero();
if (check != 1) {
return 0;
}
field t1 = z.pow(Q_minus_one_over_two);
field t2 = t1 * z;
field c = t2 * t1; // z^Q

size_t m = primitive_root_log_size();
constexpr field g = coset_generator(0).pow(Q);
constexpr field g_inv = coset_generator(0).pow(modulus - 1 - Q);
constexpr size_t root_bits = primitive_root_log_size();
constexpr size_t table_bits = 6;
constexpr size_t num_tables = root_bits / table_bits + (root_bits % table_bits != 0 ? 1 : 0);
constexpr size_t num_offset_tables = num_tables - 1;
constexpr size_t table_size = static_cast<size_t>(1UL) << table_bits;

using GTable = std::array<field, table_size>;
constexpr auto get_g_table = [&](const field& h) {
GTable result;
result[0] = 1;
for (size_t i = 1; i < table_size; ++i) {
result[i] = result[i - 1] * h;
}
return result;
};
constexpr std::array<GTable, num_tables> g_tables = [&]() {
field working_base = g_inv;
std::array<GTable, num_tables> result;
for (size_t i = 0; i < num_tables; ++i) {
result[i] = get_g_table(working_base);
for (size_t j = 0; j < table_bits; ++j) {
working_base.self_sqr();
}
}
return result;
}();
constexpr std::array<GTable, num_offset_tables> offset_g_tables = [&]() {
field working_base = g_inv;
for (size_t i = 0; i < root_bits % table_bits; ++i) {
working_base.self_sqr();
}
std::array<GTable, num_offset_tables> result;
for (size_t i = 0; i < num_offset_tables; ++i) {
result[i] = get_g_table(working_base);
for (size_t j = 0; j < table_bits; ++j) {
working_base.self_sqr();
}
}
return result;
}();

constexpr GTable root_table_a = get_g_table(g.pow(1UL << ((num_tables - 1) * table_bits)));
constexpr GTable root_table_b = get_g_table(g.pow(1UL << (root_bits - table_bits)));
// compute uvv^{2^table_bits}, uvv^{2^{table_bits*2}}, ..., uvv^{2^{table_bits*num_tables}}
std::array<field, num_tables> uvv_powers;
field base = uvv;
for (size_t i = 0; i < num_tables - 1; ++i) {
uvv_powers[i] = base;
for (size_t j = 0; j < table_bits; ++j) {
base.self_sqr();
}
}
uvv_powers[num_tables - 1] = base;
std::array<size_t, num_tables> e_slices;
for (size_t i = 0; i < num_tables; ++i) {
size_t table_index = num_tables - 1 - i;
field target = uvv_powers[table_index];
for (size_t j = 0; j < i; ++j) {
size_t e_idx = num_tables - 1 - (i - 1) + j;
size_t g_idx = num_tables - 2 - j;

field g_lookup;
if (j != i - 1) {
g_lookup = offset_g_tables[g_idx - 1][e_slices[e_idx]]; // e1
} else {
g_lookup = g_tables[g_idx][e_slices[e_idx]];
}
target *= g_lookup;
}
size_t count = 0;

if (i == 0) {
for (auto& x : root_table_a) {
if (x == target) {
break;
}
count += 1;
}
} else {
for (auto& x : root_table_b) {
if (x == target) {
break;
}
count += 1;
}
}

while (t != one()) {
size_t i = 0;
field t2m = t;
ASSERT(count != table_size);
e_slices[table_index] = count;
}

// find the smallest value of m, such that t^{2^m} = 1
while (t2m != one()) {
t2m.self_sqr();
i += 1;
// We want to compute g^{-e/2} which requires computing `e/2` via our slice representation
for (size_t i = 0; i < num_tables; ++i) {
auto& e_slice = e_slices[num_tables - 1 - i];
// e_slices[num_tables - 1] is always even.
// From theorem 3.1 (https://cr.yp.to/papers/sqroot-20011123-retypeset20220327.pdf)
// if slice is odd, propagate the downshifted bit into previous slice value
if ((e_slice & 1UL) == 1UL) {
size_t borrow_value = (i == 1) ? 1UL << ((root_bits % table_bits) - 1) : (1UL << (table_bits - 1));
e_slices[num_tables - i] += borrow_value;
}
e_slice >>= 1;
}

size_t j = m - i - 1;
b = c;
while (j > 0) {
b.self_sqr();
--j;
} // b = z^2^(m-i-1)

c = b.sqr();
t = t * c;
r = r * b;
m = i;
field g_pow_minus_e_over_2 = 1;
for (size_t i = 0; i < num_tables; ++i) {
if (i == 0) {
g_pow_minus_e_over_2 *= g_tables[i][e_slices[num_tables - 1 - i]];
} else {
g_pow_minus_e_over_2 *= offset_g_tables[i - 1][e_slices[num_tables - 1 - i]];
}
}
return r;
return uv * g_pow_minus_e_over_2;
}

template <class T> constexpr std::pair<bool, field<T>> field<T>::sqrt() const noexcept
template <class T>
constexpr std::pair<bool, field<T>> field<T>::sqrt() const noexcept
requires((T::modulus_0 & 0x3UL) == 0x3UL)
{
BB_OP_COUNT_TRACK_NAME("fr::sqrt");
field root;
if constexpr ((T::modulus_0 & 0x3UL) == 0x3UL) {
constexpr uint256_t sqrt_exponent = (modulus + uint256_t(1)) >> 2;
root = pow(sqrt_exponent);
} else {
root = tonelli_shanks_sqrt();
}
constexpr uint256_t sqrt_exponent = (modulus + uint256_t(1)) >> 2;
field root = pow(sqrt_exponent);
if ((root * root) == (*this)) {
return std::pair<bool, field>(true, root);
}
return std::pair<bool, field>(false, field::zero());
}

} // namespace bb;
template <class T>
constexpr std::pair<bool, field<T>> field<T>::sqrt() const noexcept
requires((T::modulus_0 & 0x3UL) != 0x3UL)
{
field root = tonelli_shanks_sqrt();
if ((root * root) == (*this)) {
return std::pair<bool, field>(true, root);
}
return std::pair<bool, field>(false, field::zero());
}

template <class T> constexpr field<T> field<T>::operator/(const field& other) const noexcept
{
Expand Down Expand Up @@ -634,8 +719,8 @@ constexpr std::array<field<T>, field<T>::COSET_GENERATOR_SIZE> field<T>::compute

size_t count = 1;
while (count < n) {
// work_variable contains a new field element, and we need to test that, for all previous vector elements,
// result[i] / work_variable is not a member of our subgroup
// work_variable contains a new field element, and we need to test that, for all previous vector
// elements, result[i] / work_variable is not a member of our subgroup
field work_inverse = work_variable.invert();
bool valid = true;
for (size_t j = 0; j < count; ++j) {
Expand Down Expand Up @@ -674,8 +759,9 @@ template <class Params> void field<Params>::msgpack_pack(auto& packer) const
// The field is first converted from Montgomery form, similar to how the old format did it.
auto adjusted = from_montgomery_form();

// The data is then converted to big endian format using htonll, which stands for "host to network long long".
// This is necessary because the data will be written to a raw msgpack buffer, which requires big endian format.
// The data is then converted to big endian format using htonll, which stands for "host to network long
// long". This is necessary because the data will be written to a raw msgpack buffer, which requires big
// endian format.
uint64_t bin_data[4] = {
htonll(adjusted.data[3]), htonll(adjusted.data[2]), htonll(adjusted.data[1]), htonll(adjusted.data[0])
};
Expand All @@ -693,8 +779,8 @@ template <class Params> void field<Params>::msgpack_unpack(auto o)
// The binary data is first extracted from the msgpack object.
std::array<uint8_t, sizeof(data)> raw_data = o;

// The binary data is then read as big endian uint64_t's. This is done by casting the raw data to uint64_t* and then
// using ntohll ("network to host long long") to correct the endianness to the host's endianness.
// The binary data is then read as big endian uint64_t's. This is done by casting the raw data to uint64_t*
// and then using ntohll ("network to host long long") to correct the endianness to the host's endianness.
uint64_t* cast_data = (uint64_t*)&raw_data[0]; // NOLINT
uint64_t reversed[] = { ntohll(cast_data[3]), ntohll(cast_data[2]), ntohll(cast_data[1]), ntohll(cast_data[0]) };

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ export class BarretenbergWasmMain extends BarretenbergWasmBase {
module: WebAssembly.Module,
threads = Math.min(getNumCpu(), BarretenbergWasmMain.MAX_THREADS),
logger: (msg: string) => void = debug,
initial = 30,
initial = 31,
maximum = 2 ** 16,
) {
this.logger = logger;
Expand Down
2 changes: 1 addition & 1 deletion yarn-project/foundation/src/wasm/wasm_module.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ export class WasmModule implements IWasmModule {
* @param initMethod - Defaults to calling '_initialize'.
* @param maximum - 8192 maximum by default. 512mb.
*/
public async init(initial = 30, maximum = 8192, initMethod: string | null = '_initialize') {
public async init(initial = 31, maximum = 8192, initMethod: string | null = '_initialize') {
this.debug(
`initial mem: ${initial} pages, ${(initial * 2 ** 16) / (1024 * 1024)}mb. max mem: ${maximum} pages, ${
(maximum * 2 ** 16) / (1024 * 1024)
Expand Down
Loading