From 0c7df67fdd5eed99e7715f31c4c54be3e408d34b Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Fri, 20 Sep 2024 15:50:31 +0300 Subject: [PATCH 01/22] montgomery mult with correctness --- icicle/include/icicle/fields/field.h | 33 ++++++++++++ icicle/include/icicle/fields/host_math.h | 34 +++++++++++++ .../fields/snark_fields/bls12_377_scalar.h | 1 + icicle/tests/test_field_api.cpp | 50 +++++++++++++++++++ 4 files changed, 118 insertions(+) diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 85099e929..7a09295bf 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -122,6 +122,9 @@ class Field */ static constexpr HOST_DEVICE_INLINE ff_storage get_neg_modulus() { return CONFIG::neg_modulus; } + static constexpr HOST_DEVICE_INLINE ff_storage get_mont_inv_modulus() { return CONFIG::mont_inv_modulus; } + static constexpr HOST_DEVICE_INLINE ff_storage get_mont_r() { return CONFIG::montgomery_r; } + static constexpr HOST_DEVICE_INLINE ff_storage get_mont_r_inv() { return CONFIG::montgomery_r_inv; } /** * A new addition to the config file - the number of times to reduce in [reduce](@ref reduce) function. */ @@ -802,6 +805,29 @@ class Field return r; } + template + static constexpr HOST_DEVICE_INLINE Field mont_reduce(const Wide& xs) + { + // Field xs_lo = Wide::get_lower(xs); + // Field xs_hi = Wide::get_higher(xs); + // Wide l1 = {}; + // Wide l2 = {}; + // host_math::template multiply_raw(xs_lo.limbs_storage, get_m(), l1.limbs_storage); + // Field l1_lo = Wide::get_lower(l1); + // host_math::template multiply_raw(l1_lo.limbs_storage, get_modulus<1>(), l2.limbs_storage); + // Field l2_hi = Wide::get_higher(l2); + // Field r = {}; + // add_limbs(l2_hi.limbs_storage, xs_hi.limbs_storage, r.limbs_storage); + + Field r = Wide::get_lower(xs); + ff_storage r_reduced = {}; + uint64_t carry = 0; + carry = sub_limbs(r.limbs_storage, get_modulus<1>(), r_reduced); + if (carry == 0) r = Field{r_reduced}; + return r; + } + + HOST_DEVICE Field& operator=(Field const& other) { for (int i = 0; i < TLC; i++) { @@ -816,6 +842,13 @@ class Field return reduce(xy); // reduce mod p } + static constexpr HOST_INLINE Field mont_mult(const Field& xs, const Field& ys) + { + Wide r = {}; + host_math::multiply_mont_64(xs.limbs_storage.limbs64, ys.limbs_storage.limbs64, get_mont_inv_modulus().limbs64, get_modulus<1>().limbs64, r.limbs_storage.limbs64); + return mont_reduce(r); + } + friend HOST_DEVICE bool operator==(const Field& xs, const Field& ys) { #ifdef __CUDA_ARCH__ diff --git a/icicle/include/icicle/fields/host_math.h b/icicle/include/icicle/fields/host_math.h index e256aa922..c944dda7a 100644 --- a/icicle/include/icicle/fields/host_math.h +++ b/icicle/include/icicle/fields/host_math.h @@ -199,6 +199,40 @@ namespace host_math { } } + template + static HOST_INLINE void multiply_mont_64(const uint64_t* a, const uint64_t* b, const uint64_t* q, const uint64_t* p, uint64_t* r) + { + // printf("r0: "); + // for (unsigned i = 0; i < NLIMBS_B / 2; i++) { + // printf(" %lu,",r[i]); + // } + // printf("\n"); + for (unsigned i = 0; i < NLIMBS_B / 2; i++) { + // printf("i %d\n", i); + uint64_t A = 0, C = 0; + r[0] = host_math::madc_cc_64(a[0], b[i], r[0], A); + // printf("r0 %lu\n",r[0]); + // printf("q0 %lu\n",q[0]); + // printf("p0 %lu\n",p[0]); + // printf("A %lu\n",A); + uint64_t m = host_math::madc_cc_64(r[0], q[0], 0, C); //TODO - multiply inst + // printf("m %lu\n",m); + C = 0; + host_math::madc_cc_64(m, p[0], r[0], C); + // printf("c %lu\n",C); + for (unsigned j = 1; j < NLIMBS_A / 2; j++) { + r[j] = host_math::madc_cc_64(a[j], b[i], r[j], A); + r[j - 1] = host_math::madc_cc_64(m, p[j], r[j], C); + } + r[NLIMBS_A / 2 - 1] = C + A; + } + // printf("rf: "); + // for (unsigned i = 0; i < NLIMBS_B / 2; i++) { + // printf(" %lu,",r[i]); + // } + // printf("\n"); + } + template static HOST_INLINE void multiply_raw_64(const storage& as, const storage& bs, storage& rs) diff --git a/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h b/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h index 54a56db69..9b88bc858 100644 --- a/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h +++ b/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h @@ -8,6 +8,7 @@ namespace bls12_377 { struct fp_config { static constexpr storage<8> modulus = {0x00000001, 0x0a118000, 0xd0000001, 0x59aa76fe, 0x5c37b001, 0x60b44d1e, 0x9a2ca556, 0x12ab655e}; + static constexpr storage<8> mont_inv_modulus = {0xffffffff, 0xa117fff, 0x90000001, 0x452217cc, 0x4790a000, 0x249765c3, 0x68b29556, 0x6992d0fa}; PARAMS(modulus) static constexpr storage<8> rou = {0xec2a895e, 0x476ef4a4, 0x63e3f04a, 0x9b506ee3, diff --git a/icicle/tests/test_field_api.cpp b/icicle/tests/test_field_api.cpp index 9743c6d2d..4cbfab454 100644 --- a/icicle/tests/test_field_api.cpp +++ b/icicle/tests/test_field_api.cpp @@ -84,6 +84,56 @@ TYPED_TEST(FieldApiTest, FieldSanityTest) ASSERT_EQ(a * scalar_t::from(2), a + a); } +#ifndef EXT_FIELD +TYPED_TEST(FieldApiTest, FieldLimbsTypeSanityTest) +{ + // std::cout << "__cplusplus: " << __cplusplus << std::endl; + // for (int i = 0; i < 100000; i++) { + auto a = TypeParam::rand_host(); + auto b = TypeParam::rand_host(); + // auto b = a; + auto ar = TypeParam::to_montgomery(a); + auto br = TypeParam::to_montgomery(b); + auto rr = TypeParam::mont_mult(ar,br); + auto r = TypeParam::from_montgomery(rr); + // if (r != a*b){ + std::cout << "a: "<< a << std::endl; + std::cout << "b: "<< b << std::endl; + std::cout << "ar: "<< ar << std::endl; + std::cout << "br: "<< br << std::endl; + std::cout << "rr: "<< rr << std::endl; + std::cout << "r: "<< r << std::endl; + std::cout << "p: "<(a.limbs_storage, a.limbs_storage, r_wide.limbs_storage); + // a = TypeParam::reduce(r_wide); + ar = TypeParam::mont_mult(ar,ar); + // a = TypeParam::mont_reduce(r_wide); + // host_math::template multiply_raw(b.limbs_storage, b.limbs_storage, r2_wide.limbs_storage); + // host_math::template add_sub_limbs(a.limbs_storage, a.limbs_storage, a.limbs_storage); + // host_math::template add_sub_limbs(b.limbs_storage, b.limbs_storage, b.limbs_storage); + // a = TypeParam::Wide::get_lower(r_wide); + // b = TypeParam::Wide::get_lower(r2_wide); + // a = a + a; + a = a * a; + } + END_TIMER(MULT_sync, oss.str().c_str(), true); + ASSERT_EQ(TypeParam::from_montgomery(ar), a); +} +#endif + + TYPED_TEST(FieldApiTest, vectorOps) { const uint64_t N = 1 << 22; From 5f2923943dafb0a9c2c030f06a890323325f9963 Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Mon, 30 Sep 2024 11:27:12 +0300 Subject: [PATCH 02/22] fix reduction condition --- icicle/include/icicle/fields/field.h | 4 ++++ icicle/tests/test_field_api.cpp | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 7a09295bf..951336282 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -820,6 +820,9 @@ class Field // add_limbs(l2_hi.limbs_storage, xs_hi.limbs_storage, r.limbs_storage); Field r = Wide::get_lower(xs); + Field p = Field{get_modulus<1>()}; + if (p.limbs_storage.limbs[TLC-1] > r.limbs_storage.limbs[TLC-1]) + return r; ff_storage r_reduced = {}; uint64_t carry = 0; carry = sub_limbs(r.limbs_storage, get_modulus<1>(), r_reduced); @@ -847,6 +850,7 @@ class Field Wide r = {}; host_math::multiply_mont_64(xs.limbs_storage.limbs64, ys.limbs_storage.limbs64, get_mont_inv_modulus().limbs64, get_modulus<1>().limbs64, r.limbs_storage.limbs64); return mont_reduce(r); + // return Wide::get_lower(r); } friend HOST_DEVICE bool operator==(const Field& xs, const Field& ys) diff --git a/icicle/tests/test_field_api.cpp b/icicle/tests/test_field_api.cpp index 4cbfab454..7b8fdfcb7 100644 --- a/icicle/tests/test_field_api.cpp +++ b/icicle/tests/test_field_api.cpp @@ -96,7 +96,7 @@ TYPED_TEST(FieldApiTest, FieldLimbsTypeSanityTest) auto br = TypeParam::to_montgomery(b); auto rr = TypeParam::mont_mult(ar,br); auto r = TypeParam::from_montgomery(rr); - // if (r != a*b){ + if (r != a*b){ std::cout << "a: "<< a << std::endl; std::cout << "b: "<< b << std::endl; std::cout << "ar: "<< ar << std::endl; @@ -108,7 +108,7 @@ TYPED_TEST(FieldApiTest, FieldLimbsTypeSanityTest) std::cout << "R: "< Date: Tue, 1 Oct 2024 12:14:18 +0300 Subject: [PATCH 03/22] test performance --- icicle/backend/cpu/include/cpu_ntt.h | 2 +- icicle/backend/cpu/include/cpu_ntt_domain.h | 31 ++++++++++--------- icicle/include/icicle/fields/field.h | 1 + .../fields/snark_fields/bls12_377_base.h | 2 ++ icicle/tests/test_curve_api.cpp | 7 +++-- icicle/tests/test_field_api.cpp | 17 +++++----- 6 files changed, 33 insertions(+), 27 deletions(-) diff --git a/icicle/backend/cpu/include/cpu_ntt.h b/icicle/backend/cpu/include/cpu_ntt.h index c6724d912..5883180fe 100644 --- a/icicle/backend/cpu/include/cpu_ntt.h +++ b/icicle/backend/cpu/include/cpu_ntt.h @@ -54,7 +54,7 @@ namespace ntt_cpu { NttTaskCordinates ntt_task_cordinates = {0, 0, 0, 0, 0}; NttTasksManager ntt_tasks_manager(logn); const int nof_threads = std::thread::hardware_concurrency(); - auto tasks_manager = new TasksManager>(nof_threads - 1); + auto tasks_manager = new TasksManager>(1); NttTask* task_slot; std::unique_ptr arbitrary_coset = nullptr; const int coset_stride = ntt.find_or_generate_coset(arbitrary_coset); diff --git a/icicle/backend/cpu/include/cpu_ntt_domain.h b/icicle/backend/cpu/include/cpu_ntt_domain.h index 1b4534793..e368c8fe2 100644 --- a/icicle/backend/cpu/include/cpu_ntt_domain.h +++ b/icicle/backend/cpu/include/cpu_ntt_domain.h @@ -64,23 +64,24 @@ namespace ntt_cpu { if (s_ntt_domain.twiddles == nullptr) { // (2) build the domain - bool found_logn = false; - S omega = primitive_root; - const unsigned omegas_count = S::get_omegas_count(); - for (int i = 0; i < omegas_count; i++) { - omega = S::sqr(omega); - if (!found_logn) { - ++s_ntt_domain.max_log_size; - found_logn = omega == S::one(); - if (found_logn) break; - } - } + // bool found_logn = false; + // S omega = primitive_root; + // const unsigned omegas_count = S::get_omegas_count(); + // for (int i = 0; i < omegas_count; i++) { + // omega = S::sqr(omega); + // if (!found_logn) { + // ++s_ntt_domain.max_log_size; + // found_logn = omega == S::one(); + // if (found_logn) break; + // } + // } + s_ntt_domain.max_log_size = 21; s_ntt_domain.max_size = (int)pow(2, s_ntt_domain.max_log_size); - if (omega != S::one()) { - ICICLE_LOG_ERROR << "Primitive root provided to the InitDomain function is not a root-of-unity"; - return eIcicleError::INVALID_ARGUMENT; - } + // if (omega != S::one()) { + // ICICLE_LOG_ERROR << "Primitive root provided to the InitDomain function is not a root-of-unity"; + // return eIcicleError::INVALID_ARGUMENT; + // } // calculate twiddles // Note: radix-2 INTT needs ONE in last element (in addition to first element), therefore have n+1 elements diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 951336282..5cc72ee20 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -843,6 +843,7 @@ class Field { Wide xy = mul_wide(xs, ys); // full mult return reduce(xy); // reduce mod p + // return mont_mult(xs,ys); } static constexpr HOST_INLINE Field mont_mult(const Field& xs, const Field& ys) diff --git a/icicle/include/icicle/fields/snark_fields/bls12_377_base.h b/icicle/include/icicle/fields/snark_fields/bls12_377_base.h index ec540a7de..aee776fdd 100644 --- a/icicle/include/icicle/fields/snark_fields/bls12_377_base.h +++ b/icicle/include/icicle/fields/snark_fields/bls12_377_base.h @@ -7,6 +7,8 @@ namespace bls12_377 { struct fq_config { static constexpr storage<12> modulus = {0x00000001, 0x8508c000, 0x30000000, 0x170b5d44, 0xba094800, 0x1ef3622f, 0x00f5138f, 0x1a22d9f3, 0x6ca1493b, 0xc63b05c0, 0x17c510ea, 0x01ae3a46}; + static constexpr storage<12> mont_inv_modulus = {0xffffffff, 0x8508bfff, 0xa0000000, 0xd1e94577, 0x970debff, 0x35ed1347, 0xcced7a13, 0x5b245b86, 0x806a3cec, 0x22f80141, 0xeec82e3d, 0xbfa5205f}; + PARAMS(modulus) static constexpr storage<12> rou = {0xc563b9a1, 0x7eca603c, 0x06fe0bc3, 0x06df0a43, 0x0ddff8c6, 0xb44d994a, diff --git a/icicle/tests/test_curve_api.cpp b/icicle/tests/test_curve_api.cpp index 0769df7f9..b62436b0b 100644 --- a/icicle/tests/test_curve_api.cpp +++ b/icicle/tests/test_curve_api.cpp @@ -65,10 +65,10 @@ class CurveApiTest : public ::testing::Test template void MSM_test() { - const int logn = 12; - const int batch = 2; + const int logn = 20; + const int batch = 1; const int N = 1 << logn; - const int precompute_factor = 2; + const int precompute_factor = 1; const int total_nof_elemets = batch * N; auto scalars = std::make_unique(total_nof_elemets); @@ -84,6 +84,7 @@ class CurveApiTest : public ::testing::Test config.batch_size = batch; config.are_points_shared_in_batch = true; config.precompute_factor = precompute_factor; + config.c = 15; auto run = [&](const std::string& dev_type, P* result, const char* msg, bool measure, int iters) { Device dev = {dev_type, 0}; diff --git a/icicle/tests/test_field_api.cpp b/icicle/tests/test_field_api.cpp index 7b8fdfcb7..874031905 100644 --- a/icicle/tests/test_field_api.cpp +++ b/icicle/tests/test_field_api.cpp @@ -126,7 +126,7 @@ TYPED_TEST(FieldApiTest, FieldLimbsTypeSanityTest) // a = TypeParam::Wide::get_lower(r_wide); // b = TypeParam::Wide::get_lower(r2_wide); // a = a + a; - a = a * a; + // a = a * a; } END_TIMER(MULT_sync, oss.str().c_str(), true); ASSERT_EQ(TypeParam::from_montgomery(ar), a); @@ -351,21 +351,21 @@ TYPED_TEST(FieldApiTest, ntt) int seed = time(0); srand(seed); - const bool inplace = rand() % 2; - const int logn = rand() % 16 + 3; + const bool inplace = 1; + const int logn = 15; const uint64_t N = 1 << logn; const int log_ntt_domain_size = logn + 1; - const int log_batch_size = rand() % 3; + const int log_batch_size = 0; const int batch_size = 1 << log_batch_size; - const Ordering ordering = static_cast(rand() % 4); + const Ordering ordering = static_cast(0); bool columns_batch; if (logn == 7 || logn < 4) { columns_batch = false; // currently not supported (icicle_v3/backend/cuda/src/ntt/ntt.cuh line 578) } else { - columns_batch = rand() % 2; + columns_batch = 0; } - const NTTDir dir = static_cast(rand() % 2); // 0: forward, 1: inverse - const int log_coset_stride = rand() % 3; + const NTTDir dir = static_cast(0); // 0: forward, 1: inverse + const int log_coset_stride = 0; scalar_t coset_gen; if (log_coset_stride) { coset_gen = scalar_t::omega(logn + log_coset_stride); @@ -399,6 +399,7 @@ TYPED_TEST(FieldApiTest, ntt) config.are_outputs_on_device = true; config.is_async = false; ICICLE_CHECK(ntt_init_domain(scalar_t::omega(log_ntt_domain_size), init_domain_config)); + // ntt_init_domain(scalar_t::omega(log_ntt_domain_size), init_domain_config); TypeParam *d_in, *d_out; ICICLE_CHECK(icicle_malloc_async((void**)&d_in, total_size * sizeof(TypeParam), config.stream)); ICICLE_CHECK(icicle_malloc_async((void**)&d_out, total_size * sizeof(TypeParam), config.stream)); From 779babee751962ac3bca18ff0cf4b1f417c08565 Mon Sep 17 00:00:00 2001 From: Yuval Shekel Date: Sun, 3 Nov 2024 11:31:45 +0200 Subject: [PATCH 04/22] tmp --- icicle/backend/cpu/include/ntt_cpu.h | 1 + icicle/include/icicle/fields/field.h | 284 +++++++++++++++++++++++ icicle/include/icicle/fields/host_math.h | 34 ++- icicle/include/icicle/fields/storage.h | 34 +-- icicle/include/icicle/utils/modifiers.h | 7 +- icicle/tests/test_curve_api.cpp | 108 ++++++++- icicle/tests/test_field_api.cpp | 36 +-- 7 files changed, 450 insertions(+), 54 deletions(-) diff --git a/icicle/backend/cpu/include/ntt_cpu.h b/icicle/backend/cpu/include/ntt_cpu.h index 725de780e..aa4e75414 100644 --- a/icicle/backend/cpu/include/ntt_cpu.h +++ b/icicle/backend/cpu/include/ntt_cpu.h @@ -28,6 +28,7 @@ namespace ntt_cpu { NttCpu(uint32_t logn, NTTDir direction, const NTTConfig& config, const E* input, E* output) : input(input), ntt_data(logn, output, config, direction), ntt_tasks_manager(ntt_data.ntt_sub_logn, logn), tasks_manager(std::make_unique>>(std::thread::hardware_concurrency() - 1)) + // tasks_manager(std::make_unique>>(1)) { } eIcicleError run(); diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 46caa0f3b..bd209397e 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -810,11 +810,295 @@ class Field return *this; } + // #if defined(__CUDACC__) +#if 1 friend HOST_DEVICE Field operator*(const Field& xs, const Field& ys) { Wide xy = mul_wide(xs, ys); // full mult return reduce(xy); // reduce mod p } +#else + + // #if defined(__GNUC__) && !defined(__NVCC__) && !defined(__clang__) + // #pragma GCC optimize("no-strict-aliasing") + // #endif + + friend HOST_DEVICE_INLINE Field original_multiplier(const Field& xs, const Field& ys) + { + Wide xy = mul_wide(xs, ys); // full mult + return reduce(xy); // reduce mod p + } + + // #include + + /* GNARK CODE START*/ + // those two funcs are copied from bits.go implementation (/usr/local/go/src/math/bits/bits.go) + static HOST_DEVICE_INLINE void Mul64(uint64_t x, uint64_t y, uint64_t& hi, uint64_t& lo) + { + // constexpr uint64_t mask32 = 4294967295ULL; // 2^32 - 1 + // uint64_t x0 = x & mask32; + // uint64_t x1 = x >> 32; + // uint64_t y0 = y & mask32; + // uint64_t y1 = y >> 32; + // uint64_t w0 = x0 * y0; + // uint64_t t = x1 * y0 + w0 >> 32; + // uint64_t w1 = t & mask32; + // uint64_t w2 = t >> 32; + // w1 += x0 * y1; + // hi = x1 * y1 + w2 + w1 >> 32; + // lo = x * y; + + // #if defined(__GNUC__) || defined(__clang__) + // lo = _umul128(x, y, &hi); + // #else + __uint128_t result = static_cast<__uint128_t>(x) * y; + hi = static_cast(result >> 64); + lo = static_cast(result); + // #endif + } + + // #if defined(__GNUC__) || defined(__clang__) + // #include + // #endif + + static HOST_DEVICE_INLINE void Add64(uint64_t x, uint64_t y, uint64_t carry, uint64_t& sum, uint64_t& carry_out) + { + // #if defined(__GNUC__) || defined(__clang__) + // carry_out = _addcarry_u64(carry, x, y, &sum); + // #else + sum = x + y + carry; + carry_out = ((x & y) | ((x | y) & ~sum)) >> 63; + // #endif + } + + static HOST_DEVICE_INLINE void Sub64(uint64_t x, uint64_t y, uint64_t borrow, uint64_t& diff, uint64_t& borrowOut) + { + // #if defined(__GNUC__) || defined(__clang__) + // borrowOut = _subborrow_u64(borrow, x, y, &diff); + // #else + diff = x - y - borrow; + // See Sub32 for the bit logic. + borrowOut = ((~x & y) | (~(x ^ y) & diff)) >> 63; + // #endif + } + + static HOST_DEVICE_INLINE bool smallerThanModulus(const Field& z) + { + // for bn254 specifically + constexpr uint64_t q0 = 4891460686036598785ULL; + constexpr uint64_t q1 = 2896914383306846353ULL; + constexpr uint64_t q2 = 13281191951274694749ULL; + constexpr uint64_t q3 = 3486998266802970665ULL; + return ( + z.limbs_storage.limbs64[3] < q3 || + (z.limbs_storage.limbs64[3] == q3 && + (z.limbs_storage.limbs64[2] < q2 || + (z.limbs_storage.limbs64[2] == q2 && + (z.limbs_storage.limbs64[1] < q1 || + (z.limbs_storage.limbs64[1] == q1 && (z.limbs_storage.limbs64[0] < q0))))))); + } + + // #define WITH_MONT_CONVERSIONS + + #ifdef WITH_MONT_CONVERSIONS + friend HOST_DEVICE Field operator*(const Field& x_orig, const Field& y_orig) + #else + friend HOST_DEVICE Field operator*(const Field& x, const Field& y) + #endif + { + // for bn254 specifically + constexpr uint64_t qInvNeg = 14042775128853446655ULL; + constexpr uint64_t q0 = 4891460686036598785ULL; + constexpr uint64_t q1 = 2896914383306846353ULL; + constexpr uint64_t q2 = 13281191951274694749ULL; + constexpr uint64_t q3 = 3486998266802970665ULL; + + #ifdef WITH_MONT_CONVERSIONS + // auto x = original_multiplier(x_orig, original_multiplier(Field{CONFIG::montgomery_r}, + // Field{CONFIG::montgomery_r})); auto y = original_multiplier(y_orig, + // original_multiplier(Field{CONFIG::montgomery_r}, Field{CONFIG::montgomery_r})); + auto x = original_multiplier(x_orig, Field{CONFIG::montgomery_r}); + auto y = original_multiplier(y_orig, Field{CONFIG::montgomery_r}); + #endif + + Field z{}; + uint64_t t0, t1, t2, t3; + uint64_t u0, u1, u2, u3; + + { + uint64_t c0, c1, c2, _; + uint64_t v = x.limbs_storage.limbs64[0]; + Mul64(v, y.limbs_storage.limbs64[0], u0, t0); + Mul64(v, y.limbs_storage.limbs64[1], u1, t1); + Mul64(v, y.limbs_storage.limbs64[2], u2, t2); + Mul64(v, y.limbs_storage.limbs64[3], u3, t3); + Add64(u0, t1, 0, t1, c0); + Add64(u1, t2, c0, t2, c0); + Add64(u2, t3, c0, t3, c0); + Add64(u3, 0, c0, c2, _); + + uint64_t m = qInvNeg * t0; + + Mul64(m, q0, u0, c1); + Add64(t0, c1, 0, _, c0); + Mul64(m, q1, u1, c1); + Add64(t1, c1, c0, t0, c0); + Mul64(m, q2, u2, c1); + Add64(t2, c1, c0, t1, c0); + Mul64(m, q3, u3, c1); + + Add64(0, c1, c0, t2, c0); + Add64(u3, 0, c0, u3, _); + Add64(u0, t0, 0, t0, c0); + Add64(u1, t1, c0, t1, c0); + Add64(u2, t2, c0, t2, c0); + Add64(c2, 0, c0, c2, _); + Add64(t3, t2, 0, t2, c0); + Add64(u3, c2, c0, t3, _); + } + + { + uint64_t c0, c1, c2, _; + uint64_t v = x.limbs_storage.limbs64[1]; + Mul64(v, y.limbs_storage.limbs64[0], u0, c1); + Add64(c1, t0, 0, t0, c0); + Mul64(v, y.limbs_storage.limbs64[1], u1, c1); + Add64(c1, t1, c0, t1, c0); + Mul64(v, y.limbs_storage.limbs64[2], u2, c1); + Add64(c1, t2, c0, t2, c0); + Mul64(v, y.limbs_storage.limbs64[3], u3, c1); + Add64(c1, t3, c0, t3, c0); + + Add64(0, 0, c0, c2, _); + Add64(u0, t1, 0, t1, c0); + Add64(u1, t2, c0, t2, c0); + Add64(u2, t3, c0, t3, c0); + Add64(u3, c2, c0, c2, _); + + uint64_t m = qInvNeg * t0; + + Mul64(m, q0, u0, c1); + Add64(t0, c1, 0, _, c0); + Mul64(m, q1, u1, c1); + Add64(t1, c1, c0, t0, c0); + Mul64(m, q2, u2, c1); + Add64(t2, c1, c0, t1, c0); + Mul64(m, q3, u3, c1); + + Add64(0, c1, c0, t2, c0); + Add64(u3, 0, c0, u3, _); + Add64(u0, t0, 0, t0, c0); + Add64(u1, t1, c0, t1, c0); + Add64(u2, t2, c0, t2, c0); + Add64(c2, 0, c0, c2, _); + Add64(t3, t2, 0, t2, c0); + Add64(u3, c2, c0, t3, _); + } + + { + uint64_t c0, c1, c2, _; + uint64_t v = x.limbs_storage.limbs64[2]; + Mul64(v, y.limbs_storage.limbs64[0], u0, c1); + Add64(c1, t0, 0, t0, c0); + Mul64(v, y.limbs_storage.limbs64[1], u1, c1); + Add64(c1, t1, c0, t1, c0); + Mul64(v, y.limbs_storage.limbs64[2], u2, c1); + Add64(c1, t2, c0, t2, c0); + Mul64(v, y.limbs_storage.limbs64[3], u3, c1); + Add64(c1, t3, c0, t3, c0); + + Add64(0, 0, c0, c2, _); + Add64(u0, t1, 0, t1, c0); + Add64(u1, t2, c0, t2, c0); + Add64(u2, t3, c0, t3, c0); + Add64(u3, c2, c0, c2, _); + + uint64_t m = qInvNeg * t0; + + Mul64(m, q0, u0, c1); + Add64(t0, c1, 0, _, c0); + Mul64(m, q1, u1, c1); + Add64(t1, c1, c0, t0, c0); + Mul64(m, q2, u2, c1); + Add64(t2, c1, c0, t1, c0); + Mul64(m, q3, u3, c1); + + Add64(0, c1, c0, t2, c0); + Add64(u3, 0, c0, u3, _); + Add64(u0, t0, 0, t0, c0); + Add64(u1, t1, c0, t1, c0); + Add64(u2, t2, c0, t2, c0); + Add64(c2, 0, c0, c2, _); + Add64(t3, t2, 0, t2, c0); + Add64(u3, c2, c0, t3, _); + } + + { + uint64_t c0, c1, c2, _; + uint64_t v = x.limbs_storage.limbs64[3]; + Mul64(v, y.limbs_storage.limbs64[0], u0, c1); + Add64(c1, t0, 0, t0, c0); + Mul64(v, y.limbs_storage.limbs64[1], u1, c1); + Add64(c1, t1, c0, t1, c0); + Mul64(v, y.limbs_storage.limbs64[2], u2, c1); + Add64(c1, t2, c0, t2, c0); + Mul64(v, y.limbs_storage.limbs64[3], u3, c1); + Add64(c1, t3, c0, t3, c0); + + Add64(0, 0, c0, c2, _); + Add64(u0, t1, 0, t1, c0); + Add64(u1, t2, c0, t2, c0); + Add64(u2, t3, c0, t3, c0); + Add64(u3, c2, c0, c2, _); + + uint64_t m = qInvNeg * t0; + + Mul64(m, q0, u0, c1); + Add64(t0, c1, 0, _, c0); + Mul64(m, q1, u1, c1); + Add64(t1, c1, c0, t0, c0); + Mul64(m, q2, u2, c1); + Add64(t2, c1, c0, t1, c0); + Mul64(m, q3, u3, c1); + + Add64(0, c1, c0, t2, c0); + Add64(u3, 0, c0, u3, _); + Add64(u0, t0, 0, t0, c0); + Add64(u1, t1, c0, t1, c0); + Add64(u2, t2, c0, t2, c0); + Add64(c2, 0, c0, c2, _); + Add64(t3, t2, 0, t2, c0); + Add64(u3, c2, c0, t3, _); + } + + z.limbs_storage.limbs64[0] = t0; + z.limbs_storage.limbs64[1] = t1; + z.limbs_storage.limbs64[2] = t2; + z.limbs_storage.limbs64[3] = t3; + + if (smallerThanModulus(z)) { + uint64_t b, _; + Sub64(z.limbs_storage.limbs64[0], q0, 0, z.limbs_storage.limbs64[0], b); + Sub64(z.limbs_storage.limbs64[1], q1, b, z.limbs_storage.limbs64[1], b); + Sub64(z.limbs_storage.limbs64[2], q2, b, z.limbs_storage.limbs64[2], b); + Sub64(z.limbs_storage.limbs64[3], q3, b, z.limbs_storage.limbs64[3], _); + } + + #ifdef WITH_MONT_CONVERSIONS + z = original_multiplier(z, Field{CONFIG::montgomery_r_inv}); + // z = original_multiplier(z, original_multiplier(Field{CONFIG::montgomery_r_inv}, + // Field{CONFIG::montgomery_r_inv})); + #endif + return z; + } + + // #if defined(__GNUC__) && !defined(__NVCC__) && !defined(__clang__) + // #pragma GCC reset_options + // #endif + +#endif // __CUDACC__ + + /*GNARK CODE END*/ friend HOST_DEVICE bool operator==(const Field& xs, const Field& ys) { diff --git a/icicle/include/icicle/fields/host_math.h b/icicle/include/icicle/fields/host_math.h index 90162caec..663d33949 100644 --- a/icicle/include/icicle/fields/host_math.h +++ b/icicle/include/icicle/fields/host_math.h @@ -84,14 +84,46 @@ namespace host_math { return result; } - static constexpr __host__ uint64_t madc_cc_64(const uint64_t x, const uint64_t y, const uint64_t z, uint64_t& carry) + static inline __host__ __uint128_t mul64(uint64_t x, uint64_t y) + { + uint64_t high, low; + asm("mulq %3" : "=d"(high), "=a"(low) : "a"(x), "r"(y) : "cc"); + return (static_cast<__uint128_t>(high) << 64) | low; + } + + static __host__ uint64_t madc_cc_64(const uint64_t x, const uint64_t y, const uint64_t z, uint64_t& carry) { __uint128_t r = static_cast<__uint128_t>(x) * y + z + carry; + // __uint128_t r = mul64(x, y) + z + carry; + carry = (uint64_t)(r >> 64); uint64_t result = r & 0xffffffffffffffff; return result; } +#include + + // static inline __host__ uint64_t madc_cc_64(const uint64_t x, const uint64_t y, const uint64_t z, uint64_t& carry) + // { + // uint64_t high, low; + + // // Perform multiplication of x * y + // asm("mulq %3\n\t" // x * y -> result in RDX:RAX + // "addq %4, %%rax\n\t" // Add z to the low 64 bits (RAX), setting flags + // "adcq $0, %%rdx\n\t" // Propagate carry to high 64 bits (RDX) + // "addq %5, %%rax\n\t" // Add the input carry to RAX, setting flags + // "adcq $0, %%rdx" // Propagate any carry to RDX + // : "=a"(low), "=d"(high) // Output operands + // : "a"(x), "r"(y), "r"(z), "r"(carry) // Input operands + // : "cc"); // Clobbers + + // // Set carry to the high 64 bits of the result + // carry = high; + + // // Return the low 64 bits of the result + // return low; + // } + template struct carry_chain { unsigned index; diff --git a/icicle/include/icicle/fields/storage.h b/icicle/include/icicle/fields/storage.h index 76245db16..14c4d1a01 100644 --- a/icicle/include/icicle/fields/storage.h +++ b/icicle/include/icicle/fields/storage.h @@ -4,44 +4,25 @@ #define LIMBS_ALIGNMENT(x) ((x) % 4 == 0 ? 16 : ((x) % 2 == 0 ? 8 : 4)) template -struct -#ifdef __CUDA_ARCH__ - __align__(LIMBS_ALIGNMENT(LIMBS_COUNT)) -#endif - storage; +struct ALIGN(LIMBS_ALIGNMENT(LIMBS_COUNT)) storage; // Specialization for LIMBS_COUNT == 1 template <> -struct -#ifdef __CUDA_ARCH__ - __align__(LIMBS_ALIGNMENT(1)) -#endif - storage<1> -{ +struct ALIGN(LIMBS_ALIGNMENT(1)) storage<1> { static constexpr unsigned LC = 1; uint32_t limbs[1]; }; // Specialization for LIMBS_COUNT == 3 template <> -struct -#ifdef __CUDA_ARCH__ - __align__(LIMBS_ALIGNMENT(1)) -#endif - storage<3> -{ +struct ALIGN(LIMBS_ALIGNMENT(3)) storage<3> { static constexpr unsigned LC = 3; uint32_t limbs[3]; }; // General template for LIMBS_COUNT > 1 template -struct -#ifdef __CUDA_ARCH__ - __align__(LIMBS_ALIGNMENT(LIMBS_COUNT)) -#endif - storage -{ +struct ALIGN(LIMBS_ALIGNMENT(LIMBS_COUNT)) storage { static_assert(LIMBS_COUNT % 2 == 0, "odd number of limbs is not supported\n"); static constexpr unsigned LC = LIMBS_COUNT; union { // works only with even LIMBS_COUNT @@ -51,11 +32,6 @@ struct }; template -struct -#ifdef __CUDA_ARCH__ - __align__(LIMBS_ALIGNMENT(LIMBS_COUNT)) -#endif - storage_array -{ +struct ALIGN(LIMBS_ALIGNMENT(LIMBS_COUNT)) storage_array { storage storages[OMEGAS_COUNT]; }; \ No newline at end of file diff --git a/icicle/include/icicle/utils/modifiers.h b/icicle/include/icicle/utils/modifiers.h index a8728d279..ba7e0ce85 100644 --- a/icicle/include/icicle/utils/modifiers.h +++ b/icicle/include/icicle/utils/modifiers.h @@ -16,15 +16,18 @@ #define DEVICE_INLINE __device__ INLINE_MACRO #define HOST_DEVICE __host__ __device__ #define HOST_DEVICE_INLINE HOST_DEVICE INLINE_MACRO + #define ALIGN(size) __align__(size) #else // not CUDA #define INLINE_MACRO #define UNROLL - #define HOST_INLINE + #define HOST_INLINE inline #define DEVICE_INLINE #define HOST_DEVICE - #define HOST_DEVICE_INLINE + #define HOST_DEVICE_INLINE inline + #define HOST_DEVICE_FORCE_INLINE __forceinline__ #define __host__ #define __device__ + #define ALIGN(size) alignas(size) #endif #if defined(_MSC_VER) diff --git a/icicle/tests/test_curve_api.cpp b/icicle/tests/test_curve_api.cpp index cc577b58a..28a3ee7b2 100644 --- a/icicle/tests/test_curve_api.cpp +++ b/icicle/tests/test_curve_api.cpp @@ -263,14 +263,14 @@ TEST_F(CurveApiTest, ecnttDeviceMem) { // (TODO) Randomize configuration const bool inplace = false; - const int logn = 10; + const int logn = 14; const uint64_t N = 1 << logn; const int log_ntt_domain_size = logn; const int log_batch_size = 0; const int batch_size = 1 << log_batch_size; - const Ordering ordering = static_cast(0); + const Ordering ordering = static_cast(1); // NR bool columns_batch = false; - const NTTDir dir = static_cast(0); // 0: forward, 1: inverse + const NTTDir dir = static_cast(1); // 0: forward, 1: inverse const int total_size = N * batch_size; auto input = std::make_unique(total_size); @@ -303,6 +303,11 @@ TEST_F(CurveApiTest, ecnttDeviceMem) std::ostringstream oss; oss << dev_type << " " << msg; + + // std::cout << "press any key to proceed..."; + // int a; + // std::cin >> a; + // std::cout << "proceeding\n"; START_TIMER(NTT_sync) for (int i = 0; i < iters; ++i) { ICICLE_CHECK(ntt(d_in, N, dir, config, inplace ? d_in : d_out)); @@ -318,7 +323,7 @@ TEST_F(CurveApiTest, ecnttDeviceMem) ICICLE_CHECK(ntt_release_domain()); }; - run(s_main_target, out_main.get(), "ecntt", false /*=measure*/, 1 /*=iters*/); // warmup + // run(s_main_target, out_main.get(), "ecntt", true /*=measure*/, 1 /*=iters*/); // warmup run(s_ref_target, out_ref.get(), "ecntt", VERBOSE /*=measure*/, 1 /*=iters*/); run(s_main_target, out_main.get(), "ecntt", VERBOSE /*=measure*/, 1 /*=iters*/); // note that memcmp is tricky here because projetive points can have many representations @@ -379,6 +384,101 @@ TYPED_TEST(CurveSanity, ScalarMultTest) ASSERT_EQ(mult, expected_mult); } +TYPED_TEST(CurveSanity, ECarith) +{ + constexpr int n = 1 << 10; + auto a = std::make_unique(n); + auto b = std::make_unique(n); + auto c = std::make_unique(n); + + auto scalars = std::make_unique(n); + auto d = std::make_unique(n); + + TypeParam::rand_host_many(a.get(), n); + TypeParam::rand_host_many(b.get(), n); + scalar_t::rand_host_many(scalars.get(), n); + + START_TIMER(add); + for (int i = 0; i < n; ++i) { + c[i] = a[i] + b[i]; + } + END_TIMER(add, "ADD", true); + START_TIMER(dbl); + for (int i = 0; i < n; ++i) { + c[i] = TypeParam::dbl(c[i]); + } + END_TIMER(dbl, "DOUBLE", true); + + START_TIMER(scalarmult); + for (int i = 0; i < n; ++i) { + d[i] = c[i] * scalars[i]; + } + END_TIMER(scalarmult, "SCALAR-EC-MULT", true); +} + +TYPED_TEST(CurveSanity, FieldArith) +{ + constexpr int n = 1 << 20; + + auto scalars = std::make_unique(n); + auto scalars2 = std::make_unique(n); + auto scalars3 = std::make_unique(n); + + scalar_t::rand_host_many(scalars.get(), n); + scalar_t::rand_host_many(scalars2.get(), n); + + START_TIMER(scalarScalarmult); + for (int i = 0; i < n; ++i) { + scalars3[i] = scalars2[i] * scalars[i]; + } + END_TIMER(scalarScalarmult, "SCALAR-SCALAR-MULT", true); +} + +#include + +TYPED_TEST(CurveSanity, u64Mul) +{ + constexpr int n = 1 << 25; + + auto scalars = std::make_unique(n); + auto scalars2 = std::make_unique(n); + // auto scalars_res = std::make_unique(n); + auto scalars_res_128 = std::make_unique<__uint128_t[]>(n); + + // Initialize a random generator for uint64_t + std::random_device rd; + std::mt19937_64 gen(rd()); + std::uniform_int_distribution dis(0, UINT64_MAX); + + for (int i = 0; i < n; ++i) { + scalars[i] = dis(gen); + scalars2[i] = dis(gen); + } + + START_TIMER(u64Mult); + for (int i = 0; i < n; ++i) { + scalars_res_128[i] = scalars2[i] * scalars[i]; + // auto res = scalars2[i] * scalars[i]; + } + END_TIMER(u64Mult, "U64-MULT-native", true); + + START_TIMER(u64Mult_with128); + for (int i = 0; i < n; ++i) { + scalars_res_128[i] = static_cast<__uint128_t>(scalars2[i]) * scalars[i]; + // auto res = static_cast<__uint128_t>(scalars2[i]) * scalars[i]; + } + END_TIMER(u64Mult, "U64-MULT-via-u128", true); + + START_TIMER(u64Mult_asm); + // #pragma unroll + uint64_t high, low; + for (int i = 0; i < n; ++i) { + asm("mulq %3" : "=d"(high), "=a"(low) : "a"(scalars[i]), "r"(scalars2[i]) : "cc"); + scalars_res_128[i] = (static_cast<__uint128_t>(high) << 64) | low; + } + END_TIMER(u64Mult_asm, "U64-MULT-asm", true); +} + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); diff --git a/icicle/tests/test_field_api.cpp b/icicle/tests/test_field_api.cpp index 072142876..37fdeb5c7 100644 --- a/icicle/tests/test_field_api.cpp +++ b/icicle/tests/test_field_api.cpp @@ -301,27 +301,27 @@ TYPED_TEST(FieldApiTest, ntt) int seed = time(0); srand(seed); - const bool inplace = rand() % 2; - const int logn = rand() % 15 + 3; + const bool inplace = 0; + const int logn = 16; const uint64_t N = 1 << logn; const int log_ntt_domain_size = logn + 1; - const int log_batch_size = rand() % 3; + const int log_batch_size = 0; const int batch_size = 1 << log_batch_size; - const Ordering ordering = static_cast(rand() % 4); - bool columns_batch; - if (logn == 7 || logn < 4) { - columns_batch = false; // currently not supported (icicle_v3/backend/cuda/src/ntt/ntt.cuh line 578) - } else { - columns_batch = rand() % 2; - } - const NTTDir dir = static_cast(rand() % 2); // 0: forward, 1: inverse - const int log_coset_stride = rand() % 3; - scalar_t coset_gen; - if (log_coset_stride) { - coset_gen = scalar_t::omega(logn + log_coset_stride); - } else { - coset_gen = scalar_t::one(); - } + const Ordering ordering = static_cast(0); + bool columns_batch = false; + // if (logn == 7 || logn < 4) { + // columns_batch = false; // currently not supported (icicle_v3/backend/cuda/src/ntt/ntt.cuh line 578) + // } else { + // columns_batch = rand() % 2; + // } + const NTTDir dir = static_cast(0); // 0: forward, 1: inverse + const int log_coset_stride = 0; + scalar_t coset_gen = scalar_t::one(); + // if (log_coset_stride) { + // coset_gen = scalar_t::omega(logn + log_coset_stride); + // } else { + // coset_gen = scalar_t::one(); + // } const int total_size = N * batch_size; auto scalars = std::make_unique(total_size); From df45d694c45a1dcb97f67fa995fe57f8954d7b63 Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Tue, 5 Nov 2024 15:32:54 +0200 Subject: [PATCH 05/22] merge with montgomery cpu --- icicle/backend/cpu/include/ntt_cpu.h | 4 ++-- icicle/include/icicle/fields/field.h | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/icicle/backend/cpu/include/ntt_cpu.h b/icicle/backend/cpu/include/ntt_cpu.h index aa4e75414..ccf076d1a 100644 --- a/icicle/backend/cpu/include/ntt_cpu.h +++ b/icicle/backend/cpu/include/ntt_cpu.h @@ -27,8 +27,8 @@ namespace ntt_cpu { public: NttCpu(uint32_t logn, NTTDir direction, const NTTConfig& config, const E* input, E* output) : input(input), ntt_data(logn, output, config, direction), ntt_tasks_manager(ntt_data.ntt_sub_logn, logn), - tasks_manager(std::make_unique>>(std::thread::hardware_concurrency() - 1)) - // tasks_manager(std::make_unique>>(1)) + // tasks_manager(std::make_unique>>(std::thread::hardware_concurrency() - 1)) + tasks_manager(std::make_unique>>(1)) { } eIcicleError run(); diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 548add282..12ed7cc2f 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -843,9 +843,9 @@ class Field #if 1 friend HOST_DEVICE Field operator*(const Field& xs, const Field& ys) { - Wide xy = mul_wide(xs, ys); // full mult - return reduce(xy); // reduce mod p - // return mont_mult(xs,ys); + // Wide xy = mul_wide(xs, ys); // full mult + // return reduce(xy); // reduce mod p + return mont_mult(xs,ys); } static constexpr HOST_INLINE Field mont_mult(const Field& xs, const Field& ys) From 0f3275e48d5a1ecf8d43c76192256b9fddf0a1d9 Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Tue, 12 Nov 2024 12:52:52 +0200 Subject: [PATCH 06/22] added mont const computation --- icicle/include/icicle/fields/params_gen.h | 44 +++++++++++++++++++ .../fields/snark_fields/bls12_377_base.h | 2 +- .../fields/snark_fields/bls12_377_scalar.h | 2 +- 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/icicle/include/icicle/fields/params_gen.h b/icicle/include/icicle/fields/params_gen.h index 3626971b5..37975e91a 100644 --- a/icicle/include/icicle/fields/params_gen.h +++ b/icicle/include/icicle/fields/params_gen.h @@ -49,6 +49,48 @@ namespace params_gen { return rs; } + template + static constexpr HOST_INLINE storage get_lower(const storage<2*NLIMBS>& xs) + { + storage rs = {}; + for (unsigned i = 0; i < NLIMBS; i++) + rs.limbs[i] = xs.limbs[i]; + return rs; + } + + template + static constexpr HOST_INLINE storage get_montgomery_mult_constant(const storage& modulus) + { + //p^R-1 without carry (this is mod r) and then r-res; + storage rs = {}; + storage<2*NLIMBS> w_rs = {}; + storage tmp = {}; + storage<2*NLIMBS> w_tmp = {}; + host_math::template multiply_raw(modulus, modulus, w_tmp); + tmp = params_gen::template get_lower(w_tmp); + for (int i = 0; i < NLIMBS; i++) + { + rs.limbs[i] = modulus.limbs[i]; + } + host_math::template multiply_raw(tmp, rs, w_rs); + rs = params_gen::template get_lower(w_rs); + for (int i = 0; i < 252; i++) { + storage<2*NLIMBS> w_tmp2 = {}; + host_math::template multiply_raw(tmp, tmp, w_tmp2); + tmp = params_gen::template get_lower(w_tmp2); + storage<2*NLIMBS> w_rs2 = {}; + host_math::template multiply_raw(tmp, rs, w_rs2); + rs = params_gen::template get_lower(w_rs2); + } + storage mont_r = {}; + for (int i = 0; i < NLIMBS; i++) + { + mont_r.limbs[i] = 0; + } + host_math::template add_sub_limbs(mont_r, rs, rs); + return rs; + } + constexpr unsigned floorlog2(uint32_t x) { return x == 1 ? 0 : 1 + floorlog2(x >> 1); } template @@ -116,6 +158,8 @@ namespace params_gen { params_gen::template get_montgomery_constant(modulus); \ static constexpr storage montgomery_r_inv = \ params_gen::template get_montgomery_constant(modulus); \ + static constexpr storage mont_inv_modulus = \ + params_gen::template get_montgomery_mult_constant(modulus); \ static constexpr unsigned num_of_reductions = \ params_gen::template num_of_reductions(modulus, m); diff --git a/icicle/include/icicle/fields/snark_fields/bls12_377_base.h b/icicle/include/icicle/fields/snark_fields/bls12_377_base.h index aee776fdd..3feccb729 100644 --- a/icicle/include/icicle/fields/snark_fields/bls12_377_base.h +++ b/icicle/include/icicle/fields/snark_fields/bls12_377_base.h @@ -7,7 +7,7 @@ namespace bls12_377 { struct fq_config { static constexpr storage<12> modulus = {0x00000001, 0x8508c000, 0x30000000, 0x170b5d44, 0xba094800, 0x1ef3622f, 0x00f5138f, 0x1a22d9f3, 0x6ca1493b, 0xc63b05c0, 0x17c510ea, 0x01ae3a46}; - static constexpr storage<12> mont_inv_modulus = {0xffffffff, 0x8508bfff, 0xa0000000, 0xd1e94577, 0x970debff, 0x35ed1347, 0xcced7a13, 0x5b245b86, 0x806a3cec, 0x22f80141, 0xeec82e3d, 0xbfa5205f}; + // static constexpr storage<12> mont_inv_modulus = {0xffffffff, 0x8508bfff, 0xa0000000, 0xd1e94577, 0x970debff, 0x35ed1347, 0xcced7a13, 0x5b245b86, 0x806a3cec, 0x22f80141, 0xeec82e3d, 0xbfa5205f}; PARAMS(modulus) diff --git a/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h b/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h index 9b88bc858..2e366ae53 100644 --- a/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h +++ b/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h @@ -8,7 +8,7 @@ namespace bls12_377 { struct fp_config { static constexpr storage<8> modulus = {0x00000001, 0x0a118000, 0xd0000001, 0x59aa76fe, 0x5c37b001, 0x60b44d1e, 0x9a2ca556, 0x12ab655e}; - static constexpr storage<8> mont_inv_modulus = {0xffffffff, 0xa117fff, 0x90000001, 0x452217cc, 0x4790a000, 0x249765c3, 0x68b29556, 0x6992d0fa}; + // static constexpr storage<8> mont_inv_modulus = {0xffffffff, 0xa117fff, 0x90000001, 0x452217cc, 0x4790a000, 0x249765c3, 0x68b29556, 0x6992d0fa}; PARAMS(modulus) static constexpr storage<8> rou = {0xec2a895e, 0x476ef4a4, 0x63e3f04a, 0x9b506ee3, From 9cdb4bd838cbbf4b765f8637323311fcd0d398da Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Wed, 13 Nov 2024 15:36:16 +0200 Subject: [PATCH 07/22] field tests pass --- icicle/backend/cpu/include/cpu_ntt_domain.h | 32 +- icicle/include/icicle/curves/projective.h | 15 + icicle/include/icicle/fields/field.h | 513 +++++++++++++++++++- icicle/include/icicle/fields/host_math.h | 17 + icicle/include/icicle/fields/params_gen.h | 35 +- icicle/tests/test_field_api.cpp | 54 ++- 6 files changed, 614 insertions(+), 52 deletions(-) diff --git a/icicle/backend/cpu/include/cpu_ntt_domain.h b/icicle/backend/cpu/include/cpu_ntt_domain.h index ef09c920f..d3c83ebd0 100644 --- a/icicle/backend/cpu/include/cpu_ntt_domain.h +++ b/icicle/backend/cpu/include/cpu_ntt_domain.h @@ -83,24 +83,24 @@ namespace ntt_cpu { if (s_ntt_domain.twiddles == nullptr) { // (2) build the domain - // bool found_logn = false; - // S omega = primitive_root; - // const unsigned omegas_count = S::get_omegas_count(); - // for (int i = 0; i < omegas_count; i++) { - // omega = S::sqr(omega); - // if (!found_logn) { - // ++s_ntt_domain.max_log_size; - // found_logn = omega == S::one(); - // if (found_logn) break; - // } - // } - s_ntt_domain.max_log_size = 21; + bool found_logn = false; + S omega = primitive_root; + const unsigned omegas_count = S::get_omegas_count(); + for (int i = 0; i < omegas_count; i++) { + omega = S::sqr(omega); + if (!found_logn) { + ++s_ntt_domain.max_log_size; + found_logn = omega == S::one(); + if (found_logn) break; + } + } + // s_ntt_domain.max_log_size = 21; s_ntt_domain.max_size = (int)pow(2, s_ntt_domain.max_log_size); - // if (omega != S::one()) { - // ICICLE_LOG_ERROR << "Primitive root provided to the InitDomain function is not a root-of-unity"; - // return eIcicleError::INVALID_ARGUMENT; - // } + if (omega != S::one()) { + ICICLE_LOG_ERROR << "Primitive root provided to the InitDomain function is not a root-of-unity"; + return eIcicleError::INVALID_ARGUMENT; + } // calculate twiddles // Note: radix-2 INTT needs ONE in last element (in addition to first element), therefore have n+1 elements diff --git a/icicle/include/icicle/curves/projective.h b/icicle/include/icicle/curves/projective.h index 01d439c64..0a2feb323 100644 --- a/icicle/include/icicle/curves/projective.h +++ b/icicle/include/icicle/curves/projective.h @@ -169,6 +169,21 @@ class Projective const auto t30 = FF::mul_wide(t19, t07); // t30 ← t19 · t07 < 2 const auto t31 = FF::mul_wide(t21, t12); // t31 ← t21 · t12 < 2 const FF Z3 = FF::reduce(t31 + t30); // Z3 ← t31 + t30 < 2 + // const auto t24 = FF::mul_widez(t12.limbs_storage, t23.limbs_storage); // t24 ← t12 · t23 < 2 + // const auto t25 = FF::mul_widez(t07.limbs_storage, t22.limbs_storage); // t25 ← t07 · t22 < 2 + // typename FF::Wide W3 = typename FF::Wide{t25} - typename FF::Wide{t24}; // X3 ← t25 − t24 < 2 + // FF::redc_wide_inplacez(W3.limbs_storage); // X3 ← t25 − t24 < 2 + // const auto X3 = FF::Wide::get_lower(W3); + // const auto t27 = FF::mul_widez(t23.limbs_storage, t19.limbs_storage); // t27 ← t23 · t19 < 2 + // const auto t28 = FF::mul_widez(t22.limbs_storage, t21.limbs_storage); // t28 ← t22 · t21 < 2 + // W3 = typename FF::Wide{t28} + typename FF::Wide{t27}; // Y3 ← t28 + t27 < 2 + // FF::redc_wide_inplacez(W3.limbs_storage); // Y3 ← t28 + t27 < 2 + // const auto Y3 = FF::Wide::get_lower(W3); + // const auto t30 = FF::mul_widez(t19.limbs_storage, t07.limbs_storage); // t30 ← t19 · t07 < 2 + // const auto t31 = FF::mul_widez(t21.limbs_storage, t12.limbs_storage); // t31 ← t21 · t12 < 2 + // W3 = typename FF::Wide{t31} + typename FF::Wide{t30}; // Z3 ← t31 + t30 < 2 + // FF::redc_wide_inplacez(W3.limbs_storage); // Z3 ← t31 + t30 < 2 + // const auto Z3 = FF::Wide::get_lower(W3); return {X3, Y3, Z3}; } diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 12ed7cc2f..f8dc177e9 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -35,6 +35,11 @@ using namespace icicle; +// #ifdef __CUDA_ARCH__ +// __device__ __location__(constant) uint32_t FF_BLS12377_INV = 0xffffffff; +// #else +// static constexpr uint32_t FF_BLS12377_INV = 0xffffffff; +// #endif // __CUDA_ARCH__ template class Field { @@ -53,7 +58,12 @@ class Field for (int i = 1; i < TLC; i++) { scalar.limbs[i] = 0; } - return Field{scalar}; + // printf("what?\n"); + // std::cout < CONFIG::omegas_count) { THROW_ICICLE_ERR(eIcicleError::INVALID_ARGUMENT, "Field: Invalid omega index"); } - Field omega = Field{CONFIG::rou}; + Field omega = to_montgomery(Field{CONFIG::rou}); + // Field omega = Field{CONFIG::rou}; for (int i = 0; i < CONFIG::omegas_count - logn; i++) omega = sqr(omega); + std::cout << "omega: " << omega < const inv = CONFIG::inv; - return Field{inv.storages[logn - 1]}; + return to_montgomery(Field{inv.storages[logn - 1]}); } static constexpr HOST_INLINE unsigned get_omegas_count() @@ -124,6 +136,7 @@ class Field static constexpr HOST_DEVICE_INLINE ff_storage get_mont_inv_modulus() { return CONFIG::mont_inv_modulus; } static constexpr HOST_DEVICE_INLINE ff_storage get_mont_r() { return CONFIG::montgomery_r; } + static constexpr HOST_DEVICE_INLINE ff_storage get_mont_r_sqr() { return CONFIG::montgomery_r_sqr; } static constexpr HOST_DEVICE_INLINE ff_storage get_mont_r_inv() { return CONFIG::montgomery_r_inv; } /** * A new addition to the config file - the number of times to reduce in [reduce](@ref reduce) function. @@ -304,6 +317,476 @@ class Field } #ifdef __CUDACC__ + +template struct carry_chainz { + unsigned index; + + constexpr __device__ __forceinline__ carry_chainz() : index(0) {} + + __device__ __forceinline__ uint32_t add(const uint32_t x, const uint32_t y) { + index++; + if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) + return ptx::add(x, y); + else if (index == 1 && !CARRY_IN) + return ptx::add_cc(x, y); + else if (index < OPS_COUNT || CARRY_OUT) + return ptx::addc_cc(x, y); + else + return ptx::addc(x, y); + } + + __device__ __forceinline__ uint32_t sub(const uint32_t x, const uint32_t y) { + index++; + if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) + return ptx::sub(x, y); + else if (index == 1 && !CARRY_IN) + return ptx::sub_cc(x, y); + else if (index < OPS_COUNT || CARRY_OUT) + return ptx::subc_cc(x, y); + else + return ptx::subc(x, y); + } + + __device__ __forceinline__ uint32_t mad_lo(const uint32_t x, const uint32_t y, const uint32_t z) { + index++; + if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) + return ptx::mad_lo(x, y, z); + else if (index == 1 && !CARRY_IN) + return ptx::mad_lo_cc(x, y, z); + else if (index < OPS_COUNT || CARRY_OUT) + return ptx::madc_lo_cc(x, y, z); + else + return ptx::madc_lo(x, y, z); + } + + __device__ __forceinline__ uint32_t mad_hi(const uint32_t x, const uint32_t y, const uint32_t z) { + index++; + if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) + return ptx::mad_hi(x, y, z); + else if (index == 1 && !CARRY_IN) + return ptx::mad_hi_cc(x, y, z); + else if (index < OPS_COUNT || CARRY_OUT) + return ptx::madc_hi_cc(x, y, z); + else + return ptx::madc_hi(x, y, z); + } +}; + + + + // add or subtract limbs + template static constexpr DEVICE_INLINE uint32_t add_sub_limbs_devicez(const ff_storage &xs, const ff_storage &ys, ff_storage &rs) { + const uint32_t *x = xs.limbs; + const uint32_t *y = ys.limbs; + uint32_t *r = rs.limbs; + carry_chainz chain; +#pragma unroll + for (unsigned i = 0; i < TLC; i++) + r[i] = SUBTRACT ? chain.sub(x[i], y[i]) : chain.add(x[i], y[i]); + if (!CARRY_OUT) + return 0; + return SUBTRACT ? chain.sub(0, 0) : chain.add(0, 0); + } + + // If we want, we could make "2*TLC" a template parameter to deduplicate with "ff_storage" overload, but that's a minor issue. + template + static constexpr DEVICE_INLINE uint32_t add_sub_limbs_devicez(const ff_wide_storage &xs, const ff_wide_storage &ys, ff_wide_storage &rs) { + const uint32_t *x = xs.limbs; + const uint32_t *y = ys.limbs; + uint32_t *r = rs.limbs; + carry_chainz chain; +#pragma unroll + for (unsigned i = 0; i < 2 * TLC; i++) { + r[i] = SUBTRACT ? chain.sub(x[i], y[i]) : chain.add(x[i], y[i]); + } + if (!CARRY_OUT) + return 0; + return SUBTRACT ? chain.sub(0, 0) : chain.add(0, 0); + } + + template static constexpr DEVICE_INLINE uint32_t add_sub_limbsz(const T &xs, const T &ys, T &rs) { + // No need for static_assert(std::is_same::value || std::is_same::value). + // Instantiation will fail if appropriate add_sub_limbs_device overload does not exist. + return add_sub_limbs_devicez(xs, ys, rs); + } + + template static constexpr DEVICE_INLINE uint32_t add_limbsz(const T &xs, const T &ys, T &rs) { + return add_sub_limbsz(xs, ys, rs); + } + + template static constexpr DEVICE_INLINE uint32_t sub_limbsz(const T &xs, const T &ys, T &rs) { + return add_sub_limbsz(xs, ys, rs); + } + + // return xs == 0 with field operands + static constexpr DEVICE_INLINE bool is_zero_devicez(const ff_storage &xs) { + const uint32_t *x = xs.limbs; + uint32_t limbs_or = x[0]; +#pragma unroll + for (unsigned i = 1; i < TLC; i++) + limbs_or |= x[i]; + return limbs_or == 0; + } + + static constexpr DEVICE_INLINE bool is_zeroz(const ff_storage &xs) { + return is_zero_devicez(xs); + } + + // return xs == ys with field operands + static constexpr DEVICE_INLINE bool eq_devicez(const ff_storage &xs, const ff_storage &ys) { + const uint32_t *x = xs.limbs; + const uint32_t *y = ys.limbs; + uint32_t limbs_or = x[0] ^ y[0]; +#pragma unroll + for (unsigned i = 1; i < TLC; i++) + limbs_or |= x[i] ^ y[i]; + return limbs_or == 0; + } + + static constexpr DEVICE_INLINE bool eqz(const ff_storage &xs, const ff_storage &ys) { + return eq_devicez(xs, ys); + } + + template static constexpr DEVICE_INLINE ff_storage reducez(const ff_storage &xs) { + if (REDUCTION_SIZE == 0) + return xs; + const ff_storage modulus = get_modulus(); + ff_storage rs = {}; + return sub_limbsz(xs, modulus, rs) ? xs : rs; + } + + template static constexpr DEVICE_INLINE ff_wide_storage reduce_widez(const ff_wide_storage &xs) { + if (REDUCTION_SIZE == 0) + return xs; + const ff_wide_storage modulus_squared = get_modulus_squared(); + ff_wide_storage rs = {}; + return sub_limbsz(xs, modulus_squared, rs) ? xs : rs; + } + + // return xs + ys with field operands + template static constexpr DEVICE_INLINE ff_storage addz(const ff_storage &xs, const ff_storage &ys) { + ff_storage rs = {}; + add_limbsz(xs, ys, rs); + return reducez(rs); + } + + template static constexpr DEVICE_INLINE ff_wide_storage add_widez(const ff_wide_storage &xs, const ff_wide_storage &ys) { + ff_wide_storage rs = {}; + add_limbsz(xs, ys, rs); + return reduce_widez(rs); + } + + // return xs - ys with field operands + template static DEVICE_INLINE ff_storage subz(const ff_storage &xs, const ff_storage &ys) { + ff_storage rs = {}; + if (REDUCTION_SIZE == 0) { + sub_limbsz(xs, ys, rs); + } else { + uint32_t carry = sub_limbsz(xs, ys, rs); + if (carry == 0) + return rs; + const ff_storage modulus = get_modulus(); + add_limbsz(rs, modulus, rs); + } + return rs; + } + + template static DEVICE_INLINE ff_wide_storage sub_widez(const ff_wide_storage &xs, const ff_wide_storage &ys) { + ff_wide_storage rs = {}; + if (REDUCTION_SIZE == 0) { + sub_limbsz(xs, ys, rs); + } else { + uint32_t carry = sub_limbsz(xs, ys, rs); + if (carry == 0) + return rs; + const ff_wide_storage modulus_squared = get_modulus_squared(); + add_limbsz(rs, modulus_squared, rs); + } + return rs; + } + + + // The following algorithms are adaptations of + // http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf, + // taken from https://github.com/z-prize/test-msm-gpu (under Apache 2.0 license) + // and modified to use our datatypes. + // We had our own implementation of http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf, + // but the sppark versions achieved lower instruction count thanks to clever carry handling, + // so we decided to just use theirs. + +//change + static DEVICE_INLINE void mul_nz(uint32_t *acc, const uint32_t *a, uint32_t bi, size_t n = TLC) { +#pragma unroll + for (size_t i = 0; i < n; i += 2) { + acc[i] = ptx::mul_lo(a[i], bi); + acc[i + 1] = ptx::mul_hi(a[i], bi); + } + } + +//change + static DEVICE_INLINE void cmad_nz(uint32_t *acc, const uint32_t *a, uint32_t bi, size_t n = TLC) { + acc[0] = ptx::mad_lo_cc(a[0], bi, acc[0]); + acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); +#pragma unroll + for (size_t i = 2; i < n; i += 2) { + acc[i] = ptx::madc_lo_cc(a[i], bi, acc[i]); + acc[i + 1] = ptx::madc_hi_cc(a[i], bi, acc[i + 1]); + } + // return carry flag + } + + //add + static DEVICE_INLINE void madc_n_rshiftz(uint32_t *odd, const uint32_t *a, uint32_t bi) { + constexpr uint32_t n = TLC; +#pragma unroll + for (size_t i = 0; i < n - 2; i += 2) { + odd[i] = ptx::madc_lo_cc(a[i], bi, odd[i + 2]); + odd[i + 1] = ptx::madc_hi_cc(a[i], bi, odd[i + 3]); + } + odd[n - 2] = ptx::madc_lo_cc(a[n - 2], bi, 0); + odd[n - 1] = ptx::madc_hi(a[n - 2], bi, 0); + } + + //add + static DEVICE_INLINE void mad_n_redcz(uint32_t *even, uint32_t *odd, const uint32_t *a, uint32_t bi, bool first = false) { + constexpr uint32_t n = TLC; + constexpr auto modulus = CONFIG::modulus; + const uint32_t *const MOD = modulus.limbs; + constexpr auto mont_inv_modulus = CONFIG::mont_inv_modulus; + if (first) { + mul_nz(odd, a + 1, bi); + mul_nz(even, a, bi); + } else { + even[0] = ptx::add_cc(even[0], odd[1]); + madc_n_rshiftz(odd, a + 1, bi); + cmad_nz(even, a, bi); + odd[n - 1] = ptx::addc(odd[n - 1], 0); + } + uint32_t mi = even[0] * mont_inv_modulus.limbs[0]; + cmad_nz(odd, MOD + 1, mi); + cmad_nz(even, MOD, mi); + odd[n - 1] = ptx::addc(odd[n - 1], 0); + } + +//change + static DEVICE_INLINE void mad_rowz(uint32_t *odd, uint32_t *even, const uint32_t *a, uint32_t bi, size_t n = TLC) { + cmad_nz(odd, a + 1, bi, n - 2); + odd[n - 2] = ptx::madc_lo_cc(a[n - 1], bi, 0); + odd[n - 1] = ptx::madc_hi(a[n - 1], bi, 0); + cmad_nz(even, a, bi, n); + odd[n - 1] = ptx::addc(odd[n - 1], 0); + } + +//add + static DEVICE_INLINE void qad_rowz(uint32_t *odd, uint32_t *even, const uint32_t *a, uint32_t bi, size_t n = TLC) { + cmad_nz(odd, a, bi, n - 2); + odd[n - 2] = ptx::madc_lo_cc(a[n - 2], bi, 0); + odd[n - 1] = ptx::madc_hi(a[n - 2], bi, 0); + cmad_nz(even, a + 1, bi, n - 2); + odd[n - 1] = ptx::addc(odd[n - 1], 0); + } + +//change + static DEVICE_INLINE void multiply_rawz(const ff_storage &as, const ff_storage &bs, ff_wide_storage &rs) { + const uint32_t *a = as.limbs; + const uint32_t *b = bs.limbs; + uint32_t *even = rs.limbs; + __align__(8) uint32_t odd[2 * TLC - 2]; + mul_nz(even, a, b[0]); + mul_nz(odd, a + 1, b[0]); + mad_rowz(&even[2], &odd[0], a, b[1]); + size_t i; +#pragma unroll + for (i = 2; i < TLC - 1; i += 2) { + mad_rowz(&odd[i], &even[i], a, b[i]); + mad_rowz(&even[i + 2], &odd[i], a, b[i + 1]); + } + // merge |even| and |odd| + even[1] = ptx::add_cc(even[1], odd[0]); + for (i = 1; i < 2 * TLC - 2; i++) + even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); + even[i + 1] = ptx::addc(even[i + 1], 0); + } + + static DEVICE_INLINE void sqr_rawz(const ff_storage &as, ff_wide_storage &rs) { + const uint32_t *a = as.limbs; + uint32_t *even = rs.limbs; + size_t i = 0, j; + __align__(8) uint32_t odd[2 * TLC - 2]; + + // perform |a[i]|*|a[j]| for all j>i + mul_nz(even + 2, a + 2, a[0], TLC - 2); + mul_nz(odd, a + 1, a[0], TLC); + +#pragma unroll + while (i < TLC - 4) { + ++i; + mad_rowz(&even[2 * i + 2], &odd[2 * i], &a[i + 1], a[i], TLC - i - 1); + ++i; + qad_rowz(&odd[2 * i], &even[2 * i + 2], &a[i + 1], a[i], TLC - i); + } + + even[2 * TLC - 4] = ptx::mul_lo(a[TLC - 1], a[TLC - 3]); + even[2 * TLC - 3] = ptx::mul_hi(a[TLC - 1], a[TLC - 3]); + odd[2 * TLC - 6] = ptx::mad_lo_cc(a[TLC - 2], a[TLC - 3], odd[2 * TLC - 6]); + odd[2 * TLC - 5] = ptx::madc_hi_cc(a[TLC - 2], a[TLC - 3], odd[2 * TLC - 5]); + even[2 * TLC - 3] = ptx::addc(even[2 * TLC - 3], 0); + + odd[2 * TLC - 4] = ptx::mul_lo(a[TLC - 1], a[TLC - 2]); + odd[2 * TLC - 3] = ptx::mul_hi(a[TLC - 1], a[TLC - 2]); + + // merge |even[2:]| and |odd[1:]| + even[2] = ptx::add_cc(even[2], odd[1]); + for (j = 2; j < 2 * TLC - 3; j++) + even[j + 1] = ptx::addc_cc(even[j + 1], odd[j]); + even[j + 1] = ptx::addc(odd[j], 0); + + // double |even| + even[0] = 0; + even[1] = ptx::add_cc(odd[0], odd[0]); + for (j = 2; j < 2 * TLC - 1; j++) + even[j] = ptx::addc_cc(even[j], even[j]); + even[j] = ptx::addc(0, 0); + + // accumulate "diagonal" |a[i]|*|a[i]| product + i = 0; + even[2 * i] = ptx::mad_lo_cc(a[i], a[i], even[2 * i]); + even[2 * i + 1] = ptx::madc_hi_cc(a[i], a[i], even[2 * i + 1]); + for (++i; i < TLC; i++) { + even[2 * i] = ptx::madc_lo_cc(a[i], a[i], even[2 * i]); + even[2 * i + 1] = ptx::madc_hi_cc(a[i], a[i], even[2 * i + 1]); + } + } + +//add + static DEVICE_INLINE void mul_by_1_rowz(uint32_t *even, uint32_t *odd, bool first = false) { + uint32_t mi; + constexpr auto modulus = CONFIG::modulus; + const uint32_t *const MOD = modulus.limbs; + constexpr auto mont_inv_modulus = CONFIG::mont_inv_modulus; + if (first) { + mi = even[0] * mont_inv_modulus.limbs[0]; + mul_nz(odd, MOD + 1, mi); + cmad_nz(even, MOD, mi); + odd[TLC - 1] = ptx::addc(odd[TLC - 1], 0); + } else { + even[0] = ptx::add_cc(even[0], odd[1]); + // we trust the compiler to *not* touch the carry flag here + // this code sits in between two "asm volatile" instructions witch should guarantee that nothing else interferes wit the carry flag + mi = even[0] * mont_inv_modulus.limbs[0]; + madc_n_rshiftz(odd, MOD + 1, mi); + cmad_nz(even, MOD, mi); + odd[TLC - 1] = ptx::addc(odd[TLC - 1], 0); + } + } + +//add + // Performs Montgomery reduction on a ff_wide_storage input. Input value must be in the range [0, mod*2^(32*TLC)). + // Does not implement an in-place reduce epilogue. If you want to further reduce the result, + // call reduce(xs.get_lo()) after the call to redc_wide_inplace. + static DEVICE_INLINE void redc_wide_inplacez(ff_wide_storage &xs) { + uint32_t *even = xs.limbs; + // Yields montmul of lo TLC limbs * 1. + // Since the hi TLC limbs don't participate in computing the "mi" factor at each mul-and-rightshift stage, + // it's ok to ignore the hi TLC limbs during this process and just add them in afterward. + uint32_t odd[TLC]; + size_t i; +#pragma unroll + for (i = 0; i < TLC; i += 2) { + mul_by_1_rowz(&even[0], &odd[0], i == 0); + mul_by_1_rowz(&odd[0], &even[0]); + } + even[0] = ptx::add_cc(even[0], odd[1]); +#pragma unroll + for (i = 1; i < TLC - 1; i++) + even[i] = ptx::addc_cc(even[i], odd[i + 1]); + even[i] = ptx::addc(even[i], 0); + // Adds in (hi TLC limbs), implicitly right-shifting them by TLC limbs as if they had participated in the + // add-and-rightshift stages above. + xs.limbs[0] = ptx::add_cc(xs.limbs[0], xs.limbs[TLC]); +#pragma unroll + for (i = 1; i < TLC - 1; i++) + xs.limbs[i] = ptx::addc_cc(xs.limbs[i], xs.limbs[i + TLC]); + xs.limbs[TLC - 1] = ptx::addc(xs.limbs[TLC - 1], xs.limbs[2 * TLC - 1]); + } + +//add + static DEVICE_INLINE void montmul_rawz(const ff_storage &a_in, const ff_storage &b_in, ff_storage &r_in) { + constexpr uint32_t n = TLC; + constexpr auto modulus = CONFIG::modulus; + const uint32_t *const MOD = modulus.limbs; + const uint32_t *a = a_in.limbs; + const uint32_t *b = b_in.limbs; + uint32_t *even = r_in.limbs; + __align__(8) uint32_t odd[n + 1]; + size_t i; +#pragma unroll + for (i = 0; i < n; i += 2) { + mad_n_redcz(&even[0], &odd[0], a, b[i], i == 0); + mad_n_redcz(&odd[0], &even[0], a, b[i + 1]); + } + // merge |even| and |odd| + even[0] = ptx::add_cc(even[0], odd[1]); +#pragma unroll + for (i = 1; i < n - 1; i++) + even[i] = ptx::addc_cc(even[i], odd[i + 1]); + even[i] = ptx::addc(even[i], 0); + // final reduction from [0, 2*mod) to [0, mod) not done here, instead performed optionally in mul_device wrapper + } + +//change + // Returns xs * ys without Montgomery reduction. + template static constexpr DEVICE_INLINE ff_wide_storage mul_widez(const ff_storage &xs, const ff_storage &ys) { + // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes of slack + static_assert(!(CONFIG::modulus.limbs[TLC - 1] >> 30)); + ff_wide_storage rs = {0}; + multiply_rawz(xs, ys, rs); + return reduce_widez(rs); + } + +//add + // Performs Montgomery reduction on a ff_wide_storage input. Input value must be in the range [0, mod*2^(32*TLC)). + template static constexpr DEVICE_INLINE ff_storage redc_widez(const ff_wide_storage &xs) { + ff_wide_storage tmp{xs}; + redc_wide_inplacez(tmp); // after reduce_twopass, tmp's low TLC limbs should represent a value in [0, 2*mod) + return reducez(tmp.get_lo()); + } + +//add + template static constexpr DEVICE_INLINE ff_storage mul_devicez(const ff_storage &xs, const ff_storage &ys) { + // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes of slack + static_assert(!(CONFIG::modulus.limbs[TLC - 1] >> 30)); + // printf(" "); + ff_storage rs = {0}; + montmul_rawz(xs, ys, rs); + return reducez(rs); + } + + template static constexpr DEVICE_INLINE ff_storage sqr_devicez(const ff_storage &xs) { + // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes of slack + static_assert(!(CONFIG::modulus.limbs[TLC - 1] >> 30)); + ff_wide_storage rs = {0}; + sqr_rawz(xs, rs); + redc_wide_inplacez(rs); // after reduce_twopass, tmp's low TLC limbs should represent a value in [0, 2*mod) + return reducez(rs.get_lo()); + } + +//add + // return xs * ys with field operands + // Device path adapts http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf to use IMAD.WIDE. + // Host path uses CIOS. + template static constexpr DEVICE_INLINE ff_storage mulz(const ff_storage &xs, const ff_storage &ys) { + return mul_devicez(xs, ys); + } + + + + + + + + + static DEVICE_INLINE void mul_n(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = TLC) { UNROLL @@ -447,7 +930,7 @@ class Field * \cdot b_0}{2^{32}}} + \dots + \floor{\frac{a_0 \cdot b_{TLC - 2}}{2^{32}}}) \leq 2^{64} + 2\cdot 2^{96} + \dots + * (TLC - 2) \cdot 2^{32(TLC - 1)} + (TLC - 1) \cdot 2^{32(TLC - 1)} \leq 2(TLC - 1) \cdot 2^{32(TLC - 1)}\f$. */ - static DEVICE_INLINE void multiply_msb_raw_device(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs) + static DEVICE_INLINE void multiply_msb_raw_device(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs) { if constexpr (TLC > 1) { const uint32_t* a = as.limbs; @@ -594,7 +1077,7 @@ class Field * with so far. This method implements [subtractive * Karatsuba](https://en.wikipedia.org/wiki/Karatsuba_algorithm#Implementation). */ - static DEVICE_INLINE void multiply_raw_device(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs) + static DEVICE_INLINE void multiply_raw_device(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs) { const uint32_t* a = as.limbs; const uint32_t* b = bs.limbs; @@ -702,7 +1185,7 @@ class Field value.limbs_storage.limbs[i] = distribution(generator); while (lt(Field{get_modulus()}, value)) value = value - Field{get_modulus()}; - return value; + return to_montgomery(value); } static void rand_host_many(Field* out, int size) @@ -843,9 +1326,15 @@ class Field #if 1 friend HOST_DEVICE Field operator*(const Field& xs, const Field& ys) { + #ifdef __CUDA_ARCH__ //cuda + return Field{mulz(xs.limbs_storage,ys.limbs_storage)}; // Wide xy = mul_wide(xs, ys); // full mult // return reduce(xy); // reduce mod p + #else + // Wide xy = mul_wide(xs, ys); // full mult + // return reduce(xy); // reduce mod p return mont_mult(xs,ys); + #endif } static constexpr HOST_INLINE Field mont_mult(const Field& xs, const Field& ys) @@ -1206,11 +1695,11 @@ class Field return xs * xs; } - static constexpr HOST_DEVICE_INLINE Field to_montgomery(const Field& xs) { return xs * Field{CONFIG::montgomery_r}; } + static constexpr HOST_DEVICE_INLINE Field to_montgomery(const Field& xs) { return xs * Field{CONFIG::montgomery_r_sqr}; } static constexpr HOST_DEVICE_INLINE Field from_montgomery(const Field& xs) { - return xs * Field{CONFIG::montgomery_r_inv}; + return xs * Field{1}; } template @@ -1259,9 +1748,9 @@ class Field static constexpr HOST_DEVICE Field inverse(const Field& xs) { if (xs == zero()) return zero(); - constexpr Field one = Field{CONFIG::one}; + constexpr Field one = {1}; constexpr ff_storage modulus = CONFIG::modulus; - Field u = xs; + Field u = from_montgomery(xs); Field v = Field{modulus}; Field b = one; Field c = {}; @@ -1284,7 +1773,7 @@ class Field c = c - b; } } - return (u == one) ? b : c; + return (u == one) ? to_montgomery(b) : to_montgomery(c); } static constexpr HOST_DEVICE Field pow(Field base, int exp) diff --git a/icicle/include/icicle/fields/host_math.h b/icicle/include/icicle/fields/host_math.h index 241db32a6..49866752f 100644 --- a/icicle/include/icicle/fields/host_math.h +++ b/icicle/include/icicle/fields/host_math.h @@ -229,6 +229,23 @@ namespace host_math { } } + template + static constexpr HOST_INLINE void multiply_mont_32(const uint32_t* a, const uint32_t* b, const uint32_t* q, const uint32_t* p, uint32_t* r) + { + for (unsigned i = 0; i < NLIMBS_B; i++) { + uint32_t A = 0, C = 0; + r[0] = host_math::madc_cc(a[0], b[i], r[0], A); + uint32_t m = host_math::madc_cc(r[0], q[0], 0, C); //TODO - multiply inst + C = 0; + host_math::madc_cc(m, p[0], r[0], C); + for (unsigned j = 1; j < NLIMBS_A; j++) { + r[j] = host_math::madc_cc(a[j], b[i], r[j], A); + r[j - 1] = host_math::madc_cc(m, p[j], r[j], C); + } + r[NLIMBS_A - 1] = C + A; + } + } + template static HOST_INLINE void multiply_mont_64(const uint64_t* a, const uint64_t* b, const uint64_t* q, const uint64_t* p, uint64_t* r) { diff --git a/icicle/include/icicle/fields/params_gen.h b/icicle/include/icicle/fields/params_gen.h index 37975e91a..d2ea93bd9 100644 --- a/icicle/include/icicle/fields/params_gen.h +++ b/icicle/include/icicle/fields/params_gen.h @@ -49,6 +49,18 @@ namespace params_gen { return rs; } + template + static constexpr HOST_INLINE storage get_montgomery_constant_sqr(const storage& modulus) + { + storage rs = {1}; + for (int i = 0; i < 32 * NLIMBS * 2; i++) { + rs = host_math::template left_shift(rs); + storage temp = {}; + rs = host_math::template add_sub_limbs(rs, modulus, temp) ? rs : temp; + } + return rs; + } + template static constexpr HOST_INLINE storage get_lower(const storage<2*NLIMBS>& xs) { @@ -68,13 +80,10 @@ namespace params_gen { storage<2*NLIMBS> w_tmp = {}; host_math::template multiply_raw(modulus, modulus, w_tmp); tmp = params_gen::template get_lower(w_tmp); - for (int i = 0; i < NLIMBS; i++) - { - rs.limbs[i] = modulus.limbs[i]; - } + rs = modulus; host_math::template multiply_raw(tmp, rs, w_rs); rs = params_gen::template get_lower(w_rs); - for (int i = 0; i < 252; i++) { + for (int i = 0; i < sizeof(modulus.limbs[0])*8*NLIMBS - 4; i++) { storage<2*NLIMBS> w_tmp2 = {}; host_math::template multiply_raw(tmp, tmp, w_tmp2); tmp = params_gen::template get_lower(w_tmp2); @@ -83,10 +92,6 @@ namespace params_gen { rs = params_gen::template get_lower(w_rs2); } storage mont_r = {}; - for (int i = 0; i < NLIMBS; i++) - { - mont_r.limbs[i] = 0; - } host_math::template add_sub_limbs(mont_r, rs, rs); return rs; } @@ -124,25 +129,29 @@ namespace params_gen { } template - constexpr storage_array get_invs(const storage& modulus) + constexpr storage_array get_invs(const storage& modulus, const storage& mont_r_sqr, const storage& mont_inv) { storage_array invs = {}; storage rs = {1}; for (int i = 0; i < TWO_ADICITY; i++) { if (rs.limbs[0] & 1) host_math::template add_sub_limbs(rs, modulus, rs); rs = host_math::template right_shift(rs); + // host_math::template multiply_mont_32(rs.limbs, mont_r_sqr.limbs, mont_inv.limbs, modulus.limbs, rs.limbs); invs.storages[i] = rs; } return invs; } } // namespace params_gen +//do we still need modulus_2, 3, 4 when using montgomery? smae for num_of_reductions + #define PARAMS(modulus) \ static constexpr unsigned limbs_count = modulus.LC; \ static constexpr unsigned modulus_bit_count = \ 32 * (limbs_count - 1) + params_gen::floorlog2(modulus.limbs[limbs_count - 1]) + 1; \ static constexpr storage zero = {}; \ - static constexpr storage one = {1}; \ + static constexpr storage one = \ + params_gen::template get_montgomery_constant(modulus); \ static constexpr storage modulus_2 = host_math::template left_shift(modulus); \ static constexpr storage modulus_4 = host_math::template left_shift(modulus_2); \ static constexpr storage neg_modulus = \ @@ -156,6 +165,8 @@ namespace params_gen { static constexpr storage m = params_gen::template get_m(modulus); \ static constexpr storage montgomery_r = \ params_gen::template get_montgomery_constant(modulus); \ + static constexpr storage montgomery_r_sqr = \ + params_gen::template get_montgomery_constant_sqr(modulus); \ static constexpr storage montgomery_r_inv = \ params_gen::template get_montgomery_constant(modulus); \ static constexpr storage mont_inv_modulus = \ @@ -166,4 +177,4 @@ namespace params_gen { #define TWIDDLES(modulus, rou) \ static constexpr unsigned omegas_count = params_gen::template two_adicity(modulus); \ static constexpr storage_array inv = \ - params_gen::template get_invs(modulus); + params_gen::template get_invs(modulus, montgomery_r_sqr, mont_inv_modulus); diff --git a/icicle/tests/test_field_api.cpp b/icicle/tests/test_field_api.cpp index 3e867d6b5..372f322b6 100644 --- a/icicle/tests/test_field_api.cpp +++ b/icicle/tests/test_field_api.cpp @@ -22,7 +22,7 @@ using FpMicroseconds = std::chrono::duration(N); auto in_b = std::make_unique(N); FieldApiTest::random_samples(in_a.get(), N); @@ -194,6 +205,11 @@ TYPED_TEST(FieldApiTest, vectorOps) // mul run(s_reference_target, out_ref.get(), VERBOSE /*=measure*/, vector_mul, "vector mul", ITERS); run(s_main_target, out_main.get(), VERBOSE /*=measure*/, vector_mul, "vector mul", ITERS); + + // std::cout << in_a[0] << ", " << in_b[0] << ", " << out_main[0] << ", " << out_ref[0] << std::endl; + // std::cout << in_a[1] << ", " << in_b[1] << ", " << out_main[1] << ", " << out_ref[1] << std::endl; + + ASSERT_EQ(0, memcmp(out_main.get(), out_ref.get(), N * sizeof(TypeParam))); } @@ -352,9 +368,9 @@ TYPED_TEST(FieldApiTest, ntt) int seed = time(0); srand(seed); const bool inplace = 0; - const int logn = 16; + const int logn = 20; const uint64_t N = 1 << logn; - const int log_ntt_domain_size = logn + 1; + const int log_ntt_domain_size = logn; const int log_batch_size = 0; const int batch_size = 1 << log_batch_size; const Ordering ordering = static_cast(0); @@ -376,6 +392,9 @@ TYPED_TEST(FieldApiTest, ntt) const int total_size = N * batch_size; auto scalars = std::make_unique(total_size); FieldApiTest::random_samples(scalars.get(), total_size); + // for (uint32_t i=0; i(total_size); auto out_ref = std::make_unique(total_size); auto run = [&](const std::string& dev_type, TypeParam* out, const char* msg, bool measure, int iters) { @@ -427,9 +446,20 @@ TYPED_TEST(FieldApiTest, ntt) ICICLE_CHECK(icicle_destroy_stream(stream)); ICICLE_CHECK(ntt_release_domain()); }; - run(s_main_target, out_main.get(), "ntt", false /*=measure*/, 10 /*=iters*/); // warmup - run(s_reference_target, out_ref.get(), "ntt", VERBOSE /*=measure*/, 10 /*=iters*/); - run(s_main_target, out_main.get(), "ntt", VERBOSE /*=measure*/, 10 /*=iters*/); + run(s_main_target, out_main.get(), "ntt", false /*=measure*/, 1 /*=iters*/); // warmup + run(s_reference_target, out_ref.get(), "ntt", VERBOSE /*=measure*/, 1 /*=iters*/); + run(s_main_target, out_main.get(), "ntt", VERBOSE /*=measure*/, 1 /*=iters*/); + + + // std::cout << "\n"; + // for (int i=0;i CONFIG::omegas_count) { THROW_ICICLE_ERR(eIcicleError::INVALID_ARGUMENT, "Field: Invalid omega index"); } - + #ifdef BARRET + Field omega = Field{CONFIG::rou}; + #else Field omega = to_montgomery(Field{CONFIG::rou}); - // Field omega = Field{CONFIG::rou}; + #endif for (int i = 0; i < CONFIG::omegas_count - logn; i++) omega = sqr(omega); - std::cout << "omega: " << omega < CONFIG::omegas_count) { THROW_ICICLE_ERR(eIcicleError::INVALID_ARGUMENT, "Field: Invalid omega_inv index"); } - + #ifdef BARRET + Field omega = inverse(Field{CONFIG::rou}); + #else Field omega = inverse(to_montgomery(Field{CONFIG::rou})); + #endif for (int i = 0; i < CONFIG::omegas_count - logn; i++) omega = sqr(omega); return omega; @@ -96,7 +103,7 @@ class Field static HOST_DEVICE_INLINE Field inv_log_size(uint32_t logn) { - if (logn == 0) { return Field{CONFIG::one}; } + if (logn == 0) { return one(); } #ifndef __CUDA_ARCH__ if (logn > CONFIG::omegas_count) THROW_ICICLE_ERR(eIcicleError::INVALID_ARGUMENT, "Field: Invalid inv index"); #else @@ -107,7 +114,11 @@ class Field } #endif // __CUDA_ARCH__ storage_array const inv = CONFIG::inv; + #ifdef BARRET + return Field{inv.storages[logn - 1]}; + #else return to_montgomery(Field{inv.storages[logn - 1]}); + #endif } static constexpr HOST_INLINE unsigned get_omegas_count() @@ -1185,7 +1196,11 @@ template (xs.limbs_storage.limbs64, ys.limbs_storage.limbs64, get_mont_inv_modulus().limbs64, get_modulus<1>().limbs64, r.limbs_storage.limbs64); return mont_reduce(r); - // return Wide::get_lower(r); } #else @@ -1694,13 +1714,13 @@ template static constexpr HOST_DEVICE Field neg(const Field& xs) @@ -1750,7 +1770,11 @@ template zero = {}; \ - static constexpr storage one = \ - params_gen::template get_montgomery_constant(modulus); \ + static constexpr storage one = {1}; \ static constexpr storage modulus_2 = host_math::template left_shift(modulus); \ static constexpr storage modulus_4 = host_math::template left_shift(modulus_2); \ static constexpr storage neg_modulus = \ diff --git a/icicle/tests/test_field_api.cpp b/icicle/tests/test_field_api.cpp index 372f322b6..2aa9ce206 100644 --- a/icicle/tests/test_field_api.cpp +++ b/icicle/tests/test_field_api.cpp @@ -71,6 +71,11 @@ TYPED_TEST_SUITE(FieldApiTest, FTImplementations); // Note: this is testing host arithmetic. Other tests against CPU backend should guarantee correct device arithmetic too TYPED_TEST(FieldApiTest, FieldSanityTest) { + #ifdef BARRET + printf("USING BARRET MULT\n"); + #else + printf("USING MONTGOMERY MULT\n"); + #endif auto a = TypeParam::rand_host(); std::cout< Date: Thu, 14 Nov 2024 16:10:58 +0200 Subject: [PATCH 09/22] montgomery SOS reduction added Signed-off-by: Koren-Brand --- icicle/include/icicle/fields/field.h | 503 +++++++++++++---------- icicle/include/icicle/fields/host_math.h | 49 ++- icicle/tests/test_curve_api.cpp | 30 ++ 3 files changed, 354 insertions(+), 228 deletions(-) diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 8a779b2bb..e36f07f2a 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -62,11 +62,11 @@ class Field for (int i = 1; i < TLC; i++) { scalar.limbs[i] = 0; } - #ifdef BARRET +#ifdef BARRET return Field{scalar}; - #else +#else return to_montgomery(Field{scalar}); - #endif +#endif } static HOST_INLINE Field omega(uint32_t logn) @@ -74,11 +74,11 @@ class Field if (logn == 0) { return one(); } if (logn > CONFIG::omegas_count) { THROW_ICICLE_ERR(eIcicleError::INVALID_ARGUMENT, "Field: Invalid omega index"); } - #ifdef BARRET +#ifdef BARRET Field omega = Field{CONFIG::rou}; - #else +#else Field omega = to_montgomery(Field{CONFIG::rou}); - #endif +#endif for (int i = 0; i < CONFIG::omegas_count - logn; i++) omega = sqr(omega); return omega; @@ -91,11 +91,11 @@ class Field if (logn > CONFIG::omegas_count) { THROW_ICICLE_ERR(eIcicleError::INVALID_ARGUMENT, "Field: Invalid omega_inv index"); } - #ifdef BARRET +#ifdef BARRET Field omega = inverse(Field{CONFIG::rou}); - #else +#else Field omega = inverse(to_montgomery(Field{CONFIG::rou})); - #endif +#endif for (int i = 0; i < CONFIG::omegas_count - logn; i++) omega = sqr(omega); return omega; @@ -114,11 +114,11 @@ class Field } #endif // __CUDA_ARCH__ storage_array const inv = CONFIG::inv; - #ifdef BARRET +#ifdef BARRET return Field{inv.storages[logn - 1]}; - #else +#else return to_montgomery(Field{inv.storages[logn - 1]}); - #endif +#endif } static constexpr HOST_INLINE unsigned get_omegas_count() @@ -329,194 +329,212 @@ class Field #ifdef __CUDACC__ -template struct carry_chainz { - unsigned index; - - constexpr __device__ __forceinline__ carry_chainz() : index(0) {} - - __device__ __forceinline__ uint32_t add(const uint32_t x, const uint32_t y) { - index++; - if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) - return ptx::add(x, y); - else if (index == 1 && !CARRY_IN) - return ptx::add_cc(x, y); - else if (index < OPS_COUNT || CARRY_OUT) - return ptx::addc_cc(x, y); - else - return ptx::addc(x, y); - } + template + struct carry_chainz { + unsigned index; - __device__ __forceinline__ uint32_t sub(const uint32_t x, const uint32_t y) { - index++; - if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) - return ptx::sub(x, y); - else if (index == 1 && !CARRY_IN) - return ptx::sub_cc(x, y); - else if (index < OPS_COUNT || CARRY_OUT) - return ptx::subc_cc(x, y); - else - return ptx::subc(x, y); - } + constexpr __device__ __forceinline__ carry_chainz() : index(0) {} - __device__ __forceinline__ uint32_t mad_lo(const uint32_t x, const uint32_t y, const uint32_t z) { - index++; - if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) - return ptx::mad_lo(x, y, z); - else if (index == 1 && !CARRY_IN) - return ptx::mad_lo_cc(x, y, z); - else if (index < OPS_COUNT || CARRY_OUT) - return ptx::madc_lo_cc(x, y, z); - else - return ptx::madc_lo(x, y, z); - } + __device__ __forceinline__ uint32_t add(const uint32_t x, const uint32_t y) + { + index++; + if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) + return ptx::add(x, y); + else if (index == 1 && !CARRY_IN) + return ptx::add_cc(x, y); + else if (index < OPS_COUNT || CARRY_OUT) + return ptx::addc_cc(x, y); + else + return ptx::addc(x, y); + } - __device__ __forceinline__ uint32_t mad_hi(const uint32_t x, const uint32_t y, const uint32_t z) { - index++; - if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) - return ptx::mad_hi(x, y, z); - else if (index == 1 && !CARRY_IN) - return ptx::mad_hi_cc(x, y, z); - else if (index < OPS_COUNT || CARRY_OUT) - return ptx::madc_hi_cc(x, y, z); - else - return ptx::madc_hi(x, y, z); - } -}; + __device__ __forceinline__ uint32_t sub(const uint32_t x, const uint32_t y) + { + index++; + if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) + return ptx::sub(x, y); + else if (index == 1 && !CARRY_IN) + return ptx::sub_cc(x, y); + else if (index < OPS_COUNT || CARRY_OUT) + return ptx::subc_cc(x, y); + else + return ptx::subc(x, y); + } + __device__ __forceinline__ uint32_t mad_lo(const uint32_t x, const uint32_t y, const uint32_t z) + { + index++; + if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) + return ptx::mad_lo(x, y, z); + else if (index == 1 && !CARRY_IN) + return ptx::mad_lo_cc(x, y, z); + else if (index < OPS_COUNT || CARRY_OUT) + return ptx::madc_lo_cc(x, y, z); + else + return ptx::madc_lo(x, y, z); + } + __device__ __forceinline__ uint32_t mad_hi(const uint32_t x, const uint32_t y, const uint32_t z) + { + index++; + if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) + return ptx::mad_hi(x, y, z); + else if (index == 1 && !CARRY_IN) + return ptx::mad_hi_cc(x, y, z); + else if (index < OPS_COUNT || CARRY_OUT) + return ptx::madc_hi_cc(x, y, z); + else + return ptx::madc_hi(x, y, z); + } + }; - // add or subtract limbs - template static constexpr DEVICE_INLINE uint32_t add_sub_limbs_devicez(const ff_storage &xs, const ff_storage &ys, ff_storage &rs) { - const uint32_t *x = xs.limbs; - const uint32_t *y = ys.limbs; - uint32_t *r = rs.limbs; + // add or subtract limbs + template + static constexpr DEVICE_INLINE uint32_t + add_sub_limbs_devicez(const ff_storage& xs, const ff_storage& ys, ff_storage& rs) + { + const uint32_t* x = xs.limbs; + const uint32_t* y = ys.limbs; + uint32_t* r = rs.limbs; carry_chainz chain; -#pragma unroll + #pragma unroll for (unsigned i = 0; i < TLC; i++) r[i] = SUBTRACT ? chain.sub(x[i], y[i]) : chain.add(x[i], y[i]); - if (!CARRY_OUT) - return 0; + if (!CARRY_OUT) return 0; return SUBTRACT ? chain.sub(0, 0) : chain.add(0, 0); } - // If we want, we could make "2*TLC" a template parameter to deduplicate with "ff_storage" overload, but that's a minor issue. + // If we want, we could make "2*TLC" a template parameter to deduplicate with "ff_storage" overload, but that's a + // minor issue. template - static constexpr DEVICE_INLINE uint32_t add_sub_limbs_devicez(const ff_wide_storage &xs, const ff_wide_storage &ys, ff_wide_storage &rs) { - const uint32_t *x = xs.limbs; - const uint32_t *y = ys.limbs; - uint32_t *r = rs.limbs; + static constexpr DEVICE_INLINE uint32_t + add_sub_limbs_devicez(const ff_wide_storage& xs, const ff_wide_storage& ys, ff_wide_storage& rs) + { + const uint32_t* x = xs.limbs; + const uint32_t* y = ys.limbs; + uint32_t* r = rs.limbs; carry_chainz chain; -#pragma unroll + #pragma unroll for (unsigned i = 0; i < 2 * TLC; i++) { r[i] = SUBTRACT ? chain.sub(x[i], y[i]) : chain.add(x[i], y[i]); } - if (!CARRY_OUT) - return 0; + if (!CARRY_OUT) return 0; return SUBTRACT ? chain.sub(0, 0) : chain.add(0, 0); } - template static constexpr DEVICE_INLINE uint32_t add_sub_limbsz(const T &xs, const T &ys, T &rs) { + template + static constexpr DEVICE_INLINE uint32_t add_sub_limbsz(const T& xs, const T& ys, T& rs) + { // No need for static_assert(std::is_same::value || std::is_same::value). // Instantiation will fail if appropriate add_sub_limbs_device overload does not exist. return add_sub_limbs_devicez(xs, ys, rs); } - template static constexpr DEVICE_INLINE uint32_t add_limbsz(const T &xs, const T &ys, T &rs) { + template + static constexpr DEVICE_INLINE uint32_t add_limbsz(const T& xs, const T& ys, T& rs) + { return add_sub_limbsz(xs, ys, rs); } - template static constexpr DEVICE_INLINE uint32_t sub_limbsz(const T &xs, const T &ys, T &rs) { + template + static constexpr DEVICE_INLINE uint32_t sub_limbsz(const T& xs, const T& ys, T& rs) + { return add_sub_limbsz(xs, ys, rs); } // return xs == 0 with field operands - static constexpr DEVICE_INLINE bool is_zero_devicez(const ff_storage &xs) { - const uint32_t *x = xs.limbs; + static constexpr DEVICE_INLINE bool is_zero_devicez(const ff_storage& xs) + { + const uint32_t* x = xs.limbs; uint32_t limbs_or = x[0]; -#pragma unroll + #pragma unroll for (unsigned i = 1; i < TLC; i++) limbs_or |= x[i]; return limbs_or == 0; } - static constexpr DEVICE_INLINE bool is_zeroz(const ff_storage &xs) { - return is_zero_devicez(xs); - } + static constexpr DEVICE_INLINE bool is_zeroz(const ff_storage& xs) { return is_zero_devicez(xs); } // return xs == ys with field operands - static constexpr DEVICE_INLINE bool eq_devicez(const ff_storage &xs, const ff_storage &ys) { - const uint32_t *x = xs.limbs; - const uint32_t *y = ys.limbs; + static constexpr DEVICE_INLINE bool eq_devicez(const ff_storage& xs, const ff_storage& ys) + { + const uint32_t* x = xs.limbs; + const uint32_t* y = ys.limbs; uint32_t limbs_or = x[0] ^ y[0]; -#pragma unroll + #pragma unroll for (unsigned i = 1; i < TLC; i++) limbs_or |= x[i] ^ y[i]; return limbs_or == 0; } - static constexpr DEVICE_INLINE bool eqz(const ff_storage &xs, const ff_storage &ys) { - return eq_devicez(xs, ys); - } + static constexpr DEVICE_INLINE bool eqz(const ff_storage& xs, const ff_storage& ys) { return eq_devicez(xs, ys); } - template static constexpr DEVICE_INLINE ff_storage reducez(const ff_storage &xs) { - if (REDUCTION_SIZE == 0) - return xs; + template + static constexpr DEVICE_INLINE ff_storage reducez(const ff_storage& xs) + { + if (REDUCTION_SIZE == 0) return xs; const ff_storage modulus = get_modulus(); ff_storage rs = {}; return sub_limbsz(xs, modulus, rs) ? xs : rs; } - template static constexpr DEVICE_INLINE ff_wide_storage reduce_widez(const ff_wide_storage &xs) { - if (REDUCTION_SIZE == 0) - return xs; + template + static constexpr DEVICE_INLINE ff_wide_storage reduce_widez(const ff_wide_storage& xs) + { + if (REDUCTION_SIZE == 0) return xs; const ff_wide_storage modulus_squared = get_modulus_squared(); ff_wide_storage rs = {}; return sub_limbsz(xs, modulus_squared, rs) ? xs : rs; } // return xs + ys with field operands - template static constexpr DEVICE_INLINE ff_storage addz(const ff_storage &xs, const ff_storage &ys) { + template + static constexpr DEVICE_INLINE ff_storage addz(const ff_storage& xs, const ff_storage& ys) + { ff_storage rs = {}; add_limbsz(xs, ys, rs); return reducez(rs); } - template static constexpr DEVICE_INLINE ff_wide_storage add_widez(const ff_wide_storage &xs, const ff_wide_storage &ys) { + template + static constexpr DEVICE_INLINE ff_wide_storage add_widez(const ff_wide_storage& xs, const ff_wide_storage& ys) + { ff_wide_storage rs = {}; add_limbsz(xs, ys, rs); return reduce_widez(rs); } // return xs - ys with field operands - template static DEVICE_INLINE ff_storage subz(const ff_storage &xs, const ff_storage &ys) { + template + static DEVICE_INLINE ff_storage subz(const ff_storage& xs, const ff_storage& ys) + { ff_storage rs = {}; if (REDUCTION_SIZE == 0) { sub_limbsz(xs, ys, rs); } else { uint32_t carry = sub_limbsz(xs, ys, rs); - if (carry == 0) - return rs; + if (carry == 0) return rs; const ff_storage modulus = get_modulus(); add_limbsz(rs, modulus, rs); } return rs; } - template static DEVICE_INLINE ff_wide_storage sub_widez(const ff_wide_storage &xs, const ff_wide_storage &ys) { + template + static DEVICE_INLINE ff_wide_storage sub_widez(const ff_wide_storage& xs, const ff_wide_storage& ys) + { ff_wide_storage rs = {}; if (REDUCTION_SIZE == 0) { sub_limbsz(xs, ys, rs); } else { uint32_t carry = sub_limbsz(xs, ys, rs); - if (carry == 0) - return rs; + if (carry == 0) return rs; const ff_wide_storage modulus_squared = get_modulus_squared(); add_limbsz(rs, modulus_squared, rs); } return rs; } - // The following algorithms are adaptations of // http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf, // taken from https://github.com/z-prize/test-msm-gpu (under Apache 2.0 license) @@ -525,20 +543,22 @@ template epilogue. If you want to further reduce the result, - // call reduce(xs.get_lo()) after the call to redc_wide_inplace. - static DEVICE_INLINE void redc_wide_inplacez(ff_wide_storage &xs) { - uint32_t *even = xs.limbs; + // add + // Performs Montgomery reduction on a ff_wide_storage input. Input value must be in the range [0, mod*2^(32*TLC)). + // Does not implement an in-place reduce epilogue. If you want to further reduce the result, + // call reduce(xs.get_lo()) after the call to redc_wide_inplace. + static DEVICE_INLINE void redc_wide_inplacez(ff_wide_storage& xs) + { + uint32_t* even = xs.limbs; // Yields montmul of lo TLC limbs * 1. // Since the hi TLC limbs don't participate in computing the "mi" factor at each mul-and-rightshift stage, // it's ok to ignore the hi TLC limbs during this process and just add them in afterward. uint32_t odd[TLC]; size_t i; -#pragma unroll + #pragma unroll for (i = 0; i < TLC; i += 2) { mul_by_1_rowz(&even[0], &odd[0], i == 0); mul_by_1_rowz(&odd[0], &even[0]); } even[0] = ptx::add_cc(even[0], odd[1]); -#pragma unroll + #pragma unroll for (i = 1; i < TLC - 1; i++) even[i] = ptx::addc_cc(even[i], odd[i + 1]); even[i] = ptx::addc(even[i], 0); // Adds in (hi TLC limbs), implicitly right-shifting them by TLC limbs as if they had participated in the // add-and-rightshift stages above. xs.limbs[0] = ptx::add_cc(xs.limbs[0], xs.limbs[TLC]); -#pragma unroll + #pragma unroll for (i = 1; i < TLC - 1; i++) xs.limbs[i] = ptx::addc_cc(xs.limbs[i], xs.limbs[i + TLC]); xs.limbs[TLC - 1] = ptx::addc(xs.limbs[TLC - 1], xs.limbs[2 * TLC - 1]); } -//add - static DEVICE_INLINE void montmul_rawz(const ff_storage &a_in, const ff_storage &b_in, ff_storage &r_in) { + // add + static DEVICE_INLINE void montmul_rawz(const ff_storage& a_in, const ff_storage& b_in, ff_storage& r_in) + { constexpr uint32_t n = TLC; constexpr auto modulus = CONFIG::modulus; - const uint32_t *const MOD = modulus.limbs; - const uint32_t *a = a_in.limbs; - const uint32_t *b = b_in.limbs; - uint32_t *even = r_in.limbs; + const uint32_t* const MOD = modulus.limbs; + const uint32_t* a = a_in.limbs; + const uint32_t* b = b_in.limbs; + uint32_t* even = r_in.limbs; __align__(8) uint32_t odd[n + 1]; size_t i; -#pragma unroll + #pragma unroll for (i = 0; i < n; i += 2) { mad_n_redcz(&even[0], &odd[0], a, b[i], i == 0); mad_n_redcz(&odd[0], &even[0], a, b[i + 1]); } // merge |even| and |odd| even[0] = ptx::add_cc(even[0], odd[1]); -#pragma unroll + #pragma unroll for (i = 1; i < n - 1; i++) even[i] = ptx::addc_cc(even[i], odd[i + 1]); even[i] = ptx::addc(even[i], 0); // final reduction from [0, 2*mod) to [0, mod) not done here, instead performed optionally in mul_device wrapper } -//change - // Returns xs * ys without Montgomery reduction. - template static constexpr DEVICE_INLINE ff_wide_storage mul_widez(const ff_storage &xs, const ff_storage &ys) { - // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes of slack + // change + // Returns xs * ys without Montgomery reduction. + template + static constexpr DEVICE_INLINE ff_wide_storage mul_widez(const ff_storage& xs, const ff_storage& ys) + { + // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes + // of slack static_assert(!(CONFIG::modulus.limbs[TLC - 1] >> 30)); ff_wide_storage rs = {0}; multiply_rawz(xs, ys, rs); return reduce_widez(rs); } -//add - // Performs Montgomery reduction on a ff_wide_storage input. Input value must be in the range [0, mod*2^(32*TLC)). - template static constexpr DEVICE_INLINE ff_storage redc_widez(const ff_wide_storage &xs) { + // add + // Performs Montgomery reduction on a ff_wide_storage input. Input value must be in the range [0, mod*2^(32*TLC)). + template + static constexpr DEVICE_INLINE ff_storage redc_widez(const ff_wide_storage& xs) + { ff_wide_storage tmp{xs}; redc_wide_inplacez(tmp); // after reduce_twopass, tmp's low TLC limbs should represent a value in [0, 2*mod) return reducez(tmp.get_lo()); } -//add - template static constexpr DEVICE_INLINE ff_storage mul_devicez(const ff_storage &xs, const ff_storage &ys) { - // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes of slack + // add + template + static constexpr DEVICE_INLINE ff_storage mul_devicez(const ff_storage& xs, const ff_storage& ys) + { + // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes + // of slack static_assert(!(CONFIG::modulus.limbs[TLC - 1] >> 30)); // printf(" "); ff_storage rs = {0}; @@ -773,8 +812,11 @@ template (rs); } - template static constexpr DEVICE_INLINE ff_storage sqr_devicez(const ff_storage &xs) { - // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes of slack + template + static constexpr DEVICE_INLINE ff_storage sqr_devicez(const ff_storage& xs) + { + // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes + // of slack static_assert(!(CONFIG::modulus.limbs[TLC - 1] >> 30)); ff_wide_storage rs = {0}; sqr_rawz(xs, rs); @@ -782,22 +824,16 @@ template (rs.get_lo()); } -//add - // return xs * ys with field operands - // Device path adapts http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf to use IMAD.WIDE. - // Host path uses CIOS. - template static constexpr DEVICE_INLINE ff_storage mulz(const ff_storage &xs, const ff_storage &ys) { + // add + // return xs * ys with field operands + // Device path adapts http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf to use IMAD.WIDE. + // Host path uses CIOS. + template + static constexpr DEVICE_INLINE ff_storage mulz(const ff_storage& xs, const ff_storage& ys) + { return mul_devicez(xs, ys); } - - - - - - - - static DEVICE_INLINE void mul_n(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = TLC) { UNROLL @@ -941,7 +977,7 @@ template 1) { const uint32_t* a = as.limbs; @@ -1088,7 +1124,7 @@ template - static constexpr HOST_DEVICE_INLINE Field mont_reduce(const Wide& xs) - { - // Field xs_lo = Wide::get_lower(xs); - // Field xs_hi = Wide::get_higher(xs); - // Wide l1 = {}; - // Wide l2 = {}; - // host_math::template multiply_raw(xs_lo.limbs_storage, get_m(), l1.limbs_storage); - // Field l1_lo = Wide::get_lower(l1); - // host_math::template multiply_raw(l1_lo.limbs_storage, get_modulus<1>(), l2.limbs_storage); - // Field l2_hi = Wide::get_higher(l2); - // Field r = {}; - // add_limbs(l2_hi.limbs_storage, xs_hi.limbs_storage, r.limbs_storage); - - Field r = Wide::get_lower(xs); + static constexpr HOST_DEVICE_INLINE Field mont_reduce(const Wide& xs, bool get_higher_half = false) + { + // Field xs_lo = Wide::get_lower(xs); + // Field xs_hi = Wide::get_higher(xs); + // Wide l1 = {}; + // Wide l2 = {}; + // host_math::template multiply_raw(xs_lo.limbs_storage, get_m(), l1.limbs_storage); + // Field l1_lo = Wide::get_lower(l1); + // host_math::template multiply_raw(l1_lo.limbs_storage, get_modulus<1>(), l2.limbs_storage); + // Field l2_hi = Wide::get_higher(l2); + // Field r = {}; + // add_limbs(l2_hi.limbs_storage, xs_hi.limbs_storage, r.limbs_storage); + + Field r = get_higher_half ? Wide::get_higher(xs) : Wide::get_lower(xs); Field p = Field{get_modulus<1>()}; - if (p.limbs_storage.limbs[TLC-1] > r.limbs_storage.limbs[TLC-1]) - return r; + if (p.limbs_storage.limbs[TLC - 1] > r.limbs_storage.limbs[TLC - 1]) return r; ff_storage r_reduced = {}; uint64_t carry = 0; carry = sub_limbs(r.limbs_storage, get_modulus<1>(), r_reduced); @@ -1328,7 +1363,6 @@ template (xs.limbs_storage.limbs64, ys.limbs_storage.limbs64, get_mont_inv_modulus().limbs64, get_modulus<1>().limbs64, r.limbs_storage.limbs64); + host_math::multiply_mont_64( + xs.limbs_storage.limbs64, ys.limbs_storage.limbs64, get_mont_inv_modulus().limbs64, get_modulus<1>().limbs64, + r.limbs_storage.limbs64); return mont_reduce(r); } + + /** + * @brief Perform SOS reduction on a number in montgomery representation \p t in range [0,p^2-1] (p is the field's + * modulus) limiting it to the range [0,2p-1]. + * @param t Number to be reduced. Must be in montgomery rep, and in range [0,p^2-1]. + * @return \p t mod p + */ + static constexpr HOST_INLINE Field sos_mont_reduce(Wide& t) + { + Wide r = {}; + host_math::sos_mont_reduction_64( + t.limbs_storage.limbs64, get_modulus<1>().limbs64, get_mont_inv_modulus().limbs64, r.limbs_storage.limbs64); + return mont_reduce(r, /* get_higher_half = */ true); + } + #else // #if defined(__GNUC__) && !defined(__NVCC__) && !defined(__clang__) @@ -1716,11 +1767,17 @@ template static constexpr HOST_DEVICE Field neg(const Field& xs) @@ -1770,11 +1827,11 @@ template - static constexpr HOST_INLINE void multiply_mont_32(const uint32_t* a, const uint32_t* b, const uint32_t* q, const uint32_t* p, uint32_t* r) + static constexpr HOST_INLINE void + multiply_mont_32(const uint32_t* a, const uint32_t* b, const uint32_t* q, const uint32_t* p, uint32_t* r) { for (unsigned i = 0; i < NLIMBS_B; i++) { uint32_t A = 0, C = 0; r[0] = host_math::madc_cc(a[0], b[i], r[0], A); - uint32_t m = host_math::madc_cc(r[0], q[0], 0, C); //TODO - multiply inst + uint32_t m = host_math::madc_cc(r[0], q[0], 0, C); // TODO - multiply inst C = 0; host_math::madc_cc(m, p[0], r[0], C); for (unsigned j = 1; j < NLIMBS_A; j++) { @@ -246,8 +247,9 @@ namespace host_math { } } - template - static HOST_INLINE void multiply_mont_64(const uint64_t* a, const uint64_t* b, const uint64_t* q, const uint64_t* p, uint64_t* r) + template + static HOST_INLINE void + multiply_mont_64(const uint64_t* a, const uint64_t* b, const uint64_t* q, const uint64_t* p, uint64_t* r) { // printf("r0: "); // for (unsigned i = 0; i < NLIMBS_B / 2; i++) { @@ -262,7 +264,7 @@ namespace host_math { // printf("q0 %lu\n",q[0]); // printf("p0 %lu\n",p[0]); // printf("A %lu\n",A); - uint64_t m = host_math::madc_cc_64(r[0], q[0], 0, C); //TODO - multiply inst + uint64_t m = host_math::madc_cc_64(r[0], q[0], 0, C); // TODO - multiply inst // printf("m %lu\n",m); C = 0; host_math::madc_cc_64(m, p[0], r[0], C); @@ -280,6 +282,43 @@ namespace host_math { // printf("\n"); } + /** + * @brief Perform SOS reduction on a number in montgomery representation \p t in range [0, \p n ^2-1] limiting it to + * the range [0,2 \p n -1]. + * @param t Number to be reduced. Must be in montgomery rep, and in range [0, \p n ^2-1]. + * @param n Field modulus. + * @param n_tag Number such that \p n * \p n_tag modR = -1 + * @param r Array in which to store the result in its upper half (Lower half is data that would be removed by + * dividing by R = shifting NLIMBS down). + * @tparam NLIMBS Number of 32bit limbs required to represend a number in the field defined by n. R is 2^(NLIMBS*32). + */ + template + static HOST_INLINE void + sos_mont_reduction_64(const uint64_t* t, const uint64_t* n, const uint64_t* n_tag, uint64_t* r) + { + const unsigned s = NLIMBS / 2; // Divide by 2 because NLIMBS is 32bit and this function is 64bit + + // Copy t to r as t is read-only + for (int i = 0; i < 2 * s; i++) { + r[i] = t[i]; + } + + for (int i = 0; i < s; i++) { + uint64_t c = 0; + uint64_t m = r[i] * n_tag[0]; + + for (int j = 0; j < s; j++) { + // r[i+j] = addc_cc(r[i+j], m * n[j], c); + r[i + j] = madc_cc_64(m, n[j], r[i + j], c); + } + // Propagate the carry to the remaining sublimbs + for (int carry_idx = s + i; carry_idx < 2 * s; carry_idx++) { + if (c == 0) { break; } + r[carry_idx] = add_cc(r[carry_idx], c, c); + } + } + } + template static HOST_INLINE void multiply_raw_64(const storage& as, const storage& bs, storage& rs) diff --git a/icicle/tests/test_curve_api.cpp b/icicle/tests/test_curve_api.cpp index 05e22fc46..961a58a61 100644 --- a/icicle/tests/test_curve_api.cpp +++ b/icicle/tests/test_curve_api.cpp @@ -480,6 +480,36 @@ TYPED_TEST(CurveSanity, u64Mul) END_TIMER(u64Mult_asm, "U64-MULT-asm", true); } +#ifndef BARRET +TYPED_TEST(CurveSanity, MontSosReduction) +{ + // SOS reduction currently only in CPU + if (s_ref_target == "CPU") { + const unsigned n = 1 << 10; + auto as = std::make_unique(n); + auto bs = std::make_unique(n); + auto abs = std::make_unique(n); + + scalar_t::rand_host_many(as.get(), n); + scalar_t::rand_host_many(bs.get(), n); + + icicle_set_device(Device{s_ref_target, 0}); + + START_TIMER(mont_sos_reduction); + for (int i = 0; i < n; i++) { + auto ab_no_mod = scalar_t::mul_wide(as[i], bs[i]); + abs[i] = scalar_t::sos_mont_reduce(ab_no_mod); + } + END_TIMER(mont_sos_reduction, "CPU-Montgomery SOS reduction", true); + + // Assert reduction in comparison with Montgomery multiplier + for (int i = 0; i < n; i++) { + ICICLE_ASSERT(abs[i] == as[i] * bs[i]); + } + } +} +#endif + int main(int argc, char** argv) { ::testing::InitGoogleTest(&argc, argv); From b2533cd967fa202be45183fa54c21e0c0f80689d Mon Sep 17 00:00:00 2001 From: Koren-Brand Date: Mon, 18 Nov 2024 13:45:13 +0200 Subject: [PATCH 10/22] EC addition and MSM on CPU modified to work with montgomery representation as well Signed-off-by: Koren-Brand --- icicle/backend/cpu/src/curve/cpu_msm.hpp | 21 ++++++++++-- icicle/include/icicle/curves/projective.h | 32 +++++++++++++++++++ .../include/icicle/fields/complex_extension.h | 4 +-- icicle/include/icicle/fields/field.h | 2 +- icicle/include/icicle/msm.h | 8 +++++ icicle/tests/test_curve_api.cpp | 6 +++- 6 files changed, 66 insertions(+), 7 deletions(-) diff --git a/icicle/backend/cpu/src/curve/cpu_msm.hpp b/icicle/backend/cpu/src/curve/cpu_msm.hpp index 7cfd1032d..129318795 100644 --- a/icicle/backend/cpu/src/curve/cpu_msm.hpp +++ b/icicle/backend/cpu/src/curve/cpu_msm.hpp @@ -451,9 +451,13 @@ void Msm::phase1_bucket_accumulator(const scalar_t* scalars, const A* base bool negate_p_and_s = scalar.get_scalar_digit(scalar_t::NBITS - 1, 1) > 0; if (negate_p_and_s) { scalar = scalar_t::neg(scalar); } for (int j = 0; j < m_precompute_factor; j++) { - // Handle required preprocess of base P + // Handle required preprocess of base P according to the version of Field/Ec adder (accepting Barret / Montgomery) A base = + #ifdef BARRET m_are_points_mont ? A::from_montgomery(bases[m_precompute_factor * i + j]) : bases[m_precompute_factor * i + j]; + #else + m_are_points_mont ? bases[m_precompute_factor * i + j] : A::to_montgomery(bases[m_precompute_factor * i + j]); + #endif if (base == A::zero()) { continue; } if (negate_p_and_s) { base = A::neg(base); } @@ -780,12 +784,23 @@ eIcicleError cpu_msm_precompute_bases( const unsigned int shift = c * ((num_bms_no_precomp - 1) / precompute_factor + 1); for (int i = 0; i < nof_bases; i++) { output_bases[precompute_factor * i] = input_bases[i]; - P point = P::from_affine(is_mont ? A::from_montgomery(input_bases[i]) : input_bases[i]); + // Handle required preprocess of base P according to the version of Field/Ec adder (accepting Barret / Montgomery) + P point = + #ifdef BARRET + P::from_affine(is_mont ? A::from_montgomery(input_bases[i]) : input_bases[i]); + #else + P::from_affine(is_mont ? input_bases[i] : A::to_montgomery(input_bases[i])); + #endif for (int j = 1; j < precompute_factor; j++) { for (int k = 0; k < shift; k++) { point = P::dbl(point); } - output_bases[precompute_factor * i + j] = is_mont ? A::to_montgomery(P::to_affine(point)) : P::to_affine(point); + output_bases[precompute_factor * i + j] = + #ifdef BARRET + is_mont ? A::to_montgomery(P::to_affine(point)) : P::to_affine(point); + #else + is_mont ? P::to_affine(point) : A::from_montgomery(P::to_affine(point)); + #endif } } return eIcicleError::SUCCESS; diff --git a/icicle/include/icicle/curves/projective.h b/icicle/include/icicle/curves/projective.h index 0a2feb323..86e47a7c7 100644 --- a/icicle/include/icicle/curves/projective.h +++ b/icicle/include/icicle/curves/projective.h @@ -47,7 +47,11 @@ class Projective return {FF::from_montgomery(point.x), FF::from_montgomery(point.y), FF::from_montgomery(point.z)}; } + #ifdef BARRET static HOST_DEVICE_INLINE Projective generator() { return {Gen::gen_x, Gen::gen_y, FF::one()}; } + #else + static HOST_DEVICE_INLINE Projective generator() { return {FF::to_montgomery(Gen::gen_x), FF::to_montgomery(Gen::gen_y), FF::one()}; } + #endif static HOST_DEVICE_INLINE Projective neg(const Projective& point) { return {point.x, FF::neg(point.y), point.z}; } @@ -115,13 +119,25 @@ class Projective FF::template mul_unsigned<3>(FF::template mul_const(t17)); // t23 ← b3 · t17 < 2 const auto t24 = FF::mul_wide(t12, t23); // t24 ← t12 · t23 < 2 const auto t25 = FF::mul_wide(t07, t22); // t25 ← t07 · t22 < 2 + #ifdef BARRET const FF X3 = FF::reduce(t25 - t24); // X3 ← t25 − t24 < 2 + #else + const FF X3 = FF::sos_mont_reduce(t25 - t24); // X3 ← t25 − t24 < 2 + #endif const auto t27 = FF::mul_wide(t23, t19); // t27 ← t23 · t19 < 2 const auto t28 = FF::mul_wide(t22, t21); // t28 ← t22 · t21 < 2 + #ifdef BARRET const FF Y3 = FF::reduce(t28 + t27); // Y3 ← t28 + t27 < 2 + #else + const FF Y3 = FF::sos_mont_reduce(t28 + t27); // Y3 ← t28 + t27 < 2 + #endif const auto t30 = FF::mul_wide(t19, t07); // t30 ← t19 · t07 < 2 const auto t31 = FF::mul_wide(t21, t12); // t31 ← t21 · t12 < 2 + #ifdef BARRET const FF Z3 = FF::reduce(t31 + t30); // Z3 ← t31 + t30 < 2 + #else + const FF Z3 = FF::sos_mont_reduce(t31 + t30); // Z3 ← t31 + t30 < 2 + #endif return {X3, Y3, Z3}; } @@ -162,13 +178,25 @@ class Projective FF::template mul_unsigned<3>(FF::template mul_const(t17)); // t23 ← b3 · t17 < 2 const auto t24 = FF::mul_wide(t12, t23); // t24 ← t12 · t23 < 2 const auto t25 = FF::mul_wide(t07, t22); // t25 ← t07 · t22 < 2 + #ifdef BARRET const FF X3 = FF::reduce(t25 - t24); // X3 ← t25 − t24 < 2 + #else + const FF X3 = FF::sos_mont_reduce(t25 - t24); // X3 ← t25 − t24 < 2 + #endif const auto t27 = FF::mul_wide(t23, t19); // t27 ← t23 · t19 < 2 const auto t28 = FF::mul_wide(t22, t21); // t28 ← t22 · t21 < 2 + #ifdef BARRET const FF Y3 = FF::reduce(t28 + t27); // Y3 ← t28 + t27 < 2 + #else + const FF Y3 = FF::sos_mont_reduce(t28 + t27); // Y3 ← t28 + t27 < 2 + #endif const auto t30 = FF::mul_wide(t19, t07); // t30 ← t19 · t07 < 2 const auto t31 = FF::mul_wide(t21, t12); // t31 ← t21 · t12 < 2 + #ifdef BARRET const FF Z3 = FF::reduce(t31 + t30); // Z3 ← t31 + t30 < 2 + #else + const FF Z3 = FF::sos_mont_reduce(t31 + t30); // Z3 ← t31 + t30 < 2 + #endif // const auto t24 = FF::mul_widez(t12.limbs_storage, t23.limbs_storage); // t24 ← t12 · t23 < 2 // const auto t25 = FF::mul_widez(t07.limbs_storage, t22.limbs_storage); // t25 ← t07 · t22 < 2 // typename FF::Wide W3 = typename FF::Wide{t25} - typename FF::Wide{t24}; // X3 ← t25 − t24 < 2 @@ -194,6 +222,10 @@ class Projective friend HOST_DEVICE Projective operator*(SCALAR_FF scalar, const Projective& point) { + #ifndef BARRET + scalar = SCALAR_FF::from_montgomery(scalar); + #endif + // Precompute points: P, 2P, ..., (2^window_size - 1)P constexpr unsigned window_size = 4; // 4 seems fastest. Optimum is minimizing EC add and depends on the field size. for 256b it's 4. diff --git a/icicle/include/icicle/fields/complex_extension.h b/icicle/include/icicle/fields/complex_extension.h index a6d116a3a..4c1d00298 100644 --- a/icicle/include/icicle/fields/complex_extension.h +++ b/icicle/include/icicle/fields/complex_extension.h @@ -48,12 +48,12 @@ class ComplexExtensionField static constexpr HOST_DEVICE_INLINE ComplexExtensionField to_montgomery(const ComplexExtensionField& xs) { - return ComplexExtensionField{xs.real * FF{CONFIG::montgomery_r}, xs.imaginary * FF{CONFIG::montgomery_r}}; + return ComplexExtensionField{FF::to_montgomery(xs.real), FF::to_montgomery(xs.imaginary)}; } static constexpr HOST_DEVICE_INLINE ComplexExtensionField from_montgomery(const ComplexExtensionField& xs) { - return ComplexExtensionField{xs.real * FF{CONFIG::montgomery_r_inv}, xs.imaginary * FF{CONFIG::montgomery_r_inv}}; + return ComplexExtensionField{FF::from_montgomery(xs.real), FF::from_montgomery(xs.imaginary)}; } static HOST_INLINE ComplexExtensionField rand_host() diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index e36f07f2a..a3b6e1af0 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -1407,7 +1407,7 @@ class Field * @param t Number to be reduced. Must be in montgomery rep, and in range [0,p^2-1]. * @return \p t mod p */ - static constexpr HOST_INLINE Field sos_mont_reduce(Wide& t) + static constexpr HOST_INLINE Field sos_mont_reduce(const Wide& t) { Wide r = {}; host_math::sos_mont_reduction_64( diff --git a/icicle/include/icicle/msm.h b/icicle/include/icicle/msm.h index b146074b1..e73de1ac6 100644 --- a/icicle/include/icicle/msm.h +++ b/icicle/include/icicle/msm.h @@ -67,9 +67,17 @@ namespace icicle { 1, // batch_size true, // are_points_shared_in_batch false, // are_scalars_on_device + #ifdef BARRET false, // are_scalars_montgomery_form + #else + true, // are_scalars_montgomery_form + #endif false, // are_points_on_device + #ifdef BARRET false, // are_points_montgomery_form + #else + true, // are_points_montgomery_form + #endif false, // are_results_on_device false, // is_async nullptr, // ext diff --git a/icicle/tests/test_curve_api.cpp b/icicle/tests/test_curve_api.cpp index 961a58a61..0914e17fb 100644 --- a/icicle/tests/test_curve_api.cpp +++ b/icicle/tests/test_curve_api.cpp @@ -368,12 +368,16 @@ TYPED_TEST(CurveSanity, CurveSanityTest) TYPED_TEST(CurveSanity, ScalarMultTest) { const auto point = TypeParam::rand_host(); - const auto scalar = scalar_t::rand_host(); + auto scalar = scalar_t::rand_host(); START_TIMER(main) const auto mult = scalar * point; END_TIMER(main, "scalar mult window method", true); + #ifndef BARRET + scalar = scalar_t::from_montgomery(scalar); + #endif + auto expected_mult = TypeParam::zero(); START_TIMER(ref) for (int i = 0; i < scalar_t::NBITS; i++) { From c979161e1d3c0b4d1f2c92858ebf4b8543398a29 Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Mon, 18 Nov 2024 14:50:54 +0200 Subject: [PATCH 11/22] split device math from field --- icicle/include/icicle/curves/projective.h | 29 +- icicle/include/icicle/fields/device_math.h | 657 +++++++++++++ icicle/include/icicle/fields/field.h | 866 +----------------- icicle/include/icicle/fields/params_gen.h | 2 +- .../fields/snark_fields/bls12_377_base.h | 2 +- .../fields/snark_fields/bls12_377_scalar.h | 2 +- .../fields/snark_fields/bls12_381_scalar.h | 2 +- .../icicle/fields/snark_fields/bn254_scalar.h | 2 +- .../icicle/fields/stark_fields/babybear.h | 2 +- .../icicle/fields/stark_fields/stark252.h | 2 +- icicle/tests/test_curve_api.cpp | 14 +- icicle/tests/test_device_api.cpp | 5 + icicle/tests/test_field_api.cpp | 31 +- icicle/tests/test_hash_api.cpp | 5 + icicle/tests/test_polynomial_api.cpp | 5 + 15 files changed, 746 insertions(+), 880 deletions(-) create mode 100644 icicle/include/icicle/fields/device_math.h diff --git a/icicle/include/icicle/curves/projective.h b/icicle/include/icicle/curves/projective.h index 0a2feb323..a2470b7f7 100644 --- a/icicle/include/icicle/curves/projective.h +++ b/icicle/include/icicle/curves/projective.h @@ -46,8 +46,11 @@ class Projective { return {FF::from_montgomery(point.x), FF::from_montgomery(point.y), FF::from_montgomery(point.z)}; } - +#ifdef BARRET static HOST_DEVICE_INLINE Projective generator() { return {Gen::gen_x, Gen::gen_y, FF::one()}; } + #else + static HOST_DEVICE_INLINE Projective generator() { return {FF::to_montgomery(Gen::gen_x), FF::to_montgomery(Gen::gen_y), FF::one()}; } + #endif static HOST_DEVICE_INLINE Projective neg(const Projective& point) { return {point.x, FF::neg(point.y), point.z}; } @@ -160,15 +163,7 @@ class Projective const FF t22 = t01 - t20; // t22 ← t01 − t20 < 2 const FF t23 = FF::template mul_unsigned<3>(FF::template mul_const(t17)); // t23 ← b3 · t17 < 2 - const auto t24 = FF::mul_wide(t12, t23); // t24 ← t12 · t23 < 2 - const auto t25 = FF::mul_wide(t07, t22); // t25 ← t07 · t22 < 2 - const FF X3 = FF::reduce(t25 - t24); // X3 ← t25 − t24 < 2 - const auto t27 = FF::mul_wide(t23, t19); // t27 ← t23 · t19 < 2 - const auto t28 = FF::mul_wide(t22, t21); // t28 ← t22 · t21 < 2 - const FF Y3 = FF::reduce(t28 + t27); // Y3 ← t28 + t27 < 2 - const auto t30 = FF::mul_wide(t19, t07); // t30 ← t19 · t07 < 2 - const auto t31 = FF::mul_wide(t21, t12); // t31 ← t21 · t12 < 2 - const FF Z3 = FF::reduce(t31 + t30); // Z3 ← t31 + t30 < 2 + // #ifdef __CUDA_ARCH__ // const auto t24 = FF::mul_widez(t12.limbs_storage, t23.limbs_storage); // t24 ← t12 · t23 < 2 // const auto t25 = FF::mul_widez(t07.limbs_storage, t22.limbs_storage); // t25 ← t07 · t22 < 2 // typename FF::Wide W3 = typename FF::Wide{t25} - typename FF::Wide{t24}; // X3 ← t25 − t24 < 2 @@ -184,6 +179,17 @@ class Projective // W3 = typename FF::Wide{t31} + typename FF::Wide{t30}; // Z3 ← t31 + t30 < 2 // FF::redc_wide_inplacez(W3.limbs_storage); // Z3 ← t31 + t30 < 2 // const auto Z3 = FF::Wide::get_lower(W3); + // #else + const auto t24 = FF::mul_wide(t12, t23); // t24 ← t12 · t23 < 2 + const auto t25 = FF::mul_wide(t07, t22); // t25 ← t07 · t22 < 2 + const FF X3 = FF::reduce(t25 - t24); // X3 ← t25 − t24 < 2 + const auto t27 = FF::mul_wide(t23, t19); // t27 ← t23 · t19 < 2 + const auto t28 = FF::mul_wide(t22, t21); // t28 ← t22 · t21 < 2 + const FF Y3 = FF::reduce(t28 + t27); // Y3 ← t28 + t27 < 2 + const auto t30 = FF::mul_wide(t19, t07); // t30 ← t19 · t07 < 2 + const auto t31 = FF::mul_wide(t21, t12); // t31 ← t21 · t12 < 2 + const FF Z3 = FF::reduce(t31 + t30); // Z3 ← t31 + t30 < 2 + // #endif return {X3, Y3, Z3}; } @@ -194,6 +200,9 @@ class Projective friend HOST_DEVICE Projective operator*(SCALAR_FF scalar, const Projective& point) { + #ifndef BARRET + scalar = SCALAR_FF::from_montgomery(scalar); + #endif // Precompute points: P, 2P, ..., (2^window_size - 1)P constexpr unsigned window_size = 4; // 4 seems fastest. Optimum is minimizing EC add and depends on the field size. for 256b it's 4. diff --git a/icicle/include/icicle/fields/device_math.h b/icicle/include/icicle/fields/device_math.h new file mode 100644 index 000000000..fba40f399 --- /dev/null +++ b/icicle/include/icicle/fields/device_math.h @@ -0,0 +1,657 @@ +#ifdef __CUDACC__ + +#pragma once + +#include +#include "icicle/utils/modifiers.h" +#include "icicle/fields/storage.h" +#include "ptx.h" + +namespace device_math { + +template +struct carry_chain { + unsigned index; + + constexpr __device__ __forceinline__ carry_chain() : index(0) {} + + __device__ __forceinline__ uint32_t add(const uint32_t x, const uint32_t y) { + index++; + if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) + return ptx::add(x, y); + else if (index == 1 && !CARRY_IN) + return ptx::add_cc(x, y); + else if (index < OPS_COUNT || CARRY_OUT) + return ptx::addc_cc(x, y); + else + return ptx::addc(x, y); + } + + __device__ __forceinline__ uint32_t sub(const uint32_t x, const uint32_t y) { + index++; + if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) + return ptx::sub(x, y); + else if (index == 1 && !CARRY_IN) + return ptx::sub_cc(x, y); + else if (index < OPS_COUNT || CARRY_OUT) + return ptx::subc_cc(x, y); + else + return ptx::subc(x, y); + } + + __device__ __forceinline__ uint32_t mad_lo(const uint32_t x, const uint32_t y, const uint32_t z) { + index++; + if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) + return ptx::mad_lo(x, y, z); + else if (index == 1 && !CARRY_IN) + return ptx::mad_lo_cc(x, y, z); + else if (index < OPS_COUNT || CARRY_OUT) + return ptx::madc_lo_cc(x, y, z); + else + return ptx::madc_lo(x, y, z); + } + + __device__ __forceinline__ uint32_t mad_hi(const uint32_t x, const uint32_t y, const uint32_t z) { + index++; + if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) + return ptx::mad_hi(x, y, z); + else if (index == 1 && !CARRY_IN) + return ptx::mad_hi_cc(x, y, z); + else if (index < OPS_COUNT || CARRY_OUT) + return ptx::madc_hi_cc(x, y, z); + else + return ptx::madc_hi(x, y, z); + } +}; + +template + static constexpr DEVICE_INLINE uint32_t add_sub_u32_device(const uint32_t* x, const uint32_t* y, uint32_t* r) + { + r[0] = SUBTRACT ? ptx::sub_cc(x[0], y[0]) : ptx::add_cc(x[0], y[0]); + for (unsigned i = 1; i < NLIMBS; i++) + r[i] = SUBTRACT ? ptx::subc_cc(x[i], y[i]) : ptx::addc_cc(x[i], y[i]); + if (!CARRY_OUT) { + ptx::addc(0, 0); + return 0; + } + return SUBTRACT ? ptx::subc(0, 0) : ptx::addc(0, 0); + } + + template + static constexpr DEVICE_INLINE uint32_t + add_sub_limbs_device(const storage& xs, const storage& ys, storage& rs) + { + const uint32_t* x = xs.limbs; + const uint32_t* y = ys.limbs; + uint32_t* r = rs.limbs; + return add_sub_u32_device(x, y, r); + } + + template + static DEVICE_INLINE void mul_n(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = NLIMBS) + { + UNROLL + for (size_t i = 0; i < n; i += 2) { + acc[i] = ptx::mul_lo(a[i], bi); + acc[i + 1] = ptx::mul_hi(a[i], bi); + } + } + + template + static DEVICE_INLINE void mul_n_msb(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = NLIMBS, size_t start_i = 0) + { + UNROLL + for (size_t i = start_i; i < n; i += 2) { + acc[i] = ptx::mul_lo(a[i], bi); + acc[i + 1] = ptx::mul_hi(a[i], bi); + } + } + + template + static DEVICE_INLINE void + cmad_n(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = NLIMBS, uint32_t optional_carry = 0) + { + if (CARRY_IN) ptx::add_cc(UINT32_MAX, optional_carry); + acc[0] = CARRY_IN ? ptx::madc_lo_cc(a[0], bi, acc[0]) : ptx::mad_lo_cc(a[0], bi, acc[0]); + acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); + + UNROLL + for (size_t i = 2; i < n; i += 2) { + acc[i] = ptx::madc_lo_cc(a[i], bi, acc[i]); + acc[i + 1] = ptx::madc_hi_cc(a[i], bi, acc[i + 1]); + } + } + + template + static DEVICE_INLINE void cmad_n_msb(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = NLIMBS) + { + if (EVEN_PHASE) { + acc[0] = ptx::mad_lo_cc(a[0], bi, acc[0]); + acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); + } else { + acc[1] = ptx::mad_hi_cc(a[0], bi, acc[1]); + } + + UNROLL + for (size_t i = 2; i < n; i += 2) { + acc[i] = ptx::madc_lo_cc(a[i], bi, acc[i]); + acc[i + 1] = ptx::madc_hi_cc(a[i], bi, acc[i + 1]); + } + } + + template + static DEVICE_INLINE void cmad_n_lsb(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = NLIMBS) + { + if (n > 1) + acc[0] = ptx::mad_lo_cc(a[0], bi, acc[0]); + else + acc[0] = ptx::mad_lo(a[0], bi, acc[0]); + + size_t i; + UNROLL + for (i = 1; i < n - 1; i += 2) { + acc[i] = ptx::madc_hi_cc(a[i - 1], bi, acc[i]); + if (i == n - 2) + acc[i + 1] = ptx::madc_lo(a[i + 1], bi, acc[i + 1]); + else + acc[i + 1] = ptx::madc_lo_cc(a[i + 1], bi, acc[i + 1]); + } + if (i == n - 1) acc[i] = ptx::madc_hi(a[i - 1], bi, acc[i]); + } + + template + static DEVICE_INLINE uint32_t mad_row( + uint32_t* odd, + uint32_t* even, + const uint32_t* a, + uint32_t bi, + size_t n = NLIMBS, + uint32_t ci = 0, + uint32_t di = 0, + uint32_t carry_for_high = 0, + uint32_t carry_for_low = 0) + { + cmad_n(odd, a + 1, bi, n - 2, carry_for_low); + odd[n - 2] = ptx::madc_lo_cc(a[n - 1], bi, ci); + odd[n - 1] = CARRY_OUT ? ptx::madc_hi_cc(a[n - 1], bi, di) : ptx::madc_hi(a[n - 1], bi, di); + uint32_t cr = CARRY_OUT ? ptx::addc(0, 0) : 0; + cmad_n(even, a, bi, n); + if (CARRY_OUT) { + odd[n - 1] = ptx::addc_cc(odd[n - 1], carry_for_high); + cr = ptx::addc(cr, 0); + } else + odd[n - 1] = ptx::addc(odd[n - 1], carry_for_high); + return cr; + } + + template + static DEVICE_INLINE void mad_row_msb(uint32_t* odd, uint32_t* even, const uint32_t* a, uint32_t bi, size_t n = NLIMBS) + { + cmad_n_msb(odd, EVEN_PHASE ? a : (a + 1), bi, n - 2); + odd[EVEN_PHASE ? (n - 1) : (n - 2)] = ptx::madc_lo_cc(a[n - 1], bi, 0); + odd[EVEN_PHASE ? n : (n - 1)] = ptx::madc_hi(a[n - 1], bi, 0); + cmad_n_msb(even, EVEN_PHASE ? (a + 1) : a, bi, n - 1); + odd[EVEN_PHASE ? n : (n - 1)] = ptx::addc(odd[EVEN_PHASE ? n : (n - 1)], 0); + } + + template + static DEVICE_INLINE void mad_row_lsb(uint32_t* odd, uint32_t* even, const uint32_t* a, uint32_t bi, size_t n = NLIMBS) + { + // bi here is constant so we can do a compile-time check for zero (which does happen once for bls12-381 scalar field + // modulus) + if (bi != 0) { + if (n > 1) cmad_n_lsb(odd, a + 1, bi, n - 1); + cmad_n_lsb(even, a, bi, n); + } + return; + } + + template + static DEVICE_INLINE uint32_t + mul_n_and_add(uint32_t* acc, const uint32_t* a, uint32_t bi, uint32_t* extra, size_t n = (NLIMBS >> 1)) + { + acc[0] = ptx::mad_lo_cc(a[0], bi, extra[0]); + + UNROLL + for (size_t i = 1; i < n - 1; i += 2) { + acc[i] = ptx::madc_hi_cc(a[i - 1], bi, extra[i]); + acc[i + 1] = ptx::madc_lo_cc(a[i + 1], bi, extra[i + 1]); + } + + acc[n - 1] = ptx::madc_hi_cc(a[n - 2], bi, extra[n - 1]); + return ptx::addc(0, 0); + } + + /** + * This method multiplies `a` and `b` (both assumed to have NLIMBS / 2 limbs) and adds `in1` and `in2` (NLIMBS limbs each) + * to the result which is written to `even`. + * + * It is used to compute the "middle" part of Karatsuba: \f$ a_{lo} \cdot b_{hi} + b_{lo} \cdot a_{hi} = + * (a_{hi} - a_{lo})(b_{lo} - b_{hi}) + a_{lo} \cdot b_{lo} + a_{hi} \cdot b_{hi} \f$. Currently this method assumes + * that the top bit of \f$ a_{hi} \f$ and \f$ b_{hi} \f$ are unset. This ensures correctness by allowing to keep the + * result inside NLIMBS limbs and ignore the carries from the highest limb. + */ + template + static DEVICE_INLINE void + multiply_and_add_short_raw_device(const uint32_t* a, const uint32_t* b, uint32_t* even, uint32_t* in1, uint32_t* in2) + { + __align__(16) uint32_t odd[NLIMBS - 2]; + uint32_t first_row_carry = mul_n_and_add(even, a, b[0], in1); + uint32_t carry = mul_n_and_add(odd, a + 1, b[0], &in2[1]); + + size_t i; + UNROLL + for (i = 2; i < ((NLIMBS >> 1) - 1); i += 2) { + carry = mad_row( + &even[i], &odd[i - 2], a, b[i - 1], NLIMBS >> 1, in1[(NLIMBS >> 1) + i - 2], in1[(NLIMBS >> 1) + i - 1], carry); + carry = + mad_row(&odd[i], &even[i], a, b[i], NLIMBS >> 1, in2[(NLIMBS >> 1) + i - 1], in2[(NLIMBS >> 1) + i], carry); + } + mad_row( + &even[NLIMBS >> 1], &odd[(NLIMBS >> 1) - 2], a, b[(NLIMBS >> 1) - 1], NLIMBS >> 1, in1[NLIMBS - 2], in1[NLIMBS - 1], carry, + first_row_carry); + // merge |even| and |odd| plus the parts of `in2` we haven't added yet (first and last limbs) + even[0] = ptx::add_cc(even[0], in2[0]); + for (i = 0; i < (NLIMBS - 2); i++) + even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); + even[i + 1] = ptx::addc(even[i + 1], in2[i + 1]); + } + + + + /** + * This method multiplies `a` and `b` and writes the result into `even`. It assumes that `a` and `b` are NLIMBS/2 limbs + * long. The usual schoolbook algorithm is used. + */ + template + static DEVICE_INLINE void multiply_short_raw_device(const uint32_t* a, const uint32_t* b, uint32_t* even) + { + __align__(16) uint32_t odd[NLIMBS - 2]; + mul_n(even, a, b[0], NLIMBS >> 1); + mul_n(odd, a + 1, b[0], NLIMBS >> 1); + mad_row(&even[2], &odd[0], a, b[1], NLIMBS >> 1); + + size_t i; + UNROLL + for (i = 2; i < ((NLIMBS >> 1) - 1); i += 2) { + mad_row(&odd[i], &even[i], a, b[i], NLIMBS >> 1); + mad_row(&even[i + 2], &odd[i], a, b[i + 1], NLIMBS >> 1); + } + // merge |even| and |odd| + even[1] = ptx::add_cc(even[1], odd[0]); + for (i = 1; i < NLIMBS - 2; i++) + even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); + even[i + 1] = ptx::addc(even[i + 1], 0); + } + + /** + * This method multiplies `as` and `bs` and writes the (wide) result into `rs`. + * + * It is assumed that the highest bits of `as` and `bs` are unset which is true for all the numbers icicle had to deal + * with so far. This method implements [subtractive + * Karatsuba](https://en.wikipedia.org/wiki/Karatsuba_algorithm#Implementation). + */ + template + static DEVICE_INLINE void multiply_raw_device(const storage& as, const storage& bs, storage<2*NLIMBS>& rs) + { + const uint32_t* a = as.limbs; + const uint32_t* b = bs.limbs; + uint32_t* r = rs.limbs; + if constexpr (NLIMBS > 2) { + // Next two lines multiply high and low halves of operands (\f$ a_{lo} \cdot b_{lo}; a_{hi} \cdot b_{hi} \$f) and + // write the results into `r`. + multiply_short_raw_device(a, b, r); + multiply_short_raw_device(&a[NLIMBS >> 1], &b[NLIMBS >> 1], &r[NLIMBS]); + __align__(16) uint32_t middle_part[NLIMBS]; + __align__(16) uint32_t diffs[NLIMBS]; + // Differences of halves \f$ a_{hi} - a_{lo}; b_{lo} - b_{hi} \$f are written into `diffs`, signs written to + // `carry1` and `carry2`. + uint32_t carry1 = add_sub_u32_device<(NLIMBS >> 1), true, true>(&a[NLIMBS >> 1], a, diffs); + uint32_t carry2 = add_sub_u32_device<(NLIMBS >> 1), true, true>(b, &b[NLIMBS >> 1], &diffs[NLIMBS >> 1]); + // Compute the "middle part" of Karatsuba: \f$ a_{lo} \cdot b_{hi} + b_{lo} \cdot a_{hi} \f$. + // This is where the assumption about unset high bit of `a` and `b` is relevant. + multiply_and_add_short_raw_device(diffs, &diffs[NLIMBS >> 1], middle_part, r, &r[NLIMBS]); + // Corrections that need to be performed when differences are negative. + // Again, carry doesn't need to be propagated due to unset high bits of `a` and `b`. + if (carry1) + add_sub_u32_device<(NLIMBS >> 1), true, false>(&middle_part[NLIMBS >> 1], &diffs[NLIMBS >> 1], &middle_part[NLIMBS >> 1]); + if (carry2) add_sub_u32_device<(NLIMBS >> 1), true, false>(&middle_part[NLIMBS >> 1], diffs, &middle_part[NLIMBS >> 1]); + // Now that middle part is fully correct, it can be added to the result. + add_sub_u32_device(&r[NLIMBS >> 1], middle_part, &r[NLIMBS >> 1]); + + // Carry from adding middle part has to be propagated to the highest limb. + for (size_t i = NLIMBS + (NLIMBS >> 1); i < 2 * NLIMBS; i++) + r[i] = ptx::addc_cc(r[i], 0); + } else if (NLIMBS == 2) { + __align__(8) uint32_t odd[2]; + r[0] = ptx::mul_lo(a[0], b[0]); + r[1] = ptx::mul_hi(a[0], b[0]); + r[2] = ptx::mul_lo(a[1], b[1]); + r[3] = ptx::mul_hi(a[1], b[1]); + odd[0] = ptx::mul_lo(a[0], b[1]); + odd[1] = ptx::mul_hi(a[0], b[1]); + odd[0] = ptx::mad_lo(a[1], b[0], odd[0]); + odd[1] = ptx::mad_hi(a[1], b[0], odd[1]); + r[1] = ptx::add_cc(r[1], odd[0]); + r[2] = ptx::addc_cc(r[2], odd[1]); + r[3] = ptx::addc(r[3], 0); + } else if (NLIMBS == 1) { + r[0] = ptx::mul_lo(a[0], b[0]); + r[1] = ptx::mul_hi(a[0], b[0]); + } + } + + /** + * A function that computes wide product \f$ rs = as \cdot bs \f$ that's correct for the higher NLIMBS + 1 limbs with a + * small maximum error. + * + * The way this function saves computations (as compared to regular school-book multiplication) is by not including + * terms that are too small. Namely, limb product \f$ a_i \cdot b_j \f$ is excluded if \f$ i + j < NLIMBS - 2 \f$ and + * only the higher half is included if \f$ i + j = NLIMBS - 2 \f$. All other limb products are included. So, the error + * i.e. difference between true product and the result of this function written to `rs` is exactly the sum of all + * dropped limbs products, which we can bound: \f$ a_0 \cdot b_0 + 2^{32}(a_0 \cdot b_1 + a_1 \cdot b_0) + \dots + + * 2^{32(NLIMBS - 3)}(a_{NLIMBS - 3} \cdot b_0 + \dots + a_0 \cdot b_{NLIMBS - 3}) + 2^{32(NLIMBS - 2)}(\floor{\frac{a_{NLIMBS - 2} + * \cdot b_0}{2^{32}}} + \dots + \floor{\frac{a_0 \cdot b_{NLIMBS - 2}}{2^{32}}}) \leq 2^{64} + 2\cdot 2^{96} + \dots + + * (NLIMBS - 2) \cdot 2^{32(NLIMBS - 1)} + (NLIMBS - 1) \cdot 2^{32(NLIMBS - 1)} \leq 2(NLIMBS - 1) \cdot 2^{32(NLIMBS - 1)}\f$. + */ + template + static DEVICE_INLINE void multiply_msb_raw_device(const storage& as, const storage& bs, storage<2*NLIMBS>& rs) + { + if constexpr (NLIMBS > 1) { + const uint32_t* a = as.limbs; + const uint32_t* b = bs.limbs; + uint32_t* even = rs.limbs; + __align__(16) uint32_t odd[2 * NLIMBS - 2]; + + even[NLIMBS - 1] = ptx::mul_hi(a[NLIMBS - 2], b[0]); + odd[NLIMBS - 2] = ptx::mul_lo(a[NLIMBS - 1], b[0]); + odd[NLIMBS - 1] = ptx::mul_hi(a[NLIMBS - 1], b[0]); + size_t i; + UNROLL + for (i = 2; i < NLIMBS - 1; i += 2) { + mad_row_msb(&even[NLIMBS - 2], &odd[NLIMBS - 2], &a[NLIMBS - i - 1], b[i - 1], i + 1); + mad_row_msb(&odd[NLIMBS - 2], &even[NLIMBS - 2], &a[NLIMBS - i - 2], b[i], i + 2); + } + mad_row(&even[NLIMBS], &odd[NLIMBS - 2], a, b[NLIMBS - 1]); + + // merge |even| and |odd| + ptx::add_cc(even[NLIMBS - 1], odd[NLIMBS - 2]); + for (i = NLIMBS - 1; i < 2 * NLIMBS - 2; i++) + even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); + even[i + 1] = ptx::addc(even[i + 1], 0); + } else { + multiply_raw_device(as, bs, rs); + } + } + + /** + * A function that computes the low half of the fused multiply-and-add \f$ rs = as \cdot bs + cs \f$ where + * \f$ bs = 2^{32*nof_limbs} \f$. + * + * For efficiency, this method does not include terms that are too large. Namely, limb product \f$ a_i \cdot b_j \f$ + * is excluded if \f$ i + j > NLIMBS - 1 \f$ and only the lower half is included if \f$ i + j = NLIMBS - 1 \f$. All other + * limb products are included. + */ + template + static DEVICE_INLINE void + multiply_and_add_lsb_neg_modulus_raw_device(const storage& as, const storage& bs, storage& cs, storage& rs) + { + const uint32_t* a = as.limbs; + const uint32_t* b = bs.limbs; + uint32_t* c = cs.limbs; + uint32_t* even = rs.limbs; + + if constexpr (NLIMBS > 2) { + __align__(16) uint32_t odd[NLIMBS - 1]; + size_t i; + // `b[0]` is \f$ 2^{32} \f$ minus the last limb of prime modulus. Because most scalar (and some base) primes + // are necessarily NTT-friendly, `b[0]` often turns out to be \f$ 2^{32} - 1 \f$. This actually leads to + // less efficient SASS generated by nvcc, so this case needed separate handling. + if (b[0] == UINT32_MAX) { + add_sub_u32_device(c, a, even); + for (i = 0; i < NLIMBS - 1; i++) + odd[i] = a[i]; + } else { + mul_n_and_add(even, a, b[0], c, NLIMBS); + mul_n(odd, a + 1, b[0], NLIMBS - 1); + } + mad_row_lsb(&even[2], &odd[0], a, b[1], NLIMBS - 1); + UNROLL + for (i = 2; i < NLIMBS - 1; i += 2) { + mad_row_lsb(&odd[i], &even[i], a, b[i], NLIMBS - i); + mad_row_lsb(&even[i + 2], &odd[i], a, b[i + 1], NLIMBS - i - 1); + } + + // merge |even| and |odd| + even[1] = ptx::add_cc(even[1], odd[0]); + for (i = 1; i < NLIMBS - 2; i++) + even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); + even[i + 1] = ptx::addc(even[i + 1], odd[i]); + } else if (NLIMBS == 2) { + even[0] = ptx::mad_lo(a[0], b[0], c[0]); + even[1] = ptx::mad_hi(a[0], b[0], c[0]); + even[1] = ptx::mad_lo(a[0], b[1], even[1]); + even[1] = ptx::mad_lo(a[1], b[0], even[1]); + } else if (NLIMBS == 1) { + even[0] = ptx::mad_lo(a[0], b[0], c[0]); + } + } + + + // The following algorithms are adaptations of + // http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf, + // taken from https://github.com/z-prize/test-msm-gpu (under Apache 2.0 license) + // and modified to use our datatypes. + // We had our own implementation of http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf, + // but the sppark versions achieved lower instruction count thanks to clever carry handling, + // so we decided to just use theirs. + template + static DEVICE_INLINE void madc_n_rshift(uint32_t *odd, const uint32_t *a, uint32_t bi) { +#pragma unroll + for (size_t i = 0; i < NLIMBS - 2; i += 2) { + odd[i] = ptx::madc_lo_cc(a[i], bi, odd[i + 2]); + odd[i + 1] = ptx::madc_hi_cc(a[i], bi, odd[i + 3]); + } + odd[NLIMBS - 2] = ptx::madc_lo_cc(a[NLIMBS - 2], bi, 0); + odd[NLIMBS - 1] = ptx::madc_hi(a[NLIMBS - 2], bi, 0); + } + + template + static DEVICE_INLINE void mad_n_redc(uint32_t *even, uint32_t *odd, const uint32_t *a, uint32_t bi, const uint32_t *modulus, const uint32_t *mont_inv_modulus, bool first = false) { + if (first) { + mul_n(odd, a + 1, bi); + mul_n(even, a, bi); + } else { + even[0] = ptx::add_cc(even[0], odd[1]); + madc_n_rshift(odd, a + 1, bi); + cmad_n(even, a, bi); + odd[NLIMBS - 1] = ptx::addc(odd[NLIMBS - 1], 0); + } + uint32_t mi = even[0] * mont_inv_modulus[0]; + cmad_n(odd, modulus + 1, mi); + cmad_n(even, modulus, mi); + odd[NLIMBS - 1] = ptx::addc(odd[NLIMBS - 1], 0); + } + + template + static DEVICE_INLINE void qad_row(uint32_t *odd, uint32_t *even, const uint32_t *a, uint32_t bi, size_t n = NLIMBS) { + cmad_n(odd, a, bi, n - 2); + odd[n - 2] = ptx::madc_lo_cc(a[n - 2], bi, 0); + odd[n - 1] = ptx::madc_hi(a[n - 2], bi, 0); + cmad_n(even, a + 1, bi, n - 2); + odd[n - 1] = ptx::addc(odd[n - 1], 0); + } + + //TODO: test if beeter than karatsuba + template + static DEVICE_INLINE void multiply_raw_sb(const storage &as, const storage &bs, storage<2*NLIMBS> &rs) { + const uint32_t *a = as.limbs; + const uint32_t *b = bs.limbs; + uint32_t *even = rs.limbs; + __align__(8) uint32_t odd[2 * NLIMBS - 2]; + mul_n(even, a, b[0]); + mul_n(odd, a + 1, b[0]); + mad_row(&even[2], &odd[0], a, b[1]); + size_t i; +#pragma unroll + for (i = 2; i < NLIMBS - 1; i += 2) { + mad_row(&odd[i], &even[i], a, b[i]); + mad_row(&even[i + 2], &odd[i], a, b[i + 1]); + } + // merge |even| and |odd| + even[1] = ptx::add_cc(even[1], odd[0]); + for (i = 1; i < 2 * NLIMBS - 2; i++) + even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); + even[i + 1] = ptx::addc(even[i + 1], 0); + } + + template + static DEVICE_INLINE void sqr_raw(const storage &as, storage<2*NLIMBS> &rs) { + const uint32_t *a = as.limbs; + uint32_t *even = rs.limbs; + size_t i = 0, j; + __align__(8) uint32_t odd[2 * NLIMBS - 2]; + + // perform |a[i]|*|a[j]| for all j>i + mul_n(even + 2, a + 2, a[0], NLIMBS - 2); + mul_n(odd, a + 1, a[0], NLIMBS); + +#pragma unroll + while (i < NLIMBS - 4) { + ++i; + mad_row(&even[2 * i + 2], &odd[2 * i], &a[i + 1], a[i], NLIMBS - i - 1); + ++i; + qad_row(&odd[2 * i], &even[2 * i + 2], &a[i + 1], a[i], NLIMBS - i); + } + + even[2 * NLIMBS - 4] = ptx::mul_lo(a[NLIMBS - 1], a[NLIMBS - 3]); + even[2 * NLIMBS - 3] = ptx::mul_hi(a[NLIMBS - 1], a[NLIMBS - 3]); + odd[2 * NLIMBS - 6] = ptx::mad_lo_cc(a[NLIMBS - 2], a[NLIMBS - 3], odd[2 * NLIMBS - 6]); + odd[2 * NLIMBS - 5] = ptx::madc_hi_cc(a[NLIMBS - 2], a[NLIMBS - 3], odd[2 * NLIMBS - 5]); + even[2 * NLIMBS - 3] = ptx::addc(even[2 * NLIMBS - 3], 0); + + odd[2 * NLIMBS - 4] = ptx::mul_lo(a[NLIMBS - 1], a[NLIMBS - 2]); + odd[2 * NLIMBS - 3] = ptx::mul_hi(a[NLIMBS - 1], a[NLIMBS - 2]); + + // merge |even[2:]| and |odd[1:]| + even[2] = ptx::add_cc(even[2], odd[1]); + for (j = 2; j < 2 * NLIMBS - 3; j++) + even[j + 1] = ptx::addc_cc(even[j + 1], odd[j]); + even[j + 1] = ptx::addc(odd[j], 0); + + // double |even| + even[0] = 0; + even[1] = ptx::add_cc(odd[0], odd[0]); + for (j = 2; j < 2 * NLIMBS - 1; j++) + even[j] = ptx::addc_cc(even[j], even[j]); + even[j] = ptx::addc(0, 0); + + // accumulate "diagonal" |a[i]|*|a[i]| product + i = 0; + even[2 * i] = ptx::mad_lo_cc(a[i], a[i], even[2 * i]); + even[2 * i + 1] = ptx::madc_hi_cc(a[i], a[i], even[2 * i + 1]); + for (++i; i < NLIMBS; i++) { + even[2 * i] = ptx::madc_lo_cc(a[i], a[i], even[2 * i]); + even[2 * i + 1] = ptx::madc_hi_cc(a[i], a[i], even[2 * i + 1]); + } + } + + template + static DEVICE_INLINE void mul_by_1_row(uint32_t *even, uint32_t *odd, const uint32_t *modulus, const uint32_t *mont_inv_modulus, bool first = false) { + uint32_t mi; + if (first) { + mi = even[0] * mont_inv_modulus[0]; + mul_n(odd, modulus + 1, mi); + cmad_n(even, modulus, mi); + odd[NLIMBS - 1] = ptx::addc(odd[NLIMBS - 1], 0); + } else { + even[0] = ptx::add_cc(even[0], odd[1]); + // we trust the compiler to *not* touch the carry flag here + // this code sits in between two "asm volatile" instructions witch should guarantee that nothing else interferes wit the carry flag + mi = even[0] * mont_inv_modulus[0]; + madc_n_rshift(odd, modulus + 1, mi); + cmad_n(even, modulus, mi); + odd[NLIMBS - 1] = ptx::addc(odd[NLIMBS - 1], 0); + } + } + + // Performs Montgomery reduction on a storage<2*NLIMBS> input. Input value must be in the range [0, mod*2^(32*NLIMBS)). + // Does not implement an in-place reduce epilogue. If you want to further reduce the result, + // call reduce(xs.get_lo()) after the call to redc_wide_inplace. + template + static DEVICE_INLINE void reduce_mont_inplace(storage<2*NLIMBS> &xs, const storage &modulus, const storage &mont_inv_modulus) { + uint32_t *even = xs.limbs; + // Yields montmul of lo NLIMBS limbs * 1. + // Since the hi NLIMBS limbs don't participate in computing the "mi" factor at each mul-and-rightshift stage, + // it's ok to ignore the hi NLIMBS limbs during this process and just add them in afterward. + uint32_t odd[NLIMBS]; + size_t i; +#pragma unroll + for (i = 0; i < NLIMBS; i += 2) { + mul_by_1_row(&even[0], &odd[0], modulus.limbs, mont_inv_modulus.limbs, i == 0); + mul_by_1_row(&odd[0], &even[0], modulus.limbs, mont_inv_modulus.limbs); + } + even[0] = ptx::add_cc(even[0], odd[1]); +#pragma unroll + for (i = 1; i < NLIMBS - 1; i++) + even[i] = ptx::addc_cc(even[i], odd[i + 1]); + even[i] = ptx::addc(even[i], 0); + // Adds in (hi NLIMBS limbs), implicitly right-shifting them by NLIMBS limbs as if they had participated in the + // add-and-rightshift stages above. + xs.limbs[0] = ptx::add_cc(xs.limbs[0], xs.limbs[NLIMBS]); +#pragma unroll + for (i = 1; i < NLIMBS - 1; i++) + xs.limbs[i] = ptx::addc_cc(xs.limbs[i], xs.limbs[i + NLIMBS]); + xs.limbs[NLIMBS - 1] = ptx::addc(xs.limbs[NLIMBS - 1], xs.limbs[2 * NLIMBS - 1]); + } + + template + static DEVICE_INLINE void montmul_raw(const storage &a_in, const storage &b_in, const storage &modulus, const storage &mont_inv_modulus, storage &r_in) { + const uint32_t *a = a_in.limbs; + const uint32_t *b = b_in.limbs; + uint32_t *even = r_in.limbs; + __align__(8) uint32_t odd[NLIMBS + 1]; + size_t i; +#pragma unroll + for (i = 0; i < NLIMBS; i += 2) { + mad_n_redc(&even[0], &odd[0], a, b[i], modulus.limbs, mont_inv_modulus.limbs, i == 0); + mad_n_redc(&odd[0], &even[0], a, b[i + 1], modulus.limbs, mont_inv_modulus.limbs); + } + // merge |even| and |odd| + even[0] = ptx::add_cc(even[0], odd[1]); +#pragma unroll + for (i = 1; i < NLIMBS - 1; i++) + even[i] = ptx::addc_cc(even[i], odd[i + 1]); + even[i] = ptx::addc(even[i], 0); + // final reduction from [0, 2*mod) to [0, mod) not done here, instead performed optionally in mul_device wrapper + } + + + // Device path adapts http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf to use IMAD.WIDE. + template static constexpr DEVICE_INLINE storage mulmont_device(const storage &xs, const storage &ys, const storage &modulus, const storage &mont_inv_modulus) { + // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes of slack + // static_assert(!(CONFIG::modulus.limbs[NLIMBS - 1] >> 30)); + // printf(" "); + storage rs = {0}; + montmul_raw(xs, ys, modulus, mont_inv_modulus, rs); + return rs; + } + + template static constexpr DEVICE_INLINE storage sqrmont_device(const storage &xs) { + // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes of slack + // static_assert(!(CONFIG::modulus.limbs[NLIMBS - 1] >> 30)); + storage<2*NLIMBS> rs = {0}; + sqr_raw(xs, rs); + redc_wide_inplace(rs); // after reduce_twopass, tmp's low NLIMBS limbs should represent a value in [0, 2*mod) + return rs.get_lo(); + } +// //add +// // return xs * ys with field operands +// // Device path adapts http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf to use IMAD.WIDE. +// // Host path uses CIOS. +// template static constexpr DEVICE_INLINE storage mulz(const storage &xs, const storage &ys) { +// return mul_devicez(xs, ys); +// } +} + +#endif \ No newline at end of file diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 8a779b2bb..1e2f211a2 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -25,6 +25,7 @@ #include "icicle/errors.h" #include "host_math.h" +#include "device_math.h" #include "storage.h" #include @@ -35,11 +36,6 @@ using namespace icicle; -// #ifdef __CUDA_ARCH__ -// __device__ __location__(constant) uint32_t FF_BLS12377_INV = 0xffffffff; -// #else -// static constexpr uint32_t FF_BLS12377_INV = 0xffffffff; -// #endif // __CUDA_ARCH__ template class Field { @@ -281,36 +277,12 @@ class Field } } -#ifdef __CUDACC__ - template - static constexpr DEVICE_INLINE uint32_t add_sub_u32_device(const uint32_t* x, const uint32_t* y, uint32_t* r) - { - r[0] = SUBTRACT ? ptx::sub_cc(x[0], y[0]) : ptx::add_cc(x[0], y[0]); - for (unsigned i = 1; i < NLIMBS; i++) - r[i] = SUBTRACT ? ptx::subc_cc(x[i], y[i]) : ptx::addc_cc(x[i], y[i]); - if (!CARRY_OUT) { - ptx::addc(0, 0); - return 0; - } - return SUBTRACT ? ptx::subc(0, 0) : ptx::addc(0, 0); - } - - template - static constexpr DEVICE_INLINE uint32_t - add_sub_limbs_device(const storage& xs, const storage& ys, storage& rs) - { - const uint32_t* x = xs.limbs; - const uint32_t* y = ys.limbs; - uint32_t* r = rs.limbs; - return add_sub_u32_device(x, y, r); - } -#endif // __CUDACC__ template static constexpr HOST_DEVICE_INLINE uint32_t add_limbs(const storage& xs, const storage& ys, storage& rs) { #ifdef __CUDA_ARCH__ - return add_sub_limbs_device(xs, ys, rs); + return device_math::template add_sub_limbs_device(xs, ys, rs); #else return host_math::template add_sub_limbs(xs, ys, rs); #endif @@ -321,827 +293,16 @@ class Field sub_limbs(const storage& xs, const storage& ys, storage& rs) { #ifdef __CUDA_ARCH__ - return add_sub_limbs_device(xs, ys, rs); + return device_math::template add_sub_limbs_device(xs, ys, rs); #else return host_math::template add_sub_limbs(xs, ys, rs); #endif } -#ifdef __CUDACC__ - -template struct carry_chainz { - unsigned index; - - constexpr __device__ __forceinline__ carry_chainz() : index(0) {} - - __device__ __forceinline__ uint32_t add(const uint32_t x, const uint32_t y) { - index++; - if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) - return ptx::add(x, y); - else if (index == 1 && !CARRY_IN) - return ptx::add_cc(x, y); - else if (index < OPS_COUNT || CARRY_OUT) - return ptx::addc_cc(x, y); - else - return ptx::addc(x, y); - } - - __device__ __forceinline__ uint32_t sub(const uint32_t x, const uint32_t y) { - index++; - if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) - return ptx::sub(x, y); - else if (index == 1 && !CARRY_IN) - return ptx::sub_cc(x, y); - else if (index < OPS_COUNT || CARRY_OUT) - return ptx::subc_cc(x, y); - else - return ptx::subc(x, y); - } - - __device__ __forceinline__ uint32_t mad_lo(const uint32_t x, const uint32_t y, const uint32_t z) { - index++; - if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) - return ptx::mad_lo(x, y, z); - else if (index == 1 && !CARRY_IN) - return ptx::mad_lo_cc(x, y, z); - else if (index < OPS_COUNT || CARRY_OUT) - return ptx::madc_lo_cc(x, y, z); - else - return ptx::madc_lo(x, y, z); - } - - __device__ __forceinline__ uint32_t mad_hi(const uint32_t x, const uint32_t y, const uint32_t z) { - index++; - if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) - return ptx::mad_hi(x, y, z); - else if (index == 1 && !CARRY_IN) - return ptx::mad_hi_cc(x, y, z); - else if (index < OPS_COUNT || CARRY_OUT) - return ptx::madc_hi_cc(x, y, z); - else - return ptx::madc_hi(x, y, z); - } -}; - - - - // add or subtract limbs - template static constexpr DEVICE_INLINE uint32_t add_sub_limbs_devicez(const ff_storage &xs, const ff_storage &ys, ff_storage &rs) { - const uint32_t *x = xs.limbs; - const uint32_t *y = ys.limbs; - uint32_t *r = rs.limbs; - carry_chainz chain; -#pragma unroll - for (unsigned i = 0; i < TLC; i++) - r[i] = SUBTRACT ? chain.sub(x[i], y[i]) : chain.add(x[i], y[i]); - if (!CARRY_OUT) - return 0; - return SUBTRACT ? chain.sub(0, 0) : chain.add(0, 0); - } - - // If we want, we could make "2*TLC" a template parameter to deduplicate with "ff_storage" overload, but that's a minor issue. - template - static constexpr DEVICE_INLINE uint32_t add_sub_limbs_devicez(const ff_wide_storage &xs, const ff_wide_storage &ys, ff_wide_storage &rs) { - const uint32_t *x = xs.limbs; - const uint32_t *y = ys.limbs; - uint32_t *r = rs.limbs; - carry_chainz chain; -#pragma unroll - for (unsigned i = 0; i < 2 * TLC; i++) { - r[i] = SUBTRACT ? chain.sub(x[i], y[i]) : chain.add(x[i], y[i]); - } - if (!CARRY_OUT) - return 0; - return SUBTRACT ? chain.sub(0, 0) : chain.add(0, 0); - } - - template static constexpr DEVICE_INLINE uint32_t add_sub_limbsz(const T &xs, const T &ys, T &rs) { - // No need for static_assert(std::is_same::value || std::is_same::value). - // Instantiation will fail if appropriate add_sub_limbs_device overload does not exist. - return add_sub_limbs_devicez(xs, ys, rs); - } - - template static constexpr DEVICE_INLINE uint32_t add_limbsz(const T &xs, const T &ys, T &rs) { - return add_sub_limbsz(xs, ys, rs); - } - - template static constexpr DEVICE_INLINE uint32_t sub_limbsz(const T &xs, const T &ys, T &rs) { - return add_sub_limbsz(xs, ys, rs); - } - - // return xs == 0 with field operands - static constexpr DEVICE_INLINE bool is_zero_devicez(const ff_storage &xs) { - const uint32_t *x = xs.limbs; - uint32_t limbs_or = x[0]; -#pragma unroll - for (unsigned i = 1; i < TLC; i++) - limbs_or |= x[i]; - return limbs_or == 0; - } - - static constexpr DEVICE_INLINE bool is_zeroz(const ff_storage &xs) { - return is_zero_devicez(xs); - } - - // return xs == ys with field operands - static constexpr DEVICE_INLINE bool eq_devicez(const ff_storage &xs, const ff_storage &ys) { - const uint32_t *x = xs.limbs; - const uint32_t *y = ys.limbs; - uint32_t limbs_or = x[0] ^ y[0]; -#pragma unroll - for (unsigned i = 1; i < TLC; i++) - limbs_or |= x[i] ^ y[i]; - return limbs_or == 0; - } - - static constexpr DEVICE_INLINE bool eqz(const ff_storage &xs, const ff_storage &ys) { - return eq_devicez(xs, ys); - } - - template static constexpr DEVICE_INLINE ff_storage reducez(const ff_storage &xs) { - if (REDUCTION_SIZE == 0) - return xs; - const ff_storage modulus = get_modulus(); - ff_storage rs = {}; - return sub_limbsz(xs, modulus, rs) ? xs : rs; - } - - template static constexpr DEVICE_INLINE ff_wide_storage reduce_widez(const ff_wide_storage &xs) { - if (REDUCTION_SIZE == 0) - return xs; - const ff_wide_storage modulus_squared = get_modulus_squared(); - ff_wide_storage rs = {}; - return sub_limbsz(xs, modulus_squared, rs) ? xs : rs; - } - - // return xs + ys with field operands - template static constexpr DEVICE_INLINE ff_storage addz(const ff_storage &xs, const ff_storage &ys) { - ff_storage rs = {}; - add_limbsz(xs, ys, rs); - return reducez(rs); - } - - template static constexpr DEVICE_INLINE ff_wide_storage add_widez(const ff_wide_storage &xs, const ff_wide_storage &ys) { - ff_wide_storage rs = {}; - add_limbsz(xs, ys, rs); - return reduce_widez(rs); - } - - // return xs - ys with field operands - template static DEVICE_INLINE ff_storage subz(const ff_storage &xs, const ff_storage &ys) { - ff_storage rs = {}; - if (REDUCTION_SIZE == 0) { - sub_limbsz(xs, ys, rs); - } else { - uint32_t carry = sub_limbsz(xs, ys, rs); - if (carry == 0) - return rs; - const ff_storage modulus = get_modulus(); - add_limbsz(rs, modulus, rs); - } - return rs; - } - - template static DEVICE_INLINE ff_wide_storage sub_widez(const ff_wide_storage &xs, const ff_wide_storage &ys) { - ff_wide_storage rs = {}; - if (REDUCTION_SIZE == 0) { - sub_limbsz(xs, ys, rs); - } else { - uint32_t carry = sub_limbsz(xs, ys, rs); - if (carry == 0) - return rs; - const ff_wide_storage modulus_squared = get_modulus_squared(); - add_limbsz(rs, modulus_squared, rs); - } - return rs; - } - - - // The following algorithms are adaptations of - // http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf, - // taken from https://github.com/z-prize/test-msm-gpu (under Apache 2.0 license) - // and modified to use our datatypes. - // We had our own implementation of http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf, - // but the sppark versions achieved lower instruction count thanks to clever carry handling, - // so we decided to just use theirs. - -//change - static DEVICE_INLINE void mul_nz(uint32_t *acc, const uint32_t *a, uint32_t bi, size_t n = TLC) { -#pragma unroll - for (size_t i = 0; i < n; i += 2) { - acc[i] = ptx::mul_lo(a[i], bi); - acc[i + 1] = ptx::mul_hi(a[i], bi); - } - } - -//change - static DEVICE_INLINE void cmad_nz(uint32_t *acc, const uint32_t *a, uint32_t bi, size_t n = TLC) { - acc[0] = ptx::mad_lo_cc(a[0], bi, acc[0]); - acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); -#pragma unroll - for (size_t i = 2; i < n; i += 2) { - acc[i] = ptx::madc_lo_cc(a[i], bi, acc[i]); - acc[i + 1] = ptx::madc_hi_cc(a[i], bi, acc[i + 1]); - } - // return carry flag - } - - //add - static DEVICE_INLINE void madc_n_rshiftz(uint32_t *odd, const uint32_t *a, uint32_t bi) { - constexpr uint32_t n = TLC; -#pragma unroll - for (size_t i = 0; i < n - 2; i += 2) { - odd[i] = ptx::madc_lo_cc(a[i], bi, odd[i + 2]); - odd[i + 1] = ptx::madc_hi_cc(a[i], bi, odd[i + 3]); - } - odd[n - 2] = ptx::madc_lo_cc(a[n - 2], bi, 0); - odd[n - 1] = ptx::madc_hi(a[n - 2], bi, 0); - } - - //add - static DEVICE_INLINE void mad_n_redcz(uint32_t *even, uint32_t *odd, const uint32_t *a, uint32_t bi, bool first = false) { - constexpr uint32_t n = TLC; - constexpr auto modulus = CONFIG::modulus; - const uint32_t *const MOD = modulus.limbs; - constexpr auto mont_inv_modulus = CONFIG::mont_inv_modulus; - if (first) { - mul_nz(odd, a + 1, bi); - mul_nz(even, a, bi); - } else { - even[0] = ptx::add_cc(even[0], odd[1]); - madc_n_rshiftz(odd, a + 1, bi); - cmad_nz(even, a, bi); - odd[n - 1] = ptx::addc(odd[n - 1], 0); - } - uint32_t mi = even[0] * mont_inv_modulus.limbs[0]; - cmad_nz(odd, MOD + 1, mi); - cmad_nz(even, MOD, mi); - odd[n - 1] = ptx::addc(odd[n - 1], 0); - } - -//change - static DEVICE_INLINE void mad_rowz(uint32_t *odd, uint32_t *even, const uint32_t *a, uint32_t bi, size_t n = TLC) { - cmad_nz(odd, a + 1, bi, n - 2); - odd[n - 2] = ptx::madc_lo_cc(a[n - 1], bi, 0); - odd[n - 1] = ptx::madc_hi(a[n - 1], bi, 0); - cmad_nz(even, a, bi, n); - odd[n - 1] = ptx::addc(odd[n - 1], 0); - } - -//add - static DEVICE_INLINE void qad_rowz(uint32_t *odd, uint32_t *even, const uint32_t *a, uint32_t bi, size_t n = TLC) { - cmad_nz(odd, a, bi, n - 2); - odd[n - 2] = ptx::madc_lo_cc(a[n - 2], bi, 0); - odd[n - 1] = ptx::madc_hi(a[n - 2], bi, 0); - cmad_nz(even, a + 1, bi, n - 2); - odd[n - 1] = ptx::addc(odd[n - 1], 0); - } - -//change - static DEVICE_INLINE void multiply_rawz(const ff_storage &as, const ff_storage &bs, ff_wide_storage &rs) { - const uint32_t *a = as.limbs; - const uint32_t *b = bs.limbs; - uint32_t *even = rs.limbs; - __align__(8) uint32_t odd[2 * TLC - 2]; - mul_nz(even, a, b[0]); - mul_nz(odd, a + 1, b[0]); - mad_rowz(&even[2], &odd[0], a, b[1]); - size_t i; -#pragma unroll - for (i = 2; i < TLC - 1; i += 2) { - mad_rowz(&odd[i], &even[i], a, b[i]); - mad_rowz(&even[i + 2], &odd[i], a, b[i + 1]); - } - // merge |even| and |odd| - even[1] = ptx::add_cc(even[1], odd[0]); - for (i = 1; i < 2 * TLC - 2; i++) - even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); - even[i + 1] = ptx::addc(even[i + 1], 0); - } - - static DEVICE_INLINE void sqr_rawz(const ff_storage &as, ff_wide_storage &rs) { - const uint32_t *a = as.limbs; - uint32_t *even = rs.limbs; - size_t i = 0, j; - __align__(8) uint32_t odd[2 * TLC - 2]; - - // perform |a[i]|*|a[j]| for all j>i - mul_nz(even + 2, a + 2, a[0], TLC - 2); - mul_nz(odd, a + 1, a[0], TLC); - -#pragma unroll - while (i < TLC - 4) { - ++i; - mad_rowz(&even[2 * i + 2], &odd[2 * i], &a[i + 1], a[i], TLC - i - 1); - ++i; - qad_rowz(&odd[2 * i], &even[2 * i + 2], &a[i + 1], a[i], TLC - i); - } - - even[2 * TLC - 4] = ptx::mul_lo(a[TLC - 1], a[TLC - 3]); - even[2 * TLC - 3] = ptx::mul_hi(a[TLC - 1], a[TLC - 3]); - odd[2 * TLC - 6] = ptx::mad_lo_cc(a[TLC - 2], a[TLC - 3], odd[2 * TLC - 6]); - odd[2 * TLC - 5] = ptx::madc_hi_cc(a[TLC - 2], a[TLC - 3], odd[2 * TLC - 5]); - even[2 * TLC - 3] = ptx::addc(even[2 * TLC - 3], 0); - - odd[2 * TLC - 4] = ptx::mul_lo(a[TLC - 1], a[TLC - 2]); - odd[2 * TLC - 3] = ptx::mul_hi(a[TLC - 1], a[TLC - 2]); - - // merge |even[2:]| and |odd[1:]| - even[2] = ptx::add_cc(even[2], odd[1]); - for (j = 2; j < 2 * TLC - 3; j++) - even[j + 1] = ptx::addc_cc(even[j + 1], odd[j]); - even[j + 1] = ptx::addc(odd[j], 0); - - // double |even| - even[0] = 0; - even[1] = ptx::add_cc(odd[0], odd[0]); - for (j = 2; j < 2 * TLC - 1; j++) - even[j] = ptx::addc_cc(even[j], even[j]); - even[j] = ptx::addc(0, 0); - - // accumulate "diagonal" |a[i]|*|a[i]| product - i = 0; - even[2 * i] = ptx::mad_lo_cc(a[i], a[i], even[2 * i]); - even[2 * i + 1] = ptx::madc_hi_cc(a[i], a[i], even[2 * i + 1]); - for (++i; i < TLC; i++) { - even[2 * i] = ptx::madc_lo_cc(a[i], a[i], even[2 * i]); - even[2 * i + 1] = ptx::madc_hi_cc(a[i], a[i], even[2 * i + 1]); - } - } - -//add - static DEVICE_INLINE void mul_by_1_rowz(uint32_t *even, uint32_t *odd, bool first = false) { - uint32_t mi; - constexpr auto modulus = CONFIG::modulus; - const uint32_t *const MOD = modulus.limbs; - constexpr auto mont_inv_modulus = CONFIG::mont_inv_modulus; - if (first) { - mi = even[0] * mont_inv_modulus.limbs[0]; - mul_nz(odd, MOD + 1, mi); - cmad_nz(even, MOD, mi); - odd[TLC - 1] = ptx::addc(odd[TLC - 1], 0); - } else { - even[0] = ptx::add_cc(even[0], odd[1]); - // we trust the compiler to *not* touch the carry flag here - // this code sits in between two "asm volatile" instructions witch should guarantee that nothing else interferes wit the carry flag - mi = even[0] * mont_inv_modulus.limbs[0]; - madc_n_rshiftz(odd, MOD + 1, mi); - cmad_nz(even, MOD, mi); - odd[TLC - 1] = ptx::addc(odd[TLC - 1], 0); - } - } - -//add - // Performs Montgomery reduction on a ff_wide_storage input. Input value must be in the range [0, mod*2^(32*TLC)). - // Does not implement an in-place reduce epilogue. If you want to further reduce the result, - // call reduce(xs.get_lo()) after the call to redc_wide_inplace. - static DEVICE_INLINE void redc_wide_inplacez(ff_wide_storage &xs) { - uint32_t *even = xs.limbs; - // Yields montmul of lo TLC limbs * 1. - // Since the hi TLC limbs don't participate in computing the "mi" factor at each mul-and-rightshift stage, - // it's ok to ignore the hi TLC limbs during this process and just add them in afterward. - uint32_t odd[TLC]; - size_t i; -#pragma unroll - for (i = 0; i < TLC; i += 2) { - mul_by_1_rowz(&even[0], &odd[0], i == 0); - mul_by_1_rowz(&odd[0], &even[0]); - } - even[0] = ptx::add_cc(even[0], odd[1]); -#pragma unroll - for (i = 1; i < TLC - 1; i++) - even[i] = ptx::addc_cc(even[i], odd[i + 1]); - even[i] = ptx::addc(even[i], 0); - // Adds in (hi TLC limbs), implicitly right-shifting them by TLC limbs as if they had participated in the - // add-and-rightshift stages above. - xs.limbs[0] = ptx::add_cc(xs.limbs[0], xs.limbs[TLC]); -#pragma unroll - for (i = 1; i < TLC - 1; i++) - xs.limbs[i] = ptx::addc_cc(xs.limbs[i], xs.limbs[i + TLC]); - xs.limbs[TLC - 1] = ptx::addc(xs.limbs[TLC - 1], xs.limbs[2 * TLC - 1]); - } - -//add - static DEVICE_INLINE void montmul_rawz(const ff_storage &a_in, const ff_storage &b_in, ff_storage &r_in) { - constexpr uint32_t n = TLC; - constexpr auto modulus = CONFIG::modulus; - const uint32_t *const MOD = modulus.limbs; - const uint32_t *a = a_in.limbs; - const uint32_t *b = b_in.limbs; - uint32_t *even = r_in.limbs; - __align__(8) uint32_t odd[n + 1]; - size_t i; -#pragma unroll - for (i = 0; i < n; i += 2) { - mad_n_redcz(&even[0], &odd[0], a, b[i], i == 0); - mad_n_redcz(&odd[0], &even[0], a, b[i + 1]); - } - // merge |even| and |odd| - even[0] = ptx::add_cc(even[0], odd[1]); -#pragma unroll - for (i = 1; i < n - 1; i++) - even[i] = ptx::addc_cc(even[i], odd[i + 1]); - even[i] = ptx::addc(even[i], 0); - // final reduction from [0, 2*mod) to [0, mod) not done here, instead performed optionally in mul_device wrapper - } - -//change - // Returns xs * ys without Montgomery reduction. - template static constexpr DEVICE_INLINE ff_wide_storage mul_widez(const ff_storage &xs, const ff_storage &ys) { - // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes of slack - static_assert(!(CONFIG::modulus.limbs[TLC - 1] >> 30)); - ff_wide_storage rs = {0}; - multiply_rawz(xs, ys, rs); - return reduce_widez(rs); - } - -//add - // Performs Montgomery reduction on a ff_wide_storage input. Input value must be in the range [0, mod*2^(32*TLC)). - template static constexpr DEVICE_INLINE ff_storage redc_widez(const ff_wide_storage &xs) { - ff_wide_storage tmp{xs}; - redc_wide_inplacez(tmp); // after reduce_twopass, tmp's low TLC limbs should represent a value in [0, 2*mod) - return reducez(tmp.get_lo()); - } - -//add - template static constexpr DEVICE_INLINE ff_storage mul_devicez(const ff_storage &xs, const ff_storage &ys) { - // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes of slack - static_assert(!(CONFIG::modulus.limbs[TLC - 1] >> 30)); - // printf(" "); - ff_storage rs = {0}; - montmul_rawz(xs, ys, rs); - return reducez(rs); - } - - template static constexpr DEVICE_INLINE ff_storage sqr_devicez(const ff_storage &xs) { - // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes of slack - static_assert(!(CONFIG::modulus.limbs[TLC - 1] >> 30)); - ff_wide_storage rs = {0}; - sqr_rawz(xs, rs); - redc_wide_inplacez(rs); // after reduce_twopass, tmp's low TLC limbs should represent a value in [0, 2*mod) - return reducez(rs.get_lo()); - } - -//add - // return xs * ys with field operands - // Device path adapts http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf to use IMAD.WIDE. - // Host path uses CIOS. - template static constexpr DEVICE_INLINE ff_storage mulz(const ff_storage &xs, const ff_storage &ys) { - return mul_devicez(xs, ys); - } - - - - - - - - - - static DEVICE_INLINE void mul_n(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = TLC) - { - UNROLL - for (size_t i = 0; i < n; i += 2) { - acc[i] = ptx::mul_lo(a[i], bi); - acc[i + 1] = ptx::mul_hi(a[i], bi); - } - } - - static DEVICE_INLINE void mul_n_msb(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = TLC, size_t start_i = 0) - { - UNROLL - for (size_t i = start_i; i < n; i += 2) { - acc[i] = ptx::mul_lo(a[i], bi); - acc[i + 1] = ptx::mul_hi(a[i], bi); - } - } - - template - static DEVICE_INLINE void - cmad_n(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = TLC, uint32_t optional_carry = 0) - { - if (CARRY_IN) ptx::add_cc(UINT32_MAX, optional_carry); - acc[0] = CARRY_IN ? ptx::madc_lo_cc(a[0], bi, acc[0]) : ptx::mad_lo_cc(a[0], bi, acc[0]); - acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); - - UNROLL - for (size_t i = 2; i < n; i += 2) { - acc[i] = ptx::madc_lo_cc(a[i], bi, acc[i]); - acc[i + 1] = ptx::madc_hi_cc(a[i], bi, acc[i + 1]); - } - } - - template - static DEVICE_INLINE void cmad_n_msb(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = TLC) - { - if (EVEN_PHASE) { - acc[0] = ptx::mad_lo_cc(a[0], bi, acc[0]); - acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); - } else { - acc[1] = ptx::mad_hi_cc(a[0], bi, acc[1]); - } - - UNROLL - for (size_t i = 2; i < n; i += 2) { - acc[i] = ptx::madc_lo_cc(a[i], bi, acc[i]); - acc[i + 1] = ptx::madc_hi_cc(a[i], bi, acc[i + 1]); - } - } - - static DEVICE_INLINE void cmad_n_lsb(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = TLC) - { - if (n > 1) - acc[0] = ptx::mad_lo_cc(a[0], bi, acc[0]); - else - acc[0] = ptx::mad_lo(a[0], bi, acc[0]); - - size_t i; - UNROLL - for (i = 1; i < n - 1; i += 2) { - acc[i] = ptx::madc_hi_cc(a[i - 1], bi, acc[i]); - if (i == n - 2) - acc[i + 1] = ptx::madc_lo(a[i + 1], bi, acc[i + 1]); - else - acc[i + 1] = ptx::madc_lo_cc(a[i + 1], bi, acc[i + 1]); - } - if (i == n - 1) acc[i] = ptx::madc_hi(a[i - 1], bi, acc[i]); - } - - template - static DEVICE_INLINE uint32_t mad_row( - uint32_t* odd, - uint32_t* even, - const uint32_t* a, - uint32_t bi, - size_t n = TLC, - uint32_t ci = 0, - uint32_t di = 0, - uint32_t carry_for_high = 0, - uint32_t carry_for_low = 0) - { - cmad_n(odd, a + 1, bi, n - 2, carry_for_low); - odd[n - 2] = ptx::madc_lo_cc(a[n - 1], bi, ci); - odd[n - 1] = CARRY_OUT ? ptx::madc_hi_cc(a[n - 1], bi, di) : ptx::madc_hi(a[n - 1], bi, di); - uint32_t cr = CARRY_OUT ? ptx::addc(0, 0) : 0; - cmad_n(even, a, bi, n); - if (CARRY_OUT) { - odd[n - 1] = ptx::addc_cc(odd[n - 1], carry_for_high); - cr = ptx::addc(cr, 0); - } else - odd[n - 1] = ptx::addc(odd[n - 1], carry_for_high); - return cr; - } - - template - static DEVICE_INLINE void mad_row_msb(uint32_t* odd, uint32_t* even, const uint32_t* a, uint32_t bi, size_t n = TLC) - { - cmad_n_msb(odd, EVEN_PHASE ? a : (a + 1), bi, n - 2); - odd[EVEN_PHASE ? (n - 1) : (n - 2)] = ptx::madc_lo_cc(a[n - 1], bi, 0); - odd[EVEN_PHASE ? n : (n - 1)] = ptx::madc_hi(a[n - 1], bi, 0); - cmad_n_msb(even, EVEN_PHASE ? (a + 1) : a, bi, n - 1); - odd[EVEN_PHASE ? n : (n - 1)] = ptx::addc(odd[EVEN_PHASE ? n : (n - 1)], 0); - } - - static DEVICE_INLINE void mad_row_lsb(uint32_t* odd, uint32_t* even, const uint32_t* a, uint32_t bi, size_t n = TLC) - { - // bi here is constant so we can do a compile-time check for zero (which does happen once for bls12-381 scalar field - // modulus) - if (bi != 0) { - if (n > 1) cmad_n_lsb(odd, a + 1, bi, n - 1); - cmad_n_lsb(even, a, bi, n); - } - return; - } - - static DEVICE_INLINE uint32_t - mul_n_and_add(uint32_t* acc, const uint32_t* a, uint32_t bi, uint32_t* extra, size_t n = (TLC >> 1)) - { - acc[0] = ptx::mad_lo_cc(a[0], bi, extra[0]); - - UNROLL - for (size_t i = 1; i < n - 1; i += 2) { - acc[i] = ptx::madc_hi_cc(a[i - 1], bi, extra[i]); - acc[i + 1] = ptx::madc_lo_cc(a[i + 1], bi, extra[i + 1]); - } - - acc[n - 1] = ptx::madc_hi_cc(a[n - 2], bi, extra[n - 1]); - return ptx::addc(0, 0); - } - - /** - * A function that computes wide product \f$ rs = as \cdot bs \f$ that's correct for the higher TLC + 1 limbs with a - * small maximum error. - * - * The way this function saves computations (as compared to regular school-book multiplication) is by not including - * terms that are too small. Namely, limb product \f$ a_i \cdot b_j \f$ is excluded if \f$ i + j < TLC - 2 \f$ and - * only the higher half is included if \f$ i + j = TLC - 2 \f$. All other limb products are included. So, the error - * i.e. difference between true product and the result of this function written to `rs` is exactly the sum of all - * dropped limbs products, which we can bound: \f$ a_0 \cdot b_0 + 2^{32}(a_0 \cdot b_1 + a_1 \cdot b_0) + \dots + - * 2^{32(TLC - 3)}(a_{TLC - 3} \cdot b_0 + \dots + a_0 \cdot b_{TLC - 3}) + 2^{32(TLC - 2)}(\floor{\frac{a_{TLC - 2} - * \cdot b_0}{2^{32}}} + \dots + \floor{\frac{a_0 \cdot b_{TLC - 2}}{2^{32}}}) \leq 2^{64} + 2\cdot 2^{96} + \dots + - * (TLC - 2) \cdot 2^{32(TLC - 1)} + (TLC - 1) \cdot 2^{32(TLC - 1)} \leq 2(TLC - 1) \cdot 2^{32(TLC - 1)}\f$. - */ - static DEVICE_INLINE void multiply_msb_raw_device(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs) - { - if constexpr (TLC > 1) { - const uint32_t* a = as.limbs; - const uint32_t* b = bs.limbs; - uint32_t* even = rs.limbs; - __align__(16) uint32_t odd[2 * TLC - 2]; - - even[TLC - 1] = ptx::mul_hi(a[TLC - 2], b[0]); - odd[TLC - 2] = ptx::mul_lo(a[TLC - 1], b[0]); - odd[TLC - 1] = ptx::mul_hi(a[TLC - 1], b[0]); - size_t i; - UNROLL - for (i = 2; i < TLC - 1; i += 2) { - mad_row_msb(&even[TLC - 2], &odd[TLC - 2], &a[TLC - i - 1], b[i - 1], i + 1); - mad_row_msb(&odd[TLC - 2], &even[TLC - 2], &a[TLC - i - 2], b[i], i + 2); - } - mad_row(&even[TLC], &odd[TLC - 2], a, b[TLC - 1]); - - // merge |even| and |odd| - ptx::add_cc(even[TLC - 1], odd[TLC - 2]); - for (i = TLC - 1; i < 2 * TLC - 2; i++) - even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); - even[i + 1] = ptx::addc(even[i + 1], 0); - } else { - multiply_raw_device(as, bs, rs); - } - } - - /** - * A function that computes the low half of the fused multiply-and-add \f$ rs = as \cdot bs + cs \f$ where - * \f$ bs = 2^{32*nof_limbs} \f$. - * - * For efficiency, this method does not include terms that are too large. Namely, limb product \f$ a_i \cdot b_j \f$ - * is excluded if \f$ i + j > TLC - 1 \f$ and only the lower half is included if \f$ i + j = TLC - 1 \f$. All other - * limb products are included. - */ - static DEVICE_INLINE void - multiply_and_add_lsb_neg_modulus_raw_device(const ff_storage& as, ff_storage& cs, ff_storage& rs) - { - ff_storage bs = get_neg_modulus(); - const uint32_t* a = as.limbs; - const uint32_t* b = bs.limbs; - uint32_t* c = cs.limbs; - uint32_t* even = rs.limbs; - - if constexpr (TLC > 2) { - __align__(16) uint32_t odd[TLC - 1]; - size_t i; - // `b[0]` is \f$ 2^{32} \f$ minus the last limb of prime modulus. Because most scalar (and some base) primes - // are necessarily NTT-friendly, `b[0]` often turns out to be \f$ 2^{32} - 1 \f$. This actually leads to - // less efficient SASS generated by nvcc, so this case needed separate handling. - if (b[0] == UINT32_MAX) { - add_sub_u32_device(c, a, even); - for (i = 0; i < TLC - 1; i++) - odd[i] = a[i]; - } else { - mul_n_and_add(even, a, b[0], c, TLC); - mul_n(odd, a + 1, b[0], TLC - 1); - } - mad_row_lsb(&even[2], &odd[0], a, b[1], TLC - 1); - UNROLL - for (i = 2; i < TLC - 1; i += 2) { - mad_row_lsb(&odd[i], &even[i], a, b[i], TLC - i); - mad_row_lsb(&even[i + 2], &odd[i], a, b[i + 1], TLC - i - 1); - } - - // merge |even| and |odd| - even[1] = ptx::add_cc(even[1], odd[0]); - for (i = 1; i < TLC - 2; i++) - even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); - even[i + 1] = ptx::addc(even[i + 1], odd[i]); - } else if (TLC == 2) { - even[0] = ptx::mad_lo(a[0], b[0], c[0]); - even[1] = ptx::mad_hi(a[0], b[0], c[0]); - even[1] = ptx::mad_lo(a[0], b[1], even[1]); - even[1] = ptx::mad_lo(a[1], b[0], even[1]); - } else if (TLC == 1) { - even[0] = ptx::mad_lo(a[0], b[0], c[0]); - } - } - - /** - * This method multiplies `a` and `b` (both assumed to have TLC / 2 limbs) and adds `in1` and `in2` (TLC limbs each) - * to the result which is written to `even`. - * - * It is used to compute the "middle" part of Karatsuba: \f$ a_{lo} \cdot b_{hi} + b_{lo} \cdot a_{hi} = - * (a_{hi} - a_{lo})(b_{lo} - b_{hi}) + a_{lo} \cdot b_{lo} + a_{hi} \cdot b_{hi} \f$. Currently this method assumes - * that the top bit of \f$ a_{hi} \f$ and \f$ b_{hi} \f$ are unset. This ensures correctness by allowing to keep the - * result inside TLC limbs and ignore the carries from the highest limb. - */ - static DEVICE_INLINE void - multiply_and_add_short_raw_device(const uint32_t* a, const uint32_t* b, uint32_t* even, uint32_t* in1, uint32_t* in2) - { - __align__(16) uint32_t odd[TLC - 2]; - uint32_t first_row_carry = mul_n_and_add(even, a, b[0], in1); - uint32_t carry = mul_n_and_add(odd, a + 1, b[0], &in2[1]); - - size_t i; - UNROLL - for (i = 2; i < ((TLC >> 1) - 1); i += 2) { - carry = mad_row( - &even[i], &odd[i - 2], a, b[i - 1], TLC >> 1, in1[(TLC >> 1) + i - 2], in1[(TLC >> 1) + i - 1], carry); - carry = - mad_row(&odd[i], &even[i], a, b[i], TLC >> 1, in2[(TLC >> 1) + i - 1], in2[(TLC >> 1) + i], carry); - } - mad_row( - &even[TLC >> 1], &odd[(TLC >> 1) - 2], a, b[(TLC >> 1) - 1], TLC >> 1, in1[TLC - 2], in1[TLC - 1], carry, - first_row_carry); - // merge |even| and |odd| plus the parts of `in2` we haven't added yet (first and last limbs) - even[0] = ptx::add_cc(even[0], in2[0]); - for (i = 0; i < (TLC - 2); i++) - even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); - even[i + 1] = ptx::addc(even[i + 1], in2[i + 1]); - } - - /** - * This method multiplies `a` and `b` and writes the result into `even`. It assumes that `a` and `b` are TLC/2 limbs - * long. The usual schoolbook algorithm is used. - */ - static DEVICE_INLINE void multiply_short_raw_device(const uint32_t* a, const uint32_t* b, uint32_t* even) - { - __align__(16) uint32_t odd[TLC - 2]; - mul_n(even, a, b[0], TLC >> 1); - mul_n(odd, a + 1, b[0], TLC >> 1); - mad_row(&even[2], &odd[0], a, b[1], TLC >> 1); - - size_t i; - UNROLL - for (i = 2; i < ((TLC >> 1) - 1); i += 2) { - mad_row(&odd[i], &even[i], a, b[i], TLC >> 1); - mad_row(&even[i + 2], &odd[i], a, b[i + 1], TLC >> 1); - } - // merge |even| and |odd| - even[1] = ptx::add_cc(even[1], odd[0]); - for (i = 1; i < TLC - 2; i++) - even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); - even[i + 1] = ptx::addc(even[i + 1], 0); - } - - /** - * This method multiplies `as` and `bs` and writes the (wide) result into `rs`. - * - * It is assumed that the highest bits of `as` and `bs` are unset which is true for all the numbers icicle had to deal - * with so far. This method implements [subtractive - * Karatsuba](https://en.wikipedia.org/wiki/Karatsuba_algorithm#Implementation). - */ - static DEVICE_INLINE void multiply_raw_device(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs) - { - const uint32_t* a = as.limbs; - const uint32_t* b = bs.limbs; - uint32_t* r = rs.limbs; - if constexpr (TLC > 2) { - // Next two lines multiply high and low halves of operands (\f$ a_{lo} \cdot b_{lo}; a_{hi} \cdot b_{hi} \$f) and - // write the results into `r`. - multiply_short_raw_device(a, b, r); - multiply_short_raw_device(&a[TLC >> 1], &b[TLC >> 1], &r[TLC]); - __align__(16) uint32_t middle_part[TLC]; - __align__(16) uint32_t diffs[TLC]; - // Differences of halves \f$ a_{hi} - a_{lo}; b_{lo} - b_{hi} \$f are written into `diffs`, signs written to - // `carry1` and `carry2`. - uint32_t carry1 = add_sub_u32_device<(TLC >> 1), true, true>(&a[TLC >> 1], a, diffs); - uint32_t carry2 = add_sub_u32_device<(TLC >> 1), true, true>(b, &b[TLC >> 1], &diffs[TLC >> 1]); - // Compute the "middle part" of Karatsuba: \f$ a_{lo} \cdot b_{hi} + b_{lo} \cdot a_{hi} \f$. - // This is where the assumption about unset high bit of `a` and `b` is relevant. - multiply_and_add_short_raw_device(diffs, &diffs[TLC >> 1], middle_part, r, &r[TLC]); - // Corrections that need to be performed when differences are negative. - // Again, carry doesn't need to be propagated due to unset high bits of `a` and `b`. - if (carry1) - add_sub_u32_device<(TLC >> 1), true, false>(&middle_part[TLC >> 1], &diffs[TLC >> 1], &middle_part[TLC >> 1]); - if (carry2) add_sub_u32_device<(TLC >> 1), true, false>(&middle_part[TLC >> 1], diffs, &middle_part[TLC >> 1]); - // Now that middle part is fully correct, it can be added to the result. - add_sub_u32_device(&r[TLC >> 1], middle_part, &r[TLC >> 1]); - - // Carry from adding middle part has to be propagated to the highest limb. - for (size_t i = TLC + (TLC >> 1); i < 2 * TLC; i++) - r[i] = ptx::addc_cc(r[i], 0); - } else if (TLC == 2) { - __align__(8) uint32_t odd[2]; - r[0] = ptx::mul_lo(a[0], b[0]); - r[1] = ptx::mul_hi(a[0], b[0]); - r[2] = ptx::mul_lo(a[1], b[1]); - r[3] = ptx::mul_hi(a[1], b[1]); - odd[0] = ptx::mul_lo(a[0], b[1]); - odd[1] = ptx::mul_hi(a[0], b[1]); - odd[0] = ptx::mad_lo(a[1], b[0], odd[0]); - odd[1] = ptx::mad_hi(a[1], b[0], odd[1]); - r[1] = ptx::add_cc(r[1], odd[0]); - r[2] = ptx::addc_cc(r[2], odd[1]); - r[3] = ptx::addc(r[3], 0); - } else if (TLC == 1) { - r[0] = ptx::mul_lo(a[0], b[0]); - r[1] = ptx::mul_hi(a[0], b[0]); - } - } - -#endif // __CUDACC__ static HOST_DEVICE_INLINE void multiply_raw(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs) { #ifdef __CUDA_ARCH__ - return multiply_raw_device(as, bs, rs); + return device_math::template multiply_raw_device(as, bs, rs); #else return host_math::template multiply_raw(as, bs, rs); #endif @@ -1151,7 +312,7 @@ template (as, get_neg_modulus(), r_wide.limbs_storage); @@ -1163,7 +324,7 @@ template (as, bs, rs); #endif @@ -1278,7 +439,7 @@ template - static constexpr HOST_DEVICE_INLINE Field reduce(const Wide& xs) + static constexpr HOST_DEVICE_INLINE Field reduce(const Wide& xs) //TODO = add reduce_mont_inplace { // `xs` is left-shifted by `2 * slack_bits` and higher half is written to `xs_hi` Field xs_hi = Wide::get_higher_with_slack(xs); @@ -1345,14 +506,16 @@ template (Field{device_math::template mulmont_device(xs.limbs_storage,ys.limbs_storage,get_modulus<1>(),get_mont_inv_modulus())}); + #endif #else #ifdef BARRET Wide xy = mul_wide(xs, ys); // full mult return reduce(xy); // reduce mod p - #else + // return Wide::get_lower(xy); + #else return mont_mult(xs,ys); #endif #endif @@ -1670,6 +833,7 @@ template (xs); + #endif + #ifndef BARRET + mul = to_montgomery(mul); + #endif return mul * xs; } diff --git a/icicle/include/icicle/fields/params_gen.h b/icicle/include/icicle/fields/params_gen.h index 194c9662b..e174a94d6 100644 --- a/icicle/include/icicle/fields/params_gen.h +++ b/icicle/include/icicle/fields/params_gen.h @@ -173,7 +173,7 @@ namespace params_gen { static constexpr unsigned num_of_reductions = \ params_gen::template num_of_reductions(modulus, m); -#define TWIDDLES(modulus, rou) \ +#define TWIDDLES(modulus) \ static constexpr unsigned omegas_count = params_gen::template two_adicity(modulus); \ static constexpr storage_array inv = \ params_gen::template get_invs(modulus, montgomery_r_sqr, mont_inv_modulus); diff --git a/icicle/include/icicle/fields/snark_fields/bls12_377_base.h b/icicle/include/icicle/fields/snark_fields/bls12_377_base.h index 3feccb729..1666517dd 100644 --- a/icicle/include/icicle/fields/snark_fields/bls12_377_base.h +++ b/icicle/include/icicle/fields/snark_fields/bls12_377_base.h @@ -13,7 +13,7 @@ namespace bls12_377 { static constexpr storage<12> rou = {0xc563b9a1, 0x7eca603c, 0x06fe0bc3, 0x06df0a43, 0x0ddff8c6, 0xb44d994a, 0x4512a3d4, 0x40fbe05b, 0x8aeffc9b, 0x30f15248, 0x05198a80, 0x0036a92e}; - TWIDDLES(modulus, rou) + TWIDDLES(modulus) // nonresidue to generate the extension field static constexpr uint32_t nonresidue = 5; diff --git a/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h b/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h index 2e366ae53..e047144a0 100644 --- a/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h +++ b/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h @@ -13,7 +13,7 @@ namespace bls12_377 { static constexpr storage<8> rou = {0xec2a895e, 0x476ef4a4, 0x63e3f04a, 0x9b506ee3, 0xd1a8a12f, 0x60c69477, 0x0cb92cc1, 0x11d4b7f6}; - TWIDDLES(modulus, rou) + TWIDDLES(modulus) }; /** diff --git a/icicle/include/icicle/fields/snark_fields/bls12_381_scalar.h b/icicle/include/icicle/fields/snark_fields/bls12_381_scalar.h index 3d1027534..0fce73617 100644 --- a/icicle/include/icicle/fields/snark_fields/bls12_381_scalar.h +++ b/icicle/include/icicle/fields/snark_fields/bls12_381_scalar.h @@ -14,7 +14,7 @@ namespace bls12_381 { static constexpr storage<8> rou = {0x0b912f1f, 0x1b788f50, 0x70b3e094, 0xc4024ff2, 0xd168d6c0, 0x0fd56dc8, 0x5b416b6f, 0x0212d79e}; - TWIDDLES(modulus, rou) + TWIDDLES(modulus) }; /** diff --git a/icicle/include/icicle/fields/snark_fields/bn254_scalar.h b/icicle/include/icicle/fields/snark_fields/bn254_scalar.h index edaecee20..957779f31 100644 --- a/icicle/include/icicle/fields/snark_fields/bn254_scalar.h +++ b/icicle/include/icicle/fields/snark_fields/bn254_scalar.h @@ -12,7 +12,7 @@ namespace bn254 { static constexpr storage<8> rou = {0x725b19f0, 0x9bd61b6e, 0x41112ed4, 0x402d111e, 0x8ef62abc, 0x00e0a7eb, 0xa58a7e85, 0x2a3c09f0}; - TWIDDLES(modulus, rou) + TWIDDLES(modulus) }; /** diff --git a/icicle/include/icicle/fields/stark_fields/babybear.h b/icicle/include/icicle/fields/stark_fields/babybear.h index 8000431aa..8ca7b0c39 100644 --- a/icicle/include/icicle/fields/stark_fields/babybear.h +++ b/icicle/include/icicle/fields/stark_fields/babybear.h @@ -11,7 +11,7 @@ namespace babybear { PARAMS(modulus) static constexpr storage<1> rou = {0x00000089}; - TWIDDLES(modulus, rou) + TWIDDLES(modulus) // nonresidue to generate the extension field static constexpr uint32_t nonresidue = 11; diff --git a/icicle/include/icicle/fields/stark_fields/stark252.h b/icicle/include/icicle/fields/stark_fields/stark252.h index 6fb3da62e..521fb4de2 100644 --- a/icicle/include/icicle/fields/stark_fields/stark252.h +++ b/icicle/include/icicle/fields/stark_fields/stark252.h @@ -14,7 +14,7 @@ namespace stark252 { static constexpr storage<8> rou = {0x42f8ef94, 0x6070024f, 0xe11a6161, 0xad187148, 0x9c8b0fa5, 0x3f046451, 0x87529cfa, 0x005282db}; - TWIDDLES(modulus, rou) + TWIDDLES(modulus) }; /** diff --git a/icicle/tests/test_curve_api.cpp b/icicle/tests/test_curve_api.cpp index 05e22fc46..e7e97d899 100644 --- a/icicle/tests/test_curve_api.cpp +++ b/icicle/tests/test_curve_api.cpp @@ -46,6 +46,11 @@ class CurveApiTest : public ::testing::Test if (!is_cuda_registered) { ICICLE_LOG_ERROR << "CUDA device not found. Testing CPU vs CPU"; } s_main_target = is_cuda_registered ? "CUDA" : "CPU"; s_ref_target = "CPU"; + #ifdef BARRET + ICICLE_LOG_INFO << "USING BARRET MULT\n"; + #else + ICICLE_LOG_INFO << "USING MONTGOMERY MULT\n"; + #endif } static void TearDownTestSuite() { @@ -355,6 +360,8 @@ TYPED_TEST(CurveSanity, CurveSanityTest) { auto a = TypeParam::rand_host(); auto b = TypeParam::rand_host(); + ICICLE_LOG_INFO << "a: "< 0) { expected_mult = TypeParam::dbl(expected_mult); } - if (scalar.get_scalar_digit(scalar_t::NBITS - i - 1, 1)) { expected_mult = expected_mult + point; } + if (barret_scalar.get_scalar_digit(scalar_t::NBITS - i - 1, 1)) { expected_mult = expected_mult + point; } } END_TIMER(ref, "scalar mult double-and-add", true); diff --git a/icicle/tests/test_device_api.cpp b/icicle/tests/test_device_api.cpp index 98e2dfadb..71546b381 100644 --- a/icicle/tests/test_device_api.cpp +++ b/icicle/tests/test_device_api.cpp @@ -28,6 +28,11 @@ class DeviceApiTest : public ::testing::Test icicle_load_backend_from_env_or_default(); s_registered_devices = get_registered_devices_list(); ASSERT_GT(s_registered_devices.size(), 0); + #ifdef BARRET + ICICLE_LOG_INFO << "USING BARRET MULT\n"; + #else + ICICLE_LOG_INFO << "USING MONTGOMERY MULT\n"; + #endif } static void TearDownTestSuite() {} diff --git a/icicle/tests/test_field_api.cpp b/icicle/tests/test_field_api.cpp index 2aa9ce206..743e95e00 100644 --- a/icicle/tests/test_field_api.cpp +++ b/icicle/tests/test_field_api.cpp @@ -42,6 +42,11 @@ class FieldApiTest : public ::testing::Test if (!is_cuda_registered) { ICICLE_LOG_ERROR << "CUDA device not found. Testing CPU vs CPU"; } s_main_target = is_cuda_registered ? "CUDA" : "CPU"; s_reference_target = "CPU"; + #ifdef BARRET + ICICLE_LOG_INFO << "USING BARRET MULT\n"; + #else + ICICLE_LOG_INFO << "USING MONTGOMERY MULT\n"; + #endif } static void TearDownTestSuite() { @@ -71,11 +76,6 @@ TYPED_TEST_SUITE(FieldApiTest, FTImplementations); // Note: this is testing host arithmetic. Other tests against CPU backend should guarantee correct device arithmetic too TYPED_TEST(FieldApiTest, FieldSanityTest) { - #ifdef BARRET - printf("USING BARRET MULT\n"); - #else - printf("USING MONTGOMERY MULT\n"); - #endif auto a = TypeParam::rand_host(); std::cout<(N); auto in_b = std::make_unique(N); FieldApiTest::random_samples(in_a.get(), N); @@ -198,14 +198,14 @@ TYPED_TEST(FieldApiTest, vectorOps) ASSERT_EQ(0, memcmp(in_a.get(), temp_result.get(), N * sizeof(TypeParam))); // add - run(s_reference_target, out_ref.get(), VERBOSE /*=measure*/, vector_add, "vector add", ITERS); - run(s_main_target, out_main.get(), VERBOSE /*=measure*/, vector_add, "vector add", ITERS); - ASSERT_EQ(0, memcmp(out_main.get(), out_ref.get(), N * sizeof(TypeParam))); + // run(s_reference_target, out_ref.get(), VERBOSE /*=measure*/, vector_add, "vector add", ITERS); + // run(s_main_target, out_main.get(), VERBOSE /*=measure*/, vector_add, "vector add", ITERS); + // ASSERT_EQ(0, memcmp(out_main.get(), out_ref.get(), N * sizeof(TypeParam))); - // sub - run(s_reference_target, out_ref.get(), VERBOSE /*=measure*/, vector_sub, "vector sub", ITERS); - run(s_main_target, out_main.get(), VERBOSE /*=measure*/, vector_sub, "vector sub", ITERS); - ASSERT_EQ(0, memcmp(out_main.get(), out_ref.get(), N * sizeof(TypeParam))); + // // sub + // run(s_reference_target, out_ref.get(), VERBOSE /*=measure*/, vector_sub, "vector sub", ITERS); + // run(s_main_target, out_main.get(), VERBOSE /*=measure*/, vector_sub, "vector sub", ITERS); + // ASSERT_EQ(0, memcmp(out_main.get(), out_ref.get(), N * sizeof(TypeParam))); // mul run(s_reference_target, out_ref.get(), VERBOSE /*=measure*/, vector_mul, "vector mul", ITERS); @@ -213,6 +213,11 @@ TYPED_TEST(FieldApiTest, vectorOps) // std::cout << in_a[0] << ", " << in_b[0] << ", " << out_main[0] << ", " << out_ref[0] << std::endl; // std::cout << in_a[1] << ", " << in_b[1] << ", " << out_main[1] << ", " << out_ref[1] << std::endl; + for (int i = 0; i < N; i++) + { + std::cout << in_a[i] << ", " << in_b[i] << ", " << out_main[i] << ", " << out_ref[i] << std::endl; + } + ASSERT_EQ(0, memcmp(out_main.get(), out_ref.get(), N * sizeof(TypeParam))); diff --git a/icicle/tests/test_hash_api.cpp b/icicle/tests/test_hash_api.cpp index 3c6092c68..2bfb6a64d 100644 --- a/icicle/tests/test_hash_api.cpp +++ b/icicle/tests/test_hash_api.cpp @@ -45,6 +45,11 @@ class HashApiTest : public ::testing::Test s_reference_target = "CPU"; s_registered_devices = get_registered_devices_list(); ASSERT_GE(s_registered_devices.size(), 1); + #ifdef BARRET + ICICLE_LOG_INFO << "USING BARRET MULT\n"; + #else + ICICLE_LOG_INFO << "USING MONTGOMERY MULT\n"; + #endif } static void TearDownTestSuite() { diff --git a/icicle/tests/test_polynomial_api.cpp b/icicle/tests/test_polynomial_api.cpp index c859cefd4..36e97ca49 100644 --- a/icicle/tests/test_polynomial_api.cpp +++ b/icicle/tests/test_polynomial_api.cpp @@ -53,6 +53,11 @@ class PolynomialTest : public ::testing::Test const int dev_idx = rand() % s_registered_devices.size(); icicle_set_device(s_registered_devices.at(dev_idx)); ICICLE_LOG_INFO << "setting device " << s_registered_devices.at(dev_idx) << " for polynomial tests"; + #ifdef BARRET + ICICLE_LOG_INFO << "USING BARRET MULT\n"; + #else + ICICLE_LOG_INFO << "USING MONTGOMERY MULT\n"; + #endif } static void TearDownTestSuite() {} From 2c8e2ad5fc08b899c31f402f5c25f27e37678045 Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Mon, 18 Nov 2024 15:24:29 +0200 Subject: [PATCH 12/22] field tests pass after merge --- icicle/include/icicle/fields/field.h | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 5f59eb688..2b50dd325 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -543,8 +543,6 @@ class Field return mont_reduce(r, /* get_higher_half = */ true); } -#else - // #if defined(__GNUC__) && !defined(__NVCC__) && !defined(__clang__) // #pragma GCC optimize("no-strict-aliasing") // #endif @@ -555,6 +553,8 @@ class Field return reduce(xy); // reduce mod p } +#ifdef GARBAGE + // #include /* GNARK CODE START*/ @@ -817,13 +817,12 @@ class Field #endif return z; } +#endif // #if defined(__GNUC__) && !defined(__NVCC__) && !defined(__clang__) // #pragma GCC reset_options // #endif -#endif // __CUDACC__ - /*GNARK CODE END*/ friend HOST_DEVICE bool operator==(const Field& xs, const Field& ys) From c53e3db998fc6e9593a539877b94859d1d65658f Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Tue, 19 Nov 2024 14:18:00 +0200 Subject: [PATCH 13/22] support 32 bit mont (babybear stiil fails) --- icicle/include/icicle/curves/projective.h | 24 --------- icicle/include/icicle/fields/device_math.h | 9 +++- icicle/include/icicle/fields/field.h | 61 ++++++++++++---------- icicle/include/icicle/fields/host_math.h | 35 +++++++------ icicle/tests/test_curve_api.cpp | 2 +- 5 files changed, 61 insertions(+), 70 deletions(-) diff --git a/icicle/include/icicle/curves/projective.h b/icicle/include/icicle/curves/projective.h index 14e649538..8dfc04eb4 100644 --- a/icicle/include/icicle/curves/projective.h +++ b/icicle/include/icicle/curves/projective.h @@ -119,25 +119,13 @@ class Projective FF::template mul_unsigned<3>(FF::template mul_const(t17)); // t23 ← b3 · t17 < 2 const auto t24 = FF::mul_wide(t12, t23); // t24 ← t12 · t23 < 2 const auto t25 = FF::mul_wide(t07, t22); // t25 ← t07 · t22 < 2 - #ifdef BARRET const FF X3 = FF::reduce(t25 - t24); // X3 ← t25 − t24 < 2 - #else - const FF X3 = FF::sos_mont_reduce(t25 - t24); // X3 ← t25 − t24 < 2 - #endif const auto t27 = FF::mul_wide(t23, t19); // t27 ← t23 · t19 < 2 const auto t28 = FF::mul_wide(t22, t21); // t28 ← t22 · t21 < 2 - #ifdef BARRET const FF Y3 = FF::reduce(t28 + t27); // Y3 ← t28 + t27 < 2 - #else - const FF Y3 = FF::sos_mont_reduce(t28 + t27); // Y3 ← t28 + t27 < 2 - #endif const auto t30 = FF::mul_wide(t19, t07); // t30 ← t19 · t07 < 2 const auto t31 = FF::mul_wide(t21, t12); // t31 ← t21 · t12 < 2 - #ifdef BARRET const FF Z3 = FF::reduce(t31 + t30); // Z3 ← t31 + t30 < 2 - #else - const FF Z3 = FF::sos_mont_reduce(t31 + t30); // Z3 ← t31 + t30 < 2 - #endif return {X3, Y3, Z3}; } @@ -178,25 +166,13 @@ class Projective FF::template mul_unsigned<3>(FF::template mul_const(t17)); // t23 ← b3 · t17 < 2 const auto t24 = FF::mul_wide(t12, t23); // t24 ← t12 · t23 < 2 const auto t25 = FF::mul_wide(t07, t22); // t25 ← t07 · t22 < 2 - #ifdef BARRET const FF X3 = FF::reduce(t25 - t24); // X3 ← t25 − t24 < 2 - #else - const FF X3 = FF::sos_mont_reduce(t25 - t24); // X3 ← t25 − t24 < 2 - #endif const auto t27 = FF::mul_wide(t23, t19); // t27 ← t23 · t19 < 2 const auto t28 = FF::mul_wide(t22, t21); // t28 ← t22 · t21 < 2 - #ifdef BARRET const FF Y3 = FF::reduce(t28 + t27); // Y3 ← t28 + t27 < 2 - #else - const FF Y3 = FF::sos_mont_reduce(t28 + t27); // Y3 ← t28 + t27 < 2 - #endif const auto t30 = FF::mul_wide(t19, t07); // t30 ← t19 · t07 < 2 const auto t31 = FF::mul_wide(t21, t12); // t31 ← t21 · t12 < 2 - #ifdef BARRET const FF Z3 = FF::reduce(t31 + t30); // Z3 ← t31 + t30 < 2 - #else - const FF Z3 = FF::sos_mont_reduce(t31 + t30); // Z3 ← t31 + t30 < 2 - #endif // const auto t24 = FF::mul_widez(t12.limbs_storage, t23.limbs_storage); // t24 ← t12 · t23 < 2 // const auto t25 = FF::mul_widez(t07.limbs_storage, t22.limbs_storage); // t25 ← t07 · t22 < 2 // typename FF::Wide W3 = typename FF::Wide{t25} - typename FF::Wide{t24}; // X3 ← t25 − t24 < 2 diff --git a/icicle/include/icicle/fields/device_math.h b/icicle/include/icicle/fields/device_math.h index fba40f399..3926a4bc5 100644 --- a/icicle/include/icicle/fields/device_math.h +++ b/icicle/include/icicle/fields/device_math.h @@ -184,6 +184,8 @@ template return cr; } + #ifdef BARRET + template static DEVICE_INLINE void mad_row_msb(uint32_t* odd, uint32_t* even, const uint32_t* a, uint32_t bi, size_t n = NLIMBS) { @@ -206,6 +208,8 @@ template return; } + #endif + template static DEVICE_INLINE uint32_t mul_n_and_add(uint32_t* acc, const uint32_t* a, uint32_t bi, uint32_t* extra, size_t n = (NLIMBS >> 1)) @@ -341,6 +345,8 @@ template } } + #ifdef BARRET + /** * A function that computes wide product \f$ rs = as \cdot bs \f$ that's correct for the higher NLIMBS + 1 limbs with a * small maximum error. @@ -437,6 +443,7 @@ template } } +#endif // The following algorithms are adaptations of // http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf, @@ -642,7 +649,7 @@ template // static_assert(!(CONFIG::modulus.limbs[NLIMBS - 1] >> 30)); storage<2*NLIMBS> rs = {0}; sqr_raw(xs, rs); - redc_wide_inplace(rs); // after reduce_twopass, tmp's low NLIMBS limbs should represent a value in [0, 2*mod) + reduce_mont_inplace(rs); // after reduce_twopass, tmp's low NLIMBS limbs should represent a value in [0, 2*mod) return rs.get_lo(); } // //add diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 2b50dd325..343156a3a 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -308,6 +308,8 @@ class Field #endif } + #ifdef BARRET + static HOST_DEVICE_INLINE void multiply_and_add_lsb_neg_modulus_raw(const ff_storage& as, ff_storage& cs, ff_storage& rs) { @@ -330,6 +332,8 @@ class Field #endif } + #endif + public: ff_storage limbs_storage; @@ -417,6 +421,8 @@ class Field return rs; } + #ifdef BARRET + /** * This method reduces a Wide number `xs` modulo `p` and returns the result as a Field element. * @@ -438,8 +444,7 @@ class Field * cases it's less than 1, so setting the [num_of_reductions](@ref num_of_reductions) variable for a field equal to 1 * will cause only 1 reduction to be performed. */ - template - static constexpr HOST_DEVICE_INLINE Field reduce(const Wide& xs) //TODO = add reduce_mont_inplace + static constexpr HOST_DEVICE_INLINE Field barret_reduce(const Wide& xs) //TODO = add reduce_mont_inplace { // `xs` is left-shifted by `2 * slack_bits` and higher half is written to `xs_hi` Field xs_hi = Wide::get_higher_with_slack(xs); @@ -464,21 +469,11 @@ class Field return r; } - template - static constexpr HOST_DEVICE_INLINE Field mont_reduce(const Wide& xs, bool get_higher_half = false) + #endif + + static constexpr HOST_DEVICE_INLINE Field mont_sub_modulus(const Wide& xs, bool get_higher = false) { - // Field xs_lo = Wide::get_lower(xs); - // Field xs_hi = Wide::get_higher(xs); - // Wide l1 = {}; - // Wide l2 = {}; - // host_math::template multiply_raw(xs_lo.limbs_storage, get_m(), l1.limbs_storage); - // Field l1_lo = Wide::get_lower(l1); - // host_math::template multiply_raw(l1_lo.limbs_storage, get_modulus<1>(), l2.limbs_storage); - // Field l2_hi = Wide::get_higher(l2); - // Field r = {}; - // add_limbs(l2_hi.limbs_storage, xs_hi.limbs_storage, r.limbs_storage); - - Field r = get_higher_half ? Wide::get_higher(xs) : Wide::get_lower(xs); + Field r = get_higher? Wide::get_higher(xs) : Wide::get_lower(xs); Field p = Field{get_modulus<1>()}; if (p.limbs_storage.limbs[TLC - 1] > r.limbs_storage.limbs[TLC - 1]) return r; ff_storage r_reduced = {}; @@ -488,6 +483,15 @@ class Field return r; } + static constexpr HOST_DEVICE_INLINE Field reduce(const Wide& xs) + { + #ifdef BARRET + return barret_reduce(xs); + #else + return mont_reduce(xs); + #endif + } + HOST_DEVICE Field& operator=(Field const& other) { for (int i = 0; i < TLC; i++) { @@ -503,7 +507,7 @@ class Field #ifdef __CUDA_ARCH__ // cuda #ifdef BARRET Wide xy = mul_wide(xs, ys); // full mult - return reduce(xy); // reduce mod p + return barret_reduce(xy); // reduce mod p // return Wide::get_lower(xy); // reduce mod p #else return sub_modulus<1>(Field{device_math::template mulmont_device(xs.limbs_storage,ys.limbs_storage,get_modulus<1>(),get_mont_inv_modulus())}); @@ -511,7 +515,7 @@ class Field #else #ifdef BARRET Wide xy = mul_wide(xs, ys); // full mult - return reduce(xy); // reduce mod p + return barret_reduce(xy); // reduce mod p // return Wide::get_lower(xy); #else return mont_mult(xs,ys); @@ -523,10 +527,8 @@ class Field static constexpr HOST_INLINE Field mont_mult(const Field& xs, const Field& ys) { Wide r = {}; - host_math::multiply_mont_64( - xs.limbs_storage.limbs64, ys.limbs_storage.limbs64, get_mont_inv_modulus().limbs64, get_modulus<1>().limbs64, - r.limbs_storage.limbs64); - return mont_reduce(r); + host_math::multiply_mont(xs.limbs_storage, ys.limbs_storage, get_mont_inv_modulus(), get_modulus<1>(), r.limbs_storage); + return mont_sub_modulus(r); } /** @@ -535,12 +537,17 @@ class Field * @param t Number to be reduced. Must be in montgomery rep, and in range [0,p^2-1]. * @return \p t mod p */ - static constexpr HOST_INLINE Field sos_mont_reduce(const Wide& t) + static constexpr HOST_INLINE Field mont_reduce(const Wide& t) { + #ifdef __CUDA_ARCH__ + Wide r = t; + device_math::template reduce_mont_inplace(r.limbs_storage, get_modulus<1>(), get_mont_inv_modulus()); + #else Wide r = {}; - host_math::sos_mont_reduction_64( - t.limbs_storage.limbs64, get_modulus<1>().limbs64, get_mont_inv_modulus().limbs64, r.limbs_storage.limbs64); - return mont_reduce(r, /* get_higher_half = */ true); + host_math::template sos_mont_reduction_64( + t.limbs_storage.limbs64, get_modulus<1>().limbs64, get_mont_inv_modulus().limbs64, r.limbs_storage.limbs64); //TODO enable 32 + #endif + return mont_sub_modulus(r, true); } // #if defined(__GNUC__) && !defined(__NVCC__) && !defined(__clang__) @@ -553,7 +560,7 @@ class Field return reduce(xy); // reduce mod p } -#ifdef GARBAGE +#if 0 // #include diff --git a/icicle/include/icicle/fields/host_math.h b/icicle/include/icicle/fields/host_math.h index 63b368e92..46f9c9787 100644 --- a/icicle/include/icicle/fields/host_math.h +++ b/icicle/include/icicle/fields/host_math.h @@ -251,35 +251,18 @@ namespace host_math { static HOST_INLINE void multiply_mont_64(const uint64_t* a, const uint64_t* b, const uint64_t* q, const uint64_t* p, uint64_t* r) { - // printf("r0: "); - // for (unsigned i = 0; i < NLIMBS_B / 2; i++) { - // printf(" %lu,",r[i]); - // } - // printf("\n"); for (unsigned i = 0; i < NLIMBS_B / 2; i++) { - // printf("i %d\n", i); uint64_t A = 0, C = 0; r[0] = host_math::madc_cc_64(a[0], b[i], r[0], A); - // printf("r0 %lu\n",r[0]); - // printf("q0 %lu\n",q[0]); - // printf("p0 %lu\n",p[0]); - // printf("A %lu\n",A); uint64_t m = host_math::madc_cc_64(r[0], q[0], 0, C); // TODO - multiply inst - // printf("m %lu\n",m); C = 0; host_math::madc_cc_64(m, p[0], r[0], C); - // printf("c %lu\n",C); for (unsigned j = 1; j < NLIMBS_A / 2; j++) { r[j] = host_math::madc_cc_64(a[j], b[i], r[j], A); r[j - 1] = host_math::madc_cc_64(m, p[j], r[j], C); } r[NLIMBS_A / 2 - 1] = C + A; } - // printf("rf: "); - // for (unsigned i = 0; i < NLIMBS_B / 2; i++) { - // printf(" %lu,",r[i]); - // } - // printf("\n"); } /** @@ -319,6 +302,24 @@ namespace host_math { } } +template + static constexpr HOST_INLINE void + multiply_mont(const storage& as, const storage& bs, const storage& qs, const storage& ps, storage& rs) + { + static_assert( + (NLIMBS_A % 2 == 0 || NLIMBS_A == 1) && (NLIMBS_B % 2 == 0 || NLIMBS_B == 1), + "odd number of limbs is not supported\n"); + if constexpr (USE_32) { + multiply_mont_32(as.limbs, bs.limbs, qs.limbs, ps.limbs, rs.limbs); + return; + } else if constexpr (NLIMBS_A == 1 || NLIMBS_B == 1) { + multiply_mont_32(as.limbs, bs.limbs, qs.limbs, ps.limbs, rs.limbs); + return; + } else { + multiply_mont_64(as.limbs64, bs.limbs64, qs.limbs64, ps.limbs64, rs.limbs64); + } + } + template static HOST_INLINE void multiply_raw_64(const storage& as, const storage& bs, storage& rs) diff --git a/icicle/tests/test_curve_api.cpp b/icicle/tests/test_curve_api.cpp index 3794bb752..a3302f273 100644 --- a/icicle/tests/test_curve_api.cpp +++ b/icicle/tests/test_curve_api.cpp @@ -510,7 +510,7 @@ TYPED_TEST(CurveSanity, MontSosReduction) START_TIMER(mont_sos_reduction); for (int i = 0; i < n; i++) { auto ab_no_mod = scalar_t::mul_wide(as[i], bs[i]); - abs[i] = scalar_t::sos_mont_reduce(ab_no_mod); + abs[i] = scalar_t::reduce(ab_no_mod); } END_TIMER(mont_sos_reduction, "CPU-Montgomery SOS reduction", true); From 4282114a932a019f93fadb9a2786059109951566 Mon Sep 17 00:00:00 2001 From: Koren-Brand Date: Tue, 19 Nov 2024 15:20:01 +0200 Subject: [PATCH 14/22] SOS mont reduction now implemented for 32bits as well Signed-off-by: Koren-Brand --- icicle/include/icicle/fields/field.h | 4 +- icicle/include/icicle/fields/host_math.h | 78 ++++++++++++++++++++++-- icicle/tests/test_curve_api.cpp | 2 +- 3 files changed, 75 insertions(+), 9 deletions(-) diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 343156a3a..aaf5af145 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -544,8 +544,8 @@ class Field device_math::template reduce_mont_inplace(r.limbs_storage, get_modulus<1>(), get_mont_inv_modulus()); #else Wide r = {}; - host_math::template sos_mont_reduction_64( - t.limbs_storage.limbs64, get_modulus<1>().limbs64, get_mont_inv_modulus().limbs64, r.limbs_storage.limbs64); //TODO enable 32 + host_math::template sos_mont_reduction( + t.limbs_storage, get_modulus<1>(), get_mont_inv_modulus(), r.limbs_storage); #endif return mont_sub_modulus(r, true); } diff --git a/icicle/include/icicle/fields/host_math.h b/icicle/include/icicle/fields/host_math.h index 46f9c9787..ad9d5ec6e 100644 --- a/icicle/include/icicle/fields/host_math.h +++ b/icicle/include/icicle/fields/host_math.h @@ -84,12 +84,12 @@ namespace host_math { return result; } - static inline __host__ __uint128_t mul64(uint64_t x, uint64_t y) - { - uint64_t high, low; - asm("mulq %3" : "=d"(high), "=a"(low) : "a"(x), "r"(y) : "cc"); - return (static_cast<__uint128_t>(high) << 64) | low; - } + // static inline __host__ __uint128_t mul64(uint64_t x, uint64_t y) + // { + // uint64_t high, low; + // asm("mulq %3" : "=d"(high), "=a"(low) : "a"(x), "r"(y) : "cc"); + // return (static_cast<__uint128_t>(high) << 64) | low; + // } static __host__ uint64_t madc_cc_64(const uint64_t x, const uint64_t y, const uint64_t z, uint64_t& carry) { @@ -265,6 +265,43 @@ namespace host_math { } } + /** + * @brief Perform SOS reduction on a number in montgomery representation \p t in range [0, \p n ^2-1] limiting it to + * the range [0,2 \p n -1]. + * @param t Number to be reduced. Must be in montgomery rep, and in range [0, \p n ^2-1]. + * @param n Field modulus. + * @param n_tag Number such that \p n * \p n_tag modR = -1 + * @param r Array in which to store the result in its upper half (Lower half is data that would be removed by + * dividing by R = shifting NLIMBS down). + * @tparam NLIMBS Number of 32bit limbs required to represend a number in the field defined by n. R is 2^(NLIMBS*32). + */ + template + static HOST_INLINE void + sos_mont_reduction_32(const uint32_t* t, const uint32_t* n, const uint32_t* n_tag, uint32_t* r) + { + const unsigned s = NLIMBS; // For similarity to the original algorithm + + // Copy t to r as t is read-only + for (int i = 0; i < 2 * s; i++) { + r[i] = t[i]; + } + + for (int i = 0; i < s; i++) { + uint32_t c = 0; + uint32_t m = r[i] * n_tag[0]; + + for (int j = 0; j < s; j++) { + // r[i+j] = addc_cc(r[i+j], m * n[j], c); + r[i + j] = madc_cc(m, n[j], r[i + j], c); + } + // Propagate the carry to the remaining sublimbs + for (int carry_idx = s + i; carry_idx < 2 * s; carry_idx++) { + if (c == 0) { break; } + r[carry_idx] = add_cc(r[carry_idx], c, c); + } + } + } + /** * @brief Perform SOS reduction on a number in montgomery representation \p t in range [0, \p n ^2-1] limiting it to * the range [0,2 \p n -1]. @@ -302,6 +339,35 @@ namespace host_math { } } + /** + * @brief Perform SOS reduction on a number in montgomery representation \p t in range [0, \p n ^2-1] limiting it to + * the range [0,2 \p n -1]. + * @param t Number to be reduced. Must be in montgomery rep, and in range [0, \p n ^2-1]. + * @param n Field modulus. + * @param n_tag Number such that \p n * \p n_tag modR = -1 + * @param r Array in which to store the result in its upper half (Lower half is data that would be removed by + * dividing by R = shifting NLIMBS down). + * @tparam NLIMBS Number of 32bit limbs required to represend a number in the field defined by n. R is 2^(NLIMBS*32). + */ + template + static HOST_INLINE void + sos_mont_reduction( + const storage<2*NLIMBS>& t, const storage& n, const storage& n_tag, storage<2*NLIMBS>& r) + { + static_assert( + NLIMBS % 2 == 0 || NLIMBS == 1, + "Odd number of limbs (That is not 1) is not supported\n"); + if constexpr (USE_32) { + sos_mont_reduction_32(t.limbs, n.limbs, n_tag.limbs, r.limbs); + return; + } else if constexpr (NLIMBS == 1) { + sos_mont_reduction_32(t.limbs, n.limbs, n_tag.limbs, r.limbs); + return; + } else { + sos_mont_reduction_64(t.limbs64, n.limbs64, n_tag.limbs64, r.limbs64); + } + } + template static constexpr HOST_INLINE void multiply_mont(const storage& as, const storage& bs, const storage& qs, const storage& ps, storage& rs) diff --git a/icicle/tests/test_curve_api.cpp b/icicle/tests/test_curve_api.cpp index a3302f273..9fb78b1dc 100644 --- a/icicle/tests/test_curve_api.cpp +++ b/icicle/tests/test_curve_api.cpp @@ -486,7 +486,7 @@ TYPED_TEST(CurveSanity, u64Mul) // #pragma unroll uint64_t high, low; for (int i = 0; i < n; ++i) { - asm("mulq %3" : "=d"(high), "=a"(low) : "a"(scalars[i]), "r"(scalars2[i]) : "cc"); + // asm("mulq %3" : "=d"(high), "=a"(low) : "a"(scalars[i]), "r"(scalars2[i]) : "cc"); scalars_res_128[i] = (static_cast<__uint128_t>(high) << 64) | low; } END_TIMER(u64Mult_asm, "U64-MULT-asm", true); From 7f5345b00cacb11661f00b2508df1d366d04b2aa Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Wed, 20 Nov 2024 11:59:49 +0200 Subject: [PATCH 15/22] all c++ tests pass --- icicle/include/icicle/fields/field.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index aaf5af145..05632fdbe 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -542,12 +542,13 @@ class Field #ifdef __CUDA_ARCH__ Wide r = t; device_math::template reduce_mont_inplace(r.limbs_storage, get_modulus<1>(), get_mont_inv_modulus()); + return mont_sub_modulus(r); #else Wide r = {}; host_math::template sos_mont_reduction( t.limbs_storage, get_modulus<1>(), get_mont_inv_modulus(), r.limbs_storage); - #endif return mont_sub_modulus(r, true); + #endif } // #if defined(__GNUC__) && !defined(__NVCC__) && !defined(__clang__) From a738a4121fbb1c82e31408fcfc7e3c6cfd852a69 Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Wed, 20 Nov 2024 12:10:26 +0200 Subject: [PATCH 16/22] fix extention --- icicle/include/icicle/fields/quartic_extension.h | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/icicle/include/icicle/fields/quartic_extension.h b/icicle/include/icicle/fields/quartic_extension.h index 43038d588..cd7b79751 100644 --- a/icicle/include/icicle/fields/quartic_extension.h +++ b/icicle/include/icicle/fields/quartic_extension.h @@ -153,19 +153,17 @@ class QuarticExtensionField FF::mul_wide(xs.real, ys), FF::mul_wide(xs.im1, ys), FF::mul_wide(xs.im2, ys), FF::mul_wide(xs.im3, ys)}; } - template static constexpr HOST_DEVICE_INLINE ExtensionWide mul_wide(const FF& xs, const QuarticExtensionField& ys) { return ExtensionWide{ FF::mul_wide(xs, ys.real), FF::mul_wide(xs, ys.im1), FF::mul_wide(xs, ys.im2), FF::mul_wide(xs, ys.im3)}; } - template static constexpr HOST_DEVICE_INLINE QuarticExtensionField reduce(const ExtensionWide& xs) { return QuarticExtensionField{ - FF::template reduce(xs.real), FF::template reduce(xs.im1), - FF::template reduce(xs.im2), FF::template reduce(xs.im3)}; + FF::reduce(xs.real), FF::reduce(xs.im1), + FF::reduce(xs.im2), FF::reduce(xs.im3)}; } template @@ -207,7 +205,6 @@ class QuarticExtensionField return xs * xs; } - template static constexpr HOST_DEVICE_INLINE QuarticExtensionField neg(const QuarticExtensionField& xs) { return {FF::neg(xs.real), FF::neg(xs.im1), FF::neg(xs.im2), FF::neg(xs.im3)}; From 7ae18a5ae79e4677ea6bc825f8cece182aa953ec Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Wed, 20 Nov 2024 22:51:28 +0200 Subject: [PATCH 17/22] all tests pass --- .../include/icicle/fields/complex_extension.h | 4 +- icicle/include/icicle/fields/device_math.h | 197 ++++++++++++++---- icicle/tests/test_field_api.cpp | 11 +- 3 files changed, 163 insertions(+), 49 deletions(-) diff --git a/icicle/include/icicle/fields/complex_extension.h b/icicle/include/icicle/fields/complex_extension.h index 4c1d00298..b9fd27bcc 100644 --- a/icicle/include/icicle/fields/complex_extension.h +++ b/icicle/include/icicle/fields/complex_extension.h @@ -128,17 +128,15 @@ class ComplexExtensionField return ExtensionWide{FF::mul_wide(xs.real, ys), FF::mul_wide(xs.imaginary, ys)}; } - template static constexpr HOST_DEVICE_INLINE ExtensionWide mul_wide(const FF& xs, const ComplexExtensionField& ys) { return mul_wide(ys, xs); } - template static constexpr HOST_DEVICE_INLINE ComplexExtensionField reduce(const ExtensionWide& xs) { return ComplexExtensionField{ - FF::template reduce(xs.real), FF::template reduce(xs.imaginary)}; + FF::reduce(xs.real), FF::reduce(xs.imaginary)}; } template diff --git a/icicle/include/icicle/fields/device_math.h b/icicle/include/icicle/fields/device_math.h index 3926a4bc5..4e3aa2b13 100644 --- a/icicle/include/icicle/fields/device_math.h +++ b/icicle/include/icicle/fields/device_math.h @@ -115,6 +115,8 @@ template acc[0] = CARRY_IN ? ptx::madc_lo_cc(a[0], bi, acc[0]) : ptx::mad_lo_cc(a[0], bi, acc[0]); acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); + // printf("even %d, mod %d, mi %d\n", acc[0], a[0],bi); + // printf("even %d, mod %d, mi %d\n", acc[0], a[0],bi); UNROLL for (size_t i = 2; i < n; i += 2) { acc[i] = ptx::madc_lo_cc(a[i], bi, acc[i]); @@ -494,22 +496,43 @@ template static DEVICE_INLINE void multiply_raw_sb(const storage &as, const storage &bs, storage<2*NLIMBS> &rs) { const uint32_t *a = as.limbs; const uint32_t *b = bs.limbs; - uint32_t *even = rs.limbs; - __align__(8) uint32_t odd[2 * NLIMBS - 2]; - mul_n(even, a, b[0]); - mul_n(odd, a + 1, b[0]); - mad_row(&even[2], &odd[0], a, b[1]); - size_t i; -#pragma unroll - for (i = 2; i < NLIMBS - 1; i += 2) { - mad_row(&odd[i], &even[i], a, b[i]); - mad_row(&even[i + 2], &odd[i], a, b[i + 1]); + if constexpr (NLIMBS > 2){ + uint32_t *even = rs.limbs; + __align__(8) uint32_t odd[2 * NLIMBS - 2]; + mul_n(even, a, b[0]); + mul_n(odd, a + 1, b[0]); + mad_row(&even[2], &odd[0], a, b[1]); + size_t i; + #pragma unroll + for (i = 2; i < NLIMBS - 1; i += 2) { + mad_row(&odd[i], &even[i], a, b[i]); + mad_row(&even[i + 2], &odd[i], a, b[i + 1]); + } + // merge |even| and |odd| + even[1] = ptx::add_cc(even[1], odd[0]); + for (i = 1; i < 2 * NLIMBS - 2; i++) + even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); + even[i + 1] = ptx::addc(even[i + 1], 0); + } + else if (NLIMBS == 2) { + uint32_t *r = rs.limbs; + __align__(8) uint32_t odd[2]; + r[0] = ptx::mul_lo(a[0], b[0]); + r[1] = ptx::mul_hi(a[0], b[0]); + r[2] = ptx::mul_lo(a[1], b[1]); + r[3] = ptx::mul_hi(a[1], b[1]); + odd[0] = ptx::mul_lo(a[0], b[1]); + odd[1] = ptx::mul_hi(a[0], b[1]); + odd[0] = ptx::mad_lo(a[1], b[0], odd[0]); + odd[1] = ptx::mad_hi(a[1], b[0], odd[1]); + r[1] = ptx::add_cc(r[1], odd[0]); + r[2] = ptx::addc_cc(r[2], odd[1]); + r[3] = ptx::addc(r[3], 0); + } else if (NLIMBS == 1) { + uint32_t *r = rs.limbs; + r[0] = ptx::mul_lo(a[0], b[0]); + r[1] = ptx::mul_hi(a[0], b[0]); } - // merge |even| and |odd| - even[1] = ptx::add_cc(even[1], odd[0]); - for (i = 1; i < 2 * NLIMBS - 2; i++) - even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); - even[i + 1] = ptx::addc(even[i + 1], 0); } template @@ -592,24 +615,58 @@ template // Since the hi NLIMBS limbs don't participate in computing the "mi" factor at each mul-and-rightshift stage, // it's ok to ignore the hi NLIMBS limbs during this process and just add them in afterward. uint32_t odd[NLIMBS]; - size_t i; -#pragma unroll - for (i = 0; i < NLIMBS; i += 2) { - mul_by_1_row(&even[0], &odd[0], modulus.limbs, mont_inv_modulus.limbs, i == 0); - mul_by_1_row(&odd[0], &even[0], modulus.limbs, mont_inv_modulus.limbs); + if constexpr (NLIMBS == 1) { + uint32_t mi = even[0] * mont_inv_modulus.limbs[0]; //m + // printf("m %u\n", mi); + // printf("p %u\n", modulus.limbs[0]); + // printf("t0 %u\n", odd[0]); + // cmad_n(odd, modulus + 1, mi); + // acc[0] = ptx::mad_lo_cc(a[0], bi, acc[0]); + // acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); + // cmad_n(even, modulus.limbs, mi); + even[0] = ptx::mad_lo_cc(modulus.limbs[0], mi, even[0]); + odd[0] = ptx::madc_hi_cc(modulus.limbs[0], mi, 0); + + // odd[0] = ptx::mad_lo(modulus.limbs[0], mi, odd[0]); //C + // even[0] = ptx::mul_hi(modulus.limbs[0], mi, odd[0]); //C + // even[0] = ptx::mad_hi(modulus.limbs[0], mi, 0); //C + // even[0] = ptx::mad_hi(modulus.limbs[0], 0, odd[0]); //C + // printf("C %u\n", odd[1]); + // acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); + // odd[0] = ptx::addc(odd[0], 0); + + // odd[0] = ptx::add_cc(odd[0], even[1]); + // madc_n_rshift(even, a + 1, bi); + // cmad_n(odd, a, bi); + // even[NLIMBS - 1] = ptx::addc(even[NLIMBS - 1], 0); + // mi = odd[0] * mont_inv_modulus.limbs[0]; + // cmad_n(even, modulus + 1, mi); + // cmad_n(odd, modulus.limbs, mi); + // odd[0] = ptx::mad_lo_cc(modulus.limbs[0], mi, 0); + // odd[1] = ptx::madc_hi_cc(modulus.limbs[0], mi, 0); + // even[0] = ptx::addc(even[0], 0); + even[0] = ptx::add_cc(even[1], odd[0]); + } + else { + size_t i; + #pragma unroll + for (i = 0; i < NLIMBS; i += 2) { + mul_by_1_row(&even[0], &odd[0], modulus.limbs, mont_inv_modulus.limbs, i == 0); + mul_by_1_row(&odd[0], &even[0], modulus.limbs, mont_inv_modulus.limbs); + } + even[0] = ptx::add_cc(even[0], odd[1]); + #pragma unroll + for (i = 1; i < NLIMBS - 1; i++) + even[i] = ptx::addc_cc(even[i], odd[i + 1]); + even[i] = ptx::addc(even[i], 0); + // Adds in (hi NLIMBS limbs), implicitly right-shifting them by NLIMBS limbs as if they had participated in the + // add-and-rightshift stages above. + xs.limbs[0] = ptx::add_cc(xs.limbs[0], xs.limbs[NLIMBS]); + #pragma unroll + for (i = 1; i < NLIMBS - 1; i++) + xs.limbs[i] = ptx::addc_cc(xs.limbs[i], xs.limbs[i + NLIMBS]); + xs.limbs[NLIMBS - 1] = ptx::addc(xs.limbs[NLIMBS - 1], xs.limbs[2 * NLIMBS - 1]); } - even[0] = ptx::add_cc(even[0], odd[1]); -#pragma unroll - for (i = 1; i < NLIMBS - 1; i++) - even[i] = ptx::addc_cc(even[i], odd[i + 1]); - even[i] = ptx::addc(even[i], 0); - // Adds in (hi NLIMBS limbs), implicitly right-shifting them by NLIMBS limbs as if they had participated in the - // add-and-rightshift stages above. - xs.limbs[0] = ptx::add_cc(xs.limbs[0], xs.limbs[NLIMBS]); -#pragma unroll - for (i = 1; i < NLIMBS - 1; i++) - xs.limbs[i] = ptx::addc_cc(xs.limbs[i], xs.limbs[i + NLIMBS]); - xs.limbs[NLIMBS - 1] = ptx::addc(xs.limbs[NLIMBS - 1], xs.limbs[2 * NLIMBS - 1]); } template @@ -619,17 +676,71 @@ template uint32_t *even = r_in.limbs; __align__(8) uint32_t odd[NLIMBS + 1]; size_t i; -#pragma unroll - for (i = 0; i < NLIMBS; i += 2) { - mad_n_redc(&even[0], &odd[0], a, b[i], modulus.limbs, mont_inv_modulus.limbs, i == 0); - mad_n_redc(&odd[0], &even[0], a, b[i + 1], modulus.limbs, mont_inv_modulus.limbs); + if constexpr (NLIMBS == 1) { + // mad_n_redc(&even[0], &odd[0], a, b[0], modulus.limbs, mont_inv_modulus.limbs, true); + // printf("even0 b %d\n", even[0]); + // mul_n(even, a, b[0]); + // printf("even0 a %d\n", even[0]); + // uint32_t mi = even[0] * mont_inv_modulus.limbs[0]; + // printf("mi %d\n", mi); + // cmad_n(even, modulus.limbs, mi); + // printf("even0 c %d\n", even[0]); + + // mul_n(odd, a + 1, bi); + // odd[0] = ptx::mul_lo(a[1], b[0]); + // odd[1] = ptx::mul_hi(a[1], b[0]); + // mul_n(even, a, b[0]); + // printf("a %u\n", a[0]); + // printf("b %u\n", b[0]); + odd[0] = ptx::mul_lo(a[0], b[0]); //t[0] + // printf("t0 %u\n", odd[0]); + even[0] = ptx::mul_hi(a[0], b[0]); //A + // printf("A %u\n", even[0]); + uint32_t mi = odd[0] * mont_inv_modulus.limbs[0]; //m + // printf("m %u\n", mi); + // printf("p %u\n", modulus.limbs[0]); + // printf("t0 %u\n", odd[0]); + // cmad_n(odd, modulus + 1, mi); + // acc[0] = ptx::mad_lo_cc(a[0], bi, acc[0]); + // acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); + // cmad_n(even, modulus.limbs, mi); + odd[0] = ptx::mad_lo_cc(modulus.limbs[0], mi, odd[0]); + odd[1] = ptx::madc_hi_cc(modulus.limbs[0], mi, 0); + + // odd[0] = ptx::mad_lo(modulus.limbs[0], mi, odd[0]); //C + // even[0] = ptx::mul_hi(modulus.limbs[0], mi, odd[0]); //C + // even[0] = ptx::mad_hi(modulus.limbs[0], mi, 0); //C + // even[0] = ptx::mad_hi(modulus.limbs[0], 0, odd[0]); //C + // printf("C %u\n", odd[1]); + // acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); + // odd[0] = ptx::addc(odd[0], 0); + + // odd[0] = ptx::add_cc(odd[0], even[1]); + // madc_n_rshift(even, a + 1, bi); + // cmad_n(odd, a, bi); + // even[NLIMBS - 1] = ptx::addc(even[NLIMBS - 1], 0); + // mi = odd[0] * mont_inv_modulus.limbs[0]; + // cmad_n(even, modulus + 1, mi); + // cmad_n(odd, modulus.limbs, mi); + // odd[0] = ptx::mad_lo_cc(modulus.limbs[0], mi, 0); + // odd[1] = ptx::madc_hi_cc(modulus.limbs[0], mi, 0); + // even[0] = ptx::addc(even[0], 0); + even[0] = ptx::add_cc(even[0], odd[1]); + // printf("A+C %u\n", even[0]); + } + else { + #pragma unroll + for (i = 0; i < NLIMBS; i += 2) { + mad_n_redc(&even[0], &odd[0], a, b[i], modulus.limbs, mont_inv_modulus.limbs, i == 0); + mad_n_redc(&odd[0], &even[0], a, b[i + 1], modulus.limbs, mont_inv_modulus.limbs); + } + // merge |even| and |odd| + even[0] = ptx::add_cc(even[0], odd[1]); + #pragma unroll + for (i = 1; i < NLIMBS - 1; i++) + even[i] = ptx::addc_cc(even[i], odd[i + 1]); + even[i] = ptx::addc(even[i], 0); } - // merge |even| and |odd| - even[0] = ptx::add_cc(even[0], odd[1]); -#pragma unroll - for (i = 1; i < NLIMBS - 1; i++) - even[i] = ptx::addc_cc(even[i], odd[i + 1]); - even[i] = ptx::addc(even[i], 0); // final reduction from [0, 2*mod) to [0, mod) not done here, instead performed optionally in mul_device wrapper } @@ -638,9 +749,9 @@ template template static constexpr DEVICE_INLINE storage mulmont_device(const storage &xs, const storage &ys, const storage &modulus, const storage &mont_inv_modulus) { // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes of slack // static_assert(!(CONFIG::modulus.limbs[NLIMBS - 1] >> 30)); - // printf(" "); storage rs = {0}; montmul_raw(xs, ys, modulus, mont_inv_modulus, rs); + // printf("rs %d\n",rs.limbs[0]); return rs; } diff --git a/icicle/tests/test_field_api.cpp b/icicle/tests/test_field_api.cpp index 743e95e00..49e249110 100644 --- a/icicle/tests/test_field_api.cpp +++ b/icicle/tests/test_field_api.cpp @@ -73,13 +73,16 @@ typedef testing::Types FTImplementations; TYPED_TEST_SUITE(FieldApiTest, FTImplementations); + // Note: this is testing host arithmetic. Other tests against CPU backend should guarantee correct device arithmetic too TYPED_TEST(FieldApiTest, FieldSanityTest) { auto a = TypeParam::rand_host(); + // a.limbs_storage.limbs[0] = 1089097490; std::cout<(N); auto in_b = std::make_unique(N); FieldApiTest::random_samples(in_a.get(), N); FieldApiTest::random_samples(in_b.get(), N); + // in_a[0].limbs_storage.limbs[0] = 1089097490; + // in_b[0].limbs_storage.limbs[0] = 1691855643; auto out_main = std::make_unique(N); auto out_ref = std::make_unique(N); From c923e6e5744c8932e418e63b2278bdeee64a8a2f Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Thu, 21 Nov 2024 18:51:05 +0200 Subject: [PATCH 18/22] formatting --- icicle/backend/cpu/include/ntt_cpu.h | 3 +- icicle/backend/cpu/src/curve/cpu_msm.hpp | 22 +- icicle/include/icicle/curves/projective.h | 34 +- .../include/icicle/fields/complex_extension.h | 3 +- icicle/include/icicle/fields/device_math.h | 775 ------------------ icicle/include/icicle/fields/field.h | 485 ++--------- icicle/include/icicle/fields/host_math.h | 64 +- icicle/include/icicle/fields/params_gen.h | 26 +- .../include/icicle/fields/quartic_extension.h | 4 +- .../fields/snark_fields/bls12_377_base.h | 1 - .../fields/snark_fields/bls12_377_scalar.h | 1 - .../include/icicle/fields/stark_fields/m31.h | 2 + icicle/include/icicle/msm.h | 22 +- icicle/tests/test_curve_api.cpp | 20 +- icicle/tests/test_device_api.cpp | 10 +- icicle/tests/test_field_api.cpp | 107 ++- icicle/tests/test_hash_api.cpp | 10 +- icicle/tests/test_polynomial_api.cpp | 10 +- 18 files changed, 229 insertions(+), 1370 deletions(-) delete mode 100644 icicle/include/icicle/fields/device_math.h diff --git a/icicle/backend/cpu/include/ntt_cpu.h b/icicle/backend/cpu/include/ntt_cpu.h index ccf076d1a..725de780e 100644 --- a/icicle/backend/cpu/include/ntt_cpu.h +++ b/icicle/backend/cpu/include/ntt_cpu.h @@ -27,8 +27,7 @@ namespace ntt_cpu { public: NttCpu(uint32_t logn, NTTDir direction, const NTTConfig& config, const E* input, E* output) : input(input), ntt_data(logn, output, config, direction), ntt_tasks_manager(ntt_data.ntt_sub_logn, logn), - // tasks_manager(std::make_unique>>(std::thread::hardware_concurrency() - 1)) - tasks_manager(std::make_unique>>(1)) + tasks_manager(std::make_unique>>(std::thread::hardware_concurrency() - 1)) { } eIcicleError run(); diff --git a/icicle/backend/cpu/src/curve/cpu_msm.hpp b/icicle/backend/cpu/src/curve/cpu_msm.hpp index 129318795..ba8b30e9c 100644 --- a/icicle/backend/cpu/src/curve/cpu_msm.hpp +++ b/icicle/backend/cpu/src/curve/cpu_msm.hpp @@ -453,11 +453,11 @@ void Msm::phase1_bucket_accumulator(const scalar_t* scalars, const A* base for (int j = 0; j < m_precompute_factor; j++) { // Handle required preprocess of base P according to the version of Field/Ec adder (accepting Barret / Montgomery) A base = - #ifdef BARRET +#ifdef BARRET m_are_points_mont ? A::from_montgomery(bases[m_precompute_factor * i + j]) : bases[m_precompute_factor * i + j]; - #else +#else m_are_points_mont ? bases[m_precompute_factor * i + j] : A::to_montgomery(bases[m_precompute_factor * i + j]); - #endif +#endif if (base == A::zero()) { continue; } if (negate_p_and_s) { base = A::neg(base); } @@ -785,22 +785,22 @@ eIcicleError cpu_msm_precompute_bases( for (int i = 0; i < nof_bases; i++) { output_bases[precompute_factor * i] = input_bases[i]; // Handle required preprocess of base P according to the version of Field/Ec adder (accepting Barret / Montgomery) - P point = - #ifdef BARRET + P point = +#ifdef BARRET P::from_affine(is_mont ? A::from_montgomery(input_bases[i]) : input_bases[i]); - #else +#else P::from_affine(is_mont ? input_bases[i] : A::to_montgomery(input_bases[i])); - #endif +#endif for (int j = 1; j < precompute_factor; j++) { for (int k = 0; k < shift; k++) { point = P::dbl(point); } - output_bases[precompute_factor * i + j] = - #ifdef BARRET + output_bases[precompute_factor * i + j] = +#ifdef BARRET is_mont ? A::to_montgomery(P::to_affine(point)) : P::to_affine(point); - #else +#else is_mont ? P::to_affine(point) : A::from_montgomery(P::to_affine(point)); - #endif +#endif } } return eIcicleError::SUCCESS; diff --git a/icicle/include/icicle/curves/projective.h b/icicle/include/icicle/curves/projective.h index 8dfc04eb4..b5ae43cfb 100644 --- a/icicle/include/icicle/curves/projective.h +++ b/icicle/include/icicle/curves/projective.h @@ -47,11 +47,14 @@ class Projective return {FF::from_montgomery(point.x), FF::from_montgomery(point.y), FF::from_montgomery(point.z)}; } - #ifdef BARRET +#ifdef BARRET static HOST_DEVICE_INLINE Projective generator() { return {Gen::gen_x, Gen::gen_y, FF::one()}; } - #else - static HOST_DEVICE_INLINE Projective generator() { return {FF::to_montgomery(Gen::gen_x), FF::to_montgomery(Gen::gen_y), FF::one()}; } - #endif +#else + static HOST_DEVICE_INLINE Projective generator() + { + return {FF::to_montgomery(Gen::gen_x), FF::to_montgomery(Gen::gen_y), FF::one()}; + } +#endif static HOST_DEVICE_INLINE Projective neg(const Projective& point) { return {point.x, FF::neg(point.y), point.z}; } @@ -173,23 +176,6 @@ class Projective const auto t30 = FF::mul_wide(t19, t07); // t30 ← t19 · t07 < 2 const auto t31 = FF::mul_wide(t21, t12); // t31 ← t21 · t12 < 2 const FF Z3 = FF::reduce(t31 + t30); // Z3 ← t31 + t30 < 2 - // const auto t24 = FF::mul_widez(t12.limbs_storage, t23.limbs_storage); // t24 ← t12 · t23 < 2 - // const auto t25 = FF::mul_widez(t07.limbs_storage, t22.limbs_storage); // t25 ← t07 · t22 < 2 - // typename FF::Wide W3 = typename FF::Wide{t25} - typename FF::Wide{t24}; // X3 ← t25 − t24 < 2 - // FF::redc_wide_inplacez(W3.limbs_storage); // X3 ← t25 − t24 < 2 - // const auto X3 = FF::Wide::get_lower(W3); - // const auto t27 = FF::mul_widez(t23.limbs_storage, t19.limbs_storage); // t27 ← t23 · t19 < 2 - // const auto t28 = FF::mul_widez(t22.limbs_storage, t21.limbs_storage); // t28 ← t22 · t21 < 2 - // W3 = typename FF::Wide{t28} + typename FF::Wide{t27}; // Y3 ← t28 + t27 < 2 - // FF::redc_wide_inplacez(W3.limbs_storage); // Y3 ← t28 + t27 < 2 - // const auto Y3 = FF::Wide::get_lower(W3); - // const auto t30 = FF::mul_widez(t19.limbs_storage, t07.limbs_storage); // t30 ← t19 · t07 < 2 - // const auto t31 = FF::mul_widez(t21.limbs_storage, t12.limbs_storage); // t31 ← t21 · t12 < 2 - // W3 = typename FF::Wide{t31} + typename FF::Wide{t30}; // Z3 ← t31 + t30 < 2 - // FF::redc_wide_inplacez(W3.limbs_storage); // Z3 ← t31 + t30 < 2 - // const auto Z3 = FF::Wide::get_lower(W3); - // #else - // #endif return {X3, Y3, Z3}; } @@ -200,10 +186,10 @@ class Projective friend HOST_DEVICE Projective operator*(SCALAR_FF scalar, const Projective& point) { - #ifndef BARRET +#ifndef BARRET scalar = SCALAR_FF::from_montgomery(scalar); - #endif - +#endif + // Precompute points: P, 2P, ..., (2^window_size - 1)P constexpr unsigned window_size = 4; // 4 seems fastest. Optimum is minimizing EC add and depends on the field size. for 256b it's 4. diff --git a/icicle/include/icicle/fields/complex_extension.h b/icicle/include/icicle/fields/complex_extension.h index b9fd27bcc..32fa1b950 100644 --- a/icicle/include/icicle/fields/complex_extension.h +++ b/icicle/include/icicle/fields/complex_extension.h @@ -135,8 +135,7 @@ class ComplexExtensionField static constexpr HOST_DEVICE_INLINE ComplexExtensionField reduce(const ExtensionWide& xs) { - return ComplexExtensionField{ - FF::reduce(xs.real), FF::reduce(xs.imaginary)}; + return ComplexExtensionField{FF::reduce(xs.real), FF::reduce(xs.imaginary)}; } template diff --git a/icicle/include/icicle/fields/device_math.h b/icicle/include/icicle/fields/device_math.h deleted file mode 100644 index 4e3aa2b13..000000000 --- a/icicle/include/icicle/fields/device_math.h +++ /dev/null @@ -1,775 +0,0 @@ -#ifdef __CUDACC__ - -#pragma once - -#include -#include "icicle/utils/modifiers.h" -#include "icicle/fields/storage.h" -#include "ptx.h" - -namespace device_math { - -template -struct carry_chain { - unsigned index; - - constexpr __device__ __forceinline__ carry_chain() : index(0) {} - - __device__ __forceinline__ uint32_t add(const uint32_t x, const uint32_t y) { - index++; - if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) - return ptx::add(x, y); - else if (index == 1 && !CARRY_IN) - return ptx::add_cc(x, y); - else if (index < OPS_COUNT || CARRY_OUT) - return ptx::addc_cc(x, y); - else - return ptx::addc(x, y); - } - - __device__ __forceinline__ uint32_t sub(const uint32_t x, const uint32_t y) { - index++; - if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) - return ptx::sub(x, y); - else if (index == 1 && !CARRY_IN) - return ptx::sub_cc(x, y); - else if (index < OPS_COUNT || CARRY_OUT) - return ptx::subc_cc(x, y); - else - return ptx::subc(x, y); - } - - __device__ __forceinline__ uint32_t mad_lo(const uint32_t x, const uint32_t y, const uint32_t z) { - index++; - if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) - return ptx::mad_lo(x, y, z); - else if (index == 1 && !CARRY_IN) - return ptx::mad_lo_cc(x, y, z); - else if (index < OPS_COUNT || CARRY_OUT) - return ptx::madc_lo_cc(x, y, z); - else - return ptx::madc_lo(x, y, z); - } - - __device__ __forceinline__ uint32_t mad_hi(const uint32_t x, const uint32_t y, const uint32_t z) { - index++; - if (index == 1 && OPS_COUNT == 1 && !CARRY_IN && !CARRY_OUT) - return ptx::mad_hi(x, y, z); - else if (index == 1 && !CARRY_IN) - return ptx::mad_hi_cc(x, y, z); - else if (index < OPS_COUNT || CARRY_OUT) - return ptx::madc_hi_cc(x, y, z); - else - return ptx::madc_hi(x, y, z); - } -}; - -template - static constexpr DEVICE_INLINE uint32_t add_sub_u32_device(const uint32_t* x, const uint32_t* y, uint32_t* r) - { - r[0] = SUBTRACT ? ptx::sub_cc(x[0], y[0]) : ptx::add_cc(x[0], y[0]); - for (unsigned i = 1; i < NLIMBS; i++) - r[i] = SUBTRACT ? ptx::subc_cc(x[i], y[i]) : ptx::addc_cc(x[i], y[i]); - if (!CARRY_OUT) { - ptx::addc(0, 0); - return 0; - } - return SUBTRACT ? ptx::subc(0, 0) : ptx::addc(0, 0); - } - - template - static constexpr DEVICE_INLINE uint32_t - add_sub_limbs_device(const storage& xs, const storage& ys, storage& rs) - { - const uint32_t* x = xs.limbs; - const uint32_t* y = ys.limbs; - uint32_t* r = rs.limbs; - return add_sub_u32_device(x, y, r); - } - - template - static DEVICE_INLINE void mul_n(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = NLIMBS) - { - UNROLL - for (size_t i = 0; i < n; i += 2) { - acc[i] = ptx::mul_lo(a[i], bi); - acc[i + 1] = ptx::mul_hi(a[i], bi); - } - } - - template - static DEVICE_INLINE void mul_n_msb(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = NLIMBS, size_t start_i = 0) - { - UNROLL - for (size_t i = start_i; i < n; i += 2) { - acc[i] = ptx::mul_lo(a[i], bi); - acc[i + 1] = ptx::mul_hi(a[i], bi); - } - } - - template - static DEVICE_INLINE void - cmad_n(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = NLIMBS, uint32_t optional_carry = 0) - { - if (CARRY_IN) ptx::add_cc(UINT32_MAX, optional_carry); - acc[0] = CARRY_IN ? ptx::madc_lo_cc(a[0], bi, acc[0]) : ptx::mad_lo_cc(a[0], bi, acc[0]); - acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); - - // printf("even %d, mod %d, mi %d\n", acc[0], a[0],bi); - // printf("even %d, mod %d, mi %d\n", acc[0], a[0],bi); - UNROLL - for (size_t i = 2; i < n; i += 2) { - acc[i] = ptx::madc_lo_cc(a[i], bi, acc[i]); - acc[i + 1] = ptx::madc_hi_cc(a[i], bi, acc[i + 1]); - } - } - - template - static DEVICE_INLINE void cmad_n_msb(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = NLIMBS) - { - if (EVEN_PHASE) { - acc[0] = ptx::mad_lo_cc(a[0], bi, acc[0]); - acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); - } else { - acc[1] = ptx::mad_hi_cc(a[0], bi, acc[1]); - } - - UNROLL - for (size_t i = 2; i < n; i += 2) { - acc[i] = ptx::madc_lo_cc(a[i], bi, acc[i]); - acc[i + 1] = ptx::madc_hi_cc(a[i], bi, acc[i + 1]); - } - } - - template - static DEVICE_INLINE void cmad_n_lsb(uint32_t* acc, const uint32_t* a, uint32_t bi, size_t n = NLIMBS) - { - if (n > 1) - acc[0] = ptx::mad_lo_cc(a[0], bi, acc[0]); - else - acc[0] = ptx::mad_lo(a[0], bi, acc[0]); - - size_t i; - UNROLL - for (i = 1; i < n - 1; i += 2) { - acc[i] = ptx::madc_hi_cc(a[i - 1], bi, acc[i]); - if (i == n - 2) - acc[i + 1] = ptx::madc_lo(a[i + 1], bi, acc[i + 1]); - else - acc[i + 1] = ptx::madc_lo_cc(a[i + 1], bi, acc[i + 1]); - } - if (i == n - 1) acc[i] = ptx::madc_hi(a[i - 1], bi, acc[i]); - } - - template - static DEVICE_INLINE uint32_t mad_row( - uint32_t* odd, - uint32_t* even, - const uint32_t* a, - uint32_t bi, - size_t n = NLIMBS, - uint32_t ci = 0, - uint32_t di = 0, - uint32_t carry_for_high = 0, - uint32_t carry_for_low = 0) - { - cmad_n(odd, a + 1, bi, n - 2, carry_for_low); - odd[n - 2] = ptx::madc_lo_cc(a[n - 1], bi, ci); - odd[n - 1] = CARRY_OUT ? ptx::madc_hi_cc(a[n - 1], bi, di) : ptx::madc_hi(a[n - 1], bi, di); - uint32_t cr = CARRY_OUT ? ptx::addc(0, 0) : 0; - cmad_n(even, a, bi, n); - if (CARRY_OUT) { - odd[n - 1] = ptx::addc_cc(odd[n - 1], carry_for_high); - cr = ptx::addc(cr, 0); - } else - odd[n - 1] = ptx::addc(odd[n - 1], carry_for_high); - return cr; - } - - #ifdef BARRET - - template - static DEVICE_INLINE void mad_row_msb(uint32_t* odd, uint32_t* even, const uint32_t* a, uint32_t bi, size_t n = NLIMBS) - { - cmad_n_msb(odd, EVEN_PHASE ? a : (a + 1), bi, n - 2); - odd[EVEN_PHASE ? (n - 1) : (n - 2)] = ptx::madc_lo_cc(a[n - 1], bi, 0); - odd[EVEN_PHASE ? n : (n - 1)] = ptx::madc_hi(a[n - 1], bi, 0); - cmad_n_msb(even, EVEN_PHASE ? (a + 1) : a, bi, n - 1); - odd[EVEN_PHASE ? n : (n - 1)] = ptx::addc(odd[EVEN_PHASE ? n : (n - 1)], 0); - } - - template - static DEVICE_INLINE void mad_row_lsb(uint32_t* odd, uint32_t* even, const uint32_t* a, uint32_t bi, size_t n = NLIMBS) - { - // bi here is constant so we can do a compile-time check for zero (which does happen once for bls12-381 scalar field - // modulus) - if (bi != 0) { - if (n > 1) cmad_n_lsb(odd, a + 1, bi, n - 1); - cmad_n_lsb(even, a, bi, n); - } - return; - } - - #endif - - template - static DEVICE_INLINE uint32_t - mul_n_and_add(uint32_t* acc, const uint32_t* a, uint32_t bi, uint32_t* extra, size_t n = (NLIMBS >> 1)) - { - acc[0] = ptx::mad_lo_cc(a[0], bi, extra[0]); - - UNROLL - for (size_t i = 1; i < n - 1; i += 2) { - acc[i] = ptx::madc_hi_cc(a[i - 1], bi, extra[i]); - acc[i + 1] = ptx::madc_lo_cc(a[i + 1], bi, extra[i + 1]); - } - - acc[n - 1] = ptx::madc_hi_cc(a[n - 2], bi, extra[n - 1]); - return ptx::addc(0, 0); - } - - /** - * This method multiplies `a` and `b` (both assumed to have NLIMBS / 2 limbs) and adds `in1` and `in2` (NLIMBS limbs each) - * to the result which is written to `even`. - * - * It is used to compute the "middle" part of Karatsuba: \f$ a_{lo} \cdot b_{hi} + b_{lo} \cdot a_{hi} = - * (a_{hi} - a_{lo})(b_{lo} - b_{hi}) + a_{lo} \cdot b_{lo} + a_{hi} \cdot b_{hi} \f$. Currently this method assumes - * that the top bit of \f$ a_{hi} \f$ and \f$ b_{hi} \f$ are unset. This ensures correctness by allowing to keep the - * result inside NLIMBS limbs and ignore the carries from the highest limb. - */ - template - static DEVICE_INLINE void - multiply_and_add_short_raw_device(const uint32_t* a, const uint32_t* b, uint32_t* even, uint32_t* in1, uint32_t* in2) - { - __align__(16) uint32_t odd[NLIMBS - 2]; - uint32_t first_row_carry = mul_n_and_add(even, a, b[0], in1); - uint32_t carry = mul_n_and_add(odd, a + 1, b[0], &in2[1]); - - size_t i; - UNROLL - for (i = 2; i < ((NLIMBS >> 1) - 1); i += 2) { - carry = mad_row( - &even[i], &odd[i - 2], a, b[i - 1], NLIMBS >> 1, in1[(NLIMBS >> 1) + i - 2], in1[(NLIMBS >> 1) + i - 1], carry); - carry = - mad_row(&odd[i], &even[i], a, b[i], NLIMBS >> 1, in2[(NLIMBS >> 1) + i - 1], in2[(NLIMBS >> 1) + i], carry); - } - mad_row( - &even[NLIMBS >> 1], &odd[(NLIMBS >> 1) - 2], a, b[(NLIMBS >> 1) - 1], NLIMBS >> 1, in1[NLIMBS - 2], in1[NLIMBS - 1], carry, - first_row_carry); - // merge |even| and |odd| plus the parts of `in2` we haven't added yet (first and last limbs) - even[0] = ptx::add_cc(even[0], in2[0]); - for (i = 0; i < (NLIMBS - 2); i++) - even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); - even[i + 1] = ptx::addc(even[i + 1], in2[i + 1]); - } - - - - /** - * This method multiplies `a` and `b` and writes the result into `even`. It assumes that `a` and `b` are NLIMBS/2 limbs - * long. The usual schoolbook algorithm is used. - */ - template - static DEVICE_INLINE void multiply_short_raw_device(const uint32_t* a, const uint32_t* b, uint32_t* even) - { - __align__(16) uint32_t odd[NLIMBS - 2]; - mul_n(even, a, b[0], NLIMBS >> 1); - mul_n(odd, a + 1, b[0], NLIMBS >> 1); - mad_row(&even[2], &odd[0], a, b[1], NLIMBS >> 1); - - size_t i; - UNROLL - for (i = 2; i < ((NLIMBS >> 1) - 1); i += 2) { - mad_row(&odd[i], &even[i], a, b[i], NLIMBS >> 1); - mad_row(&even[i + 2], &odd[i], a, b[i + 1], NLIMBS >> 1); - } - // merge |even| and |odd| - even[1] = ptx::add_cc(even[1], odd[0]); - for (i = 1; i < NLIMBS - 2; i++) - even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); - even[i + 1] = ptx::addc(even[i + 1], 0); - } - - /** - * This method multiplies `as` and `bs` and writes the (wide) result into `rs`. - * - * It is assumed that the highest bits of `as` and `bs` are unset which is true for all the numbers icicle had to deal - * with so far. This method implements [subtractive - * Karatsuba](https://en.wikipedia.org/wiki/Karatsuba_algorithm#Implementation). - */ - template - static DEVICE_INLINE void multiply_raw_device(const storage& as, const storage& bs, storage<2*NLIMBS>& rs) - { - const uint32_t* a = as.limbs; - const uint32_t* b = bs.limbs; - uint32_t* r = rs.limbs; - if constexpr (NLIMBS > 2) { - // Next two lines multiply high and low halves of operands (\f$ a_{lo} \cdot b_{lo}; a_{hi} \cdot b_{hi} \$f) and - // write the results into `r`. - multiply_short_raw_device(a, b, r); - multiply_short_raw_device(&a[NLIMBS >> 1], &b[NLIMBS >> 1], &r[NLIMBS]); - __align__(16) uint32_t middle_part[NLIMBS]; - __align__(16) uint32_t diffs[NLIMBS]; - // Differences of halves \f$ a_{hi} - a_{lo}; b_{lo} - b_{hi} \$f are written into `diffs`, signs written to - // `carry1` and `carry2`. - uint32_t carry1 = add_sub_u32_device<(NLIMBS >> 1), true, true>(&a[NLIMBS >> 1], a, diffs); - uint32_t carry2 = add_sub_u32_device<(NLIMBS >> 1), true, true>(b, &b[NLIMBS >> 1], &diffs[NLIMBS >> 1]); - // Compute the "middle part" of Karatsuba: \f$ a_{lo} \cdot b_{hi} + b_{lo} \cdot a_{hi} \f$. - // This is where the assumption about unset high bit of `a` and `b` is relevant. - multiply_and_add_short_raw_device(diffs, &diffs[NLIMBS >> 1], middle_part, r, &r[NLIMBS]); - // Corrections that need to be performed when differences are negative. - // Again, carry doesn't need to be propagated due to unset high bits of `a` and `b`. - if (carry1) - add_sub_u32_device<(NLIMBS >> 1), true, false>(&middle_part[NLIMBS >> 1], &diffs[NLIMBS >> 1], &middle_part[NLIMBS >> 1]); - if (carry2) add_sub_u32_device<(NLIMBS >> 1), true, false>(&middle_part[NLIMBS >> 1], diffs, &middle_part[NLIMBS >> 1]); - // Now that middle part is fully correct, it can be added to the result. - add_sub_u32_device(&r[NLIMBS >> 1], middle_part, &r[NLIMBS >> 1]); - - // Carry from adding middle part has to be propagated to the highest limb. - for (size_t i = NLIMBS + (NLIMBS >> 1); i < 2 * NLIMBS; i++) - r[i] = ptx::addc_cc(r[i], 0); - } else if (NLIMBS == 2) { - __align__(8) uint32_t odd[2]; - r[0] = ptx::mul_lo(a[0], b[0]); - r[1] = ptx::mul_hi(a[0], b[0]); - r[2] = ptx::mul_lo(a[1], b[1]); - r[3] = ptx::mul_hi(a[1], b[1]); - odd[0] = ptx::mul_lo(a[0], b[1]); - odd[1] = ptx::mul_hi(a[0], b[1]); - odd[0] = ptx::mad_lo(a[1], b[0], odd[0]); - odd[1] = ptx::mad_hi(a[1], b[0], odd[1]); - r[1] = ptx::add_cc(r[1], odd[0]); - r[2] = ptx::addc_cc(r[2], odd[1]); - r[3] = ptx::addc(r[3], 0); - } else if (NLIMBS == 1) { - r[0] = ptx::mul_lo(a[0], b[0]); - r[1] = ptx::mul_hi(a[0], b[0]); - } - } - - #ifdef BARRET - - /** - * A function that computes wide product \f$ rs = as \cdot bs \f$ that's correct for the higher NLIMBS + 1 limbs with a - * small maximum error. - * - * The way this function saves computations (as compared to regular school-book multiplication) is by not including - * terms that are too small. Namely, limb product \f$ a_i \cdot b_j \f$ is excluded if \f$ i + j < NLIMBS - 2 \f$ and - * only the higher half is included if \f$ i + j = NLIMBS - 2 \f$. All other limb products are included. So, the error - * i.e. difference between true product and the result of this function written to `rs` is exactly the sum of all - * dropped limbs products, which we can bound: \f$ a_0 \cdot b_0 + 2^{32}(a_0 \cdot b_1 + a_1 \cdot b_0) + \dots + - * 2^{32(NLIMBS - 3)}(a_{NLIMBS - 3} \cdot b_0 + \dots + a_0 \cdot b_{NLIMBS - 3}) + 2^{32(NLIMBS - 2)}(\floor{\frac{a_{NLIMBS - 2} - * \cdot b_0}{2^{32}}} + \dots + \floor{\frac{a_0 \cdot b_{NLIMBS - 2}}{2^{32}}}) \leq 2^{64} + 2\cdot 2^{96} + \dots + - * (NLIMBS - 2) \cdot 2^{32(NLIMBS - 1)} + (NLIMBS - 1) \cdot 2^{32(NLIMBS - 1)} \leq 2(NLIMBS - 1) \cdot 2^{32(NLIMBS - 1)}\f$. - */ - template - static DEVICE_INLINE void multiply_msb_raw_device(const storage& as, const storage& bs, storage<2*NLIMBS>& rs) - { - if constexpr (NLIMBS > 1) { - const uint32_t* a = as.limbs; - const uint32_t* b = bs.limbs; - uint32_t* even = rs.limbs; - __align__(16) uint32_t odd[2 * NLIMBS - 2]; - - even[NLIMBS - 1] = ptx::mul_hi(a[NLIMBS - 2], b[0]); - odd[NLIMBS - 2] = ptx::mul_lo(a[NLIMBS - 1], b[0]); - odd[NLIMBS - 1] = ptx::mul_hi(a[NLIMBS - 1], b[0]); - size_t i; - UNROLL - for (i = 2; i < NLIMBS - 1; i += 2) { - mad_row_msb(&even[NLIMBS - 2], &odd[NLIMBS - 2], &a[NLIMBS - i - 1], b[i - 1], i + 1); - mad_row_msb(&odd[NLIMBS - 2], &even[NLIMBS - 2], &a[NLIMBS - i - 2], b[i], i + 2); - } - mad_row(&even[NLIMBS], &odd[NLIMBS - 2], a, b[NLIMBS - 1]); - - // merge |even| and |odd| - ptx::add_cc(even[NLIMBS - 1], odd[NLIMBS - 2]); - for (i = NLIMBS - 1; i < 2 * NLIMBS - 2; i++) - even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); - even[i + 1] = ptx::addc(even[i + 1], 0); - } else { - multiply_raw_device(as, bs, rs); - } - } - - /** - * A function that computes the low half of the fused multiply-and-add \f$ rs = as \cdot bs + cs \f$ where - * \f$ bs = 2^{32*nof_limbs} \f$. - * - * For efficiency, this method does not include terms that are too large. Namely, limb product \f$ a_i \cdot b_j \f$ - * is excluded if \f$ i + j > NLIMBS - 1 \f$ and only the lower half is included if \f$ i + j = NLIMBS - 1 \f$. All other - * limb products are included. - */ - template - static DEVICE_INLINE void - multiply_and_add_lsb_neg_modulus_raw_device(const storage& as, const storage& bs, storage& cs, storage& rs) - { - const uint32_t* a = as.limbs; - const uint32_t* b = bs.limbs; - uint32_t* c = cs.limbs; - uint32_t* even = rs.limbs; - - if constexpr (NLIMBS > 2) { - __align__(16) uint32_t odd[NLIMBS - 1]; - size_t i; - // `b[0]` is \f$ 2^{32} \f$ minus the last limb of prime modulus. Because most scalar (and some base) primes - // are necessarily NTT-friendly, `b[0]` often turns out to be \f$ 2^{32} - 1 \f$. This actually leads to - // less efficient SASS generated by nvcc, so this case needed separate handling. - if (b[0] == UINT32_MAX) { - add_sub_u32_device(c, a, even); - for (i = 0; i < NLIMBS - 1; i++) - odd[i] = a[i]; - } else { - mul_n_and_add(even, a, b[0], c, NLIMBS); - mul_n(odd, a + 1, b[0], NLIMBS - 1); - } - mad_row_lsb(&even[2], &odd[0], a, b[1], NLIMBS - 1); - UNROLL - for (i = 2; i < NLIMBS - 1; i += 2) { - mad_row_lsb(&odd[i], &even[i], a, b[i], NLIMBS - i); - mad_row_lsb(&even[i + 2], &odd[i], a, b[i + 1], NLIMBS - i - 1); - } - - // merge |even| and |odd| - even[1] = ptx::add_cc(even[1], odd[0]); - for (i = 1; i < NLIMBS - 2; i++) - even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); - even[i + 1] = ptx::addc(even[i + 1], odd[i]); - } else if (NLIMBS == 2) { - even[0] = ptx::mad_lo(a[0], b[0], c[0]); - even[1] = ptx::mad_hi(a[0], b[0], c[0]); - even[1] = ptx::mad_lo(a[0], b[1], even[1]); - even[1] = ptx::mad_lo(a[1], b[0], even[1]); - } else if (NLIMBS == 1) { - even[0] = ptx::mad_lo(a[0], b[0], c[0]); - } - } - -#endif - - // The following algorithms are adaptations of - // http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf, - // taken from https://github.com/z-prize/test-msm-gpu (under Apache 2.0 license) - // and modified to use our datatypes. - // We had our own implementation of http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf, - // but the sppark versions achieved lower instruction count thanks to clever carry handling, - // so we decided to just use theirs. - template - static DEVICE_INLINE void madc_n_rshift(uint32_t *odd, const uint32_t *a, uint32_t bi) { -#pragma unroll - for (size_t i = 0; i < NLIMBS - 2; i += 2) { - odd[i] = ptx::madc_lo_cc(a[i], bi, odd[i + 2]); - odd[i + 1] = ptx::madc_hi_cc(a[i], bi, odd[i + 3]); - } - odd[NLIMBS - 2] = ptx::madc_lo_cc(a[NLIMBS - 2], bi, 0); - odd[NLIMBS - 1] = ptx::madc_hi(a[NLIMBS - 2], bi, 0); - } - - template - static DEVICE_INLINE void mad_n_redc(uint32_t *even, uint32_t *odd, const uint32_t *a, uint32_t bi, const uint32_t *modulus, const uint32_t *mont_inv_modulus, bool first = false) { - if (first) { - mul_n(odd, a + 1, bi); - mul_n(even, a, bi); - } else { - even[0] = ptx::add_cc(even[0], odd[1]); - madc_n_rshift(odd, a + 1, bi); - cmad_n(even, a, bi); - odd[NLIMBS - 1] = ptx::addc(odd[NLIMBS - 1], 0); - } - uint32_t mi = even[0] * mont_inv_modulus[0]; - cmad_n(odd, modulus + 1, mi); - cmad_n(even, modulus, mi); - odd[NLIMBS - 1] = ptx::addc(odd[NLIMBS - 1], 0); - } - - template - static DEVICE_INLINE void qad_row(uint32_t *odd, uint32_t *even, const uint32_t *a, uint32_t bi, size_t n = NLIMBS) { - cmad_n(odd, a, bi, n - 2); - odd[n - 2] = ptx::madc_lo_cc(a[n - 2], bi, 0); - odd[n - 1] = ptx::madc_hi(a[n - 2], bi, 0); - cmad_n(even, a + 1, bi, n - 2); - odd[n - 1] = ptx::addc(odd[n - 1], 0); - } - - //TODO: test if beeter than karatsuba - template - static DEVICE_INLINE void multiply_raw_sb(const storage &as, const storage &bs, storage<2*NLIMBS> &rs) { - const uint32_t *a = as.limbs; - const uint32_t *b = bs.limbs; - if constexpr (NLIMBS > 2){ - uint32_t *even = rs.limbs; - __align__(8) uint32_t odd[2 * NLIMBS - 2]; - mul_n(even, a, b[0]); - mul_n(odd, a + 1, b[0]); - mad_row(&even[2], &odd[0], a, b[1]); - size_t i; - #pragma unroll - for (i = 2; i < NLIMBS - 1; i += 2) { - mad_row(&odd[i], &even[i], a, b[i]); - mad_row(&even[i + 2], &odd[i], a, b[i + 1]); - } - // merge |even| and |odd| - even[1] = ptx::add_cc(even[1], odd[0]); - for (i = 1; i < 2 * NLIMBS - 2; i++) - even[i + 1] = ptx::addc_cc(even[i + 1], odd[i]); - even[i + 1] = ptx::addc(even[i + 1], 0); - } - else if (NLIMBS == 2) { - uint32_t *r = rs.limbs; - __align__(8) uint32_t odd[2]; - r[0] = ptx::mul_lo(a[0], b[0]); - r[1] = ptx::mul_hi(a[0], b[0]); - r[2] = ptx::mul_lo(a[1], b[1]); - r[3] = ptx::mul_hi(a[1], b[1]); - odd[0] = ptx::mul_lo(a[0], b[1]); - odd[1] = ptx::mul_hi(a[0], b[1]); - odd[0] = ptx::mad_lo(a[1], b[0], odd[0]); - odd[1] = ptx::mad_hi(a[1], b[0], odd[1]); - r[1] = ptx::add_cc(r[1], odd[0]); - r[2] = ptx::addc_cc(r[2], odd[1]); - r[3] = ptx::addc(r[3], 0); - } else if (NLIMBS == 1) { - uint32_t *r = rs.limbs; - r[0] = ptx::mul_lo(a[0], b[0]); - r[1] = ptx::mul_hi(a[0], b[0]); - } - } - - template - static DEVICE_INLINE void sqr_raw(const storage &as, storage<2*NLIMBS> &rs) { - const uint32_t *a = as.limbs; - uint32_t *even = rs.limbs; - size_t i = 0, j; - __align__(8) uint32_t odd[2 * NLIMBS - 2]; - - // perform |a[i]|*|a[j]| for all j>i - mul_n(even + 2, a + 2, a[0], NLIMBS - 2); - mul_n(odd, a + 1, a[0], NLIMBS); - -#pragma unroll - while (i < NLIMBS - 4) { - ++i; - mad_row(&even[2 * i + 2], &odd[2 * i], &a[i + 1], a[i], NLIMBS - i - 1); - ++i; - qad_row(&odd[2 * i], &even[2 * i + 2], &a[i + 1], a[i], NLIMBS - i); - } - - even[2 * NLIMBS - 4] = ptx::mul_lo(a[NLIMBS - 1], a[NLIMBS - 3]); - even[2 * NLIMBS - 3] = ptx::mul_hi(a[NLIMBS - 1], a[NLIMBS - 3]); - odd[2 * NLIMBS - 6] = ptx::mad_lo_cc(a[NLIMBS - 2], a[NLIMBS - 3], odd[2 * NLIMBS - 6]); - odd[2 * NLIMBS - 5] = ptx::madc_hi_cc(a[NLIMBS - 2], a[NLIMBS - 3], odd[2 * NLIMBS - 5]); - even[2 * NLIMBS - 3] = ptx::addc(even[2 * NLIMBS - 3], 0); - - odd[2 * NLIMBS - 4] = ptx::mul_lo(a[NLIMBS - 1], a[NLIMBS - 2]); - odd[2 * NLIMBS - 3] = ptx::mul_hi(a[NLIMBS - 1], a[NLIMBS - 2]); - - // merge |even[2:]| and |odd[1:]| - even[2] = ptx::add_cc(even[2], odd[1]); - for (j = 2; j < 2 * NLIMBS - 3; j++) - even[j + 1] = ptx::addc_cc(even[j + 1], odd[j]); - even[j + 1] = ptx::addc(odd[j], 0); - - // double |even| - even[0] = 0; - even[1] = ptx::add_cc(odd[0], odd[0]); - for (j = 2; j < 2 * NLIMBS - 1; j++) - even[j] = ptx::addc_cc(even[j], even[j]); - even[j] = ptx::addc(0, 0); - - // accumulate "diagonal" |a[i]|*|a[i]| product - i = 0; - even[2 * i] = ptx::mad_lo_cc(a[i], a[i], even[2 * i]); - even[2 * i + 1] = ptx::madc_hi_cc(a[i], a[i], even[2 * i + 1]); - for (++i; i < NLIMBS; i++) { - even[2 * i] = ptx::madc_lo_cc(a[i], a[i], even[2 * i]); - even[2 * i + 1] = ptx::madc_hi_cc(a[i], a[i], even[2 * i + 1]); - } - } - - template - static DEVICE_INLINE void mul_by_1_row(uint32_t *even, uint32_t *odd, const uint32_t *modulus, const uint32_t *mont_inv_modulus, bool first = false) { - uint32_t mi; - if (first) { - mi = even[0] * mont_inv_modulus[0]; - mul_n(odd, modulus + 1, mi); - cmad_n(even, modulus, mi); - odd[NLIMBS - 1] = ptx::addc(odd[NLIMBS - 1], 0); - } else { - even[0] = ptx::add_cc(even[0], odd[1]); - // we trust the compiler to *not* touch the carry flag here - // this code sits in between two "asm volatile" instructions witch should guarantee that nothing else interferes wit the carry flag - mi = even[0] * mont_inv_modulus[0]; - madc_n_rshift(odd, modulus + 1, mi); - cmad_n(even, modulus, mi); - odd[NLIMBS - 1] = ptx::addc(odd[NLIMBS - 1], 0); - } - } - - // Performs Montgomery reduction on a storage<2*NLIMBS> input. Input value must be in the range [0, mod*2^(32*NLIMBS)). - // Does not implement an in-place reduce epilogue. If you want to further reduce the result, - // call reduce(xs.get_lo()) after the call to redc_wide_inplace. - template - static DEVICE_INLINE void reduce_mont_inplace(storage<2*NLIMBS> &xs, const storage &modulus, const storage &mont_inv_modulus) { - uint32_t *even = xs.limbs; - // Yields montmul of lo NLIMBS limbs * 1. - // Since the hi NLIMBS limbs don't participate in computing the "mi" factor at each mul-and-rightshift stage, - // it's ok to ignore the hi NLIMBS limbs during this process and just add them in afterward. - uint32_t odd[NLIMBS]; - if constexpr (NLIMBS == 1) { - uint32_t mi = even[0] * mont_inv_modulus.limbs[0]; //m - // printf("m %u\n", mi); - // printf("p %u\n", modulus.limbs[0]); - // printf("t0 %u\n", odd[0]); - // cmad_n(odd, modulus + 1, mi); - // acc[0] = ptx::mad_lo_cc(a[0], bi, acc[0]); - // acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); - // cmad_n(even, modulus.limbs, mi); - even[0] = ptx::mad_lo_cc(modulus.limbs[0], mi, even[0]); - odd[0] = ptx::madc_hi_cc(modulus.limbs[0], mi, 0); - - // odd[0] = ptx::mad_lo(modulus.limbs[0], mi, odd[0]); //C - // even[0] = ptx::mul_hi(modulus.limbs[0], mi, odd[0]); //C - // even[0] = ptx::mad_hi(modulus.limbs[0], mi, 0); //C - // even[0] = ptx::mad_hi(modulus.limbs[0], 0, odd[0]); //C - // printf("C %u\n", odd[1]); - // acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); - // odd[0] = ptx::addc(odd[0], 0); - - // odd[0] = ptx::add_cc(odd[0], even[1]); - // madc_n_rshift(even, a + 1, bi); - // cmad_n(odd, a, bi); - // even[NLIMBS - 1] = ptx::addc(even[NLIMBS - 1], 0); - // mi = odd[0] * mont_inv_modulus.limbs[0]; - // cmad_n(even, modulus + 1, mi); - // cmad_n(odd, modulus.limbs, mi); - // odd[0] = ptx::mad_lo_cc(modulus.limbs[0], mi, 0); - // odd[1] = ptx::madc_hi_cc(modulus.limbs[0], mi, 0); - // even[0] = ptx::addc(even[0], 0); - even[0] = ptx::add_cc(even[1], odd[0]); - } - else { - size_t i; - #pragma unroll - for (i = 0; i < NLIMBS; i += 2) { - mul_by_1_row(&even[0], &odd[0], modulus.limbs, mont_inv_modulus.limbs, i == 0); - mul_by_1_row(&odd[0], &even[0], modulus.limbs, mont_inv_modulus.limbs); - } - even[0] = ptx::add_cc(even[0], odd[1]); - #pragma unroll - for (i = 1; i < NLIMBS - 1; i++) - even[i] = ptx::addc_cc(even[i], odd[i + 1]); - even[i] = ptx::addc(even[i], 0); - // Adds in (hi NLIMBS limbs), implicitly right-shifting them by NLIMBS limbs as if they had participated in the - // add-and-rightshift stages above. - xs.limbs[0] = ptx::add_cc(xs.limbs[0], xs.limbs[NLIMBS]); - #pragma unroll - for (i = 1; i < NLIMBS - 1; i++) - xs.limbs[i] = ptx::addc_cc(xs.limbs[i], xs.limbs[i + NLIMBS]); - xs.limbs[NLIMBS - 1] = ptx::addc(xs.limbs[NLIMBS - 1], xs.limbs[2 * NLIMBS - 1]); - } - } - - template - static DEVICE_INLINE void montmul_raw(const storage &a_in, const storage &b_in, const storage &modulus, const storage &mont_inv_modulus, storage &r_in) { - const uint32_t *a = a_in.limbs; - const uint32_t *b = b_in.limbs; - uint32_t *even = r_in.limbs; - __align__(8) uint32_t odd[NLIMBS + 1]; - size_t i; - if constexpr (NLIMBS == 1) { - // mad_n_redc(&even[0], &odd[0], a, b[0], modulus.limbs, mont_inv_modulus.limbs, true); - // printf("even0 b %d\n", even[0]); - // mul_n(even, a, b[0]); - // printf("even0 a %d\n", even[0]); - // uint32_t mi = even[0] * mont_inv_modulus.limbs[0]; - // printf("mi %d\n", mi); - // cmad_n(even, modulus.limbs, mi); - // printf("even0 c %d\n", even[0]); - - // mul_n(odd, a + 1, bi); - // odd[0] = ptx::mul_lo(a[1], b[0]); - // odd[1] = ptx::mul_hi(a[1], b[0]); - // mul_n(even, a, b[0]); - // printf("a %u\n", a[0]); - // printf("b %u\n", b[0]); - odd[0] = ptx::mul_lo(a[0], b[0]); //t[0] - // printf("t0 %u\n", odd[0]); - even[0] = ptx::mul_hi(a[0], b[0]); //A - // printf("A %u\n", even[0]); - uint32_t mi = odd[0] * mont_inv_modulus.limbs[0]; //m - // printf("m %u\n", mi); - // printf("p %u\n", modulus.limbs[0]); - // printf("t0 %u\n", odd[0]); - // cmad_n(odd, modulus + 1, mi); - // acc[0] = ptx::mad_lo_cc(a[0], bi, acc[0]); - // acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); - // cmad_n(even, modulus.limbs, mi); - odd[0] = ptx::mad_lo_cc(modulus.limbs[0], mi, odd[0]); - odd[1] = ptx::madc_hi_cc(modulus.limbs[0], mi, 0); - - // odd[0] = ptx::mad_lo(modulus.limbs[0], mi, odd[0]); //C - // even[0] = ptx::mul_hi(modulus.limbs[0], mi, odd[0]); //C - // even[0] = ptx::mad_hi(modulus.limbs[0], mi, 0); //C - // even[0] = ptx::mad_hi(modulus.limbs[0], 0, odd[0]); //C - // printf("C %u\n", odd[1]); - // acc[1] = ptx::madc_hi_cc(a[0], bi, acc[1]); - // odd[0] = ptx::addc(odd[0], 0); - - // odd[0] = ptx::add_cc(odd[0], even[1]); - // madc_n_rshift(even, a + 1, bi); - // cmad_n(odd, a, bi); - // even[NLIMBS - 1] = ptx::addc(even[NLIMBS - 1], 0); - // mi = odd[0] * mont_inv_modulus.limbs[0]; - // cmad_n(even, modulus + 1, mi); - // cmad_n(odd, modulus.limbs, mi); - // odd[0] = ptx::mad_lo_cc(modulus.limbs[0], mi, 0); - // odd[1] = ptx::madc_hi_cc(modulus.limbs[0], mi, 0); - // even[0] = ptx::addc(even[0], 0); - even[0] = ptx::add_cc(even[0], odd[1]); - // printf("A+C %u\n", even[0]); - } - else { - #pragma unroll - for (i = 0; i < NLIMBS; i += 2) { - mad_n_redc(&even[0], &odd[0], a, b[i], modulus.limbs, mont_inv_modulus.limbs, i == 0); - mad_n_redc(&odd[0], &even[0], a, b[i + 1], modulus.limbs, mont_inv_modulus.limbs); - } - // merge |even| and |odd| - even[0] = ptx::add_cc(even[0], odd[1]); - #pragma unroll - for (i = 1; i < NLIMBS - 1; i++) - even[i] = ptx::addc_cc(even[i], odd[i + 1]); - even[i] = ptx::addc(even[i], 0); - } - // final reduction from [0, 2*mod) to [0, mod) not done here, instead performed optionally in mul_device wrapper - } - - - // Device path adapts http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf to use IMAD.WIDE. - template static constexpr DEVICE_INLINE storage mulmont_device(const storage &xs, const storage &ys, const storage &modulus, const storage &mont_inv_modulus) { - // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes of slack - // static_assert(!(CONFIG::modulus.limbs[NLIMBS - 1] >> 30)); - storage rs = {0}; - montmul_raw(xs, ys, modulus, mont_inv_modulus, rs); - // printf("rs %d\n",rs.limbs[0]); - return rs; - } - - template static constexpr DEVICE_INLINE storage sqrmont_device(const storage &xs) { - // Forces us to think more carefully about the last carry bit if we use a modulus with fewer than 2 leading zeroes of slack - // static_assert(!(CONFIG::modulus.limbs[NLIMBS - 1] >> 30)); - storage<2*NLIMBS> rs = {0}; - sqr_raw(xs, rs); - reduce_mont_inplace(rs); // after reduce_twopass, tmp's low NLIMBS limbs should represent a value in [0, 2*mod) - return rs.get_lo(); - } -// //add -// // return xs * ys with field operands -// // Device path adapts http://www.acsel-lab.com/arithmetic/arith23/data/1616a047.pdf to use IMAD.WIDE. -// // Host path uses CIOS. -// template static constexpr DEVICE_INLINE storage mulz(const storage &xs, const storage &ys) { -// return mul_devicez(xs, ys); -// } -} - -#endif \ No newline at end of file diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 05632fdbe..89749af2a 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -20,12 +20,11 @@ #ifdef __CUDACC__ #include "gpu-utils/sharedmem.h" - #include "ptx.h" + #include "device_math.h" #endif // __CUDACC__ +#include "host_math.h" #include "icicle/errors.h" -#include "host_math.h" -#include "device_math.h" #include "storage.h" #include @@ -36,6 +35,12 @@ using namespace icicle; +#ifdef __CUDA_ARCH__ +namespace base_math = device_math; +#else +namespace base_math = host_math; +#endif + template class Field { @@ -191,18 +196,7 @@ class Field static constexpr Field HOST_DEVICE_INLINE get_higher_with_slack(const Wide& xs) { Field out{}; -#ifdef __CUDA_ARCH__ - UNROLL -#endif - for (unsigned i = 0; i < TLC; i++) { -#ifdef __CUDA_ARCH__ - out.limbs_storage.limbs[i] = - __funnelshift_lc(xs.limbs_storage.limbs[i + TLC - 1], xs.limbs_storage.limbs[i + TLC], 2 * slack_bits); -#else - out.limbs_storage.limbs[i] = (xs.limbs_storage.limbs[i + TLC] << 2 * slack_bits) + - (xs.limbs_storage.limbs[i + TLC - 1] >> (32 - 2 * slack_bits)); -#endif - } + base_math::get_higher_with_slack(xs.limbs_storage, out.limbs_storage, slack_bits); return out; } @@ -281,58 +275,21 @@ class Field static constexpr HOST_DEVICE_INLINE uint32_t add_limbs(const storage& xs, const storage& ys, storage& rs) { -#ifdef __CUDA_ARCH__ - return device_math::template add_sub_limbs_device(xs, ys, rs); -#else - return host_math::template add_sub_limbs(xs, ys, rs); -#endif + return base_math::template add_sub_limbs(xs, ys, rs); } template static constexpr HOST_DEVICE_INLINE uint32_t sub_limbs(const storage& xs, const storage& ys, storage& rs) { -#ifdef __CUDA_ARCH__ - return device_math::template add_sub_limbs_device(xs, ys, rs); -#else - return host_math::template add_sub_limbs(xs, ys, rs); -#endif + return base_math::template add_sub_limbs(xs, ys, rs); } static HOST_DEVICE_INLINE void multiply_raw(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs) { -#ifdef __CUDA_ARCH__ - return device_math::template multiply_raw_device(as, bs, rs); -#else - return host_math::template multiply_raw(as, bs, rs); -#endif + return base_math::template multiply_raw(as, bs, rs); } - #ifdef BARRET - - static HOST_DEVICE_INLINE void - multiply_and_add_lsb_neg_modulus_raw(const ff_storage& as, ff_storage& cs, ff_storage& rs) - { -#ifdef __CUDA_ARCH__ - return device_math::template multiply_and_add_lsb_neg_modulus_raw_device(as, get_neg_modulus(), cs, rs); -#else - Wide r_wide = {}; - host_math::template multiply_raw(as, get_neg_modulus(), r_wide.limbs_storage); - Field r = Wide::get_lower(r_wide); - add_limbs(cs, r.limbs_storage, rs); -#endif - } - - static HOST_DEVICE_INLINE void multiply_msb_raw(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs) - { -#ifdef __CUDA_ARCH__ - return device_math::template multiply_msb_raw_device(as, bs, rs); -#else - return host_math::template multiply_raw(as, bs, rs); -#endif - } - - #endif public: ff_storage limbs_storage; @@ -421,7 +378,29 @@ class Field return rs; } - #ifdef BARRET +#ifdef BARRET + + static HOST_DEVICE_INLINE void + multiply_and_add_lsb_neg_modulus_raw(const ff_storage& as, ff_storage& cs, ff_storage& rs) + { + #ifdef __CUDA_ARCH__ + return base_math::template multiply_and_add_lsb_neg_modulus_raw(as, get_neg_modulus(), cs, rs); + #else + Wide r_wide = {}; + base_math::template multiply_raw(as, get_neg_modulus(), r_wide.limbs_storage); + Field r = Wide::get_lower(r_wide); + add_limbs(cs, r.limbs_storage, rs); + #endif + } + + static HOST_DEVICE_INLINE void multiply_msb_raw(const ff_storage& as, const ff_storage& bs, ff_wide_storage& rs) + { + #ifdef __CUDA_ARCH__ + return base_math::template multiply_msb_raw(as, bs, rs); + #else + return base_math::template multiply_raw(as, bs, rs); + #endif + } /** * This method reduces a Wide number `xs` modulo `p` and returns the result as a Field element. @@ -444,7 +423,7 @@ class Field * cases it's less than 1, so setting the [num_of_reductions](@ref num_of_reductions) variable for a field equal to 1 * will cause only 1 reduction to be performed. */ - static constexpr HOST_DEVICE_INLINE Field barret_reduce(const Wide& xs) //TODO = add reduce_mont_inplace + static constexpr HOST_DEVICE_INLINE Field barret_reduce(const Wide& xs) // TODO = add reduce_mont_inplace { // `xs` is left-shifted by `2 * slack_bits` and higher half is written to `xs_hi` Field xs_hi = Wide::get_higher_with_slack(xs); @@ -469,11 +448,17 @@ class Field return r; } - #endif + static constexpr HOST_INLINE Field barret_mult(const Field& xs, const Field& ys) + { + Wide xy = mul_wide(xs, ys); // full mult + return reduce(xy); // reduce mod p + } - static constexpr HOST_DEVICE_INLINE Field mont_sub_modulus(const Wide& xs, bool get_higher = false) +#endif + + static constexpr HOST_DEVICE_INLINE Field mont_sub_modulus(const Field& xs, bool get_higher = false) { - Field r = get_higher? Wide::get_higher(xs) : Wide::get_lower(xs); + Field r = xs; Field p = Field{get_modulus<1>()}; if (p.limbs_storage.limbs[TLC - 1] > r.limbs_storage.limbs[TLC - 1]) return r; ff_storage r_reduced = {}; @@ -483,371 +468,59 @@ class Field return r; } - static constexpr HOST_DEVICE_INLINE Field reduce(const Wide& xs) - { - #ifdef BARRET - return barret_reduce(xs); - #else - return mont_reduce(xs); - #endif - } - - HOST_DEVICE Field& operator=(Field const& other) - { - for (int i = 0; i < TLC; i++) { - this->limbs_storage.limbs[i] = other.limbs_storage.limbs[i]; - } - return *this; - } - - // #if defined(__CUDACC__) -#if 1 - friend HOST_DEVICE Field operator*(const Field& xs, const Field& ys) - { - #ifdef __CUDA_ARCH__ // cuda - #ifdef BARRET - Wide xy = mul_wide(xs, ys); // full mult - return barret_reduce(xy); // reduce mod p - // return Wide::get_lower(xy); // reduce mod p - #else - return sub_modulus<1>(Field{device_math::template mulmont_device(xs.limbs_storage,ys.limbs_storage,get_modulus<1>(),get_mont_inv_modulus())}); - #endif - #else - #ifdef BARRET - Wide xy = mul_wide(xs, ys); // full mult - return barret_reduce(xy); // reduce mod p - // return Wide::get_lower(xy); - #else - return mont_mult(xs,ys); - #endif - #endif - #endif - } - static constexpr HOST_INLINE Field mont_mult(const Field& xs, const Field& ys) { - Wide r = {}; - host_math::multiply_mont(xs.limbs_storage, ys.limbs_storage, get_mont_inv_modulus(), get_modulus<1>(), r.limbs_storage); + Field r = {}; + base_math::template multiply_mont(xs.limbs_storage, ys.limbs_storage, get_mont_inv_modulus(), get_modulus<1>(), r.limbs_storage); return mont_sub_modulus(r); } + + - /** - * @brief Perform SOS reduction on a number in montgomery representation \p t in range [0,p^2-1] (p is the field's - * modulus) limiting it to the range [0,2p-1]. - * @param t Number to be reduced. Must be in montgomery rep, and in range [0,p^2-1]. - * @return \p t mod p - */ static constexpr HOST_INLINE Field mont_reduce(const Wide& t) { - #ifdef __CUDA_ARCH__ +#ifdef __CUDA_ARCH__ Wide r = t; - device_math::template reduce_mont_inplace(r.limbs_storage, get_modulus<1>(), get_mont_inv_modulus()); - return mont_sub_modulus(r); - #else + base_math::template reduce_mont_inplace(r.limbs_storage, get_modulus<1>(), get_mont_inv_modulus()); + return mont_sub_modulus(Wide::get_lower(r)); +#else Wide r = {}; - host_math::template sos_mont_reduction( + base_math::template sos_mont_reduction( t.limbs_storage, get_modulus<1>(), get_mont_inv_modulus(), r.limbs_storage); - return mont_sub_modulus(r, true); - #endif - } - - // #if defined(__GNUC__) && !defined(__NVCC__) && !defined(__clang__) - // #pragma GCC optimize("no-strict-aliasing") - // #endif - - friend HOST_DEVICE_INLINE Field original_multiplier(const Field& xs, const Field& ys) - { - Wide xy = mul_wide(xs, ys); // full mult - return reduce(xy); // reduce mod p - } - -#if 0 - - // #include - - /* GNARK CODE START*/ - // those two funcs are copied from bits.go implementation (/usr/local/go/src/math/bits/bits.go) - static HOST_DEVICE_INLINE void Mul64(uint64_t x, uint64_t y, uint64_t& hi, uint64_t& lo) - { - // constexpr uint64_t mask32 = 4294967295ULL; // 2^32 - 1 - // uint64_t x0 = x & mask32; - // uint64_t x1 = x >> 32; - // uint64_t y0 = y & mask32; - // uint64_t y1 = y >> 32; - // uint64_t w0 = x0 * y0; - // uint64_t t = x1 * y0 + w0 >> 32; - // uint64_t w1 = t & mask32; - // uint64_t w2 = t >> 32; - // w1 += x0 * y1; - // hi = x1 * y1 + w2 + w1 >> 32; - // lo = x * y; - - // #if defined(__GNUC__) || defined(__clang__) - // lo = _umul128(x, y, &hi); - // #else - __uint128_t result = static_cast<__uint128_t>(x) * y; - hi = static_cast(result >> 64); - lo = static_cast(result); - // #endif + return mont_sub_modulus(Wide::get_higher(r)); +#endif } - // #if defined(__GNUC__) || defined(__clang__) - // #include - // #endif - - static HOST_DEVICE_INLINE void Add64(uint64_t x, uint64_t y, uint64_t carry, uint64_t& sum, uint64_t& carry_out) + static constexpr HOST_DEVICE_INLINE Field reduce(const Wide& xs) { - // #if defined(__GNUC__) || defined(__clang__) - // carry_out = _addcarry_u64(carry, x, y, &sum); - // #else - sum = x + y + carry; - carry_out = ((x & y) | ((x | y) & ~sum)) >> 63; - // #endif +#ifdef BARRET + return barret_reduce(xs); +#else + return mont_reduce(xs); +#endif } - static HOST_DEVICE_INLINE void Sub64(uint64_t x, uint64_t y, uint64_t borrow, uint64_t& diff, uint64_t& borrowOut) + HOST_DEVICE Field& operator=(Field const& other) { - // #if defined(__GNUC__) || defined(__clang__) - // borrowOut = _subborrow_u64(borrow, x, y, &diff); - // #else - diff = x - y - borrow; - // See Sub32 for the bit logic. - borrowOut = ((~x & y) | (~(x ^ y) & diff)) >> 63; - // #endif + for (int i = 0; i < TLC; i++) { + this->limbs_storage.limbs[i] = other.limbs_storage.limbs[i]; + } + return *this; } - static HOST_DEVICE_INLINE bool smallerThanModulus(const Field& z) + friend HOST_DEVICE Field operator*(const Field& xs, const Field& ys) { - // for bn254 specifically - constexpr uint64_t q0 = 4891460686036598785ULL; - constexpr uint64_t q1 = 2896914383306846353ULL; - constexpr uint64_t q2 = 13281191951274694749ULL; - constexpr uint64_t q3 = 3486998266802970665ULL; - return ( - z.limbs_storage.limbs64[3] < q3 || - (z.limbs_storage.limbs64[3] == q3 && - (z.limbs_storage.limbs64[2] < q2 || - (z.limbs_storage.limbs64[2] == q2 && - (z.limbs_storage.limbs64[1] < q1 || - (z.limbs_storage.limbs64[1] == q1 && (z.limbs_storage.limbs64[0] < q0))))))); - } - - // #define WITH_MONT_CONVERSIONS - - #ifdef WITH_MONT_CONVERSIONS - friend HOST_DEVICE Field operator*(const Field& x_orig, const Field& y_orig) + #ifdef BARRET + return barret_mult(xs, ys); #else - friend HOST_DEVICE Field operator*(const Field& x, const Field& y) + return mont_mult(xs, ys); #endif - { - // for bn254 specifically - constexpr uint64_t qInvNeg = 14042775128853446655ULL; - constexpr uint64_t q0 = 4891460686036598785ULL; - constexpr uint64_t q1 = 2896914383306846353ULL; - constexpr uint64_t q2 = 13281191951274694749ULL; - constexpr uint64_t q3 = 3486998266802970665ULL; - - #ifdef WITH_MONT_CONVERSIONS - // auto x = original_multiplier(x_orig, original_multiplier(Field{CONFIG::montgomery_r}, - // Field{CONFIG::montgomery_r})); auto y = original_multiplier(y_orig, - // original_multiplier(Field{CONFIG::montgomery_r}, Field{CONFIG::montgomery_r})); - auto x = original_multiplier(x_orig, Field{CONFIG::montgomery_r}); - auto y = original_multiplier(y_orig, Field{CONFIG::montgomery_r}); - #endif - - Field z{}; - uint64_t t0, t1, t2, t3; - uint64_t u0, u1, u2, u3; - - { - uint64_t c0, c1, c2, _; - uint64_t v = x.limbs_storage.limbs64[0]; - Mul64(v, y.limbs_storage.limbs64[0], u0, t0); - Mul64(v, y.limbs_storage.limbs64[1], u1, t1); - Mul64(v, y.limbs_storage.limbs64[2], u2, t2); - Mul64(v, y.limbs_storage.limbs64[3], u3, t3); - Add64(u0, t1, 0, t1, c0); - Add64(u1, t2, c0, t2, c0); - Add64(u2, t3, c0, t3, c0); - Add64(u3, 0, c0, c2, _); - - uint64_t m = qInvNeg * t0; - - Mul64(m, q0, u0, c1); - Add64(t0, c1, 0, _, c0); - Mul64(m, q1, u1, c1); - Add64(t1, c1, c0, t0, c0); - Mul64(m, q2, u2, c1); - Add64(t2, c1, c0, t1, c0); - Mul64(m, q3, u3, c1); - - Add64(0, c1, c0, t2, c0); - Add64(u3, 0, c0, u3, _); - Add64(u0, t0, 0, t0, c0); - Add64(u1, t1, c0, t1, c0); - Add64(u2, t2, c0, t2, c0); - Add64(c2, 0, c0, c2, _); - Add64(t3, t2, 0, t2, c0); - Add64(u3, c2, c0, t3, _); - } - - { - uint64_t c0, c1, c2, _; - uint64_t v = x.limbs_storage.limbs64[1]; - Mul64(v, y.limbs_storage.limbs64[0], u0, c1); - Add64(c1, t0, 0, t0, c0); - Mul64(v, y.limbs_storage.limbs64[1], u1, c1); - Add64(c1, t1, c0, t1, c0); - Mul64(v, y.limbs_storage.limbs64[2], u2, c1); - Add64(c1, t2, c0, t2, c0); - Mul64(v, y.limbs_storage.limbs64[3], u3, c1); - Add64(c1, t3, c0, t3, c0); - - Add64(0, 0, c0, c2, _); - Add64(u0, t1, 0, t1, c0); - Add64(u1, t2, c0, t2, c0); - Add64(u2, t3, c0, t3, c0); - Add64(u3, c2, c0, c2, _); - - uint64_t m = qInvNeg * t0; - - Mul64(m, q0, u0, c1); - Add64(t0, c1, 0, _, c0); - Mul64(m, q1, u1, c1); - Add64(t1, c1, c0, t0, c0); - Mul64(m, q2, u2, c1); - Add64(t2, c1, c0, t1, c0); - Mul64(m, q3, u3, c1); - - Add64(0, c1, c0, t2, c0); - Add64(u3, 0, c0, u3, _); - Add64(u0, t0, 0, t0, c0); - Add64(u1, t1, c0, t1, c0); - Add64(u2, t2, c0, t2, c0); - Add64(c2, 0, c0, c2, _); - Add64(t3, t2, 0, t2, c0); - Add64(u3, c2, c0, t3, _); - } - - { - uint64_t c0, c1, c2, _; - uint64_t v = x.limbs_storage.limbs64[2]; - Mul64(v, y.limbs_storage.limbs64[0], u0, c1); - Add64(c1, t0, 0, t0, c0); - Mul64(v, y.limbs_storage.limbs64[1], u1, c1); - Add64(c1, t1, c0, t1, c0); - Mul64(v, y.limbs_storage.limbs64[2], u2, c1); - Add64(c1, t2, c0, t2, c0); - Mul64(v, y.limbs_storage.limbs64[3], u3, c1); - Add64(c1, t3, c0, t3, c0); - - Add64(0, 0, c0, c2, _); - Add64(u0, t1, 0, t1, c0); - Add64(u1, t2, c0, t2, c0); - Add64(u2, t3, c0, t3, c0); - Add64(u3, c2, c0, c2, _); - - uint64_t m = qInvNeg * t0; - - Mul64(m, q0, u0, c1); - Add64(t0, c1, 0, _, c0); - Mul64(m, q1, u1, c1); - Add64(t1, c1, c0, t0, c0); - Mul64(m, q2, u2, c1); - Add64(t2, c1, c0, t1, c0); - Mul64(m, q3, u3, c1); - - Add64(0, c1, c0, t2, c0); - Add64(u3, 0, c0, u3, _); - Add64(u0, t0, 0, t0, c0); - Add64(u1, t1, c0, t1, c0); - Add64(u2, t2, c0, t2, c0); - Add64(c2, 0, c0, c2, _); - Add64(t3, t2, 0, t2, c0); - Add64(u3, c2, c0, t3, _); - } - - { - uint64_t c0, c1, c2, _; - uint64_t v = x.limbs_storage.limbs64[3]; - Mul64(v, y.limbs_storage.limbs64[0], u0, c1); - Add64(c1, t0, 0, t0, c0); - Mul64(v, y.limbs_storage.limbs64[1], u1, c1); - Add64(c1, t1, c0, t1, c0); - Mul64(v, y.limbs_storage.limbs64[2], u2, c1); - Add64(c1, t2, c0, t2, c0); - Mul64(v, y.limbs_storage.limbs64[3], u3, c1); - Add64(c1, t3, c0, t3, c0); - - Add64(0, 0, c0, c2, _); - Add64(u0, t1, 0, t1, c0); - Add64(u1, t2, c0, t2, c0); - Add64(u2, t3, c0, t3, c0); - Add64(u3, c2, c0, c2, _); - - uint64_t m = qInvNeg * t0; - - Mul64(m, q0, u0, c1); - Add64(t0, c1, 0, _, c0); - Mul64(m, q1, u1, c1); - Add64(t1, c1, c0, t0, c0); - Mul64(m, q2, u2, c1); - Add64(t2, c1, c0, t1, c0); - Mul64(m, q3, u3, c1); - - Add64(0, c1, c0, t2, c0); - Add64(u3, 0, c0, u3, _); - Add64(u0, t0, 0, t0, c0); - Add64(u1, t1, c0, t1, c0); - Add64(u2, t2, c0, t2, c0); - Add64(c2, 0, c0, c2, _); - Add64(t3, t2, 0, t2, c0); - Add64(u3, c2, c0, t3, _); - } - - z.limbs_storage.limbs64[0] = t0; - z.limbs_storage.limbs64[1] = t1; - z.limbs_storage.limbs64[2] = t2; - z.limbs_storage.limbs64[3] = t3; - - if (smallerThanModulus(z)) { - uint64_t b, _; - Sub64(z.limbs_storage.limbs64[0], q0, 0, z.limbs_storage.limbs64[0], b); - Sub64(z.limbs_storage.limbs64[1], q1, b, z.limbs_storage.limbs64[1], b); - Sub64(z.limbs_storage.limbs64[2], q2, b, z.limbs_storage.limbs64[2], b); - Sub64(z.limbs_storage.limbs64[3], q3, b, z.limbs_storage.limbs64[3], _); - } - - #ifdef WITH_MONT_CONVERSIONS - z = original_multiplier(z, Field{CONFIG::montgomery_r_inv}); - // z = original_multiplier(z, original_multiplier(Field{CONFIG::montgomery_r_inv}, - // Field{CONFIG::montgomery_r_inv})); - #endif - return z; } -#endif - // #if defined(__GNUC__) && !defined(__NVCC__) && !defined(__clang__) - // #pragma GCC reset_options - // #endif - - /*GNARK CODE END*/ friend HOST_DEVICE bool operator==(const Field& xs, const Field& ys) { -#ifdef __CUDA_ARCH__ - const uint32_t* x = xs.limbs_storage.limbs; - const uint32_t* y = ys.limbs_storage.limbs; - uint32_t limbs_or = x[0] ^ y[0]; - UNROLL - for (unsigned i = 1; i < TLC; i++) - limbs_or |= x[i] ^ y[i]; - return limbs_or == 0; -#else - for (unsigned i = 0; i < TLC; i++) - if (xs.limbs_storage.limbs[i] != ys.limbs_storage.limbs[i]) return false; - return true; -#endif + return base_math::template is_equal(xs.limbs_storage, ys.limbs_storage); } friend HOST_DEVICE bool operator!=(const Field& xs, const Field& ys) { return !(xs == ys); } @@ -856,19 +529,19 @@ class Field static HOST_DEVICE_INLINE Field mul_const(const Field& xs) { Field mul = multiplier; - #ifdef BARRET +#ifdef BARRET static bool is_u32 = true; -#ifdef __CUDA_ARCH__ + #ifdef __CUDA_ARCH__ UNROLL -#endif + #endif for (unsigned i = 1; i < TLC; i++) is_u32 &= (mul.limbs_storage.limbs[i] == 0); if (is_u32) return mul_unsigned(xs); - #endif - #ifndef BARRET - mul = to_montgomery(mul); - #endif +#endif +#ifndef BARRET + mul = to_montgomery(mul); // TODO - optimize +#endif return mul * xs; } diff --git a/icicle/include/icicle/fields/host_math.h b/icicle/include/icicle/fields/host_math.h index ad9d5ec6e..0e7cb0c36 100644 --- a/icicle/include/icicle/fields/host_math.h +++ b/icicle/include/icicle/fields/host_math.h @@ -84,17 +84,9 @@ namespace host_math { return result; } - // static inline __host__ __uint128_t mul64(uint64_t x, uint64_t y) - // { - // uint64_t high, low; - // asm("mulq %3" : "=d"(high), "=a"(low) : "a"(x), "r"(y) : "cc"); - // return (static_cast<__uint128_t>(high) << 64) | low; - // } - static __host__ uint64_t madc_cc_64(const uint64_t x, const uint64_t y, const uint64_t z, uint64_t& carry) { __uint128_t r = static_cast<__uint128_t>(x) * y + z + carry; - // __uint128_t r = mul64(x, y) + z + carry; carry = (uint64_t)(r >> 64); uint64_t result = r & 0xffffffffffffffff; @@ -103,27 +95,6 @@ namespace host_math { #include - // static inline __host__ uint64_t madc_cc_64(const uint64_t x, const uint64_t y, const uint64_t z, uint64_t& carry) - // { - // uint64_t high, low; - - // // Perform multiplication of x * y - // asm("mulq %3\n\t" // x * y -> result in RDX:RAX - // "addq %4, %%rax\n\t" // Add z to the low 64 bits (RAX), setting flags - // "adcq $0, %%rdx\n\t" // Propagate carry to high 64 bits (RDX) - // "addq %5, %%rax\n\t" // Add the input carry to RAX, setting flags - // "adcq $0, %%rdx" // Propagate any carry to RDX - // : "=a"(low), "=d"(high) // Output operands - // : "a"(x), "r"(y), "r"(z), "r"(carry) // Input operands - // : "cc"); // Clobbers - - // // Set carry to the high 64 bits of the result - // carry = high; - - // // Return the low 64 bits of the result - // return low; - // } - template struct carry_chain { unsigned index; @@ -350,13 +321,10 @@ namespace host_math { * @tparam NLIMBS Number of 32bit limbs required to represend a number in the field defined by n. R is 2^(NLIMBS*32). */ template - static HOST_INLINE void - sos_mont_reduction( - const storage<2*NLIMBS>& t, const storage& n, const storage& n_tag, storage<2*NLIMBS>& r) + static HOST_INLINE void sos_mont_reduction( + const storage<2 * NLIMBS>& t, const storage& n, const storage& n_tag, storage<2 * NLIMBS>& r) { - static_assert( - NLIMBS % 2 == 0 || NLIMBS == 1, - "Odd number of limbs (That is not 1) is not supported\n"); + static_assert(NLIMBS % 2 == 0 || NLIMBS == 1, "Odd number of limbs (That is not 1) is not supported\n"); if constexpr (USE_32) { sos_mont_reduction_32(t.limbs, n.limbs, n_tag.limbs, r.limbs); return; @@ -368,9 +336,13 @@ namespace host_math { } } -template - static constexpr HOST_INLINE void - multiply_mont(const storage& as, const storage& bs, const storage& qs, const storage& ps, storage& rs) + template + static constexpr HOST_INLINE void multiply_mont( + const storage& as, + const storage& bs, + const storage& qs, + const storage& ps, + storage& rs) { static_assert( (NLIMBS_A % 2 == 0 || NLIMBS_A == 1) && (NLIMBS_B % 2 == 0 || NLIMBS_B == 1), @@ -482,6 +454,22 @@ template } } } + template + static constexpr void get_higher_with_slack(const storage<2*NLIMBS>& xs, storage& out, unsigned slack_bits) + { + for (unsigned i = 0; i < NLIMBS; i++) { + out.limbs[i] = (xs.limbs[i + NLIMBS] << 2 * slack_bits) + + (xs.limbs[i + NLIMBS - 1] >> (32 - 2 * slack_bits)); + } + } + + template + static constexpr DEVICE_INLINE bool is_equal(const storage& xs, const storage& ys) + { + for (unsigned i = 0; i < NLIMBS; i++) + if (xs.limbs[i] != ys.limbs[i]) return false; + return true; + } } // namespace host_math #if defined(__GNUC__) && !defined(__NVCC__) && !defined(__clang__) diff --git a/icicle/include/icicle/fields/params_gen.h b/icicle/include/icicle/fields/params_gen.h index e174a94d6..fa26fcdb2 100644 --- a/icicle/include/icicle/fields/params_gen.h +++ b/icicle/include/icicle/fields/params_gen.h @@ -62,7 +62,7 @@ namespace params_gen { } template - static constexpr HOST_INLINE storage get_lower(const storage<2*NLIMBS>& xs) + static constexpr HOST_INLINE storage get_lower(const storage<2 * NLIMBS>& xs) { storage rs = {}; for (unsigned i = 0; i < NLIMBS; i++) @@ -73,21 +73,21 @@ namespace params_gen { template static constexpr HOST_INLINE storage get_montgomery_mult_constant(const storage& modulus) { - //p^R-1 without carry (this is mod r) and then r-res; + // p^R-1 without carry (this is mod r) and then r-res; storage rs = {}; - storage<2*NLIMBS> w_rs = {}; + storage<2 * NLIMBS> w_rs = {}; storage tmp = {}; - storage<2*NLIMBS> w_tmp = {}; + storage<2 * NLIMBS> w_tmp = {}; host_math::template multiply_raw(modulus, modulus, w_tmp); tmp = params_gen::template get_lower(w_tmp); rs = modulus; host_math::template multiply_raw(tmp, rs, w_rs); rs = params_gen::template get_lower(w_rs); - for (int i = 0; i < sizeof(modulus.limbs[0])*8*NLIMBS - 4; i++) { - storage<2*NLIMBS> w_tmp2 = {}; + for (int i = 0; i < sizeof(modulus.limbs[0]) * 8 * NLIMBS - 4; i++) { + storage<2 * NLIMBS> w_tmp2 = {}; host_math::template multiply_raw(tmp, tmp, w_tmp2); tmp = params_gen::template get_lower(w_tmp2); - storage<2*NLIMBS> w_rs2 = {}; + storage<2 * NLIMBS> w_rs2 = {}; host_math::template multiply_raw(tmp, rs, w_rs2); rs = params_gen::template get_lower(w_rs2); } @@ -129,22 +129,20 @@ namespace params_gen { } template - constexpr storage_array get_invs(const storage& modulus, const storage& mont_r_sqr, const storage& mont_inv) + constexpr storage_array + get_invs(const storage& modulus, const storage& mont_r_sqr, const storage& mont_inv) { storage_array invs = {}; storage rs = {1}; for (int i = 0; i < TWO_ADICITY; i++) { if (rs.limbs[0] & 1) host_math::template add_sub_limbs(rs, modulus, rs); rs = host_math::template right_shift(rs); - // host_math::template multiply_mont_32(rs.limbs, mont_r_sqr.limbs, mont_inv.limbs, modulus.limbs, rs.limbs); invs.storages[i] = rs; } return invs; } } // namespace params_gen -//do we still need modulus_2, 3, 4 when using montgomery? smae for num_of_reductions - #define PARAMS(modulus) \ static constexpr unsigned limbs_count = modulus.LC; \ static constexpr unsigned modulus_bit_count = \ @@ -164,8 +162,8 @@ namespace params_gen { static constexpr storage m = params_gen::template get_m(modulus); \ static constexpr storage montgomery_r = \ params_gen::template get_montgomery_constant(modulus); \ - static constexpr storage montgomery_r_sqr = \ - params_gen::template get_montgomery_constant_sqr(modulus); \ + static constexpr storage montgomery_r_sqr = \ + params_gen::template get_montgomery_constant_sqr(modulus); \ static constexpr storage montgomery_r_inv = \ params_gen::template get_montgomery_constant(modulus); \ static constexpr storage mont_inv_modulus = \ @@ -173,7 +171,7 @@ namespace params_gen { static constexpr unsigned num_of_reductions = \ params_gen::template num_of_reductions(modulus, m); -#define TWIDDLES(modulus) \ +#define TWIDDLES(modulus) \ static constexpr unsigned omegas_count = params_gen::template two_adicity(modulus); \ static constexpr storage_array inv = \ params_gen::template get_invs(modulus, montgomery_r_sqr, mont_inv_modulus); diff --git a/icicle/include/icicle/fields/quartic_extension.h b/icicle/include/icicle/fields/quartic_extension.h index cd7b79751..7b199a575 100644 --- a/icicle/include/icicle/fields/quartic_extension.h +++ b/icicle/include/icicle/fields/quartic_extension.h @@ -161,9 +161,7 @@ class QuarticExtensionField static constexpr HOST_DEVICE_INLINE QuarticExtensionField reduce(const ExtensionWide& xs) { - return QuarticExtensionField{ - FF::reduce(xs.real), FF::reduce(xs.im1), - FF::reduce(xs.im2), FF::reduce(xs.im3)}; + return QuarticExtensionField{FF::reduce(xs.real), FF::reduce(xs.im1), FF::reduce(xs.im2), FF::reduce(xs.im3)}; } template diff --git a/icicle/include/icicle/fields/snark_fields/bls12_377_base.h b/icicle/include/icicle/fields/snark_fields/bls12_377_base.h index 1666517dd..2282e09c8 100644 --- a/icicle/include/icicle/fields/snark_fields/bls12_377_base.h +++ b/icicle/include/icicle/fields/snark_fields/bls12_377_base.h @@ -7,7 +7,6 @@ namespace bls12_377 { struct fq_config { static constexpr storage<12> modulus = {0x00000001, 0x8508c000, 0x30000000, 0x170b5d44, 0xba094800, 0x1ef3622f, 0x00f5138f, 0x1a22d9f3, 0x6ca1493b, 0xc63b05c0, 0x17c510ea, 0x01ae3a46}; - // static constexpr storage<12> mont_inv_modulus = {0xffffffff, 0x8508bfff, 0xa0000000, 0xd1e94577, 0x970debff, 0x35ed1347, 0xcced7a13, 0x5b245b86, 0x806a3cec, 0x22f80141, 0xeec82e3d, 0xbfa5205f}; PARAMS(modulus) diff --git a/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h b/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h index e047144a0..adb8fd00d 100644 --- a/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h +++ b/icicle/include/icicle/fields/snark_fields/bls12_377_scalar.h @@ -8,7 +8,6 @@ namespace bls12_377 { struct fp_config { static constexpr storage<8> modulus = {0x00000001, 0x0a118000, 0xd0000001, 0x59aa76fe, 0x5c37b001, 0x60b44d1e, 0x9a2ca556, 0x12ab655e}; - // static constexpr storage<8> mont_inv_modulus = {0xffffffff, 0xa117fff, 0x90000001, 0x452217cc, 0x4790a000, 0x249765c3, 0x68b29556, 0x6992d0fa}; PARAMS(modulus) static constexpr storage<8> rou = {0xec2a895e, 0x476ef4a4, 0x63e3f04a, 0x9b506ee3, diff --git a/icicle/include/icicle/fields/stark_fields/m31.h b/icicle/include/icicle/fields/stark_fields/m31.h index d2b831a00..4f559f2ff 100644 --- a/icicle/include/icicle/fields/stark_fields/m31.h +++ b/icicle/include/icicle/fields/stark_fields/m31.h @@ -200,7 +200,9 @@ namespace m31 { static constexpr storage one = {0x00000001}; static constexpr storage zero = {0x00000000}; static constexpr storage montgomery_r = {0x00000001}; + static constexpr storage montgomery_r_sqr = {0x00000001}; static constexpr storage montgomery_r_inv = {0x00000001}; + static constexpr storage mont_inv_modulus = {0x80000001}; static constexpr storage_array omega = {{{0x7ffffffe}}}; diff --git a/icicle/include/icicle/msm.h b/icicle/include/icicle/msm.h index e73de1ac6..ff132cba2 100644 --- a/icicle/include/icicle/msm.h +++ b/icicle/include/icicle/msm.h @@ -67,17 +67,17 @@ namespace icicle { 1, // batch_size true, // are_points_shared_in_batch false, // are_scalars_on_device - #ifdef BARRET - false, // are_scalars_montgomery_form - #else - true, // are_scalars_montgomery_form - #endif - false, // are_points_on_device - #ifdef BARRET - false, // are_points_montgomery_form - #else - true, // are_points_montgomery_form - #endif +#ifdef BARRET + false, // are_scalars_montgomery_form +#else + true, // are_scalars_montgomery_form +#endif + false, // are_points_on_device +#ifdef BARRET + false, // are_points_montgomery_form +#else + true, // are_points_montgomery_form +#endif false, // are_results_on_device false, // is_async nullptr, // ext diff --git a/icicle/tests/test_curve_api.cpp b/icicle/tests/test_curve_api.cpp index 9fb78b1dc..f6acd52f0 100644 --- a/icicle/tests/test_curve_api.cpp +++ b/icicle/tests/test_curve_api.cpp @@ -46,11 +46,11 @@ class CurveApiTest : public ::testing::Test if (!is_cuda_registered) { ICICLE_LOG_ERROR << "CUDA device not found. Testing CPU vs CPU"; } s_main_target = is_cuda_registered ? "CUDA" : "CPU"; s_ref_target = "CPU"; - #ifdef BARRET - ICICLE_LOG_INFO << "USING BARRET MULT\n"; - #else - ICICLE_LOG_INFO << "USING MONTGOMERY MULT\n"; - #endif +#ifdef BARRET + ICICLE_LOG_INFO << "USING BARRET MULT\n"; +#else + ICICLE_LOG_INFO << "USING MONTGOMERY MULT\n"; +#endif } static void TearDownTestSuite() { @@ -360,8 +360,8 @@ TYPED_TEST(CurveSanity, CurveSanityTest) { auto a = TypeParam::rand_host(); auto b = TypeParam::rand_host(); - ICICLE_LOG_INFO << "a: "< FTImplementations; TYPED_TEST_SUITE(FieldApiTest, FTImplementations); - // Note: this is testing host arithmetic. Other tests against CPU backend should guarantee correct device arithmetic too TYPED_TEST(FieldApiTest, FieldSanityTest) { auto a = TypeParam::rand_host(); // a.limbs_storage.limbs[0] = 1089097490; - std::cout<(a.limbs_storage, a.limbs_storage, r_wide.limbs_storage); - // a = TypeParam::reduce(r_wide); - ar = TypeParam::mont_mult(ar,ar); + // host_math::template multiply_raw(a.limbs_storage, a.limbs_storage, + // r_wide.limbs_storage); a = TypeParam::reduce(r_wide); + // ar = TypeParam::mont_mult(ar, ar); // a = TypeParam::mont_reduce(r_wide); - // host_math::template multiply_raw(b.limbs_storage, b.limbs_storage, r2_wide.limbs_storage); - // host_math::template add_sub_limbs(a.limbs_storage, a.limbs_storage, a.limbs_storage); - // host_math::template add_sub_limbs(b.limbs_storage, b.limbs_storage, b.limbs_storage); - // a = TypeParam::Wide::get_lower(r_wide); - // b = TypeParam::Wide::get_lower(r2_wide); - // a = a + a; - // a = a * a; - } - END_TIMER(MULT_sync, oss.str().c_str(), true); + // host_math::template multiply_raw(b.limbs_storage, b.limbs_storage, + // r2_wide.limbs_storage); host_math::template add_sub_limbs(a.limbs_storage, + // a.limbs_storage, a.limbs_storage); host_math::template add_sub_limbs(b.limbs_storage, b.limbs_storage, b.limbs_storage); a = TypeParam::Wide::get_lower(r_wide); b = + // TypeParam::Wide::get_lower(r2_wide); a = a + a; a = a * a; + // } + // END_TIMER(MULT_sync, oss.str().c_str(), true); // ASSERT_EQ(TypeParam::from_montgomery(ar), a); } -#endif - +// #endif TYPED_TEST(FieldApiTest, vectorOps) { @@ -218,12 +215,9 @@ TYPED_TEST(FieldApiTest, vectorOps) // std::cout << in_a[0] << ", " << in_b[0] << ", " << out_main[0] << ", " << out_ref[0] << std::endl; // std::cout << in_a[1] << ", " << in_b[1] << ", " << out_main[1] << ", " << out_ref[1] << std::endl; - for (int i = 0; i < N; i++) - { + for (int i = 0; i < N; i++) { std::cout << in_a[i] << ", " << in_b[i] << ", " << out_main[i] << ", " << out_ref[i] << std::endl; } - - ASSERT_EQ(0, memcmp(out_main.get(), out_ref.get(), N * sizeof(TypeParam))); } @@ -465,7 +459,6 @@ TYPED_TEST(FieldApiTest, ntt) run(s_reference_target, out_ref.get(), "ntt", VERBOSE /*=measure*/, 1 /*=iters*/); run(s_main_target, out_main.get(), "ntt", VERBOSE /*=measure*/, 1 /*=iters*/); - // std::cout << "\n"; // for (int i=0;i Date: Thu, 21 Nov 2024 20:36:59 +0200 Subject: [PATCH 19/22] formatting --- icicle/include/icicle/fields/field.h | 15 ++++++--------- icicle/include/icicle/fields/host_math.h | 15 +++++++-------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 89749af2a..998f63db9 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -290,7 +290,6 @@ class Field return base_math::template multiply_raw(as, bs, rs); } - public: ff_storage limbs_storage; @@ -451,7 +450,7 @@ class Field static constexpr HOST_INLINE Field barret_mult(const Field& xs, const Field& ys) { Wide xy = mul_wide(xs, ys); // full mult - return reduce(xy); // reduce mod p + return reduce(xy); // reduce mod p } #endif @@ -471,11 +470,10 @@ class Field static constexpr HOST_INLINE Field mont_mult(const Field& xs, const Field& ys) { Field r = {}; - base_math::template multiply_mont(xs.limbs_storage, ys.limbs_storage, get_mont_inv_modulus(), get_modulus<1>(), r.limbs_storage); + base_math::template multiply_mont( + xs.limbs_storage, ys.limbs_storage, get_mont_inv_modulus(), get_modulus<1>(), r.limbs_storage); return mont_sub_modulus(r); } - - static constexpr HOST_INLINE Field mont_reduce(const Wide& t) { @@ -510,14 +508,13 @@ class Field friend HOST_DEVICE Field operator*(const Field& xs, const Field& ys) { - #ifdef BARRET +#ifdef BARRET return barret_mult(xs, ys); - #else +#else return mont_mult(xs, ys); - #endif +#endif } - friend HOST_DEVICE bool operator==(const Field& xs, const Field& ys) { return base_math::template is_equal(xs.limbs_storage, ys.limbs_storage); diff --git a/icicle/include/icicle/fields/host_math.h b/icicle/include/icicle/fields/host_math.h index 0e7cb0c36..02e4ba109 100644 --- a/icicle/include/icicle/fields/host_math.h +++ b/icicle/include/icicle/fields/host_math.h @@ -244,7 +244,7 @@ namespace host_math { * @param n_tag Number such that \p n * \p n_tag modR = -1 * @param r Array in which to store the result in its upper half (Lower half is data that would be removed by * dividing by R = shifting NLIMBS down). - * @tparam NLIMBS Number of 32bit limbs required to represend a number in the field defined by n. R is 2^(NLIMBS*32). + * @tparam NLIMBS Number of 32bit limbs required to represent a number in the field defined by n. R is 2^(NLIMBS*32). */ template static HOST_INLINE void @@ -281,7 +281,7 @@ namespace host_math { * @param n_tag Number such that \p n * \p n_tag modR = -1 * @param r Array in which to store the result in its upper half (Lower half is data that would be removed by * dividing by R = shifting NLIMBS down). - * @tparam NLIMBS Number of 32bit limbs required to represend a number in the field defined by n. R is 2^(NLIMBS*32). + * @tparam NLIMBS Number of 32bit limbs required to represent a number in the field defined by n. R is 2^(NLIMBS*32). */ template static HOST_INLINE void @@ -318,7 +318,7 @@ namespace host_math { * @param n_tag Number such that \p n * \p n_tag modR = -1 * @param r Array in which to store the result in its upper half (Lower half is data that would be removed by * dividing by R = shifting NLIMBS down). - * @tparam NLIMBS Number of 32bit limbs required to represend a number in the field defined by n. R is 2^(NLIMBS*32). + * @tparam NLIMBS Number of 32bit limbs required to represent a number in the field defined by n. R is 2^(NLIMBS*32). */ template static HOST_INLINE void sos_mont_reduction( @@ -455,12 +455,11 @@ namespace host_math { } } template - static constexpr void get_higher_with_slack(const storage<2*NLIMBS>& xs, storage& out, unsigned slack_bits) + static constexpr void get_higher_with_slack(const storage<2 * NLIMBS>& xs, storage& out, unsigned slack_bits) { - for (unsigned i = 0; i < NLIMBS; i++) { - out.limbs[i] = (xs.limbs[i + NLIMBS] << 2 * slack_bits) + - (xs.limbs[i + NLIMBS - 1] >> (32 - 2 * slack_bits)); - } + for (unsigned i = 0; i < NLIMBS; i++) { + out.limbs[i] = (xs.limbs[i + NLIMBS] << 2 * slack_bits) + (xs.limbs[i + NLIMBS - 1] >> (32 - 2 * slack_bits)); + } } template From 01c51434691f7b76b2a7ef01255a5f0717f5b14f Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Thu, 21 Nov 2024 20:48:28 +0200 Subject: [PATCH 20/22] small fix --- icicle/backend/cpu/include/cpu_ntt_domain.h | 1 - 1 file changed, 1 deletion(-) diff --git a/icicle/backend/cpu/include/cpu_ntt_domain.h b/icicle/backend/cpu/include/cpu_ntt_domain.h index d3c83ebd0..3cfe88018 100644 --- a/icicle/backend/cpu/include/cpu_ntt_domain.h +++ b/icicle/backend/cpu/include/cpu_ntt_domain.h @@ -94,7 +94,6 @@ namespace ntt_cpu { if (found_logn) break; } } - // s_ntt_domain.max_log_size = 21; s_ntt_domain.max_size = (int)pow(2, s_ntt_domain.max_log_size); if (omega != S::one()) { From 9d4394c6b34cca598084935ffe428ee012b247c4 Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Thu, 21 Nov 2024 20:57:32 +0200 Subject: [PATCH 21/22] small fix to tests --- icicle/tests/test_curve_api.cpp | 8 +------- icicle/tests/test_field_api.cpp | 12 +++++------- 2 files changed, 6 insertions(+), 14 deletions(-) diff --git a/icicle/tests/test_curve_api.cpp b/icicle/tests/test_curve_api.cpp index f6acd52f0..bf544e577 100644 --- a/icicle/tests/test_curve_api.cpp +++ b/icicle/tests/test_curve_api.cpp @@ -310,10 +310,6 @@ TEST_F(CurveApiTest, ecnttDeviceMem) std::ostringstream oss; oss << dev_type << " " << msg; - // std::cout << "press any key to proceed..."; - // int a; - // std::cin >> a; - // std::cout << "proceeding\n"; START_TIMER(NTT_sync) for (int i = 0; i < iters; ++i) { ICICLE_CHECK(ntt(d_in, N, dir, config, inplace ? d_in : d_out)); @@ -329,7 +325,7 @@ TEST_F(CurveApiTest, ecnttDeviceMem) ICICLE_CHECK(ntt_release_domain()); }; - // run(s_main_target, out_main.get(), "ecntt", true /*=measure*/, 1 /*=iters*/); // warmup + run(s_main_target, out_main.get(), "ecntt", true /*=measure*/, 1 /*=iters*/); // warmup run(s_ref_target, out_ref.get(), "ecntt", VERBOSE /*=measure*/, 1 /*=iters*/); run(s_main_target, out_main.get(), "ecntt", VERBOSE /*=measure*/, 1 /*=iters*/); // note that memcmp is tricky here because projetive points can have many representations @@ -360,8 +356,6 @@ TYPED_TEST(CurveSanity, CurveSanityTest) { auto a = TypeParam::rand_host(); auto b = TypeParam::rand_host(); - ICICLE_LOG_INFO << "a: " << a; - ICICLE_LOG_INFO << "b: " << b; ASSERT_EQ(true, TypeParam::is_on_curve(a) && TypeParam::is_on_curve(b)); // rand is on curve ASSERT_EQ(a + TypeParam::zero(), a); // zero addition ASSERT_EQ(a + b - a, b); // addition,subtraction cancel diff --git a/icicle/tests/test_field_api.cpp b/icicle/tests/test_field_api.cpp index bf1fc435e..87036b9bb 100644 --- a/icicle/tests/test_field_api.cpp +++ b/icicle/tests/test_field_api.cpp @@ -853,8 +853,8 @@ TYPED_TEST(FieldApiTest, ntt) const bool inplace = rand() % 2; const int logn = rand() % 15 + 3; const uint64_t N = 1 << logn; - const int log_ntt_domain_size = logn; - const int log_batch_size = 0; + const int log_ntt_domain_size = logn + 1; + const int log_batch_size = rand() % 3; const int batch_size = 1 << log_batch_size; const int _ordering = rand() % 4; const Ordering ordering = static_cast(_ordering); @@ -907,7 +907,6 @@ TYPED_TEST(FieldApiTest, ntt) config.are_outputs_on_device = true; config.is_async = false; ICICLE_CHECK(ntt_init_domain(scalar_t::omega(log_ntt_domain_size), init_domain_config)); - // ntt_init_domain(scalar_t::omega(log_ntt_domain_size), init_domain_config); TypeParam *d_in, *d_out; ICICLE_CHECK(icicle_malloc_async((void**)&d_in, total_size * sizeof(TypeParam), config.stream)); ICICLE_CHECK(icicle_malloc_async((void**)&d_out, total_size * sizeof(TypeParam), config.stream)); @@ -935,10 +934,9 @@ TYPED_TEST(FieldApiTest, ntt) ICICLE_CHECK(icicle_destroy_stream(stream)); ICICLE_CHECK(ntt_release_domain()); }; - run(s_main_target, out_main.get(), "ntt", false /*=measure*/, 1 /*=iters*/); // warmup - run(s_reference_target, out_ref.get(), "ntt", VERBOSE /*=measure*/, 1 /*=iters*/); - run(s_main_target, out_main.get(), "ntt", VERBOSE /*=measure*/, 1 /*=iters*/); - + run(s_main_target, out_main.get(), "ntt", false /*=measure*/, 10 /*=iters*/); // warmup + run(s_reference_target, out_ref.get(), "ntt", VERBOSE /*=measure*/, 10 /*=iters*/); + run(s_main_target, out_main.get(), "ntt", VERBOSE /*=measure*/, 10 /*=iters*/); ASSERT_EQ(0, memcmp(out_main.get(), out_ref.get(), total_size * sizeof(scalar_t))); } #endif // NTT From 1172cecc7776dace948c5d8f96b0b0108b6c1aef Mon Sep 17 00:00:00 2001 From: hadaringonyama Date: Thu, 21 Nov 2024 21:51:48 +0200 Subject: [PATCH 22/22] bug fix --- icicle/include/icicle/fields/field.h | 4 ++-- icicle/include/icicle/fields/host_math.h | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/icicle/include/icicle/fields/field.h b/icicle/include/icicle/fields/field.h index 998f63db9..a6ca93527 100644 --- a/icicle/include/icicle/fields/field.h +++ b/icicle/include/icicle/fields/field.h @@ -467,7 +467,7 @@ class Field return r; } - static constexpr HOST_INLINE Field mont_mult(const Field& xs, const Field& ys) + static constexpr HOST_DEVICE_INLINE Field mont_mult(const Field& xs, const Field& ys) { Field r = {}; base_math::template multiply_mont( @@ -475,7 +475,7 @@ class Field return mont_sub_modulus(r); } - static constexpr HOST_INLINE Field mont_reduce(const Wide& t) + static constexpr HOST_DEVICE_INLINE Field mont_reduce(const Wide& t) { #ifdef __CUDA_ARCH__ Wide r = t; diff --git a/icicle/include/icicle/fields/host_math.h b/icicle/include/icicle/fields/host_math.h index 02e4ba109..4ebf0d83e 100644 --- a/icicle/include/icicle/fields/host_math.h +++ b/icicle/include/icicle/fields/host_math.h @@ -463,7 +463,7 @@ namespace host_math { } template - static constexpr DEVICE_INLINE bool is_equal(const storage& xs, const storage& ys) + static constexpr bool is_equal(const storage& xs, const storage& ys) { for (unsigned i = 0; i < NLIMBS; i++) if (xs.limbs[i] != ys.limbs[i]) return false;