diff --git a/benchmarks/bench_blueprint.nim b/benchmarks/bench_blueprint.nim index ce3c10e68..eabdd8b3b 100644 --- a/benchmarks/bench_blueprint.nim +++ b/benchmarks/bench_blueprint.nim @@ -88,7 +88,6 @@ proc notes*() = echo " Bench on specific compiler with assembler: \"nimble bench_ec_g1_gcc\" or \"nimble bench_ec_g1_clang\"." echo " Bench on specific compiler with assembler: \"nimble bench_ec_g1_gcc_noasm\" or \"nimble bench_ec_g1_clang_noasm\"." echo " - The simplest operations might be optimized away by the compiler." - echo " - Fast Squaring and Fast Multiplication are possible if there are spare bits in the prime representation (i.e. the prime uses 254 bits out of 256 bits)" template measure*(iters: int, startTime, stopTime: untyped, diff --git a/benchmarks/bench_fp_double_width.nim b/benchmarks/bench_fp_double_precision.nim similarity index 69% rename from benchmarks/bench_fp_double_width.nim rename to benchmarks/bench_fp_double_precision.nim index 2315de77d..ba05464df 100644 --- a/benchmarks/bench_fp_double_width.nim +++ b/benchmarks/bench_fp_double_precision.nim @@ -89,11 +89,13 @@ proc notes*() = echo "Notes:" echo " - Compilers:" echo " Compilers are severely limited on multiprecision arithmetic." - echo " Inline Assembly is used by default (nimble bench_fp)." - echo " Bench without assembly can use \"nimble bench_fp_gcc\" or \"nimble bench_fp_clang\"." + echo " Constantine compile-time assembler is used by default (nimble bench_fp)." echo " GCC is significantly slower than Clang on multiprecision arithmetic due to catastrophic handling of carries." + echo " GCC also seems to have issues with large temporaries and register spilling." + echo " This is somewhat alleviated by Constantine compile-time assembler." + echo " Bench on specific compiler with assembler: \"nimble bench_ec_g1_gcc\" or \"nimble bench_ec_g1_clang\"." + echo " Bench on specific compiler with assembler: \"nimble bench_ec_g1_gcc_noasm\" or \"nimble bench_ec_g1_clang_noasm\"." echo " - The simplest operations might be optimized away by the compiler." - echo " - Fast Squaring and Fast Multiplication are possible if there are spare bits in the prime representation (i.e. the prime uses 254 bits out of 256 bits)" template bench(op: string, desc: string, iters: int, body: untyped): untyped = let start = getMonotime() @@ -121,12 +123,12 @@ func random_unsafe(rng: var RngState, a: var FpDbl, Base: typedesc) = for i in 0 ..< aHi.mres.limbs.len: a.limbs2x[aLo.mres.limbs.len+i] = aHi.mres.limbs[i] -proc sumNoReduce(T: typedesc, iters: int) = +proc sumUnr(T: typedesc, iters: int) = var r: T let a = rng.random_unsafe(T) let b = rng.random_unsafe(T) - bench("Addition no reduce", $T, iters): - r.sumNoReduce(a, b) + bench("Addition unreduced", $T, iters): + r.sumUnr(a, b) proc sum(T: typedesc, iters: int) = var r: T @@ -135,12 +137,12 @@ proc sum(T: typedesc, iters: int) = bench("Addition", $T, iters): r.sum(a, b) -proc diffNoReduce(T: typedesc, iters: int) = +proc diffUnr(T: typedesc, iters: int) = var r: T let a = rng.random_unsafe(T) let b = rng.random_unsafe(T) - bench("Substraction no reduce", $T, iters): - r.diffNoReduce(a, b) + bench("Substraction unreduced", $T, iters): + r.diffUnr(a, b) proc diff(T: typedesc, iters: int) = var r: T @@ -149,52 +151,86 @@ proc diff(T: typedesc, iters: int) = bench("Substraction", $T, iters): r.diff(a, b) -proc diff2xNoReduce(T: typedesc, iters: int) = - var r, a, b: doubleWidth(T) +proc neg(T: typedesc, iters: int) = + var r: T + let a = rng.random_unsafe(T) + bench("Negation", $T, iters): + r.neg(a) + +proc sum2xUnreduce(T: typedesc, iters: int) = + var r, a, b: doublePrec(T) + rng.random_unsafe(r, T) + rng.random_unsafe(a, T) + rng.random_unsafe(b, T) + bench("Addition 2x unreduced", $doublePrec(T), iters): + r.sum2xUnr(a, b) + +proc sum2x(T: typedesc, iters: int) = + var r, a, b: doublePrec(T) + rng.random_unsafe(r, T) + rng.random_unsafe(a, T) + rng.random_unsafe(b, T) + bench("Addition 2x reduced", $doublePrec(T), iters): + r.sum2xMod(a, b) + +proc diff2xUnreduce(T: typedesc, iters: int) = + var r, a, b: doublePrec(T) rng.random_unsafe(r, T) rng.random_unsafe(a, T) rng.random_unsafe(b, T) - bench("Substraction 2x no reduce", $doubleWidth(T), iters): - r.diffNoReduce(a, b) + bench("Substraction 2x unreduced", $doublePrec(T), iters): + r.diff2xUnr(a, b) proc diff2x(T: typedesc, iters: int) = - var r, a, b: doubleWidth(T) + var r, a, b: doublePrec(T) rng.random_unsafe(r, T) rng.random_unsafe(a, T) rng.random_unsafe(b, T) - bench("Substraction 2x", $doubleWidth(T), iters): - r.diff(a, b) + bench("Substraction 2x reduced", $doublePrec(T), iters): + r.diff2xMod(a, b) -proc mul2xBench*(rLen, aLen, bLen: static int, iters: int) = +proc neg2x(T: typedesc, iters: int) = + var r, a: doublePrec(T) + rng.random_unsafe(a, T) + bench("Negation 2x reduced", $doublePrec(T), iters): + r.neg2xMod(a) + +proc prod2xBench*(rLen, aLen, bLen: static int, iters: int) = var r: BigInt[rLen] let a = rng.random_unsafe(BigInt[aLen]) let b = rng.random_unsafe(BigInt[bLen]) - bench("Multiplication", $rLen & " <- " & $aLen & " x " & $bLen, iters): + bench("Multiplication 2x", $rLen & " <- " & $aLen & " x " & $bLen, iters): r.prod(a, b) proc square2xBench*(rLen, aLen: static int, iters: int) = var r: BigInt[rLen] let a = rng.random_unsafe(BigInt[aLen]) - bench("Squaring", $rLen & " <- " & $aLen & "²", iters): + bench("Squaring 2x", $rLen & " <- " & $aLen & "²", iters): r.square(a) proc reduce2x*(T: typedesc, iters: int) = var r: T - var t: doubleWidth(T) + var t: doublePrec(T) rng.random_unsafe(t, T) - bench("Reduce 2x-width", $T & " <- " & $doubleWidth(T), iters): - r.reduce(t) + bench("Redc 2x", $T & " <- " & $doublePrec(T), iters): + r.redc2x(t) proc main() = separator() - sumNoReduce(Fp[BLS12_381], iters = 10_000_000) - diffNoReduce(Fp[BLS12_381], iters = 10_000_000) sum(Fp[BLS12_381], iters = 10_000_000) + sumUnr(Fp[BLS12_381], iters = 10_000_000) diff(Fp[BLS12_381], iters = 10_000_000) + diffUnr(Fp[BLS12_381], iters = 10_000_000) + neg(Fp[BLS12_381], iters = 10_000_000) + separator() + sum2x(Fp[BLS12_381], iters = 10_000_000) + sum2xUnreduce(Fp[BLS12_381], iters = 10_000_000) diff2x(Fp[BLS12_381], iters = 10_000_000) - diff2xNoReduce(Fp[BLS12_381], iters = 10_000_000) - mul2xBench(768, 384, 384, iters = 10_000_000) + diff2xUnreduce(Fp[BLS12_381], iters = 10_000_000) + neg2x(Fp[BLS12_381], iters = 10_000_000) + separator() + prod2xBench(768, 384, 384, iters = 10_000_000) square2xBench(768, 384, iters = 10_000_000) reduce2x(Fp[BLS12_381], iters = 10_000_000) separator() diff --git a/benchmarks/bench_pairing_template.nim b/benchmarks/bench_pairing_template.nim index c327c8bc2..7096ae1db 100644 --- a/benchmarks/bench_pairing_template.nim +++ b/benchmarks/bench_pairing_template.nim @@ -32,15 +32,15 @@ import ./bench_blueprint export notes -proc separator*() = separator(177) +proc separator*() = separator(132) proc report(op, curve: string, startTime, stopTime: MonoTime, startClk, stopClk: int64, iters: int) = let ns = inNanoseconds((stopTime-startTime) div iters) let throughput = 1e9 / float64(ns) when SupportsGetTicks: - echo &"{op:<60} {curve:<15} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)" + echo &"{op:<40} {curve:<15} {throughput:>15.3f} ops/s {ns:>9} ns/op {(stopClk - startClk) div iters:>9} CPU cycles (approx)" else: - echo &"{op:<60} {curve:<15} {throughput:>15.3f} ops/s {ns:>9} ns/op" + echo &"{op:<40} {curve:<15} {throughput:>15.3f} ops/s {ns:>9} ns/op" template bench(op: string, C: static Curve, iters: int, body: untyped): untyped = measure(iters, startTime, stopTime, startClk, stopClk, body) diff --git a/constantine.nimble b/constantine.nimble index bc42e1434..687a17e93 100644 --- a/constantine.nimble +++ b/constantine.nimble @@ -43,13 +43,14 @@ const testDesc: seq[tuple[path: string, useGMP: bool]] = @[ ("tests/t_finite_fields_powinv.nim", false), ("tests/t_finite_fields_vs_gmp.nim", true), ("tests/t_fp_cubic_root.nim", false), - # Double-width finite fields + # Double-precision finite fields # ---------------------------------------------------------- - ("tests/t_finite_fields_double_width.nim", false), + ("tests/t_finite_fields_double_precision.nim", false), # Towers of extension fields # ---------------------------------------------------------- ("tests/t_fp2.nim", false), ("tests/t_fp2_sqrt.nim", false), + ("tests/t_fp4.nim", false), ("tests/t_fp6_bn254_snarks.nim", false), ("tests/t_fp6_bls12_377.nim", false), ("tests/t_fp6_bls12_381.nim", false), @@ -259,7 +260,7 @@ proc buildAllBenches() = echo "\n\n------------------------------------------------------\n" echo "Building benchmarks to ensure they stay relevant ..." buildBench("bench_fp") - buildBench("bench_fp_double_width") + buildBench("bench_fp_double_precision") buildBench("bench_fp2") buildBench("bench_fp6") buildBench("bench_fp12") @@ -400,19 +401,19 @@ task bench_fp_clang_noasm, "Run benchmark 𝔽p with clang - no Assembly": runBench("bench_fp", "clang", useAsm = false) task bench_fpdbl, "Run benchmark 𝔽pDbl with your default compiler": - runBench("bench_fp_double_width") + runBench("bench_fp_double_precision") task bench_fpdbl_gcc, "Run benchmark 𝔽p with gcc": - runBench("bench_fp_double_width", "gcc") + runBench("bench_fp_double_precision", "gcc") task bench_fpdbl_clang, "Run benchmark 𝔽p with clang": - runBench("bench_fp_double_width", "clang") + runBench("bench_fp_double_precision", "clang") task bench_fpdbl_gcc_noasm, "Run benchmark 𝔽p with gcc - no Assembly": - runBench("bench_fp_double_width", "gcc", useAsm = false) + runBench("bench_fp_double_precision", "gcc", useAsm = false) task bench_fpdbl_clang_noasm, "Run benchmark 𝔽p with clang - no Assembly": - runBench("bench_fp_double_width", "clang", useAsm = false) + runBench("bench_fp_double_precision", "clang", useAsm = false) task bench_fp2, "Run benchmark with 𝔽p2 your default compiler": runBench("bench_fp2") diff --git a/constantine/arithmetic.nim b/constantine/arithmetic.nim index 7e5339a35..540cc6e20 100644 --- a/constantine/arithmetic.nim +++ b/constantine/arithmetic.nim @@ -12,7 +12,7 @@ import finite_fields, finite_fields_inversion, finite_fields_square_root, - finite_fields_double_width + finite_fields_double_precision ] export @@ -21,4 +21,4 @@ export finite_fields, finite_fields_inversion, finite_fields_square_root, - finite_fields_double_width + finite_fields_double_precision diff --git a/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim b/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim new file mode 100644 index 000000000..c19d4e94b --- /dev/null +++ b/constantine/arithmetic/assembly/limbs_asm_modular_dbl_prec_x86.nim @@ -0,0 +1,244 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + # Standard library + std/macros, + # Internal + ../../config/common, + ../../primitives + +# ############################################################ +# # +# Assembly implementation of FpDbl # +# # +# ############################################################ + +# A FpDbl is a partially-reduced double-precision element of Fp +# The allowed range is [0, 2ⁿp) +# with n = w*WordBitSize +# and w the number of words necessary to represent p on the machine. +# Concretely a 381-bit p needs 6*64 bits limbs (hence 384 bits total) +# and so FpDbl would 768 bits. + +static: doAssert UseASM_X86_64 +{.localPassC:"-fomit-frame-pointer".} # Needed so that the compiler finds enough registers + +# Double-precision field addition +# ------------------------------------------------------------ + +macro addmod2x_gen[N: static int](R: var Limbs[N], A, B: Limbs[N], m: Limbs[N div 2]): untyped = + ## Generate an optimized out-of-place double-precision addition kernel + + result = newStmtList() + + var ctx = init(Assembler_x86, BaseType) + let + H = N div 2 + + r = init(OperandArray, nimSymbol = R, N, PointerInReg, InputOutput) + # We reuse the reg used for b for overflow detection + b = init(OperandArray, nimSymbol = B, N, PointerInReg, InputOutput) + # We could force m as immediate by specializing per moduli + M = init(OperandArray, nimSymbol = m, N, PointerInReg, Input) + # If N is too big, we need to spill registers. TODO. + u = init(OperandArray, nimSymbol = ident"U", H, ElemsInReg, InputOutput) + v = init(OperandArray, nimSymbol = ident"V", H, ElemsInReg, InputOutput) + + let usym = u.nimSymbol + let vsym = v.nimSymbol + result.add quote do: + var `usym`{.noinit.}, `vsym` {.noInit.}: typeof(`A`) + staticFor i, 0, `H`: + `usym`[i] = `A`[i] + staticFor i, `H`, `N`: + `vsym`[i-`H`] = `A`[i] + + # Addition + # u = a[0..= 1" + let overflowed = b.reuseRegister() + ctx.sbb overflowed, overflowed + + # Now substract the modulus to test a < 2ⁿp + for i in 0 ..< H: + if i == 0: + ctx.sub v[0], M[0] + else: + ctx.sbb v[i], M[i] + + # If it overflows here, it means that it was + # smaller than the modulus and we don't need v + ctx.sbb overflowed, 0 + + # Conditional Mov and + # and store result + for i in 0 ..< H: + ctx.cmovnc u[i], v[i] + ctx.mov r[i+H], u[i] + + result.add ctx.generate + +func addmod2x_asm*[N: static int](r: var Limbs[N], a, b: Limbs[N], M: Limbs[N div 2]) = + ## Constant-time double-precision addition + ## Output is conditionally reduced by 2ⁿp + ## to stay in the [0, 2ⁿp) range + addmod2x_gen(r, a, b, M) + +# Double-precision field substraction +# ------------------------------------------------------------ + +macro submod2x_gen[N: static int](R: var Limbs[N], A, B: Limbs[N], m: Limbs[N div 2]): untyped = + ## Generate an optimized out-of-place double-precision substraction kernel + + result = newStmtList() + + var ctx = init(Assembler_x86, BaseType) + let + H = N div 2 + + r = init(OperandArray, nimSymbol = R, N, PointerInReg, InputOutput) + # We reuse the reg used for b for overflow detection + b = init(OperandArray, nimSymbol = B, N, PointerInReg, InputOutput) + # We could force m as immediate by specializing per moduli + M = init(OperandArray, nimSymbol = m, N, PointerInReg, Input) + # If N is too big, we need to spill registers. TODO. + u = init(OperandArray, nimSymbol = ident"U", H, ElemsInReg, InputOutput) + v = init(OperandArray, nimSymbol = ident"V", H, ElemsInReg, InputOutput) + + let usym = u.nimSymbol + let vsym = v.nimSymbol + result.add quote do: + var `usym`{.noinit.}, `vsym` {.noInit.}: typeof(`A`) + staticFor i, 0, `H`: + `usym`[i] = `A`[i] + staticFor i, `H`, `N`: + `vsym`[i-`H`] = `A`[i] + + # Substraction + # u = a[0..= 1" let overflowed = b.reuseRegister() ctx.sbb overflowed, overflowed - # Now substract the modulus + # Now substract the modulus to test a < p for i in 0 ..< N: if i == 0: ctx.sub v[0], M[0] @@ -81,7 +81,7 @@ macro addmod_gen[N: static int](R: var Limbs[N], A, B, m: Limbs[N]): untyped = ctx.sbb v[i], M[i] # If it overflows here, it means that it was - # smaller than the modulus and we don'u need V + # smaller than the modulus and we don't need V ctx.sbb overflowed, 0 # Conditional Mov and diff --git a/constantine/arithmetic/assembly/limbs_asm_montred_x86.nim b/constantine/arithmetic/assembly/limbs_asm_montred_x86.nim index 42e74e86a..8e48de73e 100644 --- a/constantine/arithmetic/assembly/limbs_asm_montred_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_montred_x86.nim @@ -83,12 +83,12 @@ proc finalSubCanOverflow*( # Montgomery reduction # ------------------------------------------------------------ -macro montyRed_gen[N: static int]( +macro montyRedc2x_gen[N: static int]( r_MR: var array[N, SecretWord], a_MR: array[N*2, SecretWord], M_MR: array[N, SecretWord], m0ninv_MR: BaseType, - canUseNoCarryMontyMul: static bool + spareBits: static int ) = # TODO, slower than Clang, in particular due to the shadowing @@ -236,7 +236,7 @@ macro montyRed_gen[N: static int]( let reuse = repackRegisters(t, scratch[N], scratch[N+1]) - if canUseNoCarryMontyMul: + if spareBits >= 1: ctx.finalSubNoCarry(r, scratch, M, reuse) else: ctx.finalSubCanOverflow(r, scratch, M, reuse, rRAX) @@ -249,7 +249,7 @@ func montRed_asm*[N: static int]( a: array[N*2, SecretWord], M: array[N, SecretWord], m0ninv: BaseType, - canUseNoCarryMontyMul: static bool + spareBits: static int ) = ## Constant-time Montgomery reduction - montyRed_gen(r, a, M, m0ninv, canUseNoCarryMontyMul) + montyRedc2x_gen(r, a, M, m0ninv, spareBits) diff --git a/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim b/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim index afc26a15c..c4e652019 100644 --- a/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim +++ b/constantine/arithmetic/assembly/limbs_asm_montred_x86_adx_bmi2.nim @@ -35,12 +35,12 @@ static: doAssert UseASM_X86_64 # Montgomery reduction # ------------------------------------------------------------ -macro montyRedx_gen[N: static int]( +macro montyRedc2xx_gen[N: static int]( r_MR: var array[N, SecretWord], a_MR: array[N*2, SecretWord], M_MR: array[N, SecretWord], m0ninv_MR: BaseType, - canUseNoCarryMontyMul: static bool + spareBits: static int ) = # TODO, slower than Clang, in particular due to the shadowing @@ -175,7 +175,7 @@ macro montyRedx_gen[N: static int]( let reuse = repackRegisters(t, scratch[N]) - if canUseNoCarryMontyMul: + if spareBits >= 1: ctx.finalSubNoCarry(r, scratch, M, reuse) else: ctx.finalSubCanOverflow(r, scratch, M, reuse, hi) @@ -188,7 +188,7 @@ func montRed_asm_adx_bmi2*[N: static int]( a: array[N*2, SecretWord], M: array[N, SecretWord], m0ninv: BaseType, - canUseNoCarryMontyMul: static bool + spareBits: static int ) = ## Constant-time Montgomery reduction - montyRedx_gen(r, a, M, m0ninv, canUseNoCarryMontyMul) + montyRedc2xx_gen(r, a, M, m0ninv, spareBits) diff --git a/constantine/arithmetic/assembly/limbs_asm_x86.nim b/constantine/arithmetic/assembly/limbs_asm_x86.nim index 8a73176c6..105bd2ecd 100644 --- a/constantine/arithmetic/assembly/limbs_asm_x86.nim +++ b/constantine/arithmetic/assembly/limbs_asm_x86.nim @@ -138,14 +138,16 @@ macro add_gen[N: static int](carry: var Carry, r: var Limbs[N], a, b: Limbs[N]): var `t0sym`{.noinit.}, `t1sym`{.noinit.}: BaseType # Algorithm - for i in 0 ..< N: - ctx.mov t0, arrA[i] - if i == 0: - ctx.add t0, arrB[0] - else: - ctx.adc t0, arrB[i] - ctx.mov arrR[i], t0 - swap(t0, t1) + ctx.mov t0, arrA[0] # Prologue + ctx.add t0, arrB[0] + + for i in 1 ..< N: + ctx.mov t1, arrA[i] # Prepare the next iteration + ctx.mov arrR[i-1], t0 # Save the previous result in an interleaved manner + ctx.adc t1, arrB[i] # Compute + swap(t0, t1) # Break dependency chain + + ctx.mov arrR[N-1], t0 # Epilogue ctx.setToCarryFlag(carry) # Codegen @@ -197,14 +199,16 @@ macro sub_gen[N: static int](borrow: var Borrow, r: var Limbs[N], a, b: Limbs[N] var `t0sym`{.noinit.}, `t1sym`{.noinit.}: BaseType # Algorithm - for i in 0 ..< N: - ctx.mov t0, arrA[i] - if i == 0: - ctx.sub t0, arrB[0] - else: - ctx.sbb t0, arrB[i] - ctx.mov arrR[i], t0 - swap(t0, t1) + ctx.mov t0, arrA[0] # Prologue + ctx.sub t0, arrB[0] + + for i in 1 ..< N: + ctx.mov t1, arrA[i] # Prepare the next iteration + ctx.mov arrR[i-1], t0 # Save the previous reult in an interleaved manner + ctx.sbb t1, arrB[i] # Compute + swap(t0, t1) # Break dependency chain + + ctx.mov arrR[N-1], t0 # Epilogue ctx.setToCarryFlag(borrow) # Codegen diff --git a/constantine/arithmetic/bigints_montgomery.nim b/constantine/arithmetic/bigints_montgomery.nim index 489d610b1..b28e9a630 100644 --- a/constantine/arithmetic/bigints_montgomery.nim +++ b/constantine/arithmetic/bigints_montgomery.nim @@ -25,7 +25,7 @@ import # # ############################################################ -func montyResidue*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) = +func montyResidue*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseType, spareBits: static int) = ## Convert a BigInt from its natural representation ## to the Montgomery n-residue form ## @@ -40,9 +40,9 @@ func montyResidue*(mres: var BigInt, a, N, r2modM: BigInt, m0ninv: static BaseTy ## - `r2modM` is R² (mod M) ## with W = M.len ## and R = (2^WordBitWidth)^W - montyResidue(mres.limbs, a.limbs, N.limbs, r2modM.limbs, m0ninv, canUseNoCarryMontyMul) + montyResidue(mres.limbs, a.limbs, N.limbs, r2modM.limbs, m0ninv, spareBits) -func redc*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) = +func redc*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static BaseType, spareBits: static int) = ## Convert a BigInt from its Montgomery n-residue form ## to the natural representation ## @@ -54,26 +54,26 @@ func redc*[mBits](r: var BigInt[mBits], a, M: BigInt[mBits], m0ninv: static Base var one {.noInit.}: BigInt[mBits] one.setOne() one - redc(r.limbs, a.limbs, one.limbs, M.limbs, m0ninv, canUseNoCarryMontyMul) + redc(r.limbs, a.limbs, one.limbs, M.limbs, m0ninv, spareBits) -func montyMul*(r: var BigInt, a, b, M: BigInt, negInvModWord: static BaseType, canUseNoCarryMontyMul: static bool) = +func montyMul*(r: var BigInt, a, b, M: BigInt, negInvModWord: static BaseType, spareBits: static int) = ## Compute r <- a*b (mod M) in the Montgomery domain ## ## This resets r to zero before processing. Use {.noInit.} ## to avoid duplicating with Nim zero-init policy - montyMul(r.limbs, a.limbs, b.limbs, M.limbs, negInvModWord, canUseNoCarryMontyMul) + montyMul(r.limbs, a.limbs, b.limbs, M.limbs, negInvModWord, spareBits) -func montySquare*(r: var BigInt, a, M: BigInt, negInvModWord: static BaseType, canUseNoCarryMontyMul: static bool) = +func montySquare*(r: var BigInt, a, M: BigInt, negInvModWord: static BaseType, spareBits: static int) = ## Compute r <- a^2 (mod M) in the Montgomery domain ## ## This resets r to zero before processing. Use {.noInit.} ## to avoid duplicating with Nim zero-init policy - montySquare(r.limbs, a.limbs, M.limbs, negInvModWord, canUseNoCarryMontyMul) + montySquare(r.limbs, a.limbs, M.limbs, negInvModWord, spareBits) func montyPow*[mBits: static int]( a: var BigInt[mBits], exponent: openarray[byte], M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, - canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool + spareBits: static int ) = ## Compute a <- a^exponent (mod M) ## ``a`` in the Montgomery domain @@ -92,12 +92,12 @@ func montyPow*[mBits: static int]( const scratchLen = if windowSize == 1: 2 else: (1 shl windowSize) + 1 var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]] - montyPow(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare) + montyPow(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, spareBits) func montyPowUnsafeExponent*[mBits: static int]( a: var BigInt[mBits], exponent: openarray[byte], M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, - canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool + spareBits: static int ) = ## Compute a <- a^exponent (mod M) ## ``a`` in the Montgomery domain @@ -116,7 +116,7 @@ func montyPowUnsafeExponent*[mBits: static int]( const scratchLen = if windowSize == 1: 2 else: (1 shl windowSize) + 1 var scratchSpace {.noInit.}: array[scratchLen, Limbs[mBits.wordsRequired]] - montyPowUnsafeExponent(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, canUseNoCarryMontyMul, canUseNoCarryMontySquare) + montyPowUnsafeExponent(a.limbs, exponent, M.limbs, one.limbs, negInvModWord, scratchSpace, spareBits) from ../io/io_bigints import exportRawUint # Workaround recursive dependencies @@ -124,7 +124,7 @@ from ../io/io_bigints import exportRawUint func montyPow*[mBits, eBits: static int]( a: var BigInt[mBits], exponent: BigInt[eBits], M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, - canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool + spareBits: static int ) = ## Compute a <- a^exponent (mod M) ## ``a`` in the Montgomery domain @@ -138,12 +138,12 @@ func montyPow*[mBits, eBits: static int]( var expBE {.noInit.}: array[(ebits + 7) div 8, byte] expBE.exportRawUint(exponent, bigEndian) - montyPow(a, expBE, M, one, negInvModWord, windowSize, canUseNoCarryMontyMul, canUseNoCarryMontySquare) + montyPow(a, expBE, M, one, negInvModWord, windowSize, spareBits) func montyPowUnsafeExponent*[mBits, eBits: static int]( a: var BigInt[mBits], exponent: BigInt[eBits], M, one: BigInt[mBits], negInvModWord: static BaseType, windowSize: static int, - canUseNoCarryMontyMul, canUseNoCarryMontySquare: static bool + spareBits: static int ) = ## Compute a <- a^exponent (mod M) ## ``a`` in the Montgomery domain @@ -161,7 +161,7 @@ func montyPowUnsafeExponent*[mBits, eBits: static int]( var expBE {.noInit.}: array[(ebits + 7) div 8, byte] expBE.exportRawUint(exponent, bigEndian) - montyPowUnsafeExponent(a, expBE, M, one, negInvModWord, windowSize, canUseNoCarryMontyMul, canUseNoCarryMontySquare) + montyPowUnsafeExponent(a, expBE, M, one, negInvModWord, windowSize, spareBits) {.pop.} # inline {.pop.} # raises no exceptions diff --git a/constantine/arithmetic/finite_fields.nim b/constantine/arithmetic/finite_fields.nim index 3ed4c1236..a219051fb 100644 --- a/constantine/arithmetic/finite_fields.nim +++ b/constantine/arithmetic/finite_fields.nim @@ -56,7 +56,7 @@ func fromBig*(dst: var FF, src: BigInt) = when nimvm: dst.mres.montyResidue_precompute(src, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord()) else: - dst.mres.montyResidue(src, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) + dst.mres.montyResidue(src, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) func fromBig*[C: static Curve](T: type FF[C], src: BigInt): FF[C] {.noInit.} = ## Convert a BigInt to its Montgomery form @@ -65,7 +65,7 @@ func fromBig*[C: static Curve](T: type FF[C], src: BigInt): FF[C] {.noInit.} = func toBig*(src: FF): auto {.noInit, inline.} = ## Convert a finite-field element to a BigInt in natural representation var r {.noInit.}: typeof(src.mres) - r.redc(src.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) + r.redc(src.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits()) return r # Copy @@ -169,7 +169,7 @@ func sum*(r: var FF, a, b: FF) {.meter.} = overflowed = overflowed or not(r.mres < FF.fieldMod()) discard csub(r.mres, FF.fieldMod(), overflowed) -func sumNoReduce*(r: var FF, a, b: FF) {.meter.} = +func sumUnr*(r: var FF, a, b: FF) {.meter.} = ## Sum ``a`` and ``b`` into ``r`` without reduction discard r.mres.sum(a.mres, b.mres) @@ -183,7 +183,7 @@ func diff*(r: var FF, a, b: FF) {.meter.} = var underflowed = r.mres.diff(a.mres, b.mres) discard cadd(r.mres, FF.fieldMod(), underflowed) -func diffNoReduce*(r: var FF, a, b: FF) {.meter.} = +func diffUnr*(r: var FF, a, b: FF) {.meter.} = ## Substract `b` from `a` and store the result into `r` ## without reduction discard r.mres.diff(a.mres, b.mres) @@ -201,11 +201,11 @@ func double*(r: var FF, a: FF) {.meter.} = func prod*(r: var FF, a, b: FF) {.meter.} = ## Store the product of ``a`` by ``b`` modulo p into ``r`` ## ``r`` is initialized / overwritten - r.mres.montyMul(a.mres, b.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) + r.mres.montyMul(a.mres, b.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits()) func square*(r: var FF, a: FF) {.meter.} = ## Squaring modulo p - r.mres.montySquare(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.canUseNoCarryMontySquare()) + r.mres.montySquare(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits()) func neg*(r: var FF, a: FF) {.meter.} = ## Negate modulo p @@ -279,8 +279,7 @@ func pow*(a: var FF, exponent: BigInt) = exponent, FF.fieldMod(), FF.getMontyOne(), FF.getNegInvModWord(), windowSize, - FF.canUseNoCarryMontyMul(), - FF.canUseNoCarryMontySquare() + FF.getSpareBits() ) func pow*(a: var FF, exponent: openarray[byte]) = @@ -292,8 +291,7 @@ func pow*(a: var FF, exponent: openarray[byte]) = exponent, FF.fieldMod(), FF.getMontyOne(), FF.getNegInvModWord(), windowSize, - FF.canUseNoCarryMontyMul(), - FF.canUseNoCarryMontySquare() + FF.getSpareBits() ) func powUnsafeExponent*(a: var FF, exponent: BigInt) = @@ -312,8 +310,7 @@ func powUnsafeExponent*(a: var FF, exponent: BigInt) = exponent, FF.fieldMod(), FF.getMontyOne(), FF.getNegInvModWord(), windowSize, - FF.canUseNoCarryMontyMul(), - FF.canUseNoCarryMontySquare() + FF.getSpareBits() ) func powUnsafeExponent*(a: var FF, exponent: openarray[byte]) = @@ -332,8 +329,7 @@ func powUnsafeExponent*(a: var FF, exponent: openarray[byte]) = exponent, FF.fieldMod(), FF.getMontyOne(), FF.getNegInvModWord(), windowSize, - FF.canUseNoCarryMontyMul(), - FF.canUseNoCarryMontySquare() + FF.getSpareBits() ) # ############################################################ @@ -350,7 +346,7 @@ func `*=`*(a: var FF, b: FF) {.meter.} = func square*(a: var FF) {.meter.} = ## Squaring modulo p - a.mres.montySquare(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.canUseNoCarryMontySquare()) + a.mres.montySquare(a.mres, FF.fieldMod(), FF.getNegInvModWord(), FF.getSpareBits()) func square_repeated*(r: var FF, num: int) {.meter.} = ## Repeated squarings @@ -389,59 +385,57 @@ func `*=`*(a: var FF, b: static int) = elif b == 2: a.double() elif b == 3: - let t1 = a - a.double() - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + a += t elif b == 4: a.double() a.double() elif b == 5: - let t1 = a - a.double() - a.double() - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + t.double() + a += t elif b == 6: - a.double() - let t2 = a - a.double() # 4 - a += t2 + var t {.noInit.}: typeof(a) + t.double(a) + t += a # 3 + a.double(t) elif b == 7: - let t1 = a - a.double() - let t2 = a - a.double() # 4 - a += t2 - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + t.double() + t.double() + a.diff(t, a) elif b == 8: a.double() a.double() a.double() elif b == 9: - let t1 = a - a.double() - a.double() - a.double() # 8 - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + t.double() + t.double() + a.sum(t, a) elif b == 10: + var t {.noInit.}: typeof(a) + t.double(a) + t.double() + a += t # 5 a.double() - let t2 = a - a.double() - a.double() # 8 - a += t2 elif b == 11: - let t1 = a - a.double() - let t2 = a - a.double() - a.double() # 8 - a += t2 - a += t1 + var t {.noInit.}: typeof(a) + t.double(a) + t += a # 3 + t.double() # 6 + t.double() # 12 + a.diff(t, a) # 11 elif b == 12: - a.double() - a.double() # 4 - let t4 = a - a.double() # 8 - a += t4 + var t {.noInit.}: typeof(a) + t.double(a) + t += a # 3 + t.double() # 6 + a.double(t) # 12 else: {.error: "Multiplication by this small int not implemented".} diff --git a/constantine/arithmetic/finite_fields_double_precision.nim b/constantine/arithmetic/finite_fields_double_precision.nim new file mode 100644 index 000000000..f3eba513f --- /dev/null +++ b/constantine/arithmetic/finite_fields_double_precision.nim @@ -0,0 +1,243 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + ../config/[common, curves, type_ff], + ../primitives, + ./bigints, + ./finite_fields, + ./limbs, + ./limbs_extmul, + ./limbs_montgomery + +when UseASM_X86_64: + import assembly/limbs_asm_modular_dbl_prec_x86 + +type FpDbl*[C: static Curve] = object + ## Double-precision Fp element + ## A FpDbl is a partially-reduced double-precision element of Fp + ## The allowed range is [0, 2ⁿp) + ## with n = w*WordBitSize + ## and w the number of words necessary to represent p on the machine. + ## Concretely a 381-bit p needs 6*64 bits limbs (hence 384 bits total) + ## and so FpDbl would 768 bits. + # We directly work with double the number of limbs, + # instead of BigInt indirection. + limbs2x*: matchingLimbs2x(C) + +template doublePrec*(T: type Fp): type = + ## Return the double-precision type matching with Fp + FpDbl[T.C] + +# No exceptions allowed +{.push raises: [].} +{.push inline.} + +func `==`*(a, b: FpDbl): SecretBool = + a.limbs2x == b.limbs2x + +func isZero*(a: FpDbl): SecretBool = + a.limbs2x.isZero() + +func setZero*(a: var FpDbl) = + a.limbs2x.setZero() + +func prod2x*(r: var FpDbl, a, b: Fp) = + ## Double-precision multiplication + ## Store the product of ``a`` by ``b`` into ``r`` + ## + ## If a and b are in [0, p) + ## Output is in [0, p²) + ## + ## Output can be up to [0, 2ⁿp) range + ## provided spare bits are available in Fp representation + r.limbs2x.prod(a.mres.limbs, b.mres.limbs) + +func square2x*(r: var FpDbl, a: Fp) = + ## Double-precision squaring + ## Store the square of ``a`` into ``r`` + ## + ## If a is in [0, p) + ## Output is in [0, p²) + ## + ## Output can be up to [0, 2ⁿp) range + ## provided spare bits are available in Fp representation + r.limbs2x.square(a.mres.limbs) + +func redc2x*(r: var Fp, a: FpDbl) = + ## Reduce a double-precision field element into r + ## from [0, 2ⁿp) range to [0, p) range + const N = r.mres.limbs.len + montyRedc2x( + r.mres.limbs, + a.limbs2x, + Fp.C.Mod.limbs, + Fp.getNegInvModWord(), + Fp.getSpareBits() + ) + +func diff2xUnr*(r: var FpDbl, a, b: FpDbl) = + ## Double-precision substraction without reduction + ## + ## If the result is negative, fully reduced addition/substraction + ## are necessary afterwards to guarantee the [0, 2ⁿp) range + discard r.limbs2x.diff(a.limbs2x, b.limbs2x) + +func diff2xMod*(r: var FpDbl, a, b: FpDbl) = + ## Double-precision modular substraction + ## Output is conditionally reduced by 2ⁿp + ## to stay in the [0, 2ⁿp) range + when UseASM_X86_64: + submod2x_asm(r.limbs2x, a.limbs2x, b.limbs2x, FpDbl.C.Mod.limbs) + else: + # Substraction step + var underflowed = SecretBool r.limbs2x.diff(a.limbs2x, b.limbs2x) + + # Conditional reduction by 2ⁿp + const N = r.limbs2x.len div 2 + const M = FpDbl.C.Mod + var carry = Carry(0) + var sum: SecretWord + staticFor i, 0, N: + addC(carry, sum, r.limbs2x[i+N], M.limbs[i], carry) + underflowed.ccopy(r.limbs2x[i+N], sum) + +func sum2xUnr*(r: var FpDbl, a, b: FpDbl) = + ## Double-precision addition without reduction + ## + ## If the result is bigger than 2ⁿp, fully reduced addition/substraction + ## are necessary afterwards to guarantee the [0, 2ⁿp) range + discard r.limbs2x.sum(a.limbs2x, b.limbs2x) + +func sum2xMod*(r: var FpDbl, a, b: FpDbl) = + ## Double-precision modular addition + ## Output is conditionally reduced by 2ⁿp + ## to stay in the [0, 2ⁿp) range + when UseASM_X86_64: + addmod2x_asm(r.limbs2x, a.limbs2x, b.limbs2x, FpDbl.C.Mod.limbs) + else: + # Addition step + var overflowed = SecretBool r.limbs2x.sum(a.limbs2x, b.limbs2x) + + const N = r.limbs2x.len div 2 + const M = FpDbl.C.Mod + # Test >= 2ⁿp + var borrow = Borrow(0) + var t{.noInit.}: Limbs[N] + staticFor i, 0, N: + subB(borrow, t[i], r.limbs2x[i+N], M.limbs[i], borrow) + + # If no borrow occured, r was bigger than 2ⁿp + overflowed = overflowed or not(SecretBool borrow) + + # Conditional reduction by 2ⁿp + staticFor i, 0, N: + SecretBool(overflowed).ccopy(r.limbs2x[i+N], t[i]) + +func neg2xMod*(r: var FpDbl, a: FpDbl) = + ## Double-precision modular substraction + ## Negate modulo 2ⁿp + when UseASM_X86_64: + negmod2x_asm(r.limbs2x, a.limbs2x, FpDbl.C.Mod.limbs) + else: + # If a = 0 we need r = 0 and not r = M + # as comparison operator assume unicity + # of the modular representation. + # Also make sure to handle aliasing where r.addr = a.addr + var t {.noInit.}: FpDbl + let isZero = a.isZero() + const N = r.limbs2x.len div 2 + const M = FpDbl.C.Mod + var borrow = Borrow(0) + # 2ⁿp is filled with 0 in the first half + staticFor i, 0, N: + subB(borrow, t.limbs2x[i], Zero, a.limbs2x[i], borrow) + # 2ⁿp has p (shifted) for the rest of the limbs + staticFor i, N, r.limbs2x.len: + subB(borrow, t.limbs2x[i], M.limbs[i-N], a.limbs2x[i], borrow) + + # Zero the result if input was zero + t.limbs2x.czero(isZero) + r = t + +func prod2xImpl( + r {.noAlias.}: var FpDbl, + a {.noAlias.}: FpDbl, b: static int) = + ## Multiplication by a small integer known at compile-time + ## Requires no aliasing and b positive + static: doAssert b >= 0 + + when b == 0: + r.setZero() + elif b == 1: + r = a + elif b == 2: + r.sum2xMod(a, a) + elif b == 3: + r.sum2xMod(a, a) + r.sum2xMod(a, r) + elif b == 4: + r.sum2xMod(a, a) + r.sum2xMod(r, r) + elif b == 5: + r.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, a) + elif b == 6: + r.sum2xMod(a, a) + let t2 = r + r.sum2xMod(r, r) # 4 + r.sum2xMod(t, t2) + elif b == 7: + r.sum2xMod(a, a) + r.sum2xMod(r, r) # 4 + r.sum2xMod(r, r) + r.diff2xMod(r, a) + elif b == 8: + r.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, r) + elif b == 9: + r.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, r) # 8 + r.sum2xMod(r, a) + elif b == 10: + r.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, a) # 5 + r.sum2xMod(r, r) + elif b == 11: + r.sum2xMod(a, a) + r.sum2xMod(r, r) + r.sum2xMod(r, a) # 5 + r.sum2xMod(r, r) + r.sum2xMod(r, a) + elif b == 12: + r.sum2xMod(a, a) + r.sum2xMod(r, r) # 4 + let t4 = a + r.sum2xMod(r, r) # 8 + r.sum2xMod(r, t4) + else: + {.error: "Multiplication by this small int not implemented".} + +func prod2x*(r: var FpDbl, a: FpDbl, b: static int) = + ## Multiplication by a small integer known at compile-time + const negate = b < 0 + const b = if negate: -b + else: b + when negate: + var t {.noInit.}: typeof(r) + t.neg2xMod(a) + else: + let t = a + prod2xImpl(r, t, b) + +{.pop.} # inline +{.pop.} # raises no exceptions diff --git a/constantine/arithmetic/finite_fields_double_width.nim b/constantine/arithmetic/finite_fields_double_width.nim deleted file mode 100644 index a9a94dbf5..000000000 --- a/constantine/arithmetic/finite_fields_double_width.nim +++ /dev/null @@ -1,81 +0,0 @@ -# Constantine -# Copyright (c) 2018-2019 Status Research & Development GmbH -# Copyright (c) 2020-Present Mamy André-Ratsimbazafy -# Licensed and distributed under either of -# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). -# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). -# at your option. This file may not be copied, modified, or distributed except according to those terms. - -import - ../config/[common, curves, type_ff], - ../primitives, - ./bigints, - ./finite_fields, - ./limbs, - ./limbs_extmul, - ./limbs_montgomery - -when UseASM_X86_64: - import assembly/limbs_asm_modular_dbl_width_x86 - -type FpDbl*[C: static Curve] = object - ## Double-width Fp element - ## This allows saving on reductions - # We directly work with double the number of limbs - limbs2x*: matchingLimbs2x(C) - -template doubleWidth*(T: typedesc[Fp]): typedesc = - ## Return the double-width type matching with Fp - FpDbl[T.C] - -# No exceptions allowed -{.push raises: [].} -{.push inline.} - -func `==`*(a, b: FpDbl): SecretBool = - a.limbs2x == b.limbs2x - -func mulNoReduce*(r: var FpDbl, a, b: Fp) = - ## Store the product of ``a`` by ``b`` into ``r`` - r.limbs2x.prod(a.mres.limbs, b.mres.limbs) - -func squareNoReduce*(r: var FpDbl, a: Fp) = - ## Store the square of ``a`` into ``r`` - r.limbs2x.square(a.mres.limbs) - -func reduce*(r: var Fp, a: FpDbl) = - ## Reduce a double-width field element into r - const N = r.mres.limbs.len - montyRed( - r.mres.limbs, - a.limbs2x, - Fp.C.Mod.limbs, - Fp.getNegInvModWord(), - Fp.canUseNoCarryMontyMul() - ) - -func diffNoReduce*(r: var FpDbl, a, b: FpDbl) = - ## Double-width substraction without reduction - discard r.limbs2x.diff(a.limbs2x, b.limbs2x) - -func diff*(r: var FpDbl, a, b: FpDbl) = - ## Double-width modular substraction - when UseASM_X86_64: - sub2x_asm(r.limbs2x, a.limbs2x, b.limbs2x, FpDbl.C.Mod.limbs) - else: - var underflowed = SecretBool r.limbs2x.diff(a.limbs2x, b.limbs2x) - - const N = r.limbs2x.len div 2 - const M = FpDbl.C.Mod - var carry = Carry(0) - var sum: SecretWord - for i in 0 ..< N: - addC(carry, sum, r.limbs2x[i+N], M.limbs[i], carry) - underflowed.ccopy(r.limbs2x[i+N], sum) - -func `-=`*(a: var FpDbl, b: FpDbl) = - ## Double-width modular substraction - a.diff(a, b) - -{.pop.} # inline -{.pop.} # raises no exceptions diff --git a/constantine/arithmetic/limbs_extmul.nim b/constantine/arithmetic/limbs_extmul.nim index a674cb8bc..979de7a05 100644 --- a/constantine/arithmetic/limbs_extmul.nim +++ b/constantine/arithmetic/limbs_extmul.nim @@ -72,7 +72,8 @@ func prod*[rLen, aLen, bLen: static int](r: var Limbs[rLen], a: Limbs[aLen], b: ## `r` must not alias ``a`` or ``b`` when UseASM_X86_64 and aLen <= 6: - if ({.noSideEffect.}: hasBmi2()) and ({.noSideEffect.}: hasAdx()): + # ADX implies BMI2 + if ({.noSideEffect.}: hasAdx()): mul_asm_adx_bmi2(r, a, b) else: mul_asm(r, a, b) diff --git a/constantine/arithmetic/limbs_montgomery.nim b/constantine/arithmetic/limbs_montgomery.nim index e45964c2c..ec90b46f3 100644 --- a/constantine/arithmetic/limbs_montgomery.nim +++ b/constantine/arithmetic/limbs_montgomery.nim @@ -281,12 +281,12 @@ func montySquare_CIOS(r: var Limbs, a, M: Limbs, m0ninv: BaseType) {.used.}= # Montgomery Reduction # ------------------------------------------------------------ -func montyRed_CIOS[N: static int]( +func montyRedc2x_CIOS[N: static int]( r: var array[N, SecretWord], a: array[N*2, SecretWord], M: array[N, SecretWord], m0ninv: BaseType) = - ## Montgomery reduce a double-width bigint modulo M + ## Montgomery reduce a double-precision bigint modulo M # - Analyzing and Comparing Montgomery Multiplication Algorithms # Cetin Kaya Koc and Tolga Acar and Burton S. Kaliski Jr. # http://pdfs.semanticscholar.org/5e39/41ff482ec3ee41dc53c3298f0be085c69483.pdf @@ -299,7 +299,7 @@ func montyRed_CIOS[N: static int]( # Algorithm # Inputs: # - N number of limbs - # - a[0 ..< 2N] (double-width input to reduce) + # - a[0 ..< 2N] (double-precision input to reduce) # - M[0 ..< N] The field modulus (must be odd for Montgomery reduction) # - m0ninv: Montgomery Reduction magic number = -1/M[0] # Output: @@ -343,12 +343,12 @@ func montyRed_CIOS[N: static int]( discard res.csub(M, SecretWord(carry).isNonZero() or not(res < M)) r = res -func montyRed_Comba[N: static int]( +func montyRedc2x_Comba[N: static int]( r: var array[N, SecretWord], a: array[N*2, SecretWord], M: array[N, SecretWord], m0ninv: BaseType) = - ## Montgomery reduce a double-width bigint modulo M + ## Montgomery reduce a double-precision bigint modulo M # We use Product Scanning / Comba multiplication var t, u, v = Zero var carry: Carry @@ -392,7 +392,7 @@ func montyRed_Comba[N: static int]( func montyMul*( r: var Limbs, a, b, M: Limbs, - m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) {.inline.} = + m0ninv: static BaseType, spareBits: static int) {.inline.} = ## Compute r <- a*b (mod M) in the Montgomery domain ## `m0ninv` = -1/M (mod SecretWord). Our words are 2^32 or 2^64 ## @@ -419,9 +419,10 @@ func montyMul*( # The implementation is visible from here, the compiler can make decision whether to: # - specialize/duplicate code for m0ninv == 1 (especially if only 1 curve is needed) # - keep it generic and optimize code size - when canUseNoCarryMontyMul: + when spareBits >= 1: when UseASM_X86_64 and a.len in {2 .. 6}: # TODO: handle spilling - if ({.noSideEffect.}: hasBmi2()) and ({.noSideEffect.}: hasAdx()): + # ADX implies BMI2 + if ({.noSideEffect.}: hasAdx()): montMul_CIOS_nocarry_asm_adx_bmi2(r, a, b, M, m0ninv) else: montMul_CIOS_nocarry_asm(r, a, b, M, m0ninv) @@ -431,14 +432,14 @@ func montyMul*( montyMul_FIPS(r, a, b, M, m0ninv) func montySquare*(r: var Limbs, a, M: Limbs, - m0ninv: static BaseType, canUseNoCarryMontySquare: static bool) {.inline.} = + m0ninv: static BaseType, spareBits: static int) {.inline.} = ## Compute r <- a^2 (mod M) in the Montgomery domain ## `m0ninv` = -1/M (mod SecretWord). Our words are 2^31 or 2^63 # TODO: needs optimization similar to multiplication - montyMul(r, a, a, M, m0ninv, canUseNoCarryMontySquare) + montyMul(r, a, a, M, m0ninv, spareBits) - # when canUseNoCarryMontySquare: + # when spareBits >= 2: # # TODO: Deactivated # # Off-by one on 32-bit on the least significant bit # # for Fp[BLS12-381] with inputs @@ -459,26 +460,27 @@ func montySquare*(r: var Limbs, a, M: Limbs, # montyMul_FIPS(r, a, a, M, m0ninv) # TODO upstream, using Limbs[N] breaks semcheck -func montyRed*[N: static int]( +func montyRedc2x*[N: static int]( r: var array[N, SecretWord], a: array[N*2, SecretWord], M: array[N, SecretWord], - m0ninv: BaseType, canUseNoCarryMontyMul: static bool) {.inline.} = - ## Montgomery reduce a double-width bigint modulo M + m0ninv: BaseType, spareBits: static int) {.inline.} = + ## Montgomery reduce a double-precision bigint modulo M when UseASM_X86_64 and r.len <= 6: - if ({.noSideEffect.}: hasBmi2()) and ({.noSideEffect.}: hasAdx()): - montRed_asm_adx_bmi2(r, a, M, m0ninv, canUseNoCarryMontyMul) + # ADX implies BMI2 + if ({.noSideEffect.}: hasAdx()): + montRed_asm_adx_bmi2(r, a, M, m0ninv, spareBits) else: - montRed_asm(r, a, M, m0ninv, canUseNoCarryMontyMul) + montRed_asm(r, a, M, m0ninv, spareBits) elif UseASM_X86_32 and r.len <= 6: # TODO: Assembly faster than GCC but slower than Clang - montRed_asm(r, a, M, m0ninv, canUseNoCarryMontyMul) + montRed_asm(r, a, M, m0ninv, spareBits) else: - montyRed_CIOS(r, a, M, m0ninv) - # montyRed_Comba(r, a, M, m0ninv) + montyRedc2x_CIOS(r, a, M, m0ninv) + # montyRedc2x_Comba(r, a, M, m0ninv) func redc*(r: var Limbs, a, one, M: Limbs, - m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) = + m0ninv: static BaseType, spareBits: static int) = ## Transform a bigint ``a`` from it's Montgomery N-residue representation (mod N) ## to the regular natural representation (mod N) ## @@ -497,10 +499,10 @@ func redc*(r: var Limbs, a, one, M: Limbs, # - http://langevin.univ-tln.fr/cours/MLC/extra/montgomery.pdf # Montgomery original paper # - montyMul(r, a, one, M, m0ninv, canUseNoCarryMontyMul) + montyMul(r, a, one, M, m0ninv, spareBits) func montyResidue*(r: var Limbs, a, M, r2modM: Limbs, - m0ninv: static BaseType, canUseNoCarryMontyMul: static bool) = + m0ninv: static BaseType, spareBits: static int) = ## Transform a bigint ``a`` from it's natural representation (mod N) ## to a the Montgomery n-residue representation ## @@ -518,7 +520,7 @@ func montyResidue*(r: var Limbs, a, M, r2modM: Limbs, ## Important: `r` is overwritten ## The result `r` buffer size MUST be at least the size of `M` buffer # Reference: https://eprint.iacr.org/2017/1057.pdf - montyMul(r, a, r2ModM, M, m0ninv, canUseNoCarryMontyMul) + montyMul(r, a, r2ModM, M, m0ninv, spareBits) # Montgomery Modular Exponentiation # ------------------------------------------ @@ -565,7 +567,7 @@ func montyPowPrologue( a: var Limbs, M, one: Limbs, m0ninv: static BaseType, scratchspace: var openarray[Limbs], - canUseNoCarryMontyMul: static bool + spareBits: static int ): uint = ## Setup the scratchspace ## Returns the fixed-window size for exponentiation with window optimization. @@ -579,7 +581,7 @@ func montyPowPrologue( else: scratchspace[2] = a for k in 2 ..< 1 shl result: - scratchspace[k+1].montyMul(scratchspace[k], a, M, m0ninv, canUseNoCarryMontyMul) + scratchspace[k+1].montyMul(scratchspace[k], a, M, m0ninv, spareBits) # Set a to one a = one @@ -593,7 +595,7 @@ func montyPowSquarings( window: uint, acc, acc_len: var uint, e: var int, - canUseNoCarryMontySquare: static bool + spareBits: static int ): tuple[k, bits: uint] {.inline.}= ## Squaring step of exponentiation by squaring ## Get the next k bits in range [1, window) @@ -629,7 +631,7 @@ func montyPowSquarings( # We have k bits and can do k squaring for i in 0 ..< k: - tmp.montySquare(a, M, m0ninv, canUseNoCarryMontySquare) + tmp.montySquare(a, M, m0ninv, spareBits) a = tmp return (k, bits) @@ -640,8 +642,7 @@ func montyPow*( M, one: Limbs, m0ninv: static BaseType, scratchspace: var openarray[Limbs], - canUseNoCarryMontyMul: static bool, - canUseNoCarryMontySquare: static bool + spareBits: static int ) = ## Modular exponentiation r = a^exponent mod M ## in the Montgomery domain @@ -669,7 +670,7 @@ func montyPow*( ## A window of size 5 requires (2^5 + 1)*(381 + 7)/8 = 33 * 48 bytes = 1584 bytes ## of scratchspace (on the stack). - let window = montyPowPrologue(a, M, one, m0ninv, scratchspace, canUseNoCarryMontyMul) + let window = montyPowPrologue(a, M, one, m0ninv, scratchspace, spareBits) # We process bits with from most to least significant. # At each loop iteration with have acc_len bits in acc. @@ -684,7 +685,7 @@ func montyPow*( a, exponent, M, m0ninv, scratchspace[0], window, acc, acc_len, e, - canUseNoCarryMontySquare + spareBits ) # Window lookup: we set scratchspace[1] to the lookup value. @@ -699,7 +700,7 @@ func montyPow*( # Multiply with the looked-up value # we keep the product only if the exponent bits are not all zeroes - scratchspace[0].montyMul(a, scratchspace[1], M, m0ninv, canUseNoCarryMontyMul) + scratchspace[0].montyMul(a, scratchspace[1], M, m0ninv, spareBits) a.ccopy(scratchspace[0], SecretWord(bits).isNonZero()) func montyPowUnsafeExponent*( @@ -708,8 +709,7 @@ func montyPowUnsafeExponent*( M, one: Limbs, m0ninv: static BaseType, scratchspace: var openarray[Limbs], - canUseNoCarryMontyMul: static bool, - canUseNoCarryMontySquare: static bool + spareBits: static int ) = ## Modular exponentiation r = a^exponent mod M ## in the Montgomery domain @@ -723,7 +723,7 @@ func montyPowUnsafeExponent*( # TODO: scratchspace[1] is unused when window > 1 - let window = montyPowPrologue(a, M, one, m0ninv, scratchspace, canUseNoCarryMontyMul) + let window = montyPowPrologue(a, M, one, m0ninv, scratchspace, spareBits) var acc, acc_len: uint @@ -733,16 +733,16 @@ func montyPowUnsafeExponent*( a, exponent, M, m0ninv, scratchspace[0], window, acc, acc_len, e, - canUseNoCarryMontySquare + spareBits ) ## Warning ⚠️: Exposes the exponent bits if bits != 0: if window > 1: - scratchspace[0].montyMul(a, scratchspace[1+bits], M, m0ninv, canUseNoCarryMontyMul) + scratchspace[0].montyMul(a, scratchspace[1+bits], M, m0ninv, spareBits) else: # scratchspace[1] holds the original `a` - scratchspace[0].montyMul(a, scratchspace[1], M, m0ninv, canUseNoCarryMontyMul) + scratchspace[0].montyMul(a, scratchspace[1], M, m0ninv, spareBits) a = scratchspace[0] {.pop.} # raises no exceptions diff --git a/constantine/config/curves_derived.nim b/constantine/config/curves_derived.nim index 87530a76e..21ffa39d3 100644 --- a/constantine/config/curves_derived.nim +++ b/constantine/config/curves_derived.nim @@ -51,18 +51,10 @@ macro genDerivedConstants*(mode: static DerivedConstantMode): untyped = let M = if mode == kModulus: bindSym(curve & "_Modulus") else: bindSym(curve & "_Order") - # const MyCurve_CanUseNoCarryMontyMul = useNoCarryMontyMul(MyCurve_Modulus) + # const MyCurve_SpareBits = countSpareBits(MyCurve_Modulus) result.add newConstStmt( - used(curve & ff & "_CanUseNoCarryMontyMul"), newCall( - bindSym"useNoCarryMontyMul", - M - ) - ) - - # const MyCurve_CanUseNoCarryMontySquare = useNoCarryMontySquare(MyCurve_Modulus) - result.add newConstStmt( - used(curve & ff & "_CanUseNoCarryMontySquare"), newCall( - bindSym"useNoCarryMontySquare", + used(curve & ff & "_SpareBits"), newCall( + bindSym"countSpareBits", M ) ) diff --git a/constantine/config/curves_prop_derived.nim b/constantine/config/curves_prop_derived.nim index 2439f4ebf..1bac65f77 100644 --- a/constantine/config/curves_prop_derived.nim +++ b/constantine/config/curves_prop_derived.nim @@ -61,15 +61,18 @@ template fieldMod*(Field: type FF): auto = else: Field.C.getCurveOrder() -macro canUseNoCarryMontyMul*(ff: type FF): untyped = - ## Returns true if the Modulus is compatible with a fast - ## Montgomery multiplication that avoids many carries - result = bindConstant(ff, "CanUseNoCarryMontyMul") - -macro canUseNoCarryMontySquare*(ff: type FF): untyped = - ## Returns true if the Modulus is compatible with a fast - ## Montgomery squaring that avoids many carries - result = bindConstant(ff, "CanUseNoCarryMontySquare") +macro getSpareBits*(ff: type FF): untyped = + ## Returns the number of extra bits + ## in the modulus M representation. + ## + ## This is used for no-carry operations + ## or lazily reduced operations by allowing + ## output in range: + ## - [0, 2p) if 1 bit is available + ## - [0, 4p) if 2 bits are available + ## - [0, 8p) if 3 bits are available + ## - ... + result = bindConstant(ff, "SpareBits") macro getR2modP*(ff: type FF): untyped = ## Get the Montgomery "R^2 mod P" constant associated to a curve field modulus diff --git a/constantine/config/precompute.nim b/constantine/config/precompute.nim index a5cf9f5d2..7a478bcfd 100644 --- a/constantine/config/precompute.nim +++ b/constantine/config/precompute.nim @@ -238,21 +238,20 @@ func checkValidModulus(M: BigInt) = doAssert msb == expectedMsb, "Internal Error: the modulus must use all declared bits and only those" -func useNoCarryMontyMul*(M: BigInt): bool = - ## Returns if the modulus is compatible - ## with the no-carry Montgomery Multiplication - ## from https://hackmd.io/@zkteam/modular_multiplication - # Indirection needed because static object are buggy - # https://github.com/nim-lang/Nim/issues/9679 - BaseType(M.limbs[^1]) < high(BaseType) shr 1 - -func useNoCarryMontySquare*(M: BigInt): bool = - ## Returns if the modulus is compatible - ## with the no-carry Montgomery Squaring - ## from https://hackmd.io/@zkteam/modular_multiplication - # Indirection needed because static object are buggy - # https://github.com/nim-lang/Nim/issues/9679 - BaseType(M.limbs[^1]) < high(BaseType) shr 2 +func countSpareBits*(M: BigInt): int = + ## Count the number of extra bits + ## in the modulus M representation. + ## + ## This is used for no-carry operations + ## or lazily reduced operations by allowing + ## output in range: + ## - [0, 2p) if 1 bit is available + ## - [0, 4p) if 2 bits are available + ## - [0, 8p) if 3 bits are available + ## - ... + checkValidModulus(M) + let msb = log2(BaseType(M.limbs[^1])) + result = WordBitWidth - 1 - msb.int func invModBitwidth[T: SomeUnsignedInt](a: T): T = # We use BaseType for return value because static distinct type diff --git a/constantine/tower_field_extensions/README.md b/constantine/tower_field_extensions/README.md index baa5463fd..7b5927b84 100644 --- a/constantine/tower_field_extensions/README.md +++ b/constantine/tower_field_extensions/README.md @@ -87,6 +87,16 @@ From Ben Edgington, https://hackmd.io/@benjaminion/bls12-381 Jean-Luc Beuchat and Jorge Enrique González Díaz and Shigeo Mitsunari and Eiji Okamoto and Francisco Rodríguez-Henríquez and Tadanori Teruya, 2010\ https://eprint.iacr.org/2010/354 +- Faster Explicit Formulas for Computing Pairings over Ordinary Curves\ + Diego F. Aranha and Koray Karabina and Patrick Longa and Catherine H. Gebotys and Julio López, 2010\ + https://eprint.iacr.org/2010/526.pdf\ + https://www.iacr.org/archive/eurocrypt2011/66320047/66320047.pdf + +- Efficient Implementation of Bilinear Pairings on ARM Processors + Gurleen Grewal, Reza Azarderakhsh, + Patrick Longa, Shi Hu, and David Jao, 2012 + https://eprint.iacr.org/2012/408.pdf + - Choosing and generating parameters for low level pairing implementation on BN curves\ Sylvain Duquesne and Nadia El Mrabet and Safia Haloui and Franck Rondepierre, 2015\ https://eprint.iacr.org/2015/1212 diff --git a/constantine/tower_field_extensions/extension_fields.nim b/constantine/tower_field_extensions/extension_fields.nim index edd2acd58..1d7dafcf0 100644 --- a/constantine/tower_field_extensions/extension_fields.nim +++ b/constantine/tower_field_extensions/extension_fields.nim @@ -47,11 +47,11 @@ template c1*(a: ExtensionField): auto = template c2*(a: CubicExt): auto = a.coords[2] -template `c0=`*(a: ExtensionField, v: auto) = +template `c0=`*(a: var ExtensionField, v: auto) = a.coords[0] = v -template `c1=`*(a: ExtensionField, v: auto) = +template `c1=`*(a: var ExtensionField, v: auto) = a.coords[1] = v -template `c2=`*(a: CubicExt, v: auto) = +template `c2=`*(a: var CubicExt, v: auto) = a.coords[2] = v template C*(E: type ExtensionField): Curve = @@ -222,88 +222,343 @@ func csub*(a: var ExtensionField, b: ExtensionField, ctl: SecretBool) = func `*=`*(a: var ExtensionField, b: static int) = ## Multiplication by a small integer known at compile-time - - const negate = b < 0 - const b = if negate: -b - else: b - when negate: - a.neg(a) - when b == 0: - a.setZero() - elif b == 1: - return - elif b == 2: - a.double() - elif b == 3: - let t1 = a - a.double() - a += t1 - elif b == 4: - a.double() - a.double() - elif b == 5: - let t1 = a - a.double() - a.double() - a += t1 - elif b == 6: - a.double() - let t2 = a - a.double() # 4 - a += t2 - elif b == 7: - let t1 = a - a.double() - let t2 = a - a.double() # 4 - a += t2 - a += t1 - elif b == 8: - a.double() - a.double() - a.double() - elif b == 9: - let t1 = a - a.double() - a.double() - a.double() # 8 - a += t1 - elif b == 10: - a.double() - let t2 = a - a.double() - a.double() # 8 - a += t2 - elif b == 11: - let t1 = a - a.double() - let t2 = a - a.double() - a.double() # 8 - a += t2 - a += t1 - elif b == 12: - a.double() - a.double() # 4 - let t4 = a - a.double() # 8 - a += t4 - else: - {.error: "Multiplication by this small int not implemented".} + for i in 0 ..< a.coords.len: + a.coords[i] *= b func prod*(r: var ExtensionField, a: ExtensionField, b: static int) = ## Multiplication by a small integer known at compile-time - const negate = b < 0 - const b = if negate: -b - else: b - when negate: - r.neg(a) - else: - r = a + r = a r *= b {.pop.} # inline +# ############################################################ +# # +# Lazy reduced extension fields # +# # +# ############################################################ + +type + QuadraticExt2x[F] = object + ## Quadratic Extension field for lazy reduced fields + coords: array[2, F] + + CubicExt2x[F] = object + ## Cubic Extension field for lazy reduced fields + coords: array[3, F] + + ExtensionField2x[F] = QuadraticExt2x[F] or CubicExt2x[F] + +template doublePrec(T: type ExtensionField): type = + # For now naive unrolling, recursive template don't match + # and I don't want to deal with types in macros + when T is QuadraticExt: + when T.F is QuadraticExt: # Fp4Dbl + QuadraticExt2x[QuadraticExt2x[doublePrec(T.F.F)]] + elif T.F is Fp: # Fp2Dbl + QuadraticExt2x[doublePrec(T.F)] + elif T is CubicExt: + when T.F is QuadraticExt: # Fp6Dbl + CubicExt2x[QuadraticExt2x[doublePrec(T.F.F)]] + +func has1extraBit(F: type Fp): bool = + ## We construct extensions only on Fp (and not Fr) + getSpareBits(F) >= 1 + +func has2extraBits(F: type Fp): bool = + ## We construct extensions only on Fp (and not Fr) + getSpareBits(F) >= 2 + +func has1extraBit(E: type ExtensionField): bool = + ## We construct extensions only on Fp (and not Fr) + getSpareBits(Fp[E.F.C]) >= 1 + +func has2extraBits(E: type ExtensionField): bool = + ## We construct extensions only on Fp (and not Fr) + getSpareBits(Fp[E.F.C]) >= 2 + +template C(E: type ExtensionField2x): Curve = + E.F.C + +template c0(a: ExtensionField2x): auto = + a.coords[0] +template c1(a: ExtensionField2x): auto = + a.coords[1] +template c2(a: CubicExt2x): auto = + a.coords[2] + +template `c0=`(a: var ExtensionField2x, v: auto) = + a.coords[0] = v +template `c1=`(a: var ExtensionField2x, v: auto) = + a.coords[1] = v +template `c2=`(a: var CubicExt2x, v: auto) = + a.coords[2] = v + +# Initialization +# ------------------------------------------------------------------- + +func setZero*(a: var ExtensionField2x) = + ## Set ``a`` to 0 in the extension field + staticFor i, 0, a.coords.len: + a.coords[i].setZero() + +# Abelian group +# ------------------------------------------------------------------- + +func sumUnr(r: var ExtensionField, a, b: ExtensionField) = + ## Sum ``a`` and ``b`` into ``r`` + staticFor i, 0, a.coords.len: + r.coords[i].sumUnr(a.coords[i], b.coords[i]) + +func diff2xUnr(r: var ExtensionField2x, a, b: ExtensionField2x) = + ## Double-precision substraction without reduction + staticFor i, 0, a.coords.len: + r.coords[i].diff2xUnr(a.coords[i], b.coords[i]) + +func diff2xMod(r: var ExtensionField2x, a, b: ExtensionField2x) = + ## Double-precision modular substraction + staticFor i, 0, a.coords.len: + r.coords[i].diff2xMod(a.coords[i], b.coords[i]) + +func sum2xUnr(r: var ExtensionField2x, a, b: ExtensionField2x) = + ## Double-precision addition without reduction + staticFor i, 0, a.coords.len: + r.coords[i].sum2xUnr(a.coords[i], b.coords[i]) + +func sum2xMod(r: var ExtensionField2x, a, b: ExtensionField2x) = + ## Double-precision modular addition + staticFor i, 0, a.coords.len: + r.coords[i].sum2xMod(a.coords[i], b.coords[i]) + +func neg2xMod(r: var ExtensionField2x, a: ExtensionField2x) = + ## Double-precision modular negation + staticFor i, 0, a.coords.len: + r.coords[i].neg2xMod(a.coords[i], b.coords[i]) + +# Reductions +# ------------------------------------------------------------------- + +func redc2x(r: var ExtensionField, a: ExtensionField2x) = + ## Reduction + staticFor i, 0, a.coords.len: + r.coords[i].redc2x(a.coords[i]) + +# Multiplication by a small integer known at compile-time +# ------------------------------------------------------------------- + +func prod2x(r: var ExtensionField2x, a: ExtensionField2x, b: static int) = + ## Multiplication by a small integer known at compile-time + for i in 0 ..< a.coords.len: + r.coords[i].prod2x(a.coords[i], b) + +# NonResidue +# ---------------------------------------------------------------------- + +func prod2x(r: var FpDbl, a: FpDbl, _: type NonResidue){.inline.} = + ## Multiply an element of 𝔽p by the quadratic non-residue + ## chosen to construct 𝔽p2 + static: doAssert FpDbl.C.getNonResidueFp() != -1, "𝔽p2 should be specialized for complex extension" + r.prod2x(a, FpDbl.C.getNonResidueFp()) + +func prod2x[C: static Curve]( + r {.noalias.}: var QuadraticExt2x[FpDbl[C]], + a {.noalias.}: QuadraticExt2x[FpDbl[C]], + _: type NonResidue) {.inline.} = + ## Multiplication by non-residue + ## ! no aliasing! + const complex = C.getNonResidueFp() == -1 + const U = C.getNonResidueFp2()[0] + const V = C.getNonResidueFp2()[1] + const Beta {.used.} = C.getNonResidueFp() + + when complex and U == 1 and V == 1: + r.c0.diff2xMod(a.c0, a.c1) + r.c1.sum2xMod(a.c0, a.c1) + else: + # Case: + # - BN254_Snarks, QNR_Fp: -1, SNR_Fp2: 9+1𝑖 (𝑖 = √-1) + # - BLS12_377, QNR_Fp: -5, SNR_Fp2: 0+1j (j = √-5) + # - BW6_761, SNR_Fp: -4, CNR_Fp2: 0+1j (j = √-4) + when U == 0: + # mul_sparse_by_0v + # r0 = β a1 v + # r1 = a0 v + # r and a don't alias, we use `r` as a temp location + r.c1.prod2x(a.c1, V) + r.c0.prod2x(r.c1, NonResidue) + r.c1.prod2x(a.c0, V) + else: + # ξ = u + v x + # and x² = β + # + # (c0 + c1 x) (u + v x) => u c0 + (u c0 + u c1)x + v c1 x² + # => u c0 + β v c1 + (v c0 + u c1) x + var t {.noInit.}: FpDbl[C] + + r.c0.prod2x(a.c0, U) + when V == 1 and Beta == -1: # Case BN254_Snarks + r.c0.diff2xMod(r.c0, a.c1) # r0 = u c0 + β v c1 + else: + {.error: "Unimplemented".} + + r.c1.prod2x(a.c0, V) + t.prod2x(a.c1, U) + r.c1.sum2xMod(r.c1, t) # r1 = v c0 + u c1 + +# ############################################################ +# # +# Quadratic extensions - Lazy Reductions # +# # +# ############################################################ + +# Forward declarations +# ---------------------------------------------------------------------- + +func prod2x(r: var QuadraticExt2x, a, b: QuadraticExt) +func square2x(r: var QuadraticExt2x, a: QuadraticExt) + +# Commutative ring implementation for complex quadratic extension fields +# ---------------------------------------------------------------------- + +func prod2x_complex(r: var QuadraticExt2x, a, b: QuadraticExt) = + ## Double-precision unreduced complex multiplication + # r and a or b cannot alias + + mixin fromComplexExtension + static: doAssert a.fromComplexExtension() + + var D {.noInit.}: typeof(r.c0) + var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) + + r.c0.prod2x(a.c0, b.c0) # r0 = a0 b0 + D.prod2x(a.c1, b.c1) # d = a1 b1 + when QuadraticExt.has1extraBit(): + t0.sumUnr(a.c0, a.c1) + t1.sumUnr(b.c0, b.c1) + else: + t0.sum(a.c0, a.c1) + t1.sum(b.c0, b.c1) + r.c1.prod2x(t0, t1) # r1 = (b0 + b1)(a0 + a1) + when QuadraticExt.has1extraBit(): + r.c1.diff2xUnr(r.c1, r.c0) # r1 = (b0 + b1)(a0 + a1) - a0 b0 + r.c1.diff2xUnr(r.c1, D) # r1 = (b0 + b1)(a0 + a1) - a0 b0 - a1b1 + else: + r.c1.diff2xMod(r.c1, r.c0) + r.c1.diff2xMod(r.c1, D) + r.c0.diff2xMod(r.c0, D) # r0 = a0 b0 - a1 b1 + +func square2x_complex(r: var QuadraticExt2x, a: QuadraticExt) = + ## Double-precision unreduced complex squaring + + mixin fromComplexExtension + static: doAssert a.fromComplexExtension() + + var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) + + # Require 2 extra bits + when QuadraticExt.has2extraBits(): + t0.sumUnr(a.c1, a.c1) + t1.sum(a.c0, a.c1) + else: + t0.double(a.c1) + t1.sum(a.c0, a.c1) + + r.c1.prod2x(t0, a.c0) # r1 = 2a0a1 + t0.diff(a.c0, a.c1) + r.c0.prod2x(t0, t1) # r0 = (a0 + a1)(a0 - a1) + +# Commutative ring implementation for generic quadratic extension fields +# ---------------------------------------------------------------------- +# +# Some sparse functions, reconstruct a Fp4 from disjoint pieces +# to limit copies, we provide versions with disjoint elements +# prod2x_disjoint: +# - 2 products in mul_sparse_by_line_xyz000 (Fp4) +# - 2 products in mul_sparse_by_line_xy000z (Fp4) +# - mul_by_line_xy0 in mul_sparse_by_line_xy00z0 (Fp6) +# +# square2x_disjoint: +# - cyclotomic square in Fp2 -> Fp6 -> Fp12 towering +# needs Fp4 as special case + +func prod2x_disjoint[Fdbl, F]( + r: var QuadraticExt2x[FDbl], + a: QuadraticExt[F], + b0, b1: F) = + ## Return a * (b0, b1) in r + static: doAssert Fdbl is doublePrec(F) + + var V0 {.noInit.}, V1 {.noInit.}: typeof(r.c0) # Double-precision + var t0 {.noInit.}, t1 {.noInit.}: typeof(a.c0) # Single-width + + # Require 2 extra bits + V0.prod2x(a.c0, b0) # v0 = a0b0 + V1.prod2x(a.c1, b1) # v1 = a1b1 + when F.has1extraBit(): + t0.sumUnr(a.c0, a.c1) + t1.sumUnr(b0, b1) + else: + t0.sum(a.c0, a.c1) + t1.sum(b0, b1) + + r.c1.prod2x(t0, t1) # r1 = (a0 + a1)(b0 + b1) + when F.has1extraBit(): + r.c1.diff2xMod(r.c1, V0) + r.c1.diff2xMod(r.c1, V1) + else: + r.c1.diff2xMod(r.c1, V0) # r1 = (a0 + a1)(b0 + b1) - a0b0 + r.c1.diff2xMod(r.c1, V1) # r1 = (a0 + a1)(b0 + b1) - a0b0 - a1b1 + + r.c0.prod2x(V1, NonResidue) # r0 = β a1 b1 + r.c0.sum2xMod(r.c0, V0) # r0 = a0 b0 + β a1 b1 + +func square2x_disjoint[Fdbl, F]( + r: var QuadraticExt2x[FDbl], + a0, a1: F) = + ## Return (a0, a1)² in r + var V0 {.noInit.}, V1 {.noInit.}: typeof(r.c0) # Double-precision + var t {.noInit.}: F # Single-width + + # TODO: which is the best formulation? 3 squarings or 2 Mul? + # It seems like the higher the tower the better squarings are + # So for Fp12 = 2xFp6, prefer squarings. + V0.square2x(a0) + V1.square2x(a1) + t.sum(a0, a1) + + # r0 = a0² + β a1² (option 1) <=> (a0 + a1)(a0 + β a1) - β a0a1 - a0a1 (option 2) + r.c0.prod2x(V1, NonResidue) + r.c0.sum2xMod(r.c0, V0) + + # r1 = 2 a0 a1 (option 1) = (a0 + a1)² - a0² - a1² (option 2) + r.c1.square2x(t) + r.c1.diff2xMod(r.c1, V0) + r.c1.diff2xMod(r.c1, V1) + +# Dispatch +# ---------------------------------------------------------------------- + +func prod2x(r: var QuadraticExt2x, a, b: QuadraticExt) = + mixin fromComplexExtension + when a.fromComplexExtension(): + r.prod2x_complex(a, b) + else: + r.prod2x_disjoint(a, b.c0, b.c1) + +func square2x(r: var QuadraticExt2x, a: QuadraticExt) = + mixin fromComplexExtension + when a.fromComplexExtension(): + r.square2x_complex(a) + else: + r.square2x_disjoint(a.c0, a.c1) + +# ############################################################ +# # +# Cubic extensions - Lazy Reductions # +# # +# ############################################################ + + # ############################################################ # # # Quadratic extensions # @@ -386,60 +641,18 @@ func prod_complex(r: var QuadraticExt, a, b: QuadraticExt) = mixin fromComplexExtension static: doAssert r.fromComplexExtension() - # TODO: GCC is adding an unexplainable 30 cycles tax to this function (~10% slow down) - # for seemingly no reason - - when false: # Single-width implementation - BLS12-381 - # Clang 348 cycles on i9-9980XE @3.9 GHz - var a0b0 {.noInit.}, a1b1 {.noInit.}: typeof(r.c0) - a0b0.prod(a.c0, b.c0) # [1 Mul] - a1b1.prod(a.c1, b.c1) # [2 Mul] - - r.c0.sum(a.c0, a.c1) # r0 = (a0 + a1) # [2 Mul, 1 Add] - r.c1.sum(b.c0, b.c1) # r1 = (b0 + b1) # [2 Mul, 2 Add] - # aliasing: a and b unneeded now - r.c1 *= r.c0 # r1 = (b0 + b1)(a0 + a1) # [3 Mul, 2 Add] - 𝔽p temporary - - r.c0.diff(a0b0, a1b1) # r0 = a0 b0 - a1 b1 # [3 Mul, 2 Add, 1 Sub] - r.c1 -= a0b0 # r1 = (b0 + b1)(a0 + a1) - a0b0 # [3 Mul, 2 Add, 2 Sub] - r.c1 -= a1b1 # r1 = (b0 + b1)(a0 + a1) - a0b0 - a1b1 # [3 Mul, 2 Add, 3 Sub] - - else: # Double-width implementation with lazy reduction - # Clang 341 cycles on i9-9980XE @3.9 GHz - var a0b0 {.noInit.}, a1b1 {.noInit.}: doubleWidth(typeof(r.c0)) - var d {.noInit.}: doubleWidth(typeof(r.c0)) - const msbSet = r.c0.typeof.canUseNoCarryMontyMul() - - a0b0.mulNoReduce(a.c0, b.c0) # 44 cycles - cumul 44 - a1b1.mulNoReduce(a.c1, b.c1) # 44 cycles - cumul 88 - when msbSet: - r.c0.sum(a.c0, a.c1) - r.c1.sum(b.c0, b.c1) - else: - r.c0.sumNoReduce(a.c0, a.c1) # 5 cycles - cumul 93 - r.c1.sumNoReduce(b.c0, b.c1) # 5 cycles - cumul 98 - # aliasing: a and b unneeded now - d.mulNoReduce(r.c0, r.c1) # 44 cycles - cumul 142 - when msbSet: - d -= a0b0 - d -= a1b1 - else: - d.diffNoReduce(d, a0b0) # 11 cycles - cumul 153 - d.diffNoReduce(d, a1b1) # 11 cycles - cumul 164 - a0b0.diff(a0b0, a1b1) # 19 cycles - cumul 183 - r.c0.reduce(a0b0) # 50 cycles - cumul 233 - r.c1.reduce(d) # 50 cycles - cumul 288 - - # Single-width [3 Mul, 2 Add, 3 Sub] - # 3*88 + 2*14 + 3*14 = 334 theoretical cycles - # 348 measured - # Double-Width - # 288 theoretical cycles - # 329 measured - # Unexplained 40 cycles diff between theo and measured - # and unexplained 30 cycles between Clang and GCC - # - Function calls? - # - push/pop stack? + var a0b0 {.noInit.}, a1b1 {.noInit.}: typeof(r.c0) + a0b0.prod(a.c0, b.c0) # [1 Mul] + a1b1.prod(a.c1, b.c1) # [2 Mul] + + r.c0.sum(a.c0, a.c1) # r0 = (a0 + a1) # [2 Mul, 1 Add] + r.c1.sum(b.c0, b.c1) # r1 = (b0 + b1) # [2 Mul, 2 Add] + # aliasing: a and b unneeded now + r.c1 *= r.c0 # r1 = (b0 + b1)(a0 + a1) # [3 Mul, 2 Add] - 𝔽p temporary + + r.c0.diff(a0b0, a1b1) # r0 = a0 b0 - a1 b1 # [3 Mul, 2 Add, 1 Sub] + r.c1 -= a0b0 # r1 = (b0 + b1)(a0 + a1) - a0b0 # [3 Mul, 2 Add, 2 Sub] + r.c1 -= a1b1 # r1 = (b0 + b1)(a0 + a1) - a0b0 - a1b1 # [3 Mul, 2 Add, 3 Sub] func mul_sparse_complex_by_0y( r: var QuadraticExt, a: QuadraticExt, @@ -497,31 +710,67 @@ func square_generic(r: var QuadraticExt, a: QuadraticExt) = # # Alternative 2: # c0² + β c1² <=> (c0 + c1)(c0 + β c1) - β c0c1 - c0c1 - mixin prod - var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) + # + # This gives us 2 Mul and 2 mul-nonresidue (which is costly for BN254_Snarks) + # + # We can also reframe the 2nd term with only squarings + # which might be significantly faster on higher tower degrees + # + # 2 c0 c1 <=> (a0 + a1)² - a0² - a1² + # + # This gives us 3 Sqr and 1 Mul-non-residue + const costlyMul = block: + # No shortcutting in the VM :/ + when a.c0 is ExtensionField: + when a.c0.c0 is ExtensionField: + true + else: + false + else: + false - # v1 <- (c0 + β c1) - v1.prod(a.c1, NonResidue) - v1 += a.c0 + when QuadraticExt.C == BN254_Snarks or costlyMul: + var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) + v0.square(a.c0) + v1.square(a.c1) - # v0 <- (c0 + c1)(c0 + β c1) - v0.sum(a.c0, a.c1) - v0 *= v1 + # Aliasing: a unneeded now + r.c1.sum(a.c0, a.c1) - # v1 <- c0 c1 - v1.prod(a.c0, a.c1) + # r0 = c0² + β c1² + r.c0.prod(v1, NonResidue) + r.c0 += v0 - # aliasing: a unneeded now + # r1 = (a0 + a1)² - a0² - a1² + r.c1.square() + r.c1 -= v0 + r.c1 -= v1 - # r0 = (c0 + c1)(c0 + β c1) - c0c1 - v0 -= v1 + else: + var v0 {.noInit.}, v1 {.noInit.}: typeof(r.c0) - # r1 = 2 c0c1 - r.c1.double(v1) + # v1 <- (c0 + β c1) + v1.prod(a.c1, NonResidue) + v1 += a.c0 - # r0 = (c0 + c1)(c0 + β c1) - c0c1 - β c0c1 - v1 *= NonResidue - r.c0.diff(v0, v1) + # v0 <- (c0 + c1)(c0 + β c1) + v0.sum(a.c0, a.c1) + v0 *= v1 + + # v1 <- c0 c1 + v1.prod(a.c0, a.c1) + + # aliasing: a unneeded now + + # r0 = (c0 + c1)(c0 + β c1) - c0c1 + v0 -= v1 + + # r1 = 2 c0c1 + r.c1.double(v1) + + # r0 = (c0 + c1)(c0 + β c1) - c0c1 - β c0c1 + v1 *= NonResidue + r.c0.diff(v0, v1) func prod_generic(r: var QuadraticExt, a, b: QuadraticExt) = ## Returns r = a * b @@ -529,7 +778,6 @@ func prod_generic(r: var QuadraticExt, a, b: QuadraticExt) = # # r0 = a0 b0 + β a1 b1 # r1 = (a0 + a1) (b0 + b1) - a0 b0 - a1 b1 (Karatsuba) - mixin prod var v0 {.noInit.}, v1 {.noInit.}, v2 {.noInit.}: typeof(r.c0) # v2 <- (a0 + a1)(b0 + b1) @@ -564,7 +812,6 @@ func mul_sparse_generic_by_x0(r: var QuadraticExt, a, sparseB: QuadraticExt) = # # r0 = a0 b0 # r1 = (a0 + a1) b0 - a0 b0 = a1 b0 - mixin prod template b(): untyped = sparseB r.c0.prod(a.c0, b.c0) @@ -658,21 +905,52 @@ func invImpl(r: var QuadraticExt, a: QuadraticExt) = # Exported quadratic symbols # ------------------------------------------------------------------- -{.push inline.} - func square*(r: var QuadraticExt, a: QuadraticExt) = mixin fromComplexExtension when r.fromComplexExtension(): - r.square_complex(a) + when true: + r.square_complex(a) + else: # slower + var d {.noInit.}: doublePrec(typeof(r)) + d.square2x_complex(a) + r.c0.redc2x(d.c0) + r.c1.redc2x(d.c1) else: - r.square_generic(a) + when true: # r.typeof.F.C in {BLS12_377, BW6_761}: + # BW6-761 requires too many registers for Dbl width path + r.square_generic(a) + else: + # TODO understand why Fp4[BLS12_377] + # is so slow in the branch + # TODO: + # - On Fp4, we can have a.c0.c0 off by p + # a reduction is missing + var d {.noInit.}: doublePrec(typeof(r)) + d.square2x_disjoint(a.c0, a.c1) + r.c0.redc2x(d.c0) + r.c1.redc2x(d.c1) func prod*(r: var QuadraticExt, a, b: QuadraticExt) = mixin fromComplexExtension when r.fromComplexExtension(): - r.prod_complex(a, b) + when false: + r.prod_complex(a, b) + else: # faster + var d {.noInit.}: doublePrec(typeof(r)) + d.prod2x_complex(a, b) + r.c0.redc2x(d.c0) + r.c1.redc2x(d.c1) else: - r.prod_generic(a, b) + when r.typeof.F.C == BW6_761 or typeof(r.c0) is Fp: + # BW6-761 requires too many registers for Dbl width path + r.prod_generic(a, b) + else: + var d {.noInit.}: doublePrec(typeof(r)) + d.prod2x_disjoint(a, b.c0, b.c1) + r.c0.redc2x(d.c0) + r.c1.redc2x(d.c1) + +{.push inline.} func inv*(r: var QuadraticExt, a: QuadraticExt) = ## Compute the multiplicative inverse of ``a`` @@ -765,7 +1043,6 @@ func mul_sparse_by_x0*(a: var QuadraticExt, sparseB: QuadraticExt) = func square_Chung_Hasan_SQR2(r: var CubicExt, a: CubicExt) {.used.}= ## Returns r = a² - mixin prod, square, sum var s0{.noInit.}, m01{.noInit.}, m12{.noInit.}: typeof(r.c0) # precomputations that use a @@ -801,7 +1078,6 @@ func square_Chung_Hasan_SQR2(r: var CubicExt, a: CubicExt) {.used.}= func square_Chung_Hasan_SQR3(r: var CubicExt, a: CubicExt) = ## Returns r = a² - mixin prod, square, sum var s0{.noInit.}, t{.noInit.}, m12{.noInit.}: typeof(r.c0) # s₀ = (a₀ + a₁ + a₂)² diff --git a/constantine/tower_field_extensions/tower_instantiation.nim b/constantine/tower_field_extensions/tower_instantiation.nim index cdcbc1a71..bb337d1db 100644 --- a/constantine/tower_field_extensions/tower_instantiation.nim +++ b/constantine/tower_field_extensions/tower_instantiation.nim @@ -116,17 +116,24 @@ func prod*(r: var Fp2, a: Fp2, _: type NonResidue) {.inline.} = # BLS12_377 and BW6_761, use small addition chain r.mul_sparse_by_0y(a, v) else: - # BN254_Snarks, u = 9 - # Full 𝔽p2 multiplication is cheaper than addition chains - # for u*c0 and u*c1 - static: - doAssert u >= 0 and uint64(u) <= uint64(high(BaseType)) - doAssert v >= 0 and uint64(v) <= uint64(high(BaseType)) - # TODO: compile-time - var NR {.noInit.}: Fp2 - NR.c0.fromUint(uint u) - NR.c1.fromUint(uint v) - r.prod(a, NR) + # BN254_Snarks, u = 9, v = 1, β = -1 + # Even with u = 9, the 2x9 addition chains (8 additions total) + # are cheaper than full Fp2 multiplication + var t {.noInit.}: typeof(a.c0) + + t.prod(a.c0, u) + when v == 1 and Beta == -1: # Case BN254_Snarks + t -= a.c1 # r0 = u c0 + β v c1 + else: + {.error: "Unimplemented".} + + r.c1.prod(a.c1, u) + when v == 1: # r1 = v c0 + u c1 + r.c1 += a.c0 + # aliasing: a.c0 is unused + r.c0 = t + else: + {.error: "Unimplemented".} func `*=`*(a: var Fp2, _: type NonResidue) {.inline.} = ## Multiply an element of 𝔽p2 by the non-residue diff --git a/docs/optimizations.md b/docs/optimizations.md index 023119b3e..e0a939dfa 100644 --- a/docs/optimizations.md +++ b/docs/optimizations.md @@ -26,10 +26,10 @@ The optimizations can be of algebraic, algorithmic or "implementation details" n - [x] x86: MULX, ADCX, ADOX instructions - [x] Fused Multiply + Shift-right by word (for Barrett Reduction and approximating multiplication by fractional constant) - Squaring - - [ ] Dedicated squaring functions - - [ ] int128 + - [x] Dedicated squaring functions + - [x] int128 - [ ] loop unrolling - - [ ] x86: Full Assembly implementation + - [x] x86: Full Assembly implementation - [ ] x86: MULX, ADCX, ADOX instructions ## Finite Fields & Modular Arithmetic @@ -107,13 +107,13 @@ The optimizations can be of algebraic, algorithmic or "implementation details" n ## Extension Fields -- [ ] Lazy reduction via double-width base fields +- [x] Lazy reduction via double-precision base fields - [x] Sparse multiplication - Fp2 - [x] complex multiplication - [x] complex squaring - [x] sqrt via the constant-time complex method (Adj et al) - - [ ] sqrt using addition chain + - [x] sqrt using addition chain - [x] fused complex method sqrt by rotating in complex plane - Cubic extension fields - [x] Toom-Cook polynomial multiplication (Chung-Hasan) diff --git a/helpers/prng_unsafe.nim b/helpers/prng_unsafe.nim index 38740ac48..80b4a99a6 100644 --- a/helpers/prng_unsafe.nim +++ b/helpers/prng_unsafe.nim @@ -146,7 +146,7 @@ func random_unsafe(rng: var RngState, a: var FF) = # Note: a simple modulo will be biaised but it's simple and "fast" reduced.reduce(unreduced, FF.fieldMod()) - a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) + a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) func random_unsafe(rng: var RngState, a: var ExtensionField) = ## Recursively initialize an extension Field element @@ -177,7 +177,7 @@ func random_highHammingWeight(rng: var RngState, a: var FF) = # Note: a simple modulo will be biaised but it's simple and "fast" reduced.reduce(unreduced, FF.fieldMod()) - a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) + a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) func random_highHammingWeight(rng: var RngState, a: var ExtensionField) = ## Recursively initialize an extension Field element @@ -222,7 +222,7 @@ func random_long01Seq(rng: var RngState, a: var FF) = # Note: a simple modulo will be biaised but it's simple and "fast" reduced.reduce(unreduced, FF.fieldMod()) - a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.canUseNoCarryMontyMul()) + a.mres.montyResidue(reduced, FF.fieldMod(), FF.getR2modP(), FF.getNegInvModWord(), FF.getSpareBits()) func random_long01Seq(rng: var RngState, a: var ExtensionField) = ## Recursively initialize an extension Field element diff --git a/tests/t_finite_fields_double_precision.nim b/tests/t_finite_fields_double_precision.nim new file mode 100644 index 000000000..60eaf5b58 --- /dev/null +++ b/tests/t_finite_fields_double_precision.nim @@ -0,0 +1,228 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + # Standard library + std/[unittest, times], + # Internal + ../constantine/arithmetic, + ../constantine/io/[io_bigints, io_fields], + ../constantine/config/[curves, common, type_bigint], + # Test utilities + ../helpers/prng_unsafe + +const Iters = 24 + +var rng: RngState +let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32 +rng.seed(seed) +echo "\n------------------------------------------------------\n" +echo "test_finite_fields_double_precision xoshiro512** seed: ", seed + +template addsubnegTest(rng_gen: untyped): untyped = + proc `addsubneg _ rng_gen`(C: static Curve) = + # Try to exercise all code paths for in-place/out-of-place add/sum/sub/diff/double/neg + # (1 - (-a) - b + (-a) - 2a) + (2a + 2b + (-b)) == 1 + let aFp = rng_gen(rng, Fp[C]) + let bFp = rng_gen(rng, Fp[C]) + var accumFp {.noInit.}: Fp[C] + var OneFp {.noInit.}: Fp[C] + var accum {.noInit.}, One {.noInit.}, a{.noInit.}, na{.noInit.}, b{.noInit.}, nb{.noInit.}, a2 {.noInit.}, b2 {.noInit.}: FpDbl[C] + + OneFp.setOne() + One.prod2x(OneFp, OneFp) + a.prod2x(aFp, OneFp) + b.prod2x(bFp, OneFp) + + block: # sanity check + var t: Fp[C] + t.redc2x(One) + doAssert bool t.isOne() + + a2.sum2xMod(a, a) + na.neg2xMod(a) + + block: # sanity check + var t0, t1: Fp[C] + t0.redc2x(na) + t1.neg(aFp) + doAssert bool(t0 == t1), + "Beware, if the hex are the same, it means the outputs are the same (mod p),\n" & + "but one might not be completely reduced\n" & + " t0: " & t0.toHex() & "\n" & + " t1: " & t1.toHex() & "\n" + + block: # sanity check + var t0, t1: Fp[C] + t0.redc2x(a2) + t1.double(aFp) + doAssert bool(t0 == t1), + "Beware, if the hex are the same, it means the outputs are the same (mod p),\n" & + "but one might not be completely reduced\n" & + " t0: " & t0.toHex() & "\n" & + " t1: " & t1.toHex() & "\n" + + b2.sum2xMod(b, b) + nb.neg2xMod(b) + + accum.diff2xMod(One, na) + accum.diff2xMod(accum, b) + accum.sum2xMod(accum, na) + accum.diff2xMod(accum, a2) + + var t{.noInit.}: FpDbl[C] + t.sum2xMod(a2, b2) + t.sum2xMod(t, nb) + + accum.sum2xMod(accum, t) + accumFp.redc2x(accum) + doAssert bool accumFp.isOne(), + "Beware, if the hex are the same, it means the outputs are the same (mod p),\n" & + "but one might not be completely reduced\n" & + " accumFp: " & accumFp.toHex() + +template mulTest(rng_gen: untyped): untyped = + proc `mul _ rng_gen`(C: static Curve) = + let a = rng_gen(rng, Fp[C]) + let b = rng_gen(rng, Fp[C]) + + var r_fp{.noInit.}, r_fpDbl{.noInit.}: Fp[C] + var tmpDbl{.noInit.}: FpDbl[C] + + r_fp.prod(a, b) + tmpDbl.prod2x(a, b) + r_fpDbl.redc2x(tmpDbl) + + doAssert bool(r_fp == r_fpDbl) + +template sqrTest(rng_gen: untyped): untyped = + proc `sqr _ rng_gen`(C: static Curve) = + let a = rng_gen(rng, Fp[C]) + + var mulDbl{.noInit.}, sqrDbl{.noInit.}: FpDbl[C] + + mulDbl.prod2x(a, a) + sqrDbl.square2x(a) + + doAssert bool(mulDbl == sqrDbl) + +addsubnegTest(random_unsafe) +addsubnegTest(randomHighHammingWeight) +addsubnegTest(random_long01Seq) +mulTest(random_unsafe) +mulTest(randomHighHammingWeight) +mulTest(random_long01Seq) +sqrTest(random_unsafe) +sqrTest(randomHighHammingWeight) +sqrTest(random_long01Seq) + +suite "Field Addition/Substraction/Negation via double-precision field elements" & " [" & $WordBitwidth & "-bit mode]": + test "With P-224 field modulus": + for _ in 0 ..< Iters: + addsubneg_random_unsafe(P224) + for _ in 0 ..< Iters: + addsubneg_randomHighHammingWeight(P224) + for _ in 0 ..< Iters: + addsubneg_random_long01Seq(P224) + + test "With P-256 field modulus": + for _ in 0 ..< Iters: + addsubneg_random_unsafe(P256) + for _ in 0 ..< Iters: + addsubneg_randomHighHammingWeight(P256) + for _ in 0 ..< Iters: + addsubneg_random_long01Seq(P256) + + test "With BN254_Snarks field modulus": + for _ in 0 ..< Iters: + addsubneg_random_unsafe(BN254_Snarks) + for _ in 0 ..< Iters: + addsubneg_randomHighHammingWeight(BN254_Snarks) + for _ in 0 ..< Iters: + addsubneg_random_long01Seq(BN254_Snarks) + + test "With BLS12_381 field modulus": + for _ in 0 ..< Iters: + addsubneg_random_unsafe(BLS12_381) + for _ in 0 ..< Iters: + addsubneg_randomHighHammingWeight(BLS12_381) + for _ in 0 ..< Iters: + addsubneg_random_long01Seq(BLS12_381) + + test "Negate 0 returns 0 (unique Montgomery repr)": + var a: FpDbl[BN254_Snarks] + var r {.noInit.}: FpDbl[BN254_Snarks] + r.neg2xMod(a) + + check: bool r.isZero() + +suite "Field Multiplication via double-precision field elements is consistent with single-width." & " [" & $WordBitwidth & "-bit mode]": + test "With P-224 field modulus": + for _ in 0 ..< Iters: + mul_random_unsafe(P224) + for _ in 0 ..< Iters: + mul_randomHighHammingWeight(P224) + for _ in 0 ..< Iters: + mul_random_long01Seq(P224) + + test "With P-256 field modulus": + for _ in 0 ..< Iters: + mul_random_unsafe(P256) + for _ in 0 ..< Iters: + mul_randomHighHammingWeight(P256) + for _ in 0 ..< Iters: + mul_random_long01Seq(P256) + + test "With BN254_Snarks field modulus": + for _ in 0 ..< Iters: + mul_random_unsafe(BN254_Snarks) + for _ in 0 ..< Iters: + mul_randomHighHammingWeight(BN254_Snarks) + for _ in 0 ..< Iters: + mul_random_long01Seq(BN254_Snarks) + + test "With BLS12_381 field modulus": + for _ in 0 ..< Iters: + mul_random_unsafe(BLS12_381) + for _ in 0 ..< Iters: + mul_randomHighHammingWeight(BLS12_381) + for _ in 0 ..< Iters: + mul_random_long01Seq(BLS12_381) + +suite "Field Squaring via double-precision field elements is consistent with single-width." & " [" & $WordBitwidth & "-bit mode]": + test "With P-224 field modulus": + for _ in 0 ..< Iters: + sqr_random_unsafe(P224) + for _ in 0 ..< Iters: + sqr_randomHighHammingWeight(P224) + for _ in 0 ..< Iters: + sqr_random_long01Seq(P224) + + test "With P-256 field modulus": + for _ in 0 ..< Iters: + sqr_random_unsafe(P256) + for _ in 0 ..< Iters: + sqr_randomHighHammingWeight(P256) + for _ in 0 ..< Iters: + sqr_random_long01Seq(P256) + + test "With BN254_Snarks field modulus": + for _ in 0 ..< Iters: + sqr_random_unsafe(BN254_Snarks) + for _ in 0 ..< Iters: + sqr_randomHighHammingWeight(BN254_Snarks) + for _ in 0 ..< Iters: + sqr_random_long01Seq(BN254_Snarks) + + test "With BLS12_381 field modulus": + for _ in 0 ..< Iters: + sqr_random_unsafe(BLS12_381) + for _ in 0 ..< Iters: + sqr_randomHighHammingWeight(BLS12_381) + for _ in 0 ..< Iters: + sqr_random_long01Seq(BLS12_381) diff --git a/tests/t_finite_fields_double_width.nim b/tests/t_finite_fields_double_width.nim deleted file mode 100644 index 0ca027ae3..000000000 --- a/tests/t_finite_fields_double_width.nim +++ /dev/null @@ -1,123 +0,0 @@ -# Constantine -# Copyright (c) 2018-2019 Status Research & Development GmbH -# Copyright (c) 2020-Present Mamy André-Ratsimbazafy -# Licensed and distributed under either of -# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). -# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). -# at your option. This file may not be copied, modified, or distributed except according to those terms. - -import - # Standard library - std/[unittest, times], - # Internal - ../constantine/arithmetic, - ../constantine/io/[io_bigints, io_fields], - ../constantine/config/[curves, common, type_bigint], - # Test utilities - ../helpers/prng_unsafe - -const Iters = 24 - -var rng: RngState -let seed = uint32(getTime().toUnix() and (1'i64 shl 32 - 1)) # unixTime mod 2^32 -rng.seed(seed) -echo "\n------------------------------------------------------\n" -echo "test_finite_fields_double_width xoshiro512** seed: ", seed - -template mulTest(rng_gen: untyped): untyped = - proc `mul _ rng_gen`(C: static Curve) = - let a = rng_gen(rng, Fp[C]) - let b = rng.random_unsafe(Fp[C]) - - var r_fp{.noInit.}, r_fpDbl{.noInit.}: Fp[C] - var tmpDbl{.noInit.}: FpDbl[C] - - r_fp.prod(a, b) - tmpDbl.mulNoReduce(a, b) - r_fpDbl.reduce(tmpDbl) - - doAssert bool(r_fp == r_fpDbl) - -template sqrTest(rng_gen: untyped): untyped = - proc `sqr _ rng_gen`(C: static Curve) = - let a = rng_gen(rng, Fp[C]) - - var mulDbl{.noInit.}, sqrDbl{.noInit.}: FpDbl[C] - - mulDbl.mulNoReduce(a, a) - sqrDbl.squareNoReduce(a) - - doAssert bool(mulDbl == sqrDbl) - -mulTest(random_unsafe) -mulTest(randomHighHammingWeight) -mulTest(random_long01Seq) -sqrTest(random_unsafe) -sqrTest(randomHighHammingWeight) -sqrTest(random_long01Seq) - -suite "Field Multiplication via double-width field elements is consistent with single-width." & " [" & $WordBitwidth & "-bit mode]": - test "With P-224 field modulus": - for _ in 0 ..< Iters: - mul_random_unsafe(P224) - for _ in 0 ..< Iters: - mul_randomHighHammingWeight(P224) - for _ in 0 ..< Iters: - mul_random_long01Seq(P224) - - test "With P-256 field modulus": - for _ in 0 ..< Iters: - mul_random_unsafe(P256) - for _ in 0 ..< Iters: - mul_randomHighHammingWeight(P256) - for _ in 0 ..< Iters: - mul_random_long01Seq(P256) - - test "With BN254_Snarks field modulus": - for _ in 0 ..< Iters: - mul_random_unsafe(BN254_Snarks) - for _ in 0 ..< Iters: - mul_randomHighHammingWeight(BN254_Snarks) - for _ in 0 ..< Iters: - mul_random_long01Seq(BN254_Snarks) - - test "With BLS12_381 field modulus": - for _ in 0 ..< Iters: - mul_random_unsafe(BLS12_381) - for _ in 0 ..< Iters: - mul_randomHighHammingWeight(BLS12_381) - for _ in 0 ..< Iters: - mul_random_long01Seq(BLS12_381) - -suite "Field Squaring via double-width field elements is consistent with single-width." & " [" & $WordBitwidth & "-bit mode]": - test "With P-224 field modulus": - for _ in 0 ..< Iters: - sqr_random_unsafe(P224) - for _ in 0 ..< Iters: - sqr_randomHighHammingWeight(P224) - for _ in 0 ..< Iters: - sqr_random_long01Seq(P224) - - test "With P-256 field modulus": - for _ in 0 ..< Iters: - sqr_random_unsafe(P256) - for _ in 0 ..< Iters: - sqr_randomHighHammingWeight(P256) - for _ in 0 ..< Iters: - sqr_random_long01Seq(P256) - - test "With BN254_Snarks field modulus": - for _ in 0 ..< Iters: - sqr_random_unsafe(BN254_Snarks) - for _ in 0 ..< Iters: - sqr_randomHighHammingWeight(BN254_Snarks) - for _ in 0 ..< Iters: - sqr_random_long01Seq(BN254_Snarks) - - test "With BLS12_381 field modulus": - for _ in 0 ..< Iters: - sqr_random_unsafe(BLS12_381) - for _ in 0 ..< Iters: - sqr_randomHighHammingWeight(BLS12_381) - for _ in 0 ..< Iters: - sqr_random_long01Seq(BLS12_381) diff --git a/tests/t_finite_fields_mulsquare.nim b/tests/t_finite_fields_mulsquare.nim index 49cd86bda..722cef424 100644 --- a/tests/t_finite_fields_mulsquare.nim +++ b/tests/t_finite_fields_mulsquare.nim @@ -27,7 +27,7 @@ echo "test_finite_fields_mulsquare xoshiro512** seed: ", seed static: doAssert defined(testingCurves), "This modules requires the -d:testingCurves compile option" proc sanity(C: static Curve) = - test "Squaring 0,1,2 with "& $Curve(C) & " [FastSquaring = " & $Fp[C].canUseNoCarryMontySquare & "]": + test "Squaring 0,1,2 with "& $Curve(C) & " [FastSquaring = " & $(Fp[C].getSpareBits() >= 2) & "]": block: # 0² mod var n: Fp[C] @@ -89,7 +89,7 @@ mainSanity() proc mainSelectCases() = suite "Modular Squaring: selected tricky cases" & " [" & $WordBitwidth & "-bit mode]": - test "P-256 [FastSquaring = " & $Fp[P256].canUseNoCarryMontySquare & "]": + test "P-256 [FastSquaring = " & $(Fp[P256].getSpareBits() >= 2) & "]": block: # Triggered an issue in the (t[N+1], t[N]) = t[N] + (A1, A0) # between the squaring and reduction step, with t[N+1] and A1 being carry bits. @@ -136,7 +136,7 @@ proc random_long01Seq(C: static Curve) = doAssert bool(r_mul == r_sqr) suite "Random Modular Squaring is consistent with Modular Multiplication" & " [" & $WordBitwidth & "-bit mode]": - test "Random squaring mod P-224 [FastSquaring = " & $Fp[P224].canUseNoCarryMontySquare & "]": + test "Random squaring mod P-224 [FastSquaring = " & $(Fp[P224].getSpareBits() >= 2) & "]": for _ in 0 ..< Iters: randomCurve(P224) for _ in 0 ..< Iters: @@ -144,7 +144,8 @@ suite "Random Modular Squaring is consistent with Modular Multiplication" & " [" for _ in 0 ..< Iters: random_long01Seq(P224) - test "Random squaring mod P-256 [FastSquaring = " & $Fp[P256].canUseNoCarryMontySquare & "]": + test "Random squaring mod P-256 [FastSquaring = " & $(Fp[P256].getSpareBits() >= 2) & "]": + echo "Fp[P256].getSpareBits(): ", Fp[P256].getSpareBits() for _ in 0 ..< Iters: randomCurve(P256) for _ in 0 ..< Iters: @@ -152,7 +153,7 @@ suite "Random Modular Squaring is consistent with Modular Multiplication" & " [" for _ in 0 ..< Iters: random_long01Seq(P256) - test "Random squaring mod BLS12_381 [FastSquaring = " & $Fp[BLS12_381].canUseNoCarryMontySquare & "]": + test "Random squaring mod BLS12_381 [FastSquaring = " & $(Fp[BLS12_381].getSpareBits() >= 2) & "]": for _ in 0 ..< Iters: randomCurve(BLS12_381) for _ in 0 ..< Iters: diff --git a/tests/t_fp4.nim b/tests/t_fp4.nim new file mode 100644 index 000000000..8fc7d436a --- /dev/null +++ b/tests/t_fp4.nim @@ -0,0 +1,129 @@ +# Constantine +# Copyright (c) 2018-2019 Status Research & Development GmbH +# Copyright (c) 2020-Present Mamy André-Ratsimbazafy +# Licensed and distributed under either of +# * MIT license (license terms in the root directory or at http://opensource.org/licenses/MIT). +# * Apache v2 license (license terms in the root directory or at http://www.apache.org/licenses/LICENSE-2.0). +# at your option. This file may not be copied, modified, or distributed except according to those terms. + +import + std/unittest, + # Internals + ../constantine/towers, + ../constantine/io/io_towers, + ../constantine/config/curves, + # Test utilities + ./t_fp_tower_template + +const TestCurves = [ + BN254_Nogami, + BN254_Snarks, + BLS12_377, + BLS12_381, + BW6_761 + ] + +runTowerTests( + ExtDegree = 4, + Iters = 12, + TestCurves = TestCurves, + moduleName = "test_fp4", + testSuiteDesc = "𝔽p4 = 𝔽p2[v]" +) + +# Fuzzing failure +# Issue when using Fp4Dbl + +suite "𝔽p4 - Anti-regression": + test "Partial reduction (off by p) on double-precision field": + proc partred1() = + type F = Fp4[BN254_Snarks] + var x: F + x.fromHex( + "0x0000000000000000000fffffffffffffffffe000000fffffffffcffffff80000", + "0x000000000000007ffffffffff800000001fffe000000000007ffffffffffffe0", + "0x000000c0ff0300fcffffffff7f00000000f0ffffffffffffffff00000000e0ff", + "0x0e0a77c19a07df27e5eea36f7879462c0a7ceb28e5c70b3dd35d438dc58f4d9c" + ) + + # echo "x: ", x.toHex() + # echo "\n----------------------" + + var s: F + s.square(x) + + # echo "s: ", s.toHex() + # echo "\ns raw: ", s + + # echo "\n----------------------" + var p: F + p.prod(x, x) + + # echo "p: ", p.toHex() + # echo "\np raw: ", p + + check: bool(p == s) + + partred1() + + proc partred2() = + type F = Fp4[BN254_Snarks] + var x: F + x.fromHex( + "0x0660df54c75b67a0c32fc6208f08b13d8cc86cd93084180725a04884e7f45849", + "0x094185b0915ce1aa3bd3c63d33fd6d9cf3f04ea30fc88efe1e6e9b59117513bb", + "0x26c20beee711e46406372ab4f0e6d0069c67ded0a494bc0301bbfde48f7a4073", + "0x23c60254946def07120e46155466cc9b883b5c3d1c17d1d6516a6268a41dcc5d" + ) + + # echo "x: ", x.toHex() + # echo "\n----------------------" + + var s: F + s.square(x) + + # echo "s: ", s.toHex() + # echo "\ns raw: ", s + + # echo "\n----------------------" + var p: F + p.prod(x, x) + + # echo "p: ", p.toHex() + # echo "\np raw: ", p + + check: bool(p == s) + + partred2() + + + proc partred3() = + type F = Fp4[BN254_Snarks] + var x: F + x.fromHex( + "0x233066f735efcf7a0ad6e3ffa3afe4ed39bdfeffffb3f7d8b1fd7eeabfddfb36", + "0x1caba0b27fdfdfd512bdecf3fffbfebdb939fffffffbff8a14e663f7fef7fc85", + "0x212a64f0efefff1b7abe2ebe2bffbfc1b9335fb73ffd7c8815ffffffffffff8d", + "0x212ba4b1ff8feff552a61efff5ffffc5b839f7ffffffff71f477dffe7ffc7e08" + ) + + # echo "x: ", x.toHex() + # echo "\n----------------------" + + var s: F + s.square(x) + + # echo "s: ", s.toHex() + # echo "\ns raw: ", s + + # echo "\n----------------------" + var n, s2: F + n.neg(x) + s2.prod(n, n) + + # echo "s2: ", s2.toHex() + # echo "\ns2 raw: ", s2 + + check: bool(s == s2) + + partred3() diff --git a/tests/t_fp_tower_template.nim b/tests/t_fp_tower_template.nim index a9fa4e4e8..1e8541c59 100644 --- a/tests/t_fp_tower_template.nim +++ b/tests/t_fp_tower_template.nim @@ -20,6 +20,7 @@ import ../constantine/towers, ../constantine/config/[common, curves], ../constantine/arithmetic, + ../constantine/io/io_towers, # Test utilities ../helpers/[prng_unsafe, static_for] @@ -28,6 +29,8 @@ echo "\n------------------------------------------------------\n" template ExtField(degree: static int, curve: static Curve): untyped = when degree == 2: Fp2[curve] + elif degree == 4: + Fp4[curve] elif degree == 6: Fp6[curve] elif degree == 12: @@ -273,7 +276,7 @@ proc runTowerTests*[N]( rMul.prod(a, a) rSqr.square(a) - check: bool(rMul == rSqr) + doAssert bool(rMul == rSqr), "Failure with a (" & $Field & "): " & a.toHex() staticFor(curve, TestCurves): test(ExtField(ExtDegree, curve), Iters, gen = Uniform) @@ -292,7 +295,7 @@ proc runTowerTests*[N]( rSqr.square(a) rNegSqr.square(na) - check: bool(rSqr == rNegSqr) + doAssert bool(rSqr == rNegSqr), "Failure with a (" & $Field & "): " & a.toHex() staticFor(curve, TestCurves): test(ExtField(ExtDegree, curve), Iters, gen = Uniform) diff --git a/tests/t_fr.nim b/tests/t_fr.nim index d6eb349a4..de1e47a4e 100644 --- a/tests/t_fr.nim +++ b/tests/t_fr.nim @@ -25,7 +25,7 @@ echo "\n------------------------------------------------------\n" echo "test_fr xoshiro512** seed: ", seed proc sanity(C: static Curve) = - test "Fr: Squaring 0,1,2 with "& $Fr[C] & " [FastSquaring = " & $Fr[C].canUseNoCarryMontySquare & "]": + test "Fr: Squaring 0,1,2 with "& $Fr[C] & " [FastSquaring = " & $(Fr[C].getSpareBits() >= 2) & "]": block: # 0² mod var n: Fr[C] @@ -112,7 +112,7 @@ proc random_long01Seq(C: static Curve) = doAssert bool(r_mul == r_sqr) suite "Fr: Random Modular Squaring is consistent with Modular Multiplication" & " [" & $WordBitwidth & "-bit mode]": - test "Random squaring mod r_BN254_Snarks [FastSquaring = " & $Fr[BN254_Snarks].canUseNoCarryMontySquare & "]": + test "Random squaring mod r_BN254_Snarks [FastSquaring = " & $(Fr[BN254_Snarks].getSpareBits() >= 2) & "]": for _ in 0 ..< Iters: randomCurve(BN254_Snarks) for _ in 0 ..< Iters: @@ -120,7 +120,7 @@ suite "Fr: Random Modular Squaring is consistent with Modular Multiplication" & for _ in 0 ..< Iters: random_long01Seq(BN254_Snarks) - test "Random squaring mod r_BLS12_381 [FastSquaring = " & $Fr[BLS12_381].canUseNoCarryMontySquare & "]": + test "Random squaring mod r_BLS12_381 [FastSquaring = " & $(Fr[BLS12_381].getSpareBits() >= 2) & "]": for _ in 0 ..< Iters: randomCurve(BLS12_381) for _ in 0 ..< Iters: